In [229]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision
from tqdm import tqdm

In [230]:
BATCH_SIZE = 64
EPOCHS = 30
LEARNING_RATE = 1e-3
CHECKPOINT_FOLDER = 'checkpoint'
DEVICE = torch.device('cpu')

In [231]:
transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

cifar_iter = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms)
train_iter, val_iter = torch.utils.data.random_split(cifar_iter, [45000, 5000])
train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, shuffle=False)

test_iter = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms)
test_dataloader = DataLoader(test_iter, batch_size=BATCH_SIZE, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [232]:
# print(train_iter.data.shape)
# data = torch.tensor(train_iter.data, dtype=torch.float32)
# mean = torch.mean(data, dim=(0, 1, 2)) / 255
# std = torch.std(data, dim=(0, 1, 2)) / 255
# print(mean, std)

In [233]:
class ImageClassifier(nn.Module):
    def __init__(self):
        super(ImageClassifier, self).__init__()
        self.conv1 = nn.Conv2d(3, 8, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=5, stride=1, padding=2)
        self.conv3 = nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2)
        self.pool = nn.MaxPool2d(2, 2)
        
        self.linear = nn.Sequential(
            nn.Linear(512, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 10),
        )
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))

        x = torch.flatten(x, 1)
        x = self.linear(x)
        return x

In [234]:
model = ImageClassifier()
model.to(DEVICE)

optimizer = Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
loss_fn = nn.CrossEntropyLoss()

In [235]:
def train_epoch(train_dataloader: DataLoader, model: ImageClassifier, optimizer):
    model.train()
    losses = 0

    for images, labels in tqdm(train_dataloader):
        images.to(DEVICE)
        labels.to(DEVICE)

        logits = model(images)
        loss = loss_fn(logits, labels)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        losses += loss.item()

    return losses / len(train_dataloader)

In [236]:
def evaluate(val_dataloader: DataLoader, model: ImageClassifier):
    model.eval()
    losses = 0

    for images, labels in val_dataloader:
        images.to(DEVICE)
        labels.to(DEVICE)

        logits = model(images)
        loss = loss_fn(logits, labels)

        losses += loss.item()

    return losses / len(val_dataloader)

In [None]:
for epoch in range(EPOCHS):
    train_loss = train_epoch(train_dataloader, model, optimizer)
    val_loss = evaluate(val_dataloader, model)
    print(f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}")
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
    }, CHECKPOINT_FOLDER + f'/cifar_epoch{epoch}.pt')