# Downloading dataset and model

In [None]:
!wget https://zenodo.org/record/7851339/files/Pubic%20Symphysis-Fetal%20Head%20Segmentation%20and%20Angle%20of%20Progression.zip?download=1

In [3]:
!unzip 'Pubic Symphysis-Fetal Head Segmentation and Angle of Progression.zip?download=1'

Archive:  Pubic Symphysis-Fetal Head Segmentation and Angle of Progression.zip?download=1
   creating: Pubic Symphysis-Fetal Head Segmentation and Angle of Progression/image_mha/
  inflating: Pubic Symphysis-Fetal Head Segmentation and Angle of Progression/image_mha/00001.mha  
  inflating: Pubic Symphysis-Fetal Head Segmentation and Angle of Progression/image_mha/00002.mha  
  inflating: Pubic Symphysis-Fetal Head Segmentation and Angle of Progression/image_mha/00003.mha  
  inflating: Pubic Symphysis-Fetal Head Segmentation and Angle of Progression/image_mha/00004.mha  
  inflating: Pubic Symphysis-Fetal Head Segmentation and Angle of Progression/image_mha/00005.mha  
  inflating: Pubic Symphysis-Fetal Head Segmentation and Angle of Progression/image_mha/00006.mha  
  inflating: Pubic Symphysis-Fetal Head Segmentation and Angle of Progression/image_mha/00007.mha  
  inflating: Pubic Symphysis-Fetal Head Segmentation and Angle of Progression/image_mha/00008.mha  
  inflating: Pubic Sy

In [5]:
!rm -rf 'Pubic Symphysis-Fetal Head Segmentation and Angle of Progression.zip?download=1'

In [None]:
!pip install -r requirements.txt

In [None]:
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

In [None]:
!mkdir checkpoints
!mv sam_vit_h_4b8939.pth checkpoints/
!mv epoch_159.pth checkpoints/

In [None]:
!gdown https://drive.google.com/u/0/uc?id=1Kx_vx9bcxJaiMYWAgljNtwtHcooUsq8m
!mv epoch_299.pth checkpoints/

# Training Configuration

In [None]:
from monai.transforms import (
    Activations,
    AddChanneld,
    AsDiscrete,
    Compose,
    LoadImaged,
    RandFlipd,
    RandRotated,
    RandZoomd,
    ScaleIntensityd,
    EnsureTyped,
    Resized,
    RandGaussianNoised,
    RandGaussianSmoothd,
    Rand2DElasticd,
    RandAffined,
    OneOf,
    NormalizeIntensity,
    AsChannelFirstd,
    EnsureType,
    LabelToMaskd,
    EnsureChannelFirstd
    
)

import numpy as np
deform = Rand2DElasticd(
    keys=["image", "label"],
    prob=0.5,
    spacing=(7, 7),
    magnitude_range=(1, 2),
    rotate_range=(np.pi / 6,),
    scale_range=(0.2, 0.2),
    translate_range=(20, 20),
    padding_mode="zeros",
    mode=['bilinear','nearest']
    # device=self.device,
)

affine = RandAffined(
    keys=["image", "label"],
    prob=0.5,
    rotate_range=(np.pi / 6),
    scale_range=(0.2, 0.2),
    translate_range=(20, 20),
    padding_mode="zeros",
    # device=self.device
    mode=['bilinear','nearest']

)

# TODO joaquin check transforms again
train_transform = Compose(
    [
#         LoadImaged(keys=["image", "label"], reader=image_loader),

#         channel_transform,

#         LabelToMaskd(keys=["label"], select_labels=[0]),
#         EnsureChannelFirstd(keys=['label'],channel_dim='no_channel'),

        AddChanneld(keys=['label']),
        ScaleIntensityd(keys=["image"]),

        RandRotated(keys=["image", "label"], range_x=(-np.pi / 12, np.pi / 12), prob=0.5, keep_size=True,mode=['bilinear','nearest']),
        RandFlipd(keys=["image", "label"], spatial_axis=1, prob=0.5),

        RandZoomd(keys=["image", "label"], min_zoom=0.9, max_zoom=1.1, prob=0.5,mode=['area','nearest']),

        RandGaussianSmoothd(keys=["image"], prob=0.1, sigma_x=(0.25, 1.5), sigma_y=(0.25, 1.5)),
        RandGaussianNoised(keys=["image"], mean=0, std=0.1, prob=0.5),

        OneOf(transforms=[affine, deform], weights=[0.8, 0.2]),

        Resized(keys=["image", "label"], spatial_size=(512, 512),mode=['area','nearest']),

        EnsureTyped(keys=["image"] ),
    ]
)

val_transform = Compose(
    [
#         EnsureChannelFirstd(keys=['label'],channel_dim='no_channel'),
        AddChanneld(keys=['label']),

        ScaleIntensityd(keys=["image"]),

        Resized(keys=["image", "label"], spatial_size=(512, 512),mode=['area','nearest']),
        EnsureTyped(keys=["image"])
    ])



## Naive Dataloader

