In [1]:
import lightly
import torchvision
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
from sklearn.model_selection import train_test_split
from pathlib import Path
import json
import os
from PIL import Image
import sklearn
from sklearn.model_selection import StratifiedShuffleSplit
from shutil import copyfile

In [2]:
# utils

            
def make_split(**kwargs):
    test_size = kwargs.get('test_size')
    labeled_path = kwargs.get('labeled_path')
    train_path = kwargs.get('train_path')
    test_path = kwargs.get('test_path')
    
    train_path.mkdir(parents=True, exist_ok=True)
    test_path.mkdir(parents=True, exist_ok=True)
    
    sss = StratifiedShuffleSplit(n_splits=1, test_size=test_size)

    img_paths = []
    labels = []
    for label in labeled_path.iterdir():
        for img in labeled_path.joinpath(label).iterdir():
            img_paths.append(img)
            labels.append(label)

    idx_train, idx_test = next(sss.split(img_paths,labels))

    train_imgs = [img_paths[i] for i in idx_train]
    train_labels = [labels[i] for i in idx_train]

    test_imgs = [img_paths[i] for i in idx_test]
    test_labels = [labels[i] for i in idx_test]
    
    for img,label in zip(train_imgs,train_labels):        
        dest_cat_path = train_path.joinpath(label.name)
        dest_cat_path.mkdir(parents=True, exist_ok=True)
        copyfile(img,dest_cat_path.joinpath(img.name))
        
    for img,label in zip(test_imgs,test_labels):        
        dest_cat_path = test_path.joinpath(label.name)
        dest_cat_path.mkdir(parents=True, exist_ok=True)
        copyfile(img,dest_cat_path.joinpath(img.name))
            

# class UnlabeledDataset(torch.utils.data.Dataset):
#     def __init__(self, main_dir, transform):
#         self.main_dir = main_dir
#         self.transform = transform
#         self.all_imgs = os.listdir(main_dir)
        
#     def __len__(self):
#         return len(self.all_imgs)

#     def __getitem__(self, idx):
#         img_loc = os.path.join(self.main_dir, self.all_imgs[idx])
#         image = Image.open(img_loc).convert("RGB")
#         tensor_image = self.transform(image)
#         return tensor_image
    
def class_distribution(dataset,class_index_dict):
    count_dict = {}
    for img,label in dataset:
        if class_index_dict[label] not in count_dict:
            count_dict.update({class_index_dict[label]:0})
        else:
            count_dict[class_index_dict[label]] += 1
            
    return count_dict
    
def sample_dataset(dataset,sample):
    sample_idx = np.random.randint(len(unlabeled_dataset),size = int(sample*len(unlabeled_dataset)))
    return torch.utils.data.Subset(dataset, sample_idx)
    
def load_simcrl(simclr_results_path,n_categories,model_size = 18):
    
    #load config
    conf_path = simclr_results_path.joinpath('conf.json')
    with open(conf_path,'r') as f:
        conf = json.load(f)

    #load model
    model_path = simclr_results_path.joinpath('checkpoint.pth')

    num_ftrs = conf['num_ftrs']
    
    model_name = conf['model_name']

    resnet = lightly.models.ResNetGenerator('resnet-'+str(model_size))
    last_conv_channels = list(resnet.children())[-1].in_features
    backbone = nn.Sequential(
        *list(resnet.children())[:-1],
        nn.Conv2d(last_conv_channels, num_ftrs, 1),
        nn.AdaptiveAvgPool2d(1)
    )

    if model_name == 'simclr':
        model = lightly.models.SimCLR(backbone, num_ftrs=num_ftrs)
    elif model_name == 'moco':
        model = lightly.models.MoCo(backbone, num_ftrs=num_ftrs, m=0.99, batch_shuffle=True)
        

    encoder = lightly.embedding.SelfSupervisedEmbedding(
        model,
        None,
        None,
        None
    )

    encoder.model.load_state_dict(torch.load(model_path))
    teacher = Teacher(encoder.model,num_ftrs,n_categories).to(device)
    return teacher
    
    
