In [None]:
import torch
import torch.nn as nn
import torchvision
import numpy as np
import copy
import time
from torchvision.datasets import MNIST
from torchvision import transforms
from torchvision.transforms import ToTensor
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data.dataloader import DataLoader
#from torch.utils.tensorboard import SummaryWriter
from torch import linalg as la

#Set a simple cnn for classifying mnist datasets
class MNIST_net(nn.Module):
    def __init__(self, classes = 10):
        super(MNIST_net, self).__init__()
        self.classes = classes
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 6, 5),
            nn.BatchNorm2d(6),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(6, 16, 5),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc1 = nn.Sequential(
            nn.Linear(16*5*5,120),
            nn.ReLU()
        )
        self.fc2 = nn.Sequential(
            nn.Linear(120,84),
            nn.ReLU()
        )
        self.fc3 = nn.Linear(84, self.classes)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(-1, 16*5*5)
        x = self.fc1(x)
        x = self.fc2(x)
        out = self.fc3(x)
        return out


def zero_grad(params):
    for p in params:
        if p.grad is not None:
            p.grad.detach()
            p.grad.zero_()

def total_loss(model, loss_fns, dataloader, n):
    total_loss = 0
    for x, y in dataloader:
        out = model(x)
        loss = loss_fns(out, y)
        total_loss = total_loss + loss.item()
    return total_loss * (1/n)

def test_loss(model, loss_fns, test_dataloader, n):
    test_loss = 0
    for x, y in test_dataloader:
        out = model(x)
        loss = loss_fns(out, y)
        test_loss = test_loss + loss.item()
    return test_loss * (1/n)
        
def total_grad(model, loss_fns, dataloader, n):
    total_grad = 0
    zero_grad(list(model.parameters()))
    for x, y in dataloader:
        out = model(x)
        loss = loss_fns(out, y)
        loss.backward()
    for p in list(model.parameters()):
        total_grad = total_grad + torch.sum(torch.square(torch.mul(torch.clone(p.grad.data).detach(), (1/n))))
    zero_grad(list(model.parameters()))
    return torch.sqrt(torch.clone(total_grad).detach())

def full_grad(model, loss_fns, dataloader, n):
    full_grad = []
    zero_grad(list(model.parameters()))
    for x, y in dataloader:
        out = model(x)
        loss = loss_fns(out, y)
        loss.backward()
    for p in list(model.parameters()):
        full_grad.append(torch.mul(torch.clone(p.grad.data).detach(), (1/n)))
    zero_grad(list(model.parameters()))
    return full_grad # a list of model parameters

def get_default_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')

def to_device(data, device):
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
    def __iter__(self):
        for b in self.dl:
            yield to_device(b, self.device)
    def __len__(self):
        return len(self.dl)
    
train_trans = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.Resize((32,32)),
    transforms.ToTensor()
])

test_trans = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor()
])

MNIST_train = MNIST(root='test/', train = True, download=True, transform=train_trans) 
MNIST_test = MNIST(root='test/', train = False, download=True, transform=test_trans) 
n = len(MNIST_train)
n1 = len(MNIST_test)
weight_decay = 1e-4
nabla = 1e-4
rbs = 1000
epoches = 100
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MNIST_train_loader_eva = DataLoader(MNIST_train, shuffle = False, batch_size = rbs)
MNIST_train_loader_eva = DeviceDataLoader(MNIST_train_loader_eva, device)
MNIST_test_loader_eva = DataLoader(MNIST_test, shuffle = False, batch_size = rbs)
MNIST_test_loader_eva = DeviceDataLoader(MNIST_test_loader_eva, device)
loss_func_rec = nn.CrossEntropyLoss(reduction='sum')

In [None]:
#training stage: mnist training with l^2 regularization using ggd-adamas (two models version)
#stage one: preparation, initialization and hyperparameter setting
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


CNN_MNIST_net = MNIST_net()
CNN_MNIST_net = CNN_MNIST_net.to(device)
loss_func = nn.CrossEntropyLoss()
#GPU TRAINING if possible
copied_model = copy.deepcopy(CNN_MNIST_net)
copied_model.to(device)

ilr0 = 1e-4
flr0 = 1e-6
lr_schedule = 'Poly'
b = 2
m = 128
MNIST_gadamas_loss_list = []
MNIST_gadamas_gradnorm_list = []
MNIST_gadamas_test_loss_list = []
MNIST_gadamas_iteration_list = []

