Ref:
* https://www.kaggle.com/code/jirkaborovec/whale-dolphin-embedding-lit-flash-simclr
* https://www.kaggle.com/remekkinas/remove-background-salient-object-detection
* https://www.kaggle.com/abcd28s/simclr

In [1]:
import os
import time
import torch

import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt

In [2]:
import torch.nn as nn
import torchvision.models as models

from torchvision.datasets import ImageFolder
from torchvision.transforms import (
    RandomResizedCrop,
    RandomHorizontalFlip,
    ColorJitter,
    RandomGrayscale,
    RandomApply,
    Compose,
    GaussianBlur,
    ToTensor,
)

In [3]:
print(f'Torch-Version {torch.__version__}')
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'DEVICE: {DEVICE}')

Torch-Version 1.9.1+cpu
DEVICE: cpu


In [4]:
!ls -l /kaggle/input/happy-whale-and-dolphin/

total 4668
-rw-r--r-- 1 nobody nogroup 2404234 Feb  1 16:58 sample_submission.csv
drwxr-xr-x 2 nobody nogroup       0 Feb  1 17:00 test_images
-rw-r--r-- 1 nobody nogroup 2371769 Feb  1 17:00 train.csv
drwxr-xr-x 2 nobody nogroup       0 Feb  1 17:04 train_images


In [5]:
!tree -d /kaggle/input/happy-whale-and-dolphin/

