# Sweep over `train_frames` and `losses_to_use`

In [1]:
import hydra
import json
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import seaborn as sns
import torch
from omegaconf import DictConfig

from lightning_pose.utils.io import return_absolute_data_paths
from lightning_pose.utils.scripts import get_imgaug_transform, get_dataset, get_data_module, get_loss_factories
from lightning_pose.losses.losses import PCALoss

import sys
sys.path.append('/home/jovyan/tracking-diagnostics')
from diagnostics.handler import ModelHandler
from diagnostics.io import get_base_config, get_keypoint_names

In [2]:
# %% get config
dataset_name = "rick-configs-1"
base_config_dir = "/home/jovyan/rick-configs-1"
base_save_dir = "/home/jovyan/"
cfg = get_base_config(config_dir=base_config_dir, config_name="config")

In [3]:

# load ground truth labels
csv_file = os.path.join(cfg.data.data_dir, cfg.data.csv_file)
csv_data = pd.read_csv(csv_file, header=list(cfg.data.header_rows))
keypoints_gt = csv_data.iloc[:, 1:].to_numpy().reshape(csv_data.shape[0], -1, 2)

keypoint_names = get_keypoint_names(csv_data, cfg.data.header_rows)

### This section loops over single `losses_to_use`. no combos.

Start by looping over `train_frames` and individual `losses_to_use`. Later build in complication.

In [4]:
save_dir = "/home/jovyan/lightning-pose"
loss_types = [[], ["pca_multiview"], ["pca_singleview"], ["temporal"], ["unimodal_mse"]] # TODO: add ["pca_singleview"] TODO: order matters
log_weight_list = [] # TODO: order matters
model_names = ["train_frames_sweep"]*len(loss_types)
supervised_model_name = model_names[0] # they all have the same name
train_frames_list = [50,75,100,125]
model_type = "heatmap"
handlers = []
name_strs_to_plot = []
for tr_fr_idx, train_frames in enumerate(train_frames_list):
    for loss_idx, loss in enumerate(loss_types):
        print("==========================")
        print("Searching for train_frames: {}, loss_type: {}...".format(train_frames, loss))
        # TODO: multi-loss version below, if needed
        # name_strs_to_plot = ['+'.join([l[:5] for l in loss]) if len(loss)>0 else 's' for loss in loss_types]
        name_str = '*'.join([loss[0][:5] if len(loss)==1 else 's', str(train_frames)])
        print("name_str: {}".format(name_str)) # TODO: just for single losses now
        name_strs_to_plot.append(name_str) 
        model_cfg = cfg.copy()
        model_cfg.training.train_frames = train_frames
        model_cfg.model.losses_to_use = loss # assume loss is already a list, [] if supervised
        model_cfg.model.model_name = model_names[loss_idx]
        model_cfg.model.model_type = model_type
        # specific arguments to "train_frames_sweep" models. TODO: change if needed
        model_cfg.training.train_prob=0.2
        model_cfg.training.val_prob=0.2
        model_cfg.training.min_epochs=125
        model_cfg.training.max_epochs=2000
        if len(loss) == 0:
            # support for uniquely-named supervised models
            model_cfg.model.model_name = supervised_model_name
        else:
            # loop over the sub losses
            if len(log_weight_list)>0:
                for sub_loss_idx,sub_loss in enumerate(loss):
                    model_cfg.losses[sub_loss].log_weight = log_weight_list[loss_idx][sub_loss_idx]
        
        try:
            handlers.append(ModelHandler(save_dir, model_cfg, verbose=False))
            print("Found: {}".format(model_cfg.model.model_name))
            print("In: {}".format(handlers[-1].model_dir))
        except FileNotFoundError:
            print('did not find %s model for train_frames=%i' % (loss, train_frames))
            continue
# report on the models found
print("==========================")
print("Found {} models out of {}.".format(len(handlers), len(train_frames_list)*len(loss_types)))

Searching for train_frames: 100, loss_type: []...
name_str: s*100
Found: train_frames_sweep
In: /home/jovyan/lightning-pose/multirun/2022-04-02/00-31-26/3
Searching for train_frames: 100, loss_type: ['pca_multiview']...
name_str: pca_m*100
Found: train_frames_sweep
In: /home/jovyan/lightning-pose/multirun/2022-04-02/00-31-26/18
Searching for train_frames: 100, loss_type: ['pca_singleview']...
name_str: pca_s*100
Found: train_frames_sweep
In: /home/jovyan/lightning-pose/multirun/2022-04-02/13-32-43/2
Searching for train_frames: 100, loss_type: ['temporal']...
name_str: tempo*100
Found: train_frames_sweep
In: /home/jovyan/lightning-pose/multirun/2022-04-02/00-31-26/8
Searching for train_frames: 100, loss_type: ['unimodal_mse']...
name_str: unimo*100
Found: train_frames_sweep
In: /home/jovyan/lightning-pose/multirun/2022-04-02/00-31-26/13
Found 5 models out of 5.


