In [None]:
from inference import segmentation_scores
import os
from torch.utils import data
import utils.globals as globals
import torch
from data import InferenceDataset,get_preprocessing_fn_without_normalization,get_preprocessing, PandaDataset
import numpy as np
import pandas as pd
from statsmodels.stats.descriptivestats import sign_test
import json

if torch.cuda.is_available():
    device = 'cuda'
    torch.cuda.set_device(1)
else:
    device = 'cpu'
print("Running on " + device)

In [None]:
###### BREAST #########
class_no = 4
ignore_classes = 0
metrics_names = ['macro_dice', 'micro_dice', 'accuracy','brier']
class_names = ['other','tumor', 'stroma', 'inflammation']
dir =  '/datasets/breast/expert/'
mode = 'smooth'
folder = 3
model_path = f'/experiments/{mode}/{folder}/models/best_model.pth'
imgs_dir = f'{dir}patches/Test'
masks_dir = f'{dir}masks/Test/expert'
c_weights = [0.0,0.597,0.655,0.862]
all_results = {'macro_dice':[],
 'micro_dice': [],
 'dice_class_1_tumor': [],
 'dice_class_2_stroma': [],
 'dice_class_3_inflammation': [],
 'f1_class_1_tumor': [],
 'f1_class_2_stroma': [],
 'f1_class_3_inflammation': [],
 'prec_class_1_tumor': [],
 'prec_class_2_stroma': [],
 'prec_class_3_inflammation': [],
 'recall_class_1_tumor': [],
 'recall_class_2_stroma': [],
 'recall_class_3_inflammation': [],
 'brier': [],
 'ce':[],
 'accuracy': []}

preprocessing_fn = get_preprocessing_fn_without_normalization()
preprocessing = get_preprocessing(preprocessing_fn)
test_dataset = InferenceDataset(imgs_dir, masks_dir, preprocessing = preprocessing)
testloader = data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1,
                                 drop_last=False)

In [None]:
######### PANDA #########

class_no = 5
ignore_classes = 0
metrics_names = ['macro_dice', 'micro_dice', 'accuracy','brier']
class_names = ['background','gleason1', 'gleason2', 'gleason3', 'gleason4']
dir =  '/datasets/PANDA/patches/'
mode = 'supervised_PANDA'
folder = 2
c_weights = [0.1,0.128,0.9773,0.9796,0.9234,0]
model_path = f'/experiments/{mode}/{folder}/models/best_model.pth'
imgs_dir = f'{dir}imgs'
masks_dir = f'{dir}masks/original'
all_results = {'macro_dice':[],
 'micro_dice': [],
 'dice_class_1_gleason1': [],
 'dice_class_2_gleason2': [],
 'dice_class_3_gleason3': [],
 'dice_class_4_gleason4': [],
 'f1_class_1_gleason1': [],
 'f1_class_2_gleason2': [],
 'f1_class_3_gleason3': [],
 'f1_class_4_gleason4': [],
 'prec_class_1_gleason1': [],
 'prec_class_2_gleason2': [],
 'prec_class_3_gleason3': [],
 'prec_class_4_gleason4': [],
 'recall_class_1_gleason1': [],
 'recall_class_2_gleason2': [],
 'recall_class_3_gleason3': [],
 'recall_class_4_gleason4': [],
 'brier': [],
 'ce': [],
 'accuracy': []}

preprocessing_fn = get_preprocessing_fn_without_normalization()
preprocessing = get_preprocessing(preprocessing_fn)
test_dataset = InferenceDataset(imgs_dir, masks_dir, panda=True, preprocessing = preprocessing)
testloader = data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1,
                                 drop_last=False)

In [None]:
model = torch.load(model_path).to(device)
model.eval()

labels = []
preds = []

