In [89]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
print(gpu_info)

from datetime import datetime
from functools import partial
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet
from tqdm import tqdm
import argparse
import json
import math
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import lightly.models as models
import lightly.loss as loss

Sun Jun 27 11:35:52 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.80       Driver Version: 460.80       CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  TITAN V             Off  | 00000000:3B:00.0 Off |                  N/A |
| 28%   36C    P2    38W / 250W |   5818MiB / 12066MiB |     13%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  TITAN V             Off  | 00000000:5E:00.0 Off |                  N/A |
| 28%   38C    P2    35W / 250W |   6208MiB / 12066MiB |     33%      Default |
|       

## Create dataloader, model

In [None]:
parser = argparse.ArgumentParser(description='Train SimSiam on CIFAR-10')

parser.add_argument('-a', '--arch', default='resnet18')

# lr: 0.06 for batch 512 (or 0.03 for batch 256)
parser.add_argument('--lr', '--learning-rate', default=0.06, type=float, metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run')
parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int, help='learning rate schedule (when to drop lr by 10x); does not take effect if --cos is on')
parser.add_argument('--cos', action='store_true', help='use cosine lr schedule')

parser.add_argument('--batch-size', default=512, type=int, metavar='N', help='mini-batch size')
parser.add_argument('--wd', default=5e-4, type=float, metavar='W', help='weight decay')

parser.add_argument('--bn-splits', default=8, type=int, help='simulate multi-gpu behavior of BatchNorm in one gpu; 1 is SyncBatchNorm in multi-gpu')

parser.add_argument('--symmetric', action='store_true', help='use a symmetric loss function that backprops to both crops')

# knn monitor
parser.add_argument('--knn-k', default=200, type=int, help='k in kNN monitor')
parser.add_argument('--knn-t', default=0.1, type=float, help='softmax temperature in kNN monitor; could be different with moco-t')

# utils
parser.add_argument('--resume', default='../output/CIFAR10-2021-06-27-11-37-16-moco/model_last.pth', 
                    type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('--results-dir', default='', type=str, metavar='PATH', help='path to cache (default: none)')

'''
args = parser.parse_args()  # running in command line
'''
args = parser.parse_args('')  # running in ipynb

# set command line arguments here when running in ipynb
args.epochs = 800
args.cos = True
args.schedule = []  # cos in use
args.symmetric = False
if args.results_dir == '':
    args.results_dir = '../output/CIFAR10-' + datetime.now().strftime("%Y-%m-%d-%H-%M-%S-moco")

print(args)

In [91]:
class CIFAR10Pair(CIFAR10):
    """CIFAR10 Dataset.
    """
    def __getitem__(self, index):
        img = self.data[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            im_1 = self.transform(img)
            im_2 = self.transform(img)

        return im_1, im_2

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

# data prepare
train_data = CIFAR10Pair(root='data', train=True, transform=train_transform, download=True)
train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=16, pin_memory=True, drop_last=True)

memory_data = CIFAR10(root='data', train=True, transform=test_transform, download=True)
memory_loader = DataLoader(memory_data, batch_size=args.batch_size, shuffle=False, num_workers=16, pin_memory=True)

test_data = CIFAR10(root='data', train=False, transform=test_transform, download=True)
test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=16, pin_memory=True)


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


## Model

In [98]:
# use a resnet50 backbone
resnet = torchvision.models.resnet.resnet18()
resnet = torch.nn.Sequential(*list(resnet.children())[:-1])

# build the simsiam model
model = models.SimSiam(resnet, num_ftrs=512)
model = model.cuda()

# use a criterion for self-supervised learning
criterion = loss.SymNegCosineSimilarityLoss()

In [99]:
# train for one epoch
def train(net, data_loader, train_optimizer, epoch, args):
    net.train()
    adjust_learning_rate(optimizer, epoch, args)

    total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader)
    for im_1, im_2 in train_bar:
        im_1, im_2 = im_1.cuda(non_blocking=True), im_2.cuda(non_blocking=True)

        y0, y1 = net(im_1, im_2)
        # backpropagation
        loss = criterion(y0, y1)
        
        train_optimizer.zero_grad()
        loss.backward()
        train_optimizer.step()

        total_num += data_loader.batch_size
        total_loss += loss.item() * data_loader.batch_size
        train_bar.set_description('Train Epoch: [{}/{}], lr: {:.6f}, Loss: {:.4f}'.format(epoch, args.epochs, optimizer.param_groups[0]['lr'], total_loss / total_num))

    return total_loss / total_num

