In [1]:
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # put -1 to not use any

import sys
sys.path.append('/home/lumargot/SurgicalSAM/segment-anything')
sys.path.append('/home/lumargot/hysterectomy-coach/src/py')

sys.path.append('/home/lumargot/SurgicalSAM/surgicalSAM')

In [2]:
! pip install pytorch-metric-learning



In [3]:
import pytorch_lightning as pl
import torch
from torch.nn import CrossEntropyLoss
from loss import DiceLoss
from torch.utils.data import Dataset, DataLoader

from pytorch_metric_learning import losses as pml_losses
from model_forward import model_forward_function
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.transforms import ResizeLongestSide

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# ! pip uninstall -y segment_anything

In [5]:
import pandas as pd
import albumentations as A
import torch.nn.functional as F


  check_for_updates()


In [6]:
import SimpleITK as sitk
import numpy as np
import random
from model import Learnable_Prototypes, Prototype_Prompt_Encoder


In [7]:
class HystDatasetSAMPrompt(Dataset):
    def __init__(self, df, mount_point="./", transform=None, img_column="img_path", seg_column='seg_path',
                 class_column='class_column', predictor=None):
        self.df = df
        self.mount_point = mount_point
        self.transform = transform
        self.img_column = img_column
        self.seg_column = seg_column
        self.class_column = class_column
        self.predictor = predictor

        self.df_subject = self.df[self.img_column].drop_duplicates().reset_index()

    def __len__(self):
        return len(self.df_subject)

    def __getitem__(self, idx):
        subject = self.df_subject.iloc[idx][self.img_column]
        img_path = os.path.join(self.mount_point, subject)
        df_patches = self.df.loc[self.df[self.img_column] == subject]

        # Load image
        img = sitk.GetArrayFromImage(sitk.ReadImage(img_path)).astype(np.float32).squeeze() / 255.0
        shape = img.shape[:2]

        masks, bboxes, labels = [], [], []

        for _, row in df_patches.iterrows():
            label = row[self.class_column]
            seg_path = os.path.join(self.mount_point, row[self.seg_column])
            seg = sitk.GetArrayFromImage(sitk.ReadImage(seg_path)).astype(np.float32).squeeze()

            x, y, w, h = row['x'] * shape[1], row['y'] * shape[0], row['w'] * shape[1], row['h'] * shape[0]
            bbox = [np.clip(x, 0, shape[1] - 5),
                    np.clip(y, 0, shape[0] - 5),
                    np.clip(x + w + 1, 5, shape[1]),
                    np.clip(y + h + 1, 5, shape[0])]

            masks.append(seg)
            bboxes.append(bbox)
            labels.append(label)

        if len(masks) == 0:
            return self.__getitem__(random.randint(0, len(self) - 1))

        d_aug = self.transform(image=img, bboxes=bboxes, category_ids=labels, mask=np.stack(masks))
        img = d_aug['image']
        masks = d_aug['mask']
        bboxes = d_aug['bboxes']

        img_tensor = torch.tensor(img).permute(2, 0, 1)  # (1, 3, H, W)
        mask_tensor = torch.tensor(masks).permute(2, 0, 1)

        # obtain SAM feature of the augmented frame
        self.predictor.set_image(img_tensor)
        feat = self.predictor.features.squeeze() #


        class_embeddings = []

        for mask, label in zip(masks, labels):
            mask_tensor = torch.tensor(mask, dtype=torch.float32)

            mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0)  # [1, 1, H, W]
            mask_resized = F.interpolate(mask_tensor, size=(64,64), mode='bilinear', align_corners=False)
            mask_resized = mask_resized.squeeze()  # shape [H, W]

            if torch.sum(mask_resized > 0.1) == 0:
                continue

            foreground_mask = mask_resized > 0.1  # Threshold to binarize
            class_feat_vectors = feat[:, foreground_mask]  # shape [C, N_pixels]
            cls_emb = class_feat_vectors.mean(dim=1)  # shape [C]


            class_embeddings.append(cls_emb)

        class_embeddings = torch.stack(class_embeddings).unsqueeze(0).transpose(1,2)
        return feat.permute(1, 2, 0).unsqueeze(0), torch.tensor(labels).unsqueeze(0), masks, class_embeddings, img_tensor.shape


