In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import ssl
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from torch import nn
import sys
sys.path.append('../')
sns.set_style("whitegrid")
from torchvision import models
import pandas as pd

In [None]:
transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

transform_train = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor()
    ])


train_ds = torchvision.datasets.CIFAR10(
    root="../data",
    train=True,
    download=True,
    transform=transform_train

)

test_ds = torchvision.datasets.CIFAR10(
    root="../data",
    train=False,
    download=True,
    transform=transform_test
)

torch.manual_seed(42)

In [None]:
def plot_data(dataset_augm):
    labels_map = {
        0:'plane', 
        1:'car', 
        2:'bird', 
        3:'cat',
        4:'deer', 
        5:'dog',
        6:'frog',
        7:'horse', 
        8:'ship', 
        9:'truck'

    }
    figure = plt.figure(figsize=(8, 8))
    cols, rows = 3, 3
    for i in range(1, cols * rows + 1):
        sample_idx = torch.randint(len(dataset_augm), size=(1,)).item()
        img = dataset_augm[sample_idx]
        figure.add_subplot(rows, cols, i)
        # plt.title(labels_map[label])
        plt.axis("off")
        img = img.cpu()
        img = img.numpy()
        img = np.transpose(img, (1, 2, 0))
        plt.imshow(img.squeeze(), cmap="gray")
    plt.show()

In [None]:

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def train(dataloader, model, loss_fn, optimizer, beta, cutmix_prob):

    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.train()
    train_loss, correct = 0, 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        r = np.random.rand(1)
        if beta > 0 and r < cutmix_prob:
            lam = np.random.beta(beta, beta)
            rand_index = torch.randperm(X.size()[0]).cuda()
            target_a = y
            target_b = y[rand_index]
            bbx1, bby1, bbx2, bby2 = rand_bbox(X.size(), lam)
            X[:, :, bbx1:bbx2, bby1:bby2] = X[rand_index, :, bbx1:bbx2, bby1:bby2]
            # adjust lambda to exactly match pixel ratio
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (X.size()[-1] * X.size()[-2]))

            # plot_data(X)
            
            pred = model(X)
            loss = loss_fn(pred, target_a) * lam + loss_fn(pred, target_b) * (1. - lam)

        else:

            pred = model(X)
            loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


        train_loss += loss.item()
        correct += (pred.argmax(1) == y).type(torch.float).sum().item()


    train_loss /= num_batches
    correct /= size
    print(f" Train Accuracy: {(100*correct):>0.1f}%, Train Avg loss {train_loss:>8f} \n")
    return correct, train_loss


def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size

    print(f"Test Accuracy: {(100*correct):>0.1f}%, Test Avg loss: {test_loss:>8f} \n")
    return correct, test_loss
    

In [None]:
class Resnet18_2(nn.Module):
    def __init__(self):
        super().__init__()
        resnet18_pre = models.resnet18(pretrained=True)

        self.conv_layers_frozen = nn.ModuleList()
        for child in resnet18_pre.named_children():
            if child[0] == "layer4":
                break
            self.conv_layers_frozen.append(child[1])

        self.layer_to_train = resnet18_pre.layer4
        self.pool = resnet18_pre.avgpool
        self.fc1 = nn.Linear(resnet18_pre.fc.in_features, 10)

    def forward(self, x):

        for layer in self.conv_layers_frozen:
            x = layer(x)
        x = self.layer_to_train(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)

        return x


In [None]:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [None]:
epochs = 30
learning_rate = 1e-3

batch_size = 8
step_size = 5
gamma = 0.5
weight_decay=0.01

cutmix_prob = 0.6
beta = 1


data = np.array([[0]*4]*epochs)
adam_wd_schedul_history = pd.DataFrame(data, columns = ["test_acc", "train_acc", "test_loss", "train_loss"])


train_dataloader = DataLoader(
        train_ds,
        batch_size=batch_size, 
        shuffle=True
        )
test_dataloader = DataLoader(
    test_ds, 
    batch_size=batch_size,
    shuffle=False
    )

#Freeze last layers (layer4)
net = Resnet18_2().to(device)
for param in net.conv_layers_frozen.parameters():    
    param.requires_grad = False


criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate, weight_decay=weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

for t in range(epochs):
    
    print(f"Epoch {t+1}\n-------------------------------")

    train_acc, train_loss = train(train_dataloader, net, criterion, optimizer, beta, cutmix_prob)
    test_acc, test_loss = test(test_dataloader, net, criterion)
    lr_scheduler.step()
    adam_wd_schedul_history.loc[t,"train_acc"] = train_acc        
    adam_wd_schedul_history.loc[t,"test_acc"] = test_acc
    adam_wd_schedul_history.loc[t,"train_loss"] = train_loss
    adam_wd_schedul_history.loc[t,"test_loss"] = test_loss

    torch.save({
            'epoch': t+1,
            'model_state_dict': net.state_dict(),
            'train_loss': train_loss,
            'optimizer_state_dict': optimizer.state_dict()
            }, "drive/My Drive/DL/Resnet18_2/cutmix/model_epoch_"+str(t+1)+".pt")

# adam_wd_schedul_history.to_csv("cutmix_resnet18_2_history.csv")
# !cp cutmix_resnet18_2_history.csv "drive/My Drive/DL/Resnet18_2/cutmix/"

 
    