In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import pickle
from src.group_model import Simulation, groupModel, groupTrainer, summary_plot
from src.model import sparseModel, sparseTrainer
%matplotlib inline

In [None]:
import asyncio
import pickle

In [None]:
def background(f):
    def wrapped(*args, **kwargs):
        return asyncio.get_event_loop().run_in_executor(None, f, *args, **kwargs)

    return wrapped

@background
def one_fit(seed, holder):
    sim = Simulation(m=80, p=200, seed=rep*100, support = np.repeat([1,-1,-1,1,1],1), std=.5)

    model = groupModel(p=sim.p, group_size=1, depth=2)

    init = 1e-6
    for param in model.parameters():
        torch.nn.init.ones_(param)
    model.u.weight.data *= init
    for i in range(model.num_groups):
        model.vs[i].weight.data *= 1/np.sqrt(model.group_size)

    trainer = groupTrainer(model, sim, lr=0.05, tol_on_u=3e-2, is_two_lr=False, is_small_train=False)

    trainer.train(1000)
    holder.append(trainer)
    print(rep)

In [None]:
holder = []
for rep in range(30):
    one_fit(rep, holder)

In [None]:
sim = Simulation(m=80, p=200, seed=42, support = np.repeat([1,-1,-1,1,1],1), std=.5)

In [None]:
# sim.lst_est_err

In [None]:
model = groupModel(p=sim.p, group_size=1, depth=2)

In [None]:
init = 1e-6
for param in model.parameters():
    torch.nn.init.ones_(param)
model.u.weight.data *= init
for i in range(model.num_groups):
    model.vs[i].weight.data *= 1/np.sqrt(model.group_size)

In [None]:
trainer = groupTrainer(model, sim, lr=0.05, tol_on_u=3e-2, is_two_lr=False, is_small_train=False)

In [None]:
trainer.train(1000)

In [None]:
# summary_plot(trainer, n_groups=sim.support.shape[0], group_size=model.group_size)

In [None]:
plt.rcParams.update({'text.usetex': True})
plt.rcParams.update({'text.latex.preamble': r'\usepackage{amsmath}'})
plt.rcParams.update({'lines.linewidth': 3})
plt.rcParams.update({'font.size': 15})
plt.rcParams.update({'legend.frameon': False})

In [None]:
fig, axes = plt.subplots(1,2)
fig.set_size_inches(12,4)

colors = ['C0']*5 + ['C3']
axes[0].plot(trainer.monitor['w'], label=[r'$w_{li}(t), l\in S$']*5 + [r'$\max\limits_{l\notin S} |w_{li}(t)|$'])
for i, j in enumerate(axes[0].lines):
    j.set_color(colors[i])
axes[0].set_xlabel('epochs')
axes[0].set_ylabel(r'$w_{li}(t)$')
axes[0].set_title('Recovered entries')
handles, labels = axes[0].get_legend_handles_labels()
display = [0,5]
axes[0].legend([handle for i,handle in enumerate(handles) if i in display],
      [label for i,label in enumerate(labels) if i in display], loc=(.6,.2))
axes[0].hlines(0,1000,0, color='black',linestyle='dashed')


ax = axes[1]

errs = [np.log(x.params_est_err) for x in holder]
errs = np.vstack(errs)

epochs = errs.shape[1]

means = np.mean(errs, axis=0)
# sd = np.std(errs, axis=0)
lower_percentile = np.percentile(errs, 5, axis=0)
upper_percentile = np.percentile(errs, 95, axis=0)

bars = np.stack((means - lower_percentile, upper_percentile - means))


epochs_vec = [*range(epochs)]
line_width = 2

ax.plot(epochs_vec, means)
ax.fill_between(epochs_vec, lower_percentile, upper_percentile, alpha=.15)

ax.set_ylabel(r'$\log_{2} ||\mathbf{w}_{t} - \mathbf{w}^{\star}||_{2}^{2}$')
ax.set_xlabel('epochs')
ax.set_title('Recovery error')

fig.tight_layout()
fig.savefig('outputs/degenerated_case.pdf')