In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import torch
from sklearn.datasets import make_regression
from torch.distributions.multivariate_normal import MultivariateNormal

seed = 7
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [None]:
n_data = 2000
z = torch.normal(0, 2.0, size=(n_data,),dtype=torch.float)
xp = torch.bernoulli(torch.ones(n_data)/2)
e1 = torch.tensor([1,0],dtype=torch.float)
e2 = torch.tensor([0,1],dtype=torch.float)
sigma1 = 1.
sigma2 = 1./pow(2,6)

In [None]:
x = [z[i]*sigma1*xp[i]*e1 + z[i]*sigma2*(1-xp[i])*e2 for i in range(n_data)]
x = torch.stack(x)
w = torch.tensor([3, 4],dtype=torch.float)
y = x@w

In [None]:
sig = torch.diag(torch.tensor([2*sigma1**2, 2*sigma2**2]))
x_dist = MultivariateNormal(torch.zeros(2), sig)
x = [x_dist.sample() for i in range(n_data)]
x = torch.stack(x)
w = torch.tensor([3, 4],dtype=torch.float)
y = x@w

In [None]:
y.shape

In [None]:
from torch.utils.data import TensorDataset, DataLoader
x_train = x.clone().detach().requires_grad_(True)
y_train = y.clone().detach().requires_grad_(True)
train_ds = TensorDataset(x_train, y_train)


In [None]:
n_iter = 1
sgd_loss = torch.zeros(n_iter)
mom_loss = torch.zeros(n_iter)
nes_loss = torch.zeros(n_iter)
mas_loss = torch.zeros(n_iter)

In [None]:
import torch.nn.functional as F
import torch.nn as nn
import mass
device = "cuda:0"

In [None]:
def fit(model_instance, loss_fn, optim, data_loader, n_iter = 1):
    batch_loss = torch.zeros(n_data)

    for epoch in range(n_iter):
        model_instance.train()
        running_loss = 0.0
        i=0
        for xb,yb in data_loader:
            xb = xb.to(device)
            yb = yb.to(device)
            pred = model_instance(xb)
            pred = pred.squeeze(0)
            loss = loss_fn(pred, yb)
            
            loss.backward()
            optim.step()
            optim.zero_grad()
            running_loss += loss.item()
            batch_loss[i]=loss.item()
            i+=1
        print("Epoch %d, loss %4.2f" % (epoch, running_loss))
    return batch_loss

In [None]:
eta1 = 1/6
eta2 = 5/(36+6*sigma2)
gamma = (6-sigma2)/(6+sigma2)
lr_mass = eta1
alpha_mass = (1-gamma)/(1+gamma)
kappa_mass = eta1/(eta1 - eta2*(1+alpha_mass))

In [None]:
lr_mass

In [None]:
model = nn.Linear(2, 1).to(device)
batch_size = 1
train_dl = DataLoader(train_ds, batch_size, shuffle=True)
optSGD = torch.optim.SGD(model.parameters(), lr=1e-3)
optMOM = torch.optim.SGD(model.parameters(), lr = 1e-3, momentum=0.9)
optNES = torch.optim.SGD(model.parameters(), lr = 1e-3,momentum=0.9, nesterov=True)
optMAS = mass.Mass(model.parameters(),lr=lr_mass,alpha=alpha_mass,kappa_t=kappa_mass)
optASGD = torch.optim.ASGD(model.parameters(),lr=1e-3)
loss_function = F.mse_loss

In [None]:
#sgd_batch_loss = fit(model, loss_function,optSGD,train_dl)
#mom_batch_loss = fit(model, loss_function,optMOM,train_dl)
#nes_batch_loss = fit(model, loss_function,optNES,train_dl)
mas_batch_loss = fit(model, loss_function,optMAS,train_dl)
#asgd_batch_loss = fit(model, loss_function,optASGD,train_dl)

In [None]:
plt.plot(torch.log10(sgd_batch_loss), c = 'red', label = 'sgd')
plt.plot(torch.log10(mom_batch_loss), c = 'blue', label = 'momentum')
plt.plot(torch.log10(nes_batch_loss), c = 'green', label = 'nesterov')
plt.plot(torch.log10(mas_batch_loss), c = 'purple', label = 'mass')
plt.plot(torch.log10(asgd_batch_loss), c = 'cyan', label = 'asgd')
plt.legend()

In [None]:
plt.plot(torch.log10(asgd_batch_loss), c = 'cyan', label = 'asgd')

In [None]:
plt.plot(torch.log10(sgd_batch_loss), c = 'red', label = 'sgd')

In [None]:
plt.plot(torch.log10(mom_batch_loss), c = 'blue', label = 'momentum')

In [None]:
plt.plot(torch.log10(nes_batch_loss), c = 'green', label = 'nesterov')

In [None]:
plt.plot(torch.log10(mas_batch_loss), c = 'purple', label = 'mass')