# import

In [1]:
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torchvision
import pytorch_lightning as pl
import copy
import lightly
import numpy as np

import lightly.models as models
import lightly.loss as loss
import lightly.data as data

from lightly.models.utils import deactivate_requires_grad
from lightly.models.utils import update_momentum
from lightly.models.utils import batch_shuffle
from lightly.models.utils import batch_unshuffle

from classifier import Classifier
from dataloader import define_transforms, get_dataloader

from model import SimCLRModel, MocoModel, SimSiam
from configs.config import Config

# Arguments

In [2]:
c = Config()
    
gpus = 1 if torch.cuda.is_available() else 0
pl.seed_everything(c.seed)
np.random.seed(c.seed)

Global seed set to 1


# Data load

In [3]:
transforms_tr_classifier, transforms_tt = define_transforms(32)

dataset_train_ssl = lightly.data.LightlyDataset(input_dir = c.DATAPATH_tr)

train_ssl_loader, train_classifier_loader, test_loader = get_dataloader(c,
                                                                        transforms_tr_classifier, 
                                                                        transforms_tt)

# Model Define

In [4]:
model = SimSiam()

# Train backbone(SSL model)

In [5]:
# trainer = pl.Trainer(max_epochs=c.max_epochs, gpus=gpus)
# trainer.fit(model, train_ssl_loader)

# Save trained backbone

In [6]:
# state_dict = {
#     'params': model.backbone.state_dict()
# }
# torch.save(state_dict, 'ssl_weight/simsiam.pth')

# Load trained backbone

In [7]:
state_dict = torch.load('ssl_weight/simsiam.pth')
model.backbone.load_state_dict(state_dict['params'])

<All keys matched successfully>

# Classifier Define

In [8]:
classifier = Classifier(c, model.backbone, save_path='clf_weight/simsiam_clf.pth')

# Train classifier

In [9]:
trainer = pl.Trainer(max_epochs=c.clf_epochs, gpus=gpus)
trainer.fit(
    classifier,
    train_classifier_loader,
    test_loader
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | backbone  | Sequential       | 11.2 M
1 | fc        | Linear           | 5.1 K 
2 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 1


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

# Test performance of Selfsupervision

In [10]:
classifier.best_valacc

array(0.81399995, dtype=float32)