## Preliminary

In [None]:
!pip install -q pytorch-metric-learning[with-hooks]
!git clone https://github.com/manuel-tran/s5cl.git
%cd s5cl

In [None]:
import os
import sys
import random
import numpy as np

import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms

from pytorch_metric_learning import samplers
from pytorch_metric_learning.utils import common_functions
from pytorch_metric_learning import losses, miners

from s5cl.models import MLP
from s5cl.methods import s5cl
from s5cl.transforms import Transform
from s5cl.datasets import make_dataset

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Dataset

In [2]:
# uncomment to download dataset
'''
!wget https://zenodo.org/record/1214456/files/NCT-CRC-HE-100K.zip?download=1
!unzip -q NCT-CRC-HE-100K.zip?download=1
!rm NCT-CRC-HE-100K.zip?download=1

!wget https://zenodo.org/record/1214456/files/CRC-VAL-HE-7K.zip?download=1
!unzip -q CRC-VAL-HE-7K.zip?download=1
!rm CRC-VAL-HE-7K.zip?download=1
'''

# mean and std of NCT-CRC-HE-100K
mean, std = [0.7406, 0.5331, 0.7059], [0.1279, 0.1606, 0.1191]

# define training and test augmentation
transform_t = Transform(mean, std)
transform_v = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])

# get test dataset and test dataloader
dataset_v = datasets.ImageFolder(root="CRC-VAL-HE-7K", transform=transform_v)
dataloader_v = DataLoader(dataset_v, batch_size=128, shuffle=False, num_workers=4)

# get training set and labeled / unabeled split
dataset_t = datasets.ImageFolder(root="NCT-CRC-HE-100K", transform=transform_t)
targets_l, dataset_l, dataset_u = make_dataset(dataset_t, images_per_class=5)

# sample four images per class
sampler = samplers.MPerClassSampler(targets_l, m=4, length_before_new_iter=len(dataset_l))

# get labeled and unlabeled dataloader
dataloader_l = DataLoader(dataset_l, batch_size=32, sampler=sampler, num_workers=4)
dataloader_u = DataLoader(dataset_u, batch_size=128, shuffle=True, num_workers=4)

# check dataset and dataloader length
print(len(dataset_v), len(dataset_t), len(dataset_l), len(dataset_u))
print(len(dataloader_v), len(dataloader_l), len(dataloader_u))

7180 100000 45 99955
57 2 781


## Models

In [3]:
# define encoder, embedder, and classifier
encoder = torchvision.models.resnet18(pretrained=True).to(device)

encoder_output_size = encoder.fc.in_features
encoder.fc = common_functions.Identity()

embedder = MLP([encoder_output_size, 64]).to(device)
classifier = MLP([64, 9]).to(device)

# define supervised, self-supervised, semi-supervised, and classification loss
criterion_l = losses.SupConLoss(temperature=0.2) 
criterion_u = losses.SupConLoss(temperature=0.7)
criterion_p = losses.SupConLoss(temperature=0.7)
criterion_c = torch.nn.CrossEntropyLoss()

# set weights for each loss
weight_l = 1.0
weight_u = 1.0
weight_p = 1.0
weight_c = 1.0

# decide which optimizer to use
optimizer_enc = torch.optim.Adamax(encoder.parameters(), lr=0.0001, weight_decay=0.0001)
optimizer_emb = torch.optim.Adamax(embedder.parameters(), lr=0.0001, weight_decay=0.0001)
optimizer_cls = torch.optim.Adamax(classifier.parameters(), lr=0.0001, weight_decay=0.0001)

## Training

In [4]:
#train with S5CL
epoch = 5
threshold = len(dataloader_u) 
total_steps = len(dataloader_u) * epoch
args = {'start_step': 0, 'total_steps': total_steps+1, 'eval_step': total_steps, 'threshold': threshold}

s5cl(args,
     encoder, 
     embedder, 
     classifier, 
     optimizer_enc, 
     optimizer_emb, 
     optimizer_cls, 
     criterion_l, 
     criterion_u, 
     criterion_p, 
     criterion_c, 
     weight_l, 
     weight_u, 
     weight_p, 
     weight_c, 
     dataloader_l, 
     dataloader_u, 
     dataloader_v, 
     device
    )


Average test loss: 0.4013  Accuracy: 6562/ 7180 (91.39%)

