In [312]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision
from tqdm import tqdm

In [313]:
BATCH_SIZE = 128
EPOCHS = 20
LEARNING_RATE = 1e-4
CHECKPOINT_FOLDER = 'checkpoint'
DEVICE = torch.device('cpu')

In [314]:
transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

cifar_iter = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms)
train_iter, val_iter = torch.utils.data.random_split(cifar_iter, [45000, 5000])
train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, shuffle=False)

test_iter = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms)
test_dataloader = DataLoader(test_iter, batch_size=BATCH_SIZE, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [315]:
# print(train_iter.data.shape)
# data = torch.tensor(train_iter.data, dtype=torch.float32)
# mean = torch.mean(data, dim=(0, 1, 2)) / 255
# std = torch.std(data, dim=(0, 1, 2)) / 255
# print(mean, std)

In [316]:
class ImageClassifier(nn.Module):
    def __init__(self):
        super(ImageClassifier, self).__init__()
        self.conv1 = nn.Conv2d(3, 8, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=5, stride=1, padding=2)
        self.conv3 = nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2)
        self.pool = nn.MaxPool2d(2, 2)
        
        self.linear = nn.Sequential(
            nn.Linear(512, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 10),
        )
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))

        x = torch.flatten(x, 1)
        x = self.linear(x)
        return x

In [317]:
model = ImageClassifier()
model.to(DEVICE)

optimizer = Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
loss_fn = nn.CrossEntropyLoss()

In [318]:
def train_epoch(train_dataloader: DataLoader, model: ImageClassifier, optimizer):
    model.train()
    losses = 0

    for images, labels in tqdm(train_dataloader):
        images.to(DEVICE)
        labels.to(DEVICE)

        logits = model(images)
        loss = loss_fn(logits, labels)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        losses += loss.item()

    return losses / len(train_dataloader)

In [319]:
def evaluate(val_dataloader: DataLoader, model: ImageClassifier):
    model.eval()
    losses = 0

    for images, labels in tqdm(val_dataloader):
        images.to(DEVICE)
        labels.to(DEVICE)

        logits = model(images)
        loss = loss_fn(logits, labels)

        losses += loss.item()

    return losses / len(val_dataloader)

In [323]:
def score(test_dataloader: DataLoader, model: ImageClassifier):
    model.eval()
    losses = 0
    acc = 0

    for images, labels in test_dataloader:
        images.to(DEVICE)
        labels.to(DEVICE)

        logits = model(images)
        loss = loss_fn(logits, labels)

        losses += loss.item()
        _, max = torch.max(logits, dim=-1)
        acc += torch.sum(max == labels).item()
        
    return losses / len(test_dataloader), acc / len(test_iter)

In [321]:
for epoch in range(EPOCHS):
    train_loss = train_epoch(train_dataloader, model, optimizer)
    val_loss = evaluate(val_dataloader, model)
    print(f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}")
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
    }, CHECKPOINT_FOLDER + f'/cifar_epoch{epoch}.pt')

100%|██████████| 352/352 [00:17<00:00, 20.59it/s]
100%|██████████| 40/40 [00:01<00:00, 39.16it/s]


Epoch: 0, Train loss: 2.143, Val loss: 1.929


100%|██████████| 352/352 [00:18<00:00, 19.21it/s]
100%|██████████| 40/40 [00:00<00:00, 40.20it/s]


Epoch: 1, Train loss: 1.849, Val loss: 1.793


100%|██████████| 352/352 [00:16<00:00, 21.70it/s]
100%|██████████| 40/40 [00:01<00:00, 39.64it/s]


Epoch: 2, Train loss: 1.759, Val loss: 1.742


100%|██████████| 352/352 [00:16<00:00, 21.06it/s]
100%|██████████| 40/40 [00:00<00:00, 42.07it/s]


Epoch: 3, Train loss: 1.708, Val loss: 1.700


100%|██████████| 352/352 [00:16<00:00, 21.96it/s]
100%|██████████| 40/40 [00:00<00:00, 42.43it/s]


Epoch: 4, Train loss: 1.665, Val loss: 1.673


100%|██████████| 352/352 [00:16<00:00, 21.19it/s]
100%|██████████| 40/40 [00:01<00:00, 38.92it/s]


Epoch: 5, Train loss: 1.632, Val loss: 1.634


100%|██████████| 352/352 [00:16<00:00, 20.96it/s]
100%|██████████| 40/40 [00:00<00:00, 42.88it/s]


Epoch: 6, Train loss: 1.602, Val loss: 1.604


100%|██████████| 352/352 [00:16<00:00, 21.35it/s]
100%|██████████| 40/40 [00:01<00:00, 35.02it/s]


Epoch: 7, Train loss: 1.577, Val loss: 1.583


100%|██████████| 352/352 [00:16<00:00, 21.84it/s]
100%|██████████| 40/40 [00:00<00:00, 43.23it/s]


Epoch: 8, Train loss: 1.553, Val loss: 1.580


100%|██████████| 352/352 [00:16<00:00, 21.14it/s]
100%|██████████| 40/40 [00:00<00:00, 43.51it/s]


Epoch: 9, Train loss: 1.533, Val loss: 1.552


100%|██████████| 352/352 [00:17<00:00, 20.69it/s]
100%|██████████| 40/40 [00:01<00:00, 37.47it/s]


Epoch: 10, Train loss: 1.513, Val loss: 1.525


100%|██████████| 352/352 [00:18<00:00, 19.25it/s]
100%|██████████| 40/40 [00:01<00:00, 36.78it/s]


Epoch: 11, Train loss: 1.496, Val loss: 1.510


100%|██████████| 352/352 [00:16<00:00, 21.12it/s]
100%|██████████| 40/40 [00:00<00:00, 42.48it/s]


Epoch: 12, Train loss: 1.480, Val loss: 1.504


100%|██████████| 352/352 [00:16<00:00, 21.74it/s]
100%|██████████| 40/40 [00:00<00:00, 41.44it/s]


Epoch: 13, Train loss: 1.465, Val loss: 1.487


100%|██████████| 352/352 [00:16<00:00, 21.64it/s]
100%|██████████| 40/40 [00:00<00:00, 40.74it/s]


Epoch: 14, Train loss: 1.448, Val loss: 1.471


100%|██████████| 352/352 [00:16<00:00, 21.72it/s]
100%|██████████| 40/40 [00:00<00:00, 42.69it/s]


Epoch: 15, Train loss: 1.435, Val loss: 1.475


100%|██████████| 352/352 [00:16<00:00, 21.14it/s]
100%|██████████| 40/40 [00:01<00:00, 38.85it/s]


Epoch: 16, Train loss: 1.421, Val loss: 1.441


100%|██████████| 352/352 [00:16<00:00, 21.94it/s]
100%|██████████| 40/40 [00:00<00:00, 41.42it/s]


Epoch: 17, Train loss: 1.408, Val loss: 1.442


100%|██████████| 352/352 [00:17<00:00, 19.96it/s]
100%|██████████| 40/40 [00:01<00:00, 39.84it/s]


Epoch: 18, Train loss: 1.395, Val loss: 1.419


100%|██████████| 352/352 [00:16<00:00, 21.15it/s]
100%|██████████| 40/40 [00:00<00:00, 42.02it/s]

Epoch: 19, Train loss: 1.382, Val loss: 1.407





In [324]:
test_loss, test_acc = score(test_dataloader, model)
print(test_loss, test_acc)

1.380893566940404 0.5059