beta_1 = torch.tensor(0.9)
beta_2 = torch.tensor(0.999)
sigma = torch.tensor(1e-8)

#stage two: load training set

MNIST_train_loader = DataLoader(MNIST_train, shuffle = True, batch_size = b*m, drop_last = True)
MNIST_train_loader = DeviceDataLoader(MNIST_train_loader, device)





batch_idx = torch.tensor(0)
#stage three: train and test 
h_0 = [torch.zeros_like(paras) for paras in list(CNN_MNIST_net.parameters())]
v_0 = [torch.zeros_like(paras) for paras in list(CNN_MNIST_net.parameters())]
for epoch in range(epoches):
    CNN_MNIST_net.train()
    
    for x_data, y_target in MNIST_train_loader:
        xt = []
        yt = []
        losst = torch.empty(2)
        xt = x_data.split(m, dim = 0)
        yt = y_target.split(m, dim = 0)
        if lr_schedule == 'Poly':
            lr = ilr0 + ((flr0 - ilr0)/( (epoches) - 1)) * int(epoch )
        else:
            lr = lr0
        #calculate losses for first to derive the resampling probability
        for i, x in enumerate(xt):
            with torch.no_grad():
                output = CNN_MNIST_net(x)
                losst[i] = loss_func(output, yt[i]).item()
        prob = losst/torch.sum(losst)
        zero_grad(list(CNN_MNIST_net.parameters()))
        zero_grad(list(copied_model.parameters()))
        #construct the adam-based grafting gradient 
        output1 = CNN_MNIST_net(xt[0])
        loss1 = loss_func(output1, yt[0])
        loss1.backward()
        output2 = copied_model(xt[1])
        loss2 = loss_func(output2, yt[1])
        loss2.backward()
        for j, (p1, p2) in enumerate(zip(list(CNN_MNIST_net.parameters()), list(copied_model.parameters()))):
            d_p1 = p1.grad.data
            d_p2 = p2.grad.data
            if weight_decay != 0:
                d_p1.add_(p1.data, alpha = weight_decay)
                d_p2.add_(p2.data, alpha = weight_decay)
            indices = torch.zeros_like(torch.clone(d_p1).detach())
            indices = indices.bernoulli_(p = prob[0]).to(torch.bool)
            d_p1.masked_fill_(~indices, 0)
            d_p2.masked_fill_(indices, 0)
            d_p1.mul_(1/b).mul_(1/prob[0])
            d_p2.mul_(1/b).mul_(1/prob[1])
            ggd_1 = torch.clone(d_p1).detach() + torch.clone(d_p2).detach()
            exp_avg = h_0[j]
            exp_avg_sq = v_0[j]
            exp_avg.mul_(beta_1).add_(ggd_1, alpha = 1 - beta_1)
            exp_avg_sq.mul_(beta_2).addcmul_(ggd_1, ggd_1.conj(), value = 1 - beta_2)
            bias_correction1 = 1 - torch.pow(beta_1, (batch_idx+1))
            bias_correction2 = 1 - torch.pow(beta_2, (batch_idx+1))
            step_size = lr / bias_correction1
            step_size_neg = step_size.neg()
            bias_correction2_sqrt = bias_correction2.sqrt()
            denom = (exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)).add_(sigma / step_size_neg)
            p1.data.addcdiv_(exp_avg, denom)
        copied_model = copy.deepcopy(CNN_MNIST_net)
        copied_model.to(device)
        batch_idx += 1
    CNN_MNIST_net.eval()
    current_gradnorm = total_grad(CNN_MNIST_net, loss_func_rec, MNIST_train_loader_eva, n)
    MNIST_gadamas_gradnorm_list.append(current_gradnorm)
    with torch.no_grad():
        current_loss = total_loss(CNN_MNIST_net, loss_func_rec, MNIST_train_loader_eva, n)
        MNIST_gadamas_loss_list.append(current_loss)
        MNIST_gadamas_test_loss_list.append(test_loss(CNN_MNIST_net, loss_func_rec, MNIST_test_loader_eva, n1))
        current_iteration =  epoch
    print('Iteration: {}  Loss: {}  Gradnorm:{}'.format(current_iteration, current_loss, current_gradnorm))

In [None]:
#training stage: mnist training with l^2 regularization using ggdas (two model version)
#stage one: preparation, initialization and hyperparameter setting
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


CNN_MNIST_net = MNIST_net()
CNN_MNIST_net = CNN_MNIST_net.to(device)
loss_func = nn.CrossEntropyLoss()
#GPU TRAINING if possible
copied_model = copy.deepcopy(CNN_MNIST_net)
copied_model.to(device)

