In [1]:
import lightly
import torchvision
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
from sklearn.model_selection import train_test_split
from pathlib import Path
import json
import os
from PIL import Image
import sklearn
from sklearn.model_selection import StratifiedShuffleSplit
from shutil import copyfile

In [2]:
# utils

            
def make_train_test_split(**kwargs):
    test_size = kwargs.get('test_size')
    labeled_path = kwargs.get('labeled_path')
    train_path = kwargs.get('train_path')
    test_path = kwargs.get('test_path')
    
    train_path.mkdir(parents=True, exist_ok=True)
    test_path.mkdir(parents=True, exist_ok=True)
    
    sss = StratifiedShuffleSplit(n_splits=1, test_size=test_size)

    img_paths = []
    labels = []
    for label in labeled_path.iterdir():
        for img in labeled_path.joinpath(label).iterdir():
            img_paths.append(img)
            labels.append(label)

    idx_train, idx_test = next(sss.split(img_paths,labels))

    train_imgs = [img_paths[i] for i in idx_train]
    train_labels = [labels[i] for i in idx_train]

    test_imgs = [img_paths[i] for i in idx_test]
    test_labels = [labels[i] for i in idx_test]
    
    for img,label in zip(train_imgs,train_labels):        
        dest_cat_path = train_path.joinpath(label.name)
        dest_cat_path.mkdir(parents=True, exist_ok=True)
        copyfile(img,dest_cat_path.joinpath(img.name))
        
    for img,label in zip(test_imgs,test_labels):        
        dest_cat_path = test_path.joinpath(label.name)
        dest_cat_path.mkdir(parents=True, exist_ok=True)
        copyfile(img,dest_cat_path.joinpath(img.name))
            
    
def class_distribution(dataset,class_index_dict):
    count_dict = {}
    for img,label in dataset:
        if class_index_dict[label] not in count_dict:
            count_dict.update({class_index_dict[label]:0})
        else:
            count_dict[class_index_dict[label]] += 1
            
    return count_dict
    
def sample_dataset(dataset,sample):
    sample_idx = np.random.randint(len(unlabeled_dataset),size = int(sample*len(unlabeled_dataset)))
    return torch.utils.data.Subset(dataset, sample_idx)
    
def load_simcrl(simclr_results_path,n_categories,model_size = 18):
    
    #load config
    conf_path = simclr_results_path.joinpath('conf.json')
    with open(conf_path,'r') as f:
        conf = json.load(f)

    #load model
    model_path = simclr_results_path.joinpath('checkpoint.pth')

    num_ftrs = conf['num_ftrs']
    
    model_name = conf['model_name']

    resnet = lightly.models.ResNetGenerator('resnet-'+str(model_size))
    last_conv_channels = list(resnet.children())[-1].in_features
    backbone = nn.Sequential(
        *list(resnet.children())[:-1],
        nn.Conv2d(last_conv_channels, num_ftrs, 1),
        nn.AdaptiveAvgPool2d(1)
    )

    if model_name == 'simclr':
        model = lightly.models.SimCLR(backbone, num_ftrs=num_ftrs)
    elif model_name == 'moco':
        model = lightly.models.MoCo(backbone, num_ftrs=num_ftrs, m=0.99, batch_shuffle=True)
        

    encoder = lightly.embedding.SelfSupervisedEmbedding(
        model,
        None,
        None,
        None
    )

    encoder.model.load_state_dict(torch.load(model_path))
    teacher = Teacher(encoder.model,num_ftrs,n_categories).to(device)
    return teacher
    
    
def evaluate(model,testloader,loss_function):
  val_loss = 0
  total = 0
  correct = 0
  ground_truth_list = []
  predictions_list =  []
  for image,label in testloader:
      image, label = image.to(device), label.to(device)
      outputs = model(image)
      probabilities, predicted = torch.max(outputs.data, 1)
      val_loss += loss_function(outputs, label.long()).item()
      total += label.size(0)
      correct += (predicted == label).sum().item()
      ground_truth_list += list(label.cpu())
      predictions_list += list(predicted.cpu())

  acc = sklearn.metrics.accuracy_score(ground_truth_list,predictions_list)
  f1 = sklearn.metrics.f1_score(ground_truth_list,predictions_list,average = 'macro')
  precision = sklearn.metrics.precision_score(ground_truth_list,predictions_list,average = 'macro')
  recall = sklearn.metrics.recall_score(ground_truth_list,predictions_list,average = 'macro')
  print(f'acc:{acc:.3f} f1:{f1:.3f} precision:{precision:.3f} recall:{recall:.3f}')

  metrics_dict = {'val_loss':val_loss,'acc':acc,'f1':f1,'precision':precision,'recall':recall}

  return metrics_dict



