In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms.functional import rotate
from torchvision.transforms import ToTensor, Lambda,Compose
import copy
import numpy as np
import matplotlib.pyplot as plt

import pickle

%matplotlib inline

### Define transformation to rotatate images by a given angle

In [2]:
class ROTATE(object):
    def __init__(self, angle):
        self.angle = angle

    def __call__(self, img):
        return rotate(img,angle=self.angle)

### Data set:

In [3]:
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████████████████████████████████████████████████████████████| 9912422/9912422 [00:00<00:00, 10098133.29it/s]


Extracting data\MNIST\raw\train-images-idx3-ubyte.gz to data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████████████████████████████████████████████████████████████████| 28881/28881 [00:00<00:00, 26460396.20it/s]


Extracting data\MNIST\raw\train-labels-idx1-ubyte.gz to data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|███████████████████████████████████████████████████████████████████| 1648877/1648877 [00:00<00:00, 8623026.43it/s]


Extracting data\MNIST\raw\t10k-images-idx3-ubyte.gz to data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████████████████████████████████████████████████████████████████████████████████| 4542/4542 [00:00<?, ?it/s]

Extracting data\MNIST\raw\t10k-labels-idx1-ubyte.gz to data\MNIST\raw






In [4]:
train_dataloaderF = DataLoader(training_data, batch_size=256,shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64)

In [5]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


### Model: A simple FFNN

In [6]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

### Loss

In [7]:
loss_fn = nn.CrossEntropyLoss()

### Implementation of AUXMOM: 

In [14]:
def compute_grad(model,X,y,device=device):
    g = [None for _ in model.parameters()]
    X, y = X.to(device), y.to(device)
    pred = model(X)
    loss = loss_fn(pred, y)
    model.zero_grad()
    loss.backward()
    with torch.no_grad():
        for i,param in enumerate(model.parameters()):
            g[i] = param.grad
    return g, loss

In [17]:
def AUXMOM(K,lrs,a, model, loss_fn,epochs):

    Res = {"train_loss":[],"test_loss":[],"test_acc":[]}
    for t in range(epochs):
              #print(f"Epoch {t+1}\n-------------------------------")
            size = len(train_dataloaderF.dataset)
            m_f = [None for _ in model.parameters()]
            dloadH = iter(train_dataloaderH)
            for batch, (X, y) in enumerate(train_dataloaderF):
                lr = lrs[t*len(train_dataloaderF) + batch]
                backup = copy.deepcopy(model).to(device)
                
                g_f,_ = compute_grad(model,X,y,device=device)
                with torch.no_grad():
                    for i in range(len(m_f)):
                        m_f[i] = g_f[i] if batch==0 else (1-a)*m_f[i] + a*g_f[i]

                
                for k in range(K):
                    try:
                        (X, y) = next(dloadH)
                    except StopIteration:
                        dloadH = iter(train_dataloaderH)
                        (X, y) = next(dloadH)
                        X, y = X.to(device), y.to(device)
                   # Compute prediction and loss
                    grad_hx,_ = compute_grad(backup,X,y,device=device)
                    grad_hy, loss = compute_grad(model,X,y,device=device)
                    
                    with torch.no_grad():
                        for i,param in enumerate(model.parameters()):
                            param -= lr * (grad_hy[i] - grad_hx[i] + m_f[i])
                                   
                if batch % 50 == 0:
                    loss, current = loss.item(), batch * len(X)
                        # print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
                    Res["train_loss"].append(loss)
                        #test
                    size = len(test_dataloader.dataset)
                    num_batches = len(test_dataloader)
                    test_loss, correct = 0, 0

                    with torch.no_grad():
                        for X, y in test_dataloader:
                            X, y = X.to(device), y.to(device)
                            pred = model(X)
                            test_loss += loss_fn(pred, y).item()
                            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

                    test_loss /= num_batches
                    correct /= size
                    Res["test_loss"].append(test_loss)
                    Res["test_acc"].append(correct)
                    # print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return Res

In [22]:
def AUXMOM(K,lrs,a, model, loss_fn,epochs):

    Res = {"train_loss":[],"test_loss":[],"test_acc":[]}
    for t in range(epochs):
              #print(f"Epoch {t+1}\n-------------------------------")
            size = len(train_dataloaderF.dataset)
            m = [None for _ in model.parameters()]
            dloadH = iter(train_dataloaderH)
            for batch, (X, y) in enumerate(train_dataloaderF):
                lr = lrs[t*len(train_dataloaderF) + batch]
                
                g_f,_ = compute_grad(model,X,y,device=device)
                
                try:
                        (X, y) = next(dloadH)
                except StopIteration:
                        dloadH = iter(train_dataloaderH)
                        (X, y) = next(dloadH)
                        X, y = X.to(device), y.to(device)
                g_h,_ = compute_grad(model,X,y,device=device)
                
                with torch.no_grad():
                    for i in range(len(m)):
                        m[i] = g_f[i] - g_h[i] if batch==0 else (1-a)*m[i] + a*(g_f[i] - g_h[i])
                for k in range(K):
                    try:
                        (X, y) = next(dloadH)
                    except StopIteration:
                        dloadH = iter(train_dataloaderH)
                        (X, y) = next(dloadH)
                        X, y = X.to(device), y.to(device)
                   # Compute prediction and loss
                    grad_hy, loss = compute_grad(model,X,y,device=device)
                    
                    with torch.no_grad():
                        for i,param in enumerate(model.parameters()):
                            param -= lr * (grad_hy[i] + m[i])
                                   
                if batch % 50 == 0:
                    loss, current = loss.item(), batch * len(X)
                        # print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
                    Res["train_loss"].append(loss)
                        #test
                    size = len(test_dataloader.dataset)
                    num_batches = len(test_dataloader)
                    test_loss, correct = 0, 0

                    with torch.no_grad():
                        for X, y in test_dataloader:
                            X, y = X.to(device), y.to(device)
                            pred = model(X)
                            test_loss += loss_fn(pred, y).item()
                            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

                    test_loss /= num_batches
                    correct /= size
                    Res["test_loss"].append(test_loss)
                    Res["test_acc"].append(correct)
                    # print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return Res