In [None]:
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset
import SimpleITK
import matplotlib.pyplot as plt
# import h5py

root_path=Path('/home/interns/marawan/fetal_miccai2023/Pubic Symphysis-Fetal Head Segmentation and Angle of Progression/')

class Fetal_dataset(Dataset):
    def __init__(self, list_dir, transform=None):
        self.transform = transform  # using transform in torch!
#         self.split = split
        #should use h5py
        images = [SimpleITK.GetArrayFromImage(SimpleITK.ReadImage(str(i))) for i in list_dir[0]]
        labels = [SimpleITK.GetArrayFromImage(SimpleITK.ReadImage(str(i))) for i in list_dir[1]]

        self.sample_list = np.array(list(zip(images,labels)))
        
        self.resize=Compose([Resized(keys=["label"], spatial_size=(128, 128),mode=['nearest'])])
#         self.resize=Resized(keys=["image", "label"], spatial_size=(64, 64))

    def __len__(self):
        return len(self.sample_list)

    def __getitem__(self, idx):
        
        if self.transform:
            sample=self.transform({"image": self.sample_list[idx][0], "label": self.sample_list[idx][1]})
        
        sample['low_res_label']=self.resize({"label":sample['label']})['label'][0]
        sample['label']=sample['label'][0]
        #         sample['case_name'] = self.sample_list[idx].strip('\n')
        return sample
    


image_files = np.array([(root_path / Path("image_mha")/Path(str(i).zfill(5) + '.mha')) for i in range(1,4000)])
label_files = np.array([(root_path / Path("label_mha")/Path(str(i).zfill(5) + '.mha')) for i in range(1,4000)])


# Training one fold

In [None]:
fold_n=1 #train up to 5 folds 

In [None]:
from segment_anything.modeling.mask_decoder import MLP,MaskDecoder
import numpy as np
import torch.optim as optim
from tensorboardX import SummaryWriter
from torch.nn.modules.loss import CrossEntropyLoss
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils import DiceLoss, Focal_loss
from importlib import import_module
from segment_anything import sam_model_registry
import torch
import logging
import os


In [None]:
def calc_loss(outputs, low_res_label_batch, ce_loss, dice_loss, dice_weight:float=0.8):
    low_res_logits = outputs['low_res_logits']
    loss_ce = ce_loss(low_res_logits, low_res_label_batch.long())
    loss_dice = dice_loss(low_res_logits, low_res_label_batch, softmax=True)
    loss = (1 - dice_weight) * loss_ce + dice_weight * loss_dice
    return loss, loss_ce, loss_dice

In [None]:
base_lr = 0.001
num_classes = 2
batch_size = 16
multimask_output = True
warmup=1
max_epoch = 400
save_interval = 5
warmup_period=500
weight_decay=7
device=6
devices=[6,7]

In [None]:
n_size=len(image_files)
indices=[]
all_indices=np.arange(0,4000)


test_index=np.arange(800*fold_n,((fold_n+1)*800)-1)
train_index=np.setxor1d(all_indices,test_index)[10:-10]

snapshot_path=f'b{batch_size}_wd{weight_decay}_results/{fold_n}'
os.makedirs(snapshot_path,exist_ok=True)

db_train = Fetal_dataset(transform=train_transform,list_dir=(image_files[train_index],label_files[train_index]))
db_val = Fetal_dataset(transform=val_transform,list_dir=(image_files[test_index],label_files[test_index]))

trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
valloader = DataLoader(db_val, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)


sam, img_embedding_size = sam_model_registry['vit_h'](image_size=512,
                                                            num_classes=8,
                                                            checkpoint='checkpoints/sam_vit_h_4b8939.pth', pixel_mean=[0, 0, 0],
                                                            pixel_std=[1, 1, 1])
pkg = import_module('sam_lora_image_encoder')
net = pkg.LoRA_Sam(sam, 4).to(device)
net.load_lora_parameters('checkpoints/epoch_299.pth')
#     net.sam.mask_decoder.iou_prediction_head = MLP(
#             net.sam.mask_decoder.transformer_dim,256, num_classes+1, 3
#         )
net.sam.mask_decoder = MaskDecoder(transformer=net.sam.mask_decoder.transformer,
    transformer_dim=net.sam.mask_decoder.transformer_dim,num_multimask_outputs=num_classes
)

net = torch.nn.DataParallel(net, device_ids=devices)

model=net
model.to(device)
print("The length of train set is: {}".format(len(db_train)))



model.train()
ce_loss = CrossEntropyLoss()
dice_loss = DiceLoss(num_classes + 1)

max_iterations = max_epoch * len(trainloader)  

if warmup:
    b_lr = base_lr / warmup_period
else:
    b_lr = base_lr

optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=b_lr, betas=(0.9, 0.999), weight_decay=0.1)
writer = SummaryWriter(f'results/{fold_n}/logs')
iter_num = 0
logging.info("{} iterations per epoch. {} max iterations ".format(len(trainloader), max_iterations))
best_performance = 100
iterator = tqdm(range(max_epoch), ncols=70)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',verbose=True)

