In [None]:
import os
import numpy as np
import scipy.io as sio
import imageio
import tqdm
from projection import *
import connectivity

import matplotlib
#matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.animation import FFMpegWriter
import projection, connectivity
from com_traga_utlis import find_calib_file, load_data

def com_label_valid(com_data, base_folder, graph_title, save_path, cam='Camera1'):
    #######
    N_FRAMES = 100
    START_FRAME = 0
    #######
    label3d_path = find_calib_file(base_folder)
    video_path = os.path.join(base_folder, f'videos/{cam}/0.mp4')
    vid_title = graph_title
    VID_NAME = vid_title + 'continued.mp4'

    ###############################################################################################################
    # load camera parameters
    cameras = load_cameras(label3d_path)

    # get dannce predictions
    # pred_3d = sio.loadmat(pred_path)['pred'][START_FRAME: START_FRAME+N_FRAMES]
    pts = com_data[START_FRAME: START_FRAME+N_FRAMES] #
    # print(len(pts))
    # N_FRAMES = len(jump_indices) #jump_indices
    # print('N_FRAMES', N_FRAMES)
    # compute projections
    pred_2d = {}
    # pose_3d = np.transpose(pred_3d, (0, 2, 1))
    # pts = np.reshape(pose_3d, (-1, 3))

    # get the 2d projection
    projpts = project_to_2d(pts,
                            cameras[cam]["K"],
                            cameras[cam]["r"],
                            cameras[cam]["t"])[:, :2]

    projpts = distortPoints(projpts,
                            cameras[cam]["K"],
                            np.squeeze(cameras[cam]["RDistort"]),
                            np.squeeze(cameras[cam]["TDistort"]))
    projpts = projpts.T
    projpts = np.reshape(projpts, (N_FRAMES, -1, 2))
    pred_2d[cam] = projpts
    print('pred_2d',len(projpts))


    del projpts#, pred_3d

    # open videos
    vids = imageio.get_reader(video_path)

    # set up video writer
    metadata = dict(title='dannce_visualization', artist='Matplotlib')
    writer = FFMpegWriter(fps=30, metadata=metadata) # orig fps = 30., sihan uses 20, use 0.5 for com debug

    ###############################################################################################################
    fig = plt.figure()
    plt.rcParams['figure.figsize'] = (6, 6)

    with writer.saving(fig, os.path.join(save_path, "vis_"+VID_NAME), dpi=300):
        for curr_frame, i in enumerate(tqdm.tqdm(range(N_FRAMES))): #jump_indices , i
            # print(curr_frame, i)
            plt.clf()
            # grab images
            # imgs = vids.get_data(i)
            imgs = vids.get_data(i+START_FRAME)
            # imgs = [vids.get_data(curr_frame+START_FRAME)][0]
            kpts_2d = pred_2d[cam][curr_frame]
            
            # temp_kpts_2d = np.r_[kpts_2d[0:6,:],kpts_2d[8:,:]]
            
            # Zoom in based on keypoints
            # adjust_viewport(temp_kpts_2d, margin=150)  # Adjust margin as needed for best fit


            plt.imshow(imgs)
            plt.scatter(kpts_2d[:, 0], kpts_2d[:, 1], marker='.', color='white', linewidths=2, alpha=0.5) #point size

            # for color, (index_from, index_to) in zip(COLOR, CONNECTIVITY):
            #     xs, ys = [np.array([kpts_2d[index_from, j], kpts_2d[index_to, j]]) for j in range(2)]
            #     plt.plot(xs, ys, c=color, lw=2) #line error
            #     del xs, ys

            plt.title(vid_title)
            plt.axis("off")
            
            
            writer.grab_frame()


def adjust_viewport(kpts_2d, margin=70):
    """
    Adjust the plot's viewport based on keypoints.
    :param kpts_2d: Keypoints for the current frame.
    :param margin: Extra space around the keypoints to ensure they are not on the edge.
    """
    # This method is way too shaky
    # min_x, max_x = np.min(kpts_2d[:, 0]), np.max(kpts_2d[:, 0])
    # min_y, max_y = np.min(kpts_2d[:, 1]), np.max(kpts_2d[:, 1])
    # plt.xlim([min_x - margin, max_x + margin])
    # plt.ylim([max_y + margin, min_y - margin])

    center_x = np.mean(kpts_2d[:, 0])
    center_y = np.mean(kpts_2d[:, 1])
    plt.xlim([center_x - margin, center_x + margin])
    plt.ylim([center_y + margin, center_y - margin])

