<a href="https://colab.research.google.com/github/mlellouch/contrastive-learning/blob/main/colab_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Contrastive For Class Detection
We will train a siamese network with contrastive learning to learn how to detect if two images are of the same class. 

Please note the following:

*   Google colab has an option to load a notebook from github, you should probably use that.
*   This notebook will access your google drive data


## Google Drive Setup:

In [1]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

Mounted at /content/gdrive


In [2]:
import os

project_dir = "gdrive/MyDrive/cont_learning_example"
os.makedirs(project_dir, exist_ok=True)

### Clone the repo, to access the data 
If you plan on using your own data, don't forget to upload it in some way. Some of your options are:

1.   Upload from the colab interface
2.   Write a cell that copies data from google drive to this dir. 

Anyhow, remember that:


*   In each different run, the data in this dir will be deleted
*   However, using data from the mounted google drive is very slow, thus it is better to load the data to this folder at the start of the run





In [3]:
!git clone https://github.com/mlellouch/contrastive-learning.git

fatal: destination path 'contrastive-learning' already exists and is not an empty directory.


In [4]:
# init python
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import os
import torch
import csv
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torchvision.models as models
import tqdm
from PIL import Image
import random

# The Data
This project expects data like in "data/animals/train" where all the images of the same class are in a single directory.

# Prepare The Dataset
The dataset class defines how the images will be read, and what their labels are going to be. 

In our case, we will each time load a pair of images, and label the pair with 1 if it's a pair of the same class, and 0 if not. 


In [5]:
BATCH_SIZE = 16

class PairsDataset(Dataset):

  def __init__(self, source_dir, transforms):
    """
    source_dir: the initial dir where all the images are
    transforms: a pytorch image transform. This is used if you want to "enrich" the data with new datapoints by using the normal data. For example, by cropping the image a bit, or mirroring it. 
    """

    self.source_dir = source_dir
    self.transforms = transforms

  def __len__(self):
    return BATCH_SIZE


  def process_images(self, *images):
    """
    given a list of image paths, loads them, and passes them through the transform
    """

    out_images = []
    for image_path in images:
      np_image = np.array(Image.open(image_path))
      out_images.append(np_image.astype('float32'))
      
    return [self.transforms(Image.open(image_path)) for image_path in images]

  def __getitem__(self, index):
    """
    The function that is called when a new image is requested for training
    """
    all_classes = os.listdir(self.source_dir)
    # randomly select if we're going to return a same class image
    should_get_same_class = random.randint(0,1) 
    base_class = random.choice(all_classes)
    base_class_dir = os.path.join(self.source_dir, base_class)  
    if should_get_same_class:
      img1, img2 = random.sample(os.listdir(base_class_dir), 2)
      img1, img2 = os.path.join(base_class_dir, img1), os.path.join(base_class_dir, img2)
      return self.process_images(img1, img2), 1

    else:
      img1 = random.choice(os.listdir(base_class_dir))

      # choose image from another class
      all_classes.remove(base_class)
      new_class = random.choice(all_classes)
      new_class_dir = os.path.join(self.source_dir, new_class)
      img2 = random.choice(os.listdir(new_class_dir))

      img1, img2 = os.path.join(base_class_dir, img1), os.path.join(new_class_dir, img2)
      return self.process_images(img1, img2), 0


In [6]:
# Sanity test the dataset
sanity_test_transforms = transforms.Compose([
                                             transforms.Resize(400),
                                             transforms.RandomCrop(360),
                                             transforms.RandomHorizontalFlip(),
                                             
                                             transforms.ToTensor(),
                                             transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5]),
])
sanity_test_dataset = PairsDataset('./contrastive-learning/data/animals/train', sanity_test_transforms)
a = sanity_test_dataset[0]
assert type(a[1]) == int
assert list(a[0][0].shape) == [3, 360, 360]

# Define The Model
This model will use as backbone a pretrained network, which will make our trainig time a lot shorter. If you plan on using it for unnatural images (like web-pages), this pretrained model will still help, but expect a longer training time

## Contrastive Loss vs Classification Loss
Reminder: In this method, the model will return a vector for each image. From this point, we have to options (which you can decide in the code):


