In [None]:
%load_ext autoreload
%autoreload 3

In [None]:
import torch
from torchvision.models import resnet50
from tqdm import tqdm
from pathlib import Path
from model import SimSiam
from utils import AverageMeter, get_dataset, get_backbone, get_optimizer, get_scheduler
from augmentations import get_aug
from dotted_dict import DottedDict
import datetime
from torch.utils.tensorboard import SummaryWriter

### TODO
- tensorboard
- multi gpu training

# Hyperparameters

In [None]:
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
timestamp = "tmp"

In [None]:
config = DottedDict()
config.p_data = Path("/usr/data/pytorch")
config.p_train = Path("/usr/experiments/simsiam") / "run_{}".format(timestamp)
config.p_ckpts = config.p_train / "ckpts"
config.p_logs = config.p_train / "logs"
config.fs_ckpt = "model_{}_epoch_{:0>6}.ckpt"
config.mean_std = [[0.485, 0.456, 0.406],[0.229, 0.224, 0.225]]
config.dataset = "cifar10"
config.backbone = "resnet18"
config.batch_size = 512
config.num_epochs = 800
config.img_size = 32
config.optimizer = "sgd"
config.optimizer_args = {
     "lr": 0.03,
     "weight_decay": 0.0005,
     "momentum": 0.9
}
config.scheduler = "cosine_decay"
config.scheduler_args = {
    "T_max": 800,
    "eta_min": 0,
}
config.debug = False
config.num_workers = 8
config.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
config.resume = False
#
# debug settings
if config.debug:
    config.batch_size = 2 
    config.num_epochs = 5 # train only one epoch
    config.num_workers = 1

### Prepare Data

In [None]:
transform = get_aug(img_size=config.img_size,
                    train=True,
                    train_classifier=False,
                    means_std=config.mean_std)

In [None]:
train_set = get_dataset(config.dataset, config.p_data, transform=transform)
if config.debug:
    train_set = torch.utils.data.Subset(train_set, range(0, config.batch_size)) # take only one batch

In [None]:
train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True,
        drop_last=True
    )

### Prepare model

In [None]:
backbone = get_backbone(config.backbone)
model = SimSiam(backbone).to(config.device)

### Prepare optimizer

In [None]:
optimizer = get_optimizer(config.optimizer, model, config.optimizer_args)

In [None]:
# define lr scheduler
lr_scheduler = get_scheduler(config.scheduler, optimizer, config.scheduler_args)

In [None]:
loss_meter = AverageMeter("loss")

## train run

In [None]:
# create train dir
config.p_logs.mkdir(exist_ok=True, parents=True)
config.p_ckpts.mkdir(exist_ok=True, parents=True)
#
# tensorboard writer
writer = SummaryWriter(config.p_logs)
print("tensorboard --logdir={} --host=0.0.0.0".format(str(config.p_logs)))
#
for epoch in tqdm(range(1, config.num_epochs+1), desc=f'Training'):
    loss_meter.reset()
    model.train()
    p_bar=tqdm(train_loader, desc=f'Epoch {epoch}/{config.num_epochs}')
    for idx, ((images1, images2), labels) in enumerate(p_bar):
        model.zero_grad()
        loss = model.forward(images1.to(config.device), images2.to(config.device))
        loss.backward()
        optimizer.step()
        loss_meter.update(loss.item())
        p_bar.set_postfix({"loss":loss_meter.val, 'loss_avg':loss_meter.avg})
        lr_scheduler.step()
        writer.add_scalar('loss', loss_meter.val, epoch * len(train_loader) + idx)
        writer.add_scalar('avg_loss', loss_meter.avg, epoch * len(train_loader) + idx)
    
    # Save checkpoint
    p_ckpt = config.p_ckpts / config.fs_ckpt.format(config.dataset, epoch)
    torch.save({
        'epoch': epoch,
        'state_dict':model.state_dict(),
        # 'optimizer':optimizer.state_dict(), # will double the checkpoint file size
        'lr_scheduler':lr_scheduler.state_dict(),
        'config': config,
        'loss_meter':loss_meter
        }, p_ckpt)
    print(f"Model saved to {p_ckpt}")

# Test results

In [None]:
X_test_1 = torch.rand(4, 3, config.img_size, img_size).to(config.device)
X_test_2 = torch.rand(X_test_1.shape).to(config.device)

In [None]:
model = model.eval()
L_test_1 = model.forward(X_test_1, X_test_2)
#
model = SimSiam(resnet50()).to(device)
#
p_model = fs_p_model.format(dataset, num_epochs)
model.load_state_dict(torch.load(p_model)["state_dict"])
model = model.eval()
#
L_test_2 = model.forward(X_test_1, X_test_2)

In [None]:
L_test_1 - L_test_2

In [None]:
L_test_1

# delete stuff

In [None]:
!rm -rf /usr/experiments/simsiam/tmp