In [None]:
!git clone https://github.com/tranhp98/SGDHess.git

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(font_scale=1.4)
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from torch.optim import Optimizer

from typing import List, Optional

import torchvision
from torchvision.models import resnet18, resnet34
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10

from transformers import ResNetConfig, ResNetModel

import copy
from tqdm.notebook import tqdm

from IPython.display import clear_output
import time

# SGDHess

In [None]:
class SGDHess:

    def __init__(self, parameters, lr, momentum=1., clipping=None):
        self.parameters = list(parameters)
        self.lr = lr
        self.momentum = momentum
        self.clipping = clipping
        
        self.prev_params = None
        self.direction = None
        
    def step(self, loss):
        
        if self.prev_params is None:
            grad = torch.autograd.grad(outputs=loss, inputs=self.parameters)
            self.prev_params = []
            self.direction = []
            with torch.no_grad():
                for i, param in enumerate(self.parameters):
                    self.prev_params.append(param.detach().clone())
                    self.direction.append(grad[i].detach().clone())
                    param -= self.lr * self.direction[i]
            return        
            
        grad = torch.autograd.grad(outputs=loss, inputs=self.parameters, create_graph=True)
        
        
        for i, param in enumerate(self.parameters):
            self.prev_params[i].add_(param.detach(), alpha=-1)
            
        hessian = torch.autograd.grad(
            outputs=grad, 
            inputs=self.parameters, 
            grad_outputs=self.prev_params
        )
        
        for i, param in enumerate(self.parameters):
            self.prev_params[i] = param.detach().clone()
        
        with torch.no_grad():
            if self.clipping is not None:
                sq_norm = torch.tensor(0.).to(self.parameters[0].device)
                for i in range(len(self.direction)):
                    self.direction[i].add_(self.direction[i], alpha=-self.momentum)
                    self.direction[i].add_(hessian[i], alpha=-(1 - self.momentum))
                    self.direction[i].add_(grad[i], alpha=self.momentum)
                    sq_norm += (self.direction[i] ** 2).sum()

                norm = torch.sqrt(sq_norm)
                
                if norm > self.clipping:
                    for grad_i in self.direction:
                        grad_i /= norm 
                        grad_i *= self.clipping
            else:
                for i in range(len(self.direction)):
                    self.direction[i].add_(self.direction[i], alpha=-self.momentum)
                    self.direction[i].add_(hessian[i], alpha=(1 - self.momentum))
                    self.direction[i].add_(grad[i], alpha=self.momentum)
            
            for i, param in enumerate(self.parameters):
                param.add_(self.direction[i], alpha=-self.lr)   
                
                
                

class SGD:
    
    def __init__(self, parameters, lr, momentum=1., clipping=None):
        self.parameters = list(parameters)
        self.lr = lr
        self.momentum = momentum
        self.clipping = clipping
        
#         self.prev_params = None
        self.direction = None
        
    def step(self, loss):
        
        if self.direction is None:
            grad = torch.autograd.grad(outputs=loss, inputs=self.parameters)
            self.prev_params = []
            self.direction = []
            with torch.no_grad():
                for i, param in enumerate(self.parameters):
#                     self.prev_params.append(param.detach().clone())
                    self.direction.append(grad[i].detach().clone())
                    param -= self.lr * self.direction[i]
            return        
            
        grad = torch.autograd.grad(outputs=loss, inputs=self.parameters, create_graph=True)
        
#         for i, param in enumerate(self.parameters):
#             self.prev_params[i].add_(param.detach(), alpha=-1)
            
#         hessian = torch.autograd.grad(
#             outputs=grad, 
#             inputs=self.parameters, 
#             grad_outputs=self.prev_params
#         )
        
#         for i, param in enumerate(self.parameters):
#             self.prev_params[i] = param.detach().clone()
        
        with torch.no_grad():
            if self.clipping is not None:
                sq_norm = torch.tensor(0.).to(self.parameters[0].device)
                for i in range(len(self.direction)):
                    self.direction[i].add_(self.direction[i], alpha=-self.momentum)
#                     self.direction[i].add_(hessian[i], alpha=- (1 - self.momentum))
                    self.direction[i].add_(grad[i], alpha=self.momentum)
                    sq_norm += (self.direction[i] ** 2).sum()

                norm = torch.sqrt(sq_norm)
                
                if norm > self.clipping:
                    for grad_i in self.direction:
                        grad_i /= norm 
                        grad_i *= self.clipping
            else:
                for i in range(len(self.direction)):
                    self.direction[i].add_(self.direction[i], alpha=-self.momentum)