[01;34m/kaggle/input/happy-whale-and-dolphin/[00m
├── [01;34mtest_images[00m
└── [01;34mtrain_images[00m

2 directories


In [6]:
PATH_DATASET = '/kaggle/input/happy-whale-and-dolphin/'

In [7]:
df_train = pd.read_csv(PATH_DATASET + 'train.csv')
df_train.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 51033 entries, 0 to 51032
Data columns (total 3 columns):
 #   Column         Non-Null Count  Dtype 
---  ------         --------------  ----- 
 0   image          51033 non-null  object
 1   species        51033 non-null  object
 2   individual_id  51033 non-null  object
dtypes: object(3)
memory usage: 1.2+ MB


In [8]:
df_train.head()

Unnamed: 0,image,species,individual_id
0,00021adfb725ed.jpg,melon_headed_whale,cadddb1636b9
1,000562241d384d.jpg,humpback_whale,1a71fbb72250
2,0007c33415ce37.jpg,false_killer_whale,60008f293a2b
3,0007d9bca26a99.jpg,bottlenose_dolphin,4b00fe572063
4,00087baf5cef7a.jpg,humpback_whale,8e5253662392


In [9]:
print(f"Dataset size: {len(df_train)}")
print(f"Unique ids: {len(df_train['individual_id'].unique())}")
print(f"# unique species: {len(df_train['species'].unique())}")

Dataset size: 51033
Unique ids: 15587
# unique species: 30


In [10]:
df_train = df_train[:128]  # needs to be removed

In [11]:
from PIL import Image
from torch.utils.data import Dataset


class HappyWhaleDataset(Dataset):
    def __init__(self, df: pd.DataFrame, path_folder: str, transform=None):
        """
        Args:
            df: dataframe of train csv 
            path_folder: path of train images
            transform: (callable, optional): optional transform to be applied on a sample
        """
        self.df = df
        self.transform = transform

        self.image_names = self.df["image"].values
        self.image_paths = [os.path.join(path_folder, name) for name in self.image_names]
        self.targets = list(self.df["individual_id"])
        self.uq_targets = sorted(set(self.targets))
        lut = {v: k for k, v in dict(enumerate(self.uq_targets)).items()}
        self.labels = [lut[ind] for ind in self.targets]

    def __len__(self) -> int:
        return len(self.df)
    
    def __getitem__(self, idx: int) -> tuple:
        """
        Args:
            idx: index for the dataset
        """
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert('RGB')

        if self.transform:
            img = self.transform(img)
            
        return img


In [12]:
def get_complete_transform(output_shape, kernel_size, s=1.0):
    """
    The SimCLR transform.
    
    Args:
        output_shape: output image shape, [int, int]
        kernel_size: kernel size for Gaussian blur, [int, int]
        s: strength parameter, float
    
    Returns:
        image_transform: the simclr transform
    """
    rnd_crop = RandomResizedCrop(output_shape)
    rnd_flip = RandomHorizontalFlip(p=0.5)
    
    color_jitter = ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
    rnd_color_jitter = RandomApply([color_jitter], p=0.8)
    
    rnd_gray = RandomGrayscale(p=0.2)
    gaussian_blur = GaussianBlur(kernel_size=kernel_size)
    rnd_gaussian_blur = RandomApply([gaussian_blur], p=0.5)
    to_tensor = ToTensor()
    image_transform = Compose([
        to_tensor,
        rnd_crop,
        rnd_flip,
        rnd_color_jitter,
        rnd_gray,
        rnd_gaussian_blur,
    ])
    return image_transform


class ContrastiveLearningViewGenerator(object):
    """
    Take two random crops of one image as the query and key.
    
    Args:
        base_transform: the transform applied to the image
        n_views (optional): number of output images
    """

    def __init__(self, base_transform, n_views=2):
        self.base_transform = base_transform
        self.n_views = n_views

    def __call__(self, x):
        views = [self.base_transform(x) for i in range(self.n_views)]
        return views

In [13]:
# The size of the images
output_shape = [224,224]
kernel_size = [21,21] # 10% of the output_shape

# The custom transform
base_transforms = get_complete_transform(output_shape=output_shape, kernel_size=kernel_size, s=1.0)
custom_transform = ContrastiveLearningViewGenerator(base_transform=base_transforms)

In [14]:
dataset = HappyWhaleDataset(
    df=df_train,
    path_folder=f"{PATH_DATASET}/train_images",
    transform=custom_transform
)


In [15]:
# BATCH_SIZE = 128
BATCH_SIZE = 16  # needs to be updated


# Building the data loader
train_dl = torch.utils.data.DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=os.cpu_count(),
    drop_last=True,
    pin_memory=True,
)

In [16]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
    def forward(self, x):
        return x

class SimCLR(nn.Module):
    def __init__(self, linear_eval=False):
        super().__init__()
        self.linear_eval = linear_eval
        resnet18 = models.resnet18(pretrained=False)
        resnet18.fc = Identity()
        self.encoder = resnet18
        self.projection = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 256)
        )
    def forward(self, x):
        if not self.linear_eval:
            x = torch.cat(x, dim=0)
        
        encoding = self.encoder(x)
        projection = self.projection(encoding) 
        return projection

In [17]:
LABELS = torch.cat([torch.arange(BATCH_SIZE) for i in range(2)], dim=0)
LABELS = (LABELS.unsqueeze(0) == LABELS.unsqueeze(1)).float() # Creates a one-hot with broadcasting
LABELS = LABELS.to(DEVICE) #128,128

def cont_loss(features, temp):
    """
    The NTxent Loss.
    
    Args:
        z1: The projection of the first branch
        z2: The projeciton of the second branch
    
    Returns:
        the NTxent loss
    """
    similarity_matrix = torch.matmul(features, features.T) # 128, 128
    # discard the main diagonal from both: labels and similarities matrix
    mask = torch.eye(LABELS.shape[0], dtype=torch.bool).to(DEVICE)
    # ~mask is the negative of the mask
    # the view is required to bring the matrix back to shape
    labels = LABELS[~mask].view(LABELS.shape[0], -1) # 128, 127
    similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1) # 128, 127

    # select and combine multiple positives
    positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1) # 128, 1

    # select only the negatives
    negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1) # 128, 126

    logits = torch.cat([positives, negatives], dim=1) # 128, 127
    labels = torch.zeros(logits.shape[0], dtype=torch.long).to(DEVICE)  # vectorization

    logits = logits / temp
    return logits, labels

In [18]:
simclr_model = SimCLR().to(DEVICE)
criterion = nn.CrossEntropyLoss().to(DEVICE)
optimizer = torch.optim.Adam(simclr_model.parameters())

In [19]:
EPOCHS = 1
for epoch in range(EPOCHS):
    t0 = time.time()
    running_loss = 0.0
    for i, views in enumerate(train_dl):
        
        projections = simclr_model([view.to(DEVICE) for view in views])
        logits, labels = cont_loss(projections, temp=2)
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # print statistics
        running_loss += loss.item()
        if i % 10 == 9:    # print every 10 mini-batches
            print(f'EPOCH: {epoch+1} BATCH: {i+1} LOSS: {(running_loss/100):.4f} ')
            running_loss = 0.0
    print(f'Time taken: {((time.time()-t0)/60):.3f} mins')

Time taken: 0.852 mins
