In [None]:
import torch
import numpy as np
import copy
import matplotlib.pyplot as plt

from src.group_model import Simulation, groupModel, groupTrainer, summary_plot

In [None]:
sim = Simulation(150, 300, support = np.array([1,1,1,1,1,1,1,1,1])*10., std = .5, seed=6)

In [None]:
# sim.lst_est_err

In [None]:
model = groupModel(p=sim.p, group_size=3, depth=2)

In [None]:
init = 1e-6
for param in model.parameters():
    torch.nn.init.ones_(param)
model.u.weight.data *= init
for i in range(model.num_groups):
    model.vs[i].weight.data *= 1/np.sqrt(model.group_size)

In [None]:
trainer = groupTrainer(model, sim, lr=0.001, is_two_lr=False, is_small_train=False)

In [None]:
trainer.train(2000)

In [None]:
log_err = [np.log(x) for x in trainer.params_est_err]

In [None]:
# summary_plot(trainer, n_groups=3, group_size=model.group_size)

In [None]:
plt.rcParams.update({'text.usetex': True})
plt.rcParams.update({'text.latex.preamble': r'\usepackage{amsmath}'})
plt.rcParams.update({'lines.linewidth': 3})
plt.rcParams.update({'font.size': 15})
plt.rcParams.update({'legend.frameon': False})

In [None]:
fig, axes = plt.subplots(1,2)
fig.set_size_inches(12,4)
axes[0].plot([np.exp(x) for x in log_err])# [:500]])
axes[0].set_xlabel('epochs')
axes[0].set_ylabel(r'$||\mathrm{w}(t) - \mathrm{w}^\star||^2$')
axes[0].set_title(r'Recovery error')

n_groups = 3
group_size = 3
colors = ['C0' for i in range(n_groups)] + [f'C{n_groups}']
# axes[1].plot(trainer.monitor['u'], label = ['group '+str(i+1) for i in range(n_groups)] + ['non support'])
axes[1].plot(trainer.monitor['u'], label = [r'$u_l(t), l\in S$' for i in range(n_groups)] + [r'$\max\limits_{l\notin S} u_l(t)$'])
for i, j in enumerate(axes[1].lines):
    j.set_color(colors[i])
axes[1].set_title('Recovered group magnitudes')
# axes[1].legend()
axes[1].set_xlabel('epochs')
axes[1].set_ylabel(r'$u_l(t)$')
handles, labels = axes[1].get_legend_handles_labels()
display = [0,3]
axes[1].legend([handle for i,handle in enumerate(handles) if i in display],
      [label for i,label in enumerate(labels) if i in display])
fig.tight_layout()
# fig.savefig('outputs/convergence_alg1.pdf')

In [None]:
fig, axes = plt.subplots(1,2)
fig.set_size_inches(12,4)
axes[0].plot([np.exp(x) for x in log_err[:500]])
axes[0].set_xlabel('epochs')
axes[0].set_ylabel(r'$||\mathrm{w}(t) - \mathrm{w}^\star||^2$')
axes[0].set_title('Recovery error')

n_groups = 3
group_size = 3
colors = ['C0' for i in range(n_groups)] + [f'C{n_groups}']
# axes[1].plot(trainer.monitor['u'], label = ['group '+str(i+1) for i in range(n_groups)] + ['non support'])
axes[1].plot(trainer.monitor['u'][:500], label = [r'$u_l(t), l\in S$' for i in range(n_groups)] + [r'$\max\limits_{l\notin S} u_l(t)$'])
for i, j in enumerate(axes[1].lines):
    j.set_color(colors[i])
axes[1].set_title('Recovered group magnitudes')
# axes[1].legend()
axes[1].set_xlabel('epochs')
axes[1].set_ylabel(r'$u_l(t)$')
handles, labels = axes[1].get_legend_handles_labels()
display = [0,3]
axes[1].legend([handle for i,handle in enumerate(handles) if i in display],
      [label for i,label in enumerate(labels) if i in display])
fig.tight_layout()
fig.savefig('outputs/stability_alg1.pdf')

In [None]:
fig, axes = plt.subplots(1,2)
fig.set_size_inches(12,4)
ax = axes[0]

colors = ['C'+str(i) for i in range(n_groups) for j in range(group_size)] + [f'C{n_groups}']
ax.plot(trainer.monitor['w'][:500], label = ['group'+str(i+1) for i in range(n_groups) for j in range(group_size)] + ['non support'])
for i, j in enumerate(ax.lines):
    j.set_color(colors[i])
handles, labels = ax.get_legend_handles_labels()
ax.set_title('Recovered entries')
display = np.arange(0,n_groups*group_size+1,group_size)
# ax.axvline(275,ymin=.03,ymax=.97,color='black',linestyle='dashed')
ax.legend([handle for i,handle in enumerate(handles) if i in display],
      [label for i,label in enumerate(labels) if i in display])
    
ax = axes[1]
n_groups = 3
group_size = 3
colors = ['C'+str(i) for i in range(n_groups) for j in range(group_size)] + [f'C{n_groups}']
ax.plot(trainer.monitor['v'][:500], label = ['group'+str(i+1) for i in range(n_groups) for j in range(group_size)] + ['non support'])
for i, j in enumerate(ax.lines):
    j.set_color(colors[i])
handles, labels = ax.get_legend_handles_labels()
display = np.arange(0,n_groups*group_size+1,group_size)
ax.legend([handle for i,handle in enumerate(handles) if i in display],
      [label for i,label in enumerate(labels) if i in display], loc=(0.05,0))
ax.set_title(r'Recovered direction parameters ($\mathbf{v}$)')
# ax.axvline(275,ymin=.03,ymax=.97,color='black',linestyle='dashed')


fig.tight_layout()
fig.savefig('outputs/instability_alg1.pdf')