In [6]:
import numpy as np
import torch
import torch.nn
from pathlib import Path
from torch.utils.data import DataLoader
from datasets import lits17_no_slice, brats20_no_slice, kits21_no_slice
from tqdm.notebook import tqdm
from netcal.metrics import ECE

In [7]:
def eval_ece(pred, target):
    ece = ECE(bins=10)
    return ece.measure(pred, target)

In [8]:
def eval_uncertainty(pred, target, uncertainty):
    tp = (pred & target)
    tn = ((~pred) & (~target))
    fp = (pred & (~target))
    fn = ((~pred) & target)
    
    t = tp | tn
    f = fp | fn
    
    return uncertainty[t].mean().item(), uncertainty[f].mean().item()

In [9]:
def eval_dice(pred, target):
    tp = (pred & target).sum()
    total = pred.sum() + target.sum()
    return (2 * tp / total).item()

In [10]:
def eval_brier(pred, target):
    return np.mean((target - pred) ** 2)

In [11]:
ROOT = Path('/scratch/zc2357/cv/final/nyu-cv2271-final/baseline/runs/')

tasks = [
    {
        'name': 'lits17_baseline',
        'dataset': lits17_no_slice,
        'in_channels': 1,
        'n_classes': 2,
        'enabled': True,
        'root': ROOT / 'Dec12_04-00-51_gr017.nyu.cluster_lits17_baseline_lr1.0e-04_weightDecay1.0e-02',
        'encoder_path': 'lits17_epoch_55_step_5824_encoder.pth',
        'decoder_mean_path': 'lits17_epoch_55_epoch_55_loss_0.04442_decoder_mean.pth',
        'decoder_var_path': 'lits17_epoch_55_epoch_55_loss_0.04442_decoder_var.pth',
    },
    {
        'name': 'lits17_cotraining',
        'dataset': lits17_no_slice,
        'in_channels': 1,
        'n_classes': 2,
        'enabled': True,
        'root': ROOT / 'Dec11_19-20-25_gr011.nyu.cluster_lits17_brats20_kits21_cotraining_baseline_lr1.0e-04_weightDecay1.0e-02',
        'encoder_path': 'lits17_epoch_59_step_6241_encoder.pth',
        'decoder_mean_path': 'lits17_epoch_59_epoch_59_loss_0.04083_decoder_mean.pth',
        'decoder_var_path': 'lits17_epoch_59_epoch_59_loss_0.04083_decoder_var.pth',
    },
    {
        'name': 'brats20_baseline',
        'dataset': brats20_no_slice,
        'in_channels': 4,
        'n_classes': 2,
        'enabled': True,
        'root': ROOT / 'Dec11_14-53-08_gr011.nyu.cluster_brats20_baseline_lr1.0e-04_weightDecay1.0e-02',
        'encoder_path': 'brats20_epoch_21_step_6490_encoder.pth',
        'decoder_mean_path': 'brats20_epoch_21_epoch_21_loss_0.09565_decoder_mean.pth',
        'decoder_var_path': 'brats20_epoch_21_epoch_21_loss_0.09565_decoder_var.pth',
    },
    {
        'name': 'brats20_cotraining',
        'dataset': brats20_no_slice,
        'in_channels': 4,
        'n_classes': 2,
        'enabled': True,
        'root': ROOT / 'Dec11_19-20-25_gr011.nyu.cluster_lits17_brats20_kits21_cotraining_baseline_lr1.0e-04_weightDecay1.0e-02',
        'encoder_path': 'brats20_epoch_14_step_4596_encoder.pth',
        'decoder_mean_path': 'brats20_epoch_14_epoch_14_loss_0.09117_decoder_mean.pth',
        'decoder_var_path': 'brats20_epoch_14_epoch_14_loss_0.09117_decoder_var.pth',
    },
    {
        'name': 'kits21_baseline',
        'dataset': kits21_no_slice,
        'in_channels': 1,
        'n_classes': 2,
        'enabled': True,
        'root': ROOT / 'Dec11_05-24-43_gr038.nyu.cluster_kits21_baseline_lr1.0e-04_weightDecay1.0e-02',
        'encoder_path': 'kits21_epoch_25_step_6240_encoder.pth',
        'decoder_mean_path': 'kits21_epoch_25_epoch_25_loss_0.05592_decoder_mean.pth',
        'decoder_var_path': 'kits21_epoch_25_epoch_25_loss_0.05592_decoder_var.pth',
    },
    {
        'name': 'kits21_cotraining',
        'dataset': kits21_no_slice,
        'in_channels': 1,
        'n_classes': 2,
        'enabled': True,
        'root': ROOT / 'Dec11_19-20-25_gr011.nyu.cluster_lits17_brats20_kits21_cotraining_baseline_lr1.0e-04_weightDecay1.0e-02',
        'encoder_path': 'kits21_epoch_25_step_6241_encoder.pth',
        'decoder_mean_path': 'kits21_epoch_25_epoch_25_loss_0.05474_decoder_mean.pth',
        'decoder_var_path': 'kits21_epoch_25_epoch_25_loss_0.05474_decoder_var.pth',
    },
]

