In [None]:
# python
from pathlib import Path
import time

# torch
import torch
import torch.nn as nn
import torch.nn.functional as F

# non torch
from dotted_dict import DottedDict
import pprint
from tqdm import tqdm
from sklearn.neighbors import KNeighborsClassifier
import numpy as np

# local
from utils import AverageMeter, get_dataset, get_backbone, get_optimizer, get_scheduler
from augmentations import get_aug
from model import SimSiam, DownStreamClassifier
import utils
import configs

In [None]:
pp = pprint.PrettyPrinter(indent=4)

In [None]:
p_ckpt = Path(
    "/mnt/experiments/simsiam/run_cifar10_resnet18_20201204-135350/ckpts/model_cifar10_epoch_000099.ckpt")
assert p_ckpt.exists()

In [None]:
ckpt = torch.load(p_ckpt)

train_config = ckpt["config"]
pp.pprint(train_config)


In [None]:
config = configs.get_config(train_config.dataset,train=False)

pp.pprint(config)

In [None]:
# prepare data
train_set = get_dataset(
    train_config.dataset,
    train_config.p_data,
    transform=get_aug(train_config.img_size, train=True, train_classifier=True, means_std=train_config.mean_std),
    train=True,
    download=False
)
if train_config.dataset == "stl10":
    # stl10 has only 5000 labeled samples in its train set
    train_set = torch.utils.data.Subset(train_set, range(0, 5000))

test_set = get_dataset(
    train_config.dataset,
    train_config.p_data,
    transform=get_aug(train_config.img_size, train=True, train_classifier=True, means_std=train_config.mean_std),
    train=False,
    download=False
)
train_loader = torch.utils.data.DataLoader(
    dataset=train_set,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    pin_memory=True,
    drop_last=True
)
test_loader = torch.utils.data.DataLoader(
    dataset=test_set,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=config.num_workers,
    pin_memory=True,
    drop_last=True
)

In [None]:
print(len(train_loader))
print(len(test_loader))

In [None]:
# create model
backbone = get_backbone(train_config.backbone)
model = SimSiam(backbone, train_config.projector_args, train_config.predictor_args)

# load weights
#msg = model.load_state_dict(ckpt["state_dict"], strict=True)
#print("Loading weights: {}".format(msg))

In [None]:
model = DownStreamClassifier(model, 2048, 512, 10).to(config.device)
#for name, param in model.named_parameters():
#    print(name, param.requires_grad)

In [None]:
optimizer = get_optimizer(config.optimizer, model, config.optimizer_args)
lr_scheduler = lr_scheduler = get_scheduler(config.scheduler, optimizer, config.scheduler_args)
#optimizer = torch.optim.SGD(model.parameters(), lr=0.5, weight_decay=5e-4, momentum=0.9)
#lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=0.0001)

criterion = nn.CrossEntropyLoss()

In [None]:
lr_scheduler

In [None]:
def _train(epoch, train_loader, model, optimizer, criterion):
    model.train()

    losses, acc, step, total = 0., 0., 0., 0.
    for data, target in train_loader:
        data, target = data.to(config.device), target.to(config.device)

        logits = model(data)

        optimizer.zero_grad()
        loss = criterion(logits, target)
        loss.backward()
        losses += loss.item()
        optimizer.step()

        pred = F.softmax(logits, dim=-1).max(-1)[1]
        acc += pred.eq(target).sum().item()

        step += 1
        total += target.size(0)

    print('[Train Epoch: {0:4d}], loss: {1:.3f}, acc: {2:.3f}'.format(epoch, losses / step, acc / total * 100.))


def _eval(epoch, test_loader, model, criterion):
    model.eval()

    losses, acc, step, total = 0., 0., 0., 0.
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(config.device), target.to(config.device)

            logits = model(data)
            loss = criterion(logits, target)
            losses += loss.item()
            pred = F.softmax(logits, dim=-1).max(-1)[1]
            acc += pred.eq(target).sum().item()

            step += 1
            total += target.size(0)
        print('[Test Epoch: {0:4d}], loss: {1:.3f}, acc: {2:.3f}'.format(epoch, losses / step, acc / total * 100.))


In [None]:
for epoch in range(1, 100):
    _train(epoch, train_loader, model, optimizer, criterion)
    _eval(epoch, test_loader, model, criterion)
    lr_scheduler.step()