### Image Size Changes:

1. Update args dictionary in notebook
2. Update GlaucomaDataset to resize to appropriate size
3. Update feat_sizes
4. Update the sam_prompt_encoder (via notebook cell 3)
5. Update _bb_feat_sizes in sam2_image_predictor.py

In [1]:
# General packages
import torch
import numpy as np
import pandas as pd
import sys
import os
from PIL import Image
import cv2
import matplotlib.pyplot as plt

# Medical-SAM2 imports
from sam2_train.sam2_image_predictor import SAM2ImagePredictor
from sam2_train.build_sam import build_sam2
from sam2_train.modeling.sam2_utils import MLP

# Data processing
sys.path.append("../unet")
from GlaucomaDataset import GlaucomaDatasetBoundingBoxes
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

# Training imports
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

KeyboardInterrupt: 

In [2]:
# Load in the base SAM2 small model
med_sam_2 = build_sam2("sam2_hiera_s", "./checkpoints/sam2_hiera_small.pt", device="cuda")

In [3]:
med_sam_2.sam_mask_decoder.num_multimask_outputs = 2
med_sam_2.sam_mask_decoder.num_mask_tokens = 1 + med_sam_2.sam_mask_decoder.num_multimask_outputs
med_sam_2.sam_mask_decoder.mask_tokens = nn.Embedding(med_sam_2.sam_mask_decoder.num_mask_tokens, med_sam_2.sam_mask_decoder.transformer_dim)
med_sam_2.sam_mask_decoder.output_hypernetworks_mlps = nn.ModuleList(
            [
                MLP(med_sam_2.sam_mask_decoder.transformer_dim, med_sam_2.sam_mask_decoder.transformer_dim, med_sam_2.sam_mask_decoder.transformer_dim // 8, 3)
                for i in range(med_sam_2.sam_mask_decoder.num_mask_tokens)
            ]
        )
med_sam_2.sam_mask_decoder.iou_prediction_head = MLP(
            med_sam_2.sam_mask_decoder.transformer_dim,
            256,
            med_sam_2.sam_mask_decoder.num_mask_tokens,
            3,
            sigmoid_output=True,
        )

In [4]:
# Model architecture updates for working with different image sizes

# 256x256 input images 
# med_sam_2.sam_prompt_encoder.image_embedding_size = (16, 16)
# med_sam_2.sam_prompt_encoder.input_image_size = (256, 256)
# med_sam_2.sam_prompt_encoder.mask_input_size = (64, 64)
# med_sam_2.image_size = 256
# med_sam_2.sam_mask_decoder.num_mask_tokens = 2


# 512x512 input images
# med_sam_2.sam_prompt_encoder.image_embedding_size = (32, 32)
# med_sam_2.sam_prompt_encoder.input_image_size = (512, 512)
# med_sam_2.sam_prompt_encoder.mask_input_size = (128, 128)
# med_sam_2.image_size = 512
# med_sam_2.sam_mask_decoder.num_mask_tokens = 2

In [4]:
def dice_score(pred_mask: torch.Tensor, gt_mask: torch.Tensor, smooth: float = 1e-6) -> torch.Tensor:
    """
    Calculate the Dice score for a single pair of predicted and ground truth masks.

    Args:
        pred_mask (torch.Tensor): Predicted mask of shape [H, W] with values between 0 and 1.
        gt_mask (torch.Tensor): Ground truth mask of shape [H, W] with binary values {0, 1}.
        smooth (float): A small constant added to avoid division by zero.

    Returns:
        torch.Tensor: The Dice score (scalar tensor).
    """
    # Flatten the masks to 1D vectors
    pred_flat = pred_mask.contiguous().view(-1)
    gt_flat   = gt_mask.contiguous().view(-1)
    
    # Compute the intersection and the sums of masks
    intersection = (pred_flat * gt_flat).sum()
    sum_masks = pred_flat.sum() + gt_flat.sum()
    
    # Compute Dice score
    dice = (2.0 * intersection + smooth) / (sum_masks + smooth)
    return dice

In [5]:
# Functions for displaying masks and prompts on final predicted masks

np.random.seed(42)

def show_mask(mask, ax, random_color=False, borders=True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask = mask.astype(np.uint8)
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    if borders:
        import cv2
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
        mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))

