In [None]:
import os
import sys
import datetime
import math
import numpy as np
import matplotlib.pyplot as plt
import random
from imageio import imread
import json
import torch
torch.cuda.empty_cache()

from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd

current_dir = os.path.dirname(os.path.realpath('__file__'))
import utils
from utils import plot_3d_slices
from utils import set_seeds
from utils import set_device
#from utils import get_optimizer_nn
from utils import init_weights_xavier
from utils import get_patch_size,generate_rgb_array
from utils import Log
data_dir = os.path.join(current_dir, 'data')

from training_pipnet_LR import get_network,get_optimizer_nn
sys.path.append(data_dir)
#from make_dataset import get_dataloaders
import make_dataset_LR
from make_dataset_LR import get_dataloaders,getAllDataloader,getAllDataset
# Construct the path to the models directory
models_dir = os.path.join(current_dir, 'models')

# Add the models directory to sys.path
sys.path.append(models_dir)
from resnet_features import video_resnet18_features
from pipnet import PIPNet,NonNegLinear
from train_model_custom import train_pipnet

from test_model import eval_pipnet

vis_dir=os.path.join(current_dir, 'visualization')
sys.path.append(vis_dir)
import vis_pipnet
#from vis_pipnet import visualize, visualize_topk
from vis_pipnet import get_img_coordinates,plot_rgb_slices,plot_local_explanation
import plotly.graph_objects as go
import xarray as xr
import plotly.express as px

from scipy.ndimage import binary_erosion
from monai.transforms import (
    Compose,
    Resize,
    RandRotate,
    Affine,
    RandGaussianNoise,
    RandZoom,
    RepeatChannel,
)
import math
import joblib
import h5py
from importlib import reload


In [None]:

args={
    'log_dir':'logs/OPNorm9995_tan2_backbone1en4_fold1',
    'seed':42,
    'experiment_folder':'data/experiment_1',
    'lr':.0001,
    'lr_net':.0001,
    'lr_block':.0001,
    'lr_class':.0001,
    'lr_backbone':.0001,
    'weight_decay':0,
    'gamma':.1,
    'step_size':1,
    'batch_size':15,
    'epochs':160,
    'epochs_pretrain':30,
    'freeze_epochs':0,
    'epochs_finetune':10,
    'num_classes':2,
    'channels':3,
    'net':"3Dresnet18",
    'num_features':0,
    'bias':False,
    'out_shape':1,
    'disable_pretrained':False,
    'optimizer':'Adam',
    'state_dict_dir_net':'',
    "dic_classes":{False:0,True:1},
    'val_split':.05,
    'test_split':.2,
    'defaultFinetune':True,
    'lr_finetune':.05,
    'flipTrain':False,
    'stratSampling':True,
    'excludePatients':['735','322','531','523','876','552'],
    'log_power':1,
    'img_shape':[54,121,74],
    'wshape':5, # this is assigned mid script and doesn't matter here
    'hshape':8, # these matter and should bechanged to correct vals for the analyzing_network
    'dshape':7,
    'backboneStrides':[1,2,2,2],
    'patchsize':15
}

channels=3
aug_prob = 1
rand_rot = 10                       # random rotation range [deg]
rand_rot_rad = rand_rot*math.pi/180 # random rotation range [rad]
rand_noise_std = 0.01               # std random Gaussian noise
rand_shift = 5                      # px random shift
min_zoom = 0.9
max_zoom = 1.1
transforms_dic = {
    'train': Compose([
        RandRotate(range_x=rand_rot_rad, 
                    range_y=rand_rot_rad, 
                    range_z=rand_rot_rad, 
                    prob=aug_prob),
        RandGaussianNoise(std=rand_noise_std, prob=aug_prob),
        Affine(translate_params=(rand_shift,
                                    rand_shift,
                                    rand_shift), 
                image_only=True),
        RandZoom(min_zoom=min_zoom, max_zoom=max_zoom, prob=aug_prob),
        RepeatChannel(repeats=channels),
    ]),
    'train_noaug': Compose([RepeatChannel(repeats=channels)]),
    'project_noaug':Compose([RepeatChannel(repeats=channels)]),
    'val': Compose([RepeatChannel(repeats=channels)]),
    'test': Compose([RepeatChannel(repeats=channels)]),
    'test_projection': Compose([RepeatChannel(repeats=channels)]),
}

downSample=3.2
lowerBound=.15
#inputData=f'data/FP923_LR_avgCrop_DS{int(downSample*10)}_point{int(lowerBound*100)}Thresh.h5'
inputData=f'data/FP_LR_OPNorm_avgcrop_DS{int(downSample*10)}_point{int(lowerBound*100)}Thresh.h5'
#inputData=f'data/syntheticData_balls_LR_fixed.h5'

