In [1]:
!pip install -r req.txt

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [2]:
# importing base python dependencies
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from datetime import datetime, date
import os
import random
from pathlib import Path
import copy
from tqdm import tqdm
import multiprocessing
from skimage.io import imread
import pdb
import pytorch_lightning as pl
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch

import torch.nn as nn
import torch.distributed as dist
from torch import optim
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
import torchvision.transforms as transforms
from torchvision.transforms import functional as TF
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import pandas as pd

from lightly.models.modules.heads import DINOProjectionHead




In [3]:
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

# References:
#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py

from typing import Callable, Optional, Tuple, Union

from torch import Tensor
import torch.nn as nn


def make_2tuple(x):
    if isinstance(x, tuple):
        assert len(x) == 2
        return x

    assert isinstance(x, int)
    return (x, x)


class PatchEmbed(nn.Module):
    """
    2D image to patch embedding: (B,C,H,W) -> (B,N,D)

    Args:
        img_size: Image size.
        patch_size: Patch token size.
        in_chans: Number of input image channels.
        embed_dim: Number of linear projection output channels.
        norm_layer: Normalization layer.
    """

    def __init__(
        self,
        img_size: Union[int, Tuple[int, int]] = 224,
        patch_size: Union[int, Tuple[int, int]] = 16,
        in_chans: int = 3,
        embed_dim: int = 768,
        norm_layer: Optional[Callable] = None,
        flatten_embedding: bool = True,
    ) -> None:
        super().__init__()

        image_HW = make_2tuple(img_size)
        patch_HW = make_2tuple(patch_size)
        patch_grid_size = (
            image_HW[0] // patch_HW[0],
            image_HW[1] // patch_HW[1],
        )

        self.img_size = image_HW
        self.patch_size = patch_HW
        self.patches_resolution = patch_grid_size
        self.num_patches = patch_grid_size[0] * patch_grid_size[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.flatten_embedding = flatten_embedding

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x: Tensor) -> Tensor:
        _, _, H, W = x.shape
        patch_H, patch_W = self.patch_size

        assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
        assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"

        x = self.proj(x)  # B C H W
        H, W = x.size(2), x.size(3)
        x = x.flatten(2).transpose(1, 2)  # B HW C
        x = self.norm(x)
        if not self.flatten_embedding:
            x = x.reshape(-1, H, W, self.embed_dim)  # B H W C
        return x

    def flops(self) -> float:
        Ho, Wo = self.patches_resolution
        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
        if self.norm is not None:
            flops += Ho * Wo * self.embed_dim
        return flops
    

