In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import pickle
import copy
from src.group_model import Simulation, groupModel, groupTrainer, groupTrainerWithoutWN, summary_plot
from src.model import sparseModel, sparseTrainer
%matplotlib inline

In [None]:
sim = Simulation(m=100, p=500, seed=42, support = np.repeat([1,1,1,1],4), std=.5)

In [None]:
# sim.lst_est_err

In [None]:
model = groupModel(p=sim.p, group_size=4, depth=2)

In [None]:
init = 1e-3
for param in model.parameters():
    torch.nn.init.ones_(param)
model.u.weight.data *= init
for i in range(model.num_groups):
    model.vs[i].weight.data *= 0 # 1/np.sqrt(model.group_size)

In [None]:
trainer = groupTrainerWithoutWN(model, sim, lr=0.05)#, tol_on_u=3e-2, is_two_lr=True, is_small_train=False)

In [None]:
trainer.train(8000)

In [None]:
# summary_plot(trainer, n_groups=4, group_size=model.group_size)

In [None]:
model = groupModel(p=sim.p, group_size=4, depth=2)

In [None]:
init = 1e-3
for param in model.parameters():
    torch.nn.init.ones_(param)
model.u.weight.data *= init
for i in range(model.num_groups):
    model.vs[i].weight.data *= -1/np.sqrt(model.group_size)*init

In [None]:
trainer2 = groupTrainerWithoutWN(model, sim, lr=0.05)#, tol_on_u=3e-2, is_two_lr=True, is_small_train=False)

In [None]:
trainer2.train(8000)

In [None]:
# summary_plot(trainer2, n_groups=4, group_size=model.group_size)

In [None]:
plt.rcParams.update({'text.usetex': True})
plt.rcParams.update({'text.latex.preamble': r'\usepackage{amsmath}'})
plt.rcParams.update({'lines.linewidth': 3})
plt.rcParams.update({'font.size': 15})
plt.rcParams.update({'legend.frameon': False})

In [None]:
fig, axes = plt.subplots(1,3)
fig.set_size_inches(18,5)

colors = ['C0']*16 + ['C3']
axes[0].plot(trainer.monitor['w'], label=[r'$w_{li}(t), l\in S$']*16 + [r'$\max\limits_{l\notin S} w_{li}(t)$'])
for i, j in enumerate(axes[0].lines):
    j.set_color(colors[i])
axes[0].set_xlabel('epochs')
axes[0].set_ylabel(r'$w_{li}(t)$')
axes[0].set_title('Recovered entries with zero initialization')
handles, labels = axes[0].get_legend_handles_labels()
display = [0,16]
axes[0].legend([handle for i,handle in enumerate(handles) if i in display],
      [label for i,label in enumerate(labels) if i in display])#, loc=(.1,.7))

group_labels = ['group'+str(i+1) for i in range(4)]
axes[1].plot(trainer.dir, label = group_labels)
axes[1].set_title('Recovered group directions with zero initialization')
axes[1].set_xlabel('epochs')
axes[1].set_ylabel(r'$\langle \mathrm{v}_l(t), \mathrm{v}^\star\rangle$')
axes[1].legend()



colors = ['C0']*16+['C3']
axes[2].plot(trainer2.monitor['w'], label = [r'$w_{li}(t), l\in S$']*16 + [r'$\max\limits_{l\notin S} w_{li}(t)$'])
for i, j in enumerate(axes[2].lines):
    j.set_color(colors[i])
handles, labels = axes[2].get_legend_handles_labels()
axes[2].set_title('Recovered entries with small initialization')
display = (0,16)
_=axes[2].legend([handle for i,handle in enumerate(handles) if i in display],
      [label for i,label in enumerate(labels) if i in display])#, loc=(.05,.6))
axes[2].set_xlabel('epochs')
axes[2].set_ylabel(r'$w_{li}(t)$')
fig.tight_layout()
fig.savefig('outputs/three_stages.pdf')