# Initial Experiment and Setup
SimCLRv1 notebook code originally from Spjijkervet/SimCLR. Most of the experiments here are an effort to use pretrained weights provided from Google's original implementation with pre-trained weights and conversion to v2. Useful for record, but a bit messy.

In [8]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
import os
import sys
import argparse
from pathlib import Path
from pprint import pprint

import torch
import torch.nn as nn
import torchvision
import numpy as np
from torch.utils.tensorboard import SummaryWriter


sys.path.insert(0, '../')

from model import save_model, load_optimizer
from simclr.modules import LogisticRegression
from simclr import SimCLR
from simclr.modules import get_resnet, NT_Xent
from simclr.modules.transformations import TransformsSimCLR
from utils import yaml_config_hook

ModuleNotFoundError: No module named 'model'

In [10]:
parser = argparse.ArgumentParser(description="SimCLR")
tensorboard_logdir = Path("/home/kaipak/models/runs")
config = yaml_config_hook("../config/config.yaml")

for k, v in config.items():
    parser.add_argument(f"--{k}", default=v, type=type(v))
    
args = parser.parse_args([])
args.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

NameError: name 'yaml_config_hook' is not defined

In [None]:
args.batch_size = 256
args.resnet = "resnet50"
args.epochs = 400
args.gpus = 4
pprint(vars(args))

In [None]:
torch.manual_seed(args.seed)
np.random.seed(args.seed)

if args.dataset == "STL10":
    train_dataset = torchvision.datasets.STL10(
        args.dataset_dir,
        split="unlabeled",
        download=True,
        transform=TransformsSimCLR(size=args.image_size),
    )
elif args.dataset == "CIFAR10":
    train_dataset = torchvision.datasets.CIFAR10(
        args.dataset_dir,
        download=True,
        transform=TransformsSimCLR(size=args.image_size),
    )
else:
    raise NotImplementedError

if args.nodes > 1:
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=args.world_size, rank=rank, shuffle=True
    )
else:
    train_sampler = None

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=(train_sampler is None),
    drop_last=True,
    num_workers=args.workers,
    sampler=train_sampler,
)

In [None]:
encoder = get_resnet(args.resnet, pretrained=False)
n_features = encoder.fc.in_features

# init model
model = SimCLR(encoder, args.projection_dim, n_features)

if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
  model = nn.DataParallel(model)

if args.reload:
    model_fp = os.path.join(
        args.model_path, f"checkpoint_{args.epoch_num}.tar"
    )
    model.load_state_dict(torch.load(model_fp, map_location=args.device.type))

model = model.to(args.device)
optimizer, scheduler = load_optimizer(args, model)
criterion = NT_Xent(args.batch_size, args.temperature, world_size=1)
writer = SummaryWriter(log_dir='/home/kaipak/models/runs')

In [None]:
def train(args, train_loader, model, criterion, optimizer, writer, display_every=50):
    """Train function"""
    epoch_loss = 0
    
    for step, ((x_i, x_j), _) in enumerate(train_loader):
        optimizer.zero_grad()
        x_i = x_i.cuda(non_blocking=True)
        x_j = x_j.cuda(non_blocking=True)
        
        # Positive pair with encoding
        h_i, h_j, z_i, z_j = model(x_i, x_j)
        
        loss = criterion(z_i, z_j)
        loss.backward()
        optimizer.step()
        
        if step % display_every == 0:
            print(f"Step [{step}/{len(train_loader)}]\t Loss: {loss.item()}")
        
        writer.add_scalar("Loss/train_epoch", loss.item(), args.global_step)
        epoch_loss += loss.item()
        args.global_step += 1
    
    return epoch_loss

In [None]:
args.global_step = 0
args.current_epoch = 0
tb_writer = SummaryWriter(log_dir=f'/home/kai/models/runs/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}')

