In [1]:
%load_ext autoreload
%autoreload 2

### Import

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import os
import time
import numpy as np

### MNIST 

In [3]:
def get_dataloaders_mnist(batch_size,
                           num_workers=0,
                           root='data',
                           validation_fraction=0.1,
                           train_transforms=None,
                           test_transforms=None):

    if train_transforms is None:
        train_transforms = torchvision.transforms.ToTensor()

    if test_transforms is None:
        test_transforms = torchvision.transforms.ToTensor()

    # Load training data.
    train_dataset = torchvision.datasets.MNIST(
        root=root,
        train=True,
        transform=train_transforms,
        download=True
    )

    # Load validation data.
    valid_dataset = torchvision.datasets.MNIST(
        root=root,
        train=True,
        transform=test_transforms
    )

    # Load test data.
    test_dataset = torchvision.datasets.MNIST(
        root=root,
        train=False,
        transform=test_transforms
    )

    # Perform index-based train-validation split of original training data.
    total = len(train_dataset)  # Get overall number of samples in original training data.
    idx = list(range(total))  # Make index list.
    np.random.shuffle(idx)  # Shuffle indices.
    vnum = int(validation_fraction * total)  # Determine number of validation samples from validation split.
    train_indices, valid_indices = idx[vnum:], idx[0:vnum]  # Extract train and validation indices.

    # Get samplers.
    train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
    valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)

    # Get data loaders.
    valid_loader = torch.utils.data.DataLoader(
        dataset=valid_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        sampler=valid_sampler
    )

    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        drop_last=True,
        sampler=train_sampler
    )

    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=False
    )

    return train_loader, valid_loader, test_loader


### ResNet18

In [4]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * out_channels)
            )

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.shortcut(x)
        out = self.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes):
        super(ResNet, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self.make_layer(block, 64, layers[0])
        self.layer2 = self.make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self.make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self.make_layer(block, 512, layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def make_layer(self, block, out_channels, blocks, stride=1):
        layers = []
        layers.append(block(self.in_channels, out_channels, stride))
        self.in_channels = out_channels * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.maxpool(out)

        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)

        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        logits = out
        probas = F.softmax(logits, dim=1)
        return logits, probas

        return out

def ResNet18(num_classes):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)

### Training settings