#                     self.direction[i].add_(hessian[i], alpha=- (1 - self.momentum))
                    self.direction[i].add_(grad[i], alpha=self.momentum)
            
            for i, param in enumerate(self.parameters):
                param.add_(self.direction[i], alpha=-self.lr)   
                
            

In [None]:
class SGDHess_original(Optimizer):
    
    def __init__(self, params, lr=1e-2, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False, clip=False):
        if lr and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        self.clip = clip
        self.iteration = -1
        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov)
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(SGDHess_original, self).__init__(params, defaults)
        for group in self.param_groups:
            group.setdefault('nesterov', False)
            for p in group['params']:
                state = self.state[p]
                state['displacement'] = torch.zeros_like(p)
                state['max_grad'] = torch.zeros(1, device = p.device)
                
    def step(self, closure=None):
        """Performs a single optimization step.
        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        self.iteration += 1
        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']
            vector = []
            grads = []
            param = []
            for p in group['params']:
                if p.grad is None:
                    continue
                vector.append(self.state[p]['displacement'])
                grads.append(p.grad)
                param.append(p)

            hvp = torch.autograd.grad(outputs = grads, inputs = param, grad_outputs=vector)
            with torch.no_grad():
                i = 0
                for p in group['params']:
                    if p.grad is None:
                        continue
                    state = self.state[p]
                    displacement, max_grad = state['displacement'], state['max_grad'] 
                    with torch.no_grad():
                        d_p = p.grad
                        if weight_decay != 0:
                            d_p = d_p.add(p, alpha=weight_decay)
                        if momentum != 0:
                            if 'momentum_buffer' not in state:
                                buf = state['momentum_buffer'] = torch.clone(d_p).detach()
                            else:
                                buf = state['momentum_buffer']
                                buf.add_(hvp[i]).add_(displacement, alpha = weight_decay).mul_(momentum).add_(d_p, alpha=1 - dampening)
                                if self.clip:
                                    torch.nn.utils.clip_grad_norm_(buf, max_grad)
                                    max_grad.copy_(torch.maximum((1-dampening)/(1-momentum)*torch.norm(d_p), max_grad))
                            if nesterov:
                                d_p = d_p.add(buf, alpha=momentum)
                            else:
                                d_p = buf
                        displacement.copy_(d_p).mul_(-group['lr'])
                        p.add_(displacement)
                    i += 1
                            
        return loss

In [None]:
class SGD_AdaptiveLR(Optimizer):

    def __init__(self, params, lr=1e-2, momentum=1., dampening=0, c = 0.01, w = 10, k = 0.1,
                 weight_decay=0, nesterov=False, clip=False, ):
        if lr and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        self.clip = clip
        self.iteration = -1
        self.c = c
        self.w = w
        self.momentum = momentum
        defaults = dict(lr=lr, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov,
                        c = c, w = w, k = k)
        
        super(SGD_AdaptiveLR, self).__init__(params, defaults)
        for group in self.param_groups:
            group.setdefault('nesterov', False)
            for p in group['params']:
                state = self.state[p]
                state['displacement'] = torch.zeros_like(p)
                state['max_grad'] = torch.zeros(1, device = p.device)
                state['G_1'] = torch.tensor(0., device = p.device)
                state['G_2'] = torch.tensor(0., device = p.device)
                state['G_cumsum'] = torch.tensor(0., device = p.device)
                
                
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        self.iteration += 1
        for group in self.param_groups:
            weight_decay = group['weight_decay']
            dampening = group['dampening']
            nesterov = group['nesterov']
            lr =  group['lr']
            c = group['c']
            w = group['w']
            k = group['k']
            vector = []
            grads = []
            param = []
            
            norm_sq = 0.
            for p in group['params']:
                if p.grad is None:
                    continue
                vector.append(self.state[p]['displacement'])
                grads.append(p.grad)
                param.append(p)
                norm_sq += (p.grad ** 2).sum()

            hvp = torch.autograd.grad(outputs = grads, inputs = param, grad_outputs=vector)
            with torch.no_grad():
                i = 0
                for p in group['params']:
                    if p.grad is None:
                        continue
                    state = self.state[p]
                    displacement, max_grad = state['displacement'], state['max_grad']
                    
                    with torch.no_grad():
                        d_p = p.grad
                        if weight_decay != 0:
                            d_p = d_p.add(p, alpha=weight_decay)

                        if self.iteration == 0:
                            buf = state['momentum_buffer'] = torch.clone(d_p).detach()
                            state['G_cumsum'].add_(state['G_1'])
                            state['G_1'], state['G_2'] = state['G_2'], (d_p ** 2).sum()
                            state['lr'] = k/(torch.pow((w + state['G_cumsum']),1/3))
                            state['momentum'] = self.momentum
                        else:
                            buf = state['momentum_buffer']
                            state['G_cumsum'].add_(state['G_1'])
                            state['G_1'], state['G_2'] = state['G_2'], (d_p ** 2).sum()
                            state['lr'] = k / torch.pow(w +  state['G_cumsum'], 1 / 3)
                            state['momentum'] = self.momentum
                            buf.add_(hvp[i]).add_(displacement, alpha = weight_decay).mul_(state['momentum']).add_(d_p, alpha = 1-dampening)
                            if nesterov:
                                d_p = d_p.add(buf, alpha=state['momentum'])
                            else:
                                d_p = buf
                                
                        displacement.copy_(d_p).mul_(-state['lr'])
                        p.add_(buf, alpha=-state['lr'])
                        
                        
                    i += 1
                            
        return loss

## Data loading and preparation, Tiny ImageNet

In [None]:
transform = transforms.Compose([ 
    transforms.Resize((64, 64)),
    transforms.ToTensor()
])

In [None]:
DATA_PATH = '../output/'
dataset = CIFAR10(root=DATA_PATH, train=True, transform=transform, download=True)

In [None]:
train_data, val_data, test_data = torch.utils.data.random_split(dataset, [40000, 5000, 5000])

In [None]:
batch_size = 300
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2)

## Model creating

In [None]:
class CustomResNet20_c10(nn.Module):
    
    def __init__(self):
        super().__init__()
        config = ResNetConfig()
        resnet20 = ResNetModel(config)
        self.embedder = resnet20
        self.classifier = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.ReLU(),
            torch.nn.Linear(2048, 200),
        )
        
    def forward(self, X):
        embeddings = self.embedder(X).pooler_output
        return self.classifier(embeddings)

## Model training, Tiny ImageNet

In [None]:
USE_GPU = True

dtype = torch.float32

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print('using device:', device)

In [None]:
model = CustomResNet20_c10()
model.to(device)
loss_function = torch.nn.CrossEntropyLoss()
custom_optimizer = SGDHess_original(model.parameters(), lr=1e-2, momentum=0.95, clip=True)
num_epochs = 200
train_loss = []
train_accuracy = []
train_accuracy_5 = []
val_accuracy = []
val_aacuracy_5 = []
times = []
best_val_accuracy = 0
mean_train_loss = []

In [None]:
lr_change_point = 0

for epoch in range(num_epochs):
    start_time = time.time()
    
    # Обучение
    model.train(True)
    
    acc_train_5 = 0
    acc_train = 0
    
    for X_batch, y_batch in tqdm(train_dataloader):
        logits = model(X_batch.to(device))
        loss = loss_function(logits, y_batch.to(device))
        custom_optimizer.zero_grad()
        
        predictions = torch.argmax(logits, dim=1)
        acc_train += ((predictions == y_batch.to(device)).sum().item() / X_batch.shape[0])
        loss.backward(create_graph=True)
        custom_optimizer.step()
        train_loss.append(loss.item())
        
    times.append(time.time() - start_time)
        
    train_accuracy.append(
        acc_train / len(train_dataloader)
    )
    
    
    # Валидация
    model.train(False)

    val_acc = 0
    val_acc_5 = 0

    for X_batch, y_batch in val_dataloader:
        probabilities = model(X_batch.to(device))
        predictions = torch.argmax(probabilities, dim=1)
        val_acc += ((predictions == y_batch.to(device)).sum().item() / X_batch.shape[0])
        
        
    val_accuracy.append(
        val_acc / len(val_dataloader)
    )
    

    epoch_val_accuracy = val_accuracy[-1] * 100
    
    if len(val_accuracy) > 3 and np.std(val_accuracy[-3:]) < 0.001 and epoch - lr_change_point >= 5:
        lr_change_point = epoch
        print('lr change')
        for group in custom_optimizer.param_groups:
            for p in group['params']:
                group['lr'] /= 2.
        

    if (epoch_val_accuracy > best_val_accuracy):
        torch.save(model.state_dict(), 'best_model.ml') 
        best_val_accuracy = epoch_val_accuracy

    clear_output(True)

    plt.figure(figsize=(20, 8))

    plt.subplot(1, 2, 1)
    plt.plot(train_loss)
    plt.xlabel('номер батча')
    plt.ylabel('значение')
    plt.title('Функция потерь')

    plt.subplot(1, 2, 2)
    plt.plot(val_accuracy, label='val top 1')
#     plt.plot(val_aacuracy_5, label='val top 5')
    plt.plot(train_accuracy, label='train top 1')
#     plt.plot(train_accuracy_5, label='train top 5')
    plt.legend()
    plt.xlabel('Номер эпохи')
    plt.ylabel('Точность')
    plt.show()
        
    print("Epoch {} of {} took {:.3f}s".format(
        epoch + 1, num_epochs, time.time() - start_time))
    print("  training loss (in-iteration): \t{:.6f}".format(
        np.mean(train_loss[-len(train_data) // batch_size :])))
    print("  validation accuracy: \t\t\t{:.2f} %".format(
            epoch_val_accuracy
        ))
    
    mean_train_loss.append(np.mean(train_loss[-len(train_data) // batch_size :]))
    
    logs = {
        'val top 1': val_accuracy,
        'train top 1': train_accuracy,
        'epoch_time': times
    }
    
    torch.save(logs, 'logs_hess.pt')

In [None]:
model = CustomResNet20_c10()
model.to(device)
loss_function = torch.nn.CrossEntropyLoss()
custom_optimizer = SGD(model.parameters(), 0.1, 0.8)
num_epochs = 200
train_loss = []
train_accuracy = []
train_accuracy_5 = []
val_accuracy = []
val_aacuracy_5 = []
times = []
best_val_accuracy = 0

In [None]:


for epoch in range(num_epochs):
    start_time = time.time()
    
    # Обучение
    model.train(True)
    
    acc_train_5 = 0
    acc_train = 0
    
    for X_batch, y_batch in tqdm(train_dataloader):
        logits = model(X_batch.to(device))
        loss = loss_function(logits, y_batch.to(device))
        
        predictions = torch.argmax(logits, dim=1)
        topk = torch.topk(logits, k=5, dim=1).indices
        acc_train += ((predictions == y_batch.to(device)).sum().item() / X_batch.shape[0])
        for k in range(5):
            acc_train_5 += (topk[:, k] == y_batch.to(device)).sum().item() / X_batch.shape[0]
        custom_optimizer.step(loss)
        train_loss.append(loss.item())
        
    times.append(time.time() - start_time)
        
    train_accuracy.append(
        acc_train / len(train_dataloader)
    )
    
    train_accuracy_5.append(
        acc_train_5 / len(train_dataloader)
    )
    
    
    # Валидация
    model.train(False)

    val_acc = 0
    val_acc_5 = 0

    for X_batch, y_batch in val_dataloader:
        probabilities = model(X_batch.to(device))
        predictions = torch.argmax(probabilities, dim=1)
        val_acc += ((predictions == y_batch.to(device)).sum().item() / X_batch.shape[0])
        
        topk = torch.topk(probabilities, k=5, dim=1).indices
        for k in range(5):
            val_acc_5 += (topk[:, k] == y_batch.to(device)).sum().item() / X_batch.shape[0]
        
    val_accuracy.append(
        val_acc / len(val_dataloader)
    )
    
    val_aacuracy_5.append(
        val_acc_5 / len(val_dataloader)
    )

    epoch_val_accuracy = val_accuracy[-1] * 100

    if (epoch_val_accuracy > best_val_accuracy):
        torch.save(model.state_dict(), 'best_model.ml') 
        best_val_accuracy = epoch_val_accuracy
        
    if len(val_accuracy) > 3 and np.std(val_accuracy[-3:]) < 0.001 and epoch - lr_change_point >= 5:
        lr_change_point = epoch
        print('lr change')
        for group in custom_optimizer.param_groups:
            for p in group['params']:
                group['lr'] /= 2.

    clear_output(True)

    plt.figure(figsize=(20, 8))

    plt.subplot(1, 2, 1)
    plt.plot(train_loss)
    plt.xlabel('номер батча')
    plt.ylabel('значение')
    plt.title('Функция потерь')

    plt.subplot(1, 2, 2)
    plt.plot(val_accuracy, label='val top 1')
    plt.plot(val_aacuracy_5, label='val top 5')
    plt.plot(train_accuracy, label='train top 1')
    plt.plot(train_accuracy_5, label='train top 5')
    plt.legend()
    plt.xlabel('Номер эпохи')
    plt.ylabel('Точность')
    plt.show()
        
    print("Epoch {} of {} took {:.3f}s".format(
        epoch + 1, num_epochs, time.time() - start_time))
    print("  training loss (in-iteration): \t{:.6f}".format(
        np.mean(train_loss[-len(train_data) // batch_size :])))
    print("  validation accuracy: \t\t\t{:.2f} %".format(
            epoch_val_accuracy
        ))
    
    logs = {
        'val top 1': val_accuracy,
        'val top 2': val_aacuracy_5,
        'train top 1': train_accuracy,
        'train top 5': train_accuracy_5,
        'epoch_time': times
    }
    
    torch.save(logs, 'logs_sgd.pt')