for epoch in range(args.start_epoch, args.epochs):
    lr = optimizer.param_groups[0]["lr"]
    epoch_loss = train(args, train_loader, model, criterion, optimizer, tb_writer)
    
    if scheduler:
        scheduler.step()
    
    if epoch % 10 == 0:
        save_model(args, model, optimizer)
    
    writer.add_scalar("Loss/train", epoch_loss / len(train_loader), epoch)
    writer.add_scalar("Misc/learning_rate", lr, epoch)

    print(
        f"Epoch [{epoch}/{args.epochs}]\t Loss: {epoch_loss / len(train_loader)}\t lr: {round(lr, 5)}"
    )
    args.current_epoch += 1

save_model(args, model, optimizer)

In [None]:
torch.cuda.empty_cache()

## Linear Evaluation
From example, LE using logistic regression, from weights from frozen trained model from previous cells.

In [36]:
def train(args, loader, simclr_model, model, criterion, optimizer, writer):
    """Train evaluation model"""
    epoch_loss = 0
    epoch_accuracy = 0
    model.eval()
    
    for step, (x, y) in enumerate(loader):
        optimizer.zero_grad()
        
        x = x.to(args.device)
        y = y.to(args.device)
        
        output = model(x)
        step_loss = criterion(output, y)
        
        predicted = output.argmax(1)
        step_accuracy = (predicted == y).sum().item() / y.size(0)
        epoch_accuracy += step_accuracy
        
        step_loss.backward()
        optimizer.step()
        
        epoch_loss += step_loss
        writer.add_scalar("Accuracy/train_step", step_accuracy, args.global_step)
        args.global_step += 1
        
    writer.add_scalar("Accuracy/train_epoch", step_accuracy, args.current_epoch)
    writer.add_scalar("Loss/train_epoch", epoch_loss, args.current_epoch)

    return epoch_loss, epoch_accuracy

def test(args, loader, simclr_model, model, criterion, optimizer):
    epoch_loss = 0
    epoch_accuracy = 0
    model.eval()
    
    for step, (x, y) in enumerate(loader):
        model.zero_grad()
        
        x = x.to(args.device)
        y = y.to(args.device)
        
        output = model(x)
        step_loss = criterion(output, y)
        
        predicted = output.argmax(1)
        step_accuracy = (predicted == y).sum().item() / y.size(0)
        epoch_accuracy += step_accuracy
        
        epoch_loss += step_loss.item()
    
    return epoch_loss, epoch_accuracy


In [37]:
from pprint import pprint
from utils import yaml_config_hook

use_tpu = False

parser = argparse.ArgumentParser(description="SimCLR")
config = yaml_config_hook("../config/config.yaml")
for k, v in config.items():
    parser.add_argument(f"--{k}", default=v, type=type(v))

args = parser.parse_args([])

if use_tpu:
  args.device = dev
else:
  args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

args.batch_size = 32
args.dataset = "CIFAR10"
args.resnet = "resnet50"
args.model_path = "../../../models/SimCLRv2/save"
args.epoch_num = 100
args.logistic_epochs = 500

In [38]:
if args.dataset == "STL10":
    train_dataset = torchvision.datasets.STL10(
        args.dataset_dir,
        split="train",
        download=True,
        transform=TransformsSimCLR(size=args.image_size).test_transform,
    )
    test_dataset = torchvision.datasets.STL10(
        args.dataset_dir,
        split="test",
        download=True,
        transform=TransformsSimCLR(size=args.image_size).test_transform,
    )
elif args.dataset == "CIFAR10":
    train_dataset = torchvision.datasets.CIFAR10(
        args.dataset_dir,
        train=True,
        download=True,
        transform=TransformsSimCLR(size=args.image_size).test_transform,
    )
    test_dataset = torchvision.datasets.CIFAR10(
        args.dataset_dir,
        train=False,
        download=True,
        transform=TransformsSimCLR(size=args.image_size).test_transform,
    )
else:
    raise NotImplementedError

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.logistic_batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=args.workers,
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=args.logistic_batch_size,
    shuffle=False,
    drop_last=True,
    num_workers=args.workers,
)

Files already downloaded and verified
Files already downloaded and verified


In [59]:
criterion

