In [None]:
import torch
import numpy as np
import copy
import matplotlib.pyplot as plt
from src.group_model import EarlyStopping, GaussianSimulation, Simulation, groupModel, groupTrainer, summary_plot

In [None]:
from yaglm.toy_data import sample_sparse_lin_reg

from yaglm.GlmTuned import GlmCV, GlmTrainMetric

from yaglm.config.loss import Huber
from yaglm.config.penalty import Lasso, GroupLasso
from yaglm.config.flavor import Adaptive, NonConvex

from yaglm.metrics.info_criteria import InfoCriteria
from yaglm.infer.Inferencer import Inferencer
from yaglm.infer.lin_reg_noise_var import ViaRidge

from yaglm.Glm import Glm

In [None]:
from iterreg.sparse import dual_primal
from iterreg.utils import datadriven_ratio

In [None]:
def one_fit_group_lasso(sim, penalty, group_size):
    res = Glm(loss='lin_reg', penalty=GroupLasso(groups=[range(i*group_size,i*group_size+group_size) for i in range((sim.p)//group_size)],pen_val=penalty)).fit(sim.X, sim.y)
    val_error = ((sim.y_val - res.decision_function(sim.X_val))**2).mean().item()
    return (val_error, penalty)

def val_group_lasso(sim, penalties, group_size=4):
    val_errors = []
    for penalty in penalties:
        val_error = one_fit_group_lasso(sim, penalty, group_size)
        val_errors.append(val_error)
    val_errors.sort()
    
    res = Glm(loss='lin_reg', penalty=GroupLasso(groups=[range(i*group_size,i*group_size+group_size) for i in range((sim.p)//group_size)],pen_val=val_errors[0][1])).fit(sim.X, sim.y)
    val_error = ((sim.y_val - res.decision_function(sim.X_val))**2).mean().item()
    est_err = ((res.coef_ - sim.w_star.numpy())**2).sum()
    return est_err, val_error,res, val_errors

In [None]:
class groupTrainer:
    def __init__(self, model, sim, is_two_lr=False, varepsilon = 1e-12, tol_on_u = 5e-3, 
                 lr=0.01, is_small_train=False, is_monitor_u_diff=False, verbose=False):
        self.model = model
        self.sim = sim
        self.lr = lr
        self.optimizer = torch.optim.SGD(self.model.u.parameters(), lr=lr)
        self.all_optimizer = torch.optim.SGD(self.model.parameters(), lr=lr)

        self.loss_criterion = torch.nn.MSELoss()
        self.monitor = {'w':[], 'u':[], 'v':[]}
        self.loss = []
        self.early_stopping = EarlyStopping(patience=500)
        self.flag = True
        self.early_stopped_epoch = None
        self.early_stopped_model = None
        
        self.val_err = []
        self.params_est_err = []
        self.dir_monitor = []
        
        self.change_epoch = 0
        self.is_two_lr = is_two_lr
        self.tol_on_u = tol_on_u
        self.is_small_train = is_small_train
        self.is_monitor_u_diff = is_monitor_u_diff

        self.num_groups = self.sim.p//self.model.group_size
        self.v_optimizers = []
        for i in range(self.num_groups):
            tmp_optim = torch.optim.SGD(self.model.vs[i].parameters(), lr=lr)
            self.v_optimizers.append(tmp_optim)
        self.varepsilon = varepsilon
        self.verbose = verbose
    def _one_epoch(self):
        
        y_pred = self.model(self.sim.X)
        loss = self.loss_criterion(y_pred, self.sim.y)
        self.optimizer.zero_grad()
        for i in range(self.num_groups):
            self.v_optimizers[i].zero_grad()
        loss.backward()
        for i in range(self.num_groups):
            for g in self.v_optimizers[i].param_groups:
                g['lr'] = 1./self.model.u.weight.data.detach().clone()[0,i]**(self.model.depth*2)
            self.v_optimizers[i].step()
            self.model.vs[i].weight.data = self.model.vs[i].weight.data.clone() / (self.model.vs[i].weight.data.clone() ** 2).sum().sqrt()

        self.optimizer.step()

        est_err = self._monitor()
        self.loss.append(loss.item())
        return loss.item(), est_err, self.model.u.weight.data.detach().clone()
    
    def _small_lr_one_epoch(self):
        y_pred = self.model(self.sim.X)
        loss = self.loss_criterion(y_pred, self.sim.y)
        self.all_optimizer.zero_grad()
        loss.backward()
        self.all_optimizer.step()
        
        for i in range(self.sim.p//self.model.group_size):
            self.model.vs[i].weight.data = self.model.vs[i].weight.data / (self.model.vs[i].weight.data ** 2).sum().sqrt()

        est_err = self._monitor()
        self.loss.append(loss.item())
        return loss.item(), est_err
    
    def _monitor(self):
        k = self.sim.k
        num = k//self.model.group_size
        params = self.model.get_params().squeeze(0)
        est_err = ((self.sim.w_star - self.model.get_params().detach())**2).sum().item() # self.loss_criterion(params, self.sim.w_star).item()
        
        val_err = ((self.sim.y_val - np.matmul(self.sim.X_val, self.model.get_params().detach().numpy()))**2).sum().item()/sim.y_val.shape[0]

        self.val_err.append(val_err)
        self.params_est_err.append(est_err)
        params = params.numpy().tolist()
        
        u = self.model.u.weight.detach().squeeze(0).numpy().tolist()
        v = [self.model.vs[i].weight.detach().squeeze(0).numpy().tolist() for i in range(self.model.num_groups)]
        v = [item for sublist in v for item in sublist]
        
        self.monitor['w'].append([*params[:k],max([abs(x) for x in params[k:]])])
        self.monitor['u'].append([*u[:num],max([abs(x) for x in u[num:]])])
        self.monitor['v'].append([*v[:k],max([abs(x) for x in v[k:]])])
        
        return est_err
    
    def _small_train(self, epochs):
        for epoch in range(epochs):
            loss, est_err = self._small_lr_one_epoch()
            if self.verbose:
                print(f'{epoch}/{epochs}, loss: {loss:.4f}, est error: {est_err:.4f}')

            
    def _optimal_train(self, epochs):
        for epoch in range(epochs):
            loss, est_err, _ = self._one_epoch()
            if self.verbose:
                print(f'{epoch}/{epochs}, loss: {loss:.4f}, est error: {est_err:.4f}')
            
    def _two_train(self, epochs):
        prev_u = 1
        flag = True
        for epoch in range(epochs):
            if epoch < 200:
                loss, est_err, u = self._one_epoch()
                u_diff = ((u-prev_u).abs()/np.abs(prev_u + self.varepsilon)).max()
                prev_u = u
                if self.is_monitor_u_diff:
                    print(f'{u_diff.item()}')
                if u_diff < self.tol_on_u:
                    flag = False
                    print(f'Change epoch: {epoch} with {u_diff.item()}')
                    self.change_epoch = epoch
            else:
                loss, est_err = self._small_lr_one_epoch()

            if self.verbose:
                print(f'{epoch}/{epochs}, loss: {loss:.4f}, est error: {est_err:.4f}')  
            
    def train(self, epochs=500):
        if self.is_two_lr:
            self._two_train(epochs=epochs)
        elif self.is_small_train:
            self._small_train(epochs=epochs)
        else:
            self._optimal_train(epochs=epochs)
        self.get_dir()

    def transpose_monitor(self):
        transposed_monitor = {}
        for key, item in self.monitor.items():
            item_t = [[one[i] for one in item] for i in range(len(item[0]))]
            transposed_monitor[key] = item_t
        return transposed_monitor
    
    def get_dir(self):
        inner_product = []
        group_size = self.model.group_size
        for i in range(len(self.monitor['v'])):
            vec = np.array(self.monitor['v'][i][:-1])
            n = vec.shape[0]
            one_step = []
            for j in range(n//self.model.group_size):
                tmp_vec = vec[j*group_size : (j+1)*group_size]
                tmp_support = self.sim.support[j*group_size : (j+1)*group_size]
                res = (tmp_vec * tmp_support).sum() / (np.sqrt((tmp_vec**2).sum()) * np.sqrt((tmp_support**2).sum()))
                one_step.append(res)
            inner_product.append(one_step)
        self.dir = inner_product

In [None]:
def proximal_op(w, tau):
    group_size = 4
    w = w.reshape(-1, group_size)
    norms = np.sqrt((w**2).sum(axis=1))
    norms = np.maximum(norms,tau)
    w = (1-tau/norms)[:,None]*w
    w = w.reshape(-1)
    return w

In [None]:
snrs = []
for run in range(30):
#     sims = [GaussianSimulation(m, p, support = np.array([1,1,1,1,1,1,1,1,1,1,1,1])*10., std = std, seed=run)
#            for std in range(100)]
    snr = [GaussianSimulation(m, p, support = np.array([1,1,1,1,1,1,1,1,1,1,1,1])*10., std = std, seed=run*100).snr
           for std in range()]
    snrs.append(snr)

In [None]:
snrs = [[x[i] for x in snrs] for i in range(len(snrs[0]))]
snrs = [np.mean(x) for x in snrs]
snrs

In [None]:
list(zip(snrs, range(100)))

In [None]:
snrs = []
for run in range(30):
#     sims = [GaussianSimulation(m, p, support = np.array([1,1,1,1,1,1,1,1,1,1,1,1])*10., std = std, seed=run)
#            for std in range(100)]
    snr = [GaussianSimulation(m, p, support = np.array([1,1,1,1,1,1,1,1,1,1,1,1])*10., std = std, seed=run*100).snr
           for std in [1,2,3,4,5,6,8,11,17,34,68]]
    snrs.append(snr)
snrs = [[x[i] for x in snrs] for i in range(len(snrs[0]))]
snrs = [np.mean(x) for x in snrs]
snrs

In [None]:
penalties = 0.1 * np.exp(np.linspace(0.01, 3, 10))
m = 150
p = 300

all_sims = []
outer_gres = []
outer_ires = []
outer_rres = []
snrs = []
for run in range(30):
    print('#######################')
    print('#######################')
    print(f'########{run}#########')
    print('#######################')
    print('#######################')
    seed = run * 100
    sims = [GaussianSimulation(m, p, support = np.array([1,1,1,1,1,1,1,1,1,1,1,1])*10., std = std, seed=seed)
           for std in np.exp(np.linspace(0,3,10))]
    snr = [sim.snr for sim in sims]
    snrs.append(snr)
    
    gres = []
    gres_all = []
    for sim in sims:
        est_err, val_error, res, val_errors = val_group_lasso(sim, penalties)
        gres.append(est_err)
        gres_all.append([est_err, val_error, res, val_errors])
    outer_gres.append(gres)
    
    iterate_err = []
    iterate_all = []
    for sim in sims:
        ratio = 10 * datadriven_ratio(sim.X, sim.y)
        _, _, _, all_w = dual_primal(
            sim.X.numpy(), sim.y.numpy(), step_ratio=ratio, rho=0.99, ret_all=True,
            prox=proximal_op,
            max_iter=100,
            f_store=1)
        val_err = (((sim.X_val@all_w.transpose())-sim.y_val[:,None])**2).sum(axis=0).numpy()
        est_err = ((all_w - sim.w_star.numpy()[None,:])**2).sum(axis=1)
        iterate_err.append(min(zip(val_err, est_err)))
        iterate_all.append(zip(val_err, est_err))
    final_iterate_err = [x[1] for x in iterate_err]
    outer_ires.append(final_iterate_err)
    
    reparam_res = []
    reparam_res_all = []
    cnt = 1
    for sim in sims:
        model = groupModel(p=sim.p, group_size=4, depth=2)
        init = 1e-6
        for param in model.parameters():
            torch.nn.init.ones_(param)
        model.u.weight.data *= init
        for i in range(model.num_groups):
            model.vs[i].weight.data *= 1/np.sqrt(model.group_size)
        trainer = groupTrainer(model, sim, lr=0.001, is_two_lr=True, is_small_train=False)
        trainer.train(600)
    #     est_err = ((sim.w_star - trainer.model.get_params().detach())**2).sum().item()
        reparam_res.append(min(zip(trainer.val_err, trainer.params_est_err)))
        reparam_res_all.append(zip(trainer.val_err, trainer.params_est_err))
        print(reparam_res)

        plt.plot(trainer.val_err)
        plt.plot(trainer.params_est_err)
        plt.show()
        
    final_reparam_res = [x[1] for x in reparam_res]
    outer_rres.append(final_reparam_res)


In [None]:
ssnrs = [[x[i] for x in snrs] for i in range(len(snrs[0]))]
ssnrs = [np.mean(x) for x in ssnrs]
ssnrs

In [None]:
def cal(l):
    errs = np.flip(np.log2(np.array(l)), axis=1)
    means = np.mean(errs, axis = 0)
    stds = np.std(errs, axis=0)#/np.sqrt(29)
    return means, stds

In [None]:
gmeans, gstds = cal(outer_gres)
imeans, istds = cal(outer_ires)
rmeans, rstds = cal(outer_rres)

In [None]:
plt.rcParams.update({'text.usetex': True})
plt.rcParams.update({'text.latex.preamble': r'\usepackage{amsmath}'})
plt.rcParams.update({'lines.linewidth': 3})
plt.rcParams.update({'font.size': 15})
plt.rcParams.update({'legend.frameon': False})

In [None]:
from matplotlib.transforms import ScaledTranslation


In [None]:
fig, ax = plt.subplots()

trans1 = ax.transData + ScaledTranslation(-5/72, 0, fig.dpi_scale_trans)
trans2 = ax.transData + ScaledTranslation(+5/72, 0, fig.dpi_scale_trans)

# ax.errorbar([round(x,2) for x in ssnrs[::-1]], gmeans, yerr=gstds, 
#             capthick=1, capsize=2, fmt=':', alpha=0.75, label='group lasso')
ax.plot([round(x,2) for x in ssnrs[::-1]], gmeans, '--o', label='PGD')
data = {
    'x': [round(x,2) for x in ssnrs[::-1]],
    'y1': gmeans-gstds,
    'y2': gmeans+gstds}
ax.fill_between(**data, alpha=0.25)

# ax.errorbar([round(x,2)+1.0 for x in ssnrs[::-1]], rmeans, yerr=rstds,
#             capthick=1, capsize=2, fmt=':', alpha=0.75, label='reparams')
ax.plot([round(x,2) for x in ssnrs[::-1]], rmeans, '--s', label='DGLNN')
data = {
    'x': [round(x,2) for x in ssnrs[::-1]],
    'y1': rmeans-rstds,
    'y2': rmeans+rstds}
ax.fill_between(**data, alpha=0.25)

# ax.errorbar([round(x,2)+0.5 for x in ssnrs[::-1]], imeans, yerr=istds,
#             capthick=0.5, capsize=2, fmt=':', alpha=0.75, label='primal dual')
ax.plot([round(x,2) for x in ssnrs[::-1]], imeans, '--*', label='Primal-Dual')
data = {
    'x': [round(x,2) for x in ssnrs[::-1]],
    'y1': imeans-istds,
    'y2': imeans+istds}
ax.fill_between(**data, alpha=0.25)

# plt.fill_between(range(6), imeans - istds, means + stds, alpha=.15)
ax.legend()
# plt.xticks(range(len(ssnrs)), [round(x,2) for x in ssnrs[::-1]])
ax.set_xlabel('SNR')
ax.set_ylabel(r'$\log_{2} ||\mathbf{w}_{t} - \mathbf{w}^{\star}||_{2}^{2}$')
ax.set_xlim(0,35)
fig.tight_layout()
fig.savefig('outputs/comparisons.pdf')

In [None]:
penalties = 0.1 * np.exp(np.linspace(0.01, 3, 10))
m = 150
p = 300

all_sims = []

outer_gres2 = []
outer_ires2 = []
outer_rres2 = []
snrs2 = []
for run in range(30):
    print('#######################')
    print('#######################')
    print(f'########{run}#########')
    print('#######################')
    print('#######################')
    seed = run * 100
    sims = [GaussianSimulation(m, p, support = np.array([1,1,1,1,1,1,1,1,1,1,1,1])*10., std = std, seed=seed)
           for std in [1,2,3,4,5,6,8,11,17,34,68]]
    snr = [sim.snr for sim in sims]
    snrs.append(snr)
    
    gres = []
    gres_all = []
    for sim in sims:
        est_err, val_error, res, val_errors = val_group_lasso(sim, penalties)
        gres.append(est_err)
        gres_all.append([est_err, val_error, res, val_errors])
    outer_gres.append(gres)
    
    iterate_err = []
    iterate_all = []
    for sim in sims:
        ratio = 10 * datadriven_ratio(sim.X, sim.y)
        _, _, _, all_w = dual_primal(
            sim.X.numpy(), sim.y.numpy(), step_ratio=ratio, rho=0.99, ret_all=True,
            prox=proximal_op,
            max_iter=100,
            f_store=1)
        val_err = (((sim.X_val@all_w.transpose())-sim.y_val[:,None])**2).sum(axis=0).numpy()
        est_err = ((all_w - sim.w_star.numpy()[None,:])**2).sum(axis=1)
        iterate_err.append(min(zip(val_err, est_err)))
        iterate_all.append(zip(val_err, est_err))
    final_iterate_err = [x[1] for x in iterate_err]
    outer_ires.append(final_iterate_err)
    
    reparam_res = []
    reparam_res_all = []
    cnt = 1
    for sim in sims:
        model = groupModel(p=sim.p, group_size=4, depth=2)
        init = 1e-6
        for param in model.parameters():
            torch.nn.init.ones_(param)
        model.u.weight.data *= init
        for i in range(model.num_groups):
            model.vs[i].weight.data *= 1/np.sqrt(model.group_size)
        trainer = groupTrainer(model, sim, lr=0.001, is_two_lr=True, is_small_train=False)
        trainer.train(600)
    #     est_err = ((sim.w_star - trainer.model.get_params().detach())**2).sum().item()
        reparam_res.append(min(zip(trainer.val_err, trainer.params_est_err)))
        reparam_res_all.append(zip(trainer.val_err, trainer.params_est_err))
        print(reparam_res)

        plt.plot(trainer.val_err)
        plt.plot(trainer.params_est_err)
        plt.show()
        
    final_reparam_res = [x[1] for x in reparam_res]
    outer_rres.append(final_reparam_res)