In [8]:
class HystDataModuleSAMPrompt(pl.LightningDataModule):
    def __init__(self, df_train, df_val, df_test, mount_point="./", batch_size=16, num_workers=4,
                 img_column="img_path", seg_column="seg_path", class_column="class_column",
                 train_transform=None, valid_transform=None, test_transform=None,sam_checkpoint='model.ckpt',
                 drop_last=False):
        super().__init__()
        self.df_train = df_train
        self.df_val = df_val
        self.df_test = df_test
        self.mount_point = mount_point
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.img_column = img_column
        self.seg_column = seg_column
        self.class_column = class_column
        self.train_transform = train_transform
        self.valid_transform = valid_transform
        self.test_transform = test_transform
        self.drop_last = drop_last

        
        sam = sam_model_registry[f"vit_h"](checkpoint=sam_checkpoint)
        sam.cuda()
        self.predictor = SamPredictor(sam)


    def setup(self, stage=None):
        self.train_ds = HystDatasetSAMPrompt(self.df_train, mount_point=self.mount_point,
                                             img_column=self.img_column, seg_column=self.seg_column,
                                             class_column=self.class_column, transform=self.train_transform,predictor=self.predictor,
                                             )

        self.val_ds = HystDatasetSAMPrompt(self.df_val, mount_point=self.mount_point,
                                           img_column=self.img_column, seg_column=self.seg_column,
                                           class_column=self.class_column, transform=self.valid_transform,predictor=self.predictor,
                                           )

        self.test_ds = HystDatasetSAMPrompt(self.df_test, mount_point=self.mount_point,
                                            img_column=self.img_column, seg_column=self.seg_column,
                                            class_column=self.class_column, transform=self.test_transform,predictor=self.predictor,
                                            )

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True,
                          num_workers=self.num_workers, drop_last=self.drop_last)

    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=self.batch_size,
                          num_workers=self.num_workers, drop_last=self.drop_last)

    def test_dataloader(self):
        return DataLoader(self.test_ds, batch_size=self.batch_size,
                          num_workers=self.num_workers, drop_last=self.drop_last)

In [9]:
class BBXImageTrainTransform():
    def __init__(self):

        self.transform = A.Compose(
            [
                A.LongestMaxSize(max_size_hw=(480, None)),
                A.CenterCrop(height=480, width=836, pad_if_needed=True),
                A.HorizontalFlip(),
                A.GaussNoise(),
                A.OneOf(
                    [
                        A.MotionBlur(p=.2),
                        A.MedianBlur(blur_limit=3, p=0.1),
                        A.Blur(blur_limit=3, p=0.1),
                    ], p=0.2),
                A.OneOf(
                    [
                        A.OpticalDistortion(p=0.3),
                        A.GridDistortion(p=.1),
                        ], p=0.2),
                A.OneOf(
                    [
                        A.CLAHE(clip_limit=2),
                        A.RandomBrightnessContrast(),
                    ], p=0.3),
                A.HueSaturationValue(p=0.3),
                A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=1.0),

            ], 
            bbox_params=A.BboxParams(format='pascal_voc', min_area=32, min_visibility=0.1, label_fields=['category_ids']),
            additional_targets={'mask': 'masks'}
        )

    def __call__(self, image, bboxes, category_ids, mask):
        return self.transform(image=image, bboxes=bboxes, category_ids=category_ids, mask=mask)


In [10]:
df_test = pd.read_csv('/CMF/data/lumargot/hysterectomy/mnt/surgery_tracking/csv/dataset_test.csv')
df_train = pd.read_csv('/CMF/data/lumargot/hysterectomy/mnt/surgery_tracking/csv/dataset_train_train.csv')
df_val = pd.read_csv('/CMF/data/lumargot/hysterectomy/mnt/surgery_tracking/csv/dataset_train_test.csv')
sam_checkpoint = "/home/lumargot/SurgicalSAM/ckp/sam/sam_vit_h_4b8939.pth"

