In [None]:
%load_ext autoreload
%autoreload 2

import yaml
import os
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import pandas as pd

import torch

from utils.load_dataset import load_dataset
from models.load_model import load_model
from utils.train_utils import viz_superglue_matching
from utils.eval_utils import load_gt_test, interpolate_longitudinal, interpolate_frame_angles, angle_difference, consistent_normalize, find_min_distance_configuration
from utils.preprocess import circular_to_angle
from utils.circumferential_stats_analysis import circumferential_stats_analysis, save_pkl, load_pkl 
from utils.longitudinal_stats_analysis import longitudinal_stats_analysis

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('connected to device: {}'.format(device))

## Evaluating IntraCross vs Experts

In [None]:
# ---- LOAD MODEL + CONFIG ---- #

save_folder = 'Results/2025-07-23-11-52-26-1753267946-superglue_clustered_b/'
try:
    os.mkdir(save_folder+'longi_rot_viz')
    os.mkdir(save_folder+'predictions')
except Exception as e:
    print('Warning: ', e)

with open('config/intracross.yaml') as f:
    config = yaml.load(f, yaml.FullLoader)
    
(train_loader, val_loader, test_loader), (train_dataset, val_dataset, test_dataset) = load_dataset(config)

model = load_model(config, device)
weights = torch.load(save_folder + 'best_comb.pt')
model.load_state_dict(weights)
model.eval()

print(config)

