<a href="https://colab.research.google.com/github/colinZejda/Summer2023_UCI_ML_Research/blob/main/3_using_resnet_on_imagenet_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
from tqdm import tqdm

In [None]:
# to handle runtime CUDA error: device-side assert triggered
import os
#os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#device = 'cpu'

In [None]:
# DATA LOADER FUNC
	# note: this function uses Imagefolder from pytorch
	# we want to use the ImageNet dataset
def data_loader(train_data_dir,
		validation_data_dir,
		batch_size,
		random_seed=42,
		valid_size=0.1,
		shuffle=True,
		test=False):

    # normalize the data (helps with convergence during training)
    normalize = transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010],
    )

    # define transforms
    transform_train = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(),
            normalize
    ])

    transform_valid = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            normalize
    ])

    # load the dataset
    train_dataset = ImageFolder(root = train_data_dir, transform = transform_train)
    valid_dataset = ImageFolder(root = validation_data_dir, transform = transform_valid)

    num_train = len(train_dataset)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))

    if shuffle:
        np.random.seed(42)
        np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]          # using SubSetRandomSampler guarantees shuffling :)
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=1)

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=batch_size, sampler=valid_sampler, num_workers=1)

    return (train_loader, valid_loader)

In [None]:
# ImageNet dataset
path_to_imagenet = '~/Desktop/intern_folder:)/imageNet_data'
train_loader, valid_loader = data_loader(train_data_dir=path_to_imagenet+'/train',
				                                  validation_data_dir=path_to_imagenet+'/val',
				                                  batch_size=64)

In [None]:
# DEFINE RESIDUAL BLOCK CLASS
# (to be reused several time in ResNet model)
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride = 1, downsample = None):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Sequential(
                        nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),
                        nn.BatchNorm2d(out_channels),
                        nn.ReLU())
        self.conv2 = nn.Sequential(
                        nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
                        nn.BatchNorm2d(out_channels))
        self.downsample = downsample
        self.relu = nn.ReLU()
        self.out_channels = out_channels

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.conv2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

In [None]:
# DEFINE RESNET CLASS
# model + feed forward func
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes = 100):
        super(ResNet, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Sequential(
                        nn.Conv2d(3, 64, kernel_size = 7, stride = 2, padding = 3),
                        nn.BatchNorm2d(64),
                        nn.ReLU())
        self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
        self.layer0 = self._make_layer(block, 64, layers[0], stride = 1)
        self.layer1 = self._make_layer(block, 128, layers[1], stride = 2)
        self.layer2 = self._make_layer(block, 256, layers[2], stride = 2)
        self.layer3 = self._make_layer(block, 512, layers[3], stride = 2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes:

            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=stride),
                nn.BatchNorm2d(planes),
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

In [None]:
# SETTING HYPERPARAMETERS
num_classes = 100
num_epochs = 20
batch_size = 64
learning_rate = 0.01

model = ResNet(ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)       # instantiate model

# Loss func (criterion) and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay = 0.001, momentum = 0.9)

In [None]:
# TRAIN MODEL
import gc
total_step = len(train_loader)

for epoch in range(num_epochs):
    for i, (images, labels) in tqdm(enumerate(train_loader)):
        # Move tensors to the configured device
        images = images.to(device)
        labels = labels.to(device)
        print(labels)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        del images, labels, outputs
        torch.cuda.empty_cache()
        gc.collect()

    print ('Epoch [{}/{}], Loss: {:.4f}'
                   .format(epoch+1, num_epochs, loss.item()))

    # VALIDATION (per outer loop)
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in tqdm(valid_loader):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            del images, labels, outputs

        print('Accuracy of the network on the {} validation images: {} %'.format(5000, 100 * correct / total))

In [None]:
# SAVE THE MODEL
path_to_save = "~/Desktop/intern-folder:)/colin_code/resnet_model_on_imagenet_data.pth"

# Approach 1: save entire model (architecture + parameters)
    # note: parameters are learnable elements that the model adjusts (weights + biases)
# torch.save(model, path_to_save)
# loaded_model = torch.load(path_to_save)

# Approach 2: save model state dict (better bc it only saves the weights)
torch.save(model.state_dict(), path_to_save)
# model = ResNet(ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)
# model.load_state_dict(torch.load(path_to_save))


In [None]:
# TESTING MODEL
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in tqdm(test_loader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        del images, labels, outputs

    print('Accuracy of the network on the {} test images: {} %'.format(10000, 100 * correct / total))