# lr scheduler for training
def adjust_learning_rate(optimizer, epoch, args):
    """Decay the learning rate based on schedule"""
    lr = args.lr
    if args.cos:  # cosine lr schedule
        lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
    else:  # stepwise lr schedule
        for milestone in args.schedule:
            lr *= 0.1 if epoch >= milestone else 1.
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [100]:
# test using a knn monitor
def test(net, memory_data_loader, test_data_loader, epoch, args):
    net.eval()
    classes = len(memory_data_loader.dataset.classes)
    total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
    with torch.no_grad():
        # generate feature bank
        for data, target in tqdm(memory_data_loader, desc='Feature extracting'):
            feature = net(data.cuda(non_blocking=True)).squeeze()
            feature = F.normalize(feature, dim=1)
            feature_bank.append(feature)
        # [D, N]
        feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
        # [N]
        feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device)
        # loop test data to predict the label by weighted knn search
        test_bar = tqdm(test_data_loader)
        for data, target in test_bar:
            data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
            feature = net(data).squeeze()
            feature = F.normalize(feature, dim=1)
            
            pred_labels = knn_predict(feature, feature_bank, feature_labels, classes, args.knn_k, args.knn_t)

            total_num += data.size(0)
            total_top1 += (pred_labels[:, 0] == target).float().sum().item()
            test_bar.set_description('Test Epoch: [{}/{}] Acc@1:{:.2f}%'.format(epoch, args.epochs, total_top1 / total_num * 100))

    return total_top1 / total_num * 100

# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
# implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR
def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
    # compute cos similarity between each feature vector and feature bank ---> [B, N]
    sim_matrix = torch.mm(feature, feature_bank)
    # [B, K]
    sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
    # [B, K]
    sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices)
    sim_weight = (sim_weight / knn_t).exp()

    # counts for each class
    one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device)
    # [B*K, C]
    one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
    # weighted score ---> [B, C]
    pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1)

    pred_labels = pred_scores.argsort(dim=-1, descending=True)
    return pred_labels

In [103]:
# define optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd, momentum=0.9)

# load model if resume
epoch_start = 200
if args.resume is not '':
    checkpoint = torch.load(args.resume)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    epoch_start = checkpoint['epoch'] + 1
    print('Loaded from: {}'.format(args.resume))

Loaded from: ../output/CIFAR10-2021-06-27-11-37-16-moco/model_last.pth


In [111]:
# logging
results = {'train_loss': [], 'test_acc@1': []}
if not os.path.exists(args.results_dir):
    os.mkdir(args.results_dir)
    
# dump args
with open(args.results_dir + '/args.json', 'w') as fid:
    json.dump(args.__dict__, fid, indent=2)


In [112]:
# training loop
for epoch in range(epoch_start, args.epochs + 1):
    test_acc_1 = test(model.backbone, memory_loader, test_loader, epoch, args)
    train_loss = train(model, train_loader, optimizer, epoch, args)
    results['train_loss'].append(train_loss)
    test_acc_1 = test(model.backbone, memory_loader, test_loader, epoch, args)
    results['test_acc@1'].append(test_acc_1)
    # save statistics
    data_frame = pd.DataFrame(data=results, index=range(epoch_start, epoch + 1))
    data_frame.to_csv(args.results_dir + '/log.csv', index_label='epoch')
    # save model
    torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict(),}, args.results_dir + '/model_last.pth')

Feature extracting: 100%|██████████| 98/98 [00:03<00:00, 28.00it/s]
Test Epoch: [700/800] Acc@1:78.44%: 100%|██████████| 20/20 [00:02<00:00,  8.09it/s]
Feature extracting: 100%|██████████| 98/98 [00:03<00:00, 29.17it/s]
Test Epoch: [701/800] Acc@1:78.44%: 100%|██████████| 20/20 [00:02<00:00,  7.18it/s]
Train Epoch: [701/800], lr: 0.002239, Loss: -0.8365: 100%|██████████| 97/97 [00:16<00:00,  6.04it/s]
Feature extracting: 100%|██████████| 98/98 [00:03<00:00, 29.14it/s]
Test Epoch: [701/800] Acc@1:78.15%: 100%|██████████| 20/20 [00:02<00:00,  7.93it/s]
Feature extracting: 100%|██████████| 98/98 [00:03<00:00, 28.53it/s]
Test Epoch: [702/800] Acc@1:78.15%: 100%|██████████| 20/20 [00:02<00:00,  7.83it/s]
Train Epoch: [702/800], lr: 0.002194, Loss: -0.8358: 100%|██████████| 97/97 [00:15<00:00,  6.15it/s]
Feature extracting: 100%|██████████| 98/98 [00:03<00:00, 27.76it/s]
Test Epoch: [702/800] Acc@1:78.60%: 100%|██████████| 20/20 [00:02<00:00,  9.03it/s]
Feature extracting: 100%|██████████| 9