In [None]:
# loop over handlers and compute metrics
to_compute = "rmse" # | "rmse" | "pca_singleview" | "unimodal_mse"
keypoint_names = get_keypoint_names(csv_data, cfg.data.header_rows)
error_metric = "reprojection_error" # only for PCA
pca_loss = None
data_module = None

# store results here
if to_compute == "pca_singleview":
    # remove obstacle keypoints
    keypoint_names = [kp for kp in keypoint_names if kp not in ['obs_top','obsHigh_bot','obsLow_bot']]
    print(keypoint_names)

metrics_collected = {bp: [] for bp in keypoint_names}
# can change this as a function of train_frames
    
for hand_idx, handler in enumerate(handlers):
    print(hand_idx)
    print("name: {}".format(handler.cfg.model.model_name))
    print("losses_to_use: {}".format(handler.cfg.model.losses_to_use))
    print(handler.model_dir)
    if to_compute == 'rmse':
        y_label = 'RMSE per bodypart'
    elif to_compute == 'pca_multiview' or to_compute == 'pca_singleview':
        y_label = 'PCA reprojection error'
        model_cfg.model.losses_to_use = [to_compute] # TODO: not sure that makes sense here. assume loss is already a list, [] f
        model_cfg.training.train_frames = handler.cfg.training.train_frames
        data_dir, video_dir = return_absolute_data_paths(data_cfg=handler.cfg.data)
        imgaug_transform = get_imgaug_transform(cfg=handler.cfg)
        dataset = get_dataset(cfg=handler.cfg, data_dir=data_dir, imgaug_transform=imgaug_transform)
        data_module = get_data_module(cfg=handler.cfg, dataset=dataset, video_dir=video_dir)
        data_module.setup()
        # compute pca params
        loss_factories = get_loss_factories(cfg=model_cfg, data_module=data_module) # TODO: keeping model_cfg here for now
        pca_loss = loss_factories["unsupervised"].loss_instance_dict[to_compute]
    # compute metric
    try:
        result = handler.compute_metric(
            to_compute, 'predictions.csv',
            keypoints_true=keypoints_gt, pca_loss_obj=pca_loss, datamodule=data_module)
        print(result.shape)
    except FileNotFoundError:
        print('could not find model predictions')
        continue
    for b, bodypart in enumerate(keypoint_names):
        metrics_collected[bodypart].append(result[:, b])

In [None]:
print(name_strs_to_plot)
for col_name, metric in zip(name_strs_to_plot, metrics_collected[bodypart]):
    print(col_name)
    print(metric.shape)

In [None]:
# collect results
# TODO: currently ignorant of the train_frames. either fix here or above.               
results_df = []
for bodypart in keypoint_names:
    dict_tmp = {
        'bodypart': bodypart,
        #'rng_seed': rng_seed,
        'eval_mode': handlers[-1].pred_df.iloc[:, -1].to_numpy(),
        'img_file': csv_data.iloc[:, 0], # TODO: fix, this is wrong. should be a str not a float
    }
    for col_name, metric in zip(name_strs_to_plot, metrics_collected[bodypart]):
        dict_tmp[col_name] = metric
    results_df.append(pd.DataFrame(dict_tmp))

results_df = pd.concat(results_df)

In [None]:
results_df.head()

In [None]:
df_tmp = pd.melt(
    results_df, 
    id_vars=['bodypart', 'img_file', 'eval_mode'], 
    value_vars=name_strs_to_plot,
)
def add_loss_name_col(row):
    return '_'.join(row['variable'].split('_')[:-1])
def add_loss_val_col(row):
    return row['variable'].split('_')[-1]
df_tmp['loss'] = df_tmp.apply(add_loss_name_col, axis=1) # TODO: not doing anything but fails otherwise

In [None]:
df_tmp.head()

In [None]:
df_tmp["loss_type"] = df_tmp.variable.str.split('*').str[0]
df_tmp["train_frames"] = df_tmp.variable.str.split('*').str[1]

In [None]:
df_tmp

