In [None]:
import torch
import numpy as np
import copy
import matplotlib.pyplot as plt
from src.group_model import Simulation, groupModel, groupTrainer, summary_plot

In [None]:
class GaussianSimulation:
    def __init__(self, m=200, p=500, seed=42, support = np.repeat([1,2,3,4],4), # np.tile
                 std = 0.5):
        self.m = m
        self.p = p
        self.k = support.shape[0]
        self.support = support
        self.seed = seed
        self.std = std
        
        np.random.seed(seed)
        X = np.random.normal(size=(m,p))# np.random.binomial(1, 0.5, (m, p))*2 - 1
        w_star = np.hstack((support, np.zeros(self.p-self.k)))

        signal = np.matmul(X, w_star)
        noise = np.random.normal(scale=std, size=m)
        y = signal + noise
        
        self.signal, self.noise = signal, noise
        self.X = torch.tensor(X, dtype=torch.float32)
        self.w_star = torch.tensor(w_star, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
        
        self._least_square()
        self._snr()
        
    def _least_square(self):
        self.w_lst = torch.linalg.lstsq(self.X, self.y).solution
        self.lst_est_err = ((self.w_lst-self.w_star)**2).mean().item()
        
    def _snr(self):
        self.snr = np.sqrt((self.signal ** 2).sum()) / np.sqrt((self.noise**2).sum())
        print(f'SNR: {self.snr:.4f}')
        

# convergence of alg1

In [None]:
sim = GaussianSimulation(150, 300, support = np.array([1,1,1,1,1,1,1,1,1])*10., std = .1, seed=6)

In [None]:
model = groupModel(p=sim.p, group_size=3, depth=2)

In [None]:
init = 1e-6
for param in model.parameters():
    torch.nn.init.ones_(param)
model.u.weight.data *= init
for i in range(model.num_groups):
    model.vs[i].weight.data *= 1/np.sqrt(model.group_size)

In [None]:
trainer = groupTrainer(model, sim, lr=0.001, is_two_lr=False, is_small_train=False)

In [None]:
trainer.train(2000)

In [None]:
log_err = [np.log(x) for x in trainer.params_est_err]

In [None]:
plt.rcParams.update({'text.usetex': True})
plt.rcParams.update({'text.latex.preamble': r'\usepackage{amsmath}'})
plt.rcParams.update({'lines.linewidth': 3})
plt.rcParams.update({'font.size': 15})
plt.rcParams.update({'legend.frameon': False})

In [None]:
fig, axes = plt.subplots(1,2)
fig.set_size_inches(12,4)
axes[0].plot([np.exp(x) for x in log_err])# [:500]])
axes[0].set_xlabel('epochs')
axes[0].set_ylabel(r'$||\mathrm{w}(t) - \mathrm{w}^\star||^2$')
axes[0].set_title(r'Recovery error')

n_groups = 3
group_size = 3
colors = ['C0' for i in range(n_groups)] + [f'C{n_groups}']
# axes[1].plot(trainer.monitor['u'], label = ['group '+str(i+1) for i in range(n_groups)] + ['non support'])
axes[1].plot(trainer.monitor['u'], label = [r'$u_l(t), l\in S$' for i in range(n_groups)] + [r'$\max\limits_{l\notin S} u_l(t)$'])
for i, j in enumerate(axes[1].lines):
    j.set_color(colors[i])
axes[1].set_title('Recovered group magnitudes')
# axes[1].legend()
axes[1].set_xlabel('epochs')
axes[1].set_ylabel(r'$u_l(t)$')
handles, labels = axes[1].get_legend_handles_labels()
display = [0,3]
axes[1].legend([handle for i,handle in enumerate(handles) if i in display],
      [label for i,label in enumerate(labels) if i in display])
fig.tight_layout()
fig.savefig('outputs/gaussian_convergence_alg1.pdf')

In [None]:
fig, axes = plt.subplots(1,2)
fig.set_size_inches(12,4)
axes[0].plot([np.exp(x) for x in log_err[:500]])
axes[0].set_xlabel('epochs')
axes[0].set_ylabel(r'$||\mathrm{w}(t) - \mathrm{w}^\star||^2$')
axes[0].set_title('Recovery error')

n_groups = 3
group_size = 3
colors = ['C0' for i in range(n_groups)] + [f'C{n_groups}']
# axes[1].plot(trainer.monitor['u'], label = ['group '+str(i+1) for i in range(n_groups)] + ['non support'])
axes[1].plot(trainer.monitor['u'][:500], label = [r'$u_l(t), l\in S$' for i in range(n_groups)] + [r'$\max\limits_{l\notin S} u_l(t)$'])
for i, j in enumerate(axes[1].lines):
    j.set_color(colors[i])
axes[1].set_title('Recovered group magnitudes')
# axes[1].legend()
axes[1].set_xlabel('epochs')
axes[1].set_ylabel(r'$u_l(t)$')
handles, labels = axes[1].get_legend_handles_labels()
display = [0,3]
axes[1].legend([handle for i,handle in enumerate(handles) if i in display],
      [label for i,label in enumerate(labels) if i in display])
fig.tight_layout()
fig.savefig('outputs/gaussian_stability_alg1.pdf')

In [None]:
fig, axes = plt.subplots(1,2)
fig.set_size_inches(12,4)
ax = axes[0]

colors = ['C'+str(i) for i in range(n_groups) for j in range(group_size)] + [f'C{n_groups}']
ax.plot(trainer.monitor['w'][:500], label = ['group'+str(i+1) for i in range(n_groups) for j in range(group_size)] + ['non support'])
for i, j in enumerate(ax.lines):
    j.set_color(colors[i])
handles, labels = ax.get_legend_handles_labels()
ax.set_title('Recovered entries')
display = np.arange(0,n_groups*group_size+1,group_size)
# ax.axvline(275,ymin=.03,ymax=.97,color='black',linestyle='dashed')
ax.legend([handle for i,handle in enumerate(handles) if i in display],
      [label for i,label in enumerate(labels) if i in display])
    
ax = axes[1]
n_groups = 3
group_size = 3
colors = ['C'+str(i) for i in range(n_groups) for j in range(group_size)] + [f'C{n_groups}']
ax.plot(trainer.monitor['v'][:500], label = ['group'+str(i+1) for i in range(n_groups) for j in range(group_size)] + ['non support'])
for i, j in enumerate(ax.lines):
    j.set_color(colors[i])
handles, labels = ax.get_legend_handles_labels()
display = np.arange(0,n_groups*group_size+1,group_size)
ax.legend([handle for i,handle in enumerate(handles) if i in display],
      [label for i,label in enumerate(labels) if i in display], loc=(0.05,0))
ax.set_title(r'Recovered direction parameters ($\mathbf{v}$)')
# ax.axvline(275,ymin=.03,ymax=.97,color='black',linestyle='dashed')


fig.tight_layout()
fig.savefig('outputs/gaussian_instability_alg1.pdf')

# convergence of alg2

In [None]:
sim = GaussianSimulation(150, 300, support = np.array([5,6,7,8,9,10,11,12,13]), std = .5, seed=42)

In [None]:
model = groupModel(p=sim.p, group_size=3, depth=2)

In [None]:
init = 1e-6
for param in model.parameters():
    torch.nn.init.ones_(param)
model.u.weight.data *= init
for i in range(model.num_groups):
    model.vs[i].weight.data *= 1/np.sqrt(model.group_size)

In [None]:
trainer = groupTrainer(model, sim, lr=0.001, tol_on_u=1e-1, is_two_lr=True, is_small_train=False)

In [None]:
trainer.train(1000)

In [None]:
fig, axes = plt.subplots(1,3)
fig.set_size_inches(18,5)

colors = ['C0']*9 + ['C3']
axes[0].plot(trainer.monitor['w'], label=[r'$w_{li}(t), l\in S$']*9 + [r'$\max\limits_{l\notin S} w_{li}(t)$'])
for i, j in enumerate(axes[0].lines):
    j.set_color(colors[i])
axes[0].set_xlabel('epochs')
axes[0].set_ylabel(r'$w_{li}(t)$')
axes[0].set_title('Recovered entries')
handles, labels = axes[0].get_legend_handles_labels()
display = [0,9]
axes[0].legend([handle for i,handle in enumerate(handles) if i in display],
      [label for i,label in enumerate(labels) if i in display], loc=(.6,.1))

n_groups = 3
group_size = 3
colors = ['C0' for i in range(n_groups)] + [f'C{n_groups}']
# axes[1].plot(trainer.monitor['u'], label = ['group '+str(i+1) for i in range(n_groups)] + ['non support'])
axes[1].plot(trainer.monitor['u'], label = [r'$u_l(t), l\in S$' for i in range(n_groups)] + [r'$\max\limits_{l\notin S} u_l(t)$'])
for i, j in enumerate(axes[1].lines):
    j.set_color(colors[i])
axes[1].set_title('Recovered group magnitudes')
# axes[1].legend()
axes[1].set_xlabel('epochs')
axes[1].set_ylabel(r'$u_l(t)$')
handles, labels = axes[1].get_legend_handles_labels()
display = [0,3]
axes[1].legend([handle for i,handle in enumerate(handles) if i in display],
      [label for i,label in enumerate(labels) if i in display], loc=(.6,.1))


group_labels = ['group'+str(i+1) for i in range(n_groups)]
axes[2].plot(trainer.dir, label = group_labels)
axes[2].set_title('Recovered group directions')
axes[2].set_xlabel('epochs')
axes[2].set_ylabel(r'$\langle \mathrm{v}_l(t), \mathrm{v}^\star\rangle$')
axes[2].legend()

fig.tight_layout()
fig.savefig('outputs/gaussian_convergence_alg2.pdf')