scaler = torch.cuda.amp.GradScaler(enabled=True)

for epoch_num in iterator:
    train_loss_ce = []
    train_loss_dice = []

    val_loss_ce = []
    val_loss_dice = []
    val_dice_score=[]
    for i_batch, sampled_batch in enumerate(trainloader):
        image_batch, label_batch = sampled_batch['image'], sampled_batch['label']  # [b, c, h, w], [b, h, w]
        low_res_label_batch = sampled_batch['low_res_label']
        image_batch, label_batch = image_batch.to(device), label_batch.to(device)
        low_res_label_batch = low_res_label_batch.to(device)
        assert image_batch.max() <= 3, f'image_batch max: {image_batch.max()}'


        with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True):
            outputs = model(image_batch, multimask_output, 512)
            loss, loss_ce, loss_dice = calc_loss(outputs, low_res_label_batch, ce_loss, dice_loss, 0.8)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()


        if warmup and iter_num < warmup_period:
            lr_ = base_lr * ((iter_num + 1) / warmup_period)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_
        else:
            if warmup:
                shift_iter = iter_num - warmup_period
                assert shift_iter >= 0, f'Shift iter is {shift_iter}, smaller than zero'
            else:
                shift_iter = iter_num
            lr_ = base_lr * (1.0 - shift_iter / max_iterations) ** weight_decay  # learning rate adjustment depends on the max iterations
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_

        writer.add_scalar('info/lr', lr_, iter_num)
        iter_num = iter_num + 1
        train_loss_ce.append(loss_ce.detach().cpu().numpy())
        train_loss_dice.append(loss_dice.detach().cpu().numpy())

        logging.info('iteration %d : loss : %f, loss_ce: %f, loss_dice: %f ,lr:%f' % (iter_num, loss.item(), loss_ce.item(), loss_dice.item(),
                                                                                 optimizer.param_groups[0]['lr']))


    train_loss_ce_mean = np.mean(train_loss_ce)
    train_loss_dice_mean = np.mean(train_loss_dice)

    writer.add_scalar('info/total_loss', train_loss_ce_mean+train_loss_dice_mean, iter_num)
    writer.add_scalar('info/loss_ce', train_loss_ce_mean, iter_num)
    writer.add_scalar('info/loss_dice', train_loss_dice_mean, iter_num)
    model.eval()
    with torch.no_grad():
        for i_batch, sampled_batch in enumerate(valloader):
            image_batch, label_batch = sampled_batch["image"].to(device), sampled_batch["label"].to(
                device)
            low_res_label_batch = sampled_batch['low_res_label']
            low_res_label_batch = low_res_label_batch.to(device)

            assert image_batch.max() <= 3, f'image_batch max: {image_batch.max()}'
            outputs = model(image_batch, multimask_output, 512)
            loss, loss_ce, loss_dice = calc_loss(outputs, low_res_label_batch, ce_loss, dice_loss, 0.8)

            val_loss_ce.append(loss_ce.detach().cpu().numpy())
            val_loss_dice.append(loss_dice.detach().cpu().numpy())

            if i_batch % 100 == 0:            
                fig,ax = plt.subplots(1,3,figsize=(12,6))
                ax[0].imshow(sampled_batch['image'][0].cpu().numpy().transpose(1,2,0)) 
                ax[0].set_title('image') 
                ax[1].imshow(sampled_batch['label'][0]) 
                ax[1].set_title('label')
                output_masks = outputs['masks']
                output_masks = torch.argmax(torch.softmax(output_masks, dim=1), dim=1, keepdim=True)

                ax[2].imshow(output_masks[0].cpu()[0]) 
                ax[2].set_title('prediction') 
                plt.show()


        val_loss_ce_mean = np.mean(val_loss_ce)
        val_loss_dice_mean = np.mean(val_loss_dice)

        writer.add_scalar('info/val_total_loss', val_loss_ce_mean+val_loss_dice_mean, iter_num)
        writer.add_scalar('info/val_loss_ce', val_loss_ce_mean, iter_num)
        writer.add_scalar('info/val_loss_dice', val_loss_dice_mean, iter_num)

        logging.info('epoch %d : val loss : %f, val loss_ce: %f, val loss_dice: %f' % (epoch_num, val_loss_ce_mean+val_loss_dice_mean,
                                                                           val_loss_ce_mean, val_loss_dice_mean))

    if val_loss_dice_mean < best_performance:
        best_performance=val_loss_dice_mean
        save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth')
        try:
            model.save_lora_parameters(save_mode_path)
        except:
            model.module.save_lora_parameters(save_mode_path)
        logging.info("save model to {}".format(save_mode_path))

    if epoch_num >= max_epoch - 1:
        save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth')
        try:
            model.save_lora_parameters(save_mode_path)
        except:
            model.module.save_lora_parameters(save_mode_path)
        logging.info("save model to {}".format(save_mode_path))
        iterator.close()
    model.train()

writer.close()