In [4]:
class DINOViT(pl.LightningModule):
    def __init__(self,
                 lr=1e-3,
                 max_epoch_number=500,
                 num_register_tokens=4,
                 num_patches=1369, #1369   - [518/14 = 37 * 37 =1369]
                 proj_dim=2048):
        super().__init__()
        self.lr = lr
        self.max_epoch_number = max_epoch_number
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, 768))
        assert num_register_tokens >= 0
        dinov2_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg').cuda()
        self.cls_token = dinov2_model.cls_token
        self.register_tokens = dinov2_model.register_tokens
        # Custom Patch Embed
        self.patch_embed = PatchEmbed(img_size=518, patch_size=14)
        # self.patch_embed = dinov2_model.patch_embed
        self.student_backbone = nn.Sequential(*dinov2_model.blocks)
        self.student_head = DINOProjectionHead(768, 2048, 256, proj_dim, batch_norm=False)
        self.teacher_backbone = copy.deepcopy(self.student_backbone)
        self.teacher_head = DINOProjectionHead(768, 2048, 256, proj_dim, batch_norm=False)
        # Making the teacher model require no
        for param in self.teacher_backbone.parameters():
            param.requires_grad = False
        for param in self.teacher_head.parameters():
            param.requires_grad = False
        self.dino_loss = DINOLossSingleViews(output_dim=proj_dim, warmup_teacher_temp_epochs=5)
        
    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group("DINOViT")
        parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate for training.")
        parser.add_argument("--max_epoch_number", type=int, default=500, help="Maximum number of epochs.")
        parser.add_argument("--proj_dim", type=int, default=1024, help="The embedding dimension from the DINO head.")
        parser.add_argument("--num_register_tokens", type=int, default=4, help="Number of register tokens.")
        parser.add_argument("--num_patches", type=int, default=256, help="The number of patch tokens for the ViT.")
        return parent_parser
    
    def forward(self, X):
        X = self.prepare_tokens_with_masks(X)
        h = self.student_backbone(X)
        # Extracting out the cls token
        h = h[:, 0]
        z = self.student_head(h)
        return z

    def forward_teacher(self, X):
        X = self.prepare_tokens_with_masks(X)
        h = self.teacher_backbone(X)
        # Extracting out the cls token
        h = h[:, 0]
        z = self.teacher_head(h)
        return z

    def prepare_tokens_with_masks(self, x, masks=None):
        """
        This function is adapted from DINOV2 from meta.
        https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py
        """
        B, nc, w, h = x.shape
        x = self.patch_embed(x)
        x = x + self.pos_embed.repeat(B, 1, 1)
        if masks is not None:
            x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
        
        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)

        if self.register_tokens is not None:
            x = torch.cat(
                (
                    x[:, :1],
                    self.register_tokens.expand(x.shape[0], -1, -1),
                    x[:, 1:],
                ),
                dim=1,
            )
        return x

    @torch.no_grad()
    def update_momentum(self, model: nn.Module, model_ema: nn.Module, m: float):
        """
        Updates model_ema with Exponential Moving Average of model
        """
        for model_ema, model in zip(model_ema.parameters(), model.parameters()):
            model_ema.data = model_ema.data * m + model.data * (1.0 - m)
    
    @torch.no_grad()
    def _calculate_nuc_norm(self, embeddings):
        embeddings = embeddings.to(torch.float)
        _, S, _ = torch.linalg.svd(embeddings)
        nuc_norm = S.sum()
        nuc_norm = -1 * nuc_norm
        return nuc_norm
    
    def training_step(self, batch, _):
        X, X_t = batch
        momentum = cosine_schedule(self.current_epoch, 10, 0.996, 1)
        # EMA update of backbone and head
        self.update_momentum(self.student_backbone, self.teacher_backbone, m=momentum)
        self.update_momentum(self.student_head, self.teacher_head, m=momentum)
        # Forward Pass through teacher
        teacher_out = self.forward_teacher(X_t)
        # Detaching teacher output for stop gradient
        teacher_out = teacher_out.detach()
        # Forward Pass for student
        student_out = self.forward(X)
        # Calculating the loss
        loss_dino = self.dino_loss(teacher_out, student_out, epoch=self.current_epoch)
        loss = loss_dino
        # Calculating the negative nuclear norm to asses representational collapse
        neg_nuclear_norm = self._calculate_nuc_norm(torch.vstack([student_out, teacher_out]))
        loss_dict = {'train_loss': loss, 'train_nuc_norm': neg_nuclear_norm}
        self.log_dict(loss_dict, on_step=True, on_epoch=True, sync_dist=True)
        return loss

    def validation_step(self, batch, _):
        X, X_t = batch
        # Forward Pass through teacher
        teacher_out = self.forward_teacher(X_t)
        # Detaching teacher output for stop gradient
        teacher_out = teacher_out.detach()
        # Forward Pass for student
        student_out = self.forward(X)
        # Calculating the loss
        loss_dino = self.dino_loss(teacher_out, student_out, epoch=self.current_epoch)
        loss = loss_dino
        # Calculating the negative nuclear norm to asses representational collapse
        neg_nuclear_norm = self._calculate_nuc_norm(torch.vstack([student_out, teacher_out]))
        loss_dict = {'val_loss': loss, 'val_nuc_norm': neg_nuclear_norm}
        self.log_dict(loss_dict, on_step=True, on_epoch=True, sync_dist=True)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.lr)
        def warmup_fn(epoch):
            warmup_epochs = 500
            if epoch < warmup_epochs:
                return (epoch / warmup_epochs)*self.lr
            elif epoch == 0:
                return (0.5 / warmup_epochs)*self.lr
            else:
                return self.lr
        # Adding warmup scheduler
        warmup_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_fn)
        # After warmup, use a scheduler of your choice
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.max_epoch_number)
        return [optimizer], [{"scheduler": warmup_scheduler, "interval": "step"}, {"scheduler": scheduler, "interval": "epoch"}]



