In [None]:
import torch
from torch.nn import functional as F
import torchvision
import torchvision.transforms as transforms
from train import SimCLR_Train
from models import SimCLR_Model

In [None]:
# install required packages
!pip install pytorch-ignite
!pip install torchlars

In [None]:
# credit: https://github.com/kuangliu/pytorch-cifar/issues/19
# for CIFAR-10 normalization values
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])

batch_size_list = [256, 512, 1024, 2048, 4096]

# the maximum batchsize is 1024
batch_size = batch_size_list[3]

train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                          shuffle=True, num_workers=2)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

In [None]:
resnet20 = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_resnet20", pretrained=False)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = SimCLR_Model(base_encoder=resnet20)
net = net.to(device)

In [None]:
# uncomment this line if using adam
# train_loss_list = SimCLR_Train(net, device, batch_size, train_loader, num_epoch=100, temp=0.5, lr=1e-4, optim='adam', lr_scheduler='None')

# this line for lars
# temp = [0.1, 0.5, 1.0]
# lr = [0.5, 1.0, 1.5]
# epoch = [100, 200, 300]
train_loss_list = SimCLR_Train(net, device, batch_size, train_loader, num_epoch=100, temp=0.5, lr=1, optim='lars', lr_scheduler='cosine')

In [None]:
# uncomment this line if using adam
# train_loss_list = SimCLR_Train(net, device, batch_size, train_loader, num_epoch=100, temp=0.5, lr=1e-4, optim='adam', lr_scheduler='None')

# this line for lars
# temp = [0.1, 0.5, 1.0]
# lr = [0.5, 1.0, 1.5]
# epoch = [100, 200, 300]
train_loss_list = SimCLR_Train(net, device, batch_size, train_loader, num_epoch=100, temp=0.5, lr=1, optim='lars', lr_scheduler='cosine')