In [34]:
import os
import sys
import datetime
import math
import numpy as np
import matplotlib.pyplot as plt
import random
from imageio import imread
import json
import torch
torch.cuda.empty_cache()

from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd

current_dir = os.path.dirname(os.path.realpath('__file__'))
import utils
from utils import plot_3d_slices
from utils import set_seeds
from utils import set_device
#from utils import get_optimizer_nn
from utils import init_weights_xavier
from utils import get_patch_size,generate_rgb_array
from utils import Log
from utils import topk_accuracy
data_dir = os.path.join(current_dir, 'data')

from training_pipnet_LR import get_network,get_optimizer_nn
sys.path.append(data_dir)
#from make_dataset import get_dataloaders
import make_dataset_LR
from make_dataset_LR import get_dataloaders,getAllDataloader,getAllDataset
# Construct the path to the models directory
models_dir = os.path.join(current_dir, 'models')

# Add the models directory to sys.path
sys.path.append(models_dir)
from resnet_features import video_resnet18_features
from pipnet import PIPNet,NonNegLinear
from train_model_custom import train_pipnet

from test_model import eval_pipnet

vis_dir=os.path.join(current_dir, 'visualization')
sys.path.append(vis_dir)
import vis_pipnet
#from vis_pipnet import visualize, visualize_topk
from vis_pipnet import get_img_coordinates,plot_rgb_slices,plot_local_explanation
import plotly.graph_objects as go
import xarray as xr
import plotly.express as px

from scipy.ndimage import binary_erosion
from monai.transforms import (
    Compose,
    Resize,
    RandRotate,
    Affine,
    RandGaussianNoise,
    RandZoom,
    RepeatChannel,
)
import math
import joblib
import h5py
from importlib import reload


In [35]:

args={
    'log_dir':'logs/balls_tan2_backbone1en4_bs30',
    'num_classes':1,
    'seed':42,
    'experiment_folder':'data/experiment_1',
    'lr':.0001,
    'lr_net':.0001,
    'lr_block':.0001,
    'lr_class':.0001,
    'lr_backbone':.0001,
    'weight_decay':0,
    'gamma':.1,
    'step_size':1,
    'batch_size':15,
    'epochs':160,
    'epochs_pretrain':30,
    'freeze_epochs':0,
    'epochs_finetune':10,
    'channels':3,
    'net':"3Dresnet18",
    'num_features':0,
    'bias':False,
    'out_shape':1,
    'disable_pretrained':False,
    'optimizer':'Adam',
    'state_dict_dir_net':'',
    'log_dir':'logs/kFold3',
    "dic_classes":{False:0,True:1},
    'val_split':.05,
    'test_split':.2,
    'defaultFinetune':True,
    'lr_finetune':.05,
    'flipTrain':False,
    'stratSampling':True,
    'excludePatients':['735','322','531','523','876','552'],
    'log_power':1,
    'img_shape':[54,121,74],
    'wshape':5, # this is assigned mid script and doesn't matter here
    'hshape':8, # these matter and should bechanged to correct vals for the analyzing_network
    'dshape':7,
    'backboneStrides':[1,2,2,2],
}

channels=3
aug_prob = 1
rand_rot = 10                       # random rotation range [deg]
rand_rot_rad = rand_rot*math.pi/180 # random rotation range [rad]
rand_noise_std = 0.01               # std random Gaussian noise
rand_shift = 5                      # px random shift
min_zoom = 0.9
max_zoom = 1.1
transforms_dic = {
    'train': Compose([
        RandRotate(range_x=rand_rot_rad, 
                    range_y=rand_rot_rad, 
                    range_z=rand_rot_rad, 
                    prob=aug_prob),
        RandGaussianNoise(std=rand_noise_std, prob=aug_prob),
        Affine(translate_params=(rand_shift,
                                    rand_shift,
                                    rand_shift), 
                image_only=True),
        RandZoom(min_zoom=min_zoom, max_zoom=max_zoom, prob=aug_prob),
        RepeatChannel(repeats=channels),
    ]),
    'train_noaug': Compose([RepeatChannel(repeats=channels)]),
    'project_noaug':Compose([RepeatChannel(repeats=channels)]),
    'val': Compose([RepeatChannel(repeats=channels)]),
    'test': Compose([RepeatChannel(repeats=channels)]),
    'test_projection': Compose([RepeatChannel(repeats=channels)]),
}

