In [1]:
! nvidia-smi

Tue Nov 29 14:29:36 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:02:00.0 Off |                    0 |
| N/A   24C    P0    29W / 250W |      0MiB / 12198MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE...  Off  | 00000000:03:00.0 Off |                    0 |
| N/A   28C    P0    25W / 250W |      0MiB / 12198MiB |      3%      Default |
|       

# imports

In [2]:
# import necessary dependencies
import argparse
import os, sys
import time
import datetime
from tqdm import tqdm_notebook as tqdm

import os
import torch
import torch.nn as nn
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torchvision

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

In [3]:
import random
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7fd82ec5c150>

In [4]:
BATCH_SIZE = 128
EPOCHS = 50
LR = 3e-4
TEMP = 0.5
pretrained_path = '/home/users/zg78/ece661_final/simclr/models/simclr_model_0_200_128_0.5_0.5.pth'

# model

In [5]:
class ResNet_Block(nn.Module):
    def __init__(self, in_chs, out_chs, strides):
        super(ResNet_Block, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=in_chs, out_channels=out_chs,
                      stride=strides, padding=1, kernel_size=3, bias=False),
            nn.BatchNorm2d(out_chs),
            nn.ReLU(True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=out_chs, out_channels=out_chs,
                      stride=1, padding=1, kernel_size=3, bias=False),
            nn.BatchNorm2d(out_chs)
        )

        if in_chs != out_chs:
            self.id_mapping = nn.Sequential(
                nn.Conv2d(in_channels=in_chs, out_channels=out_chs,
                          stride=strides, padding=0, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_chs))
        else:
            self.id_mapping = None
        self.final_activation = nn.ReLU(True)

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        if self.id_mapping is not None:
            x_ = self.id_mapping(x)
        else:
            x_ = x
        return self.final_activation(x_ + out)

class ResNet20Encoder(nn.Module):
    def __init__(self, num_layers=20, num_stem_conv=16, config=(16, 32, 64)):
        super(ResNet20Encoder, self).__init__()
        self.num_layers = num_layers
        self.head_conv = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=num_stem_conv,
                      stride=1, padding=1, kernel_size=3, bias=False),
            nn.BatchNorm2d(num_stem_conv),
            nn.ReLU(True)
        )
        num_layers_per_stage = (num_layers - 2) // 6
        self.body_op = []
        num_inputs = num_stem_conv
        for i in range(len(config)):
            for j in range(num_layers_per_stage):
                if j == 0 and i != 0:
                    strides = 2
                else:
                    strides = 1
                self.body_op.append(ResNet_Block(num_inputs, config[i], strides))
                num_inputs = config[i]
        self.body_op = nn.Sequential(*self.body_op)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.final_fc = nn.Linear(config[-1], 10)

    def forward(self, x):
        out = self.head_conv(x)
        out = self.body_op(out)
        features = self.avg_pool(out)
        return features

    
class SimCLR(nn.Module):
    def __init__(self, num_layers=20, num_stem_conv=16, config=(16, 32, 64), projection_dim=20):
        super(SimCLR, self).__init__()
        self.encoder = ResNet20Encoder(num_layers=num_layers, num_stem_conv=num_stem_conv, config=config)
        self.linear1 = nn.Linear(config[-1], config[-1], bias=False)
        self.linear2 = nn.Linear(config[-1], projection_dim, bias=False)
        
    def forward(self, aug1, aug2):
        aug1_out = self.encoder(aug1).squeeze()
        aug1_out = self.linear1(aug1_out)
        aug1_out = nn.ReLU()(aug1_out)
        aug1_out = self.linear2(aug1_out)
        
        aug2_out = self.encoder(aug2).squeeze()
        aug2_out = self.linear1(aug2_out)
        aug2_out = nn.ReLU()(aug2_out)
        aug2_out = self.linear2(aug2_out)
        
        return aug1_out, aug2_out
    
    
class SimCLR_linear_eval(nn.Module):
    def __init__(self, encoder):
        super(SimCLR_linear_eval, self).__init__()
        self.encoder = encoder
        self.linear = nn.Linear(64, 10, bias=False)
        
        for param in self.encoder.parameters():
            param.requires_grad = False
        
    def forward(self, x):
        out = self.encoder(x).squeeze()
        out = self.linear(out)
        return out

