In [42]:
import argparse
import os
import sys
import math
import numpy as np
import SimpleITK as sitk
import random
import pandas as pd
from typing import Tuple, Dict

import torch
from torch import Tensor
import torch.nn as nn
import torch.optim
import torch.utils.data
import torchvision
import torchvision.transforms as transforms

from monai.transforms import (
    Compose,
    Resize,
    RandRotate,
    Affine,
    RandGaussianNoise,
    RandZoom,
    RepeatChannel,
)
current_dir = os.path.dirname(os.path.realpath('__file__'))
data_dir = os.path.join(current_dir, 'data')
from training_pipnet_LR import get_network,get_optimizer_nn
sys.path.append(data_dir)
#from make_dataset import get_dataloaders
import make_dataset_LR
from make_dataset_LR import get_dataloaders,getAllDataloader,getAllDataset
# Construct the path to the models directory
models_dir = os.path.join(current_dir, 'models')

# Add the models directory to sys.path
sys.path.append(models_dir)
from resnet_features import video_resnet18_features
from pipnet import PIPNet,NonNegLinear
from train_model_custom import train_pipnet

from test_model import eval_pipnet


import plotly.graph_objects as go
import xarray as xr
import plotly.express as px

from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold
import joblib
import h5py
import openpyxl
from openpyxl import Workbook
import matplotlib.pyplot as plt
import matplotlib.cm as cm

In [16]:
def sliceViewer(images,labels, key: str, title: str, height:int):
    """
    
    """
    if len(images.shape)==3:
        newIm=RepeatChannel(3)(images.unsqueeze(0))
        newIm=torch.moveaxis(newIm,[0,1,2,3],[-1,0,1,2])
    else:
        newIm=images
    xrData = xr.DataArray(
        data   = newIm,
        dims   = [key, 'row', 'col', 'rgb'],
        coords = {key: labels}
    )
    # Hide the axes
    #layout_dict = dict(yaxis_visible=False, yaxis_showticklabels=False, xaxis_visible=False, xaxis_showticklabels=False)
    layout_dict=dict()
    return px.imshow(xrData, title=title, animation_frame=key).update_layout(layout_dict)

In [13]:
inputData='data/syntheticData_balls_LR.h5'
outputFile='data/syntheticData_balls_LR_fixed.h5'
with h5py.File(inputData, 'r') as f:
    shapesL=[f[k]['L'][:].shape for k in f.keys()]
    shapesR=[f[k]['R'][:].shape for k in f.keys()]

In [14]:
finalShape=shapesR[0]
finalShape

(54, 121, 74)

In [15]:
shapesL[0]

(121, 54, 74)

In [16]:
inputData='data/syntheticData_balls_LR.h5'
outputFile='data/syntheticData_balls_LR_fixed.h5'
with h5py.File(f'{outputFile}','w') as patientData:
    with h5py.File(inputData, 'r') as f:
        for k in f.keys():
            grp=patientData.create_group(k)
            grp.attrs['LRflag']=f[k].attrs['LRflag']
            grp.create_dataset("L",data=np.moveaxis(f[k]["L"][:],[0,1,2],[1,0,2]).astype(np.float32),compression="gzip", compression_opts=6)
            grp.create_dataset("R",data=f[k]["R"][:].astype(np.float32),compression="gzip", compression_opts=6)


## Testing network throughput issues

In [19]:
inputData='data/syntheticData_balls_LR_fixed.h5'
with h5py.File(inputData, 'r') as f:
    keys=f.keys()
    print(keys)
    k='10'    
    LRflag=f[k].attrs['LRflag']
    datLR=[f[k]["L"][:],f[k]["R"][:]]



