# Assuming dataset and model are downlaoded from train_sam.ipynb

In [2]:
from monai.transforms import (
    Activations,
    AddChanneld,
    AsDiscrete,
    Compose,
    LoadImaged,
    RandFlipd,
    RandRotated,
    RandZoomd,
    ScaleIntensityd,
    EnsureTyped,
    Resized,
    RandGaussianNoised,
    RandGaussianSmoothd,
    Rand2DElasticd,
    RandAffined,
    OneOf,
    NormalizeIntensity,
    AsChannelFirstd,
    EnsureType,
    LabelToMaskd,
    EnsureChannelFirstd
    
)

val_transform = Compose(
    [
#         EnsureChannelFirstd(keys=['label'],channel_dim='no_channel'),
        AddChanneld(keys=['label']),

        ScaleIntensityd(keys=["image"]),

        Resized(keys=["image", "label"], spatial_size=(512, 512),mode=['area','nearest']),
        EnsureTyped(keys=["image"])
    ])





In [3]:
# import h5py
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset
import SimpleITK

root_path=Path('/home/marawan/fetal_miccai2023/Pubic Symphysis-Fetal Head Segmentation and Angle of Progression/')


class Fetal_dataset(Dataset):
    def __init__(self, list_dir, transform=None):
        self.transform = transform  # using transform in torch!
#         self.split = split
        images = [SimpleITK.GetArrayFromImage(SimpleITK.ReadImage(str(i))) for i in list_dir[0]]
        labels = [SimpleITK.GetArrayFromImage(SimpleITK.ReadImage(str(i))) for i in list_dir[1]]

        self.sample_list = np.array(list(zip(images,labels)))
        
        self.resize=Compose([Resized(keys=["label"], spatial_size=(256, 256),mode=['nearest'])])
#         self.resize=Resized(keys=["image", "label"], spatial_size=(64, 64))

    def __len__(self):
        return len(self.sample_list)

    def __getitem__(self, idx):
        
        if self.transform:
            sample=self.transform({"image": self.sample_list[idx][0], "label": self.sample_list[idx][1]})
        
        sample['low_res_label']=self.resize({"label":sample['label']})['label'][0]
        sample['label']=sample['label'][0]
        #         sample['case_name'] = self.sample_list[idx].strip('\n')
        return sample


image_files = np.array([(root_path / Path("image_mha")/Path(str(i).zfill(5) + '.mha')) for i in range(1,4000)])
label_files = np.array([(root_path / Path("label_mha")/Path(str(i).zfill(5) + '.mha')) for i in range(1,4000)])


In [7]:
fold_n=1 #fold_to_test
epoch_to_output='epoch_32'

In [8]:
base_lr = 0.001
num_classes = 2
batch_size = 64
multimask_output = True
warmup=1
max_epoch = 400
save_interval = 5
iter_num = 0
warmup_period=500
weight_decay=7
device=6
devices=[6,7]

In [20]:
import numpy as np
from torch.utils.data import DataLoader
from importlib import import_module
from segment_anything import sam_model_registry
import torch
import os
import cv2
from torchvision.transforms import GaussianBlur,RandomHorizontalFlip,RandomVerticalFlip


In [10]:
sam, img_embedding_size = sam_model_registry['vit_h'](image_size=512,
                                                      num_classes=2,
                                                      checkpoint='checkpoints/sam_vit_h_4b8939.pth',
                                                      # checkpoint='./model_weights/sam_vit_b_01ec64.pth',
                                                      pixel_mean=[0, 0, 0],
                                                      pixel_std=[1, 1, 1])
pkg = import_module('sam_lora_image_encoder')
model = pkg.LoRA_Sam(sam, 4)

In [18]:
path_to_epoch=f"train/b16_wd7_results/{fold_n}/{epoch_to_output}.pth" #put ur testing epoch
output_path=f"sample_output/{fold_n}/{epoch_to_output}"

In [21]:
TTA=1
h_flip=RandomHorizontalFlip(1)
v_flip=RandomVerticalFlip(1)

n_size=len(image_files)
indices=[]
all_indices=np.arange(0,4000)

resize_to_256=Compose([Resized(keys=["label"], spatial_size=(1,256, 256),mode=['nearest'])])

test_index=np.arange(800*fold_n,(fold_n+1)*800-1)
train_index=np.setxor1d(all_indices,test_index)[10:-10]

db_val = Fetal_dataset(transform=val_transform,list_dir=(image_files[test_index],label_files[test_index]))

valloader = DataLoader(db_val, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)

print("The length of val set is: {}".format(len(valloader)))


os.makedirs(os.path.join(output_path,'labels'),exist_ok=True)
os.makedirs(os.path.join(output_path,'prediction'),exist_ok=True)

model.load_lora_parameters(path_to_epoch)

model = torch.nn.DataParallel(model, device_ids=devices)

model.to(device)

model.eval()
with torch.no_grad():
    for i_batch, sampled_batch in enumerate(valloader):
        image_batch, label_batch = sampled_batch["image"].to(device), sampled_batch["label"].to(
            device)
        low_res_label_batch = sampled_batch['low_res_label']
        low_res_label_batch = low_res_label_batch.to(device)

        assert image_batch.max() <= 3, f'image_batch max: {image_batch.max()}'
        if TTA==1:
            input_h_flipped =h_flip(image_batch)
            outputs = model(image_batch, multimask_output, 512)
            outputs_h_flip = model(input_h_flipped, multimask_output, 512)
            output_masks=(outputs['masks']+h_flip(outputs_h_flip['masks']))/2
        else:

            outputs = model(image_batch, multimask_output, 512)
            output_masks = outputs['masks']

        output_masks = torch.argmax(torch.softmax(output_masks, dim=1), dim=1, keepdim=True)

        output_masks=resize_to_256({"label":output_masks})["label"]

        for i in range(image_batch.size(0)):
            cv2.imwrite(f'{output_path}/prediction/out_{(i_batch*batch_size)+i}.png',output_masks[i][0].cpu().numpy())
            cv2.imwrite(f'{output_path}/labels/label_{(i_batch*batch_size)+i}.png',sampled_batch['low_res_label'][i].cpu().numpy())
        print(f'Batch {i_batch} Done')



The length of val set is: 13
Batch 12 Done