# dataset

In [6]:
train_transform = transforms.Compose(
    [transforms.Resize(size=(32, 32)),
    transforms.ToTensor()]
)
test_transform = transforms.Compose(
    [transforms.Resize(size=(32, 32)),
    transforms.ToTensor()]
)

In [7]:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
trainset, valset = torch.utils.data.random_split(trainset, [int(len(trainset)*0.8), int(len(trainset)*0.2)])

trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
valoader = torch.utils.data.DataLoader(valset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


# loss

In [8]:
criterion = nn.CrossEntropyLoss()

# loop

In [30]:
simclr_model = SimCLR(projection_dim=64)
simclr_model.load_state_dict(torch.load(pretrained_path))
simclr_model.cuda()

simclr_linear_eval_model = SimCLR_linear_eval(simclr_model.encoder)
simclr_linear_eval_model.cuda()

optimizer = torch.optim.Adam(simclr_linear_eval_model.parameters(), lr=LR)

best_loss = 9999999

for epoch_idx in range(EPOCHS):
    epoch_losses = 0
    epoch_correts = 0
    simclr_linear_eval_model.train()
    for batch_idx, data in enumerate(tqdm(trainloader)):
        image, label = data
        image = image.cuda()
        label = label.cuda()
        
        simclr_linear_eval_model.zero_grad()
        out = simclr_linear_eval_model(image)
        
        loss = criterion(out, label)
        loss.backward()
        optimizer.step()

        epoch_losses += loss
    
        pred = torch.argmax(out, dim=1)
        epoch_correts += torch.sum(pred == label).item()
    
    epoch_losses /= len(trainloader)
    epoch_correts /= len(trainset)
        
    
    with torch.no_grad():
        simclr_linear_eval_model.eval()
        
        val_epoch_losses = 0
        val_epoch_correts = 0

        for batch_idx, data in enumerate(tqdm(valoader)):
            image, label = data
            image = image.cuda()
            label = label.cuda()

            out = simclr_linear_eval_model(image)

            loss = criterion(out, label)

            val_epoch_losses += loss

            pred = torch.argmax(out, dim=1)
            val_epoch_correts += torch.sum(pred == label).item()

        val_epoch_losses /= len(valoader)
        val_epoch_correts /= len(valset)

        if val_epoch_losses < best_loss:
            best_loss = val_epoch_losses
            torch.save(simclr_linear_eval_model.state_dict(), f'simclr_model_linear_eval_{EPOCHS}_{BATCH_SIZE}_{TEMP}_{LR}.pth')
    
    print(f'Train Loss {epoch_losses} Acc {epoch_correts} ; Val Loss {val_epoch_losses} Acc {val_epoch_correts}')

  0%|          | 0/391 [00:00<?, ?it/s]

2.0605947971343994 0.25462


  0%|          | 0/391 [00:00<?, ?it/s]

1.8160773515701294 0.34252


  0%|          | 0/391 [00:00<?, ?it/s]

1.764581322669983 0.35662


  0%|          | 0/391 [00:00<?, ?it/s]

1.741021990776062 0.36252


  0%|          | 0/391 [00:00<?, ?it/s]

1.725546956062317 0.36742


  0%|          | 0/391 [00:00<?, ?it/s]

1.7149920463562012 0.36994


  0%|          | 0/391 [00:00<?, ?it/s]

1.7085360288619995 0.37376


  0%|          | 0/391 [00:00<?, ?it/s]

1.70162832736969 0.37588


  0%|          | 0/391 [00:00<?, ?it/s]

1.6974321603775024 0.37672


  0%|          | 0/391 [00:00<?, ?it/s]

1.6920963525772095 0.37888


  0%|          | 0/391 [00:00<?, ?it/s]

1.6891642808914185 0.37892


  0%|          | 0/391 [00:00<?, ?it/s]

1.6853586435317993 0.38228


  0%|          | 0/391 [00:00<?, ?it/s]

1.6815391778945923 0.3835


  0%|          | 0/391 [00:00<?, ?it/s]

1.677746057510376 0.38566


  0%|          | 0/391 [00:00<?, ?it/s]

1.6759467124938965 0.3846


  0%|          | 0/391 [00:00<?, ?it/s]

1.6721161603927612 0.38782


  0%|          | 0/391 [00:00<?, ?it/s]

1.6700966358184814 0.38904


  0%|          | 0/391 [00:00<?, ?it/s]

1.6686663627624512 0.38724


  0%|          | 0/391 [00:00<?, ?it/s]

1.6651744842529297 0.38924


  0%|          | 0/391 [00:00<?, ?it/s]

1.6621557474136353 0.39038


  0%|          | 0/391 [00:00<?, ?it/s]

1.6616922616958618 0.39046


  0%|          | 0/391 [00:00<?, ?it/s]

1.6584020853042603 0.39166


  0%|          | 0/391 [00:00<?, ?it/s]

1.6580966711044312 0.39124


  0%|          | 0/391 [00:00<?, ?it/s]

1.6558587551116943 0.39226


  0%|          | 0/391 [00:00<?, ?it/s]

1.6540682315826416 0.39236


  0%|          | 0/391 [00:00<?, ?it/s]

1.6510957479476929 0.39604


  0%|          | 0/391 [00:00<?, ?it/s]

1.6510989665985107 0.3952


  0%|          | 0/391 [00:00<?, ?it/s]

1.648826003074646 0.39716


  0%|          | 0/391 [00:00<?, ?it/s]

1.6467725038528442 0.39688


  0%|          | 0/391 [00:00<?, ?it/s]

1.6441770792007446 0.3972


  0%|          | 0/391 [00:00<?, ?it/s]

1.6442784070968628 0.39798


  0%|          | 0/391 [00:00<?, ?it/s]

1.642459750175476 0.39946


  0%|          | 0/391 [00:00<?, ?it/s]

1.6413681507110596 0.39882


  0%|          | 0/391 [00:00<?, ?it/s]

1.6398017406463623 0.39928


  0%|          | 0/391 [00:00<?, ?it/s]

1.6378265619277954 0.402


  0%|          | 0/391 [00:00<?, ?it/s]

1.636799693107605 0.40196


  0%|          | 0/391 [00:00<?, ?it/s]

1.6359562873840332 0.40164


  0%|          | 0/391 [00:00<?, ?it/s]

1.6339596509933472 0.40124


  0%|          | 0/391 [00:00<?, ?it/s]

1.6343388557434082 0.40102


  0%|          | 0/391 [00:00<?, ?it/s]

1.6331089735031128 0.40234


  0%|          | 0/391 [00:00<?, ?it/s]

1.6322360038757324 0.4035


  0%|          | 0/391 [00:00<?, ?it/s]

1.6300920248031616 0.40346


  0%|          | 0/391 [00:00<?, ?it/s]

1.629189372062683 0.40374


  0%|          | 0/391 [00:00<?, ?it/s]

1.6277587413787842 0.40568


  0%|          | 0/391 [00:00<?, ?it/s]

1.6275852918624878 0.40424


  0%|          | 0/391 [00:00<?, ?it/s]

1.62680184841156 0.40386


  0%|          | 0/391 [00:00<?, ?it/s]

1.6250429153442383 0.4047


  0%|          | 0/391 [00:00<?, ?it/s]

1.6250852346420288 0.40548


  0%|          | 0/391 [00:00<?, ?it/s]

1.62477445602417 0.40508


  0%|          | 0/391 [00:00<?, ?it/s]

1.6228222846984863 0.40744


In [32]:
with torch.no_grad():
    simclr_linear_eval_model.eval()
    epoch_losses = 0
    epoch_correts = 0
    
    simclr_model.train()
    for batch_idx, data in enumerate(tqdm(trainloader)):
        image, label = data
        
        image = image.cuda()
        label = label.cuda()
        
        out = simclr_linear_eval_model(image)
        loss = criterion(out, label)
        
        epoch_losses += loss
    
        pred = torch.argmax(out, dim=1)
        epoch_correts += torch.sum(pred == label).item()
    
    epoch_losses /= len(trainloader)
    epoch_correts /= len(trainset)
    
    print(epoch_losses.item(), epoch_correts)

  0%|          | 0/391 [00:00<?, ?it/s]

1.6222840547561646 0.40568