In [None]:
def eval(model, test_dataset, device, save_viz=None, save_pred=False, verbose=True):
    """
    Evaluate model on test dataset against A1, A2, A1*
    """

    longi_r1_pred = []
    longi_r1star_pred = []
    longi_r2_pred = []
    longi_computer_pred = []
    longi_r2_pred_using_r2_keypts = []
    longi_computer_pred_using_r2_keypts = []
    longi_r1_pred_interpolated = []
    longi_r1star_pred_interpolated = []
    longi_r2_pred_interpolated = []
    longi_computer_pred_interpolated = []

    rot_r1_pred = []
    rot_r1star_pred = []
    rot_r2_pred = []
    rot_computer_pred = []
    rot_r2_pred_using_r2_keypts = []
    rot_computer_pred_using_r2_keypts = []
    rot_r1_pred_interpolated = []
    rot_r1star_pred_interpolated = []
    rot_r2_pred_interpolated = []
    rot_computer_pred_interpolated = []

    ids_longi_all = []
    ids_circ_all = []
    ids_longi_keypts = []
    ids_circ_keypts = []
    
    with torch.no_grad():
        
        for i in tqdm(range(len(test_dataset))):

            # --- LOAD DATA + GT ---- #
            
            data = test_dataset[i]
            id_ = data['ids']
            
            longi_gt, rot_gt = load_gt_test(id_)
            r1_longi_interpolated, r2_longi_interpolated, r1star_longi_interpolated, r1_longi_keypt, r2_longi_keypt, r1star_longi_keypt = longi_gt
            r1_rot_interpolated, r2_rot_interpolated, r1star_rot_interpolated, r1_rot_keypt, r2_rot_keypt, r1star_rot_keypt = rot_gt
            # ED = end diastolic frames. For OCT this is every other frame
            ivus_ed = np.load('../Data/Registration Dataset/test/{}/ivus_ids.npy'.format(id_))
            oct_ed = np.load('../Data/Registration Dataset/test/{}/oct_ids.npy'.format(id_))
            rot_interpolated = np.load('../Data/Registration Dataset/test/{}/rot_interpolated.npy'.format(id_))
            start = r1_longi_keypt[0,:].astype(np.int32)
            end = r1_longi_keypt[-1,:].astype(np.int32)
 
            # --- FORWARD PASS ---- #

            for key, value in data.items():
                if isinstance(value, torch.Tensor):
                    data[key] = value.unsqueeze(0).to(device)
            if 'img0' in data.keys():
                pred = model(
                            data['keypoints0'], data['keypoints1'], 
                            data['context0'], data['context1'], 
                            data['img0'], data['img1']
                            )
            else:
                pred = model(
                            data['keypoints0'], data['keypoints1'], 
                            data['context0'], data['context1'], 
                            )

            losses = model.loss(pred, data)
            
            m0 = pred['matches0'].detach().cpu().numpy()[0]
            m1 = pred['matches1'].detach().cpu().numpy()[0]
            matching_scores0 = pred['matching_scores0'].detach().cpu().numpy()[0]
            matching_scores1 = pred['matching_scores1'].detach().cpu().numpy()[0]
            log_assignment = pred['log_assignment'].detach().cpu().numpy()[0]
            frame_id0 = data['pos_unnorm0'].cpu().numpy()[0]
            frame_id1 = data['pos_unnorm1'].cpu().numpy()[0]
            keypoints0 = data['keypoints0'].cpu().numpy()[0]
            keypoints1 = data['keypoints1'].cpu().numpy()[0]
            original0 = data['original0']
            original1 = data['original1']

            # --- VIZUALIZE ---- #

            viz_superglue_matching(pred, data, save_folder, 'test', 0)

            # --- EXTRACT FINAL PREDICTION ---- #

            # The main objective here to normalise the predictions to the same range as the GT
            # Also some post processing to ensure the start and end points are the same
            r1_rot_keypt[:, 1] = r1_rot_keypt[:, 1] % 360
            r2_rot_keypt[:, 1] = r2_rot_keypt[:, 1] % 360
            r1star_rot_keypt[:, 1] = r1star_rot_keypt[:, 1] % 360

            final_matching = []
            final_angles = []
            for m0_idx, m1_idx in enumerate(m0):
                if m1_idx != -1:
                    norm_pos0 = original0[m0_idx][:,0].mean(axis=0)
                    norm_pos1 = original1[m1_idx][:,0].mean(axis=0)
                    pos0 = original0[m0_idx][:,-1].mean(axis=0)
                    pos1 = original1[m1_idx][:,-1].mean(axis=0)
                    angle0 = original0[m0_idx][:,2:4].mean(axis=0)
                    angle1 = original1[m1_idx][:,2:4].mean(axis=0)
                    angle0 = circular_to_angle(np.expand_dims(angle0, axis=0))[0]
                    angle1 = circular_to_angle(np.expand_dims(angle1, axis=0))[0]

                    # find the original angle1 (before alignment) 
                    oct_frame_id = np.rint(pos1).astype(np.int32)
                    if oct_frame_id < start[1]:
                        rot = rot_interpolated[0, 1]
                    elif oct_frame_id > end[1]:
                        rot = rot_interpolated[-1, 1]
                    else:
                        idx = np.where(rot_interpolated[:,0] == oct_frame_id)
                        rot = rot_interpolated[idx, 1][0][0]

                    angle1 = (angle1 - rot) % 360

                    # normalise to same range as GT. 
                    rotation_oct_to_ivus = angle0 - angle1
                    rotation_oct_to_ivus += 180
                    rotation_oct_to_ivus = rotation_oct_to_ivus % 360

                    final_matching.append([pos0, pos1])
                    final_angles.append([pos1, rotation_oct_to_ivus])
            final_matching = np.array(final_matching)
            final_angles = np.array(final_angles)

            # Our predictions can fall on any frame, however the GT is only on ED frames.
            # We need to find the closest ED frames to compare against GT
            final_matching_cleaned = []
            final_angles_cleaned = []
            for (iv_frame_id, oc_frame_id), (oc_frame_id, oc_angle) in zip(final_matching, final_angles):
                # find closest ed frames. 
                ed_dists = []
                for ivus_ed_id in ivus_ed:
                    ed_dists.append(np.abs(ivus_ed_id - iv_frame_id))
                ed_argmin = np.argmin(ed_dists)
                iv_frame_id = ivus_ed[ed_argmin]
                ed_dists = []
                for oct_ed_id in oct_ed:
                    ed_dists.append(np.abs(oct_ed_id - oc_frame_id))
                ed_argmin = np.argmin(ed_dists)
                oc_frame_id = oct_ed[ed_argmin]

                # remove any points before or after end points. and add seg start and end points. 
                if (iv_frame_id <= r1_longi_keypt[0,0] or oc_frame_id <= r1_longi_keypt[0,1]) or (iv_frame_id >= r1_longi_keypt[-1,0] or oc_frame_id >= r1_longi_keypt[-1,1]):
                    pass
                else:
                    final_matching_cleaned.append([iv_frame_id, oc_frame_id])
                    final_angles_cleaned.append([oc_frame_id, oc_angle])
            final_matching_cleaned = np.array(final_matching_cleaned)
            final_angles_cleaned = np.array(final_angles_cleaned)
            if final_matching_cleaned.shape[0] > 0:
                final_matching_cleaned = np.concatenate([np.expand_dims(r1_longi_keypt[0,:], 0), final_matching_cleaned, np.expand_dims(r1_longi_keypt[-1,:], 0)])
                final_angles_cleaned = np.concatenate([np.expand_dims(r1_rot_keypt[0,:], 0), final_angles_cleaned, np.expand_dims(r1_rot_keypt[-1,:], 0)])  
            else:
                final_matching_cleaned = r1_longi_keypt[[0, -1],:]
                final_angles_cleaned = r1_rot_keypt[[0, -1],:]

            # --- EVALUATE (and STORE/SAVE for later) ---- #

            # We do this both on the keypoints and on the interpolated frames.

            computer_longi_keypt = final_matching_cleaned
            computer_longi_interpolated = interpolate_longitudinal(computer_longi_keypt, start, end)

            computer_rot = find_min_distance_configuration(final_angles_cleaned)
            computer_rot[:,1] = consistent_normalize(computer_rot[:,1], 360, 720)
            computer_rot_interpolated = interpolate_frame_angles(computer_rot, start[1], end[1])
            computer_rot_interpolated[:,1] = consistent_normalize(computer_rot_interpolated[:,1], 360, 720)

            computer_rot_interpolated[:,1] = computer_rot_interpolated[:,1] % 360
            r2_rot_interpolated[:,1] = r2_rot_interpolated[:,1] % 360
            r1_rot_interpolated[:,1] = r1_rot_interpolated[:,1] % 360
            r1star_rot_interpolated[:,1] = r1star_rot_interpolated[:,1] % 360
            computer_rot[:,1] = computer_rot[:,1] % 360
            r2_rot_keypt[:,1] = r2_rot_keypt[:,1] % 360
            r1_rot_keypt[:,1] = r1_rot_keypt[:,1] % 360
            r1star_rot_keypt[:,1] = r1star_rot_keypt[:,1] % 360

            # store predictions for later
            # Remove start and end key points as these are the same for all methods
            longi_r1_pred.extend(r1_longi_interpolated[(r1_longi_keypt[1:-1,0] - start[0]).astype(np.int32), 1])
            longi_r1star_pred.extend(r1star_longi_interpolated[(r1_longi_keypt[1:-1,0] - start[0]).astype(np.int32), 1])
            longi_r2_pred.extend(r2_longi_interpolated[(r1_longi_keypt[1:-1,0] - start[0]).astype(np.int32), 1])
            longi_computer_pred.extend(computer_longi_interpolated[(r1_longi_keypt[1:-1,0] - start[0]).astype(np.int32), 1])
            longi_r2_pred_using_r2_keypts.extend(r2_longi_interpolated[(r2_longi_keypt[1:-1,0] - start[0]).astype(np.int32), 1])
            longi_computer_pred_using_r2_keypts.extend(computer_longi_interpolated[(r2_longi_keypt[1:-1,0] - start[0]).astype(np.int32), 1])
            longi_r1_pred_interpolated.extend(r1_longi_interpolated[:,1])
            longi_r1star_pred_interpolated.extend(r1star_longi_interpolated[:,1])
            longi_r2_pred_interpolated.extend(r2_longi_interpolated[:,1])
            longi_computer_pred_interpolated.extend(computer_longi_interpolated[:,1])
    
            rot_r1_pred.extend(r1_rot_interpolated[(r1_rot_keypt[1:-1,0] - start[1]).astype(np.int32), 1])
            rot_r1star_pred.extend(r1star_rot_interpolated[(r1_rot_keypt[1:-1,0] - start[1]).astype(np.int32), 1])
            rot_r2_pred.extend(r2_rot_interpolated[(r1_rot_keypt[1:-1,0] - start[1]).astype(np.int32), 1])
            rot_computer_pred.extend(computer_rot_interpolated[(r1_rot_keypt[1:-1,0] - start[1]).astype(np.int32), 1])
            rot_r2_pred_using_r2_keypts.extend(r2_rot_interpolated[(r2_rot_keypt[1:-1,0] - start[1]).astype(np.int32), 1])
            rot_computer_pred_using_r2_keypts.extend(computer_rot_interpolated[(r2_rot_keypt[1:-1,0] - start[1]).astype(np.int32), 1])
            rot_r1_pred_interpolated.extend(r1_rot_interpolated[:,1])
            rot_r1star_pred_interpolated.extend(r1star_rot_interpolated[:,1])
            rot_r2_pred_interpolated.extend(r2_rot_interpolated[:,1])
            rot_computer_pred_interpolated.extend(computer_rot_interpolated[:,1])

            ids_longi_all.extend([id_ for x in range(r1_longi_interpolated.shape[0])])
            ids_circ_all.extend([id_ for x in range(r1_rot_interpolated.shape[0])])
            ids_longi_keypts.extend([id_ for x in range(r1_longi_keypt[1:-1].shape[0])])
            ids_circ_keypts.extend([id_ for x in range(r1_rot_keypt[1:-1].shape[0])])
    
            longi_computer_r1_all = np.abs(r1_longi_interpolated - computer_longi_interpolated)[:,1]
            longi_computer_r2_all = np.abs(r2_longi_interpolated - computer_longi_interpolated)[:,1]
            longi_r1_r2_all = np.abs(r1_longi_interpolated - r2_longi_interpolated)[:,1]
            longi_computer_r1_all_mean = longi_computer_r1_all.mean()
            longi_computer_r2_all_mean = longi_computer_r2_all.mean()
            longi_r1_r2_all_mean = longi_r1_r2_all.mean()
            longi_computer_r1_keypt_mean = longi_computer_r1_all[(r1_longi_keypt[1:-1,0] - start[0]).astype(np.int32)].mean()
            longi_computer_r2_keypt_mean = longi_computer_r2_all[(r2_longi_keypt[1:-1,0] - start[0]).astype(np.int32)].mean()
            longi_r1_r2_keypt_mean = longi_r1_r2_all[(r1_longi_keypt[1:-1,0] - start[0]).astype(np.int32)].mean()
    
            angle_computer_R1_all = angle_difference(r1_rot_interpolated[:,1], computer_rot_interpolated[:,1])
            angle_computer_R1_all_mean = angle_computer_R1_all.mean()
            angle_computer_R1_keypt_mean = angle_computer_R1_all[(r1_rot_keypt[1:-1,0] - start[1]).astype(np.int32)].mean()
            angle_computer_R2_all = angle_difference(r2_rot_interpolated[:,1], computer_rot_interpolated[:,1])
            angle_computer_R2_all_mean = angle_computer_R2_all.mean()
            angle_computer_R2_keypt_mean = angle_computer_R2_all[(r2_rot_keypt[1:-1,0] - start[1]).astype(np.int32)].mean() 
            angle_R1_R2_all = angle_difference(r1_rot_interpolated[:,1], r2_rot_interpolated[:,1])
            angle_R1_R2_all_mean = angle_R1_R2_all.mean()
            angle_R1_R2_keypt_mean = angle_R1_R2_all[(r1_rot_keypt[1:-1,0] - start[1]).astype(np.int32)].mean()
    
            if save_pred:
                save_pkl(save_folder + 'predictions/{}_longi_interpolated.pkl'.format(id_), 
                         [computer_longi_interpolated, r1_longi_interpolated, r2_longi_interpolated, r1star_longi_interpolated])
                save_pkl(save_folder + 'predictions/{}_longi_keypts.pkl'.format(id_), 
                         [computer_longi_interpolated[(r1_longi_keypt[1:-1,0] - start[0]).astype(np.int32)], 
                          r1_longi_interpolated[(r1_longi_keypt[1:-1,0] - start[0]).astype(np.int32)], 
                          r2_longi_interpolated[(r1_longi_keypt[1:-1,0] - start[0]).astype(np.int32)], 
                          r1star_longi_interpolated[(r1_longi_keypt[1:-1,0] - start[0]).astype(np.int32)]])
                save_pkl(save_folder + 'predictions/{}_rot_interpolated.pkl'.format(id_), 
                         [computer_rot_interpolated, r1_rot_interpolated, r2_rot_interpolated, r1star_rot_interpolated])
                save_pkl(save_folder + 'predictions/{}_rot_keypts.pkl'.format(id_), 
                         [computer_rot_interpolated[(r1_rot_keypt[1:-1,0] - start[1]).astype(np.int32)], 
                          r1_rot_interpolated[(r1_rot_keypt[1:-1,0] - start[1]).astype(np.int32)], 
                          r2_rot_interpolated[(r1_rot_keypt[1:-1,0] - start[1]).astype(np.int32)], 
                          r1star_rot_interpolated[(r1_rot_keypt[1:-1,0] - start[1]).astype(np.int32)]])

            if save_pred:
                np.save(f'../predictions/{id_}_longi_r1_interpolated.npy', r1_longi_interpolated)
                np.save(f'../predictions/{id_}_longi_r2_interpolated.npy', r2_longi_interpolated)
                np.save(f'../predictions/{id_}_longi_r1star_interpolated.npy', r1star_longi_interpolated)
                np.save(f'../predictions/{id_}_longi_gm_interpolated.npy', computer_longi_interpolated)
                np.save(f'../predictions/{id_}_longi_r1_keypt.npy', r1_longi_keypt)
                np.save(f'../predictions/{id_}_longi_r2_keypt.npy', r2_longi_keypt)
                np.save(f'../predictions/{id_}_longi_r1star_keypt.npy', r1star_longi_keypt)
                np.save(f'../predictions/{id_}_longi_gm_keypt.npy', computer_longi_keypt)

                np.save(f'../predictions/{id_}_rot_r1_interpolated.npy', r1_rot_interpolated)
                np.save(f'../predictions/{id_}_rot_r2_interpolated.npy', r2_rot_interpolated)
                np.save(f'../predictions/{id_}_rot_r1star_interpolated.npy', r1star_rot_interpolated)
                np.save(f'../predictions/{id_}_rot_gm_interpolated.npy', computer_rot_interpolated)
                np.save(f'../predictions/{id_}_rot_r1_keypt.npy', r1_rot_keypt)
                np.save(f'../predictions/{id_}_rot_r2_keypt.npy', r2_rot_keypt)
                np.save(f'../predictions/{id_}_rot_r1star_keypt.npy', r1star_rot_keypt)
                np.save(f'../predictions/{id_}_rot_gm_keypt.npy', computer_rot)
                

            if save_viz is not None:
                f, axes = plt.subplots(1,2,figsize=(10,5))
                axes[0].scatter(*r2_longi_interpolated.T, c='royalblue', alpha=0.3, s=1)
                axes[0].scatter(*r1_longi_interpolated.T, c='lime', alpha=0.3, s=1)
                axes[0].scatter(*r1star_longi_interpolated.T, c='darkgreen', alpha=0.3, s=1)
                axes[0].scatter(*r2_longi_keypt.T, label='R2', c='royalblue', alpha=1, s=20)
                axes[0].scatter(*r1_longi_keypt.T, label='R1', c='lime', alpha=1, s=20)
                axes[0].scatter(*r1star_longi_keypt.T, label='R1*', c='darkgreen', alpha=1, s=20)
                axes[0].scatter(*computer_longi_keypt.T, label='Computer', c='r', alpha=1, s=20)
                axes[0].scatter(*computer_longi_interpolated.T, c='r', s=1, alpha=0.3)
                axes[0].legend()
                axes[0].set_xlabel('IVUS frame id')
                axes[0].set_ylabel('OCT frame id')
                axes[0].set_title("Longitudinal: \nR1 vs Computer (all): {:.1f}, R1 vs Computer (keypt): {:.1f}, \nR2 vs Computer (all): {:.1f}, R2 vs Computer (keypt): {:.1f}, \nR1 vs R2 (all): {:.1f}, R1 vs R2 (keypt): {:.1f} ".format(
                         longi_computer_r1_all_mean, longi_computer_r1_keypt_mean, longi_computer_r2_all_mean, longi_computer_r2_keypt_mean, longi_r1_r2_all_mean, longi_r1_r2_keypt_mean), fontsize =10)
                axes[1].scatter(*computer_rot_interpolated.T, c='r', s=1)
                axes[1].scatter(*r2_rot_interpolated.T, s=1, c='royalblue')
                axes[1].scatter(*r1_rot_interpolated.T, s=1, c='lime')
                axes[1].scatter(*r1star_rot_interpolated.T, s=1, c='darkgreen')
                axes[1].scatter(*r2_rot_keypt.T, s=20, c='royalblue', label='R2')
                axes[1].scatter(*r1_rot_keypt.T, s=20, c='lime', label='R1')
                axes[1].scatter(*r1star_rot_keypt.T, s=20, c='darkgreen', label='R1*')
                axes[1].scatter(*computer_rot.T, s=20, c='r', label='Computer')
                axes[1].set_ylim(0, 360)
                axes[1].set_ylabel('Angle (degrees)')
                axes[1].set_xlabel('OCT frame id')
                axes[1].legend()
                axes[1].set_title("Rotational: \nR1 vs Computer (all): {:.1f}, R1 vs Computer (keypt): {:.1f}, \nR2 vs Computer (all): {:.1f}, R2 vs Computer (keypt): {:.1f}, \nR1 vs R2 (all): {:.1f}, R1 vs R2 (keypt): {:.1f} ".format(
                        angle_computer_R1_all_mean, angle_computer_R1_keypt_mean, angle_computer_R2_all_mean, angle_computer_R2_keypt_mean, angle_R1_R2_all_mean, angle_R1_R2_keypt_mean), fontsize =10)
                plt.suptitle(id_)
                plt.tight_layout()
                plt.savefig(save_folder + 'longi_rot_viz/{}.jpg'.format(id_), dpi=100)
                plt.close('all')

    if save_pred:
        # SAVE LONGITUDINAL PREDS
        save_pkl(save_folder + 'predictions/test_longi_r1_pred.pkl', longi_r1_pred)
        save_pkl(save_folder + 'predictions/test_longi_r1star_pred.pkl', longi_r1star_pred)
        save_pkl(save_folder + 'predictions/test_longi_r2_pred.pkl', longi_r2_pred)
        save_pkl(save_folder + 'predictions/test_longi_computer_pred.pkl', longi_computer_pred)
        save_pkl(save_folder + 'predictions/test_longi_r2_pred_using_r2_keypts.pkl', longi_r2_pred_using_r2_keypts)
        save_pkl(save_folder + 'predictions/test_longi_computer_pred_using_r2_keypts.pkl', longi_computer_pred_using_r2_keypts)
        save_pkl(save_folder + 'predictions/test_longi_r1_pred_interpolated.pkl', longi_r1_pred_interpolated)
        save_pkl(save_folder + 'predictions/test_longi_r1star_pred_interpolated.pkl', longi_r1star_pred_interpolated)
        save_pkl(save_folder + 'predictions/test_longi_r2_pred_interpolated.pkl', longi_r2_pred_interpolated)
        save_pkl(save_folder + 'predictions/test_longi_computer_interpolated_pred.pkl', longi_computer_pred_interpolated)
    
        # SAVE ROTATIONAL PREDS
        save_pkl(save_folder + 'predictions/test_rot_r1_pred.pkl', rot_r1_pred)
        save_pkl(save_folder + 'predictions/test_rot_r1star_pred.pkl', rot_r1star_pred)
        save_pkl(save_folder + 'predictions/test_rot_r2_pred.pkl', rot_r2_pred)
        save_pkl(save_folder + 'predictions/test_rot_computer_pred.pkl', rot_computer_pred)
        save_pkl(save_folder + 'predictions/test_rot_r2_pred_using_r2_keypts.pkl', rot_r2_pred_using_r2_keypts)
        save_pkl(save_folder + 'predictions/test_rot_computer_pred_using_r2_keypts.pkl', rot_computer_pred_using_r2_keypts)
        save_pkl(save_folder + 'predictions/test_rot_r1_pred_interpolated.pkl', rot_r1_pred_interpolated)
        save_pkl(save_folder + 'predictions/test_rot_r1star_pred_interpolated.pkl', rot_r1star_pred_interpolated)
        save_pkl(save_folder + 'predictions/test_rot_r2_pred_interpolated.pkl', rot_r2_pred_interpolated)
        save_pkl(save_folder + 'predictions/test_rot_computer_interpolated_pred.pkl', rot_computer_pred_interpolated)

        # SAVE IDS
        save_pkl(save_folder + 'predictions/ids_longi_all.pkl', ids_longi_all)
        save_pkl(save_folder + 'predictions/ids_circ_all.pkl', ids_circ_all)
        save_pkl(save_folder + 'predictions/ids_longi_keypts.pkl', ids_longi_keypts)
        save_pkl(save_folder + 'predictions/ids_circ_keypts.pkl', ids_circ_keypts)

    longi_pred = [longi_r1_pred, longi_r1star_pred, longi_r2_pred, longi_computer_pred, longi_r2_pred_using_r2_keypts, longi_computer_pred_using_r2_keypts, 
                 longi_r1_pred_interpolated, longi_r1star_pred_interpolated, longi_r2_pred_interpolated, longi_computer_pred_interpolated, ids_longi_all, ids_longi_keypts]
    rot_pred = [rot_r1_pred, rot_r1star_pred, rot_r2_pred, rot_computer_pred, rot_r2_pred_using_r2_keypts, rot_computer_pred_using_r2_keypts, 
                 rot_r1_pred_interpolated, rot_r1star_pred_interpolated, rot_r2_pred_interpolated, rot_computer_pred_interpolated, ids_circ_all, ids_circ_keypts]

    # This is the full analysis that computes the statistics in the paper.
    longitudinal_results = longitudinal_stats_analysis(load_folder='Results/{}/'.format(save_folder), data=longi_pred, verbose=verbose, save_viz=save_viz)
    longi_r1_computer_diff, longi_r2_computer_diff, longi_r1_p_value, longi_r2_p_value, longi_wi  = longitudinal_results
    circumferential_results = circumferential_stats_analysis(load_folder='Results/{}/'.format(save_folder), data=rot_pred, verbose=verbose, save_viz=save_viz)
    circ_r1_computer_diff, circ_r2_computer_diff, circ_r1_p_value, circ_r2_p_value, circ_wi  = circumferential_results




