In [2]:
! nvidia-smi

Sat Dec 10 20:32:02 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   67C    P0    31W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
!mkdir models

mkdir: cannot create directory ‘models’: File exists


# imports

In [4]:
# import necessary dependencies
import argparse
import os, sys
import time
import datetime
from tqdm import tqdm_notebook as tqdm

import os
import torch
import torch.nn as nn
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torchvision

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

In [5]:
import random
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7f098a5c9c10>

In [15]:
BATCH_SIZE = 192
EPOCHS = 50
DECAY = 5e-4
MOMENTUM = 0.9
LR = 0.01
pretrained_path = "./models/rotnet_base_100_128_0.005.pth"

# model

In [7]:
class DirectFwd(nn.Module):
    def __init__(self):
        super(DirectFwd, self).__init__()

    def forward(self, x):
        return x

In [8]:
class RotNet(nn.Module):
    def __init__(self):
        super(RotNet, self).__init__()
        self.encoder = torchvision.models.resnet18(pretrained=False)
        self.encoder.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.encoder.maxpool = DirectFwd()
        self.encoder.fc = DirectFwd()
        self.linear1 = nn.Linear(512, 4, bias=False)
        
    def forward(self, x):
        out = self.encoder(x)
        out = self.linear1(out)
  
        return out


class RotNet_linear_eval(nn.Module):
    def __init__(self, encoder):
        super(RotNet_linear_eval, self).__init__()
        self.encoder = encoder
        self.linear = nn.Linear(512, 10, bias=False)
        self.relu = nn.ReLU()
        
        for param in self.encoder.parameters():
            param.requires_grad = False
        
    def forward(self, x):
        out = self.encoder(x)
        out = self.linear(out)
        return out

# dataset

In [9]:
train_transform = transforms.Compose(
    [transforms.Resize(size=(32, 32)),
    transforms.ToTensor()]
)
test_transform = transforms.Compose(
    [transforms.Resize(size=(32, 32)),
    transforms.ToTensor()]
)
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


# Loss

In [10]:
criterion = nn.CrossEntropyLoss()

# Loop

In [16]:
rot_model = RotNet()
rot_model.load_state_dict(torch.load(pretrained_path))
rot_model.cuda()

rot_linear_eval_model = RotNet_linear_eval(rot_model.encoder)
#rot_linear_eval_model.load_state_dict(torch.load("models/rot_model_semi_sup_30_128_0.001_0.1.pth"))
rot_linear_eval_model.cuda()

#optimizer = torch.optim.SGD(rot_linear_eval_model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=DECAY)
optimizer = torch.optim.Adam(rot_linear_eval_model.parameters(), lr=LR)
#scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 40, 60, 80], gamma=0.2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.2)

best_loss = 9999999

