In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms
import math

import os
import argparse
import matplotlib.pyplot as plt
import numpy as np



In [2]:
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                e_w = p.grad * scale
                p.add_(e_w)  # climb to the local maximum "w + e(w)"
                self.state[p]["e_w"] = e_w

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    def step(self, closure=None):
        raise NotImplementedError("SAM doesn't work like the other optimizers, you should first call `first_step` and the `second_step`; see the documentation for more info.")

    def _grad_norm(self):
        norm = torch.norm(
                    torch.stack([
                        p.grad.norm(p=2)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

lr = 0.1  #learning rate
best_acc = 0
num_epochs = 225

In [4]:
train_transforms = transforms.Compose([
      transforms.RandomCrop(32, padding=4),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

test_transforms = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


In [5]:
train_data = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transforms)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True, num_workers=0)

test_data = torchvision.datasets.CIFAR100(root='./data', train = False, download=True, transform=test_transforms)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=False, num_workers=0)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:04<00:00, 37539810.30it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified


In [6]:
net = torchvision.models.resnet18(pretrained=True)
num_ftrs = net.fc.in_features
net.fc = nn.Linear(num_ftrs,100)

net = net.to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 119MB/s]


In [7]:
criterion = nn.CrossEntropyLoss()
base_optimizer = torch.optim.SGD
optimizer = SAM(net.parameters(), base_optimizer, rho=0.05, lr=0.1, momentum=0.9, weight_decay=5e-4)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)


In [8]:
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    H = []
    Y = []
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        Y.append(targets)
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)

        loss = criterion(outputs, targets)
        loss.mean().backward()
        optimizer.first_step(zero_grad=True)

        # second forward-backward step
        criterion(net(inputs), targets).mean().backward()
        optimizer.second_step(zero_grad=True)

        train_loss += loss.mean().item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        if(batch_idx % 50 == 0):
              print(str(batch_idx)+"/"+str(len(train_loader)) +"  Loss: " + str(train_loss/(batch_idx+1)) +"  Acc: "+ str(100.*correct/total))


        H.append(outputs.to('cpu'))

    H = torch.cat(H,0).detach().numpy()
    Y = torch.cat(Y,0).numpy()

    return train_loss / len(train_loader), 100.*correct/total, H, Y



In [9]:
def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    H = []
    Y = []
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            Y.append(targets)
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            # loss = smooth_crossentropy(outputs, targets)
            loss = criterion(outputs, targets)
            #loss = std_loss(outputs, targets)

            test_loss += loss.mean().item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            if(batch_idx % 50 == 0):
              print(str(batch_idx)+"/"+str(len(test_loader)) +"  Loss: " + str(test_loss/(batch_idx+1)) +"  Acc: "+ str(100.*correct/total))

            H.append(outputs.to('cpu'))

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc

    H = torch.cat(H,0).detach().numpy()
    Y = torch.cat(Y,0).numpy()

    return test_loss / len(test_loader), 100.*correct/total, H, Y


In [10]:
train_loss = []
train_accuracy = []
test_loss = []
test_accuracy = []


for epoch in range(0, num_epochs):
    tr_loss, tr_acc, _, _ = train(epoch)
    te_loss, te_acc, _, _ = test(epoch)

    train_loss.append(tr_loss)
    train_accuracy.append(tr_acc)
    test_loss.append(te_loss)
    test_accuracy.append(te_acc)
    print("Train Accuracy:  " , max(train_accuracy), "% ", "  Test Accuracy:  " , max(test_accuracy), "%")

    scheduler.step()


Epoch: 0
0/391  Loss: 4.899670124053955  Acc: 0.0
50/391  Loss: 4.988009957706227  Acc: 2.2518382352941178
100/391  Loss: 4.888335261014428  Acc: 2.4520420792079207
150/391  Loss: 4.740847698110619  Acc: 2.8197433774834435
200/391  Loss: 4.595420138752876  Acc: 3.626399253731343
250/391  Loss: 4.48030491369179  Acc: 4.4135956175298805
300/391  Loss: 4.38302342836247  Acc: 5.048276578073089
350/391  Loss: 4.302273831136546  Acc: 5.669070512820513
0/79  Loss: 3.6348345279693604  Acc: 13.28125
50/79  Loss: 3.759814094094669  Acc: 10.814950980392156
Saving..
Train Accuracy:   6.112 %    Test Accuracy:   10.66 %

Epoch: 1




0/391  Loss: 3.7717888355255127  Acc: 7.8125
50/391  Loss: 3.778241994334202  Acc: 10.217524509803921
100/391  Loss: 3.7152663221453675  Acc: 11.270111386138614
150/391  Loss: 3.6744755577567396  Acc: 11.863617549668874
200/391  Loss: 3.641258246863066  Acc: 12.26290422885572
250/391  Loss: 3.6096377885674102  Acc: 12.749003984063744
300/391  Loss: 3.573472305785778  Acc: 13.44736295681063
350/391  Loss: 3.5444630435389333  Acc: 13.913372507122507
0/79  Loss: 3.14821195602417  Acc: 19.53125
50/79  Loss: 3.292984780143289  Acc: 18.995098039215687
Saving..
Train Accuracy:   14.376 %    Test Accuracy:   18.9 %

Epoch: 2
0/391  Loss: 3.23058819770813  Acc: 21.875
50/391  Loss: 3.2357575566160914  Acc: 20.159313725490197
100/391  Loss: 3.327021889167257  Acc: 18.48700495049505
150/391  Loss: 3.3651897165159514  Acc: 17.528973509933774
200/391  Loss: 3.3440244162260595  Acc: 17.89101368159204
250/391  Loss: 3.31732375118362  Acc: 18.236429282868524
300/391  Loss: 3.2853705780054643  Acc: 18.

In [11]:
torch.save(net.state_dict(), 'ckpt_final.pth')