df_labels = pd.concat([df_train, df_val, df_test])

img_column = 'img_path'
seg_column = 'seg_path'
class_column = 'simplified_class'
label_column = 'simplified_label'
mount_point = '/CMF/data/lumargot/hysterectomy/mnt/surgery_tracking/'


ttdata = HystDataModuleSAMPrompt( df_test, df_test, df_test, batch_size=1, num_workers=1, 
                            img_column=img_column,seg_column=seg_column, class_column=class_column, 
                            mount_point=mount_point,train_transform=BBXImageTrainTransform(),
                            valid_transform=BBXImageTrainTransform(), 
                            test_transform=BBXImageTrainTransform(), sam_checkpoint=sam_checkpoint)

ttdata.setup()

test_dl = ttdata.test_dataloader()
test_ds = ttdata.test_ds

In [121]:
import torch 
import torch.nn as nn 
from einops import rearrange
import torch.nn.functional as F

import pdb

class Learnable_Prototypes(nn.Module):
    def __init__(self, num_classes=7 , feat_dim=256):
        super(Learnable_Prototypes, self).__init__()
        self.class_embeddings = nn.Embedding(num_classes, feat_dim)
        
    def forward(self):
        return self.class_embeddings.weight

class Prototype_Prompt_Encoder(nn.Module):
    def __init__(self, feat_dim=256, 
                        hidden_dim_dense=128, 
                        hidden_dim_sparse=128, 
                        size=64, 
                        num_tokens=8,
                        num_classes=1):
                
        super(Prototype_Prompt_Encoder, self).__init__()
        self.dense_fc_1 = nn.Conv2d(feat_dim, hidden_dim_dense, 1)
        self.dense_fc_2 = nn.Conv2d(hidden_dim_dense, feat_dim, 1)
        
        self.relu = nn.ReLU()

        self.sparse_fc_1 = nn.Conv1d(size*size, hidden_dim_sparse, 1)
        self.sparse_fc_2 = nn.Conv1d(hidden_dim_sparse, num_tokens, 1)
        self.num_classes = num_classes
        
        
        pn_cls_embeddings = [nn.Embedding(num_tokens, feat_dim) for _ in range(2)] # one for positive and one for negative 

            
        self.pn_cls_embeddings = nn.ModuleList(pn_cls_embeddings)
                
    def forward(self, feat, prototypes, cls_ids):
  
        cls_prompts = prototypes
        # cls_prompts = torch.stack([cls_prompts for _ in range(feat.size(0))], dim=0)
        
        # feat = torch.stack([feat for _ in range(cls_prompts.size(1))], dim=1)

        print(feat.shape, cls_prompts.shape, cls_ids.shape)
        # compute similarity matrix 
        sim = torch.matmul(feat, cls_prompts)
        feat = feat.unsqueeze(3)              # [1, 1, 4096, 1, 256]
        sim = sim.unsqueeze(-1)               # [1, 1, 4096, 3, 1]

        # compute class-activated feature
        feat =  feat + feat*sim
        feat_sparse = feat.clone()        

        feat_selected = feat[:, 0, :, cls_ids, :] 

        feat = rearrange(feat_selected,'b (h w) c -> b c h w', h=64, w=64)
        dense_embeddings = self.dense_fc_2(self.relu(self.dense_fc_1(feat)))
        
        # compute sparse embeddings
        feat_sparse = feat_sparse.squeeze(1)
        feat_sparse = rearrange(feat_sparse,'b hw num_cls c -> (b num_cls) hw c')
        sparse_embeddings = self.sparse_fc_2(self.relu(self.sparse_fc_1(feat_sparse)))
        sparse_embeddings = rearrange(sparse_embeddings,'(b num_cls) n c -> b num_cls n c', num_cls=self.num_classes)
        
        pos_embed = self.pn_cls_embeddings[1].weight.unsqueeze(0).unsqueeze(0)
        neg_embed = self.pn_cls_embeddings[0].weight.unsqueeze(0).unsqueeze(0)
        

        sparse_embeddings = sparse_embeddings + pos_embed.detach() + neg_embed.detach()
            
        sparse_embeddings = rearrange(sparse_embeddings,'b num_cls n c -> b (num_cls n) c')
        
        return dense_embeddings, sparse_embeddings

