In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms.functional import rotate
from torchvision.transforms import ToTensor, Lambda,Compose
import copy
import numpy as np
import matplotlib.pyplot as plt

import pickle

%matplotlib inline

### Transformation to permute labels

In [None]:
labels = np.arange(10)
class Tlabels:
  def __init__(self,p):
    self.p = p

  def __call__(self,y):
    newy = y
    if  np.random.rand() <= self.p: 
      L = list( set(labels) - set([y]) )
      newy = np.random.choice(L)
      return newy
    return newy

### Dataset

In [None]:
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

In [None]:
train_dataloaderF = DataLoader(training_data, batch_size=256,shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

### Model: a simple FFNN

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

loss

In [None]:
loss_fn = nn.CrossEntropyLoss()

### AUXMOM algorithm

In [None]:
def AUXMOM(K,lrs,a, model, loss_fn,epochs):

    Res = {"train_loss":[],"test_loss":[],"test_acc":[]}
    for t in range(epochs):
         #print(f"Epoch {t+1}\n-------------------------------")
         size = len(train_dataloaderF.dataset)
         m_f = [None for _ in model.parameters()]
         dloadH = iter(train_dataloaderH)
         for batch, (X, y) in enumerate(train_dataloaderF):
             lr = lrs[t*len(train_dataloaderF) + batch]
             backup = copy.deepcopy(model).to(device)
             X, y = X.to(device), y.to(device)
             pred = model(X)
             loss = loss_fn(pred, y)
             model.zero_grad()
             loss.backward()
             with torch.no_grad():
                for i,p in enumerate(model.parameters()):
                    m_f[i] = p.grad if batch==0 else (1-a)*m_f[i] + a*p.grad

        
             try:
               (X, y) = next(dloadH)
             except StopIteration:
               dloadH = iter(train_dataloaderH)
               (X, y) = next(dloadH)
             X, y = X.to(device), y.to(device)
             for k in range(K):
                 # Compute prediction and loss
                 pred = backup(X)
                 loss = loss_fn(pred, y)

                 # Backpropagation
                 backup.zero_grad()
                 loss.backward()

                 grad_hx = [None for _ in backup.parameters()]
                 with torch.no_grad():
                     for i,param in enumerate(backup.parameters()):
                           grad_hx[i] = param.grad


                 pred = model(X)
                 loss = loss_fn(pred, y)

                 # Backpropagation
                 model.zero_grad()
                 loss.backward()

                 grad_hy = [None for _ in model.parameters()]
                 with torch.no_grad():
                     for i,param in enumerate(model.parameters()):
                         grad_hy[i] = param.grad

                 with torch.no_grad():
                     for i,param in enumerate(model.parameters()):
                        param -= lr * (grad_hy[i] - grad_hx[i] + m_f[i])
             if batch % 50 == 0:
                 loss, current = loss.item(), batch * len(X)
                 # print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
                 Res["train_loss"].append(loss)
                 #test
                 size = len(test_dataloader.dataset)
                 num_batches = len(test_dataloader)
                 test_loss, correct = 0, 0

                 with torch.no_grad():
                    for X, y in test_dataloader:
                        X, y = X.to(device), y.to(device)
                        pred = model(X)
                        test_loss += loss_fn(pred, y).item()
                        correct += (pred.argmax(1) == y).type(torch.float).sum().item()

                 test_loss /= num_batches
                 correct /= size
                 Res["test_loss"].append(test_loss)
                 Res["test_acc"].append(correct)
                    # print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return Res

In [None]:
ps = ["0.5,0.8,1"]
K = 10 
trials = 10
epochs = 10
T = np.arange(len(train_dataloaderF)*epochs) + 1
lrs = 1e-2 * np.ones_like(T) #1e-1 / np.sqrt(T)
a = 0.1
Res = {}
for p in ps:
  Res[p] = {}
for p in ps:
  print("p=",p)
  for t in range(trials):
     print("---------t=",t)
     training_data_Tlabels = datasets.MNIST(
                                    root="data",
                                    train=True,
                                    download=True,
                                    transform=Compose([ToTensor()]),
                                    target_transform = Lambda(Tlabels(p))
                                             )
     train_dataloaderH = DataLoader(training_data_Tlabels, batch_size=64,shuffle=True)

     model = NeuralNetwork().to(device)
     Res[p][t] = AUXMOM(K,lrs,a, model, loss_fn,epochs)
  with open(f'results/MnistExp_Labels{p}.pkl', 'wb') as f:
     pickle.dump(Res, f)


In [None]:
def NaiveAlg(K,lrs,a, model, loss_fn,epochs):

    Res = {"train_loss":[],"test_loss":[],"test_acc":[]}
    for t in range(epochs):
         #print(f"Epoch {t+1}\n-------------------------------")
         size = len(train_dataloaderF.dataset)
         m_f = [None for _ in model.parameters()]
         dloadH = iter(train_dataloaderH)
         for batch, (X, y) in enumerate(train_dataloaderF):
             lr = lrs[t*len(train_dataloaderF) + batch]
             backup = copy.deepcopy(model).to(device)
             X, y = X.to(device), y.to(device)
             pred = model(X)
             loss = loss_fn(pred, y)
             model.zero_grad()
             loss.backward()
             with torch.no_grad():
                for i,p in enumerate(model.parameters()):
                    m_f[i] = p.grad if batch==0 else (1-a)*m_f[i] + a*p.grad

        
             try:
               (X, y) = next(dloadH)
             except StopIteration:
               dloadH = iter(train_dataloaderH)
               (X, y) = next(dloadH)
             X, y = X.to(device), y.to(device)
             for k in range(K):
                 # Compute prediction and loss
                 pred = backup(X)
                 loss = loss_fn(pred, y)

                 # Backpropagation
                 backup.zero_grad()
                 loss.backward()

                 grad_hx = [None for _ in backup.parameters()]
                 with torch.no_grad():
                     for i,param in enumerate(backup.parameters()):
                           grad_hx[i] = param.grad


                 pred = model(X)
                 loss = loss_fn(pred, y)

                 # Backpropagation
                 model.zero_grad()
                 loss.backward()

                 grad_hy = [None for _ in model.parameters()]
                 with torch.no_grad():
                     for i,param in enumerate(model.parameters()):
                         grad_hy[i] = param.grad

                 with torch.no_grad():
                     for i,param in enumerate(model.parameters()):
                       if k==0:
                          param -= lr * m_f[i]
                       else:
                          param -= lr * grad_hy[i]
             if batch % 50 == 0:
                 loss, current = loss.item(), batch * len(X)
                 # print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
                 Res["train_loss"].append(loss)
                 #test
                 size = len(test_dataloader.dataset)
                 num_batches = len(test_dataloader)
                 test_loss, correct = 0, 0

                 with torch.no_grad():
                    for X, y in test_dataloader:
                        X, y = X.to(device), y.to(device)
                        pred = model(X)
                        test_loss += loss_fn(pred, y).item()
                        correct += (pred.argmax(1) == y).type(torch.float).sum().item()

                 test_loss /= num_batches
                 correct /= size
                 Res["test_loss"].append(test_loss)
                 Res["test_acc"].append(correct)
                    # print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return Res

In [None]:
ps = [0.2,0.5,0.8,1]
K = 10 
trials = 10
epochs = 10
T = np.arange(len(train_dataloaderF)*epochs) + 1
lrs = 1e-2 * np.ones_like(T) #1e-1 / np.sqrt(T)
a = 0.1
Res = {}
for p in ps:
  Res[p] = {}
for p in ps:
  print("p=",p)
  for t in range(trials):
     print("---------t=",t)
     training_data_Tlabels = datasets.MNIST(
                                    root="data",
                                    train=True,
                                    download=True,
                                    transform=Compose([ToTensor()]),
                                    target_transform = Lambda(Tlabels(p))
                                             )
     train_dataloaderH = DataLoader(training_data_Tlabels, batch_size=64,shuffle=True)

     model = NeuralNetwork().to(device)
     Res[p][t] = NaiveAlg(K,lrs,a, model, loss_fn,epochs)
  with open(f'results/MnistExpNaive_Labels{p}.pkl', 'wb') as f:
     pickle.dump(Res, f)