downSample=3.2
lowerBound=.15
#inputData=f'data/FP923_LR_avgCrop_DS{int(downSample*10)}_point{int(lowerBound*100)}Thresh.h5'
#inputData=f'data/FP_LR_OPNorm_avgcrop_DS{int(downSample*10)}_point{int(lowerBound*100)}Thresh.h5'
inputData=f'data/syntheticData_balls_LR_fixed.h5'


In [36]:
def sliceViewer(images,labels, key: str, title: str, height:int):
    if len(images.shape)==3:
        newIm=RepeatChannel(3)(images.unsqueeze(0))
        newIm=torch.moveaxis(newIm,[0,1,2,3],[-1,0,1,2])
    else:
        newIm=images
    xrData = xr.DataArray(
        data   = newIm,
        dims   = [key, 'row', 'col', 'rgb'],
        coords = {key: labels}
    )
    # Hide the axes
    #layout_dict = dict(yaxis_visible=False, yaxis_showticklabels=False, xaxis_visible=False, xaxis_showticklabels=False)
    layout_dict=dict()
    return px.imshow(xrData, title=title, animation_frame=key).update_layout(layout_dict)


In [37]:
useGPU=True
devID=0
if useGPU:
    device=torch.device(f'cuda:{devID}')
else:
    device=torch.device('cpu')
#yflags=pd.read_csv("../duke/ClinicalFlags.csv",index_col=0)


dataloaders=get_dataloaders(dataset_h5path=inputData,
                            k_fold=5,
                            test_p=.2,
                            val_p=.05,
                            batchSize=args['batch_size'],
                            seed=args['seed'],
                            kMeansSaveDir="data/kMeans_DS32.json")

trainloader = dataloaders[0]
trainloader_pretraining = dataloaders[1]
trainloader_normal = dataloaders[2] 
trainloader_normal_augment = dataloaders[3]
projectloader = dataloaders[4]
valloader = dataloaders[5]
testloader = dataloaders[6] 
test_projectloader = dataloaders[7]

allData=getAllDataset(inputData)
inputKeys=allData.subsetKeys

In [38]:
#arr,label=projectloader.dataset["109_R"]
arr,label=allData['20_L']
args['img_shape']=list(arr.shape[1:])
arr=RepeatChannel(repeats=3)(arr)
arr.shape



torch.Size([3, 54, 121, 74])

In [39]:
network_layers = get_network(num_classes=args['num_classes'], args=args)
feature_net = network_layers[0]
add_on_layers = network_layers[1]
pool_layer = network_layers[2]
classification_layer = network_layers[3]
num_prototypes = network_layers[4]
newFeatures=feature_net
net = PIPNet(
        num_classes = args['num_classes'],
        num_prototypes = num_prototypes,
        feature_net = newFeatures,
        args = args,
        add_on_layers = add_on_layers,
        pool_layer = pool_layer,
        classification_layer = classification_layer
        )
net = net.to(device=device)
net = nn.DataParallel(net, device_ids = [0])  

Number of prototypes:  512


In [40]:
optimizer = get_optimizer_nn(net, args)
optimizer_net = optimizer[0]
optimizer_classifier = optimizer[1] 
params_to_freeze = optimizer[2] 
params_to_train = optimizer[3] 
params_backbone = optimizer[4]   

Network is  3Dresnet18


In [41]:
key="20_R"
arr,label=projectloader.dataset[key]
xs=arr.unsqueeze(0).to(device)
proto_features, pooled, out = net(xs)
features = net.module._net(xs)
proto_features = net.module._add_on(features)

### to implement a logit style 1D/class output, we have to recreate the get_network function to work

In [42]:
featureNet=video_resnet18_features(
        pretrained = not args['disable_pretrained'],
        backboneStrides=[1,1,1,1])

In [43]:
xs.shape

torch.Size([1, 3, 54, 121, 74])

In [44]:
with torch.no_grad():
    test=feature_net(xs)

In [45]:
test

