In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import pickle
from src.group_model import Simulation, groupModel, groupTrainer, summary_plot
from src.model import sparseModel, sparseTrainer
%matplotlib inline

In [None]:
sim = Simulation(m=100, p=500, seed=42, support = np.repeat([1,1,1,1],4), std=.5)

In [None]:
# sim.lst_est_err

In [None]:
model = groupModel(p=sim.p, group_size=4, depth=2)

In [None]:
init = 1e-6
for param in model.parameters():
    torch.nn.init.ones_(param)
model.u.weight.data *= init
for i in range(model.num_groups):
    model.vs[i].weight.data *= 1/np.sqrt(model.group_size)

In [None]:
trainer = groupTrainer(model, sim, lr=0.01, tol_on_u=3e-2, is_two_lr=True, is_small_train=False)

In [None]:
trainer.train(500)

In [None]:
# summary_plot(trainer, n_groups=4, group_size=model.group_size)

In [None]:
sparse_model = sparseModel(only_pos=True)

In [None]:
init = 1e-6
for param in sparse_model.parameters():
    torch.nn.init.ones_(param)
    param.data *= init

In [None]:
sparse_trainer = sparseTrainer(sparse_model, sim, lr=0.01)

In [None]:
sparse_trainer.train(epochs=500)

In [None]:
# fig, axes = plt.subplots(2,2)
# fig.set_size_inches(16, 10)
# axes[0,0].plot(sparse_trainer.loss)
# axes[0,0].set_title('loss')

# axes[0,1].plot(sparse_trainer.params_est_err)
# axes[0,1].set_title('estimation error')

# axes[1,0].plot(sparse_trainer.params_est_err[200:])
# axes[1,0].set_title('estimation error (zoom in)')

# colors = ['C0']*16+['C1']
# axes[1,1].plot(sparse_trainer.monitor, label = ['support']*16 + ['non support'])
# for i, j in enumerate(axes[1,1].lines):
#     j.set_color(colors[i])
# handles, labels = axes[1,1].get_legend_handles_labels()
# axes[1,1].set_title('recovered params')
# display = (0,16)
# _=axes[1,1].legend([handle for i,handle in enumerate(handles) if i in display],
#       [label for i,label in enumerate(labels) if i in display], loc=(1.04,0))

In [None]:
plt.rcParams.update({'text.usetex': True})
plt.rcParams.update({'text.latex.preamble': r'\usepackage{amsmath}'})
plt.rcParams.update({'lines.linewidth': 3})
plt.rcParams.update({'font.size': 15})
plt.rcParams.update({'legend.frameon': False})

In [None]:
fig, axes = plt.subplots(1,2)
fig.set_size_inches(12,4)

colors = ['C0']*16 + ['C3']
axes[0].plot(trainer.monitor['w'], label=[r'$w_{li}(t), l\in S$']*16 + [r'$\max\limits_{l\notin S} w_{li}(t)$'])
for i, j in enumerate(axes[0].lines):
    j.set_color(colors[i])
axes[0].set_xlabel('epochs')
axes[0].set_ylabel(r'$w_{li}(t)$')
axes[0].set_title('Recovered entries using group sparsity')
handles, labels = axes[0].get_legend_handles_labels()
display = [0,16]
axes[0].legend([handle for i,handle in enumerate(handles) if i in display],
      [label for i,label in enumerate(labels) if i in display], loc=(.5,.1))

colors = ['C0']*16+['C3']
axes[1].plot(sparse_trainer.monitor, label = [r'$w_{li}(t), l\in S$']*16 + [r'$\max\limits_{l\notin S} w_{li}(t)$'])
for i, j in enumerate(axes[1].lines):
    j.set_color(colors[i])
handles, labels = axes[1].get_legend_handles_labels()
axes[1].set_title('Recovered entries using sparsity')
display = (0,16)
_=axes[1].legend([handle for i,handle in enumerate(handles) if i in display],
      [label for i,label in enumerate(labels) if i in display], loc=(.05,.6))
axes[1].set_xlabel('epochs')
axes[1].set_ylabel(r'$w_{li}(t)$')
fig.tight_layout()
fig.savefig('outputs/group_vs_sparse.pdf')