In [None]:
#have multiple columns
sns.set(context='talk', style='whitegrid', font_scale=1, rc = {'figure.figsize':(20,40)})
eval_mode = 'test'
saving_format = 'eps'
y_label = "Singleview PCA reconstruction error (pix.)"
hue_order = ['s', 'unimo', 'tempo', 'pca_s', 'pca_m']
num_losses_to_plot = 20 #len(cols_collected) # can exclude vals here
# average over keypoints and frames
df_tmp_ = df_tmp[df_tmp.eval_mode==eval_mode]
g = sns.catplot(
    data=df_tmp_, x='train_frames', hue="loss_type", y='value', kind='bar', hue_order=hue_order
)
# for ax in g.axes:
#     plt.setp(ax.get_yticklabels(), visible=True, rotation=None)
#g.set_xticklabels(rotation=80, ha='center')
# plt.title("train frames: {}, metric: {}, {} data".format(train_frames, to_compute, eval_mode))
plt.ylabel("{}".format(y_label))
plt.gcf().set_size_inches(12, 8)
plt.savefig('/home/jovyan/figs/{}_{}_{}_losses.{}'.format(to_compute, eval_mode, y_label, saving_format))

### This section loops over combos, no inividual losses
Loop over train_frames. Loop over all combos for a giving number of train frames.

In [4]:
save_dir = "/home/jovyan/lightning-pose"
loss_types = [["temporal","unimodal_mse"],["temporal","pca_multiview"],["temporal","unimodal_mse","pca_multiview"], \
    ["unimodal_mse","pca_multiview"], ["pca_singleview","unimodal_mse"],["pca_singleview","temporal"], \
        ["pca_singleview","pca_multiview"],["pca_singleview","unimodal_mse","temporal"],["pca_singleview","unimodal_mse","pca_multiview"],["pca_singleview","unimodal_mse","pca_multiview","temporal"]] # TODO: add ["pca_singleview"] TODO: order matters
log_weight_list = [] # TODO: order matters
model_names = ["train_frames_sweep"]*len(loss_types)
supervised_model_name = model_names[0] # they all have the same name
train_frames_list = [50,75,100,125]
model_type = "heatmap"
handlers = []
name_strs_to_plot = []
for tr_fr_idx, train_frames in enumerate(train_frames_list):
    for loss_idx, loss in enumerate(loss_types):
        print("==========================")
        print("Searching for train_frames: {}, loss_type: {}...".format(train_frames, loss))
        # TODO: multi-loss version below, if needed
        loss_str = '+'.join([l[:5] for l in loss] if len(loss)>0 else 's')
        # name_strs_to_plot = ['+'.join([l[:5] for l in loss]) if len(loss)>0 else 's' for loss in loss_types]
        name_str = '*'.join([loss_str, str(train_frames)])
        print("name_str: {}".format(name_str)) # TODO: just for single losses now
        name_strs_to_plot.append(name_str) 
        model_cfg = cfg.copy()
        model_cfg.training.train_frames = train_frames
        model_cfg.model.losses_to_use = loss # assume loss is already a list, [] if supervised
        model_cfg.model.model_name = model_names[loss_idx]
        model_cfg.model.model_type = model_type
        # specific arguments to "train_frames_sweep" models. TODO: change if needed
        model_cfg.training.train_prob=0.2
        model_cfg.training.val_prob=0.2
        model_cfg.training.min_epochs=125
        model_cfg.training.max_epochs=2000
        if len(loss) == 0:
            # support for uniquely-named supervised models
            model_cfg.model.model_name = supervised_model_name
        else:
            # loop over the sub losses
            if len(log_weight_list)>0:
                for sub_loss_idx,sub_loss in enumerate(loss):
                    model_cfg.losses[sub_loss].log_weight = log_weight_list[loss_idx][sub_loss_idx]
        
        try:
            handlers.append(ModelHandler(save_dir, model_cfg, verbose=False))
            print("Found: {}".format(model_cfg.model.model_name))
            print("In: {}".format(handlers[-1].model_dir))
        except FileNotFoundError:
            print('did not find %s model for train_frames=%i' % (loss, train_frames))
            continue
# report on the models found
print("==========================")
print("Found {} models out of {}.".format(len(handlers), len(train_frames_list)*len(loss_types)))