# DO NOT CHANGE; CHANGING THESE BREAKS REPLICATION
SEED = 42
TRAIN_VAL_SPLIT = 0.8  # 80% training
# /DO NOT CHANGE

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

for i in range(len(tasks)):
    task = tasks[i]
    if task['enabled']:
        n = len(task['dataset'])
        idxrange = np.arange(n)
        rng = np.random.RandomState(SEED)
        rng.shuffle(idxrange)
        n_train = int(n * TRAIN_VAL_SPLIT)
        task['train_idx'] = idxrange[:n_train]
        task['val_idx'] = idxrange[n_train:]

        task['val_dataloader'] = DataLoader(
            task['dataset'],
            batch_size=1,
            sampler=task['val_idx'],
            num_workers=1,
        )

In [12]:
SAVEROOT = Path('/scratch/zc2357/cv/final/nyu-cv2271-final/baseline/inference/')

for task in tasks:
    print(task['name'])
    val_dice_loss = []
    val_uq_t = []
    val_uq_f = []
    val_brier = []
    val_ece = []
    for i, (X, y) in tqdm(enumerate(task['val_dataloader']), total=len(task['val_dataloader'])):
        alea = np.load(SAVEROOT / task['name'] / ('alea_%03d.npy' % i))
        epis = np.load(SAVEROOT / task['name'] / ('epis_%03d.npy' % i))
        yhat = np.load(SAVEROOT / task['name'] / ('yhat_%03d.npy' % i))
        total_uncertainty = alea + epis
        yhat_bool = yhat > 0.5
        target = (y.cpu().numpy() >= 1)[0,0]
        
        this_val_dice_loss = eval_dice(yhat_bool, target)
        val_dice_loss.append(this_val_dice_loss)
        
        this_val_uq_t, this_val_uq_f = eval_uncertainty(yhat_bool, target, total_uncertainty)
        val_uq_t.append(this_val_uq_t)
        val_uq_f.append(this_val_uq_f)
        
        this_val_brier = eval_brier(yhat.flatten(), target.flatten())
        val_brier.append(this_val_brier)
        
        this_val_ece = eval_ece(yhat.flatten(), target.flatten())
        val_ece.append(this_val_ece)
        
        print(this_val_dice_loss, this_val_uq_t, this_val_uq_f, this_val_brier, this_val_ece)
    
    task['val_dice_loss'] = val_dice_loss
    task['val_uq_t'] = val_uq_t
    task['val_uq_f'] = val_uq_f
    task['val_brier'] = val_brier
    task['val_ece'] = val_ece

lits17_baseline


  0%|          | 0/27 [00:00<?, ?it/s]

0.9344466055210187 0.008380726911127567 0.4584302008152008 0.004838357 0.0038510670903276487
0.9609969289175758 0.0072391098365187645 0.447378009557724 0.0034273937 0.002056968837211408
0.8937703634367181 0.009192978963255882 0.45450910925865173 0.00699488 0.006553716269382322
0.9461381012541504 0.0067254831083118916 0.5126076340675354 0.003058055 0.0027068106162251903
0.9291439080561339 0.01100191380828619 0.42051541805267334 0.006442762 0.005172451371251592
0.9548815198499424 0.008156044408679008 0.4121928811073303 0.0044301827 0.0031794464397117144
0.9298268663314685 0.009471721947193146 0.47167178988456726 0.004156042 0.0033890899251908992
0.9512317763974751 0.007682390045374632 0.4154490828514099 0.004562042 0.004311940746775625
0.900304497871276 0.013120855204761028 0.5031876564025879 0.010197328 0.011880778869423361
0.7602161325074064 0.011418555863201618 0.25638991594314575 0.0284871 0.02746512036373662
0.9249234132520243 0.005891166161745787 0.5403608083724976 0.0036589026 0.0

  0%|          | 0/27 [00:00<?, ?it/s]

