In [None]:
import torch
import numpy as np
import copy
import matplotlib.pyplot as plt

from src.group_model import Simulation, groupModel, groupTrainer, summary_plot

In [None]:
sim = Simulation(150, 300, support = np.array([5,6,7,8,9,10,11,12,13]), std = .5, seed=6)

In [None]:
# sim.lst_est_err

In [None]:
model = groupModel(p=sim.p, group_size=3, depth=2)

In [None]:
init = 1e-6
for param in model.parameters():
    torch.nn.init.ones_(param)
model.u.weight.data *= init
for i in range(model.num_groups):
    model.vs[i].weight.data *= 1/np.sqrt(model.group_size)

In [None]:
trainer = groupTrainer(model, sim, lr=0.001, tol_on_u=1e-1, is_two_lr=True, is_small_train=False)

In [None]:
trainer.train(1000)

In [None]:
plt.rcParams.update({'text.usetex': True})
plt.rcParams.update({'text.latex.preamble': r'\usepackage{amsmath}'})
plt.rcParams.update({'lines.linewidth': 3})
plt.rcParams.update({'font.size': 15})
plt.rcParams.update({'legend.frameon': False})

In [None]:
fig, axes = plt.subplots(1,3)
fig.set_size_inches(18,5)

colors = ['C0']*9 + ['C3']
axes[0].plot(trainer.monitor['w'], label=[r'$w_{li}(t), l\in S$']*9 + [r'$\max\limits_{l\notin S} w_{li}(t)$'])
for i, j in enumerate(axes[0].lines):
    j.set_color(colors[i])
axes[0].set_xlabel('epochs')
axes[0].set_ylabel(r'$w_{li}(t)$')
axes[0].set_title('Recovered entries')
handles, labels = axes[0].get_legend_handles_labels()
display = [0,9]
axes[0].legend([handle for i,handle in enumerate(handles) if i in display],
      [label for i,label in enumerate(labels) if i in display], loc=(.6,.1))

n_groups = 3
group_size = 3
colors = ['C0' for i in range(n_groups)] + [f'C{n_groups}']
# axes[1].plot(trainer.monitor['u'], label = ['group '+str(i+1) for i in range(n_groups)] + ['non support'])
axes[1].plot(trainer.monitor['u'], label = [r'$u_l(t), l\in S$' for i in range(n_groups)] + [r'$\max\limits_{l\notin S} u_l(t)$'])
for i, j in enumerate(axes[1].lines):
    j.set_color(colors[i])
axes[1].set_title('Recovered group magnitudes')
# axes[1].legend()
axes[1].set_xlabel('epochs')
axes[1].set_ylabel(r'$u_l(t)$')
handles, labels = axes[1].get_legend_handles_labels()
display = [0,3]
axes[1].legend([handle for i,handle in enumerate(handles) if i in display],
      [label for i,label in enumerate(labels) if i in display], loc=(.6,.1))


group_labels = ['group'+str(i+1) for i in range(n_groups)]
axes[2].plot(trainer.dir, label = group_labels)
axes[2].set_title('Recovered group directions')
axes[2].set_xlabel('epochs')
axes[2].set_ylabel(r'$\langle \mathrm{v}_l(t), \mathrm{v}^\star\rangle$')
axes[2].legend()

fig.tight_layout()
fig.savefig('outputs/convergence_alg2.pdf')