In [None]:
from pathlib import Path
import torch
from augmentations import get_aug
from utils import get_dataset
from torchvision.models import resnet50
from model import SimSiam
import torch.nn as nn
import torch.nn.functional as F 
from utils import AverageMeter
from tqdm import tqdm
import time

In [None]:
from pathlib import Path
#
import torch
import torch.nn as nn
import torch.nn.functional as F 
#
from dotted_dict import DottedDict
import pprint
from tqdm import tqdm
#
from utils import AverageMeter, get_dataset, get_backbone, get_optimizer, get_scheduler
from augmentations import get_aug
from model import SimSiam

In [None]:
pp = pprint.PrettyPrinter(indent=4)

In [None]:
p_ckpt = Path(
    "/usr/experiments/simsiam/run_20201202-201102/ckpts/model_cifar10_epoch_000012.ckpt")
assert p_ckpt.exists()

In [None]:
ckpt = torch.load(p_ckpt)

In [None]:
train_config = ckpt["config"]
pp.pprint(train_config)

In [None]:
config = DottedDict()
config.device = 'cuda:1'
config.optimizer = 'sgd'
config.optimizer_args = {
    'lr': 30,
    'weight_decay': 0,
    'momentum': 0.9
}
config.batch_size = 256
config.img_size = train_config.img_size
config.debug = False
config.num_workers = 8
config.num_epochs = 800

### prepare data

In [None]:
train_set = get_dataset(
        train_config.dataset, 
        train_config.p_data, 
        transform=get_aug(config.img_size, train=True, train_classifier=True), 
        train=True, 
        download=False
    )
test_set = get_dataset(
        train_config.dataset, 
        train_config.p_data, 
        transform=get_aug(config.img_size, train=True, train_classifier=True),
        train=False, 
        download=True # default is False
    )
if config.debug:
    train_set = torch.utils.data.Subset(train_set, range(0, config.batch_size)) # take only one batch
    test_set = torch.utils.data.Subset(test_set, range(0, config.batch_size))

train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True,
        drop_last=True
    )
test_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True,
        drop_last=True
    )

### load model

In [None]:
# load backbone
backbone = get_backbone(train_config["backbone"])
in_features = backbone.fc.in_features

In [None]:
backbone.fc = nn.Identity()

In [None]:
model = backbone

In [None]:
state_dict = {k[9:]:v for k, v in ckpt['state_dict'].items() if k.startswith('backbone.')}

In [None]:
model.load_state_dict(state_dict, strict=True)

In [None]:
model = model.to(config.device)

In [None]:
classifier = nn.Linear(in_features=in_features, out_features=len(train_set.classes), bias=True)
classifier = classifier.to(config.device)

In [None]:
optimizer = get_optimizer(config.optimizer, classifier, config.optimizer_args)

In [None]:
loss_meter = AverageMeter(name='Loss')
acc_meter = AverageMeter(name='Accuracy')

In [None]:
for epoch in range(1, config.num_epochs + 1):
    #
    # TRAIN LOOP
    #
    loss_meter.reset()
    model.eval()
    classifier.train()
    p_bar=tqdm(train_loader, desc=f'Epoch {epoch}/{config.num_epochs}', position=1)
    for idx, (images, labels) in enumerate(p_bar):
        classifier.zero_grad()
        with torch.no_grad():
            feature = model(images.to(config.device))
        preds = classifier(feature)
        #
        loss = F.cross_entropy(preds, labels.to(config.device))
        optimizer.step()
        loss_meter.update(loss.item())
        p_bar.set_postfix({"loss":loss_meter.val, 'loss_avg':loss_meter.avg})
    #
    # EVAL LOOP
    #
    classifier.eval()
    correct, total = 0, 0
    acc_meter.reset()
    p_bar=tqdm(test_loader, desc=f'Test {epoch}/{config.num_epochs}')
    for idx, (images, labels) in enumerate(p_bar):
        with torch.no_grad():
            feature = model(images.to(config.device))
            preds = classifier(feature).argmax(dim=1)
            correct = (preds == labels.to(config.device)).sum().item()
            acc_meter.update(correct/preds.shape[0])
            p_bar.set_postfix({'accuracy': acc_meter.avg})