def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
    num_masks = len(masks)
    fig, axs = plt.subplots(1, num_masks, figsize=(10 * num_masks, 10))  # Horizontal layout
    if num_masks == 1:
        axs = [axs]  # Ensure it's iterable
    for i, (mask, score) in enumerate(zip(masks, scores)):
        ax = axs[i]
        ax.imshow(image)
        show_mask(mask, ax, borders=borders)
        if point_coords is not None:
            assert input_labels is not None
            show_points(point_coords, input_labels, ax)
        if box_coords is not None:
            show_box(box_coords, ax)
        if num_masks > 1:
            ax.set_title("Optic Disc Mask" if i == 0 else "Optic Cup Mask", fontsize=18)
        ax.axis('off')
    plt.tight_layout()
    plt.show()

In [7]:
# ChatGPT suggested changes to model architecture:
# med_sam_2.sam_mask_decoder.mask_tokens = nn.Embedding(2, 256)
# med_sam_2.sam_mask_decoder.output_hypernetworks_mlps = nn.ModuleList([MLP()])

In [None]:
# Training function for Medical-SAM2

GPUdevice = torch.device('cuda')
pos_weight = torch.ones([1]).cuda(device=GPUdevice)*2
criterion_G = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
mask_type = torch.float32

torch.backends.cudnn.benchmark = True