metatensor([[[[[0.0000e+00, 1.5382e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00],
           [1.0543e+00, 1.2814e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
           [1.4917e+00, 2.0564e+00, 9.5572e-01, 0.0000e+00, 0.0000e+00],
           ...,
           [1.2670e+00, 1.2737e+00, 1.2129e+00, 0.0000e+00, 0.0000e+00],
           [9.8586e-01, 1.3235e+00, 9.6965e-01, 0.0000e+00, 0.0000e+00],
           [1.2281e+00, 9.7179e-01, 4.8838e-01, 4.2377e-01, 0.0000e+00]],

          [[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
           [9.2707e-01, 8.3819e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],
           [1.9776e+00, 2.6259e+00, 1.5099e+00, 0.0000e+00, 0.0000e+00],
           ...,
           [2.6422e+00, 2.6556e+00, 1.3293e+00, 0.0000e+00, 0.0000e+00],
           [2.8108e+00, 3.2037e+00, 1.8240e+00, 2.6958e-01, 0.0000e+00],
           [2.1532e+00, 2.1020e+00, 1.2966e+00, 8.2301e-01, 3.9834e-01]],

          [[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
           

In [46]:
features.detach()

metatensor([[[[[0.0000e+00, 1.5382e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00],
           [1.0543e+00, 1.2814e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
           [1.4917e+00, 2.0564e+00, 9.5572e-01, 0.0000e+00, 0.0000e+00],
           ...,
           [1.2670e+00, 1.2737e+00, 1.2129e+00, 0.0000e+00, 0.0000e+00],
           [9.8586e-01, 1.3235e+00, 9.6965e-01, 0.0000e+00, 0.0000e+00],
           [1.2281e+00, 9.7179e-01, 4.8838e-01, 4.2377e-01, 0.0000e+00]],

          [[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
           [9.2707e-01, 8.3819e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],
           [1.9776e+00, 2.6259e+00, 1.5099e+00, 0.0000e+00, 0.0000e+00],
           ...,
           [2.6422e+00, 2.6556e+00, 1.3293e+00, 0.0000e+00, 0.0000e+00],
           [2.8108e+00, 3.2037e+00, 1.8240e+00, 2.6958e-01, 0.0000e+00],
           [2.1532e+00, 2.1020e+00, 1.2966e+00, 8.2301e-01, 3.9834e-01]],

          [[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
           

In [47]:
proto_features

metatensor([[[[[3.1494e-04, 2.3234e-04, 6.7272e-04, 4.8546e-04, 7.9586e-04],
           [1.6236e-04, 1.8883e-04, 3.2239e-04, 1.6554e-04, 1.9886e-04],
           [9.2807e-04, 1.2435e-03, 1.1149e-03, 3.1123e-04, 5.0694e-04],
           ...,
           [1.2420e-03, 8.8812e-04, 1.6644e-03, 5.7298e-04, 9.2680e-04],
           [7.3252e-04, 4.3813e-04, 5.8092e-04, 2.1177e-04, 5.7553e-04],
           [1.9089e-03, 7.6891e-04, 8.1606e-04, 6.2296e-04, 8.7818e-04]],

          [[4.0600e-04, 3.1956e-04, 7.2677e-04, 6.2655e-04, 1.1489e-03],
           [2.8695e-04, 1.5419e-04, 3.4295e-04, 1.6680e-04, 4.1660e-04],
           [2.8212e-03, 3.7930e-03, 2.5336e-03, 3.9198e-04, 7.7235e-04],
           ...,
           [9.2153e-03, 5.7442e-03, 2.2163e-03, 5.6786e-04, 9.2189e-04],
           [7.8970e-03, 4.1134e-03, 1.5764e-03, 2.2814e-04, 4.3912e-04],
           [7.3897e-03, 3.1960e-03, 1.7070e-03, 9.2869e-04, 1.3890e-03]],

          [[4.1221e-04, 2.2994e-04, 6.2565e-04, 6.0654e-04, 1.0890e-03],
           

In [48]:
pooled

tensor([[0.0165, 0.0244, 0.0305, 0.0260, 0.0477, 0.0322, 0.0327, 0.0339, 0.0297,
         0.0989, 0.0309, 0.0110, 0.0430, 0.0171, 0.0726, 0.0376, 0.0633, 0.0278,
         0.0347, 0.0141, 0.0179, 0.0500, 0.0145, 0.1788, 0.0104, 0.0390, 0.0172,
         0.0334, 0.0287, 0.0178, 0.0304, 0.0290, 0.0305, 0.0798, 0.0512, 0.0142,
         0.0905, 0.0451, 0.0433, 0.0791, 0.0293, 0.0059, 0.0247, 0.0109, 0.0186,
         0.0272, 0.0201, 0.0400, 0.0150, 0.0558, 0.2093, 0.0186, 0.0065, 0.0152,
         0.0681, 0.0442, 0.0507, 0.0230, 0.1390, 0.0158, 0.0264, 0.0340, 0.1425,
         0.0703, 0.1678, 0.0303, 0.0172, 0.0232, 0.0494, 0.0306, 0.0469, 0.0139,
         0.0532, 0.0484, 0.0456, 0.0235, 0.0362, 0.0180, 0.0205, 0.0154, 0.0098,
         0.0125, 0.0467, 0.0515, 0.0334, 0.0194, 0.1386, 0.0230, 0.0202, 0.0136,
         0.0106, 0.0225, 0.0178, 0.0228, 0.0283, 0.0525, 0.0753, 0.0235, 0.0744,
         0.0413, 0.0128, 0.0625, 0.0271, 0.1282, 0.0203, 0.0127, 0.0150, 0.0201,
         0.0402, 0.0434, 0.0

In [49]:
criterion = nn.NLLLoss(reduction='mean').to(device)
torch.tensor([out[0][0],0])### logit on this

tensor([0.5328, 0.0000])

In [50]:
batch=next(iter(trainloader))


In [51]:
batch[0].shape

torch.Size([15, 3, 54, 121, 74])

In [52]:
proto_features, pooled, out = net(batch[0].to(device))

In [53]:
ys=batch[2].to(device)
ys

tensor([0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1], device='cuda:0')

In [54]:
out

tensor([[0.5983],
        [0.3002],
        [0.4997],
        [0.5618],
        [0.4294],
        [0.5422],
        [0.3146],
        [0.3168],
        [0.5111],
        [0.5596],
        [0.4536],
        [0.3858],
        [0.5878],
        [0.4962],
        [0.3597]], device='cuda:0', grad_fn=<MmBackward0>)

In [55]:
smallNum=5e-8 # exact value doesn't really matter

###tanh_output = torch.tanh(out) # OLD -which dim do we do this when we're considering batchsize
tanh_output = (1-smallNum)*torch.tanh(out)+smallNum/2# just scrunching the tanh for numerical stability
criterion_input=torch.log(torch.tensor([[ 1-i[0],i[0]] for i in tanh_output])).to(out.device)
class_loss=criterion(criterion_input,ys)

In [56]:
class_loss

tensor(0.8716, device='cuda:0')

In [57]:
torch.atanh(torch.tensor([.5]))

tensor([0.5493])

In [58]:
len(out[0])

1

In [59]:
atanhPoint5=0.549306144334055
outPred=torch.tensor([1  if i[0]==True else 0 for i in out>atanhPoint5]).to(out.device)
outPred

tensor([1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0], device='cuda:0')

In [60]:
correct = torch.sum(torch.eq(outPred, ys))
acc = correct.item() / float(len(ys))

In [61]:
acc

0.2

### eval pipnet issues

In [62]:
def acc_from_cm(cm: np.ndarray) -> float:

    """
    Compute the accuracy from the confusion matrix
    :param cm: confusion matrix
    :return: the accuracy score
    """
    assert len(cm.shape) == 2 and cm.shape[0] == cm.shape[1]

    correct = 0
    for i in range(len(cm)):
        correct += cm[i, i]

    total = np.sum(cm)
    if total == 0:
        return 1
    else:
        return correct / total

In [63]:
net.eval()
# Keep an info dict about the procedure
info = dict()
# Build a confusion matrix
if net.module._num_classes<2:
    cm = np.zeros((2, 2), 
                dtype = int)
else:
    cm = np.zeros((net.module._num_classes, net.module._num_classes), 
                dtype = int)
global_top1acc = 0.
global_top5acc = 0.
global_sim_anz = 0.
global_anz = 0.
local_size_total = 0.
y_trues = []
y_preds = []
y_preds_classes = []
abstained = 0

test_iter = iter(testloader)

# Iterate through the test set
for i, (xs, ys) in enumerate(test_iter):
    
    xs, ys = xs.to(device), ys.to(device)
    
    with torch.no_grad():
        
        # net.module._classification.weight.copy_(torch.clamp(
        #     net.module._classification.weight.data - 1e-3, min = 0.))
        
        # Use the model to classify this batch of input data
        _, pooled, out = net(xs, inference = True)
        if net.module._num_classes<2:
            atanhPoint5=0.549306144334055
            outN = torch.tensor([[atanhPoint5,i[0]] for i in out]).to(out.device)
            out=outN
        max_out_score, ys_pred = torch.max(out, dim=1) # max, max_idx
        
        ys_pred_scores = torch.amax(F.softmax((torch.log1p(
            out**net.module._classification.normalization_multiplier)),
            dim = 1), dim = 1) # class confidence scores
        
        abstained += (max_out_score.shape[0] - \
                        torch.count_nonzero(max_out_score))
            
        repeated_weight = net.module._classification.weight.unsqueeze(
            1).repeat(1, pooled.shape[0], 1)
        
        sim_scores_anz = torch.count_nonzero(torch.gt(torch.abs(
            pooled*repeated_weight), 1e-3).float(), dim = 2).float()
        
        local_size = torch.count_nonzero(torch.gt(
            torch.relu((pooled*repeated_weight) - 1e-3).sum(dim = 1), 
            0.).float(), dim = 1).float()
        
        local_size_total += local_size.sum().item()

        correct_class_sim_scores_anz = torch.diagonal(torch.index_select(
            sim_scores_anz, dim = 0, index = ys_pred), 0)
        
        global_sim_anz += correct_class_sim_scores_anz.sum().item()
        
        almost_nz = torch.count_nonzero(torch.gt(
            torch.abs(pooled), 1e-3).float(), dim = 1).float()
        
        global_anz += almost_nz.sum().item()
        
        # Update the confusion matrix
        if net.module._num_classes<2:
            cm_batch = np.zeros((2, 2), 
                        dtype = int)
        else:
            cm_batch = np.zeros((net.module._num_classes, net.module._num_classes), 
                        dtype = int)
        
        for y_pred, y_true in zip(ys_pred, ys):
            
            cm[y_true][y_pred] += 1
            cm_batch[y_true][y_pred] += 1
            
        acc = acc_from_cm(cm_batch)   

        (top1accs, top5accs) = topk_accuracy(out, ys, topk=[1,5])
        
        global_top1acc += torch.sum(top1accs).item()
        global_top5acc += torch.sum(top5accs).item()
        y_preds += ys_pred_scores.detach().tolist()     # predicted class' confidence scores
        y_trues += ys.detach().tolist()
        y_preds_classes += ys_pred.detach().tolist()    # predicted classes
    
    #del out
    #del pooled
    #del ys_pred

In [64]:
out

tensor([[0.5493, 0.0385]], device='cuda:0')

In [68]:
correct_class_sim_scores_anz = torch.diagonal(torch.index_select(
                sim_scores_anz, dim = 0, index = ys_pred), 0)

In [69]:
correct_class_sim_scores_anz

tensor([9.], device='cuda:0')

In [31]:
global_sim_anz

2241.0

In [34]:
print("PIP-Net abstained from a decision for", 
        abstained.item(), 
        "images", 
        flush = True)     
    
info['num non-zero prototypes'] = torch.gt(
    net.module._classification.weight, 1e-3).any(dim = 0).sum().item()

print(
    "sparsity ratio: ",
    (torch.numel(net.module._classification.weight) - torch.count_nonzero(
        torch.nn.functional.relu(net.module._classification.weight-1e-3)
        ).item()) / torch.numel(net.module._classification.weight), 
    flush = True)

info['confusion_matrix'] = cm
info['test_accuracy'] = acc_from_cm(cm)
info['top1_accuracy'] = global_top1acc/len(testloader.dataset)
info['top5_accuracy'] = global_top5acc/len(testloader.dataset)
info['almost_sim_nonzeros'] = global_sim_anz/len(testloader.dataset)
info['local_size_all_classes'] = local_size_total/len(testloader.dataset)
info['almost_nonzeros'] = global_anz/len(testloader.dataset)

PIP-Net abstained from a decision for 0 images
sparsity ratio:  0.00390625


In [26]:
cm

array([[0, 0],
       [1, 0]])