In [5]:
class WSIPathSSLDataset(Dataset):
    def __init__(self, csv_file):
        self.img_files = pd.read_csv(csv_file)
        self.length = len(self.img_files)
        self.resize_transform = transforms.Resize((518, 518), antialias=True)
        #518 x518

    def _random_transform(self, image):
        # Define a list of PyTorch transform functions
        transform_functions = [
            transforms.RandomHorizontalFlip(p=1),
            transforms.RandomRotation([-90, 90]),
            transforms.RandomVerticalFlip(p=1),
            # Assuming RGBPerturbStainConcentrationTransform is a custom transform
            # RGBPerturbStainConcentrationTransform(sigma1=0.7, sigma2=0.7)
        ]
        
        # Apply the transform function to the image
        n_transforms = random.choice([1, 2, 3])
        for _ in range(n_transforms):
            try:
                transform_function = random.choice(transform_functions)
                image = transform_function(image)
            except:
                pass
        # Return the transformed images as a tensor
        transformed_image = image
        return transformed_image

    def __getitem__(self, idx):

        image_path = self.img_files.iloc[idx, 0]  # Assuming the first column contains image file paths
        # image_path = image_path.split('/gladstone/finkbeiner/steve')
        # image_path = '/Volumes/Finkbeiner-Steve' + image_path[1]
        im = imread(image_path)
     
        im = torch.from_numpy(im)
        im = im.permute(2, 0, 1) # bring channels first
        
        # Assuming 'root' and 'ii' are not defined here
        # im = torch.Tensor(np.array(root['arr_0'][ii])).permute(2, 0, 1)

        t_im = self._random_transform(im)
        im = self.resize_transform(im)/255
        t_im = self.resize_transform(t_im)/255
        return im, t_im

    def __len__(self):
        return self.length



class DINOLossSingleViews(nn.Module):
    def __init__(
        self,
        output_dim: int = 65536,
        warmup_teacher_temp: float = 0.04,
        teacher_temp: float = 0.07,
        warmup_teacher_temp_epochs: int = 10,
        student_temp: float = 0.1,
        center_momentum: float = 0.9,
    ):
        super().__init__()
        self.warmup_teacher_temp_epochs = warmup_teacher_temp_epochs
        self.teacher_temp = teacher_temp
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.register_buffer("center", torch.zeros(1, output_dim))
        self.teacher_temp_schedule = torch.linspace(
            start=warmup_teacher_temp,
            end=teacher_temp,
            steps=warmup_teacher_temp_epochs,
        )

    def forward(
        self,
        teacher_out: torch.Tensor,
        student_out: torch.Tensor,
        epoch: int,
        validation: bool = False,
    ) -> torch.Tensor:
        """Cross-entropy between softmax outputs of the teacher and student
        networks.

        Paramters
        ---------
            teacher_out:
                feature tensors from the teacher model. Each tensor is assumed 
                to contain features from one view of the batch and have length batch_size.
            student_out:
                feature tensors from the student model. Each tensor is assumed 
                to contain features from one view of the batch and have length batch_size.
            epoch:
                The current training epoch.

        Returns
        -------
            The average cross-entropy loss.

        """
        # get teacher temperature
        if epoch < self.warmup_teacher_temp_epochs:
            teacher_temp = self.teacher_temp_schedule[epoch]
        else:
            teacher_temp = self.teacher_temp

        t_out = F.softmax((teacher_out - self.center) / teacher_temp, dim=-1)
        s_out = F.log_softmax(student_out / self.student_temp, dim=-1)
        t_out = t_out.unsqueeze(0)
        s_out = s_out.unsqueeze(0)
        
        # calculate feature similarities where:
        # b -> batch_size, t -> n_views_teacher, s -> n_views_student, d -> output_dim
        loss = -torch.einsum("tbd,sbd->ts", t_out, s_out).squeeze()
        loss = loss/teacher_out.shape[0]
        if not validation:
            self.update_center(teacher_out)
        return loss

    @torch.no_grad()
    def update_center(self, teacher_out: torch.Tensor) -> None:
        """Moving average update of the center used for the teacher output.

        Args:
            teacher_out:
                Stacked output from the teacher model.

        """
        batch_center = torch.mean(teacher_out, dim=0, keepdim=True)
        if dist.is_available() and dist.is_initialized():
            dist.all_reduce(batch_center)
            batch_center = batch_center / dist.get_world_size()

        # ema update
        self.center = self.center * self.center_momentum + batch_center * (
            1 - self.center_momentum)

