In [None]:
import numpy as np
import pandas as pd
from methods import list_files_in_directory, init_model, load_model
import torch
from matplotlib import pyplot as plt
from tqdm import tqdm
import torch.nn.functional as F
import os
import time

In [None]:
DEVICE = 'mps'

In [None]:
model_name = "01"
GROUP = 'rl' # 'rl' or 'gradient'

In [None]:
path_to_weights = (
    "/Users/hazimiasad/Documents/Work/megan/data/collection/Study1/sub-"
    + model_name
    + "/pattern/dc_weights.csv"
)


path_to_in_data_1 = (
    "/Users/hazimiasad/Documents/Work/megan/data/collection/Study1/sub-"
    + model_name
    + "/pattern/in_data_d1.csv"
)
path_to_in_data_2 = (
    "/Users/hazimiasad/Documents/Work/megan/data/collection/Study1/sub-"
    + model_name
    + "/pattern/in_data_d2.csv"
)
path_to_in_data_3 = (
    "/Users/hazimiasad/Documents/Work/megan/data/collection/Study1/sub-"
    + model_name
    + "/pattern/in_data_d3.csv"
)

path_to_all_raw_roi_data = os.path.join(f'../results/raw_data/sub-{model_name}', "all_roi_timeseries_zscore.npy")

In [None]:
weights = torch.from_numpy(pd.read_csv(path_to_weights, header=None).values.T).to(
    DEVICE, dtype=torch.float32
)
state_size = len(weights)

In [None]:
# data_in_all_list = [pd.read_csv(path_to_in_data, header=None).values for path_to_in_data in [path_to_in_data_1, path_to_in_data_2, path_to_in_data_3]]
# data_in_all_array = np.vstack(data_in_all_list)

In [None]:
data_in_all_array = np.load(path_to_all_raw_roi_data).T

In [None]:
# BATCH_SIZE = len(data_in_all_array)
N_STEPS = 40
# FC2_LENGTH = 128

In [None]:
all_models_path = '/Users/hazimiasad/Documents/Work/megan/code/playground/RL-Diffusion/results/models/'+GROUP+'/sub-'+model_name

In [None]:
all_models = list_files_in_directory(all_models_path)

In [None]:
save_base_path = '/Users/hazimiasad/Documents/Work/megan/code/playground/RL-Diffusion/results/noise_distribution/'+GROUP+'/sub-'+model_name+'/'

In [None]:
reps = 1

In [None]:
mod = all_models[-1]
model = init_model(DEVICE, state_size, state_size)
model, _ = load_model(model, mod)
model = model.to(DEVICE)

means = np.zeros((len(data_in_all_array), N_STEPS, state_size))
stds = np.zeros((len(data_in_all_array), N_STEPS, state_size))

for x_idx, x_main in tqdm(enumerate(data_in_all_array), total=len(data_in_all_array)):
    x = torch.from_numpy(x_main.reshape(1, state_size)).float().to(DEVICE)
    for rep in range(reps):
        for step, t in enumerate(range(N_STEPS, 0, -1)):
            with torch.no_grad():
                action, _, mean, std = model.select_action(x, t)
                x = x + action

                means[x_idx, step] += mean.cpu().numpy().reshape(state_size)
                stds[x_idx, step] += std.cpu().numpy().reshape(state_size)


In [None]:
# Create the directory if it does not exist
tme = time.time()
os.makedirs(save_base_path, exist_ok=True)

np.save(save_base_path+f'means_all_TRs_{str(tme)}.npy', means)
np.save(save_base_path+f'stds_all_TRs_{str(tme)}.npy', stds)

In [None]:
means = np.load('../results/noise_distribution/rl/sub-01/means_all_TRs_1741413146.851917.npy')
stds = np.load('../results/noise_distribution/rl/sub-01/stds_all_TRs_1741413146.851917.npy')

In [None]:
means_mean = np.mean(means, axis=0)
stds_mean = np.mean(stds, axis=0)

In [None]:
means_mean_normalized = (means_mean - means_mean.min(axis=0, keepdims=True)) / (means_mean.max(axis=0, keepdims=True) - means_mean.min(axis=0, keepdims=True))
stds_mean_normalized = (stds_mean - stds_mean.min(axis=0, keepdims=True)) / (stds_mean.max(axis=0, keepdims=True) - stds_mean.min(axis=0, keepdims=True))

In [None]:
fig_save_base = '../results/Imgs/sub-'+model_name+'/'

In [None]:
plt.figure(figsize=(5, 10))

plt.subplot(2, 1, 1)
plt.imshow((means_mean.T), aspect='auto')
plt.colorbar()
plt.title('Mean')
plt.xlabel('Denoising Step')
plt.ylabel('Voxel')

plt.subplot(2, 1, 2, sharex=plt.gca())
plt.imshow(stds_mean.T, aspect='auto')
plt.colorbar()
plt.title('Std')
plt.xlabel('Denoising Step')
plt.ylabel('Voxel')

# plt.savefig(fig_save_base+'sub-'+model_name+'_noise_distribution_means_std_raw_'+GROUP+'.pdf')

plt.show()

In [None]:
plt.figure(figsize=(5, 10))

plt.subplot(2, 1, 1)
plt.imshow((means_mean_normalized.T), aspect='auto')
plt.colorbar()
plt.title('Mean')
plt.xlabel('Denoising Step')
plt.ylabel('Voxel')

plt.subplot(2, 1, 2, sharex=plt.gca())
plt.imshow(stds_mean_normalized.T, aspect='auto')
plt.colorbar()
plt.title('Std')
plt.xlabel('Denoising Step')
plt.ylabel('Voxel')

# plt.savefig(fig_save_base+'sub-'+model_name+'_noise_distribution_means_std_normalized_'+GROUP+'.pdf')