0.9401986186674333 0.0054496522061526775 0.47381818294525146 0.0043715197 0.003338806477444286
0.9588342268337436 0.004974425304681063 0.4600936770439148 0.0035615247 0.0028409306273648547
0.9293451836754816 0.006423479877412319 0.5042099356651306 0.004298346 0.00341775130850975
0.9414437897224346 0.0056578353978693485 0.48113811016082764 0.0034344406 0.0032732359068719666
0.9321478193968379 0.009881701320409775 0.49968844652175903 0.0060153017 0.00537651895544863
0.9500368186458549 0.007012215908616781 0.43571752309799194 0.004925388 0.004317089071412794
0.9127049558379263 0.009921764023602009 0.49384403228759766 0.005112678 0.004345253359337782
0.9592483563712906 0.004957269877195358 0.4724712073802948 0.0036398678 0.002888121773922411
0.9069781579160668 0.010679945349693298 0.5193970799446106 0.009587669 0.0107043875707761
0.7708492782764342 0.008879635483026505 0.24310855567455292 0.025312489 0.02401079913321881
0.9570709175344345 0.004873075056821108 0.51103675365448 0.0022095963 

  0%|          | 0/74 [00:00<?, ?it/s]

0.9411498703895955 0.004053816199302673 0.540231466293335 0.0015387199 0.0011608095013141853
0.8607474540229495 0.00648141885176301 0.517217218875885 0.0034133308 0.00433004005917834
0.912797739200646 0.0021713220048695803 0.6042444109916687 0.00090023276 0.0007346253215853988
0.959146700768546 0.003075364977121353 0.6346980333328247 0.0010622203 0.000587325088855964
0.9168507387022016 0.003341040341183543 0.6218405961990356 0.0015262832 0.0009947056217502362
0.9484863643333492 0.004082457162439823 0.5446654558181763 0.0016951632 0.0012274743959778559
0.8288977633722655 0.006110924296081066 0.6445692181587219 0.003197552 0.0041940494725725965
0.91627725171055 0.005707348696887493 0.6738898158073425 0.002058624 0.0015832795949769312
0.9084182527214109 0.0066076950170099735 0.7250993251800537 0.0022827655 0.0012300627008504051
0.8874661167744833 0.003417902858927846 0.5593584179878235 0.0017105424 0.001611866711266132
0.9426595584411798 0.002961642574518919 0.5655246376991272 0.000974891

  0%|          | 0/74 [00:00<?, ?it/s]

0.943423184927694 0.006549411453306675 0.6188982725143433 0.0014979905 0.0004419498237040984
0.919970219206171 0.007038329262286425 0.6343194246292114 0.002197439 0.0009455698718857304
0.8203113461140815 0.008159779943525791 0.612375020980835 0.0021597433 0.003630841280282335
0.9588190010415252 0.0053583974950015545 0.7331980466842651 0.001120148 0.0009820100571155544
0.9206053901850362 0.006397091783583164 0.6154395937919617 0.0016507625 0.0021571671723204806
0.9519856714772845 0.006150558590888977 0.7014436721801758 0.0015692565 0.0005686052009073384
0.775204957742914 0.011915406212210655 0.679207980632782 0.004401627 0.006795758352875618
0.8793858728811097 0.0068206945434212685 0.5735806822776794 0.0032569796 0.003974188770730091
0.8671107700056362 0.01325242780148983 0.6540113687515259 0.003885001 0.005791117517477679
0.9242452739074962 0.005934530403465033 0.6475434303283691 0.0012496174 0.0005894491319087725
0.9453534178904004 0.005373718682676554 0.7628120183944702 0.000989819 0

  0%|          | 0/60 [00:00<?, ?it/s]

