In [None]:
# import sys
# sys.path.append('/kaggle/input/alexnet/pytorch/baseline/1')
import torch
import torch.nn as nn
from torchvision import datasets
from baseline.alexnet_baseline import AlexNetBaseline
from baseline.dataset_preparation import indices_split
from baseline.data_transforms import preprocess, train_augment
from baseline.train import train
import matplotlib.pyplot as plt

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
train_dataset = datasets.CIFAR10(
    'datasets/cifar10', train=True, download=True, 
    transform=lambda X: train_augment(preprocess(X)))
cv_dataset = datasets.CIFAR10(
    'datasets/cifar10', train=True, download=True, 
    transform=lambda X: preprocess(X))
test_dataset = datasets.CIFAR10(
    'datasets/cifar10', train=False, download=True, 
    transform=lambda X: preprocess(X))

cv_indices, train_indices = indices_split(len(train_dataset), ratio=0.2)
train_dataset = torch.utils.data.Subset(train_dataset, train_indices)
cv_dataset = torch.utils.data.Subset(cv_dataset, cv_indices)

print(f'Train: {len(train_dataset)}')
print(f'Cross Validation: {len(cv_dataset)}')
print(f'Test: {len(test_dataset)}')

In [None]:
model_baseline = AlexNetBaseline(len(test_dataset.classes))
if torch.cuda.device_count() > 1:
    print(f'Using {torch.cuda.device_count()} GPUs.')
    model_baseline = nn.DataParallel(model_baseline)
model_baseline.to(device)

costs_baseline, cv_error_rates_baseline, learning_rates_baseline = train(
    model_baseline,
    train_dataset=train_dataset,
    cv_dataset=cv_dataset,
)

In [None]:
plt.plot(costs_baseline)

In [None]:
plt.plot(cv_error_rates_baseline)

In [None]:
plt.plot(learning_rates_baseline)