In [1]:
from torch.utils.data import DataLoader, Dataset
from torchvision import models
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm, trange
from resnet50_get_relu_outputs import resnet_output
from joblib import Parallel, delayed
from scipy.io import savemat, loadmat
from PIL import Image
from scipy.stats import entropy
import seaborn as sns

In [2]:
input_folder = "/media/sarvagya-pc/2TB HDD/Balgrist/full_MPM_images/wp1"


In [3]:
def entropy(p):
    """Calculate entropy for given probability distribution."""
    p = p[p > 0]  # Filter zero probabilities to avoid log(0)
    return -np.sum(p * np.log(p))

def mutual_information(x, y, bins=120):
    """Estimate mutual information for continuous variables."""
    # Flatten the 2D image arrays to 1D
    x = x.flatten()
    y = y.flatten()

    # Calculate the 2D histogram
    c_xy = np.histogram2d(x, y, bins)[0]
    p_xy = c_xy / np.sum(c_xy)  # Joint probability distribution
    p_x = np.sum(p_xy, axis=1)  # Marginal for x
    p_y = np.sum(p_xy, axis=0)  # Marginal for y
    
    # Calculate entropies
    h_x = entropy(p_x)
    h_y = entropy(p_y)
    h_xy = entropy(p_xy.flatten())
    
    # Mutual information
    mi = h_x + h_y - h_xy
    return mi



In [6]:
subject_folders = sorted(os.listdir(input_folder))

In [10]:
# matrix_sub = list(matrices.keys())
# num_matrices = len(matrix_sub)
# mi_matrix = np.zeros((num_matrices, num_matrices, 13))
# print(matrix_sub)
num_images = 151

# Collect all sessions for all subjects
all_sessions = []
for subject in subject_folders:
    subject_path = os.path.join(input_folder, subject)
    sessions = sorted([s for s in os.listdir(subject_path) if os.path.isdir(os.path.join(subject_path, s))])
    for session in sessions:
        all_sessions.append((subject, session))

num_sessions = len(all_sessions)

# Initialize a matrix to store mutual information values
# Shape: (num_sessions, num_sessions, num_images)
mi_matrix = np.zeros((num_sessions, num_sessions, num_images))

# Iterate through each image index (i)
for i in tqdm(range(num_images)):
    # Iterate through each session of each subject
    for idx_1, (subject_1, session_1) in enumerate(all_sessions):
        img_path_1 = os.path.join(input_folder, subject_1, session_1, 'PCA', 'axial', f'img_{i:03}_normalized.npy')
        try:
            # img_1 = Image.open(img_path_1).convert('L')
            # img_1 = np.array(img_1)
            img_1 = np.load(img_path_1, allow_pickle=True)
        except FileNotFoundError:
            # print(f"Error: Image {img_path_1} not found.")
            continue

        # Iterate through each other session for mutual information calculation
        for idx_2, (subject_2, session_2) in enumerate(all_sessions):
            img_path_2 = os.path.join(input_folder, subject_2, session_2, 'PCA', 'axial', f'img_{i:03}_normalized.npy')
            try:
                # img_2 = Image.open(img_path_2).convert('L')
                # img_2 = np.array(img_2)
                img_2 = np.load(img_path_2, allow_pickle=True)
            except FileNotFoundError:
                # print(f"Error: Image {img_path_2} not found.")
                continue

            # Calculate mutual information between the ith image of the current session and other sessions
            mi_value = mutual_information(img_1, img_2)
            mi_matrix[idx_1, idx_2, i] = mi_value

# Print or save the MI matrix for all images across all sessions and subjects
print("Mutual Information Matrix for All Sessions and Subjects:")
print(mi_matrix)



  0%|          | 0/151 [00:00<?, ?it/s]