In [None]:
useGPU=True
devID=0
if useGPU:
    device=torch.device(f'cuda:{devID}')
else:
    device=torch.device('cpu')
#yflags=pd.read_csv("../duke/ClinicalFlags.csv",index_col=0)


dataloaders=get_dataloaders(dataset_h5path=inputData,
                            k_fold=5,
                            test_p=.2,
                            val_p=.05,
                            batchSize=args['batch_size'],
                            seed=args['seed'],
                            kMeansSaveDir="data/kMeans_DS32.json")

trainloader = dataloaders[0]
trainloader_pretraining = dataloaders[1]
trainloader_normal = dataloaders[2] 
trainloader_normal_augment = dataloaders[3]
projectloader = dataloaders[4]
valloader = dataloaders[5]
testloader = dataloaders[6] 
test_projectloader = dataloaders[7]

allData=getAllDataset(inputData)
inputKeys=allData.subsetKeys


network_layers = get_network(num_classes=args['num_classes'], args=args)
feature_net = network_layers[0]
add_on_layers = network_layers[1]
pool_layer = network_layers[2]
classification_layer = network_layers[3]
num_prototypes = network_layers[4]
newFeatures=feature_net
net = PIPNet(
        num_classes = args['num_classes'],
        num_prototypes = num_prototypes,
        feature_net = newFeatures,
        args = args,
        add_on_layers = add_on_layers,
        pool_layer = pool_layer,
        classification_layer = classification_layer
        )
net = net.to(device=device)
net = nn.DataParallel(net, device_ids = [0])  


optimizer = get_optimizer_nn(net, args)
optimizer_net = optimizer[0]
optimizer_classifier = optimizer[1] 
params_to_freeze = optimizer[2] 
params_to_train = optimizer[3] 
params_backbone = optimizer[4]   

checkpointFile=f"{args['log_dir']}/checkpoints/net_trained_last"
checkpoint = torch.load(checkpointFile, map_location = device)
net.load_state_dict(checkpoint['model_state_dict'], strict = True) 
net.module._multiplier.requires_grad = False
try:
    optimizer_net.load_state_dict(
        checkpoint['optimizer_net_state_dict']) 
except:
    print("optimizer failed load")

In [None]:
def sliceViewer(images,labels, key: str, title: str, height:int):
    if len(images.shape)==3:
        newIm=RepeatChannel(3)(images.unsqueeze(0))
        newIm=torch.moveaxis(newIm,[0,1,2,3],[-1,0,1,2])
    else:
        newIm=images
    xrData = xr.DataArray(
        data   = newIm,
        dims   = [key, 'row', 'col', 'rgb'],
        coords = {key: labels}
    )
    # Hide the axes
    #layout_dict = dict(yaxis_visible=False, yaxis_showticklabels=False, xaxis_visible=False, xaxis_showticklabels=False)
    layout_dict=dict()
    return px.imshow(xrData, title=title, animation_frame=key).update_layout(layout_dict)


In [None]:
def doubleSliceViewer(images, labels1, labels2, key1: str, key2: str, title: str, height: int):
    if len(images.shape) != 5 or images.shape[-1] != 3:
        raise ValueError("Input images must be a 5D tensor (N, M, H, W, 3) where the last dimension represents RGB channels.")

    # Convert to numpy if necessary
    if isinstance(images, torch.Tensor):
        images = images.numpy()

    # Ensure the image data is in the [0, 255] range for display
    if images.max() <= 1.0:
        images = (images * 255).astype(np.uint8)

    # Initial image from the first slice in both dimensions
    initial_image = images[0, 0]

    # Create the figure and add the initial image
    fig = go.Figure(data=go.Image(z=initial_image))

    # Prepare the frames
    frames = []
    for i in range(images.shape[0]):
        for j in range(images.shape[1]):
            frames.append(go.Frame(
                data=go.Image(z=images[i, j]),
                name=f'{i}-{j}'
            ))

    fig.frames = frames

    # Define the sliders
    slider1 = {
        "active": 0,
        "currentvalue": {"prefix": f"{key1}: "},
        "pad": {"t": 100},
        "steps": [],
        "y":0.20
    }

    slider2 = {
        "active": 0,
        "currentvalue": {"prefix": f"{key2}: "},
        "pad": {"t": 150},
        "steps": [],
        "y":0
    }

    # Update layout with sliders
    fig.update_layout(
        title=title,
        sliders=[slider1, slider2],
        height=height,
        margin=dict(t=80, b=250, l=40, r=40)
    )

    # Add steps to sliders dynamically
    fig.layout.sliders[0].steps = [
        {
            "args": [
                [f"{i}-{fig.layout.sliders[1].active}"],  # Reference to second slider's active value
                {"frame": {"duration": 0, "redraw": True}, "mode": "immediate"}
            ],
            "label": str(label1),
            "method": "animate",
        }
        for i, label1 in enumerate(labels1)
    ]

    fig.layout.sliders[1].steps = [
        {
            "args": [
                [f"{fig.layout.sliders[0].active}-{j}"],  # Reference to first slider's active value
                {"frame": {"duration": 0, "redraw": True}, "mode": "immediate"}
            ],
            "label": str(label2),
            "method": "animate",
        }
        for j, label2 in enumerate(labels2)
    ]

    return fig

