In [None]:
!git clone https://github.com/spijkervet/SimCLR.git
%cd SimCLR
!mkdir -p logs && cd logs && wget https://github.com/Spijkervet/SimCLR/releases/download/1.2/checkpoint_100.tar && cd ../
!sh setup.sh || python3 -m pip install -r requirements.txt || exit 1
!pip install  pyyaml --upgrade
!pip install gdown

In [None]:
import os
import torch
import numpy as np
import torch.nn as nn

import torchvision
import argparse

from torch.utils.tensorboard import SummaryWriter

from model import save_model, load_optimizer
from modules import SimCLR, get_resnet, NT_Xent
from modules.transformations import TransformsSimCLR

In [None]:
args.batch_size = 64
args.dataset = "local" # make sure to check this with the (pre-)trained checkpoint
args.resnet = "resnet50" # make sure to check this with the (pre-)trained checkpoint
args.model_path = "logs"
args.epoch_num = 200
args.logistic_epochs = 500

In [None]:
import gdown
file_id="1nb__5N4HRDEJt-SILcyUBjPXcXqb2jPT"
url = f'https://drive.google.com/uc?id={file_id}'
gdown.download(url, f'{args.model_path}/checkpoint_200.tar', quiet=False)

In [None]:
image_datasets = {
    'train': torchvision.datasets.ImageFolder(
        'kneeKL224/train', 
        transform=TransformsSimCLR(size=args.image_size).train_transform
        ),
    'validation': torchvision.datasets.ImageFolder(
        'kneeKL224/val', 
        transform=TransformsSimCLR(size=args.image_size).test_transform
        )
}

dataloaders = {
    'train': DataLoader(
        image_datasets['train'], 
        batch_size=args.logistic_batch_size, 
        shuffle=True, 
        num_workers=args.workers),
    'validation': DataLoader(
        image_datasets['validation'], 
        batch_size=args.logistic_batch_size, 
        shuffle=False, 
        num_workers=args.workers)
}

In [None]:
encoder = get_resnet(args.resnet, pretrained=False) # don't load a pre-trained model from PyTorch repo
n_features = encoder.fc.in_features  # get dimensions of fc layer

# load pre-trained model from checkpoint
simclr_model = SimCLR(args, encoder, n_features)
model_fp = os.path.join(
    args.model_path, "checkpoint_{}.tar".format(args.epoch_num)
)
simclr_model.load_state_dict(torch.load(model_fp, map_location=args.device.type))
simclr_model = simclr_model.to(args.device)

In [None]:
output_feature_dim = simclr_model.projector[0].in_features

In [None]:
simclr_model.projector = nn.Sequential(
    nn.Linear(output_feature_dim, output_feature_dim // 2),
    nn.ReLU(inplace=True),
    nn.Dropout(p=0.4),
    nn.Linear(output_feature_dim // 2, 5),
    nn.LogSoftmax(dim=1)).to(args.device)

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=3):
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-' * 10)

        for phase in ['train', 'validation']:
            if phase == 'train':
                scheduler.step()
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            valid_acc = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)

                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                _, preds = torch.max(outputs, 1)
                running_loss += loss.item() * inputs.size(0)
                correct_tensor = preds.eq(labels.data.view_as(preds))
                accuracy = torch.mean(correct_tensor.type(torch.FloatTensor))
                valid_acc += accuracy.item() * inputs.size(0)

            epoch_loss = running_loss / len(image_datasets[phase])
            epoch_acc = valid_acc / len(image_datasets[phase])

            print('{} loss: {:.4f}, acc: {:.4f}'.format(phase,
                                                        epoch_loss,
                                                        epoch_acc))
    return model

In [None]:
simclr_model

In [None]:
plist = [
        {'params': encoder.encoder[7].parameters(), 'lr': 1e-5},
        {'params': encoder.projetion.parameters(), 'lr': 5e-3}
        ]
optimizer_ft = optim.Adam(plist, lr=0.001)
criterion = nn.NLLLoss()
lr_sch = lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)


model_ft = train_model(encoder,
                       criterion,
                       optimizer_ft,
                       lr_sch,
                       num_epochs=3)

torch.save(model_ft.state_dict(), "model.bin")