Searching for train_frames: 50, loss_type: ['temporal', 'unimodal_mse']...
name_str: tempo+unimo*50
Found: train_frames_sweep
In: /home/jovyan/lightning-pose/multirun/2022-04-02/00-31-26/21
Searching for train_frames: 50, loss_type: ['temporal', 'pca_multiview']...
name_str: tempo+pca_m*50
Found: train_frames_sweep
In: /home/jovyan/lightning-pose/multirun/2022-04-02/00-31-26/26
Searching for train_frames: 50, loss_type: ['temporal', 'unimodal_mse', 'pca_multiview']...
name_str: tempo+unimo+pca_m*50
Found: train_frames_sweep
In: /home/jovyan/lightning-pose/multirun/2022-04-02/00-31-26/31
Searching for train_frames: 50, loss_type: ['unimodal_mse', 'pca_multiview']...
name_str: unimo+pca_m*50
Found: train_frames_sweep
In: /home/jovyan/lightning-pose/multirun/2022-04-02/00-31-26/36
Searching for train_frames: 50, loss_type: ['pca_singleview', 'unimodal_mse']...
name_str: pca_s+unimo*50
Found: train_frames_sweep
In: /home/jovyan/lightning-pose/multirun/2022-04-02/13-32-43/4
Searching for tr

In [5]:
# Now for each model in a given training frame class, compute RMSE
pca_loss = None
data_module = None
to_compute = 'rmse'
metrics_collected = {bp: [] for bp in keypoint_names}
for hand_idx, handler in enumerate(handlers):
    print("idx: %i" % hand_idx)
    print("name: {}".format(handler.cfg.model.model_name))
    print("train_frames: {}".format(handler.cfg.training.train_frames))
    print("losses_to_use: {}".format(handler.cfg.model.losses_to_use))
    print(handler.model_dir)
    try:
        result = handler.compute_metric(
            to_compute, 'predictions.csv',
            keypoints_true=keypoints_gt, pca_loss_obj=pca_loss, datamodule=data_module)
    except FileNotFoundError:
        print('could not find model predictions')
        continue
    for b, bodypart in enumerate(keypoint_names):
        metrics_collected[bodypart].append(result[:, b])

idx: 0
name: train_frames_sweep
train_frames: 50
losses_to_use: ['temporal', 'unimodal_mse']
/home/jovyan/lightning-pose/multirun/2022-04-02/00-31-26/21
Metric: rmse
Computing RMSE...
idx: 1
name: train_frames_sweep
train_frames: 50
losses_to_use: ['temporal', 'pca_multiview']
/home/jovyan/lightning-pose/multirun/2022-04-02/00-31-26/26
Metric: rmse
Computing RMSE...
idx: 2
name: train_frames_sweep
train_frames: 50
losses_to_use: ['temporal', 'unimodal_mse', 'pca_multiview']
/home/jovyan/lightning-pose/multirun/2022-04-02/00-31-26/31
Metric: rmse
Computing RMSE...
idx: 3
name: train_frames_sweep
train_frames: 50
losses_to_use: ['unimodal_mse', 'pca_multiview']
/home/jovyan/lightning-pose/multirun/2022-04-02/00-31-26/36
Metric: rmse
Computing RMSE...
idx: 4
name: train_frames_sweep
train_frames: 50
losses_to_use: ['pca_singleview', 'unimodal_mse']
/home/jovyan/lightning-pose/multirun/2022-04-02/13-32-43/4
Metric: rmse
Computing RMSE...
idx: 5
name: train_frames_sweep
train_frames: 50
los

In [6]:
# TODO: currently ignorant of the train_frames. either fix here or above.               
results_df = []
for bodypart in keypoint_names:
    dict_tmp = {
        'bodypart': bodypart,
        #'rng_seed': rng_seed,
        'eval_mode': handlers[-1].pred_df.iloc[:, -1].to_numpy(),
        'img_file': csv_data.iloc[:, 0], # TODO: fix, this is wrong. should be a str not a float
    }
    for col_name, metric in zip(name_strs_to_plot, metrics_collected[bodypart]):
        dict_tmp[col_name] = metric
    results_df.append(pd.DataFrame(dict_tmp))

results_df = pd.concat(results_df)

In [7]:
results_df.head()

