In [1]:
import torch 
import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR
import numpy as np
from tensorboardX import SummaryWriter
from tqdm import tqdm_notebook as tqdm 
import os
import time
from datetime import datetime 
from knn_monitor import knn_monitor
from model import ContrastiveLearner
from data import Loader, cifar_test_transforms, cifar_train_transforms
from logger import Logger
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class ContrastiveLoss(nn.Module):

    def __init__(self, temp=0.5, normalize= False):
        super().__init__()
        self.temp = temp
        self.normalize = normalize

    def forward(self,xi,xj):

        z1 = F.normalize(xi, dim=1)
        z2 = F.normalize(xj, dim=1)
        
        N, Z = z1.shape 
        device = z1.device 
        
        representations = torch.cat([z1, z2], dim=0)
        similarity_matrix = torch.mm(representations, representations.T) / self.temp

        # create positive matches
        l_pos = torch.diag(similarity_matrix, N)
        r_pos = torch.diag(similarity_matrix, -N)
        positives = torch.cat([l_pos, r_pos]).view(2 * N, 1)
        
        # get the values of every pair that's a mismatch
        diag = torch.eye(2*N, dtype=torch.bool, device=device)
        diag[N:,:N] = diag[:N,N:] = diag[:N,:N]        
        negatives = similarity_matrix[~diag].view(2*N, -1)
        
#         print(positives)
#         print(negatives)
        
        exp_upper = (torch.exp(torch.sum(positives, dim=1)))
        exp_lower = (torch.exp((torch.sum(negatives,dim=1))))
        
        loss = (torch.mean(-torch.log(exp_upper/exp_lower)))
        
        return loss / (2 * N)

# main = torch.rand(4,256)
# augm = torch.rand(4,256) 
# loss = ContrastiveLoss()
# loss(main,augm)

In [3]:
import torch.nn as nn
from torchvision.models import resnet50


def get_backbone(backbone, castrate=True):
    if castrate:
        backbone.output_dim = backbone.fc.in_features
        backbone.fc = torch.nn.Identity()
    return backbone

class ProjectionHead(nn.Module):
    def __init__(self,in_shape,out_shape=256):
        super().__init__()
        hidden_shape = in_shape//2

        self.layer_1 = nn.Sequential(
            nn.Linear(in_shape,hidden_shape),
            nn.ReLU(inplace=True)
        )

        self.layer_2 = nn.Sequential(
            nn.Linear(hidden_shape,hidden_shape),
            nn.ReLU(inplace=True)
        )

        self.layer_3 = nn.Linear(hidden_shape,out_shape)

    def forward(self,x):
        x = self.layer_1(x)
#         x = self.layer_2(x)
        x = self.layer_3(x)

        return x


class ContrastiveLearner(nn.Module):
    def __init__(self, backbone=resnet50(), projection_head=None):
        super().__init__()

        self.backbone = get_backbone(backbone)
        self.projection_head = ProjectionHead(backbone.output_dim)
        self.loss = ContrastiveLoss(temp=0.5, normalize= True)

        self.encoder = nn.Sequential(
            self.backbone,
            self.projection_head
        )

    def forward(self,x,x_):
        
        z   = self.encoder(x)

        z_  = self.encoder(x_)
        loss= self.loss(z,z_)

        return loss


In [4]:
uid = 'SimCLR'
dataset_name = 'CIFAR10C'
data_dir = 'dataset'
ckpt_dir = "./ckpt"
features = 128
batch = 64
epochs = 15
lr = 1e-3
use_cuda = True
device_id = 0
wt_decay  = 0.9
 

In [5]:
if use_cuda:
    dtype = torch.cuda.FloatTensor
    device = torch.device("cuda")
    # torch.cuda.set_device(device_id)
    print('GPU')
else:
    dtype = torch.FloatTensor
    device = torch.device("cpu")


# Setup tensorboard
log_dir = "./tb" 

#create dataset folder 
if not os.path.exists('dataset'):
    os.makedirs('dataset')
# Setup asset directories
if not os.path.exists('models'):
    os.makedirs('models')

if not os.path.exists('runs'):
    os.makedirs('runs')

GPU


In [6]:
logger = Logger(log_dir=log_dir, tensorboard=True, matplotlib=True)


in_channel = 3
train_transform = cifar_train_transforms()
test_transform = cifar_test_transforms()
target_transform = None


loader = Loader(dataset_name, data_dir,True, 
                batch, train_transform, test_transform,
                target_transform, use_cuda)


train_loader = loader.train_loader
test_loader = loader.test_loader

Files already downloaded and verified
Files already downloaded and verified


In [7]:
model = ContrastiveLearner().to(device)
optimizer = optim.Adam(model.parameters(), 
            lr=lr,
            weight_decay=wt_decay) 
scheduler = ExponentialLR(optimizer, gamma=wt_decay)

accuracy = 0

# start training 
global_progress = tqdm(range(0, epochs), desc=f'Training')
data_dict = {"loss": 100}
with torch.autograd.detect_anomaly():
    for epoch in global_progress:
        model.train()   

        local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{epochs}')

        for idx, (image, aug_image, label) in enumerate(local_progress):

            model.zero_grad()
            loss = model.forward(image.to(device, non_blocking=True),aug_image.to(device, non_blocking=True))

            # loss =  data_dict['loss'].mean()
            data_dict['loss'] = (loss.mean()).item() 
            loss.backward()
            optimizer.step()
            scheduler.step()
            data_dict.update({'lr': scheduler.get_lr()[0]})
            local_progress.set_postfix(data_dict)
            logger.update_scalers(data_dict)

        epoch_dict = {'epoch':epoch, 'accuracy':accuracy}
        global_progress.set_postfix(epoch_dict)
        logger.update_scalers(epoch_dict)

    model_path = os.path.join(ckpt_dir, f"{uid}_{datetime.now().strftime('%m%d%H%M%S')}.pth")
    torch.save({
        'epoch':epoch+1,
        'state_dict': model.module.state_dict()
            }, model_path)
    print(f'Model saved at: {model_path}')

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  global_progress = tqdm(range(0, epochs), desc=f'Training')


Training:   0%|          | 0/15 [00:00<?, ?it/s]

  with torch.autograd.detect_anomaly():
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{epochs}')


Epoch 0/15:   0%|          | 0/782 [00:00<?, ?it/s]

RuntimeError: Function 'DivBackward0' returned nan values in its 0th output.