In [1]:
import datetime
import time

import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from torchinfo import summary

In [2]:
%matplotlib inline

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

cuda


In [3]:
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation((-7, 7)),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261)),
])

In [4]:
DATA_PATH = 'data'
BATCH_SIZE = 256

In [5]:
dataset_train = tv.datasets.CIFAR10(DATA_PATH, train=True, download=True, transform=transform_train)
dataset_test = tv.datasets.CIFAR10(DATA_PATH, train=False, download=True, transform=transform_test)

loader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
loader_test = DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=True)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
# class ResNet18(nn.Module):
#     def __init__(self, n_labels) -> None:
#         super(ResNet18, self).__init__()
#         self.n_labels = n_labels

#         # Load pre-trained resnet model
#         resnet = tv.models.resnet18(weights=tv.models.ResNet18_Weights.IMAGENET1K_V1, progress=False)

#         self.conv1 = resnet.conv1
#         self.bn1 = resnet.bn1
#         self.relu = resnet.relu
#         self.maxpool = resnet.maxpool
#         self.layer1 = resnet.layer1
#         self.layer2 = resnet.layer2
#         self.layer3 = resnet.layer3
#         self.layer4 = resnet.layer4
#         self.avgpool = resnet.avgpool
#         n_in_features = resnet.fc.in_features
#         self.fc = nn.Linear(n_in_features, self.n_labels)

#     def forward(self, x):
#         x = self.conv1(x)
#         x = self.bn1(x)
#         x = self.relu(x)
#         x = self.maxpool(x)

#         x = self.layer1(x)
#         x = self.layer2(x)
#         x = self.layer3(x)
#         x = self.layer4(x)

#         x = self.avgpool(x)
#         x = torch.flatten(x, 1)
#         x = self.fc(x)
#         return x

In [7]:
# i = iter(dataset_train)
# x, y = next(i)
# input_size = tuple([BATCH_SIZE] + list(x.size()))
# print('input_size:', input_size)

# summary(model=ResNet18(10), input_size=input_size)

In [8]:
def create_model():
    model = tv.models.resnet18(weights=None, num_classes=10)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    return model

In [9]:
# model = create_model()
# summary(model, input_size=input_size)

In [10]:
def train(model, dataloader, loss_fn, optimizer, device):
    n = len(dataloader.dataset)
    n_batches = len(dataloader)
    running_loss = 0.
    corrects = 0.

    model.train()
    for x, y in dataloader:
        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()
        outputs = model(x)
        loss = loss_fn(outputs, y)
        loss.backward()
        optimizer.step()

        preds = outputs.max(1, keepdim=True)[1]
        corrects += preds.eq(y.view_as(preds)).sum().item()

        running_loss += loss.item()

    avg_loss = running_loss / n_batches
    acc = corrects / n
    return avg_loss, acc

In [11]:
def evaluate(model, dataloader, loss_fn, device):
    n = len(dataloader.dataset)
    n_batches = len(dataloader)
    running_loss = 0.
    corrects = 0.

    model.eval()
    with torch.no_grad():
        for x, y in dataloader:
            x= x.to(device)
            y = y.to(device)
            outputs = model(x)
            loss = loss_fn(outputs, y)
            
            preds = outputs.max(1, keepdim=True)[1]
            corrects += preds.eq(y.view_as(preds)).sum().item()

            running_loss += loss.item()
    
    avg_loss = running_loss / n_batches
    acc = corrects / n
    return avg_loss, acc

In [12]:
class EarlyStopping:
    """
    Early stopping to stop the training when the loss does not improve after
    certain epochs.
    """

    def __init__(self, patience=5, min_delta=0):
        """
        :param patience: how many epochs to wait before stopping when loss is
               not improving
        :param min_delta: minimum difference between new loss and old loss for
               new loss to be considered as an improvement
        """
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss == None:
            self.best_loss = val_loss
        elif self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            # reset counter if validation loss improves
            self.counter = 0
        elif self.best_loss - val_loss < self.min_delta:
            self.counter += 1
            # print(
            #     f"INFO: Early stopping counter {self.counter} of {self.patience}")
            if self.counter >= self.patience:
                # print('INFO: Early stopping')
                self.early_stop = True

