In [1]:
import matplotlib.pyplot as plt
import os
import pandas as pd
import numpy as np
import seaborn as sns
import csv
import monai
import torch.nn as nn
import torch.nn.functional as F
from monai.transforms import \
    Compose, LoadImaged, AddChanneld, Orientationd, \
    Spacingd, \
    ToTensord,  \
    DataStatsd, \
    ToDeviced
from monai.data import list_data_collate
import torch
import pytorch_lightning as pl
from torchsummary import summary
monai.config.print_config()
import sys
# sys.path.append(r'/data16/private/zc348/project/DL_HMC_attention/mCT/util/python')
sys.path.append(r'/data16/private/zc348/project/DL_HMC_attention/util/python')
import vicra_toolbox
import nibabel; nibabel.imageglobals.logger.setLevel(40)
# New transforms 

sys.path.append(r'../')

torch.set_num_threads(4)
from dlhmc.transforms import (
    CreateImageStack,
    ComputeRelativeMotion,
    RandSamplePET,
    ComputeRelativeMotiond,
    CreateImageStackd,
    RandSamplePETd,
)

from dlhmc.utils.data import (
    concatenate_vicra,
    split_dataset
)


2025-04-04 09:36:44,311 - Created a temporary directory at /tmp/tmppea7e2x4
2025-04-04 09:36:44,315 - Writing /tmp/tmppea7e2x4/_remote_module_non_scriptable.py
MONAI version: 1.0.1
Numpy version: 1.23.4
Pytorch version: 1.13.1
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 8271a193229fe4437026185e218d5b06f7c8ce69
MONAI __file__: /home1/zc348/anaconda3/envs/dl-hmc_2301c/lib/python3.9/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 4.0.2
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
Pillow version: 9.2.0
Tensorboard version: 2.11.0
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.14.1
tqdm version: 4.64.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.0
pandas version: 1.4.3
einops version: 0.6.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd versi

In [2]:

class UNet(nn.Module):

    def __init__(self, in_channel, out_channel):
        super(UNet, self).__init__()
        #Encode
        self.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=32)
        self.conv_maxpool1 = torch.nn.MaxPool3d(kernel_size=2)
        self.conv_encode2 = self.contracting_block(32, 64)
        self.conv_maxpool2 = torch.nn.MaxPool3d(kernel_size=2)
        self.conv_encode3 = self.contracting_block(64, 128)
        self.conv_maxpool3 = torch.nn.MaxPool3d(kernel_size=2)


    def contracting_block(self, in_channels, out_channels, kernel_size=3):
        """
        This function creates one contracting block
        """
        block = torch.nn.Sequential(
            torch.nn.Conv3d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels, padding=1),
            torch.nn.BatchNorm3d(out_channels),
            torch.nn.ReLU(),
            torch.nn.Conv3d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels, padding=1),
            torch.nn.BatchNorm3d(out_channels),
            torch.nn.ReLU(),
        )
        return block
    
    def forward(self, x):
        # Encode
        encode_block1 = self.conv_encode1(x)
        encode_pool1 = self.conv_maxpool1(encode_block1)
        encode_block2 = self.conv_encode2(encode_pool1)
        encode_pool2 = self.conv_maxpool2(encode_block2)
        
        encode_block3 = self.conv_encode3(encode_pool2)
        encode_pool3 = self.conv_maxpool3(encode_block3)


        return encode_pool3