with torch.no_grad():
    for j, (test_img, test_label, test_name, _) in enumerate(testloader):
        test_img = test_img.to(device=device, dtype=torch.float32)

        pred = model(test_img)

        loss = torch.nn.CrossEntropyLoss(weight=torch.FloatTensor(c_weights).cuda(),ignore_index=0,reduction='mean')(
                        pred, test_label.cuda())

        _, test_pred = torch.max(pred, dim=1)

        
        test_pred_np = test_pred.cpu().detach().numpy()
        test_label = test_label.cpu().detach().numpy()

        preds.append(test_pred_np.astype(np.int8).copy().flatten())
        labels.append(test_label.astype(np.int8).copy().flatten())
        if test_label.sum()!=0:
            all_results['ce'].append(loss.item())
            
    
            results = segmentation_scores(test_label, test_pred_np, metrics_names,class_names=class_names,class_no=class_no,weights=c_weights,ignore_class=ignore_classes)
            for key,value in results.items():
                all_results[key].append(value)



final_results = segmentation_scores(labels, preds, metrics_names,class_names=class_names,class_no=class_no,weights=c_weights,ignore_class=ignore_classes)
df = pd.DataFrame.from_dict(all_results)
df.to_csv(f'/experiments/{mode}/{folder}/{mode}_results.csv')


In [None]:
#breast
final_result = {'macro_dice':[],
 'micro_dice': [],
 'dice_class_1_tumor': [],
 'dice_class_2_stroma': [],
 'dice_class_3_inflammation': [],
 'f1_class_1_tumor': [],
 'f1_class_2_stroma': [],
 'f1_class_3_inflammation': [],
 'prec_class_1_tumor': [],
 'prec_class_2_stroma': [],
 'prec_class_3_inflammation': [],
 'recall_class_1_tumor': [],
 'recall_class_2_stroma': [],
 'recall_class_3_inflammation': [],
 'brier': [],
 'ce':[],
 'accuracy': []}

In [None]:
#PANDA
final_result = {'macro_dice':[],
 'micro_dice': [],
 'dice_class_1_gleason1': [],
 'dice_class_2_gleason2': [],
 'dice_class_3_gleason3': [],
 'dice_class_4_gleason4': [],
 'f1_class_1_gleason1': [],
 'f1_class_2_gleason2': [],
 'f1_class_3_gleason3': [],
 'f1_class_4_gleason4': [],
 'prec_class_1_gleason1': [],
 'prec_class_2_gleason2': [],
 'prec_class_3_gleason3': [],
 'prec_class_4_gleason4': [],
 'recall_class_1_gleason1': [],
 'recall_class_2_gleason2': [],
 'recall_class_3_gleason3': [],
 'recall_class_4_gleason4': [],
 'brier': [],
 'ce': [],
 'accuracy': []}

In [None]:
for res,value in all_results.items():
    final_result[res].append(np.array(value).mean())
    final_result[res].append(np.std(value))

In [None]:
dict_res = {'balanced_metrics':final_results,'all_images':final_result}
with open(f'/experiments/{mode}/{folder}/{mode}_results.json', "w") as outfile:
    json.dump(dict_res, outfile)

In [None]:
df = pd.DataFrame.from_dict(all_results)
df.to_csv(f'/experiments/{mode}/{folder}/{mode}_results.csv')

In [None]:
smooth = pd.read_csv(f'/experiments/smooth/3/smooth_results.csv')
sc = pd.read_csv(f'/experiments/sc/3/sc_results.csv')


In [None]:
sign_test(smooth['brier'],sc['brier']),sign_test(smooth['accuracy'],sc['accuracy']),sign_test(smooth['ce'],sc['ce'])

In [None]:
sign_test(smooth['micro_dice'],sc['micro_dice']),sign_test(smooth['macro_dice'],sc['macro_dice'])

In [None]:
sign_test(smooth['dice_class_1_tumor'],sc['dice_class_1_tumor']),sign_test(smooth['dice_class_2_stroma'],sc['dice_class_2_stroma']),sign_test(smooth['dice_class_3_inflammation'],sc['dice_class_3_inflammation'])