## Set up + Imports 

In [1]:
import setup

setup.main()
%load_ext autoreload
%autoreload 2
%load_ext jupyter_black
import yaml
import torch

Working directory:  /home/facosta/neurometry/neurometry
Directory added to path:  /home/facosta/neurometry
Directory added to path:  /home/facosta/neurometry/neurometry


Specify run name

In [2]:
run_name = "run_tus9d935_s_0=1_sigma_saliency=0.05_x_saliency=0.5"

In [3]:
import os

base_dir = os.path.join(os.getcwd(), "neuroai/piRNNs/models")
configs_dir = os.path.join(base_dir, "results/configs")
models_dir = os.path.join(base_dir, "results/trained_models")

In [4]:
def _load_expt_config(run_name, configs_dir):
    config_file = os.path.join(configs_dir, f"{run_name}.json")

    with open(config_file) as file:
        return yaml.safe_load(file)

### Load experiment config

In [5]:
expt_config = _load_expt_config(run_name, configs_dir)

In [6]:
import ml_collections


def _d(**kwargs):
    """Helper of creating a config dict."""
    return ml_collections.ConfigDict(initial_dictionary=kwargs)


import ml_collections


def _convert_config(normal_config):
    """Convert a normal dictionary to ml_collections.ConfigDict.

    Parameters
    ----------
    normal_config : dict
        Configuration dictionary.

    Returns
    -------
    ml_collections.ConfigDict
        Converted configuration dictionary.

    """
    config = ml_collections.ConfigDict()

    # Training config
    config.train = {
        "load_pretrain": normal_config["load_pretrain"],
        "pretrain_path": normal_config["pretrain_path"],
        "num_steps_train": normal_config["num_steps_train"],
        "lr": normal_config["lr"],
        "lr_decay_from": normal_config["lr_decay_from"],
        "steps_per_logging": normal_config["steps_per_logging"],
        "steps_per_large_logging": normal_config["steps_per_large_logging"],
        "steps_per_integration": normal_config["steps_per_integration"],
        "norm_v": normal_config["norm_v"],
        "positive_v": normal_config["positive_v"],
        "positive_u": normal_config["positive_u"],
        "optimizer_type": normal_config["optimizer_type"],
    }

    # Simulated data config
    config.data = {
        "max_dr_trans": normal_config["max_dr_trans"],
        "max_dr_isometry": normal_config["max_dr_isometry"],
        "batch_size": normal_config["batch_size"],
        "sigma_data": normal_config["sigma_data"],
        "add_dx_0": normal_config["add_dx_0"],
        "small_int": normal_config["small_int"],
    }

    # Model parameter config
    config.model = {
        "freeze_decoder": normal_config.get("freeze_decoder", False),
        "trans_type": normal_config["trans_type"],
        "rnn_step": normal_config["rnn_step"],
        "num_grid": normal_config["num_grid"],
        "num_neurons": normal_config["num_neurons"],
        "block_size": normal_config["block_size"],
        "sigma": normal_config["sigma"],
        "w_kernel": normal_config["w_kernel"],
        "w_trans": normal_config["w_trans"],
        "w_isometry": normal_config["w_isometry"],
        "w_reg_u": normal_config["w_reg_u"],
        "reg_decay_until": normal_config["reg_decay_until"],
        "adaptive_dr": normal_config["adaptive_dr"],
        "s_0": normal_config["s_0"],
        "x_saliency": normal_config["x_saliency"],
        "sigma_saliency": normal_config["sigma_saliency"],
        "reward_step": normal_config["reward_step"],
        "saliency_type": normal_config["saliency_type"],
    }

    # Path integration config
    config.integration = {
        "n_inte_step": normal_config["n_inte_step"],
        "n_traj": normal_config["n_traj"],
        "n_inte_step_vis": normal_config["n_inte_step_vis"],
        "n_traj_vis": normal_config["n_traj_vis"],
    }

    return config

### Load Trained Model

In [7]:
from neurometry.neuroai.piRNNs.models import model

config = _convert_config(expt_config)
model_config = model.GridCellConfig(**config.model)
device = "cuda"
model = model.GridCell(model_config).to(device)

In [8]:
trained_model_path = os.path.join(models_dir, f"{run_name}_model.pt")
trained_model = torch.load(trained_model_path, map_location=device)
model.load_state_dict(trained_model["state_dict"])

<All keys matched successfully>

### Load evaluation data (trajectories)

In [9]:
config.integration.n_inte_step = 150

print(config.integration)

n_inte_step: 150
n_inte_step_vis: 50
n_traj: 100
n_traj_vis: 5



In [10]:
import numpy as np
from neurometry.neuroai.piRNNs.models import input_pipeline
import neurometry.neuroai.piRNNs.models.utils as utils

rng = np.random.default_rng()

eval_dataset = input_pipeline.EvalDataset(
    rng, config.integration, config.data.max_dr_trans, config.model.num_grid
)

eval_iter = iter(eval_dataset)

