In [1]:
import numpy as np
import pandas as pd
from methods import list_files_in_directory, init_model, load_model
import torch
from matplotlib import pyplot as plt
from tqdm import tqdm
import torch.nn.functional as F
import os

In [2]:
DEVICE = 'mps'

In [3]:
model_name = "01"
path_to_weights = (
    "/Users/hazimiasad/Documents/Work/megan/data/collection/Study1/sub-"
    + model_name
    + "/pattern/dc_weights.csv"
)


path_to_in_data_1 = (
    "/Users/hazimiasad/Documents/Work/megan/data/collection/Study1/sub-"
    + model_name
    + "/pattern/in_data_d1.csv"
)
path_to_in_data_2 = (
    "/Users/hazimiasad/Documents/Work/megan/data/collection/Study1/sub-"
    + model_name
    + "/pattern/in_data_d2.csv"
)
path_to_in_data_3 = (
    "/Users/hazimiasad/Documents/Work/megan/data/collection/Study1/sub-"
    + model_name
    + "/pattern/in_data_d3.csv"
)

In [4]:
weights = torch.from_numpy(pd.read_csv(path_to_weights, header=None).values.T).to(
    DEVICE, dtype=torch.float32
)
state_size = len(weights)

In [5]:
data_in_all_list = [pd.read_csv(path_to_in_data, header=None).values for path_to_in_data in [path_to_in_data_1, path_to_in_data_2, path_to_in_data_3]]
data_in_all_array = np.vstack(data_in_all_list)

In [6]:
BATCH_SIZE = len(data_in_all_array)
N_STEPS = 40
FC2_LENGTH = 128

In [7]:
all_models_path = '/Users/hazimiasad/Documents/Work/megan/code/playground/RL-Diffusion/results/models/sub-'+model_name

In [8]:
all_models = list_files_in_directory(all_models_path)

In [9]:
def forward_with_fc2(self, state, t):
    t = torch.tensor([t] * len(state)).to(self.device).unsqueeze(1)
    state = torch.cat([state, t], dim=-1)
    x = F.relu(self.fc1(state))
    fc2_activation = F.relu(self.fc2(x))  # Capture FC2 output
    mean = self.mean(fc2_activation)
    log_std = self.log_std(fc2_activation)
    std = torch.exp(log_std)
    return mean, std, fc2_activation

In [10]:
save_base_path = '/Users/hazimiasad/Documents/Work/megan/code/playground/RL-Diffusion/results/activations/sub-01/'

In [11]:
for mod in tqdm (all_models, total=len(all_models)):
    model = init_model(DEVICE, state_size, state_size)
    model, model_name = load_model(model, mod)
    model = model.to(DEVICE)

    policy_network = model.policy
    policy_network.forward = forward_with_fc2.__get__(policy_network)

    activations = np.zeros((BATCH_SIZE, N_STEPS, FC2_LENGTH))

    for x_idx, x in enumerate(data_in_all_array):
        x = torch.from_numpy(x.reshape(1, state_size)).float().to(DEVICE)
        for t in range(40, 0, -1):
            with torch.no_grad():
                mean, std, fc2_activation = policy_network.forward(x, t)
                # mean, std = policy_network(x, t)
            activations[x_idx, t-1] = fc2_activation.cpu()
            dist = torch.distributions.Normal(mean, std)
            action = dist.sample().clamp(-1.0, 1.0)
            x = x + action
    
    np.save(os.path.join(save_base_path, f'epoch_{model_name}.npy'), activations)

100%|██████████| 300/300 [2:37:44<00:00, 31.55s/it]  


In [11]:
policy_network = model.policy
policy_network.forward = forward_with_fc2.__get__(policy_network)

In [14]:
activations = np.zeros((BATCH_SIZE, N_STEPS, FC2_LENGTH))

# def get_activation(data_idx, step_idx):
#     def hook(model, input, output):
#         activations[data_idx, step_idx] = output.cpu().detach()  # Store activation without gradient computation
#     return hook

In [None]:
for x_idx, x in tqdm(enumerate(data_in_all_array), total=len(data_in_all_array)):
    x = torch.from_numpy(x.reshape(1, state_size)).float().to(DEVICE)
    for t in range(40, 0, -1):
        # policy_network.fc2.register_forward_hook(get_activation(x_idx, t - 1))
        with torch.no_grad():
            mean, std, fc2_activation = policy_network.forward(x, t)
            # mean, std = policy_network(x, t)
        activations[x_idx, t-1] = fc2_activation.cpu()
        dist = torch.distributions.Normal(mean, std)
        action = dist.sample().clamp(-1.0, 1.0)
        x = x + action

In [25]:
np.save(os.path.join(save_base_path, f'epoch_{model_name}.npy'), activations)

In [None]:
plt.imshow(activations['fc2'].cpu())
plt.colorbar()
plt.show()