In [None]:
import sys
sys.path.append('/kaggle/input/alexnet/pytorch/baseline/1')
import torch
import torch.nn as nn
from torchvision import datasets
from baseline.model import AlexNetBaseline, init_params
from baseline.dataset_preparation import indices_split
from baseline.data_transforms import prepreprocess, get_preprocess, get_train_augment
from baseline.train import train
import matplotlib.pyplot as plt
from baseline.eval import accuracy, topk 

In [None]:
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
preprocess = get_preprocess(datasets.CIFAR10(
    'datasets/cifar10', train=True, download=True, transform=prepreprocess))
cv_dataset = datasets.CIFAR10(
    'datasets/cifar10', train=True, transform=preprocess)
train_augment = get_train_augment(cv_dataset)
train_dataset = datasets.CIFAR10(
    'datasets/cifar10', train=True, transform=train_augment)
test_dataset = datasets.CIFAR10(
    'datasets/cifar10', train=False,transform=preprocess)

cv_indices, train_indices = indices_split(len(train_dataset), ratio=0.2)
train_dataset = torch.utils.data.Subset(train_dataset, train_indices)
cv_dataset = torch.utils.data.Subset(cv_dataset, cv_indices)

print(f'Train: {len(train_dataset)}')
print(f'Cross Validation: {len(cv_dataset)}')
print(f'Test: {len(test_dataset)}')

In [None]:
model = AlexNetBaseline(len(test_dataset.classes))
if torch.cuda.device_count() > 1:
    print(f'Using {torch.cuda.device_count()} GPUs.')
    model = nn.DataParallel(model)
model.to(device)
init_params(model)
compiled_model = model
# if torch.cuda.is_available():
#     compiled_model = torch.compile(model)

In [None]:
costs, cv_error_rates, learning_rates = train(
    compiled_model,
    train_dataset=train_dataset,
    cv_dataset=cv_dataset,
    batch_size=128,
    num_epochs=90,
    initial_lr=0.01,
    num_workers=3
)

In [None]:
torch.save(model.state_dict(), 'baseline_cifar10.model.pt')
torch.save(preprocess.state_dict(), 'baseline_cifar10.preprocess.pt')
torch.save(train_augment.state_dict(), 'baseline_cifar10.train_augment.pt')

In [None]:
top1 = accuracy(compiled_model, test_dataset)
top3 = topk(compiled_model, test_dataset, k=3)
print('Top 1: {top1}')
print('Top 3: {top3}')

In [None]:
plt.plot(costs)
plt.xlabel("Epoch")
plt.title('Costs')

In [None]:
plt.plot(cv_error_rates)
plt.xlabel("Epoch")
plt.title('Cross Validation Error Rates')

In [None]:
plt.plot(learning_rates)
plt.xlabel("Epoch")
plt.title('Learning Rates')