def evaluate(model,testloader,loss_function):
  val_loss = 0
  total = 0
  correct = 0
  ground_truth_list = []
  predictions_list =  []
  for image,label in testloader:
      image, label = image.to(device), label.to(device)
      outputs = model(image)
      probabilities, predicted = torch.max(outputs.data, 1)
      val_loss += loss_function(outputs, label.long()).item()
      total += label.size(0)
      correct += (predicted == label).sum().item()
      ground_truth_list += list(label.cpu())
      predictions_list += list(predicted.cpu())

  acc = sklearn.metrics.accuracy_score(ground_truth_list,predictions_list)
  f1 = sklearn.metrics.f1_score(ground_truth_list,predictions_list,average = 'macro')
  precision = sklearn.metrics.precision_score(ground_truth_list,predictions_list,average = 'macro')
  recall = sklearn.metrics.recall_score(ground_truth_list,predictions_list,average = 'macro')
  print(f'acc:{acc:.3f} f1:{f1:.3f} precision:{precision:.3f} recall:{recall:.3f}')

  metrics_dict = {'val_loss':val_loss,'acc':acc,'f1':f1,'precision':precision,'recall':recall}

  return metrics_dict



# Big-self supervised models are strong semi-supervised dentists


Inspired by [this paper](https://arxiv.org/pdf/2006.10029.pdf)


We have a small subset of labeled data $L$ and a large pool of unlabeled data $U$. The goal is to make the most out of $U$ for training a classifier for solving the task on $L$.

The procedure has three steps: 

* Pretrain a big SimCLR model on $U$
* Fine-tune on $L$
* Use the resulting model as a teacher for a smaller model, which is trained on the predictions of the teacher model on $U$



We will split $L$ into a training and testing set using stratified sampling for getting splits with the same proportions of labels. 





In [3]:
# results path
results_path = Path('/projects/self_supervised/results/apical_distillation')
results_path.mkdir(parents=True, exist_ok=True)



In [11]:
# load data 

sample_unlabeled = 1.0

data_path = Path('/projects/self_supervised/data/apical_lesion/full_dataset')
unlabeled_path = Path('/projects/self_supervised/data/apical_lesion/apical_U')
labeled_path = Path('/projects/self_supervised/data/apical_lesion/apical_L')
train_path = Path('/projects/self_supervised/data/apical_lesion/apical_train')
test_path = Path('/projects/self_supervised/data/apical_lesion/apical_test')  

# #not required if you already splitted the data


# make_split(
#     labeled_path = data_path,
#     train_path = unlabeled_path,
#     test_path = labeled_path,
#     test_size = 0.01
# ) 


# make_split(
#     labeled_path = labeled_path,
#     train_path = train_path,
#     test_path = test_path,
#     test_size = 0.5
# )  


input_size = 128
batch_size = 16
num_workers = 2


train_transform = transforms.Compose([
    transforms.Resize((input_size, input_size)),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(360),
    transforms.ToTensor(),
    torchvision.transforms.Normalize(
    mean=lightly.data.collate.imagenet_normalize['mean'],
    std=lightly.data.collate.imagenet_normalize['std'],
    )
])


test_transform = transforms.Compose([
    transforms.Resize((input_size, input_size)),
    transforms.ToTensor(),
    torchvision.transforms.Normalize(
    mean=lightly.data.collate.imagenet_normalize['mean'],
    std=lightly.data.collate.imagenet_normalize['std'],
    )
])


trainset = datasets.ImageFolder(root=train_path, transform=train_transform)

trainloader = torch.utils.data.DataLoader(
    trainset, 
    batch_size=batch_size,
    shuffle=True, 
    num_workers=num_workers,
    drop_last=True
)

testset = datasets.ImageFolder(root=test_path, transform=test_transform)

testloader = torch.utils.data.DataLoader(
    testset, 
    batch_size=batch_size,
    shuffle=True, 
    num_workers=num_workers,
    drop_last=True
)

class_index_dict = {v:k for k,v in testset.class_to_idx.items()}

unlabeled_dataset = datasets.ImageFolder(root=unlabeled_path, transform=train_transform)
unlabeled_dataset = sample_dataset(unlabeled_dataset,sample_unlabeled)
unlabeledloader = torch.utils.data.DataLoader(
    unlabeled_dataset,
    batch_size=batch_size,
    shuffle=True, 
    num_workers=num_workers
)

In [12]:
print('training',len(trainset))
#print('training',class_distribution(trainset,class_index_dict))
print('testing',len(testset))
#print('testing',class_distribution(testset,class_index_dict))
print('unlabeled',len(unlabeled_dataset))

training 54
testing 55
unlabeled 10768


# Vanilla student model (resnet18)

We train on $L$ a resnet18 as a baseline.

We define the softmax-like function $P(y|x_{i})=\frac{exp(f(x_{i})[y]/\tau}{\sum_{y'} exp(f(x_{i})[y']/\tau}$, where $\tau$ is a temperature parameter and $f$ the model. 

We will use cross-entropy defined as $-\sum_{(x_{i},y_{i})} \left[ log P(y_{i}|x_{i})\right]$ as the loss function for training on $L$

In [13]:
class Student(nn.Module):
    def __init__(self,output_size):
        super().__init__()
        
        self.net = models.resnet18(pretrained=True)
        self.net.fc = nn.Linear(self.net.fc.in_features, output_size)

    def forward(self, x):
        out = self.net(x)
        return out
    
def P(x,tau = 1.0):
  return torch.exp(x/tau)/(torch.exp(x/tau).sum())

class CrossEntropyLoss(torch.nn.Module):

    def __init__(self,n_categories):
        super(CrossEntropyLoss,self).__init__()
        self.n_classes = n_categories

    def forward(self, prediction, label):
      label = torch.nn.functional.one_hot(label,num_classes=n_categories)
      loss = -label*torch.log(P(prediction))
      return loss.sum().sum()

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_categories = len([cat for cat in labeled_path.iterdir()])
crossent_loss = CrossEntropyLoss(n_categories)

In [17]:
n_categories

2

In [22]:
#train the student model on L+U

def FolderDataset(data_path):
    img_path_list = []
    label_list = []
    for label in data_path.iterdir():
        for img in data_path.joinpath(label).iterdir():
            label_list.append(label.name)
            img_path_list.append(img)
            
    return np.array(img_path_list), np.array(label_list)
    

class FullDataset(torch.utils.data.Dataset):
    def __init__(self, imgs_path,labels, transform):
   
        self.imgs_path = imgs_path
        self.labels = labels                
        self.transform = transform
        
    def __len__(self):
        return self.labels.shape[0]

    def __getitem__(self, idx):
        img_loc = self.imgs_path[idx]
        image = Image.open(img_loc).convert("RGB")
        tensor_image = self.transform(image)
        label = torch.tensor(self.labels[idx])
        return tensor_image,label
    
    
from sklearn import preprocessing
le = preprocessing.LabelEncoder()

train_imgs,train_labels = FolderDataset(train_path) 
U_imgs,U_labels = FolderDataset(unlabeled_path) 


imgs = np.concatenate((train_imgs,U_imgs))
labels = np.concatenate((train_labels,U_labels))
labels = le.fit_transform(labels)

full_dataset = FullDataset(imgs,labels,train_transform)

full_dataloader = torch.utils.data.DataLoader(
    full_dataset,
    batch_size=batch_size,
    shuffle=True, 
    num_workers=num_workers
)



In [23]:
max(labels)


1

In [24]:



student = Student(n_categories).to(device)
optimizer = optim.Adam(student.parameters(), lr=0.0001)

patience = 3
count = 0
best_loss = 1e9
for epoch in range(40):
  for image,label in full_dataloader:
      image, label = image.to(device), label.to(device)
      optimizer.zero_grad()
      loss = crossent_loss(student(image), label)
      loss.backward()
      optimizer.step()
        
  metrics_dict = evaluate(student,testloader,crossent_loss)
  val_loss = metrics_dict['val_loss']
  if val_loss < best_loss:
        #torch.save(student.state_dict(),results_path.joinpath('student.pth'))
        best_loss = val_loss
        best_metrics = metrics_dict
        
        count = 0
  else:
    count += 1
  if count > patience:
    break
        
        
print('\nBest metrics:')
acc = best_metrics['acc']
f1 = best_metrics['f1']
recall = best_metrics['recall']
precision = best_metrics['precision']
print(f'acc:{acc:.3f} f1:{f1:.3f} precision:{precision:.3f} recall:{recall:.3f}')

acc:0.979 f1:0.967 precision:0.950 recall:0.987


KeyboardInterrupt: 

In [None]:

#train the student model on L

student = Student(n_categories).to(device)
optimizer = optim.Adam(student.parameters(), lr=0.0001)

patience = 3
count = 0
best_loss = 1e9
for epoch in range(40):
  for image,label in trainloader:
      image, label = image.to(device), label.to(device)
      optimizer.zero_grad()
      loss = crossent_loss(student(image), label)
      loss.backward()
      optimizer.step()
        
  metrics_dict = evaluate(student,testloader,crossent_loss)
  val_loss = metrics_dict['val_loss']
  if val_loss < best_loss:
        torch.save(student.state_dict(),results_path.joinpath('student.pth'))
        best_loss = val_loss
        best_metrics = metrics_dict
        
        count = 0
  else:
    count += 1
  if count > patience:
    break
        
        
print('\nBest metrics:')
acc = best_metrics['acc']
f1 = best_metrics['f1']
recall = best_metrics['recall']
precision = best_metrics['precision']
print(f'acc:{acc:.3f} f1:{f1:.3f} precision:{precision:.3f} recall:{recall:.3f}')


# Big teacher (SimCLR)


We fine-tune on $L$ a SimCLR model with backbone resnet50 pretrained on $U$


In [28]:
class Teacher(nn.Module):
    def __init__(self, model,num_ftrs,output_dim):
        super().__init__()
        
        self.freeze = False
        self.net = model
        self.fc1 = nn.Linear(num_ftrs, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, output_dim)
        
    def forward(self, x):
        if self.freeze:
            with torch.no_grad():
                y_hat = self.net.backbone(x).squeeze()
                y_hat = self.fc1(y_hat)
                y_hat = self.relu(y_hat)
                y_hat = self.fc2(y_hat)
                return y_hat
        else:
            y_hat = self.net.backbone(x).squeeze()
            y_hat = self.fc1(y_hat)
            y_hat = self.relu(y_hat)
            y_hat = self.fc2(y_hat)
            return y_hat
            
    def freeze_weights(self):
      for p in self.net.parameters():
          p.requires_grad = False
      self.freeze = True
      return self

    def unfreeze_weights(self):
      for p in self.net.parameters():
          p.requires_grad = True
      self.freeze = False
      return self
                

In [29]:
#finetune simclr
simclr_results_path = Path('/projects/self_supervised/results/bitewings_caries_moco_50')

teacher = load_simcrl(simclr_results_path,n_categories,model_size = 50)
optimizer = optim.Adam(teacher.parameters(), lr=0.000005)

patience = 3
count = 0
best_loss = 1e9
for epoch in range(40):
  for image,label in trainloader:
      image, label = image.to(device), label.to(device)
      optimizer.zero_grad()
      loss = crossent_loss(teacher(image), label.long())
      loss.backward()
      optimizer.step()

  metrics_dict = evaluate(teacher,testloader,crossent_loss)
  val_loss = metrics_dict['val_loss']
  if val_loss < best_loss:
        torch.save(teacher.state_dict(),results_path.joinpath('teacher.pth'))
        best_loss = val_loss
        best_metrics = metrics_dict
        count = 0
        
  else:
    count += 1
  if count > patience:
    break

print('\nBest metrics:')
acc = best_metrics['acc']
f1 = best_metrics['f1']
recall = best_metrics['recall']
precision = best_metrics['precision']
print(f'acc:{acc:.3f} f1:{f1:.3f} precision:{precision:.3f} recall:{recall:.3f}')


acc:0.505 f1:0.490 precision:0.518 recall:0.521
acc:0.542 f1:0.525 precision:0.546 recall:0.554
acc:0.536 f1:0.514 precision:0.530 recall:0.536
acc:0.536 f1:0.514 precision:0.532 recall:0.539
acc:0.542 f1:0.519 precision:0.533 recall:0.540
acc:0.568 f1:0.526 precision:0.530 recall:0.534
acc:0.594 f1:0.546 precision:0.547 recall:0.552
acc:0.547 f1:0.488 precision:0.491 recall:0.490
acc:0.594 f1:0.535 precision:0.535 recall:0.538
acc:0.625 f1:0.555 precision:0.555 recall:0.555
acc:0.656 f1:0.584 precision:0.586 recall:0.583
acc:0.630 f1:0.535 precision:0.538 recall:0.534
acc:0.625 f1:0.500 precision:0.509 recall:0.506
acc:0.615 f1:0.506 precision:0.510 recall:0.509
acc:0.641 f1:0.517 precision:0.531 recall:0.523
acc:0.646 f1:0.521 precision:0.534 recall:0.525
acc:0.651 f1:0.518 precision:0.534 recall:0.524
acc:0.625 f1:0.469 precision:0.479 recall:0.487
acc:0.688 f1:0.535 precision:0.582 recall:0.545
acc:0.651 f1:0.485 precision:0.510 recall:0.506
acc:0.615 f1:0.455 precision:0.459 recal

# Model distillation

The fine-tuned model often yields better performance than the vanilla model. Now that we used $L$ for finetuning the teacher model, we can use $U$ again for transfering the knowledge from the teacher to the student

The procedure is as follows:
* sample data from $U$
* predict with the teacher
* use the labels as targets for training the student

This algorithm is expected to yield an even better model. Notice that the resulting model would be a resnet18 with even better performance than a resnet50 pretrained on $U$ and fine-tuned on $L$.

We will use a distillation loss, which takes the output probabilities of the teacher model as the target for the student model.

$$ -\sum_{x_{i}} \left[ \sum_{y} P^{T}(y|x_{i};\tau) log P^{S}(y|x_{i};\tau) \right] $$

With this loss function the student doesn't only see the one-vs-zero encoding used in supervised learning, but a probability distribution as a target. This could help the student model to learn nuances of the data

In [30]:
class DistilLoss(torch.nn.Module):

    def __init__(self):
        super(DistilLoss,self).__init__()

    def forward(self, t_output, s_output):
      loss = -P(t_output)*torch.log(P(s_output))
      return loss.sum().sum()
     

In [31]:
#init student
student = Student(n_categories).to(device)
#student.load_state_dict(torch.load(results_path.joinpath('student.pth')))

#load teacher checkpoint
teacher = load_simcrl(simclr_results_path,n_categories,model_size = 50)
teacher.load_state_dict(torch.load(results_path.joinpath('teacher.pth')))
teacher = teacher.freeze_weights()
teacher.eval()

distill_loss = DistilLoss()
optimizer = optim.Adam(student.parameters(), lr=0.00001)

patience = 5
count = 0
best_loss = 1e9
for epoch in range(40):
    
#   for image,label in trainloader:
#       image, label = image.to(device), label.to(device)
#       optimizer.zero_grad()
#       loss = crossent_loss(student(image), label)
#       loss.backward()
#       optimizer.step()
        
  for image,_ in unlabeledloader:
    image = image.to(device)
    loss = distill_loss(teacher(image),student(image))
    loss.backward()
    optimizer.step()

  metrics_dict = evaluate(student,testloader,crossent_loss)
  val_loss = metrics_dict['val_loss']
  if val_loss < best_loss:        
        best_loss = val_loss
        best_metrics = metrics_dict
        count = 0
  else:
    count += 1
  if count > patience:
    break

print('\nBest metrics:')
acc = best_metrics['acc']
f1 = best_metrics['f1']
recall = best_metrics['recall']
precision = best_metrics['precision']
print(f'acc:{acc:.3f} f1:{f1:.3f} precision:{precision:.3f} recall:{recall:.3f}')


acc:0.693 f1:0.425 precision:0.516 recall:0.501
acc:0.620 f1:0.563 precision:0.562 recall:0.566
acc:0.651 f1:0.476 precision:0.497 recall:0.498
acc:0.688 f1:0.407 precision:0.349 recall:0.489
acc:0.677 f1:0.458 precision:0.517 recall:0.505
acc:0.656 f1:0.410 precision:0.393 recall:0.475
acc:0.661 f1:0.438 precision:0.462 recall:0.489
acc:0.667 f1:0.415 precision:0.408 recall:0.483
acc:0.703 f1:0.429 precision:0.851 recall:0.509
acc:0.615 f1:0.500 precision:0.505 recall:0.504
acc:0.698 f1:0.427 precision:0.600 recall:0.505
acc:0.604 f1:0.567 precision:0.570 recall:0.580
acc:0.698 f1:0.411 precision:0.349 recall:0.500


  _warn_prf(average, modifier, msg_start, len(result))


acc:0.521 f1:0.517 precision:0.570 recall:0.578
acc:0.698 f1:0.411 precision:0.349 recall:0.500

Best metrics:
acc:0.703 f1:0.429 precision:0.851 recall:0.509


  _warn_prf(average, modifier, msg_start, len(result))