ilr0 = 0.01
flr0 = 1e-5
lr_schedule = 'Poly'
b = 2
m = 128

MNIST_ggdas_loss_list = []
MNIST_ggdas_gradnorm_list = []
MNIST_ggdas_test_loss_list = []
MNIST_ggdas_iteration_list = []


#stage two: load training set 

MNIST_train_loader = DataLoader(MNIST_train, shuffle = True, batch_size = b*m, drop_last = True)
MNIST_train_loader = DeviceDataLoader(MNIST_train_loader, device)


#stage three: train and test 



for epoch in range(epoches):
    CNN_MNIST_net.train()
    for x_data, y_target in MNIST_train_loader:
        xt = []
        yt = []
        losst = torch.empty(2)
        xt = x_data.split(m, dim = 0)
        yt = y_target.split(m, dim = 0)
        #calculate losses for first to derive the resampling probability
        for i, x in enumerate(xt):
            with torch.no_grad():
                output = CNN_MNIST_net(x)
                losst[i] = loss_func(output, yt[i]).item()
        prob = losst/torch.sum(losst)
        zero_grad(list(CNN_MNIST_net.parameters()))
        zero_grad(list(copied_model.parameters()))
        if lr_schedule == 'Poly':
            lr = ilr0 + ((flr0 - ilr0)/( (epoches) - 1)) * int(epoch)
        else:
            lr = lr0
        #construct the grafting gradient
        output1 = CNN_MNIST_net(xt[0])
        output2 = copied_model(xt[1])
        loss1 = loss_func(output1, yt[0])
        loss2 = loss_func(output2, yt[1])
        loss1.backward()
        loss2.backward()
        for  p1, p2 in zip(list(CNN_MNIST_net.parameters()), list(copied_model.parameters())):
            d_p1 = p1.grad.data
            d_p2 = p2.grad.data
            if weight_decay != 0:
                d_p1.add_(p1.data, alpha = weight_decay)
                d_p2.add_(p2.data, alpha = weight_decay)
            indices = torch.zeros_like(torch.clone(d_p1).detach())
            indices = indices.bernoulli_(p = prob[0]).to(torch.bool)
            d_p1.masked_fill_(~indices, 0)
            d_p2.masked_fill_(indices, 0)
            d_p1.mul_(1/b).mul_(1/prob[0])
            d_p2.mul_(1/b).mul_(1/prob[1])
            p1.data.add_(torch.add(d_p1, d_p2), alpha = -lr)
        copied_model = copy.deepcopy(CNN_MNIST_net)
        copied_model.to(device)    
    CNN_MNIST_net.eval()
    current_gradnorm = total_grad(CNN_MNIST_net, loss_func_rec, MNIST_train_loader_eva, n)
    MNIST_ggdas_gradnorm_list.append(current_gradnorm)
    with torch.no_grad():
        current_loss = total_loss(CNN_MNIST_net, loss_func_rec, MNIST_train_loader_eva, n)
        MNIST_ggdas_loss_list.append(current_loss)
        MNIST_ggdas_test_loss_list.append(test_loss(CNN_MNIST_net, loss_func_rec, MNIST_test_loader_eva, n1))
        current_iteration =  epoch
    print('Iteration: {}  Loss: {}  Gradnorm:{}'.format(current_iteration, current_loss, current_gradnorm))
 

In [None]:
#training stage: mnist training with ggd-svrg-as
#stage one: preparation, initialization and hyperparameter setting
CNN_MNIST_net = MNIST_net()
CNN_MNIST_net = CNN_MNIST_net.to(device)
Snap_model = copy.deepcopy(CNN_MNIST_net)
Snap_model.to(device)
loss_func = nn.CrossEntropyLoss()
#initial learning rate
ilr0 = 100/(n ** (2/3))
#final learning rate
flr0 = 1/(n ** (2/3))
#learning schedule
lr_schedule = 'Poly'
b = 2
m = 128
MNIST_gsvrg_loss_list = []
MNIST_gsvrg_gradnorm_list = []
MNIST_gsvrg_test_loss_list = []

q = 3*int(n/m) #update frequency


#stage two: load training set 
MNIST_train_loader = DataLoader(MNIST_train, shuffle = True, batch_size = b*m, drop_last = True)
MNIST_train_loader = DeviceDataLoader(MNIST_train_loader, device)


