In [6]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
import os

from pathlib import Path

import motorlab as ml
import numpy as np
import yaml

In [8]:
os.chdir(Path().resolve().parent)

### pcs to exclude

In [9]:
# obtained in the notebook 'analysis_pca.ipynb'

# delete pcs that are >= 0.15 above baseline.
with open("artifacts/tables/analysis_pca/loose.yml", "r") as f:
    loose = yaml.safe_load(f)

# delete pcs that are >= 0.10 above baseline.
with open("artifacts/tables/analysis_pca/medium.yml", "r") as f:
    medium = yaml.safe_load(f)

# delete pcs that are >= 0.05 above baseline.
with open("artifacts/tables/analysis_pca/strict.yml", "r") as f:
    strict = yaml.safe_load(f)

# delete pcs that are >= 0.01 above baseline.
with open("artifacts/tables/analysis_pca/draconian.yml", "r") as f:
    draconian = yaml.safe_load(f)

### train

In [10]:
experiment = "gbyk"
sessions = ml.sessions.GBYK
config = ml.config.load_default(experiment, sessions)

config["model"]["n_layers"] = 2
config["poses"]["representation"] = "centered"
config["poses"]["keypoints_to_exclude"] = [
    "e_tail",
    "s_tail",
    "l_hip",
    "l_knee",
    "l_ankle",
    "r_hip",
    "r_knee",
    "r_ankle",
    "l_shoulder",
    "l_elbow",
    "l_wrist",
    "r_shoulder",
    "r_elbow",
    "r_wrist",
]
# config["poses"]["project_to_pca"] = True
# config["poses"]["pcs_to_exclude"] = draconian

ml.model.train(config)