In [6]:
model = DINOViT(proj_dim=1024)
GPU_NUM = 0
CHECKPOINT = '/workspace/Projects/logs/amyb-ssl/5wnhj6vp/checkpoints/last.ckpt'
checkpoint = torch.load(CHECKPOINT)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
model.to(GPU_NUM)

# Check if multiple GPUs are available
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)

print('Done!')

Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main


Done!


In [7]:
train_img_path = '/workspace/Projects/semi-supervised-learning/train.csv'
val_img_path = '/workspace/Projects/semi-supervised-learning/val.csv'


train_dataset = WSIPathSSLDataset(train_img_path)
val_dataset = WSIPathSSLDataset(val_img_path)
train_dl = DataLoader(train_dataset, shuffle=False, batch_size=16, num_workers=18, persistent_workers=True)
val_dl = DataLoader(val_dataset, shuffle=False, batch_size=16, num_workers=18, persistent_workers=True)
    

In [8]:
split = []
embeddings = []

for batch in tqdm(train_dl):
    X, _ = batch
    out = model(X.to(GPU_NUM))
    out = out.detach().cpu().numpy()
    embeddings.append(out)
    split.append(['Train']*X.shape[0])

for batch in tqdm(val_dl):
    X, _ = batch
    out = model(X.to(GPU_NUM))
    out = out.detach().cpu().numpy()
    embeddings.append(out)
    split.append(['Validation']*X.shape[0])

split = np.hstack(split)
# embeddings = np.load("embeddings.npy")
embeddings = np.vstack(embeddings)
embeddings_norm = F.normalize(torch.Tensor(embeddings)).numpy()

100%|██████████| 3094/3094 [12:25<00:00,  4.15it/s]
100%|██████████| 344/344 [01:25<00:00,  4.02it/s]


In [9]:
np.save("embeddings-518.npy", embeddings)


In [10]:
import umap
# print(umap.__version__)


In [11]:
import umap.umap_ as umap


u_obj = umap.UMAP(metric='cosine')
u = u_obj.fit_transform(embeddings_norm)

In [12]:
len(u)

54996

In [None]:

rand_inds = np.random.choice(np.arange(len(u))[::40], 1000, replace=False)

coordinates = u[rand_inds]


images = [train_dataset[i][0].permute(1, 2, 0).numpy() if i < len(train_dataset) else val_dataset[i-len(train_dataset)][0].permute(1, 2, 0).numpy() for i in rand_inds]

plt.figure(figsize=(20, 20))
sns.scatterplot(x=u[: : 40, 0], y=u[::40, 1], s=1, alpha=0.05)
ax = plt.gca()
for coord, img in zip(coordinates, images):
    imgbox = plt.matplotlib.offsetbox.OffsetImage(img, zoom=0.08, resample=True, clip_path=None)
    ab = plt.matplotlib.offsetbox.AnnotationBbox(imgbox, coord, frameon=False, pad=0, xycoords='data', boxcoords="data")
    ax.add_artist(ab)
plt.savefig('scatter-518.jpg')
plt.show()