## MetaSeg implementation for 2D MRI segmentation using 5 classes (including background). 

Dataset: We use the [OASIS-MRI Neurite dataset](https://github.com/adalca/medical-datasets/blob/master/neurite-oasis.md).


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
torch.manual_seed(422)

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '4'

In [None]:
import os, os.path as osp
import sys

import numpy as np
import json
import glob
import matplotlib.pyplot as plt
import torch
from copy import deepcopy
import torch.nn as nn

In [None]:
import numpy as np
import os, os.path as osp
import torch

from matplotlib import pyplot as plt

import alpine

# import libINR.models, libINR.utils
from tqdm.autonotebook import tqdm

In [None]:
import sys
sys.path.append('../modules')
sys.path.append("..")
from learner import INRMetaLearner
import dataloaders

import models
import loss_functions
import metrics
import utils
import vis

In [None]:
config_file = "../config/oasis_splits.json"


In [None]:
files_data = json.load(open(config_file, 'r'))
train_files = files_data['train']
val_files = files_data['val']
test_files = files_data['test']


# check overlap for any samples.
print(len(train_files), len(val_files), len(test_files))
set_train = set([x['img'] for x in train_files])
set_val = set([x['img'] for x in val_files])
set_test = set([x['img'] for x in test_files])

if len(set_train.intersection(set_val)) > 0 or  len(set_train.intersection(set_test)) > 0 or  len(set_val.intersection(set_test)) > 0:
    print("WARNING: OVERLAPPING DATA SPLITS")
else:
    print("No overlap in data splits. GOOD TO GO!!!!!!!!!!!")

In [None]:
INNER_STEPS = 2
RANDOM_AUGMENT = False
RES = 192
TEST_RUN_STEPS = VAL_STEPS  = 100# 2 #100#50
VAL_META_STEPS =  50
OUTER_LOOP_ITERATIONS =  5000 #300 #300  # 5000
NUM_CLASSES_AND_ONE = 4 + 1
NUM_CLASSES = 4

GAMMA = 1.0 # 2.0

NORMALIZE_FEATURES = False

In [None]:
HF = 128 
nonlin = 'siren'
inr_config = {"in_features":2, "out_features": 1, "hidden_features": HF, "hidden_layers": 4 }#} + 2} 
segmentation_config = {'hidden_features':[HF,],#[128, 64],
                         'output_features' : NUM_CLASSES_AND_ONE}

In [None]:
inr_seg_model = models.SirenSegINR(
    inr_type='siren',
    inr_config=inr_config,
    segmentation_config=segmentation_config,
    normalize_features=NORMALIZE_FEATURES, #only change

).float().cuda()

In [None]:
from torchinfo import summary
summary(inr_seg_model)

In [None]:
meta_learner = INRMetaLearner(
    model=inr_seg_model,
    inner_steps=INNER_STEPS,
    config={'inner_lr':1e-4, 'outer_lr':1e-4},
    custom_loss_fn = loss_functions.LossFunction(
        {'mse_loss':loss_functions.MSELoss(alpha=1), 
        'focal_loss':loss_functions.FocalSemanticLoss(gamma=GAMMA),
        # 'dice_loss':loss_functions.DiceLossMonai(),
        },
    ),
    inner_loop_loss_fn=None, # handled internally by inner loop gradient update and computation. Uses the custom loss function.
)

In [None]:
coords_tmp = alpine.utils.coords.get_coords2d(RES, RES).float().cuda()[None,...]
print(coords_tmp.shape)

In [None]:
train_ds = dataloaders.TorchMRIDataloader(json_file=config_file, mode='train', resolution=RES, coords=coords_tmp, config={'augment':RANDOM_AUGMENT})
val_ds = dataloaders.TorchMRIDataloader(json_file=config_file, mode='val', resolution=RES, coords=coords_tmp, config={'augment':RANDOM_AUGMENT,}, )
test_ds = dataloaders.TorchMRIDataloader(json_file=config_file, mode='test', resolution=RES, coords=coords_tmp, config={'augment':RANDOM_AUGMENT})
print(len(train_ds), len(val_ds), len(test_ds))

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=1, shuffle=False)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=1, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=1, shuffle=False)

In [None]:
best_weights = deepcopy(meta_learner.model_params)
best_inr_weights = None
best_classifier_weights = None
val_dice_score = []
val_iou_scores = []
val_psnr_scores = []
best_val_psnr = 0
best_val_dice_score = 0