Test Epoch: [720/800] Acc@1:78.77%: 100%|██████████| 20/20 [00:02<00:00,  8.53it/s]
Feature extracting: 100%|██████████| 98/98 [00:03<00:00, 29.67it/s]
Test Epoch: [721/800] Acc@1:78.77%: 100%|██████████| 20/20 [00:02<00:00,  8.50it/s]
Train Epoch: [721/800], lr: 0.001432, Loss: -0.8432: 100%|██████████| 97/97 [00:16<00:00,  5.94it/s]
Feature extracting: 100%|██████████| 98/98 [00:03<00:00, 27.31it/s]
Test Epoch: [721/800] Acc@1:78.95%: 100%|██████████| 20/20 [00:02<00:00,  8.27it/s]
Feature extracting: 100%|██████████| 98/98 [00:03<00:00, 27.10it/s]
Test Epoch: [722/800] Acc@1:78.95%: 100%|██████████| 20/20 [00:02<00:00,  9.72it/s]
Train Epoch: [722/800], lr: 0.001396, Loss: -0.8442: 100%|██████████| 97/97 [00:16<00:00,  5.97it/s]
Feature extracting: 100%|██████████| 98/98 [00:03<00:00, 30.60it/s]
Test Epoch: [722/800] Acc@1:78.78%: 100%|██████████| 20/20 [00:02<00:00,  9.27it/s]
Feature extracting: 100%|██████████| 98/98 [00:03<00:00, 28.51it/s]
Test Epoch: [723/800] Acc@1:78.78%: 10

Feature extracting: 100%|██████████| 98/98 [00:03<00:00, 27.37it/s]
Test Epoch: [741/800] Acc@1:78.89%: 100%|██████████| 20/20 [00:02<00:00,  7.89it/s]
Train Epoch: [741/800], lr: 0.000802, Loss: -0.8548: 100%|██████████| 97/97 [00:15<00:00,  6.29it/s]
Feature extracting: 100%|██████████| 98/98 [00:03<00:00, 27.27it/s]
Test Epoch: [741/800] Acc@1:78.89%: 100%|██████████| 20/20 [00:02<00:00,  8.45it/s]
Feature extracting: 100%|██████████| 98/98 [00:03<00:00, 25.81it/s]
Test Epoch: [742/800] Acc@1:78.89%: 100%|██████████| 20/20 [00:02<00:00,  7.96it/s]
Train Epoch: [742/800], lr: 0.000775, Loss: -0.8538: 100%|██████████| 97/97 [00:16<00:00,  5.94it/s]
Feature extracting: 100%|██████████| 98/98 [00:03<00:00, 25.40it/s]
Test Epoch: [742/800] Acc@1:78.89%: 100%|██████████| 20/20 [00:02<00:00,  7.83it/s]
Feature extracting: 100%|██████████| 98/98 [00:03<00:00, 27.40it/s]
Test Epoch: [743/800] Acc@1:78.89%: 100%|██████████| 20/20 [00:01<00:00, 10.78it/s]
Train Epoch: [743/800], lr: 0.000748, 

Feature extracting: 100%|██████████| 98/98 [00:03<00:00, 26.41it/s]
Test Epoch: [782/800] Acc@1:79.51%: 100%|██████████| 20/20 [00:02<00:00,  8.07it/s]
Train Epoch: [782/800], lr: 0.000075, Loss: -0.8588: 100%|██████████| 97/97 [00:16<00:00,  5.97it/s]
Feature extracting: 100%|██████████| 98/98 [00:03<00:00, 28.06it/s]
Test Epoch: [782/800] Acc@1:79.43%: 100%|██████████| 20/20 [00:02<00:00,  8.27it/s]
Feature extracting: 100%|██████████| 98/98 [00:03<00:00, 27.75it/s]
Test Epoch: [783/800] Acc@1:79.43%: 100%|██████████| 20/20 [00:02<00:00,  8.48it/s]
Train Epoch: [783/800], lr: 0.000067, Loss: -0.8594: 100%|██████████| 97/97 [00:16<00:00,  6.02it/s]
Feature extracting: 100%|██████████| 98/98 [00:04<00:00, 24.28it/s]
Test Epoch: [783/800] Acc@1:79.47%: 100%|██████████| 20/20 [00:02<00:00,  8.23it/s]
Feature extracting: 100%|██████████| 98/98 [00:04<00:00, 22.34it/s]
Test Epoch: [784/800] Acc@1:79.47%: 100%|██████████| 20/20 [00:01<00:00, 10.33it/s]
Train Epoch: [784/800], lr: 0.000059, 

In [113]:
test_acc_1 = test(model.backbone, memory_loader, test_loader, epoch, args)

Feature extracting: 100%|██████████| 98/98 [00:04<00:00, 24.11it/s]
Test Epoch: [800/800] Acc@1:79.43%: 100%|██████████| 20/20 [00:02<00:00,  8.18it/s]
