<a href="https://www.kaggle.com/code/hsw1212/cirfar10-classification-resnet50-81?scriptVersionId=139383501" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [25]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# Load cifar10 data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

x_train = x_train
x_test = x_test

# x_train = np.moveaxis(x_train, [0,1,2,3], [0,2,3,1])
# x_test = np.moveaxis(x_test, [0,1,2,3], [0,2,3,1])

y_train = tf.keras.utils.to_categorical(y_train,num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test,num_classes=10)

print(x_train.shape)
print(y_train.shape)
labels = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
# labels

(50000, 32, 32, 3)
(50000, 10)


In [26]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from torchvision.models import resnet50

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Build ResNet50 model
model = resnet50(pretrained = True)
# ResNet50 = ResNet50.float()

# Dont train weights in feature extractor layers
# for param in ResNet50.parameters():
#     param.requires_grad = False 
'''
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
'''
# Connect to fully connnected layer
model.fc = nn.Sequential(
    nn.Linear(2048,512),
    nn.ReLU(inplace=True),
    nn.Linear(512,64),
    nn.ReLU(inplace=True),
    nn.Linear(64,10)
)
# model.apply(init_weights)
model.to(device)

model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [27]:
from sklearn.model_selection import train_test_split

# Set data loader
class cifar10Dataset(Dataset):
    def __init__(self, imgs, labels, transform=None):
        self.imgs = imgs
        self.labels = labels
        self.transforms = transform

    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, index):
        x = self.imgs[index]
        y = self.labels[index]

        if self.transforms:
            x = self.transforms(x)

        x = x.float()
        return x, y

# Split train and validation
x_train, x_val, y_train, y_val = train_test_split(x_train, 
                                                  y_train, 
                                                  test_size=0.2, 
                                                  random_state=100,
                                                  shuffle=True)

learningRate = 0.001
batch_size = 64

# Set loss function and optimiser
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr = learningRate)

stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
transform = transforms.Compose([transforms.ToTensor(),
                                # transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
                                # transforms.RandomHorizontalFlip(),
                                transforms.Normalize(*stats,inplace=True)])

train_dataset = cifar10Dataset(x_train, y_train, transform)
val_dataset = cifar10Dataset(x_val, y_val, transform)
test_dataset = cifar10Dataset(x_test, y_test, transform)
train_loader = DataLoader(dataset=train_dataset, batch_size= batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size= batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size= batch_size, shuffle=True)


In [28]:
# Train loop
num_epoch = 25
best_val_accuracy = 0.0

for epoch in range(num_epoch):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            _, labels = torch.max(labels, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        val_accuracy = correct / total
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), 'best_model.pth')
        
        print(f'Epoch [{epoch+1}/{num_epoch}], Validation Accuracy: {val_accuracy:.4f}, BEST Accuracy: {best_val_accuracy:.4f}')

print('Training Finished!')


Epoch [1/25], Validation Accuracy: 0.7224, BEST Accuracy: 0.7224
Epoch [2/25], Validation Accuracy: 0.7503, BEST Accuracy: 0.7503
Epoch [3/25], Validation Accuracy: 0.7716, BEST Accuracy: 0.7716
Epoch [4/25], Validation Accuracy: 0.7982, BEST Accuracy: 0.7982
Epoch [5/25], Validation Accuracy: 0.7992, BEST Accuracy: 0.7992
Epoch [6/25], Validation Accuracy: 0.7952, BEST Accuracy: 0.7992
Epoch [7/25], Validation Accuracy: 0.8004, BEST Accuracy: 0.8004
Epoch [8/25], Validation Accuracy: 0.7898, BEST Accuracy: 0.8004
Epoch [9/25], Validation Accuracy: 0.8091, BEST Accuracy: 0.8091
Epoch [10/25], Validation Accuracy: 0.8051, BEST Accuracy: 0.8091
Epoch [11/25], Validation Accuracy: 0.8170, BEST Accuracy: 0.8170
Epoch [12/25], Validation Accuracy: 0.8084, BEST Accuracy: 0.8170
Epoch [13/25], Validation Accuracy: 0.7975, BEST Accuracy: 0.8170
Epoch [14/25], Validation Accuracy: 0.8041, BEST Accuracy: 0.8170
Epoch [15/25], Validation Accuracy: 0.7998, BEST Accuracy: 0.8170
Epoch [16/25], Vali

In [29]:
model.load_state_dict(torch.load('best_model.pth'))
model.eval()

total = 0
correct = 0

for images, labels in test_loader:
    images, labels = images.to(device), labels.to(device)
    batch_outputs = model(images) 
    _, predicted = torch.max(batch_outputs, 1)
    _, labels = torch.max(labels, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
    
test_accuracy = correct / total
print('Test accuracy: ', test_accuracy)

Test accuracy:  0.812