In [13]:
LR = 0.05
EPOCHS = 50

PYTORCH LIGHTNING CIFAR10 ~94% BASELINE TUTORIA [URL](https://pytorch-lightning.readthedocs.io/en/latest/notebooks/lightning_examples/cifar10-baseline.html)

In [14]:
model = create_model().to(device)
loss = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
steps_per_epoch = len(loader_train)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 0.1, epochs=EPOCHS, steps_per_epoch=steps_per_epoch)

In [15]:
train_losses = []
train_accs = []
test_losses = []
test_accs = []


In [16]:
def train_model(epochs=EPOCHS):
    early_stopping = EarlyStopping()

    time_start = time.perf_counter()
    for e in range(epochs):
        epoch_start = time.perf_counter()
        tr_loss, tr_acc = train(model, loader_train, loss, optimizer, device)
        va_loss, va_acc = evaluate(model, loader_test, loss, device)
        scheduler.step()
        epoch_end = time.perf_counter() - epoch_start

        print('[{:3d}/{:d} T:{:s}] Train Loss: {:.4f} Acc: {:.4f}%, Test Loss: {:.4f} Acc: {:.4f}%'.format(
            e+1, epochs, str(datetime.timedelta(seconds=epoch_end)), tr_loss, tr_acc*100, va_loss, va_acc*100))

        early_stopping(tr_loss)
        if early_stopping.early_stop:
                break
        
        train_losses.append(tr_loss)
        train_accs.append(tr_acc)
        test_losses.append(va_loss)
        test_accs.append(va_acc)
    time_elapsed = time.perf_counter()
    print('Total training time: {}'.format(str(datetime.timedelta(seconds=time_elapsed))))


In [17]:
train_model()

[  1/50 T:0:00:36.450871] Train Loss: 1.7544 Acc: 34.1340%, Test Loss: 1.5617 Acc: 42.3500%
[  2/50 T:0:00:35.660056] Train Loss: 1.3262 Acc: 51.5280%, Test Loss: 1.2019 Acc: 56.2400%
[  3/50 T:0:00:34.768923] Train Loss: 1.1093 Acc: 59.9660%, Test Loss: 1.0699 Acc: 62.9000%
[  4/50 T:0:00:35.035917] Train Loss: 0.9778 Acc: 65.1060%, Test Loss: 1.0295 Acc: 64.2700%
[  5/50 T:0:00:34.952071] Train Loss: 0.8822 Acc: 68.7700%, Test Loss: 0.9051 Acc: 68.1500%
[  6/50 T:0:00:34.731951] Train Loss: 0.8114 Acc: 71.0500%, Test Loss: 0.8415 Acc: 70.8000%
[  7/50 T:0:00:35.513442] Train Loss: 0.7361 Acc: 73.8560%, Test Loss: 0.6907 Acc: 75.7700%
[  8/50 T:0:00:35.180229] Train Loss: 0.6729 Acc: 76.3320%, Test Loss: 0.6986 Acc: 76.2700%
[  9/50 T:0:00:34.916122] Train Loss: 0.6347 Acc: 77.5920%, Test Loss: 0.7625 Acc: 74.7000%
[ 10/50 T:0:00:35.761907] Train Loss: 0.5841 Acc: 79.5480%, Test Loss: 0.6076 Acc: 79.6600%
[ 11/50 T:0:00:35.134597] Train Loss: 0.5427 Acc: 80.8960%, Test Loss: 0.5851 Ac

In [18]:
import os

PATH_OUTPUTS = 'outputs'
if not os.path.exists(PATH_OUTPUTS):
    os.mkdir(PATH_OUTPUTS)
PATH_MODEL = os.path.join(PATH_OUTPUTS, 'resnet18_cifar10.pt')
torch.save(model, PATH_MODEL)