# Big-self supervised models are strong semi-supervised radiologists


Inspired by [this paper](https://arxiv.org/pdf/2006.10029.pdf)


We have a small subset of labeled data $L$ and a large pool of unlabeled data $U$. The goal is to make the most out of $U$ for training a classifier for solving the task on $L$.

The procedure has three steps: 

* Pretrain a big SimCLR model on $U$
* Fine-tune on $L$
* Use the resulting model as a teacher for a smaller model, which is trained on the predictions of the teacher model on $U$



We will split $L$ into a training and testing set using stratified sampling for getting splits with the same proportions of labels. 





In [3]:
# results path
results_path = Path('/projects/self_supervised/results/bw_caries_distillation')
results_path.mkdir(parents=True, exist_ok=True)



In [4]:
# load data 

sample_unlabeled = 1.0

data_path = Path('/projects/self_supervised/data/bitewings_caries/classification_dataset')
unlabeled_path = Path('/projects/self_supervised/data/bitewings_caries/bw_U')
labeled_path = Path('/projects/self_supervised/data/bitewings_caries/bw_L')
train_path = Path('/projects/self_supervised/data/bitewings_caries/bw_train')
test_path = Path('/projects/self_supervised/data/bitewings_caries/bw_test')  

# #not required if you already splitted the data


# make_train_test_split(
#     labeled_path = data_path,
#     train_path = unlabeled_path,
#     test_path = labeled_path,
#     test_size = 0.1
# ) 


# make_train_test_split(
#     labeled_path = labeled_path,
#     train_path = train_path,
#     test_path = test_path,
#     test_size = 0.5
# )  


input_size = 128
batch_size = 16
num_workers = 2


train_transform = transforms.Compose([
    transforms.Resize((input_size, input_size)),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(360),
    transforms.ToTensor(),
    torchvision.transforms.Normalize(
    mean=lightly.data.collate.imagenet_normalize['mean'],
    std=lightly.data.collate.imagenet_normalize['std'],
    )
])


test_transform = transforms.Compose([
    transforms.Resize((input_size, input_size)),
    transforms.ToTensor(),
    torchvision.transforms.Normalize(
    mean=lightly.data.collate.imagenet_normalize['mean'],
    std=lightly.data.collate.imagenet_normalize['std'],
    )
])


trainset = datasets.ImageFolder(root=train_path, transform=train_transform)

trainloader = torch.utils.data.DataLoader(
    trainset, 
    batch_size=batch_size,
    shuffle=True, 
    num_workers=num_workers,
    drop_last=True
)

testset = datasets.ImageFolder(root=test_path, transform=test_transform)

testloader = torch.utils.data.DataLoader(
    testset, 
    batch_size=batch_size,
    shuffle=True, 
    num_workers=num_workers,
    drop_last=True
)

class_index_dict = {v:k for k,v in testset.class_to_idx.items()}

#load unlabeled data

unlabeled_dataset = datasets.ImageFolder(root=unlabeled_path, transform=train_transform)
unlabeled_dataset = sample_dataset(unlabeled_dataset,sample_unlabeled)
unlabeledloader = torch.utils.data.DataLoader(
    unlabeled_dataset,
    batch_size=batch_size,
    shuffle=True, 
    num_workers=num_workers
)

In [5]:
print('training',len(trainset))
#print('training',class_distribution(trainset,class_index_dict))
print('testing',len(testset))
#print('testing',class_distribution(testset,class_index_dict))
print('unlabeled',len(unlabeled_dataset))

training 192
testing 193
unlabeled 3457


# Vanilla student model (resnet18)

We train on $L$ a resnet18 as a baseline.

We define the softmax-like function $P(y|x_{i})=\frac{exp(f(x_{i})[y]/\tau}{\sum_{y'} exp(f(x_{i})[y']/\tau}$, where $\tau$ is a temperature parameter and $f$ the model. 

We will use cross-entropy defined as $-\sum_{(x_{i},y_{i})} \left[ log P(y_{i}|x_{i})\right]$ as the loss function for training on $L$

In [6]:
class Student(nn.Module):
    def __init__(self,output_size):
        super().__init__()
        
        self.net = models.resnet18(pretrained=True)
        self.net.fc = nn.Linear(self.net.fc.in_features, output_size)

    def forward(self, x):
        out = self.net(x)
        return out
    
def P(x,tau = 1.0):
  return torch.exp(x/tau)/(torch.exp(x/tau).sum())

class CrossEntropyLoss(torch.nn.Module):

    def __init__(self,n_categories):
        super(CrossEntropyLoss,self).__init__()
        self.n_classes = n_categories

    def forward(self, prediction, label):
      label = torch.nn.functional.one_hot(label,num_classes=n_categories)
      loss = -label*torch.log(P(prediction))
      return loss.sum().sum()

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_categories = len([cat for cat in labeled_path.iterdir()])
crossent_loss = CrossEntropyLoss(n_categories)

In [8]:
#train the student model on L+U

def FolderDataset(data_path):
    img_path_list = []
    label_list = []
    for label in data_path.iterdir():
        for img in data_path.joinpath(label).iterdir():
            label_list.append(label)
            img_path_list.append(img)
            
    return np.array(img_path_list), np.array(label_list)
    

class FullDataset(torch.utils.data.Dataset):
    def __init__(self, imgs_path,labels, transform):
   
        self.imgs_path = imgs_path
        self.labels = labels                
        self.transform = transform
        
    def __len__(self):
        return self.labels.shape[0]

    def __getitem__(self, idx):
        img_loc = self.imgs_path[idx]
        image = Image.open(img_loc).convert("RGB")
        tensor_image = self.transform(image)
        label = torch.tensor(self.labels[idx])
        return tensor_image,label
    
    
from sklearn import preprocessing
le = preprocessing.LabelEncoder()

train_imgs,train_labels = FolderDataset(train_path) 
U_imgs,U_labels = FolderDataset(train_path) 
imgs = np.concatenate((train_imgs,U_imgs))
labels = np.concatenate((train_labels,U_labels))
labels = le.fit_transform(labels)

full_dataset = FullDataset(imgs,labels,train_transform)

full_dataloader = torch.utils.data.DataLoader(
    full_dataset,
    batch_size=batch_size,
    shuffle=True, 
    num_workers=num_workers
)



In [9]:



student = Student(n_categories).to(device)
optimizer = optim.Adam(student.parameters(), lr=0.0001)

patience = 3
count = 0
best_loss = 1e9
for epoch in range(40):
  for image,label in full_dataloader:
      image, label = image.to(device), label.to(device)
      optimizer.zero_grad()
      loss = crossent_loss(student(image), label)
      loss.backward()
      optimizer.step()
        
  metrics_dict = evaluate(student,testloader,crossent_loss)
  val_loss = metrics_dict['val_loss']
  if val_loss < best_loss:
        #torch.save(student.state_dict(),results_path.joinpath('student.pth'))
        best_loss = val_loss
        best_metrics = metrics_dict
        
        count = 0
  else:
    count += 1
  if count > patience:
    break
        
        
print('\nBest metrics:')
acc = best_metrics['acc']
f1 = best_metrics['f1']
recall = best_metrics['recall']
precision = best_metrics['precision']
print(f'acc:{acc:.3f} f1:{f1:.3f} precision:{precision:.3f} recall:{recall:.3f}')

acc:0.698 f1:0.515 precision:0.595 recall:0.537
acc:0.693 f1:0.521 precision:0.585 recall:0.538
acc:0.667 f1:0.536 precision:0.559 recall:0.540
acc:0.719 f1:0.548 precision:0.676 recall:0.564
acc:0.667 f1:0.521 precision:0.553 recall:0.531
acc:0.646 f1:0.442 precision:0.456 recall:0.482
acc:0.672 f1:0.488 precision:0.537 recall:0.516

Best metrics:
acc:0.667 f1:0.536 precision:0.559 recall:0.540


In [10]:

#train the student model on L

student = Student(n_categories).to(device)
optimizer = optim.Adam(student.parameters(), lr=0.0001)

patience = 3
count = 0
best_loss = 1e9
for epoch in range(40):
  for image,label in trainloader:
      image, label = image.to(device), label.to(device)
      optimizer.zero_grad()
      loss = crossent_loss(student(image), label)
      loss.backward()
      optimizer.step()
        
  metrics_dict = evaluate(student,testloader,crossent_loss)
  val_loss = metrics_dict['val_loss']
  if val_loss < best_loss:
        torch.save(student.state_dict(),results_path.joinpath('student.pth'))
        best_loss = val_loss
        best_metrics = metrics_dict
        
        count = 0
  else:
    count += 1
  if count > patience:
    break
        
        
print('\nBest metrics:')
acc = best_metrics['acc']
f1 = best_metrics['f1']
recall = best_metrics['recall']
precision = best_metrics['precision']
print(f'acc:{acc:.3f} f1:{f1:.3f} precision:{precision:.3f} recall:{recall:.3f}')


acc:0.630 f1:0.481 precision:0.493 recall:0.495
acc:0.651 f1:0.466 precision:0.491 recall:0.496
acc:0.661 f1:0.461 precision:0.496 recall:0.498
acc:0.646 f1:0.463 precision:0.483 recall:0.492
acc:0.667 f1:0.504 precision:0.542 recall:0.522

Best metrics:
acc:0.630 f1:0.481 precision:0.493 recall:0.495


# Big teacher (SimCLR)


We fine-tune on $L$ a SimCLR model with backbone resnet50 pretrained on $U$


In [11]:
class Teacher(nn.Module):
    def __init__(self, model,num_ftrs,output_dim):
        super().__init__()
        
        self.freeze = False
        self.net = model
        self.fc1 = nn.Linear(num_ftrs, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, output_dim)
        
    def forward(self, x):
        if self.freeze:
            with torch.no_grad():
                y_hat = self.net.backbone(x).squeeze()
                y_hat = self.fc1(y_hat)
                y_hat = self.relu(y_hat)
                y_hat = self.fc2(y_hat)
                return y_hat
        else:
            y_hat = self.net.backbone(x).squeeze()
            y_hat = self.fc1(y_hat)
            y_hat = self.relu(y_hat)
            y_hat = self.fc2(y_hat)
            return y_hat
            
    def freeze_weights(self):
      for p in self.net.parameters():
          p.requires_grad = False
      self.freeze = True
      return self

    def unfreeze_weights(self):
      for p in self.net.parameters():
          p.requires_grad = True
      self.freeze = False
      return self
                

In [12]:
#finetune simclr
simclr_results_path = Path('/projects/self_supervised/results/bitewings_caries_moco_50')

teacher = load_simcrl(simclr_results_path,n_categories,model_size = 50)
optimizer = optim.Adam(teacher.parameters(), lr=0.000005)

patience = 3
count = 0
best_loss = 1e9
for epoch in range(40):
  for image,label in trainloader:
      image, label = image.to(device), label.to(device)
      optimizer.zero_grad()
      loss = crossent_loss(teacher(image), label.long())
      loss.backward()
      optimizer.step()

  metrics_dict = evaluate(teacher,testloader,crossent_loss)
  val_loss = metrics_dict['val_loss']
  if val_loss < best_loss:
        torch.save(teacher.state_dict(),results_path.joinpath('teacher.pth'))
        best_loss = val_loss
        best_metrics = metrics_dict
        count = 0
        
  else:
    count += 1
  if count > patience:
    break

print('\nBest metrics:')
acc = best_metrics['acc']
f1 = best_metrics['f1']
recall = best_metrics['recall']
precision = best_metrics['precision']
print(f'acc:{acc:.3f} f1:{f1:.3f} precision:{precision:.3f} recall:{recall:.3f}')


acc:0.490 f1:0.459 precision:0.475 recall:0.470
acc:0.490 f1:0.452 precision:0.464 recall:0.458
acc:0.505 f1:0.468 precision:0.478 recall:0.474
acc:0.500 f1:0.454 precision:0.462 recall:0.456
acc:0.521 f1:0.476 precision:0.483 recall:0.481
acc:0.547 f1:0.492 precision:0.495 recall:0.494
acc:0.583 f1:0.515 precision:0.515 recall:0.516
acc:0.542 f1:0.433 precision:0.431 recall:0.437
acc:0.547 f1:0.460 precision:0.460 recall:0.460
acc:0.562 f1:0.465 precision:0.464 recall:0.467
acc:0.630 f1:0.529 precision:0.534 recall:0.529
acc:0.568 f1:0.442 precision:0.440 recall:0.449
acc:0.589 f1:0.469 precision:0.470 recall:0.474
acc:0.583 f1:0.459 precision:0.459 recall:0.467
acc:0.583 f1:0.452 precision:0.451 recall:0.462
acc:0.589 f1:0.440 precision:0.438 recall:0.456
acc:0.604 f1:0.465 precision:0.468 recall:0.477
acc:0.589 f1:0.448 precision:0.447 recall:0.461
acc:0.620 f1:0.466 precision:0.474 recall:0.483
acc:0.594 f1:0.425 precision:0.419 recall:0.448
acc:0.625 f1:0.469 precision:0.479 recal

# Model distillation

The fine-tuned model often yields better performance than the vanilla model. Now that we used $L$ for finetuning the teacher model, we can use $U$ again for transfering the knowledge from the teacher to the student

The procedure is as follows:
* sample data from $U$
* predict with the teacher
* use the labels as targets for training the student

This algorithm is expected to yield an even better model. Notice that the resulting model would be a resnet18 with even better performance than a resnet50 pretrained on $U$ and fine-tuned on $L$.

We will use a distillation loss, which takes the output probabilities of the teacher model as the target for the student model.

$$ -\sum_{x_{i}} \left[ \sum_{y} P^{T}(y|x_{i};\tau) log P^{S}(y|x_{i};\tau) \right] $$

With this loss function the student doesn't only see the one-vs-zero encoding used in supervised learning, but a probability distribution as a target. This could help the student model to learn nuances of the data

In [13]:
class DistilLoss(torch.nn.Module):

    def __init__(self):
        super(DistilLoss,self).__init__()

    def forward(self, t_output, s_output):
      loss = -P(t_output)*torch.log(P(s_output))
      return loss.sum().sum()
     

In [14]:
#init student
student = Student(n_categories).to(device)

#load teacher checkpoint
teacher = load_simcrl(simclr_results_path,n_categories,model_size = 50)
teacher.load_state_dict(torch.load(results_path.joinpath('teacher.pth')))
teacher.eval()

distill_loss = DistilLoss()
optimizer = optim.Adam(student.parameters(), lr=0.00001)

patience = 5
count = 0
best_loss = 1e9
for epoch in range(40):
    
#   for image,label in trainloader:
#       image, label = image.to(device), label.to(device)
#       optimizer.zero_grad()
#       loss = crossent_loss(student(image), label)
#       loss.backward()
#       optimizer.step()
        
  for image,_ in unlabeledloader:
    image = image.to(device)
    loss = distill_loss(teacher(image),student(image))
    loss.backward()
    optimizer.step()

  metrics_dict = evaluate(student,testloader,crossent_loss)
  val_loss = metrics_dict['val_loss']
  if val_loss < best_loss:        
        best_loss = val_loss
        best_metrics = metrics_dict
        count = 0
  else:
    count += 1
  if count > patience:
    break

print('\nBest metrics:')
acc = best_metrics['acc']
f1 = best_metrics['f1']
recall = best_metrics['recall']
precision = best_metrics['precision']
print(f'acc:{acc:.3f} f1:{f1:.3f} precision:{precision:.3f} recall:{recall:.3f}')


acc:0.646 f1:0.580 precision:0.580 recall:0.580
acc:0.688 f1:0.508 precision:0.569 recall:0.529
acc:0.703 f1:0.413 precision:0.352 recall:0.500


  _warn_prf(average, modifier, msg_start, len(result))


acc:0.698 f1:0.411 precision:0.349 recall:0.500


  _warn_prf(average, modifier, msg_start, len(result))


acc:0.698 f1:0.411 precision:0.349 recall:0.500


  _warn_prf(average, modifier, msg_start, len(result))


acc:0.693 f1:0.409 precision:0.348 recall:0.496
acc:0.693 f1:0.425 precision:0.516 recall:0.501
acc:0.698 f1:0.411 precision:0.349 recall:0.500


  _warn_prf(average, modifier, msg_start, len(result))


acc:0.703 f1:0.429 precision:0.851 recall:0.509
acc:0.698 f1:0.411 precision:0.349 recall:0.500


  _warn_prf(average, modifier, msg_start, len(result))


acc:0.703 f1:0.413 precision:0.352 recall:0.500


  _warn_prf(average, modifier, msg_start, len(result))


acc:0.703 f1:0.445 precision:0.685 recall:0.514
acc:0.698 f1:0.411 precision:0.349 recall:0.500

Best metrics:
acc:0.693 f1:0.425 precision:0.516 recall:0.501


  _warn_prf(average, modifier, msg_start, len(result))
