In [204]:
import deep_learning_playground.papers.alexnet.model

In [205]:
alexnet = deep_learning_playground.papers.alexnet.model.AlexNet()

In [209]:
import torch
import torchvision
from torchvision.transforms import Compose, RandomCrop, RandomHorizontalFlip, ColorJitter, PILToTensor, Normalize, ConvertImageDtype,Resize
from torch.utils.data import DataLoader, random_split

In [210]:
train_augmentation = Compose([
    # 28 because in original paper, we crop 224 patches from 256 images. Cropping 28 from 32 is the same ratio.
    PILToTensor(),
    ConvertImageDtype(torch.float),
    # Assume uniform dist between 0 and 1
    RandomCrop(28),
    RandomHorizontalFlip(p=0.5),
    ColorJitter(brightness = 0.1),
    Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    Resize((224,224))
])

test_augmentation = Compose([
    PILToTensor(),
    ConvertImageDtype(torch.float),
    Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
    Resize((224, 224))
])

train_dataset_full = torchvision.datasets.CIFAR100('../data/CIFAR100/', transform = train_augmentation)
val_dataset_full = torchvision.datasets.CIFAR100('../data/CIFAR100/', transform = test_augmentation)

test_dataset = torchvision.datasets.CIFAR100('../data/CIFAR100/', train = False, transform=test_augmentation)


In [211]:
# Very ugly, but if we just use random split, we cannot define separate augmentations for train and val. So in the above
# cell, we initialize train and val to be the same dataset, then in this cell, we split them identically
generator=torch.Generator().manual_seed(42)
train_dataset, _ = random_split(train_dataset_full, lengths = [0.8, 0.2], generator=generator)
generator=torch.Generator().manual_seed(42)
_, val_dataset = random_split(val_dataset_full, lengths = [0.8, 0.2], generator=generator)

assert len(set(train_dataset.indices).intersection(set(val_dataset.indices))) == 0

In [212]:
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [214]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(alexnet.parameters(), lr = 0.01, momentum = 0.9, weight_decay=0.0005)

In [215]:
n_epochs = 5

In [221]:
def train_one_epoch(epoch_index=0):
    alexnet.train()
    running_loss = 0
    for i, data in enumerate(train_dataloader):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = alexnet(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 1 == 0:
            print(f"Batch {i} / {len(train_dataloader)}")

In [222]:
train_one_epoch()

Batch 0 / 313
Batch 1 / 313


KeyboardInterrupt: 