eval_data = utils.dict_to_device(next(eval_iter), device)

2024-07-02 10:31:13.735186: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [11]:
path_integration_output = model.path_integration(eval_data["traj"]["traj"])

err, traj_real, traj_pred, heatmaps = path_integration_output.values()

traj_pred_vanilla = traj_pred["vanilla"]
traj_pred_reencode = traj_pred["reencode"]

In [12]:
traj_real = traj_real.cpu().numpy()
traj_pred_vanilla = traj_pred_vanilla.cpu().numpy()
traj_pred_reencode = traj_pred_reencode.cpu().numpy()

In [15]:
traj_predicted = traj_pred_vanilla

In [16]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

traj_idx = 0

num_trajectories = traj_real.shape[0]
num_steps = traj_real.shape[1]
max_x = (
    max(np.max(traj_real[traj_idx, :, 0]), np.max(traj_predicted[traj_idx, :, 0])) + 1
)
max_y = (
    max(np.max(traj_real[traj_idx, :, 1]), np.max(traj_predicted[traj_idx, :, 1])) + 1
)

min_x = (
    min(np.min(traj_real[traj_idx, :, 0]), np.min(traj_predicted[traj_idx, :, 0])) - 1
)
min_y = (
    min(np.min(traj_real[traj_idx, :, 1]), np.min(traj_predicted[traj_idx, :, 1])) - 1
)


plt.style.use("dark_background")


def animate(i, traj_idx):
    plt.cla()  # Clear current plot
    traj_real_single = traj_real[traj_idx]
    traj_pred_single = traj_predicted[traj_idx]
    # Plot real trajectory
    plt.plot(
        traj_real_single[:i, 0],
        traj_real_single[:i, 1],
        "b-",
        alpha=0.5,
        label="Real Traj",
        linewidth=2,
    )  # Plot trail with reduced opacity
    plt.plot(
        traj_real_single[i, 0], traj_real_single[i, 1], "bo", markersize=10
    )  # Plot current point
    # Plot predicted trajectory
    plt.plot(
        traj_pred_single[:i, 0],
        traj_pred_single[:i, 1],
        "r-",
        alpha=0.5,
        label="Pred Traj",
        linewidth=2,
    )  # Plot trail with reduced opacity
    plt.plot(
        traj_pred_single[i, 0], traj_pred_single[i, 1], "ro", markersize=10
    )  # Plot current point
    plt.xlim(min_x, max_x)  # Adjust x-axis limits as needed
    plt.ylim(-min_y, max_y)  # Adjust y-axis limits as needed
    plt.title(
        f"Real vs Predicted Trajectory at Time t={i}", fontsize=16
    )  # Set title for the frame
    plt.xlabel("X Coordinate", fontsize=14)
    plt.ylabel("Y Coordinate", fontsize=14)
    plt.legend(loc="upper right", fontsize=12)
    plt.grid(True)

    plt.annotate(
        "Real",
        xy=(traj_real_single[i, 0], traj_real_single[i, 1]),
        xytext=(5, 5),
        textcoords="offset points",
        color="blue",
    )
    plt.annotate(
        "Pred",
        xy=(traj_pred_single[i, 0], traj_pred_single[i, 1]),
        xytext=(5, 5),
        textcoords="offset points",
        color="red",
    )
    # plt.axis("equal")


# Specify which trajectory index you want to visualize
traj_idx_to_visualize = (
    0  # Change this to the index of the trajectory you want to visualize
)

# Set up figure and animation
fig = plt.figure(figsize=(10, 8), dpi=150)
ani = animation.FuncAnimation(
    fig, animate, frames=num_steps, fargs=(traj_idx_to_visualize,), interval=100
)

# Display animation inline in Jupyter Notebook
%matplotlib notebook
HTML(ani.to_html5_video())

<IPython.core.display.Javascript object>

In [97]:
heatmaps.shape

# what do heatmaps represent?

torch.Size([100, 151, 40, 40])

In [85]:
plt.style.available

['Solarize_Light2',
 '_classic_test_patch',
 '_mpl-gallery',
 '_mpl-gallery-nogrid',
 'bmh',
 'classic',
 'dark_background',
 'fast',
 'fivethirtyeight',
 'ggplot',
 'grayscale',
 'seaborn-v0_8',
 'seaborn-v0_8-bright',
 'seaborn-v0_8-colorblind',
 'seaborn-v0_8-dark',
 'seaborn-v0_8-dark-palette',
 'seaborn-v0_8-darkgrid',
 'seaborn-v0_8-deep',
 'seaborn-v0_8-muted',
 'seaborn-v0_8-notebook',
 'seaborn-v0_8-paper',
 'seaborn-v0_8-pastel',
 'seaborn-v0_8-poster',
 'seaborn-v0_8-talk',
 'seaborn-v0_8-ticks',
 'seaborn-v0_8-white',
 'seaborn-v0_8-whitegrid',
 'tableau-colorblind10']