In [29]:
import lightly
import torchvision
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
from sklearn.model_selection import train_test_split
from pathlib import Path
import json
import os
from PIL import Image
import sklearn
from sklearn.model_selection import StratifiedShuffleSplit
from shutil import copyfile

In [31]:
# utils

def make_label_unlabeled_split(**kwargs):
    #split U and L from labeled dataset
    
    data_path = kwargs.get('data_path')
    saving_unlabeled_path = kwargs.get('U_path')
    saving_labeled_path = kwargs.get('L_path')
    sample = kwargs.get('L_size',0.01)

    saving_unlabeled_path.mkdir(parents=True, exist_ok=True)
    saving_labeled_path.mkdir(parents=True, exist_ok=True)

    folder_list = [folder for folder in data_path.iterdir() if folder.is_dir()]
    for folder in folder_list:

        dest_cat_path = saving_labeled_path.joinpath(folder.name)
        dest_cat_path.mkdir(parents = True,exist_ok=True)

        img_list = [filepath for filepath in folder.iterdir()]

        all_indexes = np.arange(len(img_list))
        label_idx = np.random.randint(len(img_list),size = int(sample*len(img_list)))
        unlabel_idx = np.array([idx for idx in all_indexes if idx not in label_idx])
        label_imgs = [img_list[i] for i in label_idx]
        unlabel_imgs = [img_list[i] for i in unlabel_idx]

        for img in label_imgs:
            dest_label_img_path = dest_cat_path.joinpath(img.name)
            copyfile(img,dest_label_img_path)

        for img in unlabel_imgs:
            dest_unlabel_img_path = saving_unlabeled_path.joinpath(img.name)
            copyfile(img,dest_unlabel_img_path)
            

class UnlabeledDataset(torch.utils.data.Dataset):
    def __init__(self, main_dir, transform):
        self.main_dir = main_dir
        self.transform = transform
        self.all_imgs = os.listdir(main_dir)
        
    def __len__(self):
        return len(self.all_imgs)

    def __getitem__(self, idx):
        img_loc = os.path.join(self.main_dir, self.all_imgs[idx])
        image = Image.open(img_loc).convert("RGB")
        tensor_image = self.transform(image)
        return tensor_image
    
def class_distribution(dataset,class_index_dict):
    count_dict = {}
    for img,label in dataset:
        if class_index_dict[label] not in count_dict:
            count_dict.update({class_index_dict[label]:0})
        else:
            count_dict[class_index_dict[label]] += 1
            
    return count_dict
    
def sample_dataset(dataset,sample):
    sample_idx = np.random.randint(len(unlabeled_dataset),size = int(sample*len(unlabeled_dataset)))
    return torch.utils.data.Subset(dataset, sample_idx)
    
def load_simcrl(simclr_results_path,n_categories,model_size = 18):
    
    #load config
    conf_path = simclr_results_path.joinpath('conf.json')
    with open(conf_path,'r') as f:
        conf = json.load(f)

    #load model
    model_path = simclr_results_path.joinpath('checkpoint.pth')

    num_ftrs = conf['num_ftrs']

    resnet = lightly.models.ResNetGenerator('resnet-'+str(model_size))
    last_conv_channels = list(resnet.children())[-1].in_features
    backbone = nn.Sequential(
        *list(resnet.children())[:-1],
        nn.Conv2d(last_conv_channels, num_ftrs, 1),
        nn.AdaptiveAvgPool2d(1)
    )

    model = lightly.models.SimCLR(backbone, num_ftrs=num_ftrs)

    encoder = lightly.embedding.SelfSupervisedEmbedding(
        model,
        None,
        None,
        None
    )

    encoder.model.load_state_dict(torch.load(model_path))
    teacher = Teacher(encoder.model,num_ftrs,n_categories).to(device)
    return teacher
    
    
def evaluate(model,testloader,loss_function):
  val_loss = 0
  total = 0
  correct = 0
  ground_truth_list = []
  predictions_list =  []
  for image,label in testloader:
      image, label = image.to(device), label.to(device)
      outputs = model(image)
      probabilities, predicted = torch.max(outputs.data, 1)
      val_loss += loss_function(outputs, label.long()).item()
      total += label.size(0)
      correct += (predicted == label).sum().item()
      ground_truth_list += list(label.cpu())
      predictions_list += list(predicted.cpu())

  acc = sklearn.metrics.accuracy_score(ground_truth_list,predictions_list)
  f1 = sklearn.metrics.f1_score(ground_truth_list,predictions_list,average = 'macro')
  precision = sklearn.metrics.precision_score(ground_truth_list,predictions_list,average = 'macro')
  recall = sklearn.metrics.recall_score(ground_truth_list,predictions_list,average = 'macro')
  print(f'acc:{acc:.3f} f1:{f1:.3f} precision:{precision:.3f} recall:{recall:.3f}')

  metrics_dict = {'val_loss':val_loss,'acc':acc,'f1':f1,'precision':precision,'recall':recall}

  return metrics_dict