def train_sam(args, net: nn.Module, optimizer, train_loader, epoch):
    
    # use bfloat16 for the entire notebook
    torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

    if torch.cuda.get_device_properties(0).major >= 8:
        # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    
    # train mode
    net.train()
    optimizer.zero_grad()

    # init
    epoch_loss = 0
    epoch_dice = 0
    memory_bank_list = []
    lossfunc = criterion_G

    # Update to match the input image sizes
    feat_sizes = [(256, 256), (128, 128), (64, 64)] # 1024x1024 images
    # feat_sizes = [(64, 64), (32, 32), (16, 16)] # 256x256 images
    # feat_sizes = [(128,128), (64,64), (32,32)] # 512x512 images


    with tqdm(total=len(train_loader), desc=f'Epoch {epoch}', unit='img') as pbar:
        for ind, pack in enumerate(train_loader):
            # print(f"Running for batch {ind}.")
            
            to_cat_memory = []
            to_cat_memory_pos = []
            to_cat_image_embed = []

            # input image and gt masks
            imgs = pack['image'].to(dtype = mask_type, device = GPUdevice)
            masks = pack['mask'].to(dtype = mask_type, device = GPUdevice)
            name = pack['image_filename']

            # click prompt: unsqueeze to indicate only one click, add more click across this dimension
            if 'bbox' in pack:
                # Assuming bbox shape is [batch, 4]; adjust unsqueeze if more than one box is expected.
                boxes = pack['bbox'].unsqueeze(1).to(device=GPUdevice, dtype=torch.float)
            else:
                boxes = None

            '''Train image encoder'''                    
            backbone_out = net.forward_image(imgs)
            _, vision_feats, vision_pos_embeds, _ = net._prepare_backbone_features(backbone_out)
            
            # dimension hint for your future use (in reference to 1024x1024 images)
            # vision_feats: list: length = 3
            # vision_feats[0]: torch.Size([65536, batch, 32])
            # vision_feats[1]: torch.Size([16384, batch, 64])
            # vision_feats[2]: torch.Size([4096, batch, 256])
            # vision_pos_embeds[0]: torch.Size([65536, batch, 256])
            # vision_pos_embeds[1]: torch.Size([16384, batch, 256])
            # vision_pos_embeds[2]: torch.Size([4096, batch, 256])
            
            

            '''Train memory attention to condition on meomory bank'''         
            B = vision_feats[-1].size(1)  # batch size 
            
            if len(memory_bank_list) == 0:
                vision_feats[-1] = vision_feats[-1] + torch.nn.Parameter(torch.zeros(1, B, net.hidden_dim)).to(device="cuda")
                vision_pos_embeds[-1] = vision_pos_embeds[-1] + torch.nn.Parameter(torch.zeros(1, B, net.hidden_dim)).to(device="cuda")
                
            else:
                for element in memory_bank_list:
                    to_cat_memory.append((element[0]).cuda(non_blocking=True).flatten(2).permute(2, 0, 1)) # maskmem_features
                    to_cat_memory_pos.append((element[1]).cuda(non_blocking=True).flatten(2).permute(2, 0, 1)) # maskmem_pos_enc
                    to_cat_image_embed.append((element[3]).cuda(non_blocking=True)) # image_embed

                memory_stack_ori = torch.stack(to_cat_memory, dim=0)
                memory_pos_stack_ori = torch.stack(to_cat_memory_pos, dim=0)
                image_embed_stack_ori = torch.stack(to_cat_image_embed, dim=0)
 
                vision_feats_temp = vision_feats[-1].permute(1, 0, 2).reshape(B, -1, 64, 64) 
                vision_feats_temp = vision_feats_temp.reshape(B, -1)

                image_embed_stack_ori = F.normalize(image_embed_stack_ori, p=2, dim=1)
                vision_feats_temp = F.normalize(vision_feats_temp, p=2, dim=1)
                similarity_scores = torch.mm(image_embed_stack_ori, vision_feats_temp.t()).t()
                
                similarity_scores = F.softmax(similarity_scores, dim=1) 
                sampled_indices = torch.multinomial(similarity_scores, num_samples=B, replacement=True).squeeze(1)  # Shape [batch_size, 16]

                memory_stack_ori_new = (memory_stack_ori[sampled_indices].squeeze(3).permute(1, 2, 0, 3))
                memory = memory_stack_ori_new.reshape(-1, memory_stack_ori_new.size(2), memory_stack_ori_new.size(3))

                memory_pos_stack_new = (memory_pos_stack_ori[sampled_indices].squeeze(3).permute(1, 2, 0, 3))
                memory_pos = memory_pos_stack_new.reshape(-1, memory_stack_ori_new.size(2), memory_stack_ori_new.size(3))


                vision_feats[-1] = net.memory_attention(
                    curr=[vision_feats[-1]],
                    curr_pos=[vision_pos_embeds[-1]],
                    memory=memory,
                    memory_pos=memory_pos,
                    num_obj_ptr_tokens=0
                    )


            feats = [feat.permute(1, 2, 0).reshape(B, -1, *feat_size) 
                     for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1])][::-1]
            
            image_embed = feats[-1]
            high_res_feats = feats[:-1]

            # for a, feat in enumerate(feats):
            #     print(f'feat[{a}]: {feat.shape}')
            
            # feats[0]: torch.Size([batch, 32, 256, 256]) #high_res_feats part1
            # feats[1]: torch.Size([batch, 64, 128, 128]) #high_res_feats part2
            # feats[2]: torch.Size([batch, 256, 64, 64]) #image_embed

            '''prompt encoder'''         
            with torch.no_grad():
                if (ind%5) == 0:
                    #points=(coords_torch, labels_torch) # input shape: ((batch, n, 2), (batch, n))
                    flag = True
                else:
                    points=None
                    flag = False

                se, de = net.sam_prompt_encoder(
                    points=None,   # No point prompts used
                    boxes=boxes,   # Use the bounding boxes from the data
                    masks=None,
                    batch_size=B,
                )
            # dimension hint for your future use
            # se: torch.Size([batch, n+1, 256])
            # de: torch.Size([batch, 256, 64, 64])
            print('se:', se.shape)
            print('de:', de.shape)
            print('image_embed:', image_embed.shape)

            '''train mask decoder'''       
            low_res_multimasks, iou_predictions, sam_output_tokens, object_score_logits = net.sam_mask_decoder(
                    image_embeddings=image_embed,
                    image_pe=net.sam_prompt_encoder.get_dense_pe(), 
                    sparse_prompt_embeddings=se,
                    dense_prompt_embeddings=de, 
                    multimask_output=True, # args.multimask_output if you want multiple masks
                    repeat_image=False,  # the image is already batched
                    high_res_features = high_res_feats
                )
            # dimension hint for your future use
            # low_res_multimasks: torch.Size([batch, multimask_output, 256, 256])
            # iou_predictions.shape:torch.Size([batch, multimask_output])
            # sam_output_tokens.shape:torch.Size([batch, multimask_output, 256])
            # object_score_logits.shape:torch.Size([batch, 1])
            
            
            # resize prediction
            pred = F.interpolate(low_res_multimasks,size=(args['out_size'],args['out_size']))
            high_res_multimasks = F.interpolate(low_res_multimasks, size=(args['image_size'], args['image_size']),
                                                mode="bilinear", align_corners=False)
            
            '''memory encoder'''       
            # new caluculated memory features
            maskmem_features, maskmem_pos_enc = net._encode_new_memory(
                current_vision_feats=vision_feats,
                feat_sizes=feat_sizes,
                pred_masks_high_res=high_res_multimasks,
                is_mask_from_pts=flag)  
            # dimension hint for your future use
            # maskmem_features: torch.Size([batch, 64, 64, 64])
            # maskmem_pos_enc: [torch.Size([batch, 64, 64, 64])]
                
            maskmem_features = maskmem_features.to(torch.bfloat16)
            maskmem_features = maskmem_features.to(device=GPUdevice, non_blocking=True)
            maskmem_pos_enc = maskmem_pos_enc[0].to(torch.bfloat16)
            maskmem_pos_enc = maskmem_pos_enc.to(device=GPUdevice, non_blocking=True)


            # add single maskmem_features, maskmem_pos_enc, iou
            if len(memory_bank_list) < args['memory_bank_size']:
                for batch in range(maskmem_features.size(0)):
                    memory_bank_list.append([(maskmem_features[batch].unsqueeze(0)).detach(),
                                             (maskmem_pos_enc[batch].unsqueeze(0)).detach(),
                                             iou_predictions[batch, 0],
                                             image_embed[batch].reshape(-1).detach()])
            
            else:
                for batch in range(maskmem_features.size(0)):
                    
                    # current simlarity matrix in existing memory bank
                    memory_bank_maskmem_features_flatten = [element[0].reshape(-1) for element in memory_bank_list]
                    memory_bank_maskmem_features_flatten = torch.stack(memory_bank_maskmem_features_flatten)

                    # normalise
                    memory_bank_maskmem_features_norm = F.normalize(memory_bank_maskmem_features_flatten, p=2, dim=1)
                    current_similarity_matrix = torch.mm(memory_bank_maskmem_features_norm,
                                                         memory_bank_maskmem_features_norm.t())

                    # replace diagonal (diagnoal always simiarity = 1)
                    current_similarity_matrix_no_diag = current_similarity_matrix.clone()
                    diag_indices = torch.arange(current_similarity_matrix_no_diag.size(0))
                    current_similarity_matrix_no_diag[diag_indices, diag_indices] = float('-inf')

                    # first find the minimum similarity from memory feature and the maximum similarity from memory bank
                    single_key_norm = F.normalize(maskmem_features[batch].reshape(-1), p=2, dim=0).unsqueeze(1)
                    similarity_scores = torch.mm(memory_bank_maskmem_features_norm, single_key_norm).squeeze()
                    min_similarity_index = torch.argmin(similarity_scores) 
                    max_similarity_index = torch.argmax(current_similarity_matrix_no_diag[min_similarity_index])

                    # replace with less similar object
                    if similarity_scores[min_similarity_index] < current_similarity_matrix_no_diag[min_similarity_index][max_similarity_index]:
                        # soft iou, not stricly greater than current iou
                        if iou_predictions[batch, 0] > memory_bank_list[max_similarity_index][2] - 0.1:
                            memory_bank_list.pop(max_similarity_index) 
                            memory_bank_list.append([(maskmem_features[batch].unsqueeze(0)).detach(),
                                                     (maskmem_pos_enc[batch].unsqueeze(0)).detach(),
                                                     iou_predictions[batch, 0],
                                                     image_embed[batch].reshape(-1).detach()])

            # backpropagation
            loss_disc = lossfunc(pred[:, 0, :, :], masks[:, 0, :, :])
            loss_cup  = lossfunc(pred[:, 1, :, :], masks[:, 1, :, :])
            dice_disc = dice_score(torch.sigmoid(pred[:, 0, :, :]), masks[:, 0, :, :])
            dice_cup  = dice_score(torch.sigmoid(pred[:, 1, :, :]), masks[:, 1, :, :])

            loss = .3 * loss_disc + .7 * loss_cup
            dice = (dice_disc + dice_cup) / 2

            pbar.set_postfix(**{'loss (batch)': loss.item()})
            epoch_loss += loss.item()
            epoch_dice += dice.item()

            loss.backward()
            optimizer.step()
            
            optimizer.zero_grad()

            # Cleanup large intermediate tensors that are no longer needed.
            del backbone_out, vision_feats, vision_pos_embeds, feats, image_embed, high_res_feats, se, de, low_res_multimasks, iou_predictions, sam_output_tokens, object_score_logits, pred

            # Optionally free cached memory.
            if ind % 10 == 0:  # For instance, every 10 iterations
                torch.cuda.empty_cache()

            pbar.update()

    return epoch_loss/len(train_loader), epoch_dice/len(train_loader)