Mutual Information Matrix for All Sessions and Subjects:
[[[0.46872496 0.46003574 0.46148326 ... 0.97153397 0.         0.        ]
  [0.4665181  0.454023   0.46089431 ... 0.93958544 0.         0.        ]
  [0.04823727 0.05185392 0.0505216  ... 0.24290169 0.         0.        ]
  ...
  [0.04960552 0.07387784 0.12716375 ... 0.41673221 0.         0.        ]
  [0.04975924 0.07407435 0.12675445 ... 0.41681762 0.         0.        ]
  [0.05104923 0.07306438 0.12663543 ... 0.41662824 0.         0.        ]]

 [[0.4665181  0.454023   0.46089431 ... 0.93958544 0.         0.        ]
  [0.46895036 0.45995181 0.46160114 ... 0.96596205 0.         0.        ]
  [0.04833868 0.05169981 0.05066692 ... 0.24094957 0.         0.        ]
  ...
  [0.04953135 0.07379291 0.1271109  ... 0.41395051 0.         0.        ]
  [0.04968464 0.07399102 0.12669152 ... 0.41403812 0.         0.        ]
  [0.05097384 0.07298597 0.12658225 ... 0.41379855 0.         0.        ]]

 [[0.04823727 0.05185392 0.0505216  ...

In [11]:
np.save('mutual_information_matrix_all_PCA.npy', mi_matrix)
mi_matrix[:,:,100].shape

(165, 165)

In [17]:
plt.figure(figsize=(10, 8))
sns.heatmap(mi_matrix[:,:,120], annot=False, cmap='viridis', xticklabels=subject_folders, yticklabels=[f'Image_{i:03}' for i in range(num_images)])
# plt.title(f'Mutual Information Confusion Matrix for {session_key}')
# plt.xlabel('Subjects')
# plt.ylabel('Images')
plt.tight_layout()
# plt.savefig(f'mutual_information_confusion_matrix_{session_key}.png')
plt.close()

In [35]:
count = 0
for subject in subject_folders:
    # print(subject)
    sessions = sorted([ses for ses in os.listdir(input_folder+'/'+subject+'/') if "ses-" in ses])
    for ses in sessions:
        print(subject, ses)
        count+=1
print(count)

BCA-001 ses-01
BCA-001 ses-03
BCA-002 ses-02
BCA-002 ses-03
BSL-001 ses-02
BSL-001 ses-03
BSL-002 ses-01
BSL-002 ses-02
BSL-002 ses-03
BSL-003 ses-01
BSL-003 ses-02
BSL-003 ses-03
BSL-004 ses-01
BSL-004 ses-02
BSL-005 ses-01
BSL-005 ses-02
BSL-005 ses-03
BSL-006 ses-01
BSL-006 ses-03
HDG-002 ses-01
HDG-002 ses-02
HDG-004 ses-01
HDG-004 ses-02
HDG-004 ses-03
HDG-006 ses-01
HDG-006 ses-02
HDG-006 ses-03
HDG-007 ses-02
HDG-007 ses-03
HDG-008 ses-01
HDG-008 ses-02
HDG-008 ses-03
HDG-009 ses-02
HDG-009 ses-03
HDG-011 ses-01
HDG-011 ses-02
HDG-012 ses-01
HDG-012 ses-02
HDG-012 ses-03
HDG-014 ses-01
HDG-014 ses-02
HDG-014 ses-03
HDG-015 ses-01
HDG-015 ses-02
HDG-018 ses-01
HDG-018 ses-02
HLE-003 ses-01
HLE-003 ses-02
HLE-003 ses-03
HLE-004 ses-01
HLE-004 ses-02
HLE-004 ses-03
HLE-005 ses-01
HLE-005 ses-02
HLE-005 ses-03
HLE-006 ses-01
HLE-006 ses-03
HLE-007 ses-01
HLE-007 ses-02
HLE-008 ses-01
HLE-008 ses-03
HLE-011 ses-01
HLE-011 ses-02
HLE-011 ses-03
HLE-012 ses-01
HLE-012 ses-02
HLE-012 se

In [46]:
mi_matrix.shape

(165, 165, 149)

In [36]:
confusion_matrix_mi = mi_matrix = np.zeros((mi_matrix.shape[1], mi_matrix.shape[1], mi_matrix.shape[0]))
for i in range(mi_matrix.shape[1]):
    # print(matrix_sub[i])
    for j in range(i, mi_matrix.shape[1]):
        confusion_matrix_mi[i,j,:] = mi_matrix[:,i]
        confusion_matrix_mi[j,i,:] = mi_matrix[:,i]

ValueError: could not broadcast input array from shape (165,149) into shape (149,)

In [40]:
mi_matrix[:,0].shape

(165, 149)