class Learnable_Prototypes(nn.Module):
    def __init__(self, num_classes=7 , feat_dim=256):
        super(Learnable_Prototypes, self).__init__()
        self.class_embeddings = nn.Embedding(num_classes, feat_dim)
        
    def forward(self):
        return self.class_embeddings.weight

In [122]:
class SAMTrainer(pl.LightningModule):
    def __init__(self, sam_checkpoint='model.ckpt'):
        super().__init__()
        self.save_hyperparameters()
        self.weights=None
        
        model_type = "vit_h_no_image_encoder"
        # model_type = "vit_h"

        self.sam_prompt_encoder, self.sam_decoder = sam_model_registry[model_type](checkpoint=sam_checkpoint)
        self.sam_prompt_encoder.cuda().eval().requires_grad_(False)
        self.sam_decoder.cuda()

        self.learnable_prototypes_model = Learnable_Prototypes(num_classes = 3, 
                                                               feat_dim = 256).cuda()
        
        self.prototype_prompt_encoder =  Prototype_Prompt_Encoder(feat_dim = 256, 
                                                            hidden_dim_dense = 128, 
                                                            hidden_dim_sparse = 128, 
                                                            size = 64,
                                                            num_classes=3,
                                                            num_tokens = 2).cuda()


        self.seg_loss = DiceLoss()
        self.class_loss = CrossEntropyLoss(weight=self.weights)
        self.contrastive_loss = pml_losses.NTXentLoss(temperature=0.07)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, sam_feats, cls_ids, class_embeddings, img_size):
        prototypes = self.learnable_prototypes_model()
        
        if sam_feats.shape[0] == 1:
            prototypes = prototypes.unsqueeze(0).permute(0,2,1)        

        return self.model_forward_function(
            self.prototype_prompt_encoder,
            self.sam_prompt_encoder,
            self.sam_decoder,
            sam_feats,
            prototypes,
            cls_ids,
        )

    # forward process of the model
    def model_forward_function(self, prototype_prompt_encoder, 
                                sam_prompt_encoder, 
                                sam_decoder, 
                                sam_feats, 
                                prototypes, 
                                cls_ids): 

        sam_feats = rearrange(sam_feats, 'b h w c -> b (h w) c')
        
        dense_embeddings, sparse_embeddings = prototype_prompt_encoder(sam_feats, prototypes, cls_ids)

        pred = []
        pred_quality = []

        sam_feats = rearrange(sam_feats,'b (h w) c -> b c h w', h=64, w=64)
    
        for dense_embedding, sparse_embedding, features_per_image in zip(dense_embeddings.unsqueeze(1), sparse_embeddings.unsqueeze(1), sam_feats):    
            
            low_res_masks_per_image, mask_quality_per_image = sam_decoder(
                    image_embeddings=features_per_image.unsqueeze(0),
                    image_pe=sam_prompt_encoder.get_dense_pe(), 
                    sparse_prompt_embeddings=sparse_embedding,
                    dense_prompt_embeddings=dense_embedding, 
                    multimask_output=False,
                )
            
            pred.append(low_res_masks_per_image)
            pred_quality.append(mask_quality_per_image.detach().cpu())
            
        pred = torch.cat(pred,dim=0).squeeze(1)
        pred_quality = torch.cat(pred_quality,dim=0).squeeze(1)
        
        return pred, pred_quality


    def postprocess_masks(self, masks, input_size, original_size):
        """
        Remove padding and upscale masks to the original image size.

        Arguments:
            masks (torch.Tensor): Batched masks from the mask_decoder,
            in BxCxHxW format.
            input_size (tuple(int, int)): The size of the image input to the
            model, in (H, W) format. Used to remove padding.
            original_size (tuple(int, int)): The original size of the image
            before resizing for input to the model, in (H, W) format.

        Returns:
            (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
            is given by original_size.
        """
        masks = F.interpolate(
            masks,
            (1024, 1024),
            mode="bilinear",
            align_corners=False,
        )
        masks = masks[..., : input_size[0], : input_size[1]]
        masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
        return masks

    def training_step(self, batch, batch_idx):
        sam_feats, _, cls_ids, masks, class_embeddings, img_size = batch
        sam_feats, cls_ids, masks, class_embeddings = sam_feats.cuda(), cls_ids.cuda(), masks.cuda(), class_embeddings.cuda()

        preds, _, cls_probs = self.forward(sam_feats, cls_ids, class_embeddings, img_size[1:])

        cls_loss = self.class_loss(cls_probs, cls_ids)
        contrastive_loss = self.contrastive_loss(
            self.learnable_prototypes_model(),
            torch.arange(1, self.hparams.num_classes + 1).cuda(),
            ref_emb=class_embeddings,
            ref_labels=cls_ids
        )
        seg_loss = self.seg_loss(preds, masks / 255.)
        total_loss = seg_loss + contrastive_loss + cls_loss

        self.log("train/seg_loss", seg_loss)
        self.log("train/class_loss", cls_loss)
        self.log("train/contrastive_loss", contrastive_loss)
        self.log("train/loss", total_loss)

        return total_loss

    def validation_step(self, batch, batch_idx):
        sam_feats, mask_names, cls_ids, masks, class_embeddings = batch
        sam_feats, cls_ids = sam_feats.cuda(), cls_ids.cuda()
        preds, preds_quality, cls_probs = self.forward(sam_feats, cls_ids, class_embeddings, img_size[1:])

        cls_loss = self.class_loss(cls_probs, cls_ids)
        contrastive_loss = self.contrastive_loss(
            self.learnable_prototypes_model(),
            torch.arange(1, self.hparams.num_classes + 1).cuda(),
            ref_emb=class_embeddings,
            ref_labels=cls_ids
        )
        seg_loss = self.seg_loss(preds, masks.cuda() / 255.)
        total_loss = seg_loss + contrastive_loss + cls_loss

        self.log("val/seg_loss", seg_loss)
        self.log("val/class_loss", cls_loss)
        self.log("val/contrastive_loss", contrastive_loss)
        self.log("val/loss", total_loss)

        return total_loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam([
            {'params': self.learnable_prototypes_model.parameters()},
            {'params': self.prototype_prompt_encoder.parameters()},
            {'params': self.sam_decoder.parameters()}
        ], lr=self.hparams.lr, weight_decay=0.0001)
        return optimizer