Unnamed: 0,bodypart,eval_mode,img_file,tempo+unimo*50,tempo+pca_m*50,tempo+unimo+pca_m*50,unimo+pca_m*50,pca_s+unimo*50,pca_s+tempo*50,pca_s+pca_m*50,...,tempo+unimo*125,tempo+pca_m*125,tempo+unimo+pca_m*125,unimo+pca_m*125,pca_s+unimo*125,pca_s+tempo*125,pca_s+pca_m*125,pca_s+unimo+tempo*125,pca_s+unimo+pca_m*125,pca_s+unimo+pca_m+tempo*125
0,paw1LH_top,unused,barObstacleScaling1/img1.png,107.167164,19.484978,49.218768,27.729192,11.284295,81.7205,40.422242,...,10.165655,13.613108,21.551051,5.469441,7.278329,39.892265,9.262841,20.134874,37.455921,7.821145
1,paw1LH_top,validation,barObstacleScaling1/img2.png,32.32532,30.303992,28.153266,39.417194,21.096111,35.661139,29.156698,...,22.072084,32.700153,71.764882,79.929805,68.726678,19.663656,80.856533,69.370396,20.622683,67.822863
2,paw1LH_top,test,barObstacleScaling1/img3.png,22.602413,7.744324,4.759023,66.695939,5.981631,14.399546,56.574518,...,3.976566,3.637896,5.1191,3.901515,3.260632,2.930602,2.910508,1.984852,2.912916,3.12086
3,paw1LH_top,test,barObstacleScaling1/img4.png,4.198313,4.597901,1.87054,3.126915,3.041861,3.458074,3.763082,...,5.618093,2.670412,2.504021,3.312007,4.087109,4.350227,2.293038,1.783343,3.953196,1.600337
4,paw1LH_top,test,barObstacleScaling1/img5.png,7.271259,7.495884,7.345606,8.308105,8.733459,9.096077,9.505998,...,10.088732,8.674619,8.053103,8.232509,8.330179,7.980887,8.358817,7.65239,8.355274,8.799086


In [8]:
# go ahead and separate the names
df_tmp = pd.melt(
    results_df, 
    id_vars=['bodypart', 'img_file', 'eval_mode'],
    value_vars=name_strs_to_plot,
)
def add_loss_name_col(row):
    return '_'.join(row['variable'].split('_')[:-1])
def add_loss_val_col(row):
    return row['variable'].split('_')[-1]
df_tmp['loss'] = df_tmp.apply(add_loss_name_col, axis=1) # TODO: not doing anything

In [9]:
df_tmp["loss_type"] = df_tmp.variable.str.split('*').str[0]
df_tmp["train_frames"] = df_tmp.variable.str.split('*').str[1]

In [10]:
df_tmp.head()

Unnamed: 0,bodypart,img_file,eval_mode,variable,value,loss,loss_type,train_frames
0,paw1LH_top,barObstacleScaling1/img1.png,unused,tempo+unimo*50,107.167164,,tempo+unimo,50
1,paw1LH_top,barObstacleScaling1/img2.png,validation,tempo+unimo*50,32.32532,,tempo+unimo,50
2,paw1LH_top,barObstacleScaling1/img3.png,test,tempo+unimo*50,22.602413,,tempo+unimo,50
3,paw1LH_top,barObstacleScaling1/img4.png,test,tempo+unimo*50,4.198313,,tempo+unimo,50
4,paw1LH_top,barObstacleScaling1/img5.png,test,tempo+unimo*50,7.271259,,tempo+unimo,50


In [19]:
# TODO: continue here. this will allow us to pick the best performing combo of losses for each train_frames
df_eval = df_tmp[df_tmp.eval_mode=="validation"]
df_grouped = df_eval.groupby(['loss_type', 'train_frames'])["value"].mean().reset_index()
print(df_grouped.shape)
df_grouped.groupby(['train_frames'])["value"].idxmin()
#df_grouped.head()
#df_grouped.groupby(['train_frames'])["value"].min().reset_index()

(40, 3)


train_frames
100    16
125    17
50     14
75     27
Name: value, dtype: int64

In [None]:
sns.set(context='talk', style='whitegrid', font_scale=1, rc = {'figure.figsize':(20,40)})
eval_mode = 'test'
saving_format = 'eps'
y_label = "Pixel error"
hue_order = None # ['s', 'unimo', 'tempo', 'pca_s', 'pca_m']
num_losses_to_plot = 20 #len(cols_collected) # can exclude vals here
# average over keypoints and frames
df_tmp_ = df_tmp[df_tmp.eval_mode==eval_mode]
g = sns.catplot(
    data=df_tmp_, x='train_frames', hue="loss_type", y='value', kind='bar', hue_order=hue_order
)
# for ax in g.axes:
#     plt.setp(ax.get_yticklabels(), visible=True, rotation=None)
#g.set_xticklabels(rotation=80, ha='center')
# plt.title("train frames: {}, metric: {}, {} data".format(train_frames, to_compute, eval_mode))
plt.ylabel("{}".format(y_label))
plt.gcf().set_size_inches(12, 8)
#plt.savefig('/home/jovyan/figs/{}_{}_{}_losses.{}'.format(to_compute, eval_mode, y_label, saving_format))