In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as T

from torchmetrics.aggregation import MeanMetric
from torchmetrics.functional.classification import accuracy

In [2]:
# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Dataset
dataset_dir = 'data'
batch_size = 64

# Training
base_lr = 0.001
epochs = 20

# Save
checkpoint_dir = 'checkpoint'

In [3]:
# Build dataset
train_set = torchvision.datasets.FashionMNIST(
    root=dataset_dir,
    train=True,
    download=True,
    transform=T.ToTensor(),
)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)

val_set = torchvision.datasets.FashionMNIST(
    root=dataset_dir,
    train=False,
    download=True,
    transform=T.ToTensor(),
)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size)

In [4]:
# Define model
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, 10)
        self.act = nn.ReLU()

    def forward(self, x):
        x = x.reshape((x.shape[0], -1))
        x = self.act(self.fc1(x))
        x = self.act(self.fc2(x))
        x = self.fc3(x)
        return x

# Build model
model = MLP()

# Move model to device
model = model.to(device)

In [5]:
# Build optimizer 
optimizer = optim.Adam(model.parameters(), lr=base_lr)

# Build scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs * len(train_loader))

# Build loss function
loss_fn = nn.CrossEntropyLoss()

# Build metric function
metric_fn = accuracy

In [6]:
# Define training function 
def train(loader, model, optimizer, scheduler, loss_fn, metric_fn, device):
    # Set model to train mode
    model.train()
    
    # Create average meters to measure loss and metric
    loss_mean = MeanMetric().to(device)
    metric_mean = MeanMetric().to(device)
    
    # train model for one epoch
    for inputs, targets in loader:
        # Move data to device
        inputs = inputs.to(device)
        targets = targets.to(device)

        # Forward
        outputs = model(inputs)
        loss = loss_fn(outputs, targets)
        metric = metric_fn(outputs, targets)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update statistics
        loss_mean.update(loss)
        metric_mean.update(metric)

        # Update learning rate
        scheduler.step()

    # Summarize statistics
    summary = {'loss': loss_mean.compute(), 'metric': metric_mean.compute()}

    return summary

In [7]:
# Define evaluation function 
def evaluate(loader, model, loss_fn, metric_fn, device):
    # Set model to evaluation mode
    model.eval()
    
    # Create average meters to measure loss and accuracy
    loss_mean = MeanMetric().to(device)
    metric_mean = MeanMetric().to(device)
    
    # Evalute model for one epoch
    for inputs, targets in loader:
        # Move data to device
        inputs = inputs.to(device)
        targets = targets.to(device)

        # Forward
        with torch.no_grad():
            outputs = model(inputs)
        loss = loss_fn(outputs, targets)
        metric = metric_fn(outputs, targets)

        # Update statistics
        loss_mean.update(loss)
        metric_mean.update(metric)
    
    # Summarize statistics
    summary = {'loss': loss_mean.compute(), 'metric': metric_mean.compute()}

    return summary

In [8]:
# Main loop
for epoch in range(epochs):
    # train one epoch
    train_summary = train(train_loader, model, optimizer, scheduler, loss_fn, metric_fn, device)
    
    # evaluate one epoch
    val_summary = evaluate(val_loader, model, loss_fn, metric_fn, device)
    
    # print log
    print((f'Epoch {epoch+1}: '
           + f'Train Loss {train_summary["loss"]:.04f}, ' 
           + f'Train Accuracy {train_summary["metric"]:.04f}, '
           + f'Test Loss {val_summary["loss"]:.04f}, '
           + f'Test Accuracy {val_summary["metric"]:.04f}'))
    

Epoch 1: Train Loss 0.4911, Train Accuracy 0.8235, Test Loss 0.3911, Test Accuracy 0.8593
Epoch 2: Train Loss 0.3595, Train Accuracy 0.8685, Test Loss 0.3660, Test Accuracy 0.8687
Epoch 3: Train Loss 0.3200, Train Accuracy 0.8814, Test Loss 0.3671, Test Accuracy 0.8685
Epoch 4: Train Loss 0.2954, Train Accuracy 0.8880, Test Loss 0.3479, Test Accuracy 0.8710
Epoch 5: Train Loss 0.2720, Train Accuracy 0.8974, Test Loss 0.3176, Test Accuracy 0.8856
Epoch 6: Train Loss 0.2534, Train Accuracy 0.9044, Test Loss 0.3343, Test Accuracy 0.8831
Epoch 7: Train Loss 0.2381, Train Accuracy 0.9087, Test Loss 0.3085, Test Accuracy 0.8921
Epoch 8: Train Loss 0.2191, Train Accuracy 0.9159, Test Loss 0.3450, Test Accuracy 0.8860
Epoch 9: Train Loss 0.2055, Train Accuracy 0.9217, Test Loss 0.3276, Test Accuracy 0.8921
Epoch 10: Train Loss 0.1899, Train Accuracy 0.9285, Test Loss 0.3179, Test Accuracy 0.8974
Epoch 11: Train Loss 0.1728, Train Accuracy 0.9342, Test Loss 0.3187, Test Accuracy 0.8953
Epoch 12

In [9]:
# Save checkpoint
import os

os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = f'{checkpoint_dir}/model.pth'

torch.save(model.state_dict(), checkpoint_path)

In [10]:
# Load checkpoint
model_pretrained = MLP()

model_pretrained.load_state_dict(torch.load(checkpoint_path))

<All keys matched successfully>

In [11]:
# Comparison with randomly initialized model
model_random = MLP()

model_random.to(device)
model_pretrained.to(device)

random_summary = evaluate(val_loader, model_random, loss_fn, metric_fn, device)
pretraiend_summary = evaluate(val_loader, model_pretrained, loss_fn, metric_fn, device)

print(f'[Random] Test Acc {random_summary["metric"]:.04f}')
print(f'[Pretrained] Test Acc {pretraiend_summary["metric"]:.04f}')

[Random] Test Acc 0.0877
[Pretrained] Test Acc 0.9020