# Big-self supervised models are strong semi-supervised radiologists


Inspired by [this paper](https://arxiv.org/pdf/2006.10029.pdf)


We have a small subset of labeled data $L$ and a large pool of unlabeled data $U$. The goal is to make the most out of $U$ for training a classifier for solving the task on $L$.

The procedure has three steps: 

* Pretrain a big SimCLR model on $U$
* Fine-tune on $L$
* Use the resulting model as a teacher for a smaller model, which is trained on the predictions of the teacher model on $U$



We will split $L$ into a training and testing set using stratified sampling for getting splits with the same proportions of labels. 





In [32]:
# results path
results_path = Path('/projects/self_supervised/results/sortifier_distillation')
results_path.mkdir(parents=True, exist_ok=True)


In [43]:
# load data 

sample_unlabeled = 1.0
test_size = 0.8

data_path = Path('/projects/self_supervised/data/sortifier')
unlabeled_path = Path('/projects/self_supervised/data/sortifier_unlabeled')
labeled_path = Path('/projects/self_supervised/data/sortifier_labeled')

# #not required if you already splitted the data
# make_label_unlabeled_split(
#     data_path = data_path,
#     U_path = unlabeled_path,
#     L_path = labeled_path,
#     L_size = 0.02
# )


input_size = 64
batch_size = 16
num_workers = 2

transform = transforms.Compose([
    transforms.Resize((input_size, input_size)),
    transforms.ToTensor(),
    torchvision.transforms.Normalize(
    mean=lightly.data.collate.imagenet_normalize['mean'],
    std=lightly.data.collate.imagenet_normalize['std'],
    )
])

labeled_dataset = datasets.ImageFolder(root=labeled_path, transform=transform)

sss = StratifiedShuffleSplit(n_splits=1, test_size=test_size)

img_paths = [item[0] for item in labeled_dataset.imgs]
labels = [item[1] for item in labeled_dataset.imgs]
idx_train, idx_test = next(sss.split(img_paths,labels))

trainset = torch.utils.data.Subset(labeled_dataset, idx_train)
trainloader = torch.utils.data.DataLoader(
    trainset, 
    batch_size=batch_size,
    shuffle=True, 
    num_workers=num_workers,
    drop_last=True
)

testset = torch.utils.data.Subset(labeled_dataset, idx_test)
testloader = torch.utils.data.DataLoader(
    testset, 
    batch_size=batch_size,
    shuffle=True, 
    num_workers=num_workers,
    drop_last=True
)

class_index_dict = {v:k for k,v in testset.dataset.class_to_idx.items()}

#load unlabeled data
unlabeled_dataset = UnlabeledDataset(unlabeled_path,transform = transform)
unlabeled_dataset = sample_dataset(unlabeled_dataset,sample_unlabeled)
unlabeledloader = torch.utils.data.DataLoader(
    unlabeled_dataset,
    batch_size=batch_size,
    shuffle=True, 
    num_workers=num_workers
)

In [44]:
print('training',len(trainset))
#print('training',class_distribution(trainset,class_index_dict))
print('testing',len(testset))
#print('testing',class_distribution(testset,class_index_dict))
print('unlabeled',len(unlabeled_dataset))

training 107
testing 431
unlabeled 26640


# Vanilla student model (resnet18)

We train on $L$ a resnet18 as a baseline.

We define the softmax-like function $P(y|x_{i})=\frac{exp(f(x_{i})[y]/\tau}{\sum_{y'} exp(f(x_{i})[y']/\tau}$, where $\tau$ is a temperature parameter and $f$ the model. 

We will use cross-entropy defined as $-\sum_{(x_{i},y_{i})} \left[ log P(y_{i}|x_{i})\right]$ as the loss function for training on $L$

In [45]:
class Student(nn.Module):
    def __init__(self,output_size):
        super().__init__()
        
        self.net = models.resnet18(pretrained=True)
        self.net.fc = nn.Linear(self.net.fc.in_features, output_size)

    def forward(self, x):
        out = self.net(x)
        return out
    
def P(x,tau = 1.0):
  return torch.exp(x/tau)/(torch.exp(x/tau).sum())

class CrossEntropyLoss(torch.nn.Module):

    def __init__(self,n_categories):
        super(CrossEntropyLoss,self).__init__()
        self.n_classes = n_categories

    def forward(self, prediction, label):
      label = torch.nn.functional.one_hot(label,num_classes=n_categories)
      loss = -label*torch.log(P(prediction))
      return loss.sum().sum()

In [46]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_categories = len([cat for cat in labeled_path.iterdir()])
crossent_loss = CrossEntropyLoss(n_categories)

In [47]:

#train the student model on labeled data

student = Student(n_categories).to(device)
optimizer = optim.Adam(student.parameters(), lr=0.0001)

patience = 3
count = 0
best_loss = 1e9
for epoch in range(40):
  for image,label in trainloader:
      image, label = image.to(device), label.to(device)
      optimizer.zero_grad()
      loss = crossent_loss(student(image), label)
      loss.backward()
      optimizer.step()
        
  metrics_dict = evaluate(student,testloader,crossent_loss)
  val_loss = metrics_dict['val_loss']
  if val_loss < best_loss:        
        best_loss = val_loss
        best_metrics = metrics_dict
        
        count = 0
  else:
    count += 1
  if count > patience:
    break
        
        
print('\nBest metrics:')
acc = best_metrics['acc']
f1 = best_metrics['f1']
recall = best_metrics['recall']
precision = best_metrics['precision']
print(f'acc:{acc:.3f} f1:{f1:.3f} precision:{precision:.3f} recall:{recall:.3f}')


acc:0.832 f1:0.604 precision:0.691 recall:0.619
acc:0.875 f1:0.689 precision:0.796 recall:0.691
acc:0.839 f1:0.656 precision:0.724 recall:0.676
acc:0.882 f1:0.667 precision:0.752 recall:0.681
acc:0.923 f1:0.755 precision:0.848 recall:0.731
acc:0.909 f1:0.765 precision:0.897 recall:0.723
acc:0.916 f1:0.715 precision:0.846 recall:0.698
acc:0.933 f1:0.776 precision:0.904 recall:0.746
acc:0.928 f1:0.766 precision:0.868 recall:0.734
acc:0.911 f1:0.729 precision:0.860 recall:0.684
acc:0.925 f1:0.736 precision:0.834 recall:0.715
acc:0.942 f1:0.773 precision:0.897 recall:0.745
acc:0.947 f1:0.811 precision:0.961 recall:0.780
acc:0.940 f1:0.783 precision:0.907 recall:0.756
acc:0.916 f1:0.736 precision:0.895 recall:0.693
acc:0.923 f1:0.754 precision:0.896 recall:0.719
acc:0.923 f1:0.732 precision:0.874 recall:0.703

Best metrics:
acc:0.947 f1:0.811 precision:0.961 recall:0.780


# Big teacher (SimCLR)


We fine-tune on $L$ a SimCLR model with backbone resnet50 pretrained on $U$


In [50]:
class Teacher(nn.Module):
    def __init__(self, model,num_ftrs,output_dim):
        super().__init__()
        
        self.freeze = False
        self.net = model
        self.fc1 = nn.Linear(num_ftrs, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, output_dim)
        
    def forward(self, x):
        if self.freeze:
            with torch.no_grad():
                y_hat = self.net.backbone(x).squeeze()
                y_hat = self.fc1(y_hat)
                y_hat = self.relu(y_hat)
                y_hat = self.fc2(y_hat)
                return y_hat
        else:
            y_hat = self.net.backbone(x).squeeze()
            y_hat = self.fc1(y_hat)
            y_hat = self.relu(y_hat)
            y_hat = self.fc2(y_hat)
            return y_hat
            
    def freeze_weights(self):
      for p in self.net.parameters():
          p.requires_grad = False
      self.freeze = True
      return self

    def unfreeze_weights(self):
      for p in self.net.parameters():
          p.requires_grad = True
      self.freeze = False
      return self
                

In [51]:
#finetune simclr
simclr_results_path = Path('/projects/self_supervised/results/sortifier_unlabeled')

teacher = load_simcrl(simclr_results_path,n_categories,model_size = 50)
optimizer = optim.Adam(teacher.parameters(), lr=0.0001)

patience = 3
count = 0
best_loss = 1e9
for epoch in range(40):
  for image,label in trainloader:
      image, label = image.to(device), label.to(device)
      optimizer.zero_grad()
      loss = crossent_loss(teacher(image), label.long())
      loss.backward()
      optimizer.step()

  metrics_dict = evaluate(teacher,testloader,crossent_loss)
  val_loss = metrics_dict['val_loss']
  if val_loss < best_loss:
        torch.save(teacher.state_dict(),results_path.joinpath('teacher.pth'))
        best_loss = val_loss
        best_metrics = metrics_dict
        count = 0
        
  else:
    count += 1
  if count > patience:
    break

print('\nBest metrics:')
acc = best_metrics['acc']
f1 = best_metrics['f1']
recall = best_metrics['recall']
precision = best_metrics['precision']
print(f'acc:{acc:.3f} f1:{f1:.3f} precision:{precision:.3f} recall:{recall:.3f}')


  _warn_prf(average, modifier, msg_start, len(result))


acc:0.911 f1:0.690 precision:0.698 recall:0.690
acc:0.938 f1:0.743 precision:0.963 recall:0.739
acc:0.945 f1:0.768 precision:0.967 recall:0.749
acc:0.909 f1:0.693 precision:0.951 recall:0.646
acc:0.935 f1:0.796 precision:0.928 recall:0.747
acc:0.945 f1:0.825 precision:0.907 recall:0.793
acc:0.947 f1:0.803 precision:0.968 recall:0.773
acc:0.945 f1:0.774 precision:0.965 recall:0.768
acc:0.942 f1:0.783 precision:0.867 recall:0.771
acc:0.964 f1:0.881 precision:0.952 recall:0.853
acc:0.954 f1:0.861 precision:0.923 recall:0.836
acc:0.957 f1:0.838 precision:0.905 recall:0.820
acc:0.959 f1:0.856 precision:0.921 recall:0.832
acc:0.952 f1:0.819 precision:0.894 recall:0.805

Best metrics:
acc:0.964 f1:0.881 precision:0.952 recall:0.853


# Model distillation

The fine-tuned model often yields better performance than the vanilla model. Now that we used $L$ for finetuning the teacher model, we can use $U$ again for transfering the knowledge from the teacher to the student

The procedure is as follows:
* sample data from $U$
* predict with the teacher
* use the labels as targets for training the student

This algorithm is expected to yield an even better model. Notice that the resulting model would be a resnet18 with even better performance than a resnet50 pretrained on $U$ and fine-tuned on $L$.

We will use a distillation loss, which takes the output probabilities of the teacher model as the target for the student model.

$$ -\sum_{x_{i}} \left[ \sum_{y} P^{T}(y|x_{i};\tau) log P^{S}(y|x_{i};\tau) \right] $$

With this loss function the student doesn't only see the one-vs-zero encoding used in supervised learning, but a probability distribution as a target. This could help the student model to learn nuances of the data

In [52]:
class DistilLoss(torch.nn.Module):

    def __init__(self):
        super(DistilLoss,self).__init__()

    def forward(self, t_output, s_output):
      loss = -P(t_output)*torch.log(P(s_output))
      return loss.sum().sum()
     

In [None]:
#init student
student = Student(n_categories).to(device)

#load teacher checkpoint
teacher = load_simcrl(simclr_results_path,n_categories,model_size = 50)
teacher.load_state_dict(torch.load(results_path.joinpath('teacher.pth')))
teacher = teacher.freeze_weights()
teacher.eval()

distill_loss = DistilLoss()
optimizer = optim.Adam(student.parameters(), lr=0.00001)

patience = 5
count = 0
best_loss = 1e9
for epoch in range(40):
  for image in unlabeledloader:
    image = image.to(device)
    loss = distill_loss(teacher(image),student(image))
    loss.backward()
    optimizer.step()

  metrics_dict = evaluate(student,testloader,crossent_loss)
  val_loss = metrics_dict['val_loss']
  if val_loss < best_loss:        
        best_loss = val_loss
        best_metrics = metrics_dict
        count = 0
  else:
    count += 1
  if count > patience:
    break

print('\nBest metrics:')
acc = best_metrics['acc']
f1 = best_metrics['f1']
recall = best_metrics['recall']
precision = best_metrics['precision']
print(f'acc:{acc:.3f} f1:{f1:.3f} precision:{precision:.3f} recall:{recall:.3f}')


acc:0.942 f1:0.745 precision:0.842 recall:0.743
acc:0.947 f1:0.725 precision:0.719 recall:0.733


  _warn_prf(average, modifier, msg_start, len(result))


acc:0.952 f1:0.796 precision:0.974 recall:0.776
acc:0.950 f1:0.836 precision:0.871 recall:0.816
acc:0.957 f1:0.818 precision:0.972 recall:0.797
