In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

from pathlib import Path

import motorlab as ml
import numpy as np
import yaml

In [None]:
os.chdir(Path().resolve().parent)

In [None]:
def get_config(experiment="gbyk"):
    if experiment == "gbyk":
        sessions = [
            "bex_20230621_spikes_sorted_SES",  # before
            "bex_20230624_spikes_sorted_SES",  # before
            "bex_20230629_spikes_sorted_SES",  # before
            "bex_20230630_spikes_sorted_SES",  # before
            "bex_20230701_spikes_sorted_SES",  # before
            "bex_20230708_spikes_sorted_SES",  # while
            # "ken_20230614_spikes_sorted_SES",  # while and before
            "ken_20230618_spikes_sorted_SES",  # before
            "ken_20230622_spikes_sorted_SES",  # while, before and free
            "ken_20230629_spikes_sorted_SES",  # while, before and free
            "ken_20230630_spikes_sorted_SES",  # while
            "ken_20230701_spikes_sorted_SES",  # before
            "ken_20230703_spikes_sorted_SES",  # while
        ]
    else:
        sessions = [
            "bex_20230221",
            "bex_20230222",
            "bex_20230223",
            "bex_20230224",
            "bex_20230225",
            "bex_20230226",
            "jon_20230125",
            "jon_20230126",
            "jon_20230127",
            "jon_20230130",
            "jon_20230131",
            "jon_20230202",
            "jon_20230203",
            "luk_20230126",
            "luk_20230127",
            "luk_20230130",
            "luk_20230131",
            "luk_20230202",
            "luk_20230203",
        ]

    config = {
        "DATA_DIR": f"data/{experiment}",
        "CHECKPOINT_DIR": "checkpoint/pose_to_position",
        "CONFIG_DIR": "config/pose_to_position",
        "save": True,
        "experiment": experiment,
        "include_trial": True,
        "include_homing": True,
        "in_modalities": "poses",
        "out_modalities": "position",
        "architecture": "fc",
        "sessions": sessions,
        "position_repr": "com",
        "body_repr": "egocentric",
        "loss_fn": "mse",
        "metric": "mse",
        "model": {
            "embedding_dim": 256,
            "hidden_dim": 256,
            "n_layers": 1,
            "readout": "linear",
        },
        "train": {"n_epochs": 300, "lr": 5e-3},
        "track": {"metrics": True, "wandb": False, "save_checkpoint": True},
        "dataset": {"seq_length": 1, "stride": 1},
    }
    return config

In [None]:
ml.model.train(get_config())

In [None]:
# classification
# run = 20250624174842  # filter: false, homing: false
# run = 20250624232437  # filter: false, homing: true

In [None]:
# regression
# run = 20250702104726  # model: gru | homing: false
# run = 20250702112218  # model: gru | homing: true
run = 20250701155654  # model: fc  | homing: false
# run = 20250702140752  # model: fc  | homing: true

In [None]:
CONFIG_DIR = Path(get_config()["CONFIG_DIR"])
CONFIG_PATH = CONFIG_DIR / f"{run}.yaml"

with open(CONFIG_PATH, "r") as f:
    config = yaml.safe_load(f)

# config["include_trial"] = True
# config["include_homing"] = False

# config["include_trial"] = False
# config["include_homing"] = True
# config["dataset"] = {"seq_length": 1, "stride": 1}

eval_metrics, eval_gts, eval_preds = ml.model.evaluate(config)

for session in eval_preds:
    eval_gts[session] = eval_gts[session].reshape(-1, 2)
    eval_preds[session] = eval_preds[session].reshape(-1, 2)

In [None]:
ml.plot.room_heatmap(eval_gts, eval_preds)

In [None]:
bex_gts = {
    session: gts for session, gts in eval_gts.items() if "bex" in session
}
bex_preds = {
    session: preds for session, preds in eval_preds.items() if "bex" in session
}
ml.plot.room_heatmap(bex_gts, bex_preds, concat=True)

In [None]:
ken_gts = {
    session: gts for session, gts in eval_gts.items() if "ken" in session
}
ken_preds = {
    session: preds for session, preds in eval_preds.items() if "ken" in session
}
ml.plot.room_heatmap(ken_gts, ken_preds, concat=True)

In [None]:
bex_ken_gts = {
    "bex": np.concatenate(list(bex_gts.values()), axis=0),
    "ken": np.concatenate(list(ken_gts.values()), axis=0),
}

bex_ken_preds = {
    "bex": np.concatenate(list(bex_preds.values()), axis=0),
    "ken": np.concatenate(list(ken_preds.values()), axis=0),
}

ml.plot.room_heatmap(
    bex_ken_gts,
    bex_ken_preds,
    # save_path="plots/pose_to_position/histogram_trial_fc.svg",
)

In [None]:
tile_size = 0.865

gts = {
    "ideal": np.array(
        [
            [1 * tile_size, 0 * tile_size + tile_size / 2],
            [2 * tile_size, 0 * tile_size + tile_size / 2],
            [3 * tile_size, 0 * tile_size + tile_size / 2],
            [1 * tile_size, 1 * tile_size + tile_size / 2],
            [2 * tile_size, 1 * tile_size + tile_size / 2],
            [3 * tile_size, 1 * tile_size + tile_size / 2],
            [1 * tile_size, 2 * tile_size + tile_size / 2],
            [2 * tile_size, 2 * tile_size + tile_size / 2],
            [3 * tile_size, 2 * tile_size + tile_size / 2],
            [1 * tile_size, 3 * tile_size + tile_size / 2],
            [2 * tile_size, 3 * tile_size + tile_size / 2],
            [3 * tile_size, 3 * tile_size + tile_size / 2],
            [1 * tile_size, 4 * tile_size + tile_size / 2],
            [2 * tile_size, 4 * tile_size + tile_size / 2],
            [3 * tile_size, 4 * tile_size + tile_size / 2],
        ]
    )
}
preds = {
    "ideal": np.array(
        [
            [3 * tile_size, 0 * tile_size + tile_size / 2],
            [2 * tile_size, 0 * tile_size + tile_size / 2],
            [1 * tile_size, 0 * tile_size + tile_size / 2],
            [3 * tile_size, 3 * tile_size + tile_size / 2],
            [2 * tile_size, 3 * tile_size + tile_size / 2],
            [1 * tile_size, 3 * tile_size + tile_size / 2],
            [3 * tile_size, 2 * tile_size + tile_size / 2],
            [2 * tile_size, 2 * tile_size + tile_size / 2],
            [1 * tile_size, 2 * tile_size + tile_size / 2],
            [3 * tile_size, 1 * tile_size + tile_size / 2],
            [2 * tile_size, 1 * tile_size + tile_size / 2],
            [1 * tile_size, 1 * tile_size + tile_size / 2],
            [3 * tile_size, 4 * tile_size + tile_size / 2],
            [2 * tile_size, 4 * tile_size + tile_size / 2],
            [1 * tile_size, 4 * tile_size + tile_size / 2],
        ]
    )
}
ml.plot.room_heatmap(
    gts,
    preds,
    save_path="plots/pose_to_position/histogram_homing_ideal.svg",
)

In [None]:
ml.plot.confusion_matrix(
    eval_gts,
    eval_preds,
    # group="x",
    include_sitting=True,
    # save_path="plots/pose_to_position/confusion_matrix_nofilter_homing.svg",
)

In [None]:
ml.room.plot(save_path="plots/pose_to_position/room.svg")