In [None]:
import os
import sys

import torch
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics.functional.classification import binary_jaccard_index
from torchvision.models.segmentation.deeplabv3 import deeplabv3_resnet50

import logging
from pathlib import Path
from PIL import Image
import numpy as np
# from architectures.unet_model_xB import UNet as UNet_ml
# from architectures.unet_model_seg import UNet as UNet_seg
import matplotlib.pyplot as plt
from utils.BUSI_multiloss import BUSIDataset

### Set device

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
print(device)

### Set seed

In [None]:
manual_seed = 0
torch.manual_seed(manual_seed)

### Load dataset

In [None]:
root_dir = Path().resolve().parent
print(root_dir)

batch_size=16

In [None]:
def get_dataloaders(root_dir,
                    val_percent=0.1):

    global n_train, n_val

#     root_dir = root_dir

    dataset = BUSIDataset(root_dir, im_res = 224, threshold = 100)
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])

    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers = 2)
    val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers = 2, pin_memory=True)

    return train_loader, val_loader

In [None]:
_, val_loader = get_dataloaders(root_dir)

### Load Model

Load weights of the model

In [None]:
ml_model_path = '../testModels/2024/10-08_BUSI/CP_Trial30_Epoch160.pth'
# ml_model_path = '../testModels/busi_ml/multiloss/CP_Trial30_Epoch160.pth'
# seg_model_path = '../pets_final/supervised_segmentation/CP_epoch60.pth'
# model_path = 'checkpoints/pascalVOC/multiloss/04-30/17-25-10/CP_epoch2.pth'

In [None]:
# #Enter the correct arguments for the UNet
# net_ml = UNet_ml(n_channels=3, n_classes=1, bilinear=True)

# net_ml.load_state_dict(
#             torch.load(ml_model_path, map_location=device)
#         )
# logging.info(f'Model loaded from {ml_model_path}')
# # net_ml.eval()

In [None]:
# #Enter the correct arguments for the UNet
# net_seg = UNet_seg(n_channels=3, n_classes=1, bilinear=True)

# net_seg.load_state_dict(
#             torch.load(seg_model_path, map_location=device)
#         )
# logging.info(f'Model loaded from {seg_model_path}')
# # net_seg.eval()

Load weights for deeplabv3 based arch

In [None]:
def get_model():

    # model = fcn_resnet50(aux_loss=True)
    model = deeplabv3_resnet50(num_classes = 1, aux_loss=True)
    aux = nn.Sequential(nn.Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
                 nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                 nn.ReLU(inplace=True),
                 nn.Dropout(p=0.1, inplace=False),
                 nn.Conv2d(512, 3, kernel_size=(1, 1), stride=(1, 1)),
                 nn.Sigmoid())
    model.aux_classifier = aux
    model.classifier.append(nn.Sigmoid())
    model.to(device=device)
    return model

In [None]:
net_ml = get_model()

In [None]:
net_ml.load_state_dict(torch.load(ml_model_path, map_location=device))

### Visualization Helpers

In [None]:
def gt_torchToPIL_img(img):
    
    img = img.squeeze().cpu().numpy()
    img = img.transpose((1,2,0))
    return Image.fromarray((img * 255).astype(np.uint8), 'RGB')

In [None]:
def gt_torchToPIL_mask(mask):
    
    mask = mask.squeeze().cpu().numpy()
    mask = np.clip(mask, 0, 1)
    return Image.fromarray((mask * 255).astype(np.uint8), 'L')

In [None]:
def ml_torchToPIL_mask(mask):
    
    mask = mask.squeeze().cpu().detach().numpy()
    mask = np.clip(mask, 0, 1)
    return Image.fromarray((mask * 255).astype(np.uint8), 'L')

In [None]:
def seg_torchToPIL_mask(mask):
    
    mask = mask.squeeze().cpu().detach().numpy()
    mask = np.clip(np.round(mask), 0, 1)
    return Image.fromarray((mask * 255).astype(np.uint8), 'L')

In [None]:
def combined_display_pic(images):
    
#     widths, heights = 224, 224
#     total_width = sum(widths)
#     max_height = max(heights)

    widths, heights = 224, 224
    total_width = widths*len(images) + 10*(len(images) - 1)
    max_height = heights
    
    white_spacing = Image.new('RGB', (10, 224), color='white')

    combined_image = Image.new('RGB', (total_width, max_height))
    x_offset = 0
    for img in images:
        combined_image.paste(img, (x_offset, 0))
        x_offset += img.width
        combined_image.paste(white_spacing, (x_offset, 0))
        x_offset += 10

    return combined_image

In [None]:
# batch = next(iter(val_loader))

### Evaluation Metrics

In [None]:
# display(torchToPIL_mask(batch['mask'][0]))

In [None]:
def single_iou(pred_mask, mask):
    
    iou = binary_jaccard_index(pred_mask, mask)
    
    return iou

In [None]:
def batch_iou():
    
    return batch_iou

In [None]:
def random_pred_expected_iou(perc):
    
    e_iou = perc / (2 - perc)
    
    return e_iou

In [None]:
j = 0
for batch in val_loader:
    net_ml.eval()
#     net_seg.eval()
    print("Batch #",j)
    j += 1
#     if j > 1:
#         break
    batch_iou_ml = 0
    batch_iou_seg = 0
    for i in range(batch_size):
        
        img = batch['image'][i]
        mask = batch['mask'][i]
        perc = batch['mask_perc'][i]
        
        img = img.to(device=device, dtype=torch.float32)
        mask = mask.to(device=device, dtype=torch.float32)
        perc = perc.to(device=device, dtype=torch.float32)

        print(batch['image_ID'][i])
        
        gt_image = gt_torchToPIL_img(img)
        gt_mask = gt_torchToPIL_mask(mask)
        
        img = torch.unsqueeze(img, 0)
        
        outs = net_ml(img)
        _, ml_mask = outs['aux'], outs['out']
#         seg_mask = net_seg(img)
        
#         print(ml_mask.shape, seg_mask.shape)
        
        ml_mask_disp = ml_torchToPIL_mask(ml_mask)
#         seg_mask_disp = seg_torchToPIL_mask(seg_mask)
        
#         combined_image = combined_display_pic([gt_image, gt_mask, ml_mask_disp, seg_mask_disp])
        combined_image = combined_display_pic([gt_image, gt_mask, ml_mask_disp])
        display(combined_image)
        
        ml_mask = ml_mask[0]
#         seg_mask = seg_mask[0]
        ml_iou = single_iou(ml_mask, mask)
#         seg_iou = single_iou(seg_mask, mask)
        random_iou = random_pred_expected_iou(perc)
        
        batch_iou_ml += ml_iou
#         batch_iou_seg += seg_iou
        
        print("Image Metrics: ", "ML IoU: ", ml_iou, "Random Pred IoU: ", random_iou)
    batch_iou_ml = batch_iou_ml / batch_size
#     batch_iou_seg = batch_iou_seg / batch_size
    print("Batch Metrics: ", "ML IoU: ", batch_iou_ml)