# Analyze the trained model

## Set up

In [32]:
# autoreload
%load_ext autoreload
%autoreload 2
# jupyter black formatter
%load_ext jupyter_black

import subprocess
import os
import sys

gitroot_path = subprocess.check_output(
    ["git", "rev-parse", "--show-toplevel"], universal_newlines=True
)

os.chdir(os.path.join(gitroot_path[:-1], "pirnns"))
print("Working directory: ", os.getcwd())

sys_dir = os.path.dirname(os.getcwd())
sys.path.append(sys_dir)
print("Directory added to path: ", sys_dir)
sys.path.append(os.getcwd())
print("Directory added to path: ", os.getcwd())

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The jupyter_black extension is already loaded. To reload it, use:
  %reload_ext jupyter_black
Working directory:  /Users/facosta/code/pirnns/pirnns
Directory added to path:  /Users/facosta/code/pirnns
Directory added to path:  /Users/facosta/code/pirnns/pirnns


## Load the trained model

In [33]:
from pirnns.model import PathIntRNN
import torch

model = PathIntRNN(input_size=2, hidden_size=50, output_size=2, alpha=0.9)

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

model.to(device)


# model.load_state_dict(torch.load("model_weights/model_weights.pth"))

print(model)

PathIntRNN(
  (rnn_step): RNNStep(
    (activation): Tanh()
    (W_in): Linear(in_features=2, out_features=50, bias=True)
    (W_rec): Linear(in_features=50, out_features=50, bias=True)
  )
  (W_out): Linear(in_features=50, out_features=2, bias=True)
  (W_h_init): Linear(in_features=2, out_features=50, bias=True)
)


## Load analysis data

In [34]:
from datamodule import PathIntegrationDataModule

datamodule = PathIntegrationDataModule(
    num_trajectories=1000,
    batch_size=128,
    num_workers=1,
    train_val_split=0.8,
    start_time=0,
    end_time=100,
    num_time_steps=1000,
    arena_L=5,
    mu_speed=0.2,
    sigma_speed=0.5,
    tau_vel=1,
)
datamodule.setup()

In [35]:
print(f"Number of validation batches: {len(datamodule.val_dataloader())}")
print(f"Validation dataset size: {len(datamodule.val_dataset)}")
print(f"Batch size: {datamodule.batch_size}")

Number of validation batches: 2
Validation dataset size: 200
Batch size: 128


In [36]:
import torch

model.eval()

all_hidden_states = []
all_outputs = []
all_inputs = []
all_targets = []

num_batches = 4  # or however many you want

with torch.no_grad():
    for i, (inputs, targets) in enumerate(datamodule.val_dataloader()):
        if i >= num_batches:
            break

        inputs = inputs.to(device)
        targets = targets.to(device)

        hidden_states, outputs = model(inputs=inputs, pos_0=targets[:, 0, :])

        all_hidden_states.append(hidden_states)
        all_outputs.append(outputs)
        all_inputs.append(inputs)
        all_targets.append(targets)

# Concatenate all batches
hidden_states = torch.cat(all_hidden_states, dim=0)
outputs = torch.cat(all_outputs, dim=0)
inputs = torch.cat(all_inputs, dim=0)
targets = torch.cat(all_targets, dim=0)

print(hidden_states.shape)
print(outputs.shape)

torch.Size([200, 1000, 50])
torch.Size([200, 1000, 2])


## Visualize PCA of hidden states

In [37]:
import plotly.graph_objects as go

# Your existing PCA code stays the same
hidden_states_data = hidden_states.reshape(-1, hidden_states.shape[-1]).cpu()

from sklearn.decomposition import PCA

pca = PCA(n_components=3)
pca.fit(hidden_states_data)

print(f"Total variance explained: {pca.explained_variance_ratio_.sum() * 100:.2f}%")

reduced_data = pca.transform(hidden_states_data)

# Create interactive 3D plot with Plotly
fig = go.Figure(
    data=[
        go.Scatter3d(
            x=reduced_data[:, 0],
            y=reduced_data[:, 1],
            z=reduced_data[:, 2],
            mode="markers",
            marker=dict(size=3, opacity=0.7),
        )
    ]
)

fig.update_layout(
    title="3D PCA of Hidden States",
    scene=dict(
        xaxis_title=f"PC1 ({pca.explained_variance_ratio_[0] * 100:.2f}%)",
        yaxis_title=f"PC2 ({pca.explained_variance_ratio_[1] * 100:.2f}%)",
        zaxis_title=f"PC3 ({pca.explained_variance_ratio_[2] * 100:.2f}%)",
    ),
    width=600,
    height=600,
)

fig.show()

Total variance explained: 84.84%