CrossEntropyLoss()

In [None]:
import torch
import torch.nn as nn


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1
    __constants__ = ['downsample']

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4
    __constants__ = ['downsample']

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        self.downsample = downsample  # hack: moving downsample to the first to make order correct
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)

        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None, width_mult=1):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64 * width_mult
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64 * width_mult, layers[0])
        self.layer2 = self._make_layer(block, 128 * width_mult, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256 * width_mult, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512 * width_mult, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion * width_mult, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    def forward(self, x):
        return self._forward_impl(x)


def resnet50x1(**kwargs):
    return ResNet(Bottleneck, [3, 4, 6, 3], width_mult=1)

def resnet50x2(**kwargs):
    return ResNet(Bottleneck, [3, 4, 6, 3], width_mult=2)


def resnet50x4(**kwargs):
    return ResNet(Bottleneck, [3, 4, 6, 3], width_mult=4)

In [None]:
# Load ResNet encoder / SimCLR and load model weights
#encoder = get_resnet(args.resnet, pretrained=False)
#n_features = encoder.fc.in_features

# load pre-trained model from checkpoint
encoder = resnet50x4()
encoder.load_state_dict(torch.load('/home/kaipak/models/SimCLRv1/resnet50-4x.pth', map_location='cpu')['state_dict'])
n_features = encoder.fc.in_features
simclr_model = SimCLR(encoder, args.projection_dim, n_features)
#model_fp = os.path.join(args.model_path, "checkpoint_40.pt")
#simclr_model.load_state_dict(torch.load(model_fp, map_location=args.device.type))

if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
  simclr_model_ngpu = nn.DataParallel(simclr_model)

simclr_model = simclr_model.to(args.device)

In [39]:
# %load /home/kaipak/dev/SimCLRv2-Pytorch/resnet.py
import torch
import torch.nn as nn
import torch.nn.functional as F

BATCH_NORM_EPSILON = 1e-5
BATCH_NORM_DECAY = 0.9  # == pytorch's default value as well


class BatchNormRelu(nn.Sequential):
    def __init__(self, num_channels, relu=True):
        super().__init__(nn.BatchNorm2d(num_channels, eps=BATCH_NORM_EPSILON), nn.ReLU() if relu else nn.Identity())


def conv(in_channels, out_channels, kernel_size=3, stride=1, bias=False):
    return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                     stride=stride, padding=(kernel_size - 1) // 2, bias=bias)


class SelectiveKernel(nn.Module):
    def __init__(self, in_channels, out_channels, stride, sk_ratio, min_dim=32):
        super().__init__()
        assert sk_ratio > 0.0
        self.main_conv = nn.Sequential(conv(in_channels, 2 * out_channels, stride=stride),
                                       BatchNormRelu(2 * out_channels))
        mid_dim = max(int(out_channels * sk_ratio), min_dim)
        self.mixing_conv = nn.Sequential(conv(out_channels, mid_dim, kernel_size=1), BatchNormRelu(mid_dim),
                                         conv(mid_dim, 2 * out_channels, kernel_size=1))

    def forward(self, x):
        x = self.main_conv(x)
        x = torch.stack(torch.chunk(x, 2, dim=1), dim=0)  # 2, B, C, H, W
        g = x.sum(dim=0).mean(dim=[2, 3], keepdim=True)
        m = self.mixing_conv(g)
        m = torch.stack(torch.chunk(m, 2, dim=1), dim=0)  # 2, B, C, 1, 1
        return (x * F.softmax(m, dim=0)).sum(dim=0)


class Projection(nn.Module):
    def __init__(self, in_channels, out_channels, stride, sk_ratio=0):
        super().__init__()
        if sk_ratio > 0:
            self.shortcut = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)),
                                          # kernel_size = 2 => padding = 1
                                          nn.AvgPool2d(kernel_size=2, stride=stride, padding=0),
                                          conv(in_channels, out_channels, kernel_size=1))
        else:
            self.shortcut = conv(in_channels, out_channels, kernel_size=1, stride=stride)
        self.bn = BatchNormRelu(out_channels, relu=False)

    def forward(self, x):
        return self.bn(self.shortcut(x))


