In [None]:
#load all desired packages
import sys
import pandas as pd
import numpy as np
import os, glob
import h5py


import os, sys
directory = os.path.abspath('')
sys.path.insert(1,os.path.join(directory,'src')) # setting path can also append directory.parent
#import utils
#import metrics
import datasets
import evaluate
#import plots
#import train
#from UNet import UNet

from matplotlib import pyplot as plt
%load_ext autoreload
%autoreload 2
%matplotlib inline

#--------- DEFINE FUNCTIONS HERE ------------
#Dataloader that opens up the int2 files and returns them with dimensions (256 x 256 x number of slices)
def open_int2(path_,dim_x=256, dim_y=256):
    img_raw = np.fromfile(path_, dtype='>i2')
    _img = img_raw.reshape((dim_x,dim_y,-1),order='F')
    _img = np.rot90(_img, axes=(1,0))
    _img = np.flip(_img, axis=1)
    return _img

def load_result(filename:str, key:str, data_type=np.float64):
    f = h5py.File(filename, 'r')
    data = f[key][()].astype(data_type)
    f.close()
    return data

# We need to infer on 256x256 images, here we reshape into the original dims
def resize_imgs(pred,dim_x=512,dim_y=512):
    _img = np.rot90(pred, axes=(1,0))
    _img = np.flip(_img, axis=1)
    _img = _img.reshape((dim_x,dim_y,-1),order='F') 
    _img = np.rot90(_img, axes=(1,0))
    _img = np.flip(_img, axis=1)
    return _img

def plot_QC(img_list:list,idx=12,clim:list=None,cmap:str='jet'):
    plt.figure(figsize = (10,6))
    counter = 1
    for img_path in img_list:
        if os.path.splitext(img_path)[1]=='.h5':
            img = load_result(img_path,'pred')
        elif os.path.splitext(img_path)[1]=='.int2': 
            img = open_int2(img_path).astype(np.float64)
        
        #QC plot
        plt.subplot(2,3,counter)
        if clim==None:
            plt.imshow(img[:,:,idx],cmap=cmap) 
        else:
            plt.imshow(img[:,:,idx],clim=clim,cmap=cmap) 
        patient = (img_path.replace('/data/path/','')).split('/')[0]
        plt.title(f'{patient} slice {idx}')

        counter = counter+1


In [None]:
# Load model
#set training and validation batch sizes, other training parameters, and path to labels

# If you would like to refer to the best model from training use ref_num 1031
#model_info_file= '/data/knee_mri9/mwtong/t1rho_map_synthesis/training/trained_model_info.csv'
#Load the csv with all your trained models
#df_modelInfo = pd.read_csv(model_info_file)
#ref_num = 1031
#row_idx = np.where(df_modelInfo['Ref']==float(ref_num))[0][0]
#df_best = df_modelInfo.iloc[row_idx]


model_path = '/data/knee_mri9/mwtong/t1rho_map_synthesis/code/code_py/checkpoints/run_604_best_model'
scaleInput_options  = ['clip', 0, 150]

print(model_path)
print(scaleInput_options)

In [None]:
#Load data

#list of paths to the data
data_path = '/data/folder'
t2_paths = glob.glob(f'{data_path}/*/*/reg/T2_Map.int2')

t1r_paths = [x.replace('/T2','/T1rho') for x in t2_paths]
syn_t1r_paths = [x.replace('T1rho_Map.int2','Syn_T1rho_Map.h5') for x in t1r_paths]
e1_paths = [x.replace('T1rho_Map.int2','Echo_e1.int2') for x in t1r_paths]

fig = plot_QC(e1_paths,cmap='gray')

fig = plot_QC(t2_paths,clim=[0,80],cmap='jet')

In [None]:
# Create a dataframe with the path to the image and slice number
t2_path_by_slices_list = []
slice_nums_list = []
for t2_path in t2_paths:
    t2_map = open_int2(t2_path).astype(np.float64)
    n_slices = np.shape(t2_map)[2]
    
    t2_path_by_slices_list.extend([t2_path] * n_slices)
    slice_nums_list.extend(np.arange(0,n_slices))
    
    print(t2_path, n_slices)

df_infer = pd.DataFrame()
df_infer['t2'] = t2_path_by_slices_list
df_infer['slice number'] = slice_nums_list
df_infer

In [None]:
# Run Inference
preds,images,inferset = evaluate.get_model_infers(df_infer,model_path,scaleInput_options)

print(len(df_infer))
print(np.shape(preds))

In [None]:
# Save predictions
for ii in range(len(t2_paths)):
    t2_path = t2_paths[ii]
    syn_t1r_path = syn_t1r_paths[ii]
    vol_indices = list(np.where(df_infer['t2']==t2_path))

    pred_vol = np.squeeze(preds[vol_indices,0,:,:].numpy())
    pred_vol = np.transpose(pred_vol,[1,2,0])
    print(syn_t1r_path, np.shape(pred_vol))

    with h5py.File(syn_t1r_path, 'w') as f:
        dset = f.create_dataset('pred', data=pred_vol, dtype=pred_vol.dtype)
syn_t1r_paths

In [None]:
#Plot for QC
#Ground Truth
plot_QC(t1r_paths,idx=6,clim=[0,80],cmap='jet')
#Preds
plot_QC(syn_t1r_paths,idx=6,clim=[0,80],cmap='jet')