In [1]:
! nvidia-smi

Mon Nov 28 18:21:39 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:81:00.0 Off |                  N/A |
| 35%   28C    P0    50W / 250W |      0MiB / 11019MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

# imports

In [2]:
# import necessary dependencies
import argparse
import os, sys
import time
import datetime
from tqdm import tqdm_notebook as tqdm

import os
import torch
import torch.nn as nn
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torchvision

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

In [3]:
import random
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7f76e9c1b150>

In [32]:
BATCH_SIZE = 128
EPOCHS = 100
LR = 1e-4

# model

In [34]:
class ResNet_Block(nn.Module):
    def __init__(self, in_chs, out_chs, strides):
        super(ResNet_Block, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=in_chs, out_channels=out_chs,
                      stride=strides, padding=1, kernel_size=3, bias=False),
            nn.BatchNorm2d(out_chs),
            nn.ReLU(True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=out_chs, out_channels=out_chs,
                      stride=1, padding=1, kernel_size=3, bias=False),
            nn.BatchNorm2d(out_chs)
        )

        if in_chs != out_chs:
            self.id_mapping = nn.Sequential(
                nn.Conv2d(in_channels=in_chs, out_channels=out_chs,
                          stride=strides, padding=0, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_chs))
        else:
            self.id_mapping = None
        self.final_activation = nn.ReLU(True)

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        if self.id_mapping is not None:
            x_ = self.id_mapping(x)
        else:
            x_ = x
        return self.final_activation(x_ + out)

class ResNet20Encoder(nn.Module):
    def __init__(self, num_layers=20, num_stem_conv=16, config=(16, 32, 64)):
        super(ResNet20Encoder, self).__init__()
        self.num_layers = num_layers
        self.head_conv = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=num_stem_conv,
                      stride=1, padding=1, kernel_size=3, bias=False),
            nn.BatchNorm2d(num_stem_conv),
            nn.ReLU(True)
        )
        num_layers_per_stage = (num_layers - 2) // 6
        self.body_op = []
        num_inputs = num_stem_conv
        for i in range(len(config)):
            for j in range(num_layers_per_stage):
                if j == 0 and i != 0:
                    strides = 2
                else:
                    strides = 1
                self.body_op.append(ResNet_Block(num_inputs, config[i], strides))
                num_inputs = config[i]
        self.body_op = nn.Sequential(*self.body_op)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.final_fc = nn.Linear(config[-1], 10)

    def forward(self, x):
        out = self.head_conv(x)
        out = self.body_op(out)
        features = self.avg_pool(out)
        return features

    
class RotNet(nn.Module):
    def __init__(self, num_layers=20, num_stem_conv=16, config=(16, 32, 64), projection_dim=20):
        super(RotNet, self).__init__()
        self.encoder = ResNet20Encoder(num_layers=num_layers, num_stem_conv=num_stem_conv, config=config)
        self.linear1 = nn.Linear(config[-1], 4, bias=False)
        
    def forward(self, x):
        out = self.encoder(x).squeeze()
        out = self.linear1(out)
        
        return out

# dataset

In [17]:
class CIFAR10Rot(Dataset):

    def __init__(self, base_dataset):

        self.base_dataset = base_dataset
        self.transformed = self.rot()
        
    def rot(self):
        roted_x = []
        roted_y = []
        for img, _ in tqdm(self.base_dataset):
            for idx, angle in enumerate([0, 90, 180, 270]):
                rot_im = torchvision.transforms.functional.rotate(img, angle)
                rot_label = idx
        
                roted_x.append(rot_im)
                roted_y.append(rot_label)
            
        return roted_x, roted_y
        
    
    def __len__(self):
        return len(self.base_dataset) * 4

    def __getitem__(self, idx):
        return self.transformed[0][idx], self.transformed[1][idx]

In [41]:
train_transform = transforms.Compose(
    [transforms.Resize(size=(32, 32)),
    transforms.ToTensor()]
)
test_transform = transforms.Compose(
    [transforms.Resize(size=(32, 32)),
    transforms.ToTensor()]
)
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
trainset, valset = torch.utils.data.random_split(trainset, [int(len(trainset)*0.8), int(len(trainset)*0.2)])
trainset = CIFAR10Rot(trainset)
valset = CIFAR10Rot(valset)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
valloader = torch.utils.data.DataLoader(valset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

testset = CIFAR10Rot(torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform))
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified


  0%|          | 0/40000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

