In [None]:
import sys
sys.path.append('/kaggle/input/alexnet/pytorch/baseline/1')
import torch
import torch.nn as nn
from torchvision import datasets
from baseline.model import AlexNetBaseline, init_params
from baseline.data_transforms import prepreprocess, get_preprocess, get_train_augment
from baseline.train import train
import matplotlib.pyplot as plt
from baseline.eval import top1_k 

In [None]:
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
train_dataset = datasets.CIFAR100(
    'datasets/cifar100', train=True, download=True, transform=prepreprocess)
preprocess = get_preprocess(train_dataset)

train_dataset = datasets.CIFAR100(
    'datasets/cifar100', train=True, transform=preprocess)
train_augment = get_train_augment(train_dataset, preprocess)

train_dataset = datasets.CIFAR100(
    'datasets/cifar100', train=True, transform=train_augment)

val_dataset = datasets.CIFAR100(
    'datasets/cifar100', train=False, transform=preprocess)
val_dataset, test_dataset = torch.utils.data.random_split(val_dataset, [0.5, 0.5])

print(f'Number of classes: {len(train_dataset.classes)}')
print(f'Train samples: {len(train_dataset)}')
print(f'Validation samples: {len(val_dataset)}')
print(f'Test samples: {len(test_dataset)}')

In [None]:
model = AlexNetBaseline(len(train_dataset.classes))
if torch.cuda.device_count() > 1:
    print(f'Using {torch.cuda.device_count()} GPUs.')
    model = nn.DataParallel(model)
model.to(device)
# initialize parameters
init_params(model)
compiled_model = model
# if torch.cuda.is_available():
#     compiled_model = torch.compile(model)

In [None]:
costs, val_error_rates, learning_rates = train(
    compiled_model,
    train_dataset=train_dataset,
    cv_dataset=val_dataset,
    batch_size=128,
    num_epochs=100,
    initial_lr=0.01,
    num_workers=3,
    patience=10
)

In [None]:
torch.save(model.state_dict(), 'baseline_cifar100.model.pt')
torch.save(preprocess, 'baseline_cifar100.preprocess.pt')
torch.save(train_augment, 'baseline_cifar100.train_augment.pt')

In [None]:
test_top1, test_top5 = top1_k(compiled_model, test_dataset, k=5)
print(f'Test Top 1: {test_top1}')
print(f'Test Top 5: {test_top5}')

val_top1, val_top5 = top1_k(compiled_model, val_dataset, k=5)
print(f'Val Top 1: {val_top1}')
print(f'Val Top 5: {val_top5}')

train_eval_dataset = datasets.CIFAR100(
    'datasets/cifar100', train=True, transform=preprocess)
train_top1, train_top5 = top1_k(compiled_model, train_eval_dataset, k=5)
print(f'Train Top 1: {train_top1}')
print(f'Train Top 5: {train_top5}')

In [None]:
plt.plot(costs)
plt.xlabel("Epoch")
plt.title('Costs')

In [None]:
plt.plot(val_error_rates)
plt.xlabel("Epoch")
plt.title('Validation Error Rates')

In [None]:
import math

plt.plot([math.log10(lr) for lr in learning_rates])
plt.ylabel("Log10(lr)")
plt.xlabel("Epoch")
plt.title('Learning Rates')