1.   **Contrastive Loss:** where we will train the model to output similar vectors for images of the same class. Then to classify a new pair, you need to calculate the distance between the two vectors, and if it's below some threshold, it is a pair of the same class. This might be useful if you intend on doing more interesting stuff with the vector than just classifying.
2.   **Classification Loss:** from the two outputted vectors, we can train a new part of the network that learns to predict given the vectors, are they from the same class. 



In [7]:
class SiameseNetwork(torch.nn.Module):

  def __init__(self, backbone='resnet18', contrastive_loss=True):
    super(SiameseNetwork, self).__init__()
    self.contrastive_loss = contrastive_loss

    available_backbones = {
        'resnet18': models.resnet18,
        'resnet34': models.resnet34,
        'resnet50': models.resnet50
    }

    self.backbone = available_backbones[backbone](pretrained=True)

    # the fully connected part, which will only be used if we're using classification loss
    self.fc1 = nn.Sequential(
        nn.Linear(2000, 256),
        nn.ReLU(inplace=True),
        
        nn.Linear(256,1),
        nn.Sigmoid()
    )

  def forward(self, image1, image2):
    out1 = self.backbone(image1)
    out2 = self.backbone(image2)
    if self.contrastive_loss:
      return out1, out2
    else:
      return self.fc1(torch.cat((out1, out2), 1))

# Define the Criterions (Losses)



In [8]:
class ContrastiveLoss(torch.nn.Module):

    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output, y):
        out1, out2 = output
        diff = out1 - out2
        dist_sq = torch.sum(torch.pow(diff, 2), 1)
        dist = torch.sqrt(dist_sq)
        mdist = self.margin - dist
        dist = torch.clamp(mdist, min=0.0)
        loss = y * dist_sq + (1 - y) * torch.pow(dist, 2)
        loss = torch.sum(loss) / 2.0 / out1.size()[0]
        return loss


class BCELoss(torch.nn.Module):

  def __init__(self):
    super(BCELoss, self).__init__()
    self.loss = torch.nn.BCELoss()

  def forward(self, output, y):
    output = output.reshape(y.shape)
    y = y.type(torch.float32)
    return self.loss(output, y)


# Define The Training Loop

In [9]:
def train(model, dataset, criterion, optimizer, epochs=10, run_name='test'):
  # define the log:
  log_file = os.path.join(project_dir, f'log_{run_name}.csv')
  log = csv.DictWriter(open(log_file, 'w', newline=''), fieldnames=['epoch', 'loss'])
  log.writeheader()

  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                              batch_size=BATCH_SIZE,
                                              shuffle=True)


  for epoch in tqdm.tqdm(range(epochs)):
      for (img1_set, img2_set), labels in train_loader:
        img1_set = img1_set.to(device)
        img2_set = img2_set.to(device)
      

        # Forward + Backward + Optimize
        optimizer.zero_grad()
        output = model(img1_set, img2_set)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()

      log.writerow({'epoch': epoch, 'loss': loss.detach().item()})

  
  # Save the Trained Model
  model_file_name = os.path.join(project_dir, f'trained_model_{run_name}.pth')
  torch.save(model.state_dict(), model_file_name)
  print("Saved model at {}".format(model_file_name))
  return model

# Let's Test It

In [11]:
# this is a single experiment cell. You should copy and paste it if you want to toy with the parameters

contrastive_loss = True
if contrastive_loss:
  criterion = ContrastiveLoss()
else:
  criterion = BCELoss()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SiameseNetwork(contrastive_loss=contrastive_loss)
model.to(device)
optimizer = torch.optim.Adam(model.parameters())

run_transforms = transforms.Compose([
  transforms.Resize(400),
  transforms.RandomAffine(30),
  transforms.RandomCrop(360),
  transforms.RandomHorizontalFlip(),
  
  transforms.ToTensor(),
  transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5]),
])

dataset = PairsDataset('./contrastive-learning/data/animals/train', run_transforms)

model = train(model, dataset, criterion, optimizer, epochs=15)

100%|██████████| 15/15 [05:29<00:00, 21.98s/it]


Saved model at gdrive/MyDrive/cont_learning_example/trained_model_test.pth
