In [None]:
%matplotlib inline
%matplotlib widget

import matplotlib.pyplot as plt
from ipywidgets import interact, FloatSlider

In [None]:

import operator
import cc3d
from fran.transforms.inferencetransforms import BacksampleMask
from fran.inference.inference_base import ArrayToSITK
from fran.utils.imageviewers import ImageMaskViewer
from fastai.vision.augment import ToTensor, Transform, store_attr, typedispatch
from fran.transforms.spatialtransforms import one_hot
from fran.utils.common import *
import SimpleITK as sitk
from monai.metrics import *
import functools as fl
import itertools as il

In [None]:


@ToTensor
def encodes(self,x:np.ndarray): 
   x = x.astype(np.uint8)
   x_pt = torch.tensor(x,dtype=torch.uint8)
   return x_pt

@ToTensor
def encodes(self,x:sitk.Image): 
   x_np = sitk.GetArrayFromImage(x)
   x_pt = torch.from_numpy(x_np)
   return x_pt

In [None]:

@typedispatch
def img_shape(x:sitk.Image):
   return x.GetSize()

@typedispatch
def img_shape(x:torch.Tensor):
   return x.shape

@ToTensor
def encodes(self,x:sitk.Image): 
   x_np = sitk.GetArrayFromImage(x)
   x_pt = torch.from_numpy(x_np)
   return x_pt

In [None]:
def randomize_labels(labels):
    for label in range(1,labels.max()):
        labels[labels==label]= random.randint(20,40)
    return labels

In [None]:
if __name__ == "__main__":
    P = Project(project_title="lits")
    proj_defaults= P.proj_summary

In [None]:
    configs_excel = ConfigMaker(proj_defaults.configuration_filename,raytune=False).config
    train_list, valid_list, test_list = get_fold_case_ids(
            fold=configs_excel['metadata']["fold"],
            json_fname=proj_defaults.validation_folds_filename,
        )



In [None]:
    mask_files = list((proj_defaults.raw_data_folder/("masks")).glob("*nii*"))
    img_files= list((proj_defaults.raw_data_folder/("images")).glob("*nii*"))
    masks_valid = [filename for filename in mask_files if  get_case_id_from_filename(proj_defaults.project_title, filename) in valid_list]
    masks_train = [filename for filename in mask_files if  get_case_id_from_filename(proj_defaults.project_title, filename) in train_list]
    imgs_valid =  [proj_defaults.raw_data_folder/"images"/mask_file.name for mask_file in masks_valid]
    imgs_test =  [filename for filename in img_files if  get_case_id_from_filename(proj_defaults.project_title, filename) in test_list]
    imgs_train =  [filename for filename in img_files if  get_case_id_from_filename(proj_defaults.project_title, filename) in train_list]

In [None]:
    run_name = "LITS-122"
    preds_folder = list(proj_defaults.predictions_folder.glob(f"*{run_name}"))[0]
    pred_fns = list(preds_folder.glob("*"))
    pred_fn = pred_fns[0]
    case_id = get_case_id_from_filename('lits',pred_fn)

In [None]:
    
    mask_fn = [fn for fn in masks_train if 'lits-128' in str(fn)][0]
    img_fn = [fn for fn in imgs_train if 'lits-128' in str(fn)][0]

In [None]:

    img = sitk.ReadImage(img_fn)
    img_np= sitk.GetArrayFromImage(img)
    img_pt = torch.tensor(img_np)

In [None]:
    pred= sitk.ReadImage(pred_fn)
    pred_np = sitk.GetArrayFromImage(pred)

In [None]:
    mask = sitk.ReadImage(mask_fn)
    img.SetOrigin((0,0,0))
    mask.SetOrigin((0,0,0))
    sitk.WriteImage(img,'tmp/img.nii')
    sitk.WriteImage(mask,'tmp/mask.nii')

In [None]:
    mask_np= sitk.GetArrayFromImage(mask)
    mask_pt = torch.tensor(mask_np)
    mask_np[mask_np==1] =0
    pred_np[pred_np==1] =0
    spacings = pred.GetSpacing()
    voxvol = fl.reduce(operator.mul,spacings)  # mm3

In [None]:
    labels_org, N_org = cc3d.connected_components(mask_np, return_N=True) 
    labels_pred, N_pred = cc3d.connected_components(pred_np, return_N=True) 
    lpr = randomize_labels(labels_pred.copy())

    stats_pred  = cc3d.statistics(labels_pred)
    stats_org = cc3d.statistics(labels_org)
    centroids_org, centroids_pred = [stats['centroids'][1:] for stats in [stats_org,stats_pred]]
    bbox_org, bbox_pred = [stats['bounding_boxes'][1:] for stats in [stats_org,stats_pred]]

In [None]:
    def respace_centroids_sitk_format(centroids_np,spacings_sitk):
        centroids_sitk = centroids_np[:,::-1]
        centroids_rescaled= centroids_sitk*spacings_sitk
        return centroids_rescaled

In [None]:
    z_max=5 # 3mm
    com, cpm= [respace_centroids_sitk_format(centroids,spacings) for centroids in  [centroids_org,centroids_pred]]

In [None]:
    ImageMaskViewer([img_np,labels_org],data_types=['img','mask']) # 
    ImageMaskViewer([lpr,labels_pred],data_types=['mask','mask']) # 

# Distances

In [None]:
    label_index = 7
    cp = cpm[label_index-1]
    distances =   np.sqrt(np.sum((cp-com)**2,1))

In [None]:

    if not any(distances<z_max): print("Label {} is a false positive lesion".format(label_index))
    else: print(np.argmin(distances)+1, distances.min())

In [None]:

    #Jaccard
    n_classes = np.maximum(N_pred,N_org)+ 1 # 1 bg class

In [None]:
    AS = ArrayToSITK(img)
    label_sitk = AS.encodes(labels_org)
    sitk.WriteImage(label_sitk,'tmp/labels_org.nii')

In [None]:
    T = ToTensor()
    BM = BacksampleMask(img)
    y = T.encodes(labels_pred)
    y = BM(y)
    lt = T(labels_pred) 