In [48]:
! nvidia-smi

Sat Dec 10 21:18:10 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   73C    P0    33W /  70W |   2368MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [49]:
!mkdir models

mkdir: cannot create directory ‘models’: File exists


# imports

In [50]:
# import necessary dependencies
import argparse
import os, sys
import time
import datetime
from tqdm import tqdm_notebook as tqdm

import os
import torch
import torch.nn as nn
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torchvision

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

In [51]:
import random
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7ff378c3be30>

In [52]:
BATCH_SIZE = 128
EPOCHS = 30
#DECAY = 5e-4
#MOMENTUM = 0.9
LR = 1e-2
label_percentage = 0.01
pretrained_path = "./models/rotnet_base_100_128_0.005.pth"

# Model

In [53]:
class DirectFwd(nn.Module):
    def __init__(self):
        super(DirectFwd, self).__init__()

    def forward(self, x):
        return x

In [54]:
class RotNet(nn.Module):
    def __init__(self):
        super(RotNet, self).__init__()
        self.encoder = torchvision.models.resnet18(pretrained=False)
        self.encoder.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.encoder.maxpool = DirectFwd()
        self.encoder.fc = DirectFwd()
        self.linear1 = nn.Linear(512, 4, bias=False)
        
    def forward(self, x):
        out = self.encoder(x)
        out = self.linear1(out)
  
        return out


class RotNet_linear_eval(nn.Module):
    def __init__(self, encoder):
        super(RotNet_linear_eval, self).__init__()
        self.encoder = encoder
        self.linear = nn.Linear(512, 10, bias=False)
        
    def forward(self, x):
        out = self.encoder(x)
        out = self.linear(out)
        return out

# dataset

In [55]:
train_transform = transforms.Compose(
    [transforms.Resize(size=(32, 32)),
    transforms.ToTensor()]
)
test_transform = transforms.Compose(
    [transforms.Resize(size=(32, 32)),
    transforms.ToTensor()]
)
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
trainset, _ = torch.utils.data.random_split(trainset, [int(len(trainset)*label_percentage), int(len(trainset)*(1-label_percentage))])

trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


# Loss

In [56]:
criterion = nn.CrossEntropyLoss()

# Loop

In [57]:
from torch.optim import optimizer
rot_model = RotNet()
rot_model.load_state_dict(torch.load(pretrained_path))
rot_model.cuda()

rot_linear_eval_model = RotNet_linear_eval(rot_model.encoder)
rot_linear_eval_model.cuda()

#optimizer = torch.optim.SGD(rot_linear_eval_model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=DECAY)
optimizer = torch.optim.Adam(rot_linear_eval_model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20], gamma=0.1)

best_loss = 9999999

for epoch_idx in range(EPOCHS):
    epoch_losses = 0
    epoch_corrects = 0
    rot_linear_eval_model.train()
    for batch_idx, data in enumerate(tqdm(trainloader)):
        image, label = data
        image = image.cuda()
        label = label.cuda()
        
        rot_linear_eval_model.zero_grad()
        out = rot_linear_eval_model(image)
        
        loss = criterion(out, label)
        loss.backward()
        optimizer.step()

        epoch_losses += loss
    
        pred = torch.argmax(out, dim=1)
        epoch_corrects += torch.sum(pred == label).item()
    
    epoch_losses /= len(trainloader)
    epoch_corrects /= len(trainset)
        
    
    with torch.no_grad():
        rot_linear_eval_model.eval()
        
        test_epoch_losses = 0
        test_epoch_corrects = 0

        for batch_idx, data in enumerate(tqdm(testloader)):
            image, label = data
            image = image.cuda()
            label = label.cuda()

            out = rot_linear_eval_model(image)

            loss = criterion(out, label)

            test_epoch_losses += loss

            pred = torch.argmax(out, dim=1)
            test_epoch_corrects += torch.sum(pred == label).item()

        test_epoch_losses /= len(testloader)
        test_epoch_corrects /= len(testset)

        if test_epoch_losses < best_loss:
            best_loss = test_epoch_losses
            torch.save(rot_linear_eval_model.state_dict(), f'models/rot_model_semi_sup_{EPOCHS}_{BATCH_SIZE}_{LR}_{label_percentage}.pth')
    
    scheduler.step()
    print(f'Train Loss {epoch_losses} Acc {epoch_corrects} ; Val Loss {test_epoch_losses} Acc {test_epoch_corrects}')


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 2.1531174182891846 Acc 0.252 ; Val Loss 2.0751307010650635 Acc 0.2466


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 1.5047729015350342 Acc 0.528 ; Val Loss 1.6585148572921753 Acc 0.4199


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.9031744599342346 Acc 0.738 ; Val Loss 1.4936102628707886 Acc 0.52


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.4214649498462677 Acc 0.908 ; Val Loss 1.6024634838104248 Acc 0.5409


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.17891748249530792 Acc 0.948 ; Val Loss 1.8630796670913696 Acc 0.5297


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.05379066243767738 Acc 0.992 ; Val Loss 1.9889124631881714 Acc 0.5414


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.010983077809214592 Acc 1.0 ; Val Loss 1.89837646484375 Acc 0.5692


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.003858822165057063 Acc 1.0 ; Val Loss 2.1829710006713867 Acc 0.5657


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.0027969875372946262 Acc 1.0 ; Val Loss 2.2610130310058594 Acc 0.5771


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.002052253345027566 Acc 1.0 ; Val Loss 2.3022801876068115 Acc 0.5833


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.0005860516685061157 Acc 1.0 ; Val Loss 2.199678421020508 Acc 0.5968


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.0006780865951441228 Acc 1.0 ; Val Loss 2.1579043865203857 Acc 0.6007


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.001122780842706561 Acc 1.0 ; Val Loss 2.1329846382141113 Acc 0.6035


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.0005516715464182198 Acc 1.0 ; Val Loss 2.1224422454833984 Acc 0.6039


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.0009124903590418398 Acc 1.0 ; Val Loss 2.117398262023926 Acc 0.6044


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.00041496119229122996 Acc 1.0 ; Val Loss 2.1125755310058594 Acc 0.6054


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.0004535777843557298 Acc 1.0 ; Val Loss 2.1102778911590576 Acc 0.6065


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.0004502456868067384 Acc 1.0 ; Val Loss 2.1149680614471436 Acc 0.6064


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.0006300032837316394 Acc 1.0 ; Val Loss 2.1055383682250977 Acc 0.6067


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.00030525229522027075 Acc 1.0 ; Val Loss 2.1049656867980957 Acc 0.6071


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.0005707810632884502 Acc 1.0 ; Val Loss 2.104503870010376 Acc 0.6068


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.0003050006926059723 Acc 1.0 ; Val Loss 2.1024017333984375 Acc 0.6069


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.00031779558048583567 Acc 1.0 ; Val Loss 2.1016809940338135 Acc 0.6062


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.0002817433560267091 Acc 1.0 ; Val Loss 2.101166248321533 Acc 0.6063


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.00036498880945146084 Acc 1.0 ; Val Loss 2.102275848388672 Acc 0.6063


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.00038944807602092624 Acc 1.0 ; Val Loss 2.100924491882324 Acc 0.6066


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.00039114910759963095 Acc 1.0 ; Val Loss 2.102259635925293 Acc 0.6064


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.0004550269222818315 Acc 1.0 ; Val Loss 2.1004250049591064 Acc 0.6063


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.0006380460108630359 Acc 1.0 ; Val Loss 2.1032395362854004 Acc 0.6066


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Train Loss 0.0005558226257562637 Acc 1.0 ; Val Loss 2.1037962436676025 Acc 0.6073