Files already downloaded and verified


  0%|          | 0/10000 [00:00<?, ?it/s]

# loss

In [42]:
criterion = nn.CrossEntropyLoss()

# loop

In [None]:
rot_model = RotNet().cuda()
optimizer = torch.optim.Adam(rot_model.parameters(), lr=LR)

best_loss = 9999999

for epoch_idx in range(EPOCHS):
    epoch_losses = 0
    epoch_correts = 0
    rot_model.train()
    for batch_idx, data in enumerate(tqdm(trainloader)):
        image, label = data
        image = image.cuda()
        label = label.cuda()
        
        rot_model.zero_grad()
        out = rot_model(image)
        
        loss = criterion(out, label)
        loss.backward()
        optimizer.step()

        epoch_losses += loss
    
        pred = torch.argmax(out, dim=1)
        epoch_correts += torch.sum(pred == label).item()
    
    epoch_losses /= len(trainloader)
    epoch_correts /= len(trainset)
        
    
    with torch.no_grad():
        rot_model.eval()
        
        val_epoch_losses = 0
        val_epoch_correts = 0

        for batch_idx, data in enumerate(tqdm(valloader)):
            image, label = data
            image = image.cuda()
            label = label.cuda()

            out = rot_model(image)

            loss = criterion(out, label)

            val_epoch_losses += loss

            pred = torch.argmax(out, dim=1)
            val_epoch_correts += torch.sum(pred == label).item()

        val_epoch_losses /= len(valloader)
        val_epoch_correts /= len(valset)

        if val_epoch_losses < best_loss:
            best_loss = val_epoch_losses
            torch.save(rot_model.state_dict(), f'models/rotnet_base_{EPOCHS}_{BATCH_SIZE}_{LR}.pth')
    
    print(f'Train Loss {epoch_losses} Acc {epoch_correts} ; Val Loss {val_epoch_losses} Acc {val_epoch_correts}')

  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 1.019081711769104 Acc 0.56398125 ; Val Loss 0.9233811497688293 Acc 0.616075


  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.8653017282485962 Acc 0.6424 ; Val Loss 0.9000920653343201 Acc 0.6323


  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.7955254912376404 Acc 0.67485 ; Val Loss 0.7955080270767212 Acc 0.6746


  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.7421718239784241 Acc 0.69943125 ; Val Loss 0.7497031688690186 Acc 0.69675


  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.6983997821807861 Acc 0.71959375 ; Val Loss 0.7250745892524719 Acc 0.707775


  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.6590319871902466 Acc 0.73735625 ; Val Loss 0.7111856937408447 Acc 0.715425


  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.6210445761680603 Acc 0.75461875 ; Val Loss 0.6831956505775452 Acc 0.72725


  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.5898740291595459 Acc 0.76814375 ; Val Loss 0.6989306211471558 Acc 0.7232


  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.5614681243896484 Acc 0.7806125 ; Val Loss 0.667412281036377 Acc 0.73935


  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.5329116582870483 Acc 0.79293125 ; Val Loss 0.6851778030395508 Acc 0.731275


  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.5056622624397278 Acc 0.8047 ; Val Loss 0.7098149061203003 Acc 0.72765


  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.48000267148017883 Acc 0.8156375 ; Val Loss 0.6802400946617126 Acc 0.739225


  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.4549916982650757 Acc 0.8257 ; Val Loss 0.7284902334213257 Acc 0.7303


  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.42969080805778503 Acc 0.83525625 ; Val Loss 0.7409656643867493 Acc 0.725175


  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.40768712759017944 Acc 0.8457 ; Val Loss 0.7327340245246887 Acc 0.74015


  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.3852309584617615 Acc 0.85438125 ; Val Loss 0.9040912389755249 Acc 0.699775


  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.36184826493263245 Acc 0.86385625 ; Val Loss 0.7617225050926208 Acc 0.729825


  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.34049397706985474 Acc 0.871575 ; Val Loss 0.9039665460586548 Acc 0.711275


  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Train Loss 0.3232710361480713 Acc 0.87794375 ; Val Loss 0.7934513092041016 Acc 0.736325


  0%|          | 0/1250 [00:00<?, ?it/s]