In [None]:
import numpy as np
import pandas as pd
from methods import list_files_in_directory, init_model, load_model
import torch
from matplotlib import pyplot as plt
from tqdm import tqdm
import torch.nn.functional as F
import os

In [None]:
DEVICE = 'mps'

In [None]:
model_name = "02"
path_to_weights = (
    "/Users/hazimiasad/Documents/Work/megan/data/collection/Study1/sub-"
    + model_name
    + "/pattern/dc_weights.csv"
)


path_to_in_data_1 = (
    "/Users/hazimiasad/Documents/Work/megan/data/collection/Study1/sub-"
    + model_name
    + "/pattern/in_data_d1.csv"
)
path_to_in_data_2 = (
    "/Users/hazimiasad/Documents/Work/megan/data/collection/Study1/sub-"
    + model_name
    + "/pattern/in_data_d2.csv"
)
path_to_in_data_3 = (
    "/Users/hazimiasad/Documents/Work/megan/data/collection/Study1/sub-"
    + model_name
    + "/pattern/in_data_d3.csv"
)

In [None]:
weights = torch.from_numpy(pd.read_csv(path_to_weights, header=None).values.T).to(
    DEVICE, dtype=torch.float32
)
state_size = len(weights)

In [None]:
data_in_all_list = [pd.read_csv(path_to_in_data, header=None).values for path_to_in_data in [path_to_in_data_1, path_to_in_data_2, path_to_in_data_3]]
data_in_all_array = np.vstack(data_in_all_list)

In [None]:
BATCH_SIZE = len(data_in_all_array)
N_STEPS = 40
FC2_LENGTH = 128

In [None]:
GROUP = 'gradient' # 'rl' or 'gradient'

In [None]:
all_models_path = '/Users/hazimiasad/Documents/Work/megan/code/playground/RL-Diffusion/results/models/'+GROUP+'/sub-'+model_name

In [None]:
all_models = list_files_in_directory(all_models_path)

In [None]:
save_base_path = '/Users/hazimiasad/Documents/Work/megan/code/playground/RL-Diffusion/results/noise_distribution/'+GROUP+'/sub-'+model_name+'/'

In [None]:
reps = 1

In [None]:
mod = all_models[-1]
model = init_model(DEVICE, state_size, state_size)
model, _ = load_model(model, mod)
model = model.to(DEVICE)

means = np.zeros((len(data_in_all_array), N_STEPS, state_size))
stds = np.zeros((len(data_in_all_array), N_STEPS, state_size))

for x_idx, x_main in tqdm(enumerate(data_in_all_array), total=len(data_in_all_array)):
    x = torch.from_numpy(x_main.reshape(1, state_size)).float().to(DEVICE)
    for rep in range(reps):
        for step, t in enumerate(range(N_STEPS, 0, -1)):
            with torch.no_grad():
                action, _, mean, std = model.select_action(x, t)
                x = x + action

                means[x_idx, step] += mean.cpu().numpy().reshape(state_size)
                stds[x_idx, step] += std.cpu().numpy().reshape(state_size)


In [None]:
# Create the directory if it does not exist
os.makedirs(save_base_path, exist_ok=True)

np.save(save_base_path+'means.npy', means)
np.save(save_base_path+'stds.npy', stds)

In [None]:
means_mean = np.mean(means, axis=0)
stds_mean = np.mean(stds, axis=0)

In [None]:
means_mean_normalized = (means_mean - means_mean.min(axis=0, keepdims=True)) / (means_mean.max(axis=0, keepdims=True) - means_mean.min(axis=0, keepdims=True))
stds_mean_normalized = (stds_mean - stds_mean.min(axis=0, keepdims=True)) / (stds_mean.max(axis=0, keepdims=True) - stds_mean.min(axis=0, keepdims=True))

In [None]:
fig_save_base = '../results/Imgs/sub-'+model_name+'/'

In [None]:
plt.figure(figsize=(5, 10))

plt.subplot(2, 1, 1)
plt.imshow((means_mean.T), aspect='auto')
plt.colorbar()
plt.title('Mean')
plt.xlabel('Denoising Step')
plt.ylabel('Voxel')

plt.subplot(2, 1, 2, sharex=plt.gca())
plt.imshow(stds_mean.T, aspect='auto')
plt.colorbar()
plt.title('Std')
plt.xlabel('Denoising Step')
plt.ylabel('Voxel')

plt.savefig(fig_save_base+'sub-'+model_name+'_noise_distribution_means_std_raw_'+GROUP+'.pdf')

plt.show()

In [None]:
plt.figure(figsize=(5, 10))

plt.subplot(2, 1, 1)
plt.imshow((means_mean_normalized.T), aspect='auto')
plt.colorbar()
plt.title('Mean')
plt.xlabel('Denoising Step')
plt.ylabel('Voxel')

plt.subplot(2, 1, 2, sharex=plt.gca())
plt.imshow(stds_mean_normalized.T, aspect='auto')
plt.colorbar()
plt.title('Std')
plt.xlabel('Denoising Step')
plt.ylabel('Voxel')

plt.savefig(fig_save_base+'sub-'+model_name+'_noise_distribution_means_std_normalized_'+GROUP+'.pdf')

plt.show()

In [None]:
from sklearn.cluster import KMeans

# Number of clusters
n_clusters = 5

# Perform KMeans clustering
kmeans_means = KMeans(n_clusters=n_clusters, random_state=0).fit(means_mean_normalized.T)
# kmeans_stds = KMeans(n_clusters=n_clusters, random_state=0).fit(stds_mean_normalized.T)

# Get the cluster labels
labels_means = kmeans_means.labels_
# labels_stds = kmeans_stds.labels_

# Reorder the means_mean_normalized based on the cluster labels
sorted_indices_means = np.argsort(labels_means)
means_mean_normalized_sorted = means_mean_normalized[:, sorted_indices_means]

stds_mean_normalized_sorted = stds_mean_normalized[:, sorted_indices_means]

plt.figure(figsize=(5, 10))

plt.subplot(2, 1, 1)
plt.imshow((means_mean_normalized_sorted.T), aspect='auto')
plt.colorbar()
plt.title('Mean')
plt.xlabel('Denoising Step')
plt.ylabel('Voxel')

plt.subplot(2, 1, 2, sharex=plt.gca())
plt.imshow(stds_mean_normalized_sorted.T, aspect='auto')
plt.colorbar()
plt.title('Std')
plt.xlabel('Denoising Step')
plt.ylabel('Voxel')

plt.savefig(fig_save_base+'sub-'+model_name+'_noise_distribution_means_std_normalized_clustered_'+GROUP+'.pdf')

plt.show()