In [None]:
# -*- coding: utf-8 -*-

"""
Simple and easy way to use ResNet with PyTorch.

[Deep Residual Learning for Image Recognition] https://arxiv.org/abs/1512.03385
"""

import os
import random
import numpy as np
import torch
import torch.optim as optim
import torchvision.transforms as transforms

from multiprocessing import cpu_count
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from torch.utils.data import DataLoader, SubsetRandomSampler
from torch.nn import CrossEntropyLoss
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18 as resnet

In [None]:
# For reproducibility
seed = 42

os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f9ffeb0e9d0>

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

trainset = CIFAR10(root='./cifar10', train=True, transform=transform, download=True)
validset = CIFAR10(root='./cifar10', train=True, transform=transform)
testset = CIFAR10(root='./cifar10', train=False, transform=transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar10/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./cifar10/cifar-10-python.tar.gz to ./cifar10


In [None]:
# Train set, validation set split
train_idx, valid_idx = train_test_split(np.arange(len(trainset)), test_size=0.1, random_state=42, shuffle=True, stratify=trainset.targets)

batch_size = 1000
num_workers = int(cpu_count() / 2)

train_loader = DataLoader(trainset, batch_size=batch_size, sampler=SubsetRandomSampler(train_idx), num_workers=num_workers)
valid_loader = DataLoader(validset, batch_size=batch_size, sampler=SubsetRandomSampler(valid_idx), num_workers=num_workers)
test_loader = DataLoader(testset, batch_size=batch_size, num_workers=num_workers)

In [None]:
# Device setting
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
net = resnet().to(device)

criterion = CrossEntropyLoss()
optimizer = optim.Adam(params=net.parameters(), lr=0.1, weight_decay=0.0001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, patience=3, verbose=True)

In [None]:
train_total = len(train_idx)
valid_total = len(valid_idx)

train_batches = len(train_loader)
valid_batches = len(valid_loader)

# Variables for lr scheduling and early stopping
best_valid_loss = 1024    # Any large number will suffice
patience = 0    # Bad epoch counter

In [None]:
# %%time

for epoch in range(5):
    # Train
    net.train()
    
    train_loss = 0
    train_correct = 0
    
    for x, y in train_loader:
        x = x.to(device)
        y = y.to(device)
        outputs = net(x)
        loss = criterion(outputs, y)
        
        optimizer.zero_grad()
        
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        train_correct += predicted.eq(y).sum().item()
        
    train_loss = train_loss / train_batches
    train_acc = train_correct / train_total
    
    # Validate
    net.eval()
    
    valid_loss = 0
    valid_correct = 0
    
    with torch.no_grad():
        for x, y in valid_loader:
            x = x.to(device)
            y = y.to(device)
            outputs = net(x)
            loss = criterion(outputs, y)
            
            valid_loss += loss.item()
            _, predicted = outputs.max(1)
            valid_correct += predicted.eq(y).sum().item()
            
    valid_loss = valid_loss / valid_batches
    valid_acc = valid_correct / valid_total
    
    # Save best model
    if best_valid_loss > valid_loss:
        torch.save(net.state_dict(), './best_resnet.pth')
        best_valid_loss = valid_loss
        patience = 0
        
    print('[%2d] TRAIN loss: %.3f, acc: %.3f, lr: %f .... VALID loss: %.3F, acc: %.3f, best_loss: %.3f .... PATIENCE %d' % (epoch+1, train_loss, train_acc, optimizer.param_groups[0]['lr'], valid_loss, valid_acc, best_valid_loss, patience))
    
    scheduler.step(metrics=valid_loss)
    
    # Break training loop if no improvement for 5 consecutive epochs
    if patience == 2:
        break
        
    patience += 1

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


[ 1] TRAIN loss: 2.948, acc: 0.177, lr: 0.100000 .... VALID loss: 2.190, acc: 0.179, best_loss: 2.190 .... PATIENCE 0
[ 2] TRAIN loss: 1.827, acc: 0.309, lr: 0.100000 .... VALID loss: 1.812, acc: 0.339, best_loss: 1.812 .... PATIENCE 0
[ 3] TRAIN loss: 1.568, acc: 0.421, lr: 0.100000 .... VALID loss: 1.870, acc: 0.307, best_loss: 1.812 .... PATIENCE 1
[ 4] TRAIN loss: 1.636, acc: 0.429, lr: 0.100000 .... VALID loss: 2.354, acc: 0.303, best_loss: 1.812 .... PATIENCE 2


In [None]:
# Load best model
loaded = resnet().to(device)
loaded.load_state_dict(torch.load('./best_resnet.pth'))

<All keys matched successfully>

In [None]:
# %%time

# Test
loaded.eval()

test_loss = 0
test_correct = 0

with torch.no_grad():
    for i, (x, y) in enumerate(test_loader):
        x = x.to(device)
        y = y.to(device)
        outputs = loaded(x)
        loss = criterion(outputs, y)
        
        test_loss += loss.item()
        _, predicted = outputs.max(1)
        test_correct += predicted.eq(y).sum().item()
        
        if i == 0:
            test_preds = predicted
        else:
            test_preds = torch.cat((test_preds, predicted), dim=0)
            
test_preds = test_preds.cpu()

print('TEST loss: %.4f, acc: %.4f' % (test_loss/len(test_loader), test_correct/len(testset)))

TEST loss: 1.8122, acc: 0.3345


In [None]:
print(classification_report(testset.targets, test_preds, target_names=testset.classes))

              precision    recall  f1-score   support

    airplane       0.36      0.54      0.43      1000
  automobile       0.48      0.42      0.45      1000
        bird       0.21      0.48      0.29      1000
         cat       0.24      0.25      0.25      1000
        deer       0.00      0.00      0.00      1000
         dog       0.29      0.44      0.35      1000
        frog       0.45      0.13      0.20      1000
       horse       0.86      0.13      0.23      1000
        ship       0.44      0.35      0.39      1000
       truck       0.37      0.60      0.46      1000

    accuracy                           0.33     10000
   macro avg       0.37      0.33      0.31     10000
weighted avg       0.37      0.33      0.31     10000



  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
from torchvision.models import resnet

In [None]:
# Modify ResNet
class ResNet_mod(resnet.ResNet):
    def __init__(self, block, layers, num_classes=10):
        super().__init__()
        self.layer1 = self._make_layer(block, 128, layers[0])
        
class Bottleneck_mod(resnet.Bottleneck):
    def __init__(self, in_planes, out_planes):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3)

In [None]:
# Get pre-trained model (with ImageNet)
pretrained = resnet(pretrained=True)

TypeError: ignored