In [23]:
Ks = [1,5,10] 
Angles = [0,45,90,180]
trials = 1
epochs = 10
T = np.arange(len(train_dataloaderF)*epochs) + 1
lrs = 1e-2 * np.ones_like(T) #1e-1 / np.sqrt(T)
a = 0.1
for angle in Angles:
    print("Angle=", angle)
    Res = {K:{} for K in Ks}
    training_data_rotated = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=Compose([ToTensor(),ROTATE(angle)])
    )
    train_dataloaderH = DataLoader(training_data_rotated, batch_size=64,shuffle=True)
    for K in Ks:
        print("K=",K)
        for t in range(trials):
            print("---------t=",t)
            model = NeuralNetwork().to(device)
            Res[K][t] = AUXMOM(K,lrs,a, model, loss_fn,epochs)
    with open(f'results/MnistExp_angle{angle}.pkl', 'wb') as f:
        pickle.dump(Res, f)

Angle= 0
K= 1
---------t= 0
K= 5
---------t= 0
K= 10
---------t= 0
Angle= 45
K= 1
---------t= 0
K= 5
---------t= 0
K= 10
---------t= 0
Angle= 90
K= 1
---------t= 0
K= 5
---------t= 0
K= 10
---------t= 0
Angle= 180
K= 1
---------t= 0
K= 5
---------t= 0
K= 10
---------t= 0


### Naive algorithm implementation

In [None]:
def NaiveAlg(K,lrs,a, model, loss_fn,epochs):

    Res = {"train_loss":[],"test_loss":[],"test_acc":[]}
    for t in range(epochs):
         #print(f"Epoch {t+1}\n-------------------------------")
         size = len(train_dataloaderF.dataset)
         m_f = [None for _ in model.parameters()]
         dloadH = iter(train_dataloaderH)
         for batch, (X, y) in enumerate(train_dataloaderF):
             lr = lrs[t*len(train_dataloaderF) + batch]
             backup = copy.deepcopy(model).to(device)
             X, y = X.to(device), y.to(device)
             pred = model(X)
             loss = loss_fn(pred, y)
             model.zero_grad()
             loss.backward()
             with torch.no_grad():
                for i,p in enumerate(model.parameters()):
                    m_f[i] = p.grad if batch==0 else (1-a)*m_f[i] + a*p.grad

        
             try:
               (X, y) = next(dloadH)
             except StopIteration:
               dloadH = iter(train_dataloaderH)
               (X, y) = next(dloadH)
             X, y = X.to(device), y.to(device)
             for k in range(K):
                 # Compute prediction and loss
                 pred = backup(X)
                 loss = loss_fn(pred, y)

                 # Backpropagation
                 backup.zero_grad()
                 loss.backward()

                 grad_hx = [None for _ in backup.parameters()]
                 with torch.no_grad():
                     for i,param in enumerate(backup.parameters()):
                           grad_hx[i] = param.grad


                 pred = model(X)
                 loss = loss_fn(pred, y)

                 # Backpropagation
                 model.zero_grad()
                 loss.backward()

                 grad_hy = [None for _ in model.parameters()]
                 with torch.no_grad():
                     for i,param in enumerate(model.parameters()):
                         grad_hy[i] = param.grad

                 with torch.no_grad():
                     for i,param in enumerate(model.parameters()):
                       if k==0:
                          param -= lr * m_f[i]
                       else:
                          param -= lr * grad_hy[i]
             if batch % 50 == 0:
                 loss, current = loss.item(), batch * len(X)
                 # print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
                 Res["train_loss"].append(loss)
                 #test
                 size = len(test_dataloader.dataset)
                 num_batches = len(test_dataloader)
                 test_loss, correct = 0, 0

                 with torch.no_grad():
                    for X, y in test_dataloader:
                        X, y = X.to(device), y.to(device)
                        pred = model(X)
                        test_loss += loss_fn(pred, y).item()
                        correct += (pred.argmax(1) == y).type(torch.float).sum().item()

                 test_loss /= num_batches
                 correct /= size
                 Res["test_loss"].append(test_loss)
                 Res["test_acc"].append(correct)
                    # print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return Res

In [None]:
Ks = [1,5,10] 
Angles = [0,45,90,180]
trials = 10
epochs = 10
T = np.arange(len(train_dataloaderF)*epochs) + 1
lrs = 1e-2 * np.ones_like(T) #1e-1 / np.sqrt(T)
a = 0.1
for angle in Angles:
  print("Angle=", angle)
  Res = {K:{} for K in Ks}
  training_data_rotated = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=Compose([ToTensor(),ROTATE(angle)])
    )
  train_dataloaderH = DataLoader(training_data_rotated, batch_size=64,shuffle=True)
  for K in Ks:
    print("K=",K)
    for t in range(trials):
        print("---------t=",t)
        model = NeuralNetwork().to(device)
        Res[K][t] = NaiveAlg(K,lrs,a, model, loss_fn,epochs)
  with open(f'results/MnistExpNaiveangle{angle}.pkl', 'wb') as f:
        pickle.dump(Res, f)