In [None]:
eval(model, test_dataset, device, save_viz=save_folder, save_pred=True, verbose=True)

## Statistical significant between IntraCross and He et al. 

In [None]:
from scipy.stats import wilcoxon
import seaborn as sns
import matplotlib as mpl
from utils.circumferential_stats_analysis import jackknife_mwi

In [None]:
def compare_to_icf_work(icf_folder, ics_folder, verbose=True):

    longi_dp_pred = np.array(load_pkl(f'Results/{icf_folder}/predictions/test_longi_computer_pred.pkl'))
    longi_gm_pred = np.array(load_pkl(f'Results/{ics_folder}/predictions/test_longi_computer_pred.pkl'))
    longi_r1_pred = np.array(load_pkl(f'Results/{ics_folder}/predictions/test_longi_r1_pred.pkl'))
    longi_r1star_pred = np.array(load_pkl(f'Results/{ics_folder}/predictions/test_longi_r1star_pred.pkl'))
    longi_r2_pred = np.array(load_pkl(f'Results/{ics_folder}/predictions/test_longi_r2_pred.pkl'))

    rot_dp_pred = np.array(load_pkl(f'Results/{icf_folder}/predictions/test_rot_computer_pred.pkl'))% 360
    rot_gm_pred = np.array(load_pkl(f'Results/{ics_folder}/predictions/test_rot_computer_pred.pkl'))% 360
    rot_r1_pred = np.array(load_pkl(f'Results/{ics_folder}/predictions/test_rot_r1_pred.pkl')) % 360
    rot_r2_pred = np.array(load_pkl(f'Results/{ics_folder}/predictions/test_rot_r2_pred.pkl')) % 360
    rot_r1star_pred = np.array(load_pkl(f'Results/{ics_folder}/predictions/test_rot_r1star_pred.pkl')) % 360

    longi_r1_dp_diff = np.abs(longi_r1_pred - longi_dp_pred)
    longi_r1_gm_diff = np.abs(longi_r1_pred - longi_gm_pred)
    longi_r2_dp_diff = np.abs(longi_r2_pred - longi_dp_pred)
    longi_r2_gm_diff = np.abs(longi_r2_pred - longi_gm_pred)
    longi_r1_r2_diff = np.abs(longi_r1_pred - longi_r2_pred)
    longi_r1_r1star_diff = np.abs(longi_r1_pred - longi_r1star_pred)

    rot_r1_dp_diff = angle_difference(rot_r1_pred, rot_dp_pred)
    rot_r1_gm_diff = angle_difference(rot_r1_pred, rot_gm_pred)
    rot_r2_dp_diff = angle_difference(rot_r2_pred, rot_dp_pred)
    rot_r2_gm_diff = angle_difference(rot_r2_pred, rot_gm_pred)
    rot_r1_r2_diff = angle_difference(rot_r2_pred, rot_r1_pred)
    rot_r1_r1star_diff = angle_difference(rot_r1_pred, rot_r1star_pred)

    violin_plot(
        [longi_r1_r2_diff, longi_r1_dp_diff,longi_r1_gm_diff,longi_r2_dp_diff,longi_r2_gm_diff], 
        [rot_r1_r2_diff, rot_r1_dp_diff,rot_r1_gm_diff, rot_r2_dp_diff, rot_r2_gm_diff ], 
        save_path=f'Results/{ics_folder}/violin.png' )

    vessel_assignment = np.array(load_pkl(f'Results/{ics_folder}/predictions/test_vessel_assignment.pkl'))
    confidence_level = 0.95

    print('# ----------------------------------------------------#')
    print('# ------------- LONGITUDINGAL ANALYSIS ---------------#')
    print('# ----------------------------------------------------#')

    print('Mean')
    print(f'Longi vs R1: DP: {longi_r1_dp_diff.mean():.1f} +/- {longi_r1_dp_diff.std():.1f}, GM: {longi_r1_gm_diff.mean():.1f} +/- {longi_r1_gm_diff.std():.1f}')
    print(f'Longi vs R2: DP: {longi_r2_dp_diff.mean():.1f} +/- {longi_r2_dp_diff.std():.1f}, GM: {longi_r2_gm_diff.mean():.1f} +/- {longi_r2_gm_diff.std():.1f}')
    print(f'Inter: {longi_r1_r2_diff.mean():.1f} +/- {longi_r1_r2_diff.std():.1f}, Intra: {longi_r1_r1star_diff.mean():.1f} +/- {longi_r1_r1star_diff.std():.1f}')
    
    print(f'DPvsGM %Diff (R1): {longi_r1_dp_diff.mean() - longi_r1_gm_diff.mean():.1f}, {((longi_r1_dp_diff.mean() - longi_r1_gm_diff.mean())/longi_r1_dp_diff.mean())*100:.1f}%')
    print(f'DPvsGM %Diff (R2): {longi_r2_dp_diff.mean() - longi_r2_gm_diff.mean():.1f}, {((longi_r2_dp_diff.mean() - longi_r2_gm_diff.mean())/longi_r2_dp_diff.mean())*100:.1f}%')
    
    _, longi_r1_dp_diff_vs_r1_gm_diff_p_value = wilcoxon(longi_r1_dp_diff, longi_r1_gm_diff)
    _, longi_r2_dp_diff_vs_r2_gm_diff_p_value = wilcoxon(longi_r2_dp_diff, longi_r2_gm_diff)
    print(f'--- P-values between GM-DP (<0.05 sig dif): {longi_r1_dp_diff_vs_r1_gm_diff_p_value:.4f} {longi_r2_dp_diff_vs_r2_gm_diff_p_value:.4f}')

    mwi, ci_lower, ci_upper = jackknife_mwi(longi_r1_r2_diff, longi_r1_dp_diff, longi_r2_dp_diff, confidence_level, reduction='mean')
    print(f"Williams Index (DP): {mwi:.2f} ({ci_lower:.2f}, {ci_upper:.2f})")
    mwi, ci_lower, ci_upper = jackknife_mwi(longi_r1_r2_diff, longi_r1_gm_diff, longi_r2_gm_diff, confidence_level, reduction='mean')
    print(f"Williams Index (GM): {mwi:.2f} ({ci_lower:.2f}, {ci_upper:.2f})")

    print('# ----------------------------------------------------#')
    print('# ------------- CIRCUMFERENTIAL ANALYSIS -------------#')
    print('# ----------------------------------------------------#')

    print('Mean')
    print(f'Longi vs R1: DP: {rot_r1_dp_diff.mean():.1f} +/- {rot_r1_dp_diff.std():.1f}, GM: {rot_r1_gm_diff.mean():.1f} +/- {rot_r1_gm_diff.std():.1f}')
    print(f'Longi vs R2: DP: {rot_r2_dp_diff.mean():.1f} +/- {rot_r2_dp_diff.std():.1f}, GM: {rot_r2_gm_diff.mean():.1f} +/- {rot_r2_gm_diff.std():.1f}')
    print(f'Inter: {rot_r1_r2_diff.mean():.1f} +/- {rot_r1_r2_diff.std():.1f}, Intra: {rot_r1_r1star_diff.mean():.1f} +/- {rot_r1_r1star_diff.std():.1f}')

    print(f'DPvsGM %Diff (R1): {rot_r1_dp_diff.mean() - rot_r1_gm_diff.mean():.1f}, {((rot_r1_dp_diff.mean() - rot_r1_gm_diff.mean())/rot_r1_dp_diff.mean())*100:.1f}%')
    print(f'DPvsGM %Diff (R2): {rot_r2_dp_diff.mean() - rot_r2_gm_diff.mean():.1f}, {((rot_r2_dp_diff.mean() - rot_r2_gm_diff.mean())/rot_r2_dp_diff.mean())*100:.1f}%')

    _, rot_r1_dp_diff_vs_r1_gm_diff_p_value = wilcoxon(rot_r1_dp_diff, rot_r1_gm_diff)
    _, rot_r2_dp_diff_vs_r2_gm_diff_p_value = wilcoxon(rot_r2_dp_diff, rot_r2_gm_diff)
    print(f'--- P-values between GM-DP (<0.05 sig dif): {rot_r1_dp_diff_vs_r1_gm_diff_p_value:.4f} {rot_r2_dp_diff_vs_r2_gm_diff_p_value:.4f}')

    mwi, ci_lower, ci_upper = jackknife_mwi(rot_r1_r2_diff, rot_r1_dp_diff, rot_r2_dp_diff, confidence_level, reduction='mean')
    print(f"Williams Index (DP): {mwi:.2f} ({ci_lower:.2f}, {ci_upper:.2f})")
    mwi, ci_lower, ci_upper = jackknife_mwi(rot_r1_r2_diff, rot_r1_gm_diff, rot_r2_gm_diff, confidence_level, reduction='mean')
    print(f"Williams Index (GM): {mwi:.2f} ({ci_lower:.2f}, {ci_upper:.2f})")