class BottleneckBlock(nn.Module):
    expansion = 4

    def __init__(self, in_channels, out_channels, stride, sk_ratio=0, use_projection=False):
        super().__init__()
        if use_projection:
            self.projection = Projection(in_channels, out_channels * 4, stride, sk_ratio)
        else:
            self.projection = nn.Identity()
        ops = [conv(in_channels, out_channels, kernel_size=1), BatchNormRelu(out_channels)]
        if sk_ratio > 0:
            ops.append(SelectiveKernel(out_channels, out_channels, stride, sk_ratio))
        else:
            ops.append(conv(out_channels, out_channels, stride=stride))
            ops.append(BatchNormRelu(out_channels))
        ops.append(conv(out_channels, out_channels * 4, kernel_size=1))
        ops.append(BatchNormRelu(out_channels * 4, relu=False))
        self.net = nn.Sequential(*ops)

    def forward(self, x):
        shortcut = self.projection(x)
        return F.relu(shortcut + self.net(x))


class Blocks(nn.Module):
    def __init__(self, num_blocks, in_channels, out_channels, stride, sk_ratio=0):
        super().__init__()
        self.blocks = nn.ModuleList([BottleneckBlock(in_channels, out_channels, stride, sk_ratio, True)])
        self.channels_out = out_channels * BottleneckBlock.expansion
        for _ in range(num_blocks - 1):
            self.blocks.append(BottleneckBlock(self.channels_out, out_channels, 1, sk_ratio))

    def forward(self, x):
        for b in self.blocks:
            x = b(x)
        return x


class Stem(nn.Sequential):
    def __init__(self, sk_ratio, width_multiplier):
        ops = []
        channels = 64 * width_multiplier // 2
        if sk_ratio > 0:
            ops.append(conv(3, channels, stride=2))
            ops.append(BatchNormRelu(channels))
            ops.append(conv(channels, channels))
            ops.append(BatchNormRelu(channels))
            ops.append(conv(channels, channels * 2))
        else:
            ops.append(conv(3, channels * 2, kernel_size=7, stride=2))
        ops.append(BatchNormRelu(channels * 2))
        ops.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
        super().__init__(*ops)


class ResNet(nn.Module):
    def __init__(self, layers, width_multiplier, sk_ratio):
        super().__init__()
        ops = [Stem(sk_ratio, width_multiplier)]
        channels_in = 64 * width_multiplier
        ops.append(Blocks(layers[0], channels_in, 64 * width_multiplier, 1, sk_ratio))
        channels_in = ops[-1].channels_out
        ops.append(Blocks(layers[1], channels_in, 128 * width_multiplier, 2, sk_ratio))
        channels_in = ops[-1].channels_out
        ops.append(Blocks(layers[2], channels_in, 256 * width_multiplier, 2, sk_ratio))
        channels_in = ops[-1].channels_out
        ops.append(Blocks(layers[3], channels_in, 512 * width_multiplier, 2, sk_ratio))
        channels_in = ops[-1].channels_out
        self.channels_out = channels_in
        self.net = nn.Sequential(*ops)
        self.fc = nn.Linear(channels_in, 1000)

    def forward(self, x, apply_fc=False):
        h = self.net(x).mean(dim=[2, 3])
        if apply_fc:
            h = self.fc(h)
        return h


class ContrastiveHead(nn.Module):
    def __init__(self, channels_in, out_dim=128, num_layers=3):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            if i != num_layers - 1:
                dim, relu = channels_in, True
            else:
                dim, relu = out_dim, False
            self.layers.append(nn.Linear(channels_in, dim, bias=False))
            bn = nn.BatchNorm1d(dim, eps=BATCH_NORM_EPSILON, affine=True)
            if i == num_layers - 1:
                nn.init.zeros_(bn.bias)
            self.layers.append(bn)
            if relu:
                self.layers.append(nn.ReLU())

    def forward(self, x):
        for b in self.layers:
            x = b(x)
        return x


