<a href="https://colab.research.google.com/github/dajiro-repo/blog/blob/master/sam_class.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/davda54/sam.git

Cloning into 'sam'...
remote: Enumerating objects: 116, done.[K
remote: Counting objects: 100% (116/116), done.[K
remote: Compressing objects: 100% (95/95), done.[K
remote: Total 116 (delta 48), reused 47 (delta 16), pack-reused 0[K
Receiving objects: 100% (116/116), 634.08 KiB | 5.51 MiB/s, done.
Resolving deltas: 100% (48/48), done.


In [None]:
!curl -LO https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
!tar -xzvf cifar-10-python.tar.gz

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  162M  100  162M    0     0  43.5M      0  0:00:03  0:00:03 --:--:-- 43.5M
cifar-10-batches-py/
cifar-10-batches-py/data_batch_4
cifar-10-batches-py/readme.html
cifar-10-batches-py/test_batch
cifar-10-batches-py/data_batch_3
cifar-10-batches-py/batches.meta
cifar-10-batches-py/data_batch_2
cifar-10-batches-py/data_batch_5
cifar-10-batches-py/data_batch_1


In [None]:
from PIL import Image
import os
import matplotlib.pyplot as plt
import numpy as np
from typing import Any, Callable, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils as utils
import torchvision
import torchvision.utils as vutils
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.optim as optim

from sam.sam import SAM

In [None]:
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='latin1')
    return dict

In [None]:
###############
# 入力パラメータ
###############
#バッチサイズ
batch_size = 32
#エポック数
epochs = 50
#GPUID
ngpu = 1
#学習率
lr = 0.001
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

In [None]:
def get_data_loaders(batch_size):
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    train_data = CIFAR10(train=True, transform=transform)
    test_data = CIFAR10(train=False, transform=transform)

    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
                                             shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size,
                                             shuffle=True)
    
    return train_loader, test_loader

In [None]:
class CIFAR10():
    train_list = [
        "cifar-10-batches-py/data_batch_1",
        "cifar-10-batches-py/data_batch_2",
        "cifar-10-batches-py/data_batch_3",
        "cifar-10-batches-py/data_batch_4",
        "cifar-10-batches-py/data_batch_5",
    ]
    test_list = [
        "cifar-10-batches-py/test_batch"
    ]
    
    def __init__(
        self,
        train: bool =True,
        transform: Optional[Callable] = None
    ) -> None:
        self.transform = transform
        self.train = train
        if self.train:
            data_list = self.train_list
        else:
            data_list = self.test_list
        
        self.data: Any = []
        self.targets = []
        
        for filename in data_list:
            entry = unpickle(filename)
            self.data.append(entry["data"])
            self.targets.extend(entry["labels"])
        
        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose(0, 2, 3, 1)
        
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        img, target = self.data[index], self.targets[index]
        
        img = Image.fromarray(img)
        
        if self.transform is not None:
            img = self.transform(img)
        
        return img, target 
    
    def __len__(self):
        return len(self.data)

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
model = Net()
criterion = nn.CrossEntropyLoss()
base_optimizer = optim.SGD
optimizer = SAM(model.parameters(), base_optimizer, lr = lr, momentum=0.9)

In [None]:
# model = Net()
# criterion = nn.CrossEntropyLoss()
# base_optimizer = optim.Adam
# optimizer = SAM(model.parameters(), base_optimizer, lr = lr)
# model = Net()
# criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

In [None]:
dataloaders_dict = {}
train_loader, test_loader = get_data_loaders(batch_size)
dataloaders_dict["train"] = train_loader
dataloaders_dict["test"] = test_loader

In [None]:
accuracy, accuracy_val = [], []
model.to(device)
for epoch in range(epochs):

    for phase in ["train", "test"]:
        if phase == "train":
            model.train()
        else:
            model.eval()

        loss_epoch = 0.0
        acc_epoch = 0.0

        if (epoch == 0) and (phase == "train"):
            continue

        for inputs, labels in dataloaders_dict[phase]:
            optimizer.zero_grad()
            
            with torch.set_grad_enabled(phase == "train"):
                labels = labels.to(device)
                inputs = inputs.to(device)
                outputs = model(inputs)
                
                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)
                
                if phase == "train":
                    # fisrt forward-backward pass
                    loss.backward()
                    optimizer.first_step(zero_grad=True)
                    
                    # second forward-backward pass
                    criterion(model(inputs), labels).backward()
                    optimizer.second_step(zero_grad=True)
                    
                loss_epoch += loss.item() * inputs.size(0)
                acc_epoch += torch.sum(preds == labels.data)
        
        loss_epoch = loss_epoch / len(dataloaders_dict[phase].dataset)
        acc_epoch = acc_epoch.double() / len(dataloaders_dict[phase].dataset)
        print(f"phase: {phase}",
              f"epoch: {epoch}",
              f"loss: {loss_epoch:.4f}",
              f"accuracy: {acc_epoch:.4f}")
             # f"accuracy: {acc_epoch}:.4f}")
            
print('Finished Training')