0.9344449788894233 0.0044127875007689 0.5262192487716675 0.001299802 0.0011530634038017002
0.9595173381207556 0.0045671830885112286 0.47905048727989197 0.0014801879 0.001151898647418629
0.9006549316645855 0.002739187330007553 0.42897897958755493 0.0020543963 0.0015915071418902706
0.9492724751332212 0.00209419266320765 0.44230225682258606 0.001158599 0.0009654050312034789
0.9484963268146485 0.0034819277934730053 0.5810431838035583 0.0011854133 0.000569470022809396
0.8825969099340617 0.00867080595344305 0.5327261686325073 0.0025725565 0.0012701959095748497
0.9080910845538328 0.002510919002816081 0.2399367392063141 0.0032211505 0.00343969289304938
0.9480324321063105 0.002844954142346978 0.4856770932674408 0.0013086149 0.000940324873144024
0.9457521539229037 0.002965449122712016 0.5567314028739929 0.0016020562 0.0010848022759871096
0.9570257749123511 0.003242677077651024 0.551547646522522 0.0008975637 0.0012173392887325047
0.9397954113891743 0.005422518122941256 0.5455324649810791 0.000927

  0%|          | 0/60 [00:00<?, ?it/s]

0.9154868620139257 0.0066172960214316845 0.4961584508419037 0.0017745034 0.002543278565380434
0.9642320981368795 0.006384300999343395 0.5500259399414062 0.0013649181 0.00039939683681740443
0.9079729607755603 0.006008959375321865 0.37300702929496765 0.0020762528 0.0014472756485037958
0.8763989605407262 0.005456016398966312 0.5345309972763062 0.002699399 0.002583922630992378
0.7906779478820358 0.0080393236130476 0.46280401945114136 0.005396814 0.005774561766907865
0.7922195343775921 0.015604643151164055 0.6073015332221985 0.0047444575 0.006380106406505703
0.9177359197789294 0.0031468530651181936 0.22178848087787628 0.0029403344 0.0027248174595354484
0.9514445025535448 0.0033560378942638636 0.4851606488227844 0.001263042 0.0006360319893490662
0.9426013944875626 0.006628467235714197 0.5413895845413208 0.0017873532 0.00090466033200183
0.9610625047275662 0.0029928868170827627 0.5044963359832764 0.0007334817 0.0010626125178097078
0.7901228387601223 0.008226252160966396 0.3220846354961395 0.00

In [13]:
for task in tasks:
    print(task['name'])
    print('\tDice:\t\t\t', np.mean(task['val_dice_loss']))
    print('\tEntropy, T pred:\t', np.mean(task['val_uq_t']))
    print('\tEntropy, F pred:\t', np.mean(task['val_uq_f']))
    print('\tBrier score:\t\t', np.mean(task['val_brier']))
    print('\tECE:\t\t\t', np.mean(task['val_ece']))

lits17_baseline
	Dice:			 0.9197711241835075
	Entropy, T pred:	 0.010655689035990724
	Entropy, F pred:	 0.44242752260631985
	Brier score:		 0.007616109
	ECE:			 0.006655385713672428
lits17_cotraining
	Dice:			 0.918851795453653
	Entropy, T pred:	 0.008248843545852988
	Entropy, F pred:	 0.4476901309357749
	Brier score:		 0.007604486
	ECE:			 0.006729530265604857
brats20_baseline
	Dice:			 0.8308539899157409
	Entropy, T pred:	 0.0042847375671782005
	Entropy, F pred:	 0.552987730583629
	Brier score:		 0.0026166819
	ECE:			 0.0025080163081123074
brats20_cotraining
	Dice:			 0.7852598200116716
	Entropy, T pred:	 0.008808589250956839
	Entropy, F pred:	 0.5937974686558182
	Brier score:		 0.0036324786
	ECE:			 0.004424604106504179
kits21_baseline
	Dice:			 0.9073150491174469
	Entropy, T pred:	 0.004616886749863625
	Entropy, F pred:	 0.4909048224488894
	Brier score:		 0.0026897795
	ECE:			 0.0024189402502915315
kits21_cotraining
	Dice:			 0.8929010340926725
	Entropy, T pred:	 0.0072447173472028

In [None]:
1

In [None]:
yhat_bool.shape

In [45]:
target.shape

(1, 1, 400, 400, 208)

In [None]:
pred = torch.Tensor([1,1,1,0,0,0]) == 1

In [None]:
target = torch.Tensor([0,0,1,0,0,0]) == 1

In [None]:
uncertainty = torch.Tensor([0.5,0.5,0.5,0.1,0.1,0.1])

In [7]:
eval_uncertainty(pred, target, uncertainty)

(0.20000001788139343, 0.5)

In [8]:
eval_dice(pred[None,None,None,None,...], target[None,None,None,None,...])

0.5