In [5]:
import os
import re
from scipy.io import loadmat, savemat
import numpy as np
from pprint import pp
import shutil
import tempfile
import uuid
from pathlib import Path
import csv
import datetime

In [6]:
# GLOBAL CONSTANTS - DO NOT CHANGE THIS SECTION -

LASTCHANGE_STRING = "2024-08-18_flip_axes"

# Matrix tranform to flip x,y axes and negate z axis
# this is multiplied by each 3d point and by each rotation matrix
transform = np.array([[0, 1, 0], [ 1, 0, 0] , [ 0, 0, -1]])


In [None]:
def save_backup_samedir(src_filepath):
    """For each file, create a copy of it with a .old postfix in the same directory"""
    assert Path(src_filepath).exists(), "src_filepath must exist"
    src_filepath_str = str(Path(src_filepath).resolve())

    candidate_path = src_filepath
    
    # try 1000 times to make a unique filename then throw an error if it failed
    tries = 0
    while True:
        candidate_path = Path(f"{candidate_path}.old")
        if candidate_path.exists():
            tries += 1
        else:
            break

        if tries > 1000: 
            raise Exception("Unable to generate unique file name. Tried 1000 times already.")

    dest_filepath = candidate_path
    
    shutil.copy2(src_filepath, dest_filepath)
    print(f"BACKUP src={src_filepath} dest={dest_filepath}")


def update_hires_file(filepath, saveto=None):    
    print(f"update_hires_file: {filepath}")    
    
    mat1 = loadmat(filepath)
    
    if 'lastchange' in mat1 and mat1['lastchange'] == LASTCHANGE_STRING:
        raise Exception("SKIPPING BECAUSE UPDATE WAS ALREADY APPLIED")

    save_backup_samedir(filepath)

    
    mat1['lastchange'] = LASTCHANGE_STRING
    mat1['r'] =transform @ mat1['r']

    if saveto is not None:
        print(f"Saving update mat file as: {saveto}")
        savemat(saveto, mat1)
    else:
        print("Not saving file because saveto not specified")


def update_label3d_working_file(filepath, saveto=None):
    """Fix label3d save files (these can be loaded in label3d). This is NOT 
    the label3d exported files used to train dannce (those are *dannce.mat)"""
    
    mat3 = loadmat(filepath)
    
    if 'lastchange' in mat3 and mat3['lastchange'] == LASTCHANGE_STRING:
        raise Exception("SKIPPING BECAUSE UPDATE WAS ALREADY APPLIED")


    save_backup_samedir(filepath)

    
    n_cams = mat3['camParams'].shape[0]
    n_samples = mat3['data_3D'].shape[0]
    n_joints =  mat3['data_3D'].shape[1]//3
    
    
    # update 3d points (labeld 3d-space)
    # format is [x y z x2 y2 z2 ... xN yN zN] for each row
    # mat3d['data_3D'].shape = (75,69)
    dat = mat3['data_3D'] # old shape: (75, 69)
    dat = dat.reshape(n_samples, n_joints, 3) # new shape: (75, 23, 3)
    print("fixing all 3d points for #samples={n_samples}")
    for joint_idx in range(n_joints):
        print(f"updating on joint_idx={joint_idx}")
        joint_dat = dat[:,joint_idx,:] #shape: (75,3)
        joint_dat = (transform @ joint_dat.T).T
        dat[:,joint_idx,:] = joint_dat
    mat3['data_3D'] = dat.reshape(n_samples, 3*n_joints)
    
    # update camera parameters (rotation mtx) for each camera
    # mat3['camParams'].shape = (6,1)
    for cam_idx in range(n_cams):
        print(f"updating rot mtx for cam_idx={cam_idx}")
        
        this_rot_mtx = mat3['camParams'][cam_idx,0]['r'][0,0]
        this_rot_mtx = transform @ this_rot_mtx
        mat3['camParams'][cam_idx,0]['r'][0,0] = this_rot_mtx
    
    mat3['lastchange'] = LASTCHANGE_STRING

    if saveto is not None:
        print(f"Saving update mat file as: {saveto}")
        savemat(saveto, mat3)
    else:
        print("Not saving file because saveto not specified")


