## Inference.

Before running this script, 

1. Please run step1.ipynb: to generate meta learned initialization for the INR
2. Please run step2.ipynb with `SAVE_FEATURE_VECS = True`. This will generate INR feature vectors for train, val, test sets and store them into './dumps/intermediate_vectors'. This is done because its computationaly simplicity. 
3. Re-run step2.ipynb with `SAVE_FEATURE_VECS = False` . This will start finetuning the segmentation head of metaseg using these saved feature vectors. Please note that thes training segmentation head for 3D data is a tiring and long process and reaching converegence in a day or so approximately. To get best results, use smaller learning rates (>= 5e-5) and let optimization run for longer. Depending on yuor data, you may have to adjust the $\gamma$ parameter in Focal loss. Please read how $\gamma$ controls the focal loss properties in the "Focal Loss for Dense Object Detection by Lin et.al, ICCV 2017". 

4. Once your model is trained, you can run the inference codes. This will directly test on the test-set feature-vectors saved as output of `step2.ipynb`. This script does not render the 3D point clouds. 

5. Incase you want to render 3D point clouds, please fit the signal using the initialization obtained from `step1.ipynb`. 




In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES']  = '0'

In [None]:
import torch
torch.manual_seed(422)


In [None]:
import os, os.path as osp
import sys

import numpy as np
import glob
import matplotlib.pyplot as plt
import torch
import alpine
from copy import deepcopy
import torch.nn as nn

In [None]:
import numpy as np
import os, os.path as osp
import torch

from matplotlib import pyplot as plt
import libINR.models, libINR.utils
from tqdm.autonotebook import tqdm

In [None]:
import sys
sys.path.append('../../modules')
sys.path.append("../../")
from learner import INRMetaLearner
import dataloaders

import models
import loss_functions
import metrics
import utils
import vis

In [None]:
INNER_STEPS = 2
RANDOM_AUGMENT = False
TEST_RUN_STEPS = VAL_STEPS  = 100
SKIP_PIXELS = 2
VAL_META_STEPS = 100
OUTER_LOOP_ITERATIONS =  5000  # 5000
NUM_CLASSES = 4
NUM_CLASSES_AND_ONE = NUM_CLASSES + 1
# RES = (160, 192, 224)

RES = (160,160,200)
VAL_RES = RES = [160//SKIP_PIXELS, 160//SKIP_PIXELS, 200//SKIP_PIXELS]

NORMALIZE_FEATURES = False

In [None]:
print(RES, VAL_RES)

In [None]:
config_file = "../../config/oasis_splits_3d.json"

In [None]:
nonlin = 'siren'
inr_config = {"in_features":3, "out_features": 1, "hidden_features": 256, "hidden_layers": 4, }#"first_omega_0":200.0, 'hidden_omega_0':200.0} 
segmentation_config = {'hidden_features':[256,],#[128, 64],
                         'output_features' : NUM_CLASSES_AND_ONE}

In [None]:
inr_seg_model = models.SirenSegINR(
    inr_type='siren',
    inr_config=inr_config,
    segmentation_config=segmentation_config,
    normalize_features=NORMALIZE_FEATURES, #only change

).float().cuda()


In [None]:
coords_tmp = alpine.utils.coords.get_coords3d(RES[0], RES[1], RES[2]).float().cuda()[None,...]
print(coords_tmp.shape)

In [None]:
SAVE_PATHS = "./dumps/intermediate_vectors" # savepath from step2.ipynb
data_dir_feature_vecs = osp.join(SAVE_PATHS, "test")


test_clf_ds = dataloaders.CLFFeature(data_dir_feature_vecs, mode='test')

test_clf_dl = torch.utils.data.DataLoader(test_clf_ds, batch_size=1, shuffle=False, num_workers=1, pin_memory=False)

print(len(test_clf_dl))
print(len(test_clf_ds.all_files))

In [None]:
segmentation_clf_weights = "./weights/classifierfinal_weights_LR_5e-05_exp_gamma_3.0_INR_2it_skip_pixels_2_parallel_test.pth"

weights_saved = torch.load(segmentation_clf_weights)

final_clf_weights = weights_saved['final_clf_weights']

In [None]:
classifier_model = deepcopy(inr_seg_model.segmentation_head)
classifier_model.load_state_dict(final_clf_weights)
classifier_model.eval()

### Just compute DICE metrics for segmentation

In [None]:

pbar_test = tqdm(enumerate(test_clf_dl), total=len(test_clf_dl))
dice_scores = []
psnr_scores = []
psnr = 0
dice = 0
TEST_RUN_STEPS = 100 
for test_ix, test_data in pbar_test:
    
    test_features = test_data['features'].float().cuda().squeeze(0).squeeze(0)
    test_img = test_data['img'].float().cuda()
    test_seg = test_data['seg'].float().cuda()

    print(test_features.shape, test_img.shape, test_seg.shape)

    with torch.no_grad():
        classifier_input = test_features.detach().clone()
        classifier_output = classifier_model(classifier_input).unsqueeze(0).unsqueeze(0)
        seg_probs = nn.functional.softmax(classifier_output, dim=-1)
        seg_probs = seg_probs.argmax(dim=-1).reshape(VAL_RES).detach().cpu().numpy()
        seg_probs_onehot = torch.nn.functional.one_hot(torch.from_numpy(seg_probs), num_classes=NUM_CLASSES_AND_ONE)

    test_seg_reshaped = test_seg.reshape(VAL_RES[0], VAL_RES[1], VAL_RES[2], -1)
    test_seg_int = torch.argmax(test_seg_reshaped, dim=-1).detach().cpu().numpy()
    dice = metrics.multiclass_dice_score_3d(seg_probs_onehot.cuda(), test_seg_reshaped.cuda(), num_classes=NUM_CLASSES_AND_ONE)
    dice_3d = float(dice.item())
    dice_scores.append(float(dice))
    print("Dice score for test sample ", test_ix, " = ", dice)

    n_frames = seg_probs.shape[-1]
    local_labels = []
    local_dice = []
    # for k in range(5, n_frames - 5, (n_frames-10)//6):
    for k in range(n_frames):
        # local_labels.append(seg_probs[...,k].detach().cpu().numpy())
        dice_2d = metrics.multiclass_dice_score(torch.nn.functional.one_hot(torch.from_numpy(seg_probs[...,k]).cuda(), num_classes=NUM_CLASSES_AND_ONE),
                                                test_seg_reshaped[...,k,:].cuda(), num_classes=NUM_CLASSES_AND_ONE)
        dice_2d = float(dice_2d.item())
        local_dice.append(dice_2d)
        
    
    print(f"Test Sample: {test_ix}/{len(test_clf_dl)} Dice={float(dice):.5f}. Img Dice={np.mean(local_dice)}")
    pbar_test.set_description(f"Test Sample: {test_ix}/{len(test_clf_dl)} Dice={float(dice):.5f}")
    pbar_test.refresh()
    print("---"*100)



In [None]:
print("Average Segmentation Dice = ", np.mean(dice_scores), "+/-", np.std(dice_scores))