In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import seaborn as sns
from torch import nn
import sys
sys.path.append('../')
from fun.models import *
sns.set_style("whitegrid")
from torch.utils.data import random_split
import pandas as pd

In [None]:

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])


dataset = torchvision.datasets.CIFAR10(
    root="../data",
    train=True,
    download=True,
    transform=transform  
)

torch.manual_seed(42)
valid_size = int(0.1 * len(dataset))
train_size = len(dataset) - valid_size


train_ds, valid_ds = random_split(dataset, [train_size, valid_size])
len(train_ds), len(valid_ds)   

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

def train(dataloader, model, loss_fn, optimizer):

    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.train()
    train_loss, correct = 0, 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 ==0:
            print(batch)

        train_loss += loss.item()
        correct += (pred.argmax(1) == y).type(torch.float).sum().item()


    train_loss /= num_batches
    correct /= size
    print(f" Train Accuracy: {(100*correct):>0.1f}%, Train Avg loss {train_loss:>8f} \n")
    return correct, train_loss


def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size

    print(f"Test Accuracy: {(100*correct):>0.1f}%, Test Avg loss: {test_loss:>8f} \n")
    return correct, test_loss
    

In [None]:
epochs = 30
learning_rate = 1e-3

batch_sizes = [4, 8, 16, 32, 64, 128]


data = np.array([[0]*len(batch_sizes)]*epochs)
batch_train_acc_history = pd.DataFrame(data, columns = batch_sizes)
batch_test_acc_history = pd.DataFrame(data, columns = batch_sizes)
batch_train_loss_history  = pd.DataFrame(data, columns = batch_sizes)
batch_test_loss_history  = pd.DataFrame(data, columns = batch_sizes)

for batch_size in batch_sizes:
    train_dataloader = DataLoader(
            train_ds,
            batch_size=batch_size, 
            shuffle=True
            )
    test_dataloader = DataLoader(
        valid_ds, 
        batch_size=batch_size,
        shuffle=False
        )

    #Freeze last layers (layer4)
    net = Resnet18_2().to(device)
    for param in net.conv_layers_frozen.parameters():    
        param.requires_grad = False


    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)

    for t in range(epochs):
        
        print(f"Epoch {t+1}\n-------------------------------")

        train_acc, train_loss = train(train_dataloader, net, criterion, optimizer)
        test_acc, test_loss = test(test_dataloader, net, criterion)
        batch_train_acc_history.loc[t,batch_size] = train_acc        
        batch_test_acc_history.loc[t,batch_size] = test_acc
        batch_train_loss_history.loc[t,batch_size] = train_loss
        batch_test_loss_history.loc[t,batch_size] = test_loss

# batch_train_acc_history.to_csv("batch_train_acc_history.csv")
# !cp batch_train_acc_history.csv "drive/My Drive/DL/Resnet18_2/batch_exp/"

# batch_test_acc_history.to_csv("batch_test_acc_history.csv")
# !cp batch_test_acc_history.csv "drive/My Drive/DL/Resnet18_2/batch_exp/"

# batch_train_loss_history.to_csv("batch_train_loss_history.csv")
# !cp batch_train_loss_history.csv "drive/My Drive/DL/Resnet18_2/batch_exp/"

# batch_test_loss_history.to_csv("batch_test_loss_history.csv")
# !cp batch_test_loss_history.csv "drive/My Drive/DL/Resnet18_2/batch_exp/"
 
    