In [None]:
import os
import random
import torch
import numpy as np
from nnunet.inference.predict import load_model_and_checkpoint_files
from sklearn.metrics import f1_score

#os.system('jupyter nbconvert --to python Pytorch_model.ipynb')
from Pytorch_model import nnUnet, Multitask_nnUNet

#os.system('jupyter nbconvert --to python Pytorch_train_multitask.ipynb')
from Pytorch_train_multitask import train, evaluation

seed = 0
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

In [None]:
#DATA
dim = (32*4, 32*3, 32*6)
spacing = 4

#IMAGE GENERATOR
scale = (0.80, 1.20)
sigma = (0.3, 0.8)
task = '001' # pretrained nnUnet -> 001 : PET SUV - 002 : PET sans SUV - 003 - PET CT 2 Channels - 004 - PET CT Fusion

#MODELE
learning_rate = 1e-4
score_weight = 3
drop_encode = 0.5
l1_lambda_fc1, l2_lambda_fc1 = 1e-3, 1e-3
weight_decay = 3e-5

batch_size = 1
nb_epoch = 40
num_workers = 10
patience = 3

In [None]:
base = '/media/nguyen-k/nnUNet_trained_models/nnUNet/3d_fullres'
list_task = os.listdir(base)

for t in list_task :
    if task in t :
        folders = os.path.join(base, t, 'nnUNetTrainerV2__nnUNetPlansv2.1')

In [None]:
for ind_list in range(1) : 

    test = 'Multitask_PET_L'+str(ind_list)
    dir_base = '/home/nguyen-k/Bureau/segCassiopet2/Comparatif/Archives_MTL3/'+test
    try:
        os.mkdir(dir_base)
    except OSError as error: 
        print(error) 

    path_list = '/home/nguyen-k/Bureau/segCassiopet/List_Patient_'+str(ind_list)
    list_test = list(np.load(path_list + '/Test/list_test.npy'))
    test_label_classe = np.load(path_list + '/Test/test_label_classe.npy')   
    test_prob = np.zeros((len(list_test), 3))
    np.save(dir_base+'/list_test.npy', list_test)
    np.save(dir_base+'/test_label_classe.npy', test_label_classe)

    for fold in range(1, 6) :
        print('LIST PATIENT', ind_list, ' - FOLD', fold)  

        dir_p = dir_base+'/Fold'+str(fold)
        dir_p_1 = dir_p+'/Fig_seg_val'
        dir_p_2 = dir_p+'/Fig_seg_test'
        
        try:
            os.mkdir(dir_p)
        except OSError as error: 
            print(error) 

        try:
            os.mkdir(dir_p_1)
        except OSError as error:
            print(error)

        try:
            os.mkdir(dir_p_2)
        except OSError as error:
            print(error)
        
        #TRAIN     
        trainer, params_tr = load_model_and_checkpoint_files(folders, folds=None, mixed_precision=None, checkpoint_name="model_best")

        nn_Unet = nnUnet(trainer.network)
        state_dict = trainer.network.state_dict()
        nn_Unet.load_state_dict(state_dict)
        MultitaskNet = Multitask_nnUNet(nn_Unet, drop_encode=drop_encode, n_classes=3).cuda()   
        
        train(fold, MultitaskNet, nb_epoch, score_weight, l1_lambda_fc1, l2_lambda_fc1, dim, spacing, scale, sigma, 
                    num_workers, drop_encode, batch_size, learning_rate, patience, weight_decay, dir_p, path_list, seed)

        #EVALUATION
        print('LIST PATIENT', ind_list, ' - FOLD', fold, ' - VALIDATION')
        path_train_val = os.path.join(path_list, 'Fold'+ str(fold))
        list_val = list(np.load(path_train_val+'/list_val.npy'))
        val_label_classe = np.load(path_train_val+'/val_label_classe.npy')
        evaluation(MultitaskNet, list_val, val_label_classe, scale, sigma, dim, spacing, num_workers, dir_p_1)

        #TEST
        print('LIST PATIENT', ind_list, ' - FOLD', fold, ' - TEST')
        test_prob = evaluation(MultitaskNet, list_test, test_label_classe, scale, sigma, dim, spacing, num_workers, dir_p_2)
        np.save(os.path.join(dir_p, 'test_prob.npy'), test_prob)
        print(' ')        

        del MultitaskNet

In [None]:
# Get results for ensemble method (average proba prediction over 5 folds)
os.system('jupyter nbconvert --to python roc_Precision_Recall.ipynb')
from roc_Precision_Recall import *

for ind_list in range(1) : 

    print('LIST PATIENT', ind_list)  

    test = 'Multitask_PET_L'+str(ind_list)
    dir_base = '/home/nguyen-k/Bureau/segCassiopet2/Comparatif/Archives_MTL3/'+test

    fold = 1
    dir_p = dir_base+'/Fold'+str(fold)
    sum = np.load(os.path.join(dir_p, 'test_prob.npy'))

    path_list = '/home/nguyen-k/Bureau/segCassiopet2/List_Patient_'+str(ind_list)
    list_test = list(np.load(path_list + '/Test/list_test.npy'))
    test_label_classe = np.load(path_list + '/Test/test_label_classe.npy')   
    test_label_classe = np.array(test_label_classe, dtype=np.uint8)

    for fold in range(2, 6) :
        dir_p = dir_base+'/Fold'+str(fold)
        test_prob = np.load(os.path.join(dir_p, 'test_prob.npy'))
        sum += test_prob

    test_prob = sum / 5

    pred = np.zeros(test_label_classe.shape[0])
    for i in range(test_label_classe.shape[0]) : 
        pred[i] = np.argmax(test_prob[i])

    mat_label = np.zeros((test_label_classe.shape[0],3))
    for i in range(test_label_classe.shape[0]) :
        mat_label[i, int(test_label_classe[i])] = 1

    roc_auc, fpr, tpr = compute_ROC_auc(y_label=mat_label,y_predicted=test_prob,n_classes=3)
    plt.clf()
    plot_ROC_curve(fpr,tpr,roc_auc,classe=0,color='blue')
    plot_ROC_curve(fpr,tpr,roc_auc,classe=1,color='red')
    plot_ROC_curve(fpr,tpr,roc_auc,classe=2,color='black')
    plt.savefig(dir_base+'/ROC.png')

    precision, recall,average_precision = compute_precision_recall(y_label=mat_label,y_predicted=test_prob,n_classes=3)
    plot_precision_recall_curve(precision, recall, average_precision,n_classes=2,color=['blue','red','black'])
    plt.savefig(dir_base+'/AUC.png')

    print('Micro F1 score ', f1_score(y_true=test_label_classe, y_pred=pred, average='micro'))
    print('Macro F1 score ', f1_score(y_true=test_label_classe, y_pred=pred, average='macro'))
    print('Weighted F1 score ', f1_score(y_true=test_label_classe, y_pred=pred, average='weighted'))