def dannce_label_valid(base_path, pred_folder, pred_name="AVG0", cam="Camera6", N_FRAMES=100, START_FRAME=0, smooth = False):
    ###############################################################################################################

    # pred_folder = "DANNCE/predict_results/six_points/non_multi_bryan_240722_full_trained_test_1000frames"
    video_path = os.path.join(base_path, f'videos/{cam}/0.mp4')
    label3d_path = find_calib_file(base_path)
    smoothed = "smoothed_prediction_AVG0.mat"
    avg0 = f"save_data_{pred_name}.mat"
    if smooth:
        pred_path = os.path.join(base_path, pred_folder, smoothed)
        vid_title = f'combined_{cam}_smoothed_{N_FRAMES}_start{START_FRAME}'
    else:
        pred_path = os.path.join(base_path, pred_folder, avg0)
        vid_title = f'combined_{cam}_{pred_name}_{N_FRAMES}_start{START_FRAME}'
     
    # N_FRAMES = 1000
    # START_FRAME = 0
    ANIMAL= 'mouse20'
    
    VID_NAME = vid_title + '.mp4'
    COLOR = connectivity.COLOR_DICT[ANIMAL]
    CONNECTIVITY = connectivity.CONNECTIVITY_DICT[ANIMAL]
    save_path = os.path.join(base_path, pred_folder, 'vis') #os.path.join(pred_path, 'vis')
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    com_file = os.path.join(base_path,pred_folder,'com3d_used.mat')
    com_data = sio.loadmat(com_file)
    ###############################################################################################################
    # load camera parameterss
    cameras = load_cameras(label3d_path)

    # get dannce predictions
    pred_3d = sio.loadmat(pred_path)['pred'][START_FRAME: START_FRAME+N_FRAMES]

    # compute projections
    pred_2d = {}
    pose_3d = np.transpose(pred_3d, (0, 2, 1))
    pts = np.reshape(pose_3d, (-1, 3))


    # get the 2d projection
    projpts = project_to_2d(pts,
                            cameras[cam]["K"],
                            cameras[cam]["r"],
                            cameras[cam]["t"])[:, :2]

    projpts = distortPoints(projpts,
                            cameras[cam]["K"],
                            np.squeeze(cameras[cam]["RDistort"]),
                            np.squeeze(cameras[cam]["TDistort"]))
    projpts = projpts.T
    projpts = np.reshape(projpts, (N_FRAMES, -1, 2))
    pred_2d[cam] = projpts


    del projpts, pred_3d


    ###############################3
    # for com
    pts_com = com_data['com'][START_FRAME: START_FRAME+N_FRAMES]
    pred_2d_com = {}
    # Get the 2d projection for com
    projpts_com = project_to_2d(pts_com,
                                cameras[cam]["K"],
                                cameras[cam]["r"],
                                cameras[cam]["t"])[:, :2]

    projpts_com = distortPoints(projpts_com,
                                cameras[cam]["K"],
                                np.squeeze(cameras[cam]["RDistort"]),
                                np.squeeze(cameras[cam]["TDistort"]))
    projpts_com = projpts_com.T
    projpts_com = np.reshape(projpts_com, (N_FRAMES, -1, 2))
    pred_2d_com[cam] = projpts_com
    del projpts_com
    #####################3


    # open videos
    vids = imageio.get_reader(video_path)

    # set up video writer
    metadata = dict(title='combined_visualization', artist='Matplotlib')
    writer = FFMpegWriter(fps=20, metadata=metadata) # orig fps = 30.

    ###############################################################################################################
    fig = plt.figure()
    plt.rcParams['figure.figsize'] = (6, 6)





    with writer.saving(fig, os.path.join(save_path, "vis_"+VID_NAME), dpi=300):
        for curr_frame in tqdm.tqdm(range(N_FRAMES)):
            plt.clf()
            # grab images
            imgs = [vids.get_data(curr_frame+START_FRAME)][0]
            kpts_2d = pred_2d[cam][curr_frame]
            
            temp_kpts_2d = np.r_[kpts_2d[0:6,:],kpts_2d[8:,:]]

            # Plot com keypoints
            kpts_2d_com = pred_2d_com[cam][curr_frame]
            temp_kpts_2d_com = np.r_[kpts_2d_com[0:6,:],kpts_2d_com[8:,:]]
            
            # Zoom in based on keypoints
            adjust_viewport(temp_kpts_2d, margin=450)  # Adjust margin as needed for best fit 150 is good.


            plt.imshow(imgs)
            
            # Plot com points
            plt.scatter(kpts_2d_com[:, 0], kpts_2d_com[:, 1], marker='.', color='red', linewidths=2, alpha=0.5)

            plt.scatter(temp_kpts_2d[:, 0], temp_kpts_2d[:, 1], marker='.', color='white', linewidths=2, alpha=0.5) #point size

            for color, (index_from, index_to) in zip(COLOR, CONNECTIVITY):
                xs, ys = [np.array([kpts_2d[index_from, j], kpts_2d[index_to, j]]) for j in range(2)]
                plt.plot(xs, ys, c=color, lw=2) #line error
                del xs, ys

            plt.title(vid_title)
            plt.axis("off")
            
            writer.grab_frame()