In [None]:
for i in range(OUTER_LOOP_ITERATIONS//len(train_dl)):
    pbar = tqdm(enumerate(train_dl), total=len(train_dl))
    for ix, data in pbar:
        img = data['img'].float().cuda()
        seg = data['seg'].float().cuda()
        coords = data['coords'].float().cuda()

        loss, loss_info = meta_learner.forward(coords, {'gt':img, 'seg':seg, 'resolution' : data['resolution'] ,'seg_integer':data['seg_integer'].cuda().long()})
        psnr = -10 * np.log10(loss_info.get('mse_loss',0.01))
        # pbar.set_description(f"Loss: {loss.item():.5f} ce={loss_info['ce_loss']:.5f}, PSNR = {psnr.item():.5f} Dice={loss_info['dice_loss']:.5f}")
        pbar.set_description(f"Loss: {loss.item():.5f} PSNR = {psnr.item():.5f} Dice={loss_info.get('dice_loss',-1):.4f} FL={loss_info.get('focal_loss', -1):.5f} TV={loss_info.get('tv_loss',-1):.5f}")
        pbar.refresh()

        if ix % VAL_META_STEPS == 0 and ix > 0:
        # if i % VAL_META_STEPS == 0 and i > 0:

            val_dice_score = []
            val_iou_scores = []
            val_psnr_scores = []
            
            for val_ix, val in enumerate(val_dl):
                val_img = val['img'].float().cuda()
                val_seg = val['seg'].float().cuda()
                val_coords = val['coords'].float().cuda()

                # with torch.no_grad():
                render = meta_learner.render_inner_loop(val_coords, val_img, inner_loop_steps=VAL_STEPS)
                segmentation_output = render['output']['segmentation_output'].detach()#[0].detach().reshape(RES, RES, -1).cpu().numpy()
                segmentation_output = nn.functional.softmax(segmentation_output, dim=-1)
                segmentation_output = segmentation_output.argmax(dim=-1).detach().reshape(RES, RES, -1).cpu().numpy()
                img_recon = render['output']['inr_output'][0].reshape(RES, RES, -1).detach().cpu().numpy()
                segmentation_output_onehot = utils.convert_tensor_to_onehot(torch.tensor(segmentation_output.squeeze()), num_classes=NUM_CLASSES_AND_ONE)
                
                mse_val = img_recon.flatten() - val_img[0].detach().cpu().numpy().flatten()
                mse_val = np.mean(mse_val**2)
                psnr = psnr = -10 * np.log10(mse_val)
                dice_score = metrics.multiclass_dice_score(segmentation_output_onehot, val['seg'].detach().reshape(RES, RES,-1).cpu(), num_classes=NUM_CLASSES_AND_ONE)
                

                val_dice_score.append(float(dice_score.item()))
                val_psnr_scores.append(float(psnr))

            
            if np.mean(val_dice_score) > best_val_dice_score:
                best_val_psnr = np.mean(val_psnr_scores)
                best_val_dice_score = np.mean(val_dice_score)
                best_weights = deepcopy(meta_learner.model_params)
                best_inr_weights = deepcopy(meta_learner.get_inr_parameters())
                best_classifier_weights = deepcopy(meta_learner.get_segmentation_parameters())
                best_idx= ix
                print('updated dice score to ', best_val_dice_score)

            print(f"Mean PSNR={np.mean(val_psnr_scores):.5f} +/- {np.std(val_psnr_scores):.5f}")
            print(f"Mean Dice={np.mean(val_dice_score):.5f} +/- {np.std(val_dice_score):.5f}")
        

print("Best weights from Dice=", best_val_dice_score)

## Finetune MetaSeg's segmentation head

In [None]:
classifier_finetune_data = []

inr_initialization_weights = deepcopy(best_inr_weights)

pbar_gen_inr_traindata = tqdm(enumerate(train_dl), total=len(train_dl), position=1)
for train_ix, train_data in pbar_gen_inr_traindata:
    train_coords = train_data['coords'].float().cuda()
    train_img = train_data['img'].float().cuda()
    train_seg = train_data['seg'].float().cuda()

    # inr_render_model = libINR.models.make_inr_model(type='siren', **inr_config).float().cuda()
    inr_render_model = models.INR(**inr_config).float().cuda()
    inr_render_model.load_state_dict({k.replace("inr.",""):v.detach().clone() for k,v in deepcopy(inr_initialization_weights).items()})

    inr_render_model = inr_render_model.float().cuda()
    inr_render_model.compile()

    inr_render_model.fit(train_coords, train_img, epochs=TEST_RUN_STEPS, disable_tqdm=True) 
    train_img_output, train_img_features = inr_render_model.forward_w_features(train_coords)

    classifier_finetune_data.append({
        'seg': train_seg.detach().clone(),
        'img': train_img_output.detach().clone(),
        'coords': train_coords.detach().clone(),
        'resolution': train_data['resolution'],
        'features': train_img_features[-2].detach().clone() # input to classifier
    })


In [None]:


classifier_validation_data = []

pbar_gen_inr_valdata = tqdm(enumerate(val_dl), total=len(val_dl), position=1)
for val_ix, val_data in pbar_gen_inr_valdata:
    val_coords = val_data['coords'].float().cuda()
    val_img = val_data['img'].float().cuda()
    val_seg = val_data['seg'].float().cuda()

    # inr_render_model_val = libINR.models.make_inr_model(type='siren', **inr_config).float().cuda()
    inr_render_model_val = models.INR(**inr_config).float().cuda()
    inr_render_model_val.load_state_dict({k.replace("inr.",""):v.detach().clone() for k,v in deepcopy(inr_initialization_weights).items()})

    inr_render_model_val = inr_render_model_val.float().cuda()
    inr_render_model_val.compile()

    inr_render_model_val.fit(val_coords, val_img, epochs=TEST_RUN_STEPS, disable_tqdm=True) 
    val_img_output, val_img_features = inr_render_model_val.forward_w_features(val_coords)

    classifier_validation_data.append({
        'seg': val_seg.detach().clone(),
        'img': val_img_output.detach().clone(),
        'coords': val_coords.detach().clone(),
        'resolution': val_data['resolution'],
        'features': val_img_features[-2].detach().clone() # input to classifier
    })


In [None]:
CLASSIFIER_FINETUNE_EPOCHS = 4001
# prepare classifier
classifier_weights = deepcopy({k.replace("segmentation_head.segmentation_head","segmentation_head"):v.clone().detach() for k,v in best_classifier_weights.items()})
classifier_model = deepcopy(inr_seg_model.segmentation_head)
classifier_model.load_state_dict(classifier_weights)

classifier_opt = torch.optim.Adam(classifier_model.parameters(), lr=5e-5)
print(classifier_model, list(classifier_weights.keys()), list(classifier_model.state_dict().keys()))

finetune_classifier_loss_fn = loss_functions.LossFunction({'focal_loss':loss_functions.FocalSemanticLoss(gamma=GAMMA)})

final_classifier_weights = None
best_val_score = 1e7

pbar_epochs = tqdm(range(CLASSIFIER_FINETUNE_EPOCHS), position=0)
for epoch in pbar_epochs:
    
    avg_loss_per_set = 0.0

    for train_ix, train_data in enumerate(classifier_finetune_data):
        train_seg = train_data['seg'].float().cuda()
        classifier_input = train_data['features'].float().cuda()
        if NORMALIZE_FEATURES:
            classifier_input = nn.functional.normalize(classifier_input, dim=-1)

        classifier_opt.zero_grad()
        classifier_output = classifier_model(classifier_input)
        loss, loss_info = finetune_classifier_loss_fn({'output':{'segmentation_output':classifier_output.squeeze(1)}, 'seg':train_seg})
        loss.backward()
        classifier_opt.step()
        # lr_scheduler.step()

        avg_loss_per_set += float(loss.item())
        # pbar.set_description(f"Loss: {loss.item():.5f} ce={loss_info['ce_loss']:.5f}, PSNR = {psnr.item():.5f} Dice={loss_info['dice_loss']:.5f}")
    avg_loss_per_set /= len(train_dl)
    pbar_epochs.set_description(f"Loss (clf): {avg_loss_per_set:.5f}. Best Val Loss(clf): {best_val_score:.5f}")
    pbar_epochs.refresh()

    if epoch % VAL_META_STEPS == 0:
        with torch.no_grad():
            avg_val_loss = 0.0
            for val_ix, val_data in enumerate(classifier_validation_data):
                val_seg = val_data['seg'].float().cuda()
                classifier_val_input = val_data['features'].float().cuda()
                if NORMALIZE_FEATURES:
                    classifier_val_input = nn.functional.normalize(classifier_val_input, dim=-1)
                
                classifier_val_output = classifier_model(classifier_val_input)
                val_loss, val_loss_info = finetune_classifier_loss_fn({'output':{'segmentation_output':classifier_val_output.squeeze(1)}, 'seg':val_seg})

                avg_val_loss += float(val_loss.item())
            avg_val_loss = avg_val_loss / len(val_dl)
            pbar_epochs.set_description(f"Loss (clf): {avg_loss_per_set:.5f} Val Loss (clf): {avg_val_loss:.5f}")
            pbar_epochs.refresh()

            if avg_val_loss < best_val_score:
                best_val_score = avg_val_loss
                final_classifier_weights = deepcopy(classifier_model.state_dict())
                tqdm.write(f'updated best val score to {best_val_score}')
            else:
                tqdm.write(f'No improvement in val score {epoch}')


    # print(f"Avg loss for this epoch ({epoch}/{CLASSIFIER_FINETUNE_EPOCHS}) = {avg_loss_per_set:.5f}")




## Test-time. Fit Pixels, Get Labels!

In [None]:
inr_initialization_weights = deepcopy(best_inr_weights)

classifier_model.load_state_dict(final_classifier_weights)
classifier_model.eval()

pbar_test = tqdm(enumerate(test_dl), total=len(test_dl))
dice_scores = []
psnr_scores = []
# TEST_RUN_STEPS = 50
TEST_RUN_STEPS = 100
for test_ix, test_data in pbar_test:
    test_coords = test_data['coords'].float().cuda()
    test_img = test_data['img'].float().cuda()
    test_seg = test_data['seg'].float().cuda()

    # test_inr_render_model = libINR.models.make_inr_model(type='siren', **inr_config).float().cuda()
    test_inr_render_model = models.INR(**inr_config).float().cuda()
    test_inr_render_model.load_state_dict({k.replace("inr.",""):v.detach().clone() for k,v in deepcopy(inr_initialization_weights).items()})

    test_inr_render_model = test_inr_render_model.float().cuda()
    test_inr_render_model.compile()

    test_inr_render_model.fit(test_coords, test_img, epochs=TEST_RUN_STEPS, disable_tqdm=True) 
    test_img_output, test_img_features = test_inr_render_model.forward_w_features(test_coords)
    
    with torch.no_grad():
        classifier_input = test_img_features[-2].detach().clone()
        if NORMALIZE_FEATURES:
            classifier_input = nn.functional.normalize(classifier_input, dim=-1)
        classifier_output = classifier_model(classifier_input)
        seg_probs = nn.functional.softmax(classifier_output, dim=-1)
        seg_probs = seg_probs.argmax(dim=-1).detach().reshape(RES, RES, -1).cpu().numpy()

        seg_probs_onehot = utils.convert_tensor_to_onehot(torch.tensor(seg_probs.squeeze(-1)), num_classes=NUM_CLASSES_AND_ONE)

   
    psnr = metrics.psnr2(test_img.reshape(RES, RES).detach().cpu().numpy(), test_img_output.reshape(RES, RES).detach().cpu().numpy())
    dice = metrics.multiclass_dice_score(seg_probs_onehot, test_seg.reshape(RES, RES,-1).cpu(), num_classes=NUM_CLASSES_AND_ONE)
    psnr_scores.append(float(psnr))
    dice_scores.append(float(dice))
    # vis.plot_result_row(titles=['Seg', 'Seg(GT)'], 
    #                     imgs=[
    #                             seg_probs,
    #                             test_data['seg_integer'].reshape(RES, RES,-1).cpu().numpy()
    #                         ], save=None, show=True)
    
    pbar_test.set_description(f"Test Sample: {test_ix}/{len(test_dl)} PSNR = {float(psnr):.5f} Dice={float(dice):.5f}")
    pbar_test.refresh()


print("Average Reconstruction PSNR = ", np.mean(psnr_scores), "+/-", np.std(psnr_scores))
print("Average Segmentation Dice = ", np.mean(dice_scores), "+/-", np.std(dice_scores))