In [3]:
class Cross_attention(nn.Module):
    def  __init__(self, all_channel=128, all_dim=4*4*4):	

        super(Cross_attention, self).__init__()
        self.linear_e = nn.Linear(all_channel, all_channel,bias = False)
        self.channel = all_channel
        self.dim = all_dim
        self.gate = nn.Conv3d(all_channel, 1, kernel_size  = 1, bias = False)
        self.gate_s = nn.Sigmoid()
        self.conv1 = nn.Conv3d(all_channel*1, all_channel, kernel_size=3, padding=1, bias = False)
        self.conv2 = nn.Conv3d(all_channel*1, all_channel, kernel_size=3, padding=1, bias = False)
        self.bn1 = nn.BatchNorm3d(all_channel)
        self.bn2 = nn.BatchNorm3d(all_channel)
        self.prelu = nn.ReLU(inplace=True)
        self.conv3d_7 = nn.Conv3d(in_channels=all_channel *2 , out_channels=all_channel, kernel_size=1, stride=(1, 1, 1), padding=0)
        self.pathC_bn1 = nn.BatchNorm3d(all_channel*2)
        self.conv3d_8 = nn.Conv3d(in_channels=all_channel, out_channels=all_channel//2, kernel_size=3, stride=(1, 1, 1), padding=1)
        self.conv3d_9 = nn.Conv3d(in_channels=all_channel//2, out_channels=all_channel//8, kernel_size=3, stride=(1, 1, 1), padding=1)
        self.pathC_bn2 = nn.BatchNorm3d(all_channel//8)

        self.conva = nn.Conv3d(all_channel, all_channel, kernel_size=1, padding=0, bias = False)
        self.convb = nn.Conv3d(all_channel, all_channel, kernel_size=1, padding=0, bias = False)
    
    
		
    def forward(self, exemplar, query): 
        
	 
        fea_size = query.size()[2:]	 
        
        #### correlation matrix computation
        exemplar_v =  self.convb(exemplar) 
        query_v =  self.convb(query) 
        exemplar_q =  self.conva(exemplar) 
        query_k =  self.conva(query) 
        exemplar_flat = exemplar_q.view(-1, self.channel, self.dim) #N,C,H*W
        query_flat = query_k.view(-1, self.channel, self.dim)
        exemplar_t = torch.transpose(exemplar_flat,1,2).contiguous()  #batch size x dim x num
        A = torch.bmm(exemplar_t, query_flat)
        A = F.softmax(A, dim = 1) 
        B = F.softmax(torch.transpose(A,1,2),dim=1)
        query_att = torch.bmm(exemplar_v.view(-1, self.channel, self.dim) , A).contiguous() 
        exemplar_att = torch.bmm(query_v.view(-1, self.channel, self.dim), B).contiguous()
        #####self-gate mechanism
        input1_att = exemplar_att.view(-1, self.channel, fea_size[0], fea_size[1], fea_size[2])  
        input2_att = query_att.view(-1, self.channel, fea_size[0], fea_size[1], fea_size[2])
        input1_mask = self.gate(input1_att)
        input2_mask = self.gate(input2_att)
        input1_mask = self.gate_s(input1_mask)
        input2_mask = self.gate_s(input2_mask)
        input1_att = input1_att * input1_mask
        input2_att = input2_att * input2_mask
        input1_att = input1_att +  exemplar_v
        input2_att = input2_att + query_v
        ######Deep norm and fusion
        input1_att  = self.conv1(input1_att )
        input2_att  = self.conv2(input2_att ) 
        input1_att  = self.bn1(input1_att )
        input2_att  = self.bn2(input2_att )
        input1_att  = self.prelu(input1_att )
        input2_att  = self.prelu(input2_att )

        conv_input = torch.cat((input1_att, input2_att), 1)
        x = self.pathC_bn1(conv_input)
        x = self.conv3d_7(x)
        x = self.prelu(x)
        x = self.conv3d_8(x)
        x = self.prelu(x)
        x = self.conv3d_9(x)
        x = self.pathC_bn2(x)

        return x.view(x.size()[0],-1)

In [4]:
class cross_att_dataloader(pl.LightningModule):
    
    def __init__(self, dropout=0.3,img_size=32):
        super().__init__()
        
             
        self.dropout = dropout
        
        self.feature_extractor = UNet(1,1)

        self.coattention = Cross_attention()


        self.regression_layers = torch.nn.Sequential(

            torch.nn.Linear(1024, 128),
            torch.nn.Linear(128, 16),
            torch.nn.Linear(16, 6),
        )
        self.loss_function = torch.nn.MSELoss()
        
    def forward(self, x1, x2):

        y1 = self.feature_extractor(x1)
        y2 = self.feature_extractor(x2)
        y = self.coattention(y1,y2)
        y = self.regression_layers(y)
        return y
    
    
    
    def prepare_data(self):
        # set deterministic training for reproducibility
        monai.utils.misc.set_determinism(seed=42)
    
    def training_step(self, batch, batch_idx):
        x1 = batch["ThreeD_Cloud_ref"]
        x2 = batch["ThreeD_Cloud_mov"]
        ref_time = batch['ScanStart_ref']
        mov_time = batch['ScanStart_mov']
        # x_t = torch.stack([ref_time, mov_time], dim=1)
        gt_reg = batch["VICRA_rel"].float()
        y = self.forward(x1,x2)
       
        target_six = matrix_transformation(gt_reg)
        loss = self.loss_function(y, target_six)   

        return {"loss": loss}

        
    def training_epoch_end(self, outputs):
        # Calculate the average loss
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        # Logging at the end of every epoch
        self.logger.experiment.add_scalar('Train/Loss', avg_loss, self.current_epoch)
    
    def validation_step(self, batch, batch_idx):
        x1 = batch["ThreeD_Cloud_ref"]
        x2 = batch["ThreeD_Cloud_mov"]
        ref_time = batch['ScanStart_ref']
        mov_time = batch['ScanStart_mov']
        x_t = torch.stack([ref_time, mov_time], dim=1)
        gt_reg = batch["VICRA_rel"].float()
        y = self.forward(x1,x2)
        target_six = matrix_transformation(gt_reg)
        loss = self.loss_function(y, target_six)   

        return {"val_loss": loss}

    def validation_epoch_end(self, outputs):
        # Calculate the average loss
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        # Logging at the end of every epoch
        self.logger.experiment.add_scalar('Val/Loss', avg_loss, self.current_epoch)

        # Log the value for model checkpoint saving
        self.log('val_loss', avg_loss.item()) #added .item(), otherwise validation wouldn't work 

    def configure_optimizers(self):
        total_params = list(self.feature_extractor.parameters()) + list(self.regression_layers.parameters()) + list(self.coattention.parameters())
        opt = torch.optim.Adam(total_params, lr=5e-4)
        scheduler = {'scheduler': torch.optim.lr_scheduler.StepLR(optimizer=opt, step_size=200, gamma=0.98),
                     'name': 'Learning Rate'}
        return [opt], [scheduler]





In [5]:
from evaluation_toolbox import build_df_results, show_df_loss, build_df_results_12, plot_vicra_network, plot_diff_vicra_network, save_synthetic_vicra, print_loss,plot_networks_comparison
from dataset_summary_toolbox import compute_delta_T 
from data_prep_toolbox import delta_T_magnitude, Relative_motion_A_to_B_12, build_legal_dataset, deal_dataframe, clean_df
from sampling_toolbox import data_split_sample, add_T_deltaT

In [6]:
test_set=['PF605', 'JB538', 'AF120', 'JG369', 'SY636',
'EC950', 'NM937', 'AS469', 'CH568', 'JO308',
'CJ509', 'SY869', 'BB688', 'HR322', 'DS636',
 'JR684', 'JM100', 'SM968', 'TM628', 'MC181']

In [7]:
summaries_test, delta_T_all_test=compute_delta_T(['FDG'], test_set,data_type='real')
df_test=[]
predictions=[]
df_input_diff_all=[]

for i in range(len(test_set)):
    df_test.append(deal_dataframe(summaries_test[i]))


KEYS = ['ThreeD_Cloud_ref','ThreeD_Cloud_mov'] 

train_transforms = Compose([
    LoadImaged(keys=KEYS, reader='NibabelReader', as_closest_canonical=False),
    AddChanneld(keys=KEYS), 
    Orientationd(keys=KEYS, axcodes='RAS'),
    ToTensord(keys=KEYS)    
])

<class 'monai.transforms.utility.array.AddChannel'>: Class `AddChannel` has been deprecated since version 0.8. please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead.


In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
saved_model_path = '/data16/private/zc348/project/Best_FDG.ckpt'
loaded_model = cross_att_dataloader().load_from_checkpoint(saved_model_path)
loaded_model.eval()
loaded_model.to(device)

cross_att_dataloader(
  (feature_extractor): UNet(
    (conv_encode1): Sequential(
      (0): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (4): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    )
    (conv_maxpool1): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv_encode2): Sequential(
      (0): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (4): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    

In [9]:
from tqdm import tqdm
from copy import deepcopy
df_results_all=[]
y_list_all=[]
df_input_diff_all=[]
prediction_list = list()
times = torch.zeros((1800*len(test_set),1)) 
# torch.cuda.synchronize()
# starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
idx=0

for j in range(len(test_set)):

    df = df_test[j]

    #fixed reference time
    df_input_diff = vicra_toolbox.build_netinput_fixed_reference(df).reset_index()
    pairs=[]
    for i in range(len(df_input_diff)):
        pairs.append(np.array([df_input_diff['ScanStart_ref'][i], df_input_diff['ScanStart_mov'][i]]))
    df_input_diff['pairs']=pairs
    
    df_input_diff_all.append(df_input_diff)

    ##building testing dataloader
    test_dict = df_input_diff.to_dict('records')

    for i in range(len(test_dict)):
        x = test_dict[i]['ThreeD_Cloud_ref'].find('nii')
        fn_cloud1 = test_dict[i]['ThreeD_Cloud_ref'][0:x] + 'nii_monai_resize'
        x = x+3
        y = test_dict[i]['ThreeD_Cloud_ref'].find('3dcld')
        fn_cloud2 =  test_dict[i]['ThreeD_Cloud_ref'][x:y] + '3dcld_monai_rz.nii'
        test_dict[i]['ThreeD_Cloud_ref'] = fn_cloud1 + fn_cloud2

        x = test_dict[i]['ThreeD_Cloud_mov'].find('nii')
        fn_cloud1 = test_dict[i]['ThreeD_Cloud_mov'][0:x] + 'nii_monai_resize'
        x = x+3
        y = test_dict[i]['ThreeD_Cloud_mov'].find('3dcld')
        fn_cloud2 =  test_dict[i]['ThreeD_Cloud_mov'][x:y] + '3dcld_monai_rz.nii'
        test_dict[i]['ThreeD_Cloud_mov'] = fn_cloud1 + fn_cloud2

    # Create the Dataset
#     ds_test = monai.data.CacheDataset(data=test_dict, transform=train_transforms)
    ds_test = monai.data.Dataset(data=test_dict, transform=train_transforms)
    #ds_tr = monai.data.SmartCacheDataset(data=tr_dict,transform=train_transforms,replace_rate=1,cache_num=64,shuffle=True)
    # Create the DataLoader
    test_loader = monai.data.DataLoader(ds_test, batch_size=8, num_workers=2, collate_fn=list_data_collate)
    
    #calculate loss function and network output
    # saved_model_path = os.path.join(MODEL_PATH,'PETRegNet-epoch=3130-val_loss=0.454.ckpt')
    # loaded_model = PETRegNet.load_from_checkpoint(saved_model_path)


    loss = 0
    loss_list = list()
    time_list = list()
    
    y_list = list()
    loss1 = []
    
    pp = 0
    for test_data in tqdm(test_loader):
        x1 = test_data['ThreeD_Cloud_ref'].to(device)
        x2 = test_data['ThreeD_Cloud_mov'].to(device)
        x_t = test_data['pairs'].to(device)
        time = test_data['delta_t']
        ref_time = test_data['ScanStart_ref']
        mov_time = test_data['ScanStart_mov']
        x_t = torch.stack([ref_time, mov_time], dim=1).to(device)
        y = test_data['T'].cpu().numpy()
        # starter.record()
        y_test = loaded_model(x1, x2).detach().cpu().numpy()
        torch.cuda.synchronize() 
        loss = y-y_test
        l = len(loss)
        for j in range(l):
            loss1=sum(np.square(loss[j]))/len(loss[j])
            loss_list.append(loss1)
            time_list.append(time.numpy()[j])
            prediction_list.append(y_test[j])
            y_list.append(y[j])
            pp+=1
            # print(loss1)
        del test_data
        idx= idx+1
        
    
    df_results = pd.DataFrame()
    y_list_all.append(y_list)
    
    df_results['Time'] = time_list
    df_results['Loss'] = loss_list
    df_results_all.append(df_results)


100%|██████████| 225/225 [00:18<00:00, 12.06it/s]
100%|██████████| 225/225 [00:24<00:00,  9.14it/s]
100%|██████████| 225/225 [00:18<00:00, 12.13it/s]
100%|██████████| 225/225 [00:17<00:00, 13.20it/s]
100%|██████████| 225/225 [00:15<00:00, 14.15it/s]
100%|██████████| 225/225 [00:19<00:00, 11.36it/s]
100%|██████████| 225/225 [00:16<00:00, 13.25it/s]
100%|██████████| 225/225 [00:15<00:00, 14.74it/s]
100%|██████████| 225/225 [00:15<00:00, 14.71it/s]
100%|██████████| 225/225 [00:15<00:00, 14.53it/s]
100%|██████████| 225/225 [00:17<00:00, 13.00it/s]
100%|██████████| 225/225 [00:16<00:00, 13.94it/s]
100%|██████████| 225/225 [00:18<00:00, 11.98it/s]
100%|██████████| 225/225 [00:17<00:00, 12.99it/s]
100%|██████████| 225/225 [00:15<00:00, 14.12it/s]
100%|██████████| 225/225 [00:18<00:00, 12.26it/s]
100%|██████████| 225/225 [00:17<00:00, 12.73it/s]
100%|██████████| 225/225 [00:17<00:00, 12.71it/s]
100%|██████████| 225/225 [00:16<00:00, 13.51it/s]
100%|██████████| 225/225 [00:17<00:00, 12.74it/s]


In [10]:
import numpy as np
import torch.nn as nn
from sklearn.metrics import mean_squared_error
class RMSE(nn.Module):
    def __init__(self, prediction,gt,class_num):
        np.random.seed(42)  # 
        num_cases = class_num
        num_results_per_case = 1800
        self.data = np.stack(prediction)
        self.data = self.data.reshape(num_cases,num_results_per_case,1,6)
        self.gt = np.stack(gt)
        self.gt = self.gt.reshape(num_cases,num_results_per_case,1,6)
        self.translation_rmses = []
        self.rotation_rmses = []
        self.total_rmses = []

    def calculate_rmse(self, true, pred):
        return np.sqrt(mean_squared_error(true, pred))
    
    def calculate_statistics(self, data):
        mean = np.mean(data)
        variance = np.var(data)
        median = np.median(data)
        q1 = np.percentile(data, 25)
        q3 = np.percentile(data, 75)
        iqr = q3 - q1
        return mean, variance, median, iqr
    
    def forward(self):
        for case, case_gt in zip(self.data,self.gt):
            # obtain translation & rotation
            translations = case[:, :, :3]  # translation
            rotations = case[:, :, 3:]    # rotation

            translations_gt = case_gt[:, :, :3]  # translation
            rotations_gt = case_gt[:, :, 3:]    # rotation
            # each case translation and rotation RMSE
            translation_rmse = self.calculate_rmse(translations.flatten(), translations_gt.flatten())
            rotation_rmse = self.calculate_rmse(rotations.flatten(), rotations_gt.flatten())
            total_rmse = self.calculate_rmse(case.flatten(), case_gt.flatten())

            self.translation_rmses.append(translation_rmse)
            self.rotation_rmses.append(rotation_rmse)
            self.total_rmses.append(total_rmse)



        # 
        translation_mean, translation_var, translation_median, translation_iqr = self.calculate_statistics(self.translation_rmses)
        rotation_mean, rotation_var, rotation_median, rotation_iqr = self.calculate_statistics(self.rotation_rmses)
        total_mean, total_var, total_median, total_iqr = self.calculate_statistics(self.total_rmses)
        
        # print
        print("Translation RMSE:")
        print(f"Mean: {translation_mean}, SD: {translation_var}, Median: {translation_median}, IQR: {translation_iqr}")

        print("\nRotation RMSE:")
        print(f"Mean: {rotation_mean}, SD: {rotation_var}, Median: {rotation_median}, IQR: {rotation_iqr}")

        print("\nTotal RMSE:")
        print(f"Mean: {total_mean}, SD: {total_var}, Median: {total_median}, IQR: {total_iqr}")

In [11]:
CRMSE_cross = RMSE(np.stack(y_list_all),np.stack(prediction_list),len(test_set))
CRMSE_cross.forward()

Translation RMSE:
Mean: 1.2654570927698043, SD: 0.4624681273977167, Median: 1.1494540428074, IQR: 0.6367913300291387

Rotation RMSE:
Mean: 1.1643550047873894, SD: 1.2026572332472591, Median: 0.8887497633560273, IQR: 0.3429450120230557

Total RMSE:
Mean: 1.2457516438031546, SD: 0.7592176376916496, Median: 1.043129245769346, IQR: 0.4542637601312354