Number of parameters: 205,336
FCModel(
  (embedding): LinearEmbedding(
    (linear): ModuleDict(
      (bex_20230621_spikes_sorted_SES): ModuleDict(
        (poses): Linear(in_features=21, out_features=256, bias=True)
      )
      (bex_20230624_spikes_sorted_SES): ModuleDict(
        (poses): Linear(in_features=21, out_features=256, bias=True)
      )
      (bex_20230629_spikes_sorted_SES): ModuleDict(
        (poses): Linear(in_features=21, out_features=256, bias=True)
      )
      (bex_20230630_spikes_sorted_SES): ModuleDict(
        (poses): Linear(in_features=21, out_features=256, bias=True)
      )
      (bex_20230701_spikes_sorted_SES): ModuleDict(
        (poses): Linear(in_features=21, out_features=256, bias=True)
      )
      (bex_20230708_spikes_sorted_SES): ModuleDict(
        (poses): Linear(in_features=21, out_features=256, bias=True)
      )
      (ken_20230618_spikes_sorted_SES): ModuleDict(
        (poses): Linear(in_features=21, out_features=256, bias=True)
      )


### eval

In [None]:
config_path = Path(f"config/poses_to_position/{run}.yaml")

with open(config_path, "r") as f:
    config = yaml.safe_load(f)

# config["include_trial"] = True
# config["include_homing"] = False

# config["include_trial"] = False
# config["include_homing"] = True

eval_metrics, eval_gts, eval_preds = ml.model.evaluate(config)

eval_gts = {session: gt.reshape(-1, 2) for session, gt in eval_gts.items()}
tiled_gts = {
    session: ml.room.get_tiles(gt[:, 0], gt[:, 1])
    for session, gt in eval_gts.items()
}

eval_preds = {
    session: pred.reshape(-1, 2) for session, pred in eval_preds.items()
}
tiled_preds = {
    session: ml.room.get_tiles(pred[:, 0], pred[:, 1])
    for session, pred in eval_preds.items()
}

In [None]:
ml.plot.confusion_matrix(
    tiled_gts,
    tiled_preds,
    group="y",
    include_sitting=False,
    # concat=True,
)

In [None]:
bex_gts = {
    session: gts for session, gts in eval_gts.items() if "bex" in session
}

bex_preds = {
    session: preds for session, preds in eval_preds.items() if "bex" in session
}

ml.plot.room_histogram2d(bex_gts, bex_preds, concat=True)

In [None]:
bex_tiled_gts = {
    session: gt for session, gt in tiled_gts.items() if "bex" in session
}

bex_tiled_preds = {
    session: pred for session, pred in tiled_preds.items() if "bex" in session
}

ml.plot.confusion_matrix(
    bex_tiled_gts,
    bex_tiled_preds,
    group="y",
    include_sitting=False,
    concat=True,
)

In [None]:
ken_gts = {
    session: gts for session, gts in eval_gts.items() if "ken" in session
}

ken_preds = {
    session: preds for session, preds in eval_preds.items() if "ken" in session
}

ml.plot.room_histogram2d(ken_gts, ken_preds, concat=True)

In [None]:
ken_tiled_gts = {
    session: gt for session, gt in tiled_gts.items() if "ken" in session
}

ken_tiled_preds = {
    session: pred for session, pred in tiled_preds.items() if "ken" in session
}

ml.plot.confusion_matrix(
    ken_tiled_gts,
    ken_tiled_preds,
    group="y",
    include_sitting=False,
    concat=True,
)

In [None]:
bex_ken_gts = {
    "bex": np.concatenate(list(bex_gts.values()), axis=0),
    "ken": np.concatenate(list(ken_gts.values()), axis=0),
}

bex_ken_preds = {
    "bex": np.concatenate(list(bex_preds.values()), axis=0),
    "ken": np.concatenate(list(ken_preds.values()), axis=0),
}

ml.plot.room_histogram2d(
    bex_ken_gts,
    bex_ken_preds,
    # save_path="plots/pose_to_position/histogram_trial_fc.svg",
)

In [None]:
bex_ken_tiled_gts = {
    "bex": np.concatenate(list(bex_tiled_gts.values()), axis=0),
    "ken": np.concatenate(list(ken_tiled_gts.values()), axis=0),
}

bex_ken_tiled_preds = {
    "bex": np.concatenate(list(bex_tiled_preds.values()), axis=0),
    "ken": np.concatenate(list(ken_tiled_preds.values()), axis=0),
}

ml.plot.confusion_matrix(
    bex_ken_tiled_gts,
    bex_ken_tiled_preds,
    group="y",
    # include_sitting=True,
)

In [None]:
tile_size = 0.865

gts = {
    "ideal": np.array(
        [
            [1 * tile_size, 0 * tile_size + tile_size / 2],
            [2 * tile_size, 0 * tile_size + tile_size / 2],
            [3 * tile_size, 0 * tile_size + tile_size / 2],
            [1 * tile_size, 1 * tile_size + tile_size / 2],
            [2 * tile_size, 1 * tile_size + tile_size / 2],
            [3 * tile_size, 1 * tile_size + tile_size / 2],
            [1 * tile_size, 2 * tile_size + tile_size / 2],
            [2 * tile_size, 2 * tile_size + tile_size / 2],
            [3 * tile_size, 2 * tile_size + tile_size / 2],
            [1 * tile_size, 3 * tile_size + tile_size / 2],
            [2 * tile_size, 3 * tile_size + tile_size / 2],
            [3 * tile_size, 3 * tile_size + tile_size / 2],
            [1 * tile_size, 4 * tile_size + tile_size / 2],
            [2 * tile_size, 4 * tile_size + tile_size / 2],
            [3 * tile_size, 4 * tile_size + tile_size / 2],
        ]
    )
}
preds = {
    "ideal": np.array(
        [
            [3 * tile_size, 0 * tile_size + tile_size / 2],
            [2 * tile_size, 0 * tile_size + tile_size / 2],
            [1 * tile_size, 0 * tile_size + tile_size / 2],
            [3 * tile_size, 3 * tile_size + tile_size / 2],
            [2 * tile_size, 3 * tile_size + tile_size / 2],
            [1 * tile_size, 3 * tile_size + tile_size / 2],
            [3 * tile_size, 2 * tile_size + tile_size / 2],
            [2 * tile_size, 2 * tile_size + tile_size / 2],
            [1 * tile_size, 2 * tile_size + tile_size / 2],
            [3 * tile_size, 1 * tile_size + tile_size / 2],
            [2 * tile_size, 1 * tile_size + tile_size / 2],
            [1 * tile_size, 1 * tile_size + tile_size / 2],
            [3 * tile_size, 4 * tile_size + tile_size / 2],
            [2 * tile_size, 4 * tile_size + tile_size / 2],
            [1 * tile_size, 4 * tile_size + tile_size / 2],
        ]
    )
}
ml.plot.room_histogram2d(
    gts,
    preds,
    save_path="plots/pose_to_position/histogram_homing_ideal.svg",
)

In [None]:
ml.plot.confusion_matrix(
    eval_gts,
    eval_preds,
    # group="x",
    include_sitting=True,
    # save_path="plots/pose_to_position/confusion_matrix_nofilter_homing.svg",
)

In [None]:
ml.room.plot(save_path="plots/pose_to_position/room.svg")