In [7]:
# Validation function for Medical-SAM2
def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True):

    # use bfloat16 for the entire notebook
    torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

    if torch.cuda.get_device_properties(0).major >= 8:
        # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True


    # eval mode
    net.eval()

    n_val = len(val_loader) 
    threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
    GPUdevice = torch.device('cuda:' + str(0))

    # init
    lossfunc = criterion_G
    memory_bank_list = []
    feat_sizes = [(256, 256), (128, 128), (64, 64)] # 1024x1024
    # feat_sizes = [(64, 64), (32, 32), (16, 16)] # 256x256
    # feat_sizes = [(128,128), (64,64), (32,32)] # 512x512
    total_loss = 0
    total_eiou = 0
    total_dice = 0


    with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
        for ind, pack in enumerate(val_loader):
            to_cat_memory = []
            to_cat_memory_pos = []
            to_cat_image_embed = []

            # input image and gt masks
            imgs = pack['image'].to(dtype = mask_type, device = GPUdevice)
            masks = pack['mask'].to(dtype = mask_type, device = GPUdevice)
            name = pack['image_filename']

            # click prompt: unsqueeze to indicate only one click, add more click across this dimension
            if 'bbox' in pack:
                # Assuming bbox shape is [batch, 4]; adjust unsqueeze if more than one box is expected.
                bbox_temp = pack['bbox'].to(device=GPUdevice, dtype=torch.float)
                boxes = bbox_temp.unsqueeze(1)  # Ensure proper dimensions if needed
            else:
                boxes = None



            '''test'''
            with torch.no_grad():

                """ image encoder """
                backbone_out = net.forward_image(imgs)
                _, vision_feats, vision_pos_embeds, _ = net._prepare_backbone_features(backbone_out)
                B = vision_feats[-1].size(1) 

                """ memory condition """
                if len(memory_bank_list) == 0:
                    vision_feats[-1] = vision_feats[-1] + torch.nn.Parameter(torch.zeros(1, B, net.hidden_dim)).to(device="cuda")
                    vision_pos_embeds[-1] = vision_pos_embeds[-1] + torch.nn.Parameter(torch.zeros(1, B, net.hidden_dim)).to(device="cuda")

                else:
                    for element in memory_bank_list:
                        maskmem_features = element[0]
                        maskmem_pos_enc = element[1]
                        to_cat_memory.append(maskmem_features.cuda(non_blocking=True).flatten(2).permute(2, 0, 1))
                        to_cat_memory_pos.append(maskmem_pos_enc.cuda(non_blocking=True).flatten(2).permute(2, 0, 1))
                        to_cat_image_embed.append((element[3]).cuda(non_blocking=True)) # image_embed
                        
                    memory_stack_ori = torch.stack(to_cat_memory, dim=0)
                    memory_pos_stack_ori = torch.stack(to_cat_memory_pos, dim=0)
                    image_embed_stack_ori = torch.stack(to_cat_image_embed, dim=0)

                    vision_feats_temp = vision_feats[-1].permute(1, 0, 2).reshape(B, -1, 64, 64) 
                    vision_feats_temp = vision_feats_temp.reshape(B, -1)

                    image_embed_stack_ori = F.normalize(image_embed_stack_ori, p=2, dim=1)
                    vision_feats_temp = F.normalize(vision_feats_temp, p=2, dim=1)
                    similarity_scores = torch.mm(image_embed_stack_ori, vision_feats_temp.t()).t()

                    similarity_scores = F.softmax(similarity_scores, dim=1) 
                    sampled_indices = torch.multinomial(similarity_scores, num_samples=B, replacement=True).squeeze(1)  # Shape [batch_size, 16]

                    memory_stack_ori_new = (memory_stack_ori[sampled_indices].squeeze(3).permute(1, 2, 0, 3))
                    memory = memory_stack_ori_new.reshape(-1, memory_stack_ori_new.size(2), memory_stack_ori_new.size(3))

                    memory_pos_stack_new = (memory_pos_stack_ori[sampled_indices].squeeze(3).permute(1, 2, 0, 3))
                    memory_pos = memory_pos_stack_new.reshape(-1, memory_stack_ori_new.size(2), memory_stack_ori_new.size(3))



                    vision_feats[-1] = net.memory_attention(
                        curr=[vision_feats[-1]],
                        curr_pos=[vision_pos_embeds[-1]],
                        memory=memory,
                        memory_pos=memory_pos,
                        num_obj_ptr_tokens=0
                        )

                feats = [feat.permute(1, 2, 0).reshape(B, -1, *feat_size) 
                        for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1])][::-1]
                
                image_embed = feats[-1]
                high_res_feats = feats[:-1]

                """ prompt encoder """
                if (ind%5) == 0:
                    flag = True
                    # points = (coords_torch, labels_torch)

                else:
                    flag = False
                    points = None

                se, de = net.sam_prompt_encoder(
                    points=None, 
                    boxes=boxes,
                    masks=None,
                    batch_size=B,
                )

                low_res_multimasks, iou_predictions, sam_output_tokens, object_score_logits = net.sam_mask_decoder(
                    image_embeddings=image_embed,
                    image_pe=net.sam_prompt_encoder.get_dense_pe(), 
                    sparse_prompt_embeddings=se,
                    dense_prompt_embeddings=de, 
                    multimask_output=False, 
                    repeat_image=False,  
                    high_res_features = high_res_feats
                )

                # prediction
                pred = F.interpolate(low_res_multimasks,size=(args['out_size'],args['out_size']))
                high_res_multimasks = F.interpolate(low_res_multimasks, size=(args['image_size'], args['image_size']),
                                                mode="bilinear", align_corners=False)
            
                """ memory encoder """
                maskmem_features, maskmem_pos_enc = net._encode_new_memory( 
                    current_vision_feats=vision_feats,
                    feat_sizes=feat_sizes,
                    pred_masks_high_res=high_res_multimasks[:, 0:1, :, :],
                    is_mask_from_pts=flag)  
                    
                maskmem_features = maskmem_features.to(torch.bfloat16)
                maskmem_features = maskmem_features.to(device=GPUdevice, non_blocking=True)
                maskmem_pos_enc = maskmem_pos_enc[0].to(torch.bfloat16)
                maskmem_pos_enc = maskmem_pos_enc.to(device=GPUdevice, non_blocking=True)


                """ memory bank """
                if len(memory_bank_list) < 16:
                    for batch in range(maskmem_features.size(0)):
                        memory_bank_list.append([(maskmem_features[batch].unsqueeze(0)),
                                                 (maskmem_pos_enc[batch].unsqueeze(0)),
                                                 iou_predictions[batch, 0],
                                                 image_embed[batch].reshape(-1).detach()])
                
                else:
                    for batch in range(maskmem_features.size(0)):
                        
                        memory_bank_maskmem_features_flatten = [element[0].reshape(-1) for element in memory_bank_list]
                        memory_bank_maskmem_features_flatten = torch.stack(memory_bank_maskmem_features_flatten)

                        memory_bank_maskmem_features_norm = F.normalize(memory_bank_maskmem_features_flatten, p=2, dim=1)
                        current_similarity_matrix = torch.mm(memory_bank_maskmem_features_norm,
                                                             memory_bank_maskmem_features_norm.t())

                        current_similarity_matrix_no_diag = current_similarity_matrix.clone()
                        diag_indices = torch.arange(current_similarity_matrix_no_diag.size(0))
                        current_similarity_matrix_no_diag[diag_indices, diag_indices] = float('-inf')

                        single_key_norm = F.normalize(maskmem_features[batch].reshape(-1), p=2, dim=0).unsqueeze(1)
                        similarity_scores = torch.mm(memory_bank_maskmem_features_norm, single_key_norm).squeeze()
                        min_similarity_index = torch.argmin(similarity_scores) 
                        max_similarity_index = torch.argmax(current_similarity_matrix_no_diag[min_similarity_index])

                        if similarity_scores[min_similarity_index] < current_similarity_matrix_no_diag[min_similarity_index][max_similarity_index]:
                            if iou_predictions[batch, 0] > memory_bank_list[max_similarity_index][2] - 0.1:
                                memory_bank_list.pop(max_similarity_index) 
                                memory_bank_list.append([(maskmem_features[batch].unsqueeze(0)),
                                                         (maskmem_pos_enc[batch].unsqueeze(0)),
                                                         iou_predictions[batch, 0],
                                                         image_embed[batch].reshape(-1).detach()])

                # binary mask and calculate loss, iou, dice
                # total_loss += lossfunc(pred, masks)
                # pred = (pred> 0.5).float()
                # # For disc (first mask)
                # dice_disc = eval_dice(pred[:, 0, :, :], masks[:, 0, :, :])
                # iou_disc = eval_iou(pred[:, 0, :, :], masks[:, 0, :, :])
                # # For cup (second mask)
                # dice_cup = eval_dice(pred[:, 1, :, :], masks[:, 1, :, :])
                # iou_cup = eval_iou(pred[:, 1, :, :], masks[:, 1, :, :])

                loss_disc = lossfunc(pred[:, 0, :, :], masks[:, 0, :, :])
                loss_cup  = lossfunc(pred[:, 1, :, :], masks[:, 1, :, :])
                dice_disc = dice_score(torch.sigmoid(pred[:, 0, :, :]), masks[:, 0, :, :])
                dice_cup  = dice_score(torch.sigmoid(pred[:, 1, :, :]), masks[:, 1, :, :])

                loss = .3 * loss_disc + .7 * loss_cup
                dice = (dice_disc + dice_cup) / 2

                # weight the loss from the cup more because it is worse in predictions
                loss = .3 * loss_disc + .7 * loss_cup
                pbar.set_postfix(**{'loss (batch)': loss.item()})
                total_loss += loss.item()
                total_dice += dice.item()


                '''vis images'''
                # if ind % args.vis == 0:
                #     namecat = 'Test'
                #     for na in name:
                #         img_name = na
                #         namecat = namecat + img_name + '+'
                #     vis_image(imgs,pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=None)
            # Cleanup large intermediate tensors that are no longer needed.
            del backbone_out, vision_feats, vision_pos_embeds, feats, image_embed, high_res_feats, se, de, low_res_multimasks, iou_predictions, sam_output_tokens, object_score_logits, pred

            # Optionally free cached memory.
            if ind % 10 == 0:  # For instance, every 10 iterations
                torch.cuda.empty_cache()    
            pbar.update()

    return total_loss/ n_val, total_dice / n_val#, tuple([total_eiou/n_val, total_dice/n_val])