for epoch_idx in range(EPOCHS):
    epoch_losses = 0
    epoch_corrects = 0
    rot_linear_eval_model.train()
    for batch_idx, data in enumerate(tqdm(trainloader)):
        image, label = data
        image = image.cuda()
        label = label.cuda()
        
        rot_linear_eval_model.zero_grad()
        out = rot_linear_eval_model(image)
        
        loss = criterion(out, label)
        loss.backward()
        optimizer.step()

        epoch_losses += loss
    
        pred = torch.argmax(out, dim=1)
        epoch_corrects += torch.sum(pred == label).item()
    
    epoch_losses /= len(trainloader)
    epoch_corrects /= len(trainset)
        
    
    with torch.no_grad():
        rot_linear_eval_model.eval()
        
        test_epoch_losses = 0
        test_epoch_corrects = 0

        for batch_idx, data in enumerate(tqdm(testloader)):
            image, label = data
            image = image.cuda()
            label = label.cuda()

            out = rot_linear_eval_model(image)

            loss = criterion(out, label)

            test_epoch_losses += loss

            pred = torch.argmax(out, dim=1)
            test_epoch_corrects += torch.sum(pred == label).item()

        test_epoch_losses /= len(testloader)
        test_epoch_corrects /= len(testset)

        if test_epoch_losses < best_loss:
            best_loss = test_epoch_losses
            torch.save(rot_linear_eval_model.state_dict(), f'models/rot_model_linear_eval_{EPOCHS}_{BATCH_SIZE}_{LR}.pth')
    
    scheduler.step()
    print(f'Train Loss {epoch_losses} Acc {epoch_corrects} ; Val Loss {test_epoch_losses} Acc {test_epoch_corrects}')


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 1.263500452041626 Acc 0.56432 ; Val Loss 1.101849913597107 Acc 0.6173


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 1.0335747003555298 Acc 0.64002 ; Val Loss 1.0160841941833496 Acc 0.6444


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.9753923416137695 Acc 0.65382 ; Val Loss 0.9965764880180359 Acc 0.6414


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.9465436935424805 Acc 0.66394 ; Val Loss 0.9645799994468689 Acc 0.6544


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.9264123439788818 Acc 0.67022 ; Val Loss 0.9587141871452332 Acc 0.6565


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.9139575958251953 Acc 0.6744 ; Val Loss 0.9551270008087158 Acc 0.6604


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.9034285545349121 Acc 0.67944 ; Val Loss 0.9392256140708923 Acc 0.6671


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8973382711410522 Acc 0.6805 ; Val Loss 0.945587694644928 Acc 0.6598


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.890230119228363 Acc 0.6825 ; Val Loss 0.9433717131614685 Acc 0.6637


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8841882944107056 Acc 0.6846 ; Val Loss 0.9390118718147278 Acc 0.663


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8677332997322083 Acc 0.68988 ; Val Loss 0.9301914572715759 Acc 0.6691


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.866217851638794 Acc 0.69116 ; Val Loss 0.9308599829673767 Acc 0.6672


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8656818866729736 Acc 0.69242 ; Val Loss 0.9280788898468018 Acc 0.6667


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8642473816871643 Acc 0.69162 ; Val Loss 0.9278239011764526 Acc 0.668


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8637024164199829 Acc 0.69258 ; Val Loss 0.9300426840782166 Acc 0.6669


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.862601101398468 Acc 0.69352 ; Val Loss 0.9291447401046753 Acc 0.6644


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8629745841026306 Acc 0.69296 ; Val Loss 0.9283120632171631 Acc 0.6653


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8629223704338074 Acc 0.69244 ; Val Loss 0.9260956645011902 Acc 0.6667


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8619676232337952 Acc 0.69314 ; Val Loss 0.9251807928085327 Acc 0.6681


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8611854910850525 Acc 0.69228 ; Val Loss 0.9274789690971375 Acc 0.6669


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8569301962852478 Acc 0.69426 ; Val Loss 0.9268310070037842 Acc 0.6668


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8576672077178955 Acc 0.69426 ; Val Loss 0.9261863231658936 Acc 0.6676


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8572934865951538 Acc 0.69442 ; Val Loss 0.9261188507080078 Acc 0.6666


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8571248650550842 Acc 0.69334 ; Val Loss 0.9231473207473755 Acc 0.6705


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8558810949325562 Acc 0.6952 ; Val Loss 0.9256051182746887 Acc 0.6702


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8567264676094055 Acc 0.69466 ; Val Loss 0.9250961542129517 Acc 0.6691


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8557194471359253 Acc 0.69448 ; Val Loss 0.9255789518356323 Acc 0.6679


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8562222719192505 Acc 0.69452 ; Val Loss 0.924993634223938 Acc 0.6696


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8555382490158081 Acc 0.69526 ; Val Loss 0.9250975847244263 Acc 0.6687


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8549219369888306 Acc 0.69408 ; Val Loss 0.9250855445861816 Acc 0.6685


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8543279767036438 Acc 0.69508 ; Val Loss 0.9241154193878174 Acc 0.669


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8559141755104065 Acc 0.69454 ; Val Loss 0.9247305393218994 Acc 0.668


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.853996217250824 Acc 0.6957 ; Val Loss 0.924643337726593 Acc 0.6686


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8546482920646667 Acc 0.69564 ; Val Loss 0.9247221946716309 Acc 0.6679


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8542227745056152 Acc 0.69576 ; Val Loss 0.9253025650978088 Acc 0.6676


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.854163646697998 Acc 0.69572 ; Val Loss 0.9256680011749268 Acc 0.6676


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8542852401733398 Acc 0.69494 ; Val Loss 0.9244018793106079 Acc 0.6695


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8546914458274841 Acc 0.69564 ; Val Loss 0.9255918860435486 Acc 0.6687


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8556943535804749 Acc 0.6961 ; Val Loss 0.9242181181907654 Acc 0.6701


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8547452092170715 Acc 0.69558 ; Val Loss 0.92783522605896 Acc 0.6662


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8541870713233948 Acc 0.6945 ; Val Loss 0.9239725470542908 Acc 0.6696


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8541349172592163 Acc 0.69498 ; Val Loss 0.9250620603561401 Acc 0.6677


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8545016646385193 Acc 0.69518 ; Val Loss 0.9241833090782166 Acc 0.6694


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8540940284729004 Acc 0.69716 ; Val Loss 0.9244322776794434 Acc 0.6704


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8540934324264526 Acc 0.69422 ; Val Loss 0.9246264100074768 Acc 0.6699


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8548343181610107 Acc 0.69566 ; Val Loss 0.9251232743263245 Acc 0.6699


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8546153903007507 Acc 0.69572 ; Val Loss 0.9251518845558167 Acc 0.6686


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8539098501205444 Acc 0.69616 ; Val Loss 0.9260851144790649 Acc 0.6676


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8539233803749084 Acc 0.69542 ; Val Loss 0.9246349930763245 Acc 0.6677


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

Train Loss 0.8539857864379883 Acc 0.69568 ; Val Loss 0.9236862659454346 Acc 0.6697


# Test

In [None]:
rot_model = RotNet()
rot_model.load_state_dict(torch.load(pretrained_path))
rot_model.cuda()

rot_linear_eval_test_model = RotNet_linear_eval(rot_model.encoder)
rot_linear_eval_test_model.load_state_dict(torch.load('models/rot_model_linear_eval_30_192_0.01.pth'))
rot_linear_eval_test_model.cuda()

with torch.no_grad():
    rot_linear_eval_test_model.eval()
        
    test_epoch_losses = 0
    test_epoch_corrects = 0

    for batch_idx, data in enumerate(tqdm(testloader)):
        image, label = data
        image = image.cuda()
        label = label.cuda()

        out = rot_linear_eval_test_model(image)

        loss = criterion(out, label)

        test_epoch_losses += loss

        pred = torch.argmax(out, dim=1)
        test_epoch_corrects += torch.sum(pred == label).item()

    test_epoch_losses /= len(testloader)
    test_epoch_corrects /= len(testset)

print(f'Test Loss {test_epoch_losses} Acc {test_epoch_corrects}')