def get_resnet(depth=50, width_multiplier=1, sk_ratio=0):  # sk_ratio=0.0625 is recommended
    layers = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]}[depth]
    resnet = ResNet(layers, width_multiplier, sk_ratio)
    return resnet, ContrastiveHead(resnet.channels_out)


def name_to_params(checkpoint):
    sk_ratio = 0.0625 if '_sk1' in checkpoint else 0
    if 'r50_' in checkpoint:
        depth = 50
    elif 'r101_' in checkpoint:
        depth = 101
    elif 'r152_' in checkpoint:
        depth = 152
    else:
        raise NotImplementedError

    if '_1x_' in checkpoint:
        width = 1
    elif '_2x_' in checkpoint:
        width = 2
    elif '_3x_' in checkpoint:
        width = 3
    else:
        raise NotImplementedError

    return depth, width, sk_ratio


In [10]:
# Load ResNet encoder / SimCLR and load model weights
#encoder = get_resnet(args.resnet, pretrained=False)
#n_features = encoder.fc.in_features

# load pre-trained model from checkpoint
encoder, head = get_resnet(depth=101, width_multiplier=2, sk_ratio=0.0625)
encoder.load_state_dict(torch.load('/home/kaipak/models/SimCLRv2/r101_2x_sk1.pth', map_location='cpu')['resnet'])

n_features = encoder.fc.in_features
simclr_model = SimCLR(encoder, args.projection_dim, n_features)
#model_fp = os.path.join(args.model_path, "checkpoint_40.pt")
#simclr_model.load_state_dict(torch.load(model_fp, map_location=args.device.type))

if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
  simclr_model_ngpu = nn.DataParallel(simclr_model)
  simclr_model_ngpu.n_features = n_features

simclr_model = simclr_model_ngpu.to(args.device)

Let's use 4 GPUs!


In [58]:
head