## Data

In [8]:
def update_image_path(path: str) -> str:
    """
    Pulls the file name from a file path.

    Args:
        path (str): file path

    Returns:
        str: the file name
    """
    split_path = path.split("/")
    return split_path[-1]

In [9]:
# Read in the fundus images and ground truth masks
origa_path = os.path.join('..', '..', "data", "ORIGA")
images_path = os.path.join(origa_path, "Images_Square")
masks_path = os.path.join(origa_path, "Masks_Square")

img_filenames = sorted(os.listdir(images_path))
mask_filenames = sorted(os.listdir(masks_path))

# Read in the bounding boxes
bb_df = pd.read_csv("../../data/ORIGA/bounding_boxes.csv")
bb_df['image_path'] = bb_df['image_path'].apply(update_image_path)

# Update the bounding box coordinates based on the image size. Boxes were created on 512x512 images
if med_sam_2.image_size == 256:
    bb_df[['x1', 'y1', 'x2', 'y2']] //= 2
elif med_sam_2.image_size == 1024:
    bb_df[['x1', 'y1', 'x2', 'y2']] *= 2

In [10]:
# Split into train, validation, and test sets (70, 15, 15)
train_imgs, temp_imgs, train_masks, temp_masks = train_test_split(
    img_filenames, mask_filenames, test_size=0.3, random_state=42)