In [None]:
def violin_plot(longi_data, rot_data, save_path):
    # Set plot style
    mpl.rcParams.update({
        "font.sans-serif": "Lato",
        "font.size": 10,
        "axes.titlesize": 10,
        "axes.labelsize": 8,
        "xtick.labelsize": 8,
        "ytick.labelsize": 8,
        "figure.dpi": 300,
        "savefig.dpi": 300,
        "axes.linewidth": 0.8,
        "xtick.major.width": 0.8,
        "ytick.major.width": 0.8,
        "axes.grid": True,
        "grid.linewidth": 0.4,
        "grid.linestyle": "--",
    })

    labels = [
        "Inter-Observer \n(A1 vs A2)",
        "A1 vs He et al.",
        "A1 vs IntraCross",
        "A2 vs He et al.",
        "A2 vs IntraCross"
    ]

    custom_palette = {
        "Inter-Observer \n(A1 vs A2)": "#1f77b4",  # blue
        "A1 vs He et al.": "#ff7f0e",              # orange
        "A2 vs He et al.": "#ff7f0e",              # orange
        "A1 vs IntraCross": "#2ca02c",             # green
        "A2 vs IntraCross": "#2ca02c"              # green
    }

    # Create long-format dataframe
    all_data = []
    for value_array, label in zip(longi_data, labels):
        all_data.extend(zip(value_array, [label]*len(value_array), ["Longitudinal"]*len(value_array)))

    for value_array, label in zip(rot_data, labels):
        all_data.extend(zip(value_array, [label]*len(value_array), ["Rotational"]*len(value_array)))

    df = pd.DataFrame(all_data, columns=["Error", "Comparison", "Type"])

    # Create catplot with independent y-axis per row
    g = sns.catplot(
        data=df,
        x="Comparison", y="Error", hue="Comparison", row="Type",
        kind="violin", inner="box", linewidth=1, cut=0,
        height=2, aspect=3, palette=custom_palette, legend=False,
        sharey=False  # <-- this is the key change
    )

    # Set titles and axis labels
    g.set_titles("{row_name} Error")
    g.set_xlabels("")
    g.set_ylabels("")

    for ax, row_name in zip(g.axes.flat, ["Longitudinal", "Rotational"]):
        ax.set_ylabel("Frames" if row_name == "Longitudinal" else "Degrees")
        ax.tick_params(axis='x', rotation=20)
        sns.despine(ax=ax)

    plt.tight_layout()
    plt.savefig(save_path, dpi=400)
    plt.show()


In [None]:
# compare to ICF work.. 
compare_to_icf_work('DTW_Labels_AI_Baseline v2', '2025-07-14-16-08-25-1752505705-superglue_clustered_d')