def generate_unique_colormap(size, colormap_name='viridis'):
    """
    Generates a unique colormap of a given size.

    :param size: Number of unique colors needed.
    :param colormap_name: Name of the matplotlib colormap to use (default is 'viridis').
    :return: A list of RGBA color tuples.
    """
    cmap = plt.get_cmap(colormap_name, size)  # Create a colormap with `size` unique colors
    return cmap(np.arange(size))  # Generate the colormap array

In [None]:
volumes=[]
for i in range(5):
    arr,label=allData[inputKeys[i]]
    args['img_shape']=list(arr.shape[1:])
    arr=RepeatChannel(repeats=3)(arr)
    arr.shape


    volume=arr
    volumeRGB=np.moveaxis(volume,[0,1,2,3],[-1,0,1,2])
    volumes.append(volumeRGB)

volumes=torch.tensor(np.array(volumes))
depthLen=volumes.shape[1]

In [None]:
labels1 = list(range(5))  # Labels for the first slider
labels2 = list(range(depthLen))   # Labels for the second slider
fig = doubleSliceViewer(volumes, labels1, labels2, 'slider1', 'slider2', 'Slice Viewer', 600)
fig.show()

In [None]:
def gradient_patchResize(key="20_R",thresh=.5,greyscaleGrad=True):
    classification_weights = net.module._classification.weight#.detach().cpu()
    relevantWeights=torch.gt(
        net.module._classification.weight, 1e-3).any(dim = 0) #used in eval_pipnet to count relevantProtos for excel
    topKProtos=torch.nonzero(relevantWeights).detach().cpu()

    arr,label=projectloader.dataset[key]
    transform=Compose([Resize(spatial_size=arr.shape[1:],mode="nearest")])
    print(f"resizeShape {arr.shape[1:]}")
    xs=arr.unsqueeze(0).to(device)
    xs.requires_grad=True
    features = net.module._net(xs)
    proto_features = net.module._add_on(features) #does any form of gradient accrue on features here?
    proto_features=proto_features.detach().cpu()
    protoThresh=proto_features > thresh
    topKProtos={int(k) for k in topKProtos if torch.any(protoThresh[0][k])}
    image_information={
        k:{"gradient":torch.zeros_like(arr),
           "patch":torch.zeros_like(arr), 
            "simweightTrue":-1,
            "simweightPred":-1,
           }
        for k in topKProtos
    }
    print(f"topKProtos : {topKProtos}")
    for proto in topKProtos:
        patchResize=transform(protoThresh[0][proto].unsqueeze(0))[0].detach().cpu().numpy()

        erosion_mask=binary_erosion(patchResize)
        #patches[proto]=patchResize-erosion_mask
        image_information[proto]['patch']=patchResize-erosion_mask
        releventIndices=[(i,j,k) for i in range(features.shape[2]) for j in range(features.shape[3]) for k in range(features.shape[4]) if protoThresh[0][proto][i][j][k]==1]
        if len(releventIndices)>0:
            
            i,j,k=releventIndices[0]
            ijkVec=torch.tensor([features[0][p][i][j][k] for p in topKProtos if p!=proto])

            #gradient=torch.autograd.grad(features[0][proto][i][j][k],xs,retain_graph=True)[0]
            gradient=torch.autograd.grad(features[0][proto][i][j][k]-torch.log(torch.sum(torch.exp(ijkVec))),xs,retain_graph=True)[0]
            
            for i,j,k in releventIndices[1:]:

                #gradient+=torch.autograd.grad(features[0][proto][i][j][k],xs,retain_graph=True)[0]
                ijkVec=torch.tensor([features[0][p][i][j][k] for p in topKProtos if p!=proto])
                gradient+=torch.autograd.grad(features[0][proto][i][j][k]-torch.log(torch.sum(torch.exp(ijkVec))),xs,retain_graph=True)[0]
            gradient=gradient[0].detach().cpu().numpy()
            #absolute val + normalization
            gradient=np.abs(gradient)
            if gradient.max()>gradient.min():
                gradient=(gradient-gradient.min())/(gradient.max()-gradient.min())
            #print(f"gradientshape {gradient.shape}")
            if greyscaleGrad:
                gradient=np.array([gradient.mean(axis=0) for c in range(3)])
            image_information[proto]['gradient']=gradient
    
    with torch.no_grad():

        ### idk if I can do this earlier since operations are being performed on grad features? 

        proto_features, clamped_pooled, out = net(xs)
        proto_features=proto_features.detach().cpu()
        clamped_pooled=clamped_pooled.detach().cpu()
        out = out.detach().cpu()

        for proto in topKProtos:
            
            image_information[proto]['simweightTrue']=clamped_pooled[0][int(proto)]*classification_weights[int(label),int(proto)]
            image_information[proto]['simweightPred']=clamped_pooled[0][int(proto)]*classification_weights[np.argmax(out[0].numpy()),int(proto)]

    return image_information,label,out