def update_dannce_mat_file(filepath,saveto=None):
    print(f"update_dannce_mat_file: {filepath}")

    mat4 = loadmat(filepath)
    
    if 'lastchange' in mat4 and mat4['lastchange'] == LASTCHANGE_STRING:
        raise Exception("SKIPPING BECAUSE UPDATE WAS ALREADY APPLIED")

    save_backup_samedir(filepath)
    
    mat4['lastchange'] = LASTCHANGE_STRING
    
    n_cams = mat4['labelData'].shape[0]
    n_samples = mat4['labelData'][0,0][0,0]['data_3d'].shape[0]
    n_joints =  mat4['labelData'][0,0][0,0]['data_3d'].shape[1]//3
    
    print(f"n_cams={n_cams}, n_samples={n_samples}, n_joints={n_joints}")
    
    # update camera parameters (rotation mtx) for each camera
    # mat4['camParams'].shape = (6,1)
    for cam_idx in range(n_cams):
        print(f"updating rot mtx for cam_idx={cam_idx}")
        
        this_rot_mtx = mat4['params'][cam_idx,0]['r'][0,0]
        this_rot_mtx = transform @ this_rot_mtx
        mat4['params'][cam_idx,0]['r'][0,0] = this_rot_mtx
    
    
    for cam_idx in range(n_cams):
        print(f"updating all 3d pts in labelData[cam_idx] -> data_3d for cam_idx={cam_idx}")
        
        dat = mat4['labelData'][cam_idx,0][0,0]['data_3d']  # old shape: (35, 69)
        dat = dat.reshape(n_samples, n_joints, 3) # new shape: (35, 23, 3)
        for joint_idx in range(n_joints):
            # print(f"updating on joint_idx={joint_idx}")
            joint_dat = dat[:,joint_idx,:] #shape: (75,3)
            joint_dat = (transform @ joint_dat.T).T
            dat[:,joint_idx,:] = joint_dat
        mat4['labelData'][cam_idx,0][0,0]['data_3d'] = dat.reshape(n_samples, 3*n_joints)
    
    if saveto is not None:
        print(f"Saving update mat file as: {saveto}")
        savemat(saveto, mat4)
    else:
        print("Not saving file because saveto not specified")


def update_com3d_file(filepath, saveto=None):
    print(f"update_com3d_file: {filepath}")
    mat5 = loadmat(filepath)

    if 'lastchange' in mat5 and mat5['lastchange'] == LASTCHANGE_STRING:
        raise Exception("SKIPPING BECAUSE UPDATE WAS ALREADY APPLIED")

    save_backup_samedir(filepath)

    n_samples = mat5['com'].shape[0]
    n_joints =  1 # always 1 for COMs
    
    
    # update 3d points (labeld 3d-space)
    # format is [x y z x2 y2 z2 ... xN yN zN] for each row
    dat = mat5['com'] # old shape: (90000, 3)
    dat = dat.reshape(n_samples, n_joints, 3) # new shape: (90000, 1, 3)
    print("fixing all 3d points for #samples={n_samples}")
    for joint_idx in range(n_joints):
        print(f"updating on joint_idx={joint_idx}")
        joint_dat = dat[:,joint_idx,:] #shape: (75,3)
        joint_dat = (transform @ joint_dat.T).T
        dat[:,joint_idx,:] = joint_dat
    mat5['com'] = dat.reshape(n_samples, 3*n_joints)
    mat5['lastchange'] = LASTCHANGE_STRING

    if saveto is not None:
        print(f"Saving update mat file as: {saveto}")
        savemat(saveto, mat5)
    else:
        print("Not saving file because saveto not specified")


**Note: label3d working files are commented out**

In [None]:
# find the following:
#   1. dannce.mat
#   2. label3d working files (_Label3D.mat)
#   3. calibration (hires_camX_params.mat) files
#   4. com3d.mat files

# specify foldes to do the glob search on:

base_paths = [ 
    '/n/olveczky_lab_tier1/Lab/dannce_rig2/data/M1-M7_photometry/Alone', SYNTAX_ERROR_UPDATE_THIS_WITH_YOUR_OWN_FILES
    '/n/holylabs/LABS/olveczky_lab/Lab/dannce-dev/hannah-data/COM_DANNCE_TRAINING'
]