ContrastiveHead(
  (layers): ModuleList(
    (0): Linear(in_features=4096, out_features=4096, bias=False)
    (1): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=4096, out_features=4096, bias=False)
    (4): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Linear(in_features=4096, out_features=128, bias=False)
    (7): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [11]:
n_features

4096

In [52]:
from simclr.modules import LARS

n_classes = 10
model = LogisticRegression(simclr_model.n_features, n_classes)
if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
  model = nn.DataParallel(model)
model = model.to(args.device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
#optimizer = LARS(model.parameters(), lr=3e-3)
criterion = torch.nn.CrossEntropyLoss()
writer = SummaryWriter(log_dir='/home/kaipak/models/runs')

Let's use 4 GPUs!


In [65]:
# Helper functions to map all input data X to their latent representations 
# h that are used in linear evaluation (they only have to be computed once)
# Should be part of the processing step before FT.
def inference(loader, simclr_model, device):
    feature_vector = []
    labels_vector = []
    for step, (x, y) in enumerate(loader):
        x = x.to(device)

        # get encoding
        with torch.no_grad():
            h, _, z, _ = simclr_model(x, x)

        h = h.detach()

        feature_vector.extend(h.cpu().detach().numpy())
        labels_vector.extend(y.numpy())

        if step % 20 == 0:
            print(f"Step [{step}/{len(loader)}]\t Computing features...")

    feature_vector = np.array(feature_vector)
    labels_vector = np.array(labels_vector)
    print("Features shape {}".format(feature_vector.shape))
    return feature_vector, labels_vector

# TODO: Instead of getting representation from Resnet, need to get from projection head layer 2
# This is sorta misleading and should be wrapped into model.
def get_features(context_model, train_loader, test_loader, device):
    train_X, train_y = inference(train_loader, context_model, device)
    test_X, test_y = inference(test_loader, context_model, device)
    return train_X, train_y, test_X, test_y


def create_data_loaders_from_arrays(X_train, y_train, X_test, y_test, batch_size):
    train = torch.utils.data.TensorDataset(
        torch.from_numpy(X_train), torch.from_numpy(y_train)
    )
    train_loader = torch.utils.data.DataLoader(
        train, batch_size=batch_size, shuffle=False
    )

    test = torch.utils.data.TensorDataset(
        torch.from_numpy(X_test), torch.from_numpy(y_test)
    )
    test_loader = torch.utils.data.DataLoader(
        test, batch_size=batch_size, shuffle=False
    )
    return train_loader, test_loader

[autoreload of simclr failed: Traceback (most recent call last):
  File "/home/kaipak/miniconda3/envs/pytorch/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 245, in check
    superreload(m, reload, self.old_objects)
  File "/home/kaipak/miniconda3/envs/pytorch/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 394, in superreload
    module = reload(module)
  File "/home/kaipak/miniconda3/envs/pytorch/lib/python3.7/imp.py", line 314, in reload
    return importlib.reload(module)
  File "/home/kaipak/miniconda3/envs/pytorch/lib/python3.7/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 630, in _exec
  File "<frozen importlib._bootstrap_external>", line 728, in exec_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "../simclr/__init__.py", line 2, in <module>
    from .simclr import SimCLRv2
ImportError: cannot import name 'SimCLRv2' 

In [45]:
(train_X, train_y, test_X, test_y) = get_features(
    simclr_model, train_loader, test_loader, args.device
)

arr_train_loader, arr_test_loader = create_data_loaders_from_arrays(
    train_X, train_y, test_X, test_y, args.logistic_batch_size
)

Step [0/195]	 Computing features...
Step [20/195]	 Computing features...
Step [40/195]	 Computing features...
Step [60/195]	 Computing features...
Step [80/195]	 Computing features...
Step [100/195]	 Computing features...
Step [120/195]	 Computing features...
Step [140/195]	 Computing features...
Step [160/195]	 Computing features...
Step [180/195]	 Computing features...
Features shape (49920, 4096)
Step [0/39]	 Computing features...
Step [20/39]	 Computing features...
Features shape (9984, 4096)


In [51]:
%time
args.global_step = 0
args.current_epoch = 0
args.logistic_epochs = 400

#for epoch in range(args.logistic_epochs):
for epoch in range(args.logistic_epochs):
    loss_epoch, accuracy_epoch = train(args, arr_train_loader, simclr_model, model, criterion, optimizer, writer)
    
    if epoch % 20 == 0:
        print(f"Epoch [{epoch}/{args.logistic_epochs}]\t Loss: {loss_epoch / len(train_loader)}\t Accuracy: {accuracy_epoch / len(train_loader)}")
    
    args.current_epoch += 1
        
loss_epoch, accuracy_epoch = test(
    args, arr_test_loader, simclr_model, model, criterion, optimizer
)

print(
    f"[FINAL]\t Loss: {loss_epoch / len(test_loader)}\t Accuracy: {accuracy_epoch / len(test_loader)}"
)

CPU times: user 0 ns, sys: 7 µs, total: 7 µs
Wall time: 13.1 µs
Epoch [0/400]	 Loss: 2.2319798469543457	 Accuracy: 0.47682291666666665
Epoch [20/400]	 Loss: 2.2054662704467773	 Accuracy: 0.5101762820512821
Epoch [40/400]	 Loss: 2.1790452003479004	 Accuracy: 0.526582532051282
Epoch [60/400]	 Loss: 2.1526060104370117	 Accuracy: 0.5374399038461538
Epoch [80/400]	 Loss: 2.126054525375366	 Accuracy: 0.544491185897436
Epoch [100/400]	 Loss: 2.0993082523345947	 Accuracy: 0.5503004807692308
Epoch [120/400]	 Loss: 2.0722970962524414	 Accuracy: 0.5547475961538462


KeyboardInterrupt: 

In [None]:
torch.cuda.empty_cache()

In [60]:
arr_train_loader

<torch.utils.data.dataloader.DataLoader at 0x78ff544f7390>

In [None]:
# Try converting pretrained weights from Google's TF Implementation