In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

## Conversion to SimCLRv2 and Converting TF Pretrained Weights
Pretrained weights can be found on Google's [repo](https://github.com/google-research/simclr). With conversion scripts linked. Most of the inital work can be found in spijkervet_prototypes.ipynb. This work is to clean up the spaghetti code and turn into modules.

In [2]:
import os
import sys
import argparse
from pprint import pprint

import torch
import torch.nn as nn
import torchvision
import numpy as np
from torch.utils.tensorboard import SummaryWriter

sys.path.insert(0, '../')

from model import save_model, load_optimizer
from simclr.modules import LogisticRegression
from simclr import SimCLR, SimCLRv2
from simclr.modules import get_resnet_pt, get_resnet_v2, NT_Xent
from simclr.modules.transformations import TransformsSimCLR
from utils import yaml_config_hook

In [3]:
parser = argparse.ArgumentParser(description="SimCLR")
config = yaml_config_hook("../config/config.yaml")

for k, v in config.items():
    parser.add_argument(f"--{k}", default=v, type=type(v))
    
args = parser.parse_args([])
args.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [4]:
args.batch_size = 32
args.resnet = "resnet50"
args.epochs = 400
args.gpus = 4
args.optimizer = 'LARS'
pprint(vars(args))

{'batch_size': 32,
 'dataparallel': 0,
 'dataset': 'CIFAR10',
 'dataset_dir': './datasets',
 'device': device(type='cuda'),
 'epoch_num': 100,
 'epochs': 400,
 'gpus': 4,
 'image_size': 224,
 'logistic_batch_size': 256,
 'logistic_epochs': 500,
 'model_path': 'save',
 'nodes': 1,
 'nr': 0,
 'optimizer': 'LARS',
 'pretrain': True,
 'projection_dim': 64,
 'reload': False,
 'resnet': 'resnet50',
 'seed': 42,
 'start_epoch': 0,
 'temperature': 0.5,
 'weight_decay': 1e-06,
 'workers': 8}


In [5]:
torch.manual_seed(args.seed)
np.random.seed(args.seed)

if args.dataset == "STL10":
    train_dataset = torchvision.datasets.STL10(
        args.dataset_dir,
        split="unlabeled",
        download=True,
        transform=TransformsSimCLR(size=args.image_size),
    )
elif args.dataset == "CIFAR10":
    train_dataset = torchvision.datasets.CIFAR10(
        args.dataset_dir,
        download=True,
        transform=TransformsSimCLR(size=args.image_size),
    )
else:
    raise NotImplementedError

if args.nodes > 1:
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=args.world_size, rank=rank, shuffle=True
    )
else:
    train_sampler = None

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=(train_sampler is None),
    drop_last=True,
    num_workers=args.workers,
    sampler=train_sampler,
)

Files already downloaded and verified


## Resnet with V2 Contrastive Head


In [6]:
model = SimCLRv2()

if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
  model = nn.DataParallel(model)

if args.reload:
    model_fp = os.path.join(
        args.model_path, f"checkpoint_{args.epoch_num}.tar"
    )
    model.load_state_dict(torch.load(model_fp, map_location=args.device.type))

model = model.to(args.device)
optimizer, scheduler = load_optimizer(args, model)
criterion = NT_Xent(args.batch_size, args.temperature, world_size=1)
writer = SummaryWriter()

Let's use 4 GPUs!


2021-09-22 17:10:11.083730: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.2
2021-09-22 17:10:11.930238: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libnvinfer.so.7
2021-09-22 17:10:11.930365: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libnvinfer_plugin.so.7


In [7]:
torch.manual_seed(args.seed)
np.random.seed(args.seed)

if args.dataset == "STL10":
    train_dataset = torchvision.datasets.STL10(
        args.dataset_dir,
        split="unlabeled",
        download=True,
        transform=TransformsSimCLR(size=args.image_size),
    )
elif args.dataset == "CIFAR10":
    train_dataset = torchvision.datasets.CIFAR10(
        args.dataset_dir,
        download=True,
        transform=TransformsSimCLR(size=args.image_size),
    )
else:
    raise NotImplementedError

if args.nodes > 1:
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=args.world_size, rank=rank, shuffle=True
    )
else:
    train_sampler = None

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=(train_sampler is None),
    drop_last=True,
    num_workers=args.workers,
    sampler=train_sampler,
)

Files already downloaded and verified


In [8]:
def train(args, train_loader, model, criterion, optimizer, writer, display_every=50):
    """Train function"""
    epoch_loss = 0
    
    for step, ((x_i, x_j), _) in enumerate(train_loader):
        optimizer.zero_grad()
        x_i = x_i.cuda(non_blocking=True)
        x_j = x_j.cuda(non_blocking=True)
        
        # Positive pair with encoding
        h_i, h_j, z_i, z_j = model(x_i, x_j)
        
        loss = criterion(z_i, z_j)
        loss.backward()
        optimizer.step()
        
        if step % display_every == 0:
            print(f"Step [{step}/{len(train_loader)}]\t Loss: {loss.item()}")
        
        writer.add_scalar("Loss/train_epoch", loss.item(), args.global_step)
        epoch_loss += loss.item()
        args.global_step += 1
    
    return epoch_loss

In [None]:
args.global_step = 0
args.current_epoch = 0

for epoch in range(args.start_epoch, args.epochs):
    lr = optimizer.param_groups[0]["lr"]
    epoch_loss = train(args, train_loader, model, criterion, optimizer, writer)
    
    if scheduler:
        scheduler.step()
    
    if epoch % 10 == 0:
        save_model(args, model, optimizer)
    
    writer.add_scalar("Loss/train", epoch_loss / len(train_loader), epoch)
    writer.add_scalar("Misc/learning_rate", lr, epoch)

    print(
        f"Epoch [{epoch}/{args.epochs}]\t Loss: {epoch_loss / len(train_loader)}\t lr: {round(lr, 5)}"
    )
    args.current_epoch += 1

save_model(args, model, optimizer)

Step [0/1562]	 Loss: 4.168707847595215
Step [50/1562]	 Loss: 4.085995674133301
Step [100/1562]	 Loss: 4.097858428955078
Step [150/1562]	 Loss: 4.022454261779785


In [None]:
torch.cuda.empty_cache()

In [None]:
head

In [None]:
foo, bar = get_resnet_v2(depth=50)

In [None]:
baz = get_resnet_pt("resnet50")

In [None]:
baz