dannce_mat_files = []
calibration_files = []
com3d_files = []

for base_path in base_paths:
    for p in Path(base_path).glob("**/*_dannce.mat"):
        if p.name.startswith("._"):
            continue
        dannce_mat_files.append(str(p))

    for p in Path(base_path).glob("**/calibration/hires_cam*_params.mat"):
        if p.name.startswith("._"):
            continue
        calibration_files.append(str(p))

    for p in Path(base_path).glob("**/*com3d.mat"):
        if p.name.startswith("._"):
            continue
        com3d_files.append(str(p))


dannce_mat_files = sorted(dannce_mat_files)
calibration_files = sorted(calibration_files)
com3d_files = sorted(com3d_files)

print("\nDANNCE Mat files")
pp(dannce_mat_files)

print("\nCalibration Files")
pp(calibration_files)

print("\nCom3D Files")
pp(com3d_files)

allfiles = []
allfiles.extend(dannce_mat_files)
allfiles.extend(calibration_files)
allfiles.extend(com3d_files)

# UPDATE THIS TO MAKE PATHS RELATIVE TO SOME BASE PATH
# e.g. /n/holylabs/LABS/olveckzy/Alone/M1/dannce.mat -> ./Alone/M1/dannce.mat
allfiles = list(map(lambda x: x.replace(r"/net/holy-nfsisilon/ifs/rc_labs/olveczky_lab_tier1/Lab/dannce_rig2/data/M1-M7_photometry", "."), allfiles))
allfiles = list(map(lambda x: x.replace(r"/n/holylabs/LABS/olveczky_lab/Lab", "."), allfiles))

# create rsync file list
with open("file-list", "wt") as f:
    for p in allfiles:
        f.write(f"{p}\n")

####
# You can use RSYNC to create a clone of files listed in allfiles
### HOWEVER YOU NEED TO DO THIS 1x BASEPATH AT A TIME ###
### e.g. comment out other so that base_paths list len is 1
### and re-run above code and rsync command for each

### EXAMPLE RSYNC COMMAND
# rsync -avR --files-from=./file-list /n/olveczky_lab_tier1/Lab/dannce_rig2/data/M1-M7_photometry ./backup-files-rot-fix

### NOTE ON RSYNC ARGS:
### rsync -avR --files-from=FILE_LIST RSYNC_SRC RSYNC_DEST
### --avR : use archive mode, print extra output (verbose), and build relative tree (relative to dot "." in file list)
### --files-from=FILE_LIST : FILE_LIST is a text file with where each line is a file to copy, relative to RSYNC_SRC
### RSYNC_SRC : directory to look for files-from within. This is also where the dest file paths are relative to.
### RSYNC_DEST : directory where the file copies will be stored.


In [None]:
### RUN FIX ON EACH LIST OF FILES ###
######################################################################
### I WOULD UN-COMMENT AND RUN ONE GROUP AT A TIME 
### IF THERE IS AN ERROR WITH ONE OF THE FILES, YOU CAN CONTINUE AFTER
###    A CERTAIN INDEX. E.G. 5
###
### for i,f in enumerate(calibration_files[5:]): ...
#######################################################################

# for i,f in enumerate(calibration_files):
#     print (f"calib_file: [{i}]:" ,f)
#     update_hires_file(f, f)

# for i,f in enumerate(dannce_mat_files):
#     print (f"dannce.mat: [{i}]:" ,f)
#     update_dannce_mat_file(f,f)

# for i,f in enumerate(com3d_files):
#     print (f"c3d: [{i}]:" ,f)
#     update_com3d_file(f,f)

In [None]:
# example code to verify all z-values are positive in COM files

for fname in com3d_files:
    mtmp = loadmat(fname)
    # print("KEYS", mtmp.keys())
    n_samples = mtmp['com'].shape[0]
    print("FNAME: ", fname)
    sum_gt_z = sum(mtmp['com'][:,2] > 0)
    print(f"COM's with z-coord > zero {sum_gt_z}/{n_samples}\n")