plt.show()

In [None]:
from sklearn.cluster import KMeans

# Number of clusters
n_clusters = 5

# Perform KMeans clustering
kmeans_means = KMeans(n_clusters=n_clusters, random_state=0).fit(means_mean_normalized.T)
# kmeans_stds = KMeans(n_clusters=n_clusters, random_state=0).fit(stds_mean_normalized.T)

# Get the cluster labels
labels_means = kmeans_means.labels_
# labels_stds = kmeans_stds.labels_

# Reorder the means_mean_normalized based on the cluster labels
sorted_indices_means = np.argsort(labels_means)
means_mean_normalized_sorted = means_mean_normalized[:, sorted_indices_means]

stds_mean_normalized_sorted = stds_mean_normalized[:, sorted_indices_means]

plt.figure(figsize=(5, 10))

plt.subplot(2, 1, 1)
plt.imshow((means_mean_normalized_sorted.T), aspect='auto')
plt.colorbar()
plt.title('Mean')
plt.xlabel('Denoising Step')
plt.ylabel('Voxel')

plt.subplot(2, 1, 2, sharex=plt.gca())
plt.imshow(stds_mean_normalized_sorted.T, aspect='auto')
plt.colorbar()
plt.title('Std')
plt.xlabel('Denoising Step')
plt.ylabel('Voxel')

# plt.savefig(fig_save_base+'sub-'+model_name+'_noise_distribution_means_std_normalized_clustered_'+GROUP+'.pdf')

plt.show()

In [None]:
all_features=(np.stack([means_mean_normalized.T, stds_mean_normalized.T], axis=1))

In [None]:
reduced_pairs = []
for pair in all_features:
    # Transpose to shape (40, 2) so that each row is a sample with 2 features
    pca = PCA(n_components=1)
    comp = pca.fit_transform(pair.T)  
    reduced_pairs.append(comp.flatten())

reduced_pairs = np.array(reduced_pairs).T

In [None]:
from sklearn.decomposition import PCA
from mpl_toolkits.mplot3d import Axes3D

pca = PCA(n_components=3)

means_pca = pca.fit_transform(reduced_pairs)


# Create a 3D plot
fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')

ax.scatter(means_pca[:, 0], means_pca[:, 1], means_pca[:, 2], c='b', marker='o')

ax.set_xlabel('Principal Component 1')
ax.set_ylabel('Principal Component 2')
ax.set_zlabel('Principal Component 3')
ax.set_title('3D PCA')
plt.show(block=False)
plt.show()

In [None]:
# Create a 3D plot
fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')

# Scatter plot
sc = ax.scatter(means_pca[:, 0], means_pca[:, 1], means_pca[:, 2], c='b', marker='o', label='Data Points')

# Add grid
ax.grid(True)

# Add labels and title
ax.set_xlabel('Principal Component 1', fontsize=12)
ax.set_ylabel('Principal Component 2', fontsize=12)
ax.set_zlabel('Principal Component 3', fontsize=12)
ax.set_title('3D PCA', fontsize=15)

# Add legend
ax.legend(loc='best')

# Show plot
plt.show()

In [None]:
# Get the variance explained by each principal component
explained_variance_ratio = pca.explained_variance_ratio_

# Create a bar plot
plt.figure(figsize=(10, 6))
plt.bar(range(1, len(explained_variance_ratio) + 1), explained_variance_ratio, alpha=0.7, align='center')
plt.xlabel('Principal Component')
plt.ylabel('Variance Explained')
plt.title('Variance Explained by Each Principal Component')
plt.show()

In [None]:
fig = plt.figure(figsize=(15, 5))

# 3D PCA plot from different angles
ax1 = fig.add_subplot(131, projection='3d')
ax1.scatter(means_pca[:, 0], means_pca[:, 1], means_pca[:, 2], c='b', marker='o')
ax1.set_xlabel('Principal Component 1')
ax1.set_ylabel('Principal Component 2')
ax1.set_zlabel('Principal Component 3')
ax1.set_title('View from PC1-PC2')

ax2 = fig.add_subplot(132, projection='3d')
ax2.scatter(means_pca[:, 0], means_pca[:, 1], means_pca[:, 2], c='b', marker='o')
ax2.view_init(elev=0, azim=90)  # View from PC1-PC3
ax2.set_xlabel('Principal Component 1')
ax2.set_ylabel('Principal Component 2')
ax2.set_zlabel('Principal Component 3')
ax2.set_title('View from PC1-PC3')

ax3 = fig.add_subplot(133, projection='3d')
ax3.scatter(means_pca[:, 0], means_pca[:, 1], means_pca[:, 2], c='b', marker='o')
ax3.view_init(elev=90, azim=0)  # View from PC2-PC3
ax3.set_xlabel('Principal Component 1')
ax3.set_ylabel('Principal Component 2')
ax3.set_zlabel('Principal Component 3')
ax3.set_title('View from PC2-PC3')

plt.tight_layout()
plt.show()

In [None]:
fig = plt.figure(figsize=(5, 5))

# 3D PCA plot with bar plot overlay
ax1 = fig.add_subplot(111, projection='3d')
ax1.scatter(means_pca[:, 0], means_pca[:, 1], means_pca[:, 2], c='b', marker='o')
ax1.set_xlabel('Principal Component 1')
ax1.set_ylabel('Principal Component 2')
ax1.set_zlabel('Principal Component 3')
ax1.set_title('3D PCA')

# Bar plot for variance explained by each principal component
for i, var in enumerate(explained_variance_ratio):
    ax1.bar3d(i, 0, 0, 0.1, 0.1, var, color='r', alpha=0.6)

plt.tight_layout()
plt.show()
