In [6]:
import os
import sys
import argparse
from pprint import pprint

import torch
import torch.nn as nn
import torchvision
import numpy as np
from torch.utils.tensorboard import SummaryWriter

sys.path.insert(0, '../')

from model import save_model, load_optimizer
from simclr import SimCLR
from simclr.modules import get_resnet, NT_Xent
from simclr.modules.transformations import TransformsSimCLR
from utils import yaml_config_hook

In [7]:
parser = argparse.ArgumentParser(description="SimCLR")
config = yaml_config_hook("../config/config.yaml")

for k, v in config.items():
    parser.add_argument(f"--{k}", default=v, type=type(v))
    
args = parser.parse_args([])
args.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [8]:
args.batch_size = 64
args.resnet = "resnet18"
pprint(vars(args))

{'batch_size': 64,
 'dataparallel': 0,
 'dataset': 'CIFAR10',
 'dataset_dir': './datasets',
 'device': device(type='cuda'),
 'epoch_num': 100,
 'epochs': 100,
 'gpus': 1,
 'image_size': 224,
 'logistic_batch_size': 256,
 'logistic_epochs': 500,
 'model_path': 'save',
 'nodes': 1,
 'nr': 0,
 'optimizer': 'Adam',
 'pretrain': True,
 'projection_dim': 64,
 'reload': False,
 'resnet': 'resnet18',
 'seed': 42,
 'start_epoch': 0,
 'temperature': 0.5,
 'weight_decay': 1e-06,
 'workers': 8}


In [9]:
torch.manual_seed(args.seed)
np.random.seed(args.seed)

if args.dataset == "STL10":
    train_dataset = torchvision.datasets.STL10(
        args.dataset_dir,
        split="unlabeled",
        download=True,
        transform=TransformsSimCLR(size=args.image_size),
    )
elif args.dataset == "CIFAR10":
    train_dataset = torchvision.datasets.CIFAR10(
        args.dataset_dir,
        download=True,
        transform=TransformsSimCLR(size=args.image_size),
    )
else:
    raise NotImplementedError

if args.nodes > 1:
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=args.world_size, rank=rank, shuffle=True
    )
else:
    train_sampler = None

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=(train_sampler is None),
    drop_last=True,
    num_workers=args.workers,
    sampler=train_sampler,
)

Files already downloaded and verified


In [10]:
encoder = get_resnet(args.resnet, pretrained=False)
n_features = encoder.fc.in_features

# init model
model = SimCLR(encoder, args.projection_dim, n_features)

if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
  model = nn.DataParallel(model)

if args.reload:
    model_fp = os.path.join(
        args.model_path, f"checkpoint_{args.epoch_num}.tar"
    )
    model.load_state_dict(torch.load(model_fp, map_location=args.device.type))

model = model.to(args.device)
optimizer, scheduler = load_optimizer(args, model)
criterion = NT_Xent(args.batch_size, args.temperature, world_size=1)
writer = SummaryWriter()

Let's use 4 GPUs!


In [11]:
def train(args, train_loader, model, criterion, optimizer, writer, display_every=50):
    """Train function"""
    epoch_loss = 0
    
    for step, ((x_i, x_j), _) in enumerate(train_loader):
        optimizer.zero_grad()
        x_i = x_i.cuda(non_blocking=True)
        x_j = x_j.cuda(non_blocking=True)
        
        # Positive pair with encoding
        h_i, h_j, z_i, z_j = model(x_i, x_j)
        
        loss = criterion(z_i, z_j)
        loss.backward()
        optimizer.step()
        
        if step % display_every == 0:
            print(f"Step [{step}/{len(train_loader)}]\t Loss: {loss.item()}")
        
        writer.add_scalar("Loss/train_epoch", loss.item(), args.global_step)
        epoch_loss += loss.item()
        args.global_step += 1
    
    return epoch_loss

In [12]:
args.global_step = 0
args.current_epoch = 0

for epoch in range(args.start_epoch, args.epochs):
    lr = optimizer.param_groups[0]["lr"]
    epoch_loss = train(args, train_loader, model, criterion, optimizer, writer)
    
    if scheduler:
        scheduler.step()
    
    if epoch % 10 == 0:
        save_model(args, model, optimizer)
    
    writer.add_scalar("Loss/train", loss_epoch / len(train_loader), epoch)
    writer.add_scalar("Misc/learning_rate", lr, epoch)

    print(
        f"Epoch [{epoch}/{args.epochs}]\t Loss: {loss_epoch / len(train_loader)}\t lr: {round(lr, 5)}"
    )
    args.current_epoch += 1

save_model(args, model, optimizer)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Step [0/781]	 Loss: 4.826760292053223
Step [50/781]	 Loss: 4.717038631439209
Step [100/781]	 Loss: 4.506049633026123
Step [150/781]	 Loss: 4.5885820388793945
Step [200/781]	 Loss: 4.671140670776367
Step [250/781]	 Loss: 4.622513771057129
Step [300/781]	 Loss: 4.636490345001221
Step [350/781]	 Loss: 4.625153064727783
Step [400/781]	 Loss: 4.531766891479492
Step [450/781]	 Loss: 4.612953186035156
Step [500/781]	 Loss: 4.49284029006958
Step [550/781]	 Loss: 4.513941764831543
Step [600/781]	 Loss: 4.399381637573242
Step [650/781]	 Loss: 4.435848236083984
Step [700/781]	 Loss: 4.199560642242432
Step [750/781]	 Loss: 4.257235050201416


FileNotFoundError: [Errno 2] No such file or directory: 'save/checkpoint_0.tar'

In [None]:
torch.cuda.empty_cache()