# Model Selection Pipeline

In [None]:
%pip install segment_anything
%pip install openpyxl
%pip install segmentation-models-pytorch
%pip install boto3

In [1]:
import json
import sys
import time
import boto3
import os
import subprocess
import openpyxl
import torch
import pandas as pd
import numpy as np

os.chdir("..")

from smart_open import open as smart_open
from torch.utils.data import DataLoader
import io
from segment_anything import sam_model_registry
from torch.nn import Linear
from torch.nn import Embedding
from segment_anything.modeling.mask_decoder import MLP
import torch.nn as nn
from segmentation_models_pytorch import Unet
from src.dino_helper import DINOv2Segmentation
from torchvision import transforms
from src.data_helper import calculate_iou
import src.data_helper as data_helper
import src.medsam_helper as medsam_helper
import src.credentials as credentials

### Prepare Patient Data

In [15]:
s3 = boto3.client('s3', aws_access_key_id=credentials.ACCESS_KEY, aws_secret_access_key=credentials.SECRET_KEY)

response = s3.get_object(Bucket='raw-data-mris-segs', Key='seg_list_test.xlsx') 
data = response['Body'].read()
patient_data = pd.read_excel(io.BytesIO(data))

patient_data

Unnamed: 0,MRI/Patient ID,Number of Brightness Levels,Number of Slices,Brightness Level 1,Brightness Level 2
0,ACRIN 6698_207837,2,320,5,6
1,ACRIN 6698_277831,2,480,6,7
2,Duke_062,2,492,1,2
3,Duke_077,2,522,2,3


In [16]:
cleaned_data = data_helper.clean_mri_data(patient_data)
cleaned_data

Unnamed: 0,index,patient_id,Number of Brightness Levels,Number of Slices,Brightness Level 1,Brightness Level 2,Total Images,Start_Index
0,0,ACRIN 6698_207837,2,320,5,6,640,0
1,1,ACRIN 6698_277831,2,480,6,7,960,640
2,2,Duke_062,2,492,1,2,984,1600
3,3,Duke_077,2,522,2,3,1044,2584