#stage three: train and test
fg = full_grad(Snap_model, loss_func_rec, MNIST_train_loader_eva, n)
batch_idx = 0 
for epoch in range(epoches):
    CNN_MNIST_net.train()
    for x_data, y_target in MNIST_train_loader:
        g0 = []
        g1 = []
        gf_r0 = []
        gf_r1 = []
        ggd = []
        xt = []
        yt = []
        losst = torch.empty(2)
        xt = x_data.split(m, dim = 0)
        yt = y_target.split(m, dim = 0)
        if lr_schedule == 'Poly':
            lr = ilr0 + ((flr0 - ilr0)/( (epoches) - 1)) * int(epoch)
        else:
            lr = lr0
        #construct the svrg-based grafting gradient
        #part one: prepare for g_mb(bar{x})
        for i, x in enumerate(xt):
            output = Snap_model(x)
            loss_snap = loss_func(output, yt[i])
            loss_snap.backward()
            for j, p in enumerate(list(Snap_model.parameters())):
                d_p = p.grad.data
                if weight_decay != 0:
                    d_p.add_(p.data, alpha = weight_decay)
                if i == 0:
                    gf_r0.append(torch.clone(d_p).detach())
                else:
                    gf_r1.append(torch.clone(d_p).detach())
            zero_grad(list(Snap_model.parameters()))
        norm_2 = torch.zeros(2)
        #deriving the sampling probability and preparing for g_mb(x^k_s)
        for i, x in enumerate(xt):
            output = CNN_MNIST_net(x)
            loss = loss_func(output, yt[i])
            loss.backward()
            if i == 0:
                for z, p in zip(gf_r0, list(CNN_MNIST_net.parameters())):
                    d_p = p.grad.data
                    if weight_decay != 0:
                        d_p.add_(p.data, alpha = weight_decay)
                    g0.append(torch.clone(d_p).detach())
                    
                    norm_2[i] = norm_2[i] + torch.sum(torch.square(torch.add(z, torch.clone(d_p).detach(), alpha = -1)))
            else:
                for z, p in zip(gf_r1, list(CNN_MNIST_net.parameters())):
                    d_p = p.grad.data
                    if weight_decay != 0:
                        d_p.add_(p.data, alpha = weight_decay)
                    g1.append(torch.clone(d_p).detach())
                    
                    norm_2[i] = norm_2[i] + torch.sum(torch.square(torch.add(z, torch.clone(d_p).detach(), alpha = -1)))
            zero_grad(list(CNN_MNIST_net.parameters()))
        if torch.min(norm_2) == 0:
            norm_2 = torch.ones(2)
        prob = torch.sqrt(norm_2)/torch.sum(torch.sqrt(norm_2))
       
        #constructing the grafting gradient \tilde{g}^k_mb
        for qr, qo, pr, po, fg_p in zip(gf_r0, g0, gf_r1, g1, fg):
            indices = torch.zeros_like(qr)
            indices = indices.bernoulli_(p = prob[0]).to(torch.bool)
            qr.masked_fill_(~indices, 0)
            qo.masked_fill_(~indices, 0)
            pr.masked_fill_(indices, 0)
            po.masked_fill_(indices, 0)
            ggd.append(torch.add(po, qo) - torch.add(pr, qr) + fg_p)
        
        #update!
        for g, p in zip(ggd, list(CNN_MNIST_net.parameters())):
            p.data = torch.add(p.data, g, alpha = -lr)
        batch_idx += 1
        #Break the loop when iteration number equal update frequency
        if  batch_idx  % q == 0:
            Snap_model = copy.deepcopy(CNN_MNIST_net)
            fg = full_grad(Snap_model, loss_func_rec, MNIST_train_loader_eva, n)
    CNN_MNIST_net.eval()
    current_gradnorm = total_grad(CNN_MNIST_net, loss_func_rec, MNIST_train_loader_eva, n)
    MNIST_gsvrg_gradnorm_list.append(current_gradnorm)
    with torch.no_grad():
        current_loss = total_loss(CNN_MNIST_net, loss_func_rec, MNIST_train_loader_eva, n)
        MNIST_gsvrg_loss_list.append(current_loss)
        MNIST_gsvrg_test_loss_list.append(test_loss(CNN_MNIST_net, loss_func_rec, MNIST_test_loader_eva, n1))
        current_iteration = epoch
    print('Iteration: {}  Loss: {}  Gradnorm:{}'.format(current_iteration, current_loss, current_gradnorm))