In [5]:
def set_all_seeds(seed): # exclude nondeterminism
    os.environ["PL_GLOBAL_SEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [6]:
seed = 1
batch_size    = 256 
num_epochs    = 20 

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
set_all_seeds(seed=seed)

### Preparation

In [7]:
train_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((70, 70)),
    torchvision.transforms.RandomCrop((64, 64)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((70, 70)),        
    torchvision.transforms.CenterCrop((64, 64)),            
    torchvision.transforms.ToTensor(),                
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_loader, valid_loader, test_loader = get_dataloaders_mnist(batch_size=batch_size)

for images, labels in train_loader:  
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    print('Class labels of 10 examples:', labels[:10])
    break

Image batch dimensions: torch.Size([256, 1, 28, 28])
Image label dimensions: torch.Size([256])
Class labels of 10 examples: tensor([7, 0, 4, 7, 2, 9, 6, 4, 7, 8])


### Training functions

In [8]:
def compute_accuracy(model, data_loader, device):

    with torch.no_grad(): # Context-manager that disables gradient calculation to reduce memory consumption.

        # Initialize number of correctly predicted samples + overall number of samples.
        correct_pred, num_samples = 0, 0

        for i, (features, targets) in enumerate(data_loader):
            features = features.to(device) 
            targets = targets.to(device)
            
            result = model(features)
            _  , predictions = torch.max(result[0], dim=1)
            num_samples += targets.size(0)
            correct_pred += (predictions == targets).sum()
    
    return correct_pred.float() / num_samples * 100 

In [9]:
def train_model(model, num_epochs, train_loader,
                valid_loader, test_loader, optimizer,
                device, logging_interval=50,
                scheduler=None):

    start = time.time()
    loss_fn = F.cross_entropy
    
    loss_history, train_acc_history, valid_acc_history = [], [], []
    
    for epoch in range(num_epochs): # Loop over epochs.
      
        # Training
        
        model.train()
        for batch_idx, (features, targets) in enumerate(train_loader): # Loop over mini batches.
            
            features = features.to(device) # Convert features and targets to used device.
            targets = targets.to(device)
            
            result = model(features) # Forward pass
            loss = loss_fn(result[0], targets)
            
            optimizer.zero_grad()
            
            loss.backward() # Backward pass
            
            optimizer.step() # Update model parameters
            
            loss_history.append(loss.item())
            
            if not batch_idx % logging_interval:
                print(f'Epoch: {epoch+1:03d}/{num_epochs:03d} '
                      f'| Batch {batch_idx:04d}/{len(train_loader):04d} '
                      f'| Loss: {loss:.4f}')
                
        # Validation
        
        model.eval()
        
        with torch.no_grad(): 
            
            train_acc = compute_accuracy(model, train_loader, device)
            valid_acc = compute_accuracy(model, valid_loader, device)
            
            print(f'Epoch: {epoch+1:03d}/{num_epochs:03d} '
                  f'| Train: {train_acc :.2f}% '
                  f'| Validation: {valid_acc :.2f}%')
            
            valid_acc_history.append(valid_acc)
            train_acc_history.append(train_acc)
            
        elapsed = time.time() - start
        print(f'Time elapsed: {elapsed:.2f}s')
        
        if scheduler is not None: scheduler.step(valid_acc_history[-1])
        
    elapsed = time.time() - start
    print(f'Total Training Time: {elapsed:.2f}s')
    
    
    
    test_acc = compute_accuracy(model, test_loader, device)
    print(f'Test accuracy: {test_acc :.2f}%')
    
    return loss_history, train_acc_history, valid_acc_history


## Training

In [10]:
model = ResNet18(num_classes=10)
model = model.to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)  

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, mode='max', verbose=True)

loss_list, train_acc_list, valid_acc_list = [], [], []
loss_list, train_acc_list, valid_acc_list = train_model(model=model, 
                                                        num_epochs=num_epochs,
                                                        train_loader=train_loader, 
                                                        valid_loader=valid_loader, 
                                                        test_loader=test_loader,
                                                        optimizer=optimizer,
                                                        device=device, 
                                                        scheduler=scheduler,
                                                        logging_interval=50)


Epoch: 001/020 | Batch 0000/0210 | Loss: 2.5828
Epoch: 001/020 | Batch 0050/0210 | Loss: 0.1752
Epoch: 001/020 | Batch 0100/0210 | Loss: 0.1288
Epoch: 001/020 | Batch 0150/0210 | Loss: 0.0503
Epoch: 001/020 | Batch 0200/0210 | Loss: 0.0782
Epoch: 001/020 | Train: 97.41% | Validation: 96.97%
Time elapsed: 16.33s
Epoch: 002/020 | Batch 0000/0210 | Loss: 0.0508
Epoch: 002/020 | Batch 0050/0210 | Loss: 0.0427
Epoch: 002/020 | Batch 0100/0210 | Loss: 0.0700
Epoch: 002/020 | Batch 0150/0210 | Loss: 0.0415
Epoch: 002/020 | Batch 0200/0210 | Loss: 0.0316
Epoch: 002/020 | Train: 98.60% | Validation: 98.25%
Time elapsed: 27.48s
Epoch: 003/020 | Batch 0000/0210 | Loss: 0.0259
Epoch: 003/020 | Batch 0050/0210 | Loss: 0.0563
Epoch: 003/020 | Batch 0100/0210 | Loss: 0.0754
Epoch: 003/020 | Batch 0150/0210 | Loss: 0.0599
Epoch: 003/020 | Batch 0200/0210 | Loss: 0.0471
Epoch: 003/020 | Train: 99.06% | Validation: 98.57%
Time elapsed: 38.72s
Epoch: 004/020 | Batch 0000/0210 | Loss: 0.0707
Epoch: 004/02