In [123]:
model = SAMTrainer(sam_checkpoint=sam_checkpoint)

In [124]:
feat, labels, masks, class_embedding, img_size = test_ds[0]
feat.shape, labels.shape, masks.shape, class_embedding.shape

(torch.Size([1, 64, 64, 256]),
 torch.Size([1, 3]),
 (3, 480, 836),
 torch.Size([1, 256, 3]))

In [125]:
class_embedding = class_embedding.permute(2,1,0)
labels = labels.permute(1,0)
feats = torch.cat([feat, feat, feat])

In [None]:
out = model(feats, labels, class_embedding, img_size)

torch.Size([3, 4096, 256]) torch.Size([3, 256]) torch.Size([3, 1])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (12288x256 and 3x256)

: 

In [91]:
out[0].min()

tensor(-157.5040, device='cuda:0', grad_fn=<MinBackward1>)

In [92]:
out[1], masks.shape

(tensor([0.3917]), (3, 480, 836))

In [None]:
# compute loss 
contrastive_loss = model.contrastive_loss_model(prototypes, torch.tensor([i for i in range(1, prototypes.size()[0] + 1)]).cuda(), ref_emb = class_embeddings, ref_labels = cls_ids)
seg_loss = model.seg_loss_model(preds, masks/255)

loss = seg_loss + contrastive_loss