In [17]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])
test_dataset = data_helper.CancerDataset(labels=cleaned_data, path='cleaned-mri-data', train=False, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

In [5]:
test = next(iter(test_loader))

### Load Models

In [6]:
model_bucket = "medseg-models"
base_path = "base_medsam_model_06-10-2025.pth"
tuned_path = "tuned_medsam_model_06-10-2025.pth"
unet_path = "unet_model_06-11-2025.pth"
dino_path = "dinov2_model_06-11-2025.pth"
medsam_path = "/home/ra-ugrad/Documents/Haleigh/MedicalImage/models/medsam_vit_b.pth"

In [7]:
base_model = sam_model_registry['vit_b'](checkpoint=medsam_path)

base_model.mask_decoder.num_mask_tokens = 8
base_model.mask_decoder.num_multimask_outputs = 7
base_model.image_encoder.patch_embed.proj = nn.Conv2d(3, 768, kernel_size = (35, 35), stride = (3, 3))
base_model.mask_decoder.mask_tokens = Embedding(base_model.mask_decoder.num_mask_tokens, 256)
base_model.mask_decoder.output_hypernetworks_mlps = nn.ModuleList([MLP(256, 256, 32, 3) for i in range(base_model.mask_decoder.num_mask_tokens)])
base_model.mask_decoder.iou_prediction_head.layers[2] = Linear(in_features=256, out_features=base_model.mask_decoder.num_mask_tokens, bias=True)

tuned_model = sam_model_registry['vit_b'](checkpoint=medsam_path)

tuned_model.mask_decoder.num_mask_tokens = 8
tuned_model.mask_decoder.num_multimask_outputs = 7
tuned_model.image_encoder.patch_embed.proj = nn.Conv2d(3, 768, kernel_size = (35, 35), stride = (3, 3))
tuned_model.mask_decoder.mask_tokens = Embedding(tuned_model.mask_decoder.num_mask_tokens, 256)
tuned_model.mask_decoder.output_hypernetworks_mlps = nn.ModuleList([MLP(256, 256, 32, 3) for i in range(tuned_model.mask_decoder.num_mask_tokens)])
tuned_model.mask_decoder.iou_prediction_head.layers[2] = Linear(in_features=256, out_features=tuned_model.mask_decoder.num_mask_tokens, bias=True)

unet_model = Unet(encoder_name='resnet34', encoder_weights='imagenet', in_channels=1, classes=7)
dino_model = DINOv2Segmentation()

Using cache found in /home/ra-ugrad/.cache/torch/hub/facebookresearch_dinov2_main


In [8]:
base_response = s3.get_object(Bucket=model_bucket, Key=base_path) 
base_data = base_response['Body'].read()
base_model.load_state_dict(torch.load(io.BytesIO(base_data)))

tuned_response = s3.get_object(Bucket=model_bucket, Key=tuned_path) 
tuned_data = tuned_response['Body'].read()
tuned_model.load_state_dict(torch.load(io.BytesIO(tuned_data)))

unet_response = s3.get_object(Bucket=model_bucket, Key=unet_path) 
unet_data = unet_response['Body'].read()
unet_model.load_state_dict(torch.load(io.BytesIO(unet_data)))

dino_response = s3.get_object(Bucket=model_bucket, Key=dino_path) 
dino_data = dino_response['Body'].read()
dino_model.load_state_dict(torch.load(io.BytesIO(dino_data)))

<All keys matched successfully>

### Getting results

In [18]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# base model
base_results = None 
base_model = base_model.to(device)
for img, seg, patient, b_level in test_loader:
    img = img.to(device)
    seg = seg.to(device)
    # print(img.size())
    B, C, H, W = img.size()
    img_3c = img.repeat(3, 1, 1, 1).view(B, 3, H, W).to(device)

    box_np = torch.Tensor(np.array([[0, 0, W, H]])).to(device)

    with torch.no_grad():
        image_embedding = base_model.image_encoder(img_3c) 
    
    medsam_seg = medsam_helper.medsam_inference(base_model, image_embedding, box_np, H, W)
    pred = torch.argmax(medsam_seg, dim=1).to(device)

    acc = list((pred == seg).float().mean(dim =(1, 2)).cpu().numpy()) #TODO: would be helpful to see acc by mask
    mean_iou, class_iou = calculate_iou(medsam_seg.cpu(), seg.cpu(), 7)

    if base_results is None:
        base_results = pd.DataFrame({"Patient": patient, 
                                    "Brightness": b_level, 
                                    "Accuracy": acc,
                                    "IoU_0": class_iou[0],
                                    "IoU_1": class_iou[1],
                                    "IoU_2": class_iou[2],
                                    "IoU_3": class_iou[3], 
                                    "IoU_4": class_iou[4],
                                    "IoU_5": class_iou[5],
                                    "IoU_6": class_iou[6],
                                    "IoU_mean": mean_iou,
                                    })
    else:
        base_results = pd.concat([base_results, pd.DataFrame({"Patient": patient, 
                                                                "Brightness": b_level, 
                                                                "Accuracy": acc,
                                                                "IoU_0": class_iou[0],
                                                                "IoU_1": class_iou[1],
                                                                "IoU_2": class_iou[2],
                                                                "IoU_3": class_iou[3], 
                                                                "IoU_4": class_iou[4],
                                                                "IoU_5": class_iou[5],
                                                                "IoU_6": class_iou[6],
                                                                "IoU_mean": mean_iou,
                                                                })]) 
        
grouped_base_results = base_results.groupby(["Patient", "Brightness"]).mean()

In [33]:
# tuned model
tuned_results = None 
tuned_model = tuned_model.to(device)
for img, seg, patient, b_level in test_loader:
    img = img.to(device)
    seg = seg.to(device)
    # print(img.size())
    B, C, H, W = img.size()
    img_3c = img

    box_np = torch.Tensor(np.array([[0, 0, W, H]])).to(device)

    with torch.no_grad():
        image_embedding = tuned_model.image_encoder(img_3c) 
    
    medsam_seg = medsam_helper.medsam_inference(tuned_model, image_embedding, box_np, H, W)
    pred = torch.argmax(medsam_seg, dim=1).to(device)

    acc = list((pred == seg).float().mean(dim =(1, 2)).cpu().numpy()) #TODO: would be helpful to see acc by mask
    mean_iou, class_iou = calculate_iou(medsam_seg.cpu(), seg.cpu(), 7)

    if tuned_results is None:
        tuned_results = pd.DataFrame({"Patient": patient, 
                                        "Brightness": b_level, 
                                        "Accuracy": acc,
                                        "IoU_0": class_iou[0],
                                        "IoU_1": class_iou[1],
                                        "IoU_2": class_iou[2],
                                        "IoU_3": class_iou[3], 
                                        "IoU_4": class_iou[4],
                                        "IoU_5": class_iou[5],
                                        "IoU_6": class_iou[6],
                                        "IoU_mean": mean_iou,
                                        })
    else:
        tuned_results = pd.concat([tuned_results, pd.DataFrame({"Patient": patient, 
                                                                "Brightness": b_level, 
                                                                "Accuracy": acc,
                                                                "IoU_0": class_iou[0],
                                                                "IoU_1": class_iou[1],
                                                                "IoU_2": class_iou[2],
                                                                "IoU_3": class_iou[3], 
                                                                "IoU_4": class_iou[4],
                                                                "IoU_5": class_iou[5],
                                                                "IoU_6": class_iou[6],
                                                                "IoU_mean": mean_iou,
                                                                })]) 
grouped_tuned_results = tuned_results.groupby(["Patient", "Brightness"]).mean()

In [34]:
grouped_tuned_results

Unnamed: 0_level_0,Unnamed: 1_level_0,Accuracy,IoU_0,IoU_1,IoU_2,IoU_3,IoU_4,IoU_5,IoU_6,IoU_mean
Patient,Brightness,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
ACRIN 6698_207837,Brightness level 1,0.59375,0.59375,0.0,0.0,0.0,0.0,0.0,0.0,0.084821
ACRIN 6698_277831,Brightness level 1,0.518177,0.518177,0.0,0.0,0.0,0.0,0.0,0.0,0.074025
ACRIN 6698_277831,Brightness level 2,0.518178,0.518178,0.0,0.0,0.0,0.0,0.0,0.0,0.074025
Duke_062,Brightness level 1,0.562607,0.562607,0.0,0.0,0.0,0.0,0.0,0.0,0.080372
Duke_062,Brightness level 2,0.562608,0.562608,0.0,0.0,0.0,0.0,0.0,0.0,0.080373
Duke_077,Brightness level 1,0.89631,0.89631,0.0,0.0,0.0,0.0,0.0,0.0,0.128044
Duke_077,Brightness level 2,0.896414,0.896414,0.0,0.0,0.0,0.0,0.0,0.0,0.128059


In [22]:
# unet model
unet_results = None
unet_model.eval()
unet_model = unet_model.to(device)
with torch.no_grad():
    for img, seg, patient, b_level in test_loader:
        img = img[:,:1,:,:] # gets just one channel, since img is black and white
        img = img.to(device)
        seg = seg.to(device)
        outputs = unet_model(img)
        pred = torch.argmax(outputs, dim=1)

        acc = list((pred == seg).float().mean(dim =(1, 2)).cpu().numpy())
        mean_iou, class_iou = calculate_iou(outputs.cpu(), seg.cpu(), 7)

        if unet_results is None:
            unet_results = pd.DataFrame({"Patient": patient, 
                                            "Brightness": b_level, 
                                            "Accuracy": acc,
                                            "IoU_0": class_iou[0],
                                            "IoU_1": class_iou[1],
                                            "IoU_2": class_iou[2],
                                            "IoU_3": class_iou[3], 
                                            "IoU_4": class_iou[4],
                                            "IoU_5": class_iou[5],
                                            "IoU_6": class_iou[6],
                                            "IoU_mean": mean_iou,
                                            })
        else:
            unet_results = pd.concat([unet_results, pd.DataFrame({"Patient": patient, 
                                                                    "Brightness": b_level, 
                                                                    "Accuracy": acc,
                                                                    "IoU_0": class_iou[0],
                                                                    "IoU_1": class_iou[1],
                                                                    "IoU_2": class_iou[2],
                                                                    "IoU_3": class_iou[3], 
                                                                    "IoU_4": class_iou[4],
                                                                    "IoU_5": class_iou[5],
                                                                    "IoU_6": class_iou[6],
                                                                    "IoU_mean": mean_iou,
                                                                    })]) 

grouped_unet_results = unet_results.groupby(["Patient", "Brightness"]).mean()

In [31]:
grouped_unet_results

Unnamed: 0_level_0,Unnamed: 1_level_0,Accuracy,IoU_0,IoU_1,IoU_2,IoU_3,IoU_4,IoU_5,IoU_6,IoU_mean
Patient,Brightness,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
ACRIN 6698_207837,Brightness level 1,0.968192,0.999966,0.924449,0.016204,0.055725,0.0,0.0,0.0,0.285192
ACRIN 6698_277831,Brightness level 1,0.906245,0.884344,0.881266,0.089409,0.153542,0.018955,0.0,0.0,0.289645
ACRIN 6698_277831,Brightness level 2,0.901625,0.876127,0.87193,0.074528,0.136824,0.014961,0.0,0.0,0.282053
Duke_062,Brightness level 1,0.88361,0.94195,0.767924,0.003329,0.361183,0.036348,0.0,0.0,0.301533
Duke_062,Brightness level 2,0.881961,0.938617,0.765434,0.004412,0.356685,0.033019,0.0,0.0,0.299738
Duke_077,Brightness level 1,0.551798,0.576738,0.002442,0.048026,0.075747,0.0,0.0,0.0,0.100422
Duke_077,Brightness level 2,0.549595,0.573675,0.002812,0.042144,0.075254,0.0,0.0,0.0,0.099126


In [29]:
# dino model
dino_model = dino_model.to(device)
dino_model.eval()
dino_results = None
with torch.no_grad():
    for img, seg, patient, b_level in test_loader:
        img = img.to(device)
        seg = seg.to(device)

        B, C, H, W = img.size()
        img_3c = img
        
        outputs = dino_model(img_3c)
        outputs = outputs.view(-1, 7, 224, 224)
        pred = torch.argmax(outputs, dim=1)
        
        acc = list((pred == seg).float().mean(dim =(1, 2)).cpu().numpy())
        mean_iou, class_iou = calculate_iou(outputs, seg, 7)

        if dino_results is None:
            dino_results = pd.DataFrame({"Patient": patient, 
                                            "Brightness": b_level, 
                                            "Accuracy": acc,
                                            "IoU_0": class_iou[0],
                                            "IoU_1": class_iou[1],
                                            "IoU_2": class_iou[2],
                                            "IoU_3": class_iou[3], 
                                            "IoU_4": class_iou[4],
                                            "IoU_5": class_iou[5],
                                            "IoU_6": class_iou[6],
                                            "IoU_mean": mean_iou,
                                            })
        else:
            dino_results = pd.concat([dino_results, pd.DataFrame({"Patient": patient, 
                                                                    "Brightness": b_level, 
                                                                    "Accuracy": acc,
                                                                    "IoU_0": class_iou[0],
                                                                    "IoU_1": class_iou[1],
                                                                    "IoU_2": class_iou[2],
                                                                    "IoU_3": class_iou[3], 
                                                                    "IoU_4": class_iou[4],
                                                                    "IoU_5": class_iou[5],
                                                                    "IoU_6": class_iou[6],
                                                                    "IoU_mean": mean_iou,
                                                                    })]) 

grouped_dino_results = dino_results.groupby(["Patient", "Brightness"]).mean()

In [30]:
grouped_dino_results

Unnamed: 0_level_0,Unnamed: 1_level_0,Accuracy,IoU_0,IoU_1,IoU_2,IoU_3,IoU_4,IoU_5,IoU_6,IoU_mean
Patient,Brightness,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
ACRIN 6698_207837,Brightness level 1,0.797493,0.711022,0.801339,0.0,0.098772,0.0,0.0,0.0,0.230162
ACRIN 6698_277831,Brightness level 1,0.896901,0.920844,0.823574,0.004487,0.044393,0.000118,0.0,0.0,0.256202
ACRIN 6698_277831,Brightness level 2,0.892904,0.914096,0.818852,0.006611,0.044342,0.00024,0.0,0.0,0.254877
Duke_062,Brightness level 1,0.876881,0.886846,0.809614,0.009751,0.054954,0.002547,0.0,0.0,0.251959
Duke_062,Brightness level 2,0.876084,0.883058,0.811471,0.008141,0.047041,0.001854,0.0,0.0,0.250224
Duke_077,Brightness level 1,0.511665,0.547549,0.012944,0.046877,0.031382,1.3e-05,0.0,0.0,0.091252
Duke_077,Brightness level 2,0.509251,0.54578,0.013036,0.042188,0.033066,0.000173,0.0,0.0,0.090606


In [35]:
all_results = pd.concat([grouped_tuned_results, grouped_unet_results, grouped_dino_results])

In [36]:
print(all_results)

                                      Accuracy     IoU_0     IoU_1     IoU_2  \
Patient           Brightness                                                   
ACRIN 6698_207837 Brightness level 1  0.593750  0.593750  0.000000  0.000000   
ACRIN 6698_277831 Brightness level 1  0.518177  0.518177  0.000000  0.000000   
                  Brightness level 2  0.518178  0.518178  0.000000  0.000000   
Duke_062          Brightness level 1  0.562607  0.562607  0.000000  0.000000   
                  Brightness level 2  0.562608  0.562608  0.000000  0.000000   
Duke_077          Brightness level 1  0.896310  0.896310  0.000000  0.000000   
                  Brightness level 2  0.896414  0.896414  0.000000  0.000000   
ACRIN 6698_207837 Brightness level 1  0.968192  0.999966  0.924449  0.016204   
ACRIN 6698_277831 Brightness level 1  0.906245  0.884344  0.881266  0.089409   
                  Brightness level 2  0.901625  0.876127  0.871930  0.074528   
Duke_062          Brightness level 1  0.

In [43]:
all_results.to_csv("results/results.csv")

### Evaluation Criteria

In [44]:
all_results = pd.read_csv('results/results.csv')

In [40]:
model = ['medsam', 'medsam', 'medsam', 'medsam', 'medsam', 'medsam', 'medsam', 
         'unet', 'unet', 'unet', 'unet', 'unet', 'unet', 'unet', 
         'dino', 'dino', 'dino', 'dino', 'dino', 'dino', 'dino']

In [46]:
all_results.drop(columns=["Unnamed: 0"], inplace=True)

In [49]:
avg_results = all_results[['Accuracy', 'IoU_0', 'IoU_1', 'IoU_2', 'IoU_3', 'IoU_4', 'IoU_5', 'IoU_6', 'IoU_mean', 'Model']].groupby('Model').mean()

In [54]:
dino_sum = avg_results.loc['dino'].sum()
unet_sum = avg_results.loc['unet'].sum()
medsam_sum = avg_results.loc['medsam'].sum()

In [55]:
print(f'dino_sum: {dino_sum}')
print(f'unet_sum: {unet_sum}')
print(f'medsam_sum: {medsam_sum}')

dino_sum: 2.3947766733654667
unet_sum: 2.700672260886672
medsam_sum: 1.3922584261042659


In [None]:
# the winner: unet!