def plot_gradient_patchResize(key=42,thresh=.5,greyscaleGrad=True,overlayOriginal=False,title="grad explanation",height=1200,returnImageInfo=True, proto_colormap_name='viridis'):    
    image_information,label,out=gradient_patchResize(key=key,thresh=thresh,greyscaleGrad=greyscaleGrad)
    #sort by either simweightTrue or simweightPred
    proto_information_sorted=[(k,v) for k,v in image_information.items()]
    proto_information_sorted.sort(key=lambda x: x[1]['simweightPred'],reverse=True)

    #protoVolumesRGB=np.array([i[1]['gradient'].copy() for i in proto_information_sorted])
    protoVolumesRGB=np.array([np.moveaxis(i[1]['gradient'],[0,1,2,3],[-1,0,1,2]) for i in proto_information_sorted])
    #shape is (k,d,r,c,3)
    cmap=generate_unique_colormap(size=len(proto_information_sorted),colormap_name=proto_colormap_name) #color by proto
    
    for i,(k,v) in enumerate(proto_information_sorted):
        for c in range(3):
            protoVolumesRGB[i,:,:,:,c]+=v['patch']*cmap[i][c]
    if overlayOriginal:
        arr,label=projectloader.dataset[key]
        protoVolumesRGB+=np.array([np.moveaxis(arr,[0,1,2,3],[-1,0,1,2]) for i in range(len(protoVolumesRGB))])
    for k in range(len(protoVolumesRGB)):
        maximum,minimum=protoVolumesRGB[k].max(),protoVolumesRGB[k].min()
        if maximum>minimum:
            protoVolumesRGB[k]=(protoVolumesRGB[k]-minimum)/(maximum-minimum)
    labels1 = list(k for k,v in proto_information_sorted)  # Labels for the first slider
    labels2 = list(range(proto_information_sorted[0][1]['patch'].shape[0]))   # Labels for the second slider
    fig = doubleSliceViewer(protoVolumesRGB, labels1, labels2, 'Proto', 'MRI depth', title, height)

    return fig,image_information

In [None]:
fig,image_information=plot_gradient_patchResize(key="20_R",thresh=.5,greyscaleGrad=False,overlayOriginal=True,title="grad explanation",height=800,returnImageInfo=True, proto_colormap_name='viridis')

In [None]:
fig.show()

In [None]:
key=42;thresh=.5;greyscaleGrad=True;overlayOriginal=False;title="grad explanation";height=1200;returnImageInfo=True; proto_colormap_name='viridis'
image_information,label,out=gradient_patchResize(key=key,thresh=thresh,greyscaleGrad=greyscaleGrad)
#sort by either simweightTrue or simweightPred
proto_information_sorted=[(k,v) for k,v in image_information.items()]
proto_information_sorted.sort(key=lambda x: x[1]['simweightPred'],reverse=True)

#protoVolumesRGB=np.array([i[1]['gradient'].copy() for i in proto_information_sorted])
protoVolumesRGB=np.array([np.moveaxis(i[1]['gradient'],[0,1,2,3],[-1,0,1,2]) for i in proto_information_sorted])
#shape is (k,d,r,c,3)
cmap=generate_unique_colormap(size=len(proto_information_sorted),colormap_name=proto_colormap_name) #color by proto

for i,(k,v) in enumerate(proto_information_sorted):
    for c in range(3):
        protoVolumesRGB[i,:,:,:,c]+=v['patch']*cmap[i][c]

In [None]:
arr,label=projectloader.dataset[key]

volume=arr
volumeRGB=np.moveaxis(volume,[0,1,2,3],[-1,0,1,2])

sliceViewer(volumeRGB,[i for i in range(len(volumeRGB))],key="Depth",title="test",height=700)