In [1]:
import datetime
import time

import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%matplotlib inline

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

cuda


In [3]:
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation((-7, 7)),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261)),
])

In [4]:
DATA_PATH = 'data'
BATCH_SIZE = 512

In [5]:
dataset_train = tv.datasets.CIFAR10(DATA_PATH, train=True, download=True, transform=transform_train)
dataset_test = tv.datasets.CIFAR10(DATA_PATH, train=False, download=True, transform=transform_test)

loader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
loader_test = DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=True)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
class ResNet18(nn.Module):
    def __init__(self, n_labels) -> None:
        super(ResNet18, self).__init__()
        self.n_labels = n_labels

        # Load pre-trained resnet model
        resnet = tv.models.resnet18(weights=tv.models.ResNet18_Weights.IMAGENET1K_V1, progress=False)

        self.conv1 = resnet.conv1
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4
        self.avgpool = resnet.avgpool
        n_in_features = resnet.fc.in_features
        self.fc = nn.Linear(n_in_features, self.n_labels)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [7]:
def train(model, dataloader, loss_fn, optimizer, device):
    n = len(dataloader.dataset)
    n_batches = len(dataloader)
    running_loss = 0.
    corrects = 0.

    model.train()
    for x, y in dataloader:
        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()
        outputs = model(x)
        loss = loss_fn(outputs, y)
        loss.backward()
        optimizer.step()

        preds = outputs.max(1, keepdim=True)[1]
        corrects += preds.eq(y.view_as(preds)).sum().item()

        running_loss += loss.item()

    avg_loss = running_loss / n_batches
    acc = corrects / n
    return avg_loss, acc

In [8]:
def evaluate(model, dataloader, loss_fn, device):
    n = len(dataloader.dataset)
    n_batches = len(dataloader)
    running_loss = 0.
    corrects = 0.

    model.eval()
    with torch.no_grad():
        for x, y in dataloader:
            x= x.to(device)
            y = y.to(device)
            outputs = model(x)
            loss = loss_fn(outputs, y)
            
            preds = outputs.max(1, keepdim=True)[1]
            corrects += preds.eq(y.view_as(preds)).sum().item()

            running_loss += loss.item()
    
    avg_loss = running_loss / n_batches
    acc = corrects / n
    return avg_loss, acc

In [9]:
class EarlyStopping:
    """
    Early stopping to stop the training when the loss does not improve after
    certain epochs.
    """

    def __init__(self, patience=5, min_delta=0):
        """
        :param patience: how many epochs to wait before stopping when loss is
               not improving
        :param min_delta: minimum difference between new loss and old loss for
               new loss to be considered as an improvement
        """
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss == None:
            self.best_loss = val_loss
        elif self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            # reset counter if validation loss improves
            self.counter = 0
        elif self.best_loss - val_loss < self.min_delta:
            self.counter += 1
            # print(
            #     f"INFO: Early stopping counter {self.counter} of {self.patience}")
            if self.counter >= self.patience:
                # print('INFO: Early stopping')
                self.early_stop = True

In [10]:
LR = 0.05
EPOCHS = 50

In [11]:
model = ResNet18(10).to(device)
loss = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
steps_per_epoch = len(loader_train)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 0.1, epochs=EPOCHS, steps_per_epoch=steps_per_epoch)

In [12]:
train_losses = []
train_accs = []
test_losses = []
test_accs = []


In [13]:
def train_model(epochs=EPOCHS):
    early_stopping = EarlyStopping()

    time_start = time.perf_counter()
    for e in range(epochs):
        epoch_start = time.perf_counter()
        tr_loss, tr_acc = train(model, loader_train, loss, optimizer, device)
        va_loss, va_acc = evaluate(model, loader_test, loss, device)
        scheduler.step()
        epoch_end = time.perf_counter() - epoch_start

        print('[{:3d}/{:d} T:{:s}] Train Loss: {:.4f} Acc: {:.4f}%, Test Loss: {:.4f} Acc: {:.4f}%'.format(
            e+1, epochs, str(datetime.timedelta(seconds=epoch_end)), tr_loss, tr_acc*100, va_loss, va_acc*100))

        early_stopping(tr_loss)
        if early_stopping.early_stop:
                break
        
        train_losses.append(tr_loss)
        train_accs.append(tr_acc)
        test_losses.append(va_loss)
        test_accs.append(va_acc)
    time_elapsed = time.perf_counter()
    print('Total training time: {}'.format(str(datetime.timedelta(seconds=time_elapsed))))


In [14]:
train_model()

[  1/50 T:0:00:14.313130] Train Loss: 1.4162 Acc: 49.5360%, Test Loss: 1.0154 Acc: 65.3800%
[  2/50 T:0:00:12.814520] Train Loss: 0.8725 Acc: 69.2320%, Test Loss: 0.7782 Acc: 73.0600%
[  3/50 T:0:00:12.852163] Train Loss: 0.7441 Acc: 73.9980%, Test Loss: 0.6956 Acc: 76.1400%
[  4/50 T:0:00:12.858415] Train Loss: 0.6647 Acc: 76.8160%, Test Loss: 0.6468 Acc: 77.6400%
[  5/50 T:0:00:12.979197] Train Loss: 0.6130 Acc: 78.4380%, Test Loss: 0.6126 Acc: 79.0600%
[  6/50 T:0:00:13.066810] Train Loss: 0.5700 Acc: 79.8220%, Test Loss: 0.5821 Acc: 79.7800%
[  7/50 T:0:00:12.887367] Train Loss: 0.5313 Acc: 81.3800%, Test Loss: 0.5683 Acc: 80.3800%
[  8/50 T:0:00:12.917864] Train Loss: 0.5109 Acc: 81.8940%, Test Loss: 0.5624 Acc: 80.9700%
[  9/50 T:0:00:12.929956] Train Loss: 0.4769 Acc: 83.2360%, Test Loss: 0.5514 Acc: 81.6500%
[ 10/50 T:0:00:12.821445] Train Loss: 0.4585 Acc: 83.7820%, Test Loss: 0.5386 Acc: 82.0200%
[ 11/50 T:0:00:12.961934] Train Loss: 0.4331 Acc: 84.8260%, Test Loss: 0.5403 Ac