val_imgs, test_imgs, val_masks, test_masks = train_test_split(
    temp_imgs, temp_masks, test_size=0.5, random_state=42)

In [11]:
batch_size = 4
n_workers = 4

# Load raw data into custom PyTorch datasets
train_set = GlaucomaDatasetBoundingBoxes(images_path, masks_path, train_imgs, train_masks, bb_df, 1024)
val_set = GlaucomaDatasetBoundingBoxes(images_path, masks_path, val_imgs, val_masks, bb_df, 1024)
test_set = GlaucomaDatasetBoundingBoxes(images_path, masks_path, test_imgs, test_masks, bb_df, 1024)

# Load datasets into PyTorch DataLoaders
train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=n_workers, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, num_workers=n_workers, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, num_workers=n_workers, shuffle=True)

In [12]:
# Training arguments
args = {
    'out_size': 1024,
    'image_size': 1024,
    'memory_bank_size': 16,
    'lr': 1e-4
}

# initialize Adam optimizer
optimizer = optim.Adam(med_sam_2.parameters(), lr=args['lr'], betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

for param in med_sam_2.image_encoder.parameters():
    param.requires_grad = False

# Set to max 50 epochs
epochs = 50

In [13]:
med_sam_2 = med_sam_2.to('cuda')

In [14]:
# Initialize early stopping conditions
best_val_dice = float('-inf')
patience = 5
patience_counter = 0

# Initialize loss lists to save loss at each epoch
train_loss_list = [0] * epochs
val_loss_list = [0] * epochs

for epoch in range(epochs):
    # Run one training epoch
    train_loss, train_dice = train_sam(args, med_sam_2, optimizer, train_loader, epoch)
    print(f'Train loss: {train_loss} || @ epoch {epoch}. Train Dice: {train_dice} || @ epoch {epoch}.')
    train_loss_list[epoch] = train_loss

    # Run one validation loop
    val_loss, val_dice = validation_sam(args, val_loader, epoch, med_sam_2)
    print(f'Validation loss: {val_loss} || @ epoch {epoch}. Validation dice: {val_dice} || @ epoch {epoch}.')
    val_loss_list[epoch] = val_loss

    # Check if our model is better than the best trained model
    if val_dice > best_val_dice:
        best_val_dice = val_dice
        state_dict_to_save = med_sam_2.state_dict()
        patience_counter = 0  # Reset patience counter
    else:
        patience_counter += 1
        print(f"No improvement. Patience: {patience_counter}/{patience}")
        if patience_counter >= patience:
            print(f"Early stopping triggered @ epoch {epoch}.")
            break

# Run a final test evaluation on a holdout set
med_sam_2.load_state_dict(state_dict_to_save)
test_loss, test_dice = validation_sam(args, test_loader, 0, med_sam_2)
print(f'Test loss: {test_loss}. Test Dice: {test_dice}.')

Epoch 0:   0%|          | 0/114 [00:00<?, ?img/s]

se: torch.Size([4, 2, 256])
de: torch.Size([4, 256, 64, 64])
image_embed: torch.Size([4, 16, 64, 64])
image_embeddings shape: torch.Size([4, 16, 64, 64])
dense_prompt_embeddings shape: torch.Size([4, 256, 64, 64])





RuntimeError: The size of tensor a (16) must match the size of tensor b (256) at non-singleton dimension 1

In [None]:
# Save the best trained model
torch.save(state_dict_to_save, "./medsam2-two-mask-512x512-50-epochs-weighted-loss-2.pth")

In [None]:
# Load the weights from the best trained model
med_sam_2.load_state_dict(torch.load("medsam2-two-mask-512x512-50-epochs-weighted-loss-2.pth", weights_only=True))

In [None]:
# Create the predictor object from the trained Medical-SAM2 model
predictor = SAM2ImagePredictor(med_sam_2)

In [None]:
# Load in a single image
img_path = "../../data/ORIGA/Images_Square/465.jpg"
img = Image.open(img_path).resize((512, 512))
img = np.array(img.convert('RGB'))

In [None]:
# Set the image for prediction. This must be done when using the predictor object to perform inference
predictor.set_image(img)

In [None]:
# Make predictions on the single image using the associated prompt
masks, scores, logits = predictor.predict(box=pack['bbox'][1], multimask_output=True)

In [None]:
# Display the predicted masks
show_masks(img, masks, scores, box_coords=pack['bbox'][1])

In [None]:
img_path = './data/Clinic/Images/Subject (14).png'
img = Image.open(img_path).resize((512, 512))
img = np.array(img.convert('RGB'))

In [None]:
predictor.set_image(img)
masks, scores, logits = predictor.predict(box=np.array([260, 65, 350, 160]), multimask_output=True)

In [None]:
show_masks(img, masks, scores, box_coords=np.array([260, 65, 350, 160]))