<KeysViewHDF5 ['0', '1', '10', '100', '101', '102', '103', '104', '105', '106', '107', '108', '109', '11', '110', '111', '112', '113', '114', '115', '116', '117', '118', '119', '12', '120', '121', '122', '123', '124', '125', '126', '127', '128', '129', '13', '130', '131', '132', '133', '134', '135', '136', '137', '138', '139', '14', '140', '141', '142', '143', '144', '145', '146', '147', '148', '149', '15', '150', '151', '152', '153', '154', '155', '156', '157', '158', '159', '16', '160', '161', '162', '163', '164', '165', '166', '167', '168', '169', '17', '170', '171', '172', '173', '174', '175', '176', '177', '178', '179', '18', '180', '181', '182', '183', '184', '185', '186', '187', '188', '189', '19', '190', '191', '192', '193', '194', '195', '196', '197', '198', '199', '2', '20', '200', '201', '202', '203', '204', '205', '206', '207', '208', '209', '21', '210', '211', '212', '213', '214', '215', '216', '217', '218', '219', '22', '220', '221', '222', '223', '224', '225', '226', '22

In [34]:
#arr,label=torch.tensor(datLR[LRflag[0]]).unsqueeze(0),0 # always guarantees empty scan
arr,label=torch.tensor(datLR[0]).unsqueeze(0),LRflag[0] # always guarantees empty scan
arr=RepeatChannel(repeats=3)(arr)
arr.shape

torch.Size([3, 54, 121, 74])

In [35]:
volume=arr
volumeRGB=np.moveaxis(volume,[0,1,2,3],[-1,0,1,2])

sliceViewer(volumeRGB,[i for i in range(len(volumeRGB))],key="Depth",title="test",height=700)

In [36]:

if arr.max()>arr.min():
    print("true")

true


In [17]:
LRflag

array([1, 0])

In [37]:
args={
    'inputData':inputData,
    'batch_size':10,
    'num_classes':1,
    'epochs':5,
    'epochs_pretrain':1,
    'freeze_epochs':0,
    'epochs_finetune':0,
    'seed':44,
    'experiment_folder':'data/experiment_1',
    'lr':.0001,
    'lr_net':.0001,
    'lr_block':.0001,
    'lr_class':.05,
    'lr_backbone':.0001,
    'weight_decay':0,
    'gamma':.1,
    'step_size':1,
    'channels':3,
    'net':"3Dresnet18",
    'num_features':0,
    'bias':False,
    'out_shape':1,
    'disable_pretrained':False,
    'optimizer':'Adam',
    'state_dict_dir_net':'',
    'log_dir':'logs/kFold3',
    "dic_classes":{False:0,True:1},
    'val_split':.05,
    'test_split':.2,
    'defaultFinetune':True,
    'lr_finetune':.05,
    'flipTrain':False,
    'stratSampling':True,
    'excludePatients':['735','322','531','523','876','552'],
    'log_power':1,
    'img_shape':[54,121,74],
    'wshape':5, # this is assigned mid script and doesn't matter here
    'hshape':8, # these matter and should bechanged to correct vals for the analyzing_network
    'dshape':7,
    'backboneStrides':[1,2,2,2],
    'verbose':False,
}


In [40]:
dataloaders=get_dataloaders(dataset_h5path=inputData,
                                k_fold=5,
                                test_p=args['test_split'],
                                val_p=args['val_split'],
                                batchSize=args['batch_size'],
                                seed=args['seed'],
                                kMeansSaveDir=None,
                                flipTrain=args['flipTrain'],
                                stratSampling=args['stratSampling'],
                                excludePatients=args['excludePatients'],
                                )

trainloader = dataloaders[0]
trainloader_pretraining = dataloaders[1]
trainloader_normal = dataloaders[2] 
trainloader_normal_augment = dataloaders[3]
projectloader = dataloaders[4]
valloader = dataloaders[5]
testloader = dataloaders[6] 
test_projectloader = dataloaders[7]

In [43]:
useGPU=True
devID=0
if useGPU:
    device=torch.device(f'cuda:{devID}')
else:
    device=torch.device('cpu')


network_layers = get_network(num_classes=args['num_classes'], args=args)
feature_net = network_layers[0]
add_on_layers = network_layers[1]
pool_layer = network_layers[2]
classification_layer = network_layers[3]
num_prototypes = network_layers[4]
newFeatures=feature_net
"""
### let's try hacking in our layer here?
testLayer=[nn.Conv3d(in_channels = 1, 
                out_channels = 3, 
                kernel_size =1, 
                stride = 1, 
                padding = 1, 
                bias = True),]
newFeatures=nn.Sequential(testLayer[0],feature_net)
"""


classification_layer.normalization_multiplier=nn.Parameter(
        torch.ones((1,), requires_grad = True)*args['log_power'])

net = PIPNet(
    num_classes = args['num_classes'],
    num_prototypes = num_prototypes,
    feature_net = newFeatures,
    args = args,
    add_on_layers = add_on_layers,
    pool_layer = pool_layer,
    classification_layer = classification_layer
    )

net = net.to(device=device)

net = nn.DataParallel(net, device_ids = [0])  

optimizer = get_optimizer_nn(net, args)
optimizer_net = optimizer[0]
optimizer_classifier = optimizer[1] 
params_to_freeze = optimizer[2] 
params_to_train = optimizer[3] 
params_backbone = optimizer[4]   

Number of prototypes:  512
Network is  3Dresnet18