In [None]:
weired_folders = [

    '/home/lq53/mir_data/24summ/2024_07_19/240605PMC_window2_right2holes_12_14', #validation set

]

for wie in weired_folders:
    # note that there exist a function plot_com_all which i wrote before that can easily do below, 
    # but for somereason i used this just that we can be a bit more flexible in terms of the name of the folders and stuff...

    # /home/lq53/mir_data/24summ/2024_06_26/1686940_left/COM_df/predict_results//vis/vis_combined_Camera2_1000_from_0.mp4
    com_foler = os.path.join(wie, 'COM_df_final/predict_results/')
    com_path = os.path.join(com_foler, 'com3d.mat')   
    com_folder_save = os.path.join(com_foler, 'vis')
    if not os.path.exists(com_folder_save):
        os.makedirs(com_folder_save)
    graph_title = "z_com3d_vis_"
    
    com_data = load_data(com_path)

    # plot_3d_trajectory(com_data, graph_title, com_folder_save)
    # jump_indices = detect_jumps(com_data, com_folder_save)

    # # # produce video, which is not necessary if not labeling more com to detect what's wrong


    save_path = com_folder_save # os.path.join(com_foler, 'vis') #os.path.join(pred_path, 'vis')
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    base_folder = wie
    base_base_folder = os.path.dirname(base_folder)
    # generate_jump_video(com_data, base_folder, jump_indices, graph_title, save_path, cam='Camera1')
    generate_com_video(com_data, base_folder, graph_title, save_path, cam='Camera1')
# generate_com_video_choice # without that much things, just generate. quick iteration made for one click pipeline...

    

In [14]:
import scipy.io as sio
label = '/home/lq53/mir_data/24summ/2024_07_03/1691486_left_right_habituation/labels/df_mir_com_30_rand_20240815_162122_Label3D_dannce.mat'
labell = sio.loadmat(label)#['labelData']#[0][0][0][0]#[0] #[0]
labell

{'__header__': b'MATLAB 5.0 MAT-file Platform: posix, Created on: Thu Aug 22 14:32:27 2024',
 '__version__': '1.0',
 '__globals__': [],
 'labelData': array([[array([[(array([[1434.26302657, 1089.13278237],
                        [1432.2040044 , 1091.74032078],
                        [ 675.42795788,  192.16078175],
                        [ 651.86982034,  514.28181416],
                        [ 662.88520176,  513.85487106],
                        [1058.13919187, 1091.4352583 ],
                        [2037.61167386, 1220.75600805],
                        [1428.34878108, 1037.47249942],
                        [1349.44293775,  993.68318783],
                        [ 376.97048783, 1305.35667485],
                        [ 811.7083806 ,   78.25533452],
                        [1479.06006233, 1069.03539086],
                        [1282.21787609, 1049.66212374],
                        [1381.56004973, 1158.71282844],
                        [1364.86421741, 1183.80063821],
          