In [1]:
# dappy env
import json
import h5py
import numpy as np
import pandas as pd
import os

def export_aligned_data_to_h5(
    data_obj, 
    rec_path, 
    frame_mapping_file, 
    out_file
):
    """
    Filters data_obj to a specific rec_path and frames from frame_mapping_file,
    then saves minimal data to an HDF5 file (with string columns stored
    as variable-length UTF-8 strings).
    """
    
    # 1. Filter by Prediction_path
    rec_mat_path = os.path.join(rec_path, "DANNCE/predict00/save_data_AVG.mat")
    path_data = data_obj.data[data_obj.data["Prediction_path"] == rec_mat_path].copy()
    if path_data.empty:
        print(f"No data found for rec_path: {rec_mat_path}")
        return
    

    # 2. Read the frame mapping JSON
    with open(frame_mapping_file, "r") as f:
        map_data = json.load(f)
    mapped_frames = set(map_data["mapped_sixcam_frame_indices"])
    time_offset = map_data["time_offset"]
    
    # 3. Filter rows by mapped frame indices (assuming your DataFrame has 'frame' column)
    if "frame" not in path_data.columns:
        raise ValueError("DataFrame does not have 'frame' column to filter by.")
    
    # Adjust frames so that they start at 0
    min_frame = path_data["frame"].min()
    path_data["frame"] = path_data["frame"] - min_frame

    # 1. Create helper offsets DataFrame
    offsets = pd.DataFrame({'offset': range(10)})

    # 2. Cross-merge (pandas 1.2+ supports `how="cross"`)
    expanded = path_data.merge(offsets, how='cross')

    # 3. Update the frame by adding the offset
    expanded['frame'] = expanded['frame'] + expanded['offset']
    expanded.drop(columns='offset', inplace=True)
    
    filtered_data = expanded[expanded["frame"].isin(mapped_frames)]
    
    if filtered_data.empty:
        print("No overlapping frames found between path_data and mapped_sixcam_frame_indices.")
    
    # 4. Save to HDF5
    os.makedirs(os.path.dirname(out_file), exist_ok=True)
    # Use variable-length string dtype for columns that are (or become) text
    variable_length_string_dt = h5py.special_dtype(vlen=str)
    
    with h5py.File(out_file, "w") as hf:
        
        #
        # (A) Save filtered DataFrame columns
        #
        grp = hf.create_group("filtered_data")
        
        for col in filtered_data.columns:
            col_data = filtered_data[col].to_numpy()  # get as NumPy array
            
            # Check if it's string-like (object, unicode, or bytes)
            if col_data.dtype.kind in ["O", "U", "S"]:
                # Convert each element to a Python string, store as variable-length UTF-8
                # Flatten, map to str, then reshape to original shape if multi-dimensional
                original_shape = col_data.shape
                col_data = col_data.reshape(-1)  # flatten
                col_data = np.array([str(item) for item in col_data], dtype=object)
                col_data = col_data.reshape(original_shape)
                
                grp.create_dataset(
                    col, 
                    data=col_data, 
                    dtype=variable_length_string_dt, 
                    compression="gzip"
                )
            else:
                # Numeric or other supported dtype can be written directly
                grp.create_dataset(col, data=col_data, compression="gzip")
        
        #
        # (B) Save relevant data_obj attributes
        #
        if hasattr(data_obj, "embed_vals") and data_obj.embed_vals is not None:
            hf.create_dataset("embed_vals", data=data_obj.embed_vals, compression="gzip")
        
        # Example: saving meta info
        if hasattr(data_obj, "meta") and data_obj.meta is not None:
            meta_grp = hf.create_group("meta")
            if isinstance(data_obj.meta, dict):
                for key, val in data_obj.meta.items():
                    # Convert to array for consistency
                    val_array = np.array(val, dtype=object)  # object to handle strings
                    # If it has any string/unicode, cast them properly
                    if val_array.dtype.kind in ["O", "U", "S"]:
                        val_array = val_array.reshape(-1)
                        val_array = np.array([str(item) for item in val_array], dtype=object)
                        # Reshape back if needed (only if it's consistent)
                        # But typically meta might be a 1D list, so might not need reshape
                        
                        meta_grp.create_dataset(
                            key,
                            data=val_array,
                            dtype=variable_length_string_dt,
                            compression="gzip"
                        )
                    else:
                        meta_grp.create_dataset(key, data=val_array, compression="gzip")
            else:
                # Non-dict meta structure
                meta_vals = np.array(data_obj.meta, dtype=object)
                if meta_vals.dtype.kind in ["O", "U", "S"]:
                    meta_vals = [str(item) for item in meta_vals.flatten()]
                    meta_vals = np.array(meta_vals, dtype=object)
                    meta_grp.create_dataset(
                        "meta_data",
                        data=meta_vals,
                        dtype=variable_length_string_dt,
                        compression="gzip"
                    )
                else:
                    meta_grp.create_dataset("meta_data", data=meta_vals, compression="gzip")
        
        
        # (C) Save frame mapping info
        #
        map_grp = hf.create_group("frame_mapping")
        map_grp.create_dataset(
            "mapped_sixcam_frame_indices",
            data=np.array(map_data["mapped_sixcam_frame_indices"]),
            compression="gzip"
        )
        map_grp.attrs["time_offset"] = time_offset

    print(f"Filtered data for '{rec_path}' saved to '{out_file}'")

    import pdb
    pdb.set_trace()


In [2]:
import pickle

trrrry = '60_p'

dts_p = "/home/lq53/mir_repos/dappy_24_nov/byws_version/250116_wav_ffix_ang_pos/50_p/datastruct.p"
# f"/home/lq53/mir_repos/dappy_24_nov/byws_version/250109_opti/{trrrry}/datastruct.p"
# Load the data structure
with open(dts_p, "rb") as f:
    loaded_data_obj = pickle.load(f)

rec_path = "/data/big_rim/rsync_dcc_sum/Oct3V1/2024_10_25/20241002PMCr2_17_05"
frame_mapping_file = "/data/big_rim/rsync_dcc_sum/Oct3V1/2024_10_25/20241002PMCr2_17_05/MIR_Aligned/frame_mapping.json"
out_file = "/home/lq53/mir_repos/dappy_24_nov/byws_version/250116_wav_ffix_ang_pos/50_p/test_250120.h5"
# f"/home/lq53/mir_repos/dappy_24_nov/byws_version/250109_opti/{trrrry}/aligned_mir_walalala_filtered_data.h5"

export_aligned_data_to_h5(
    data_obj=loaded_data_obj, 
    rec_path=rec_path, 
    frame_mapping_file=frame_mapping_file, 
    out_file=out_file
)


> [0;32m/tmp/ipykernel_37692/786061616.py[0m(30)[0;36mexport_aligned_data_to_h5[0;34m()[0m
[0;32m     28 [0;31m[0;34m[0m[0m
[0m[0;32m     29 [0;31m    [0;31m# 2. Read the frame mapping JSON[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 30 [0;31m    [0;32mwith[0m [0mopen[0m[0;34m([0m[0mframe_mapping_file[0m[0;34m,[0m [0;34m"r"[0m[0;34m)[0m [0;32mas[0m [0mf[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     31 [0;31m        [0mmap_data[0m [0;34m=[0m [0mjson[0m[0;34m.[0m[0mload[0m[0;34m([0m[0mf[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     32 [0;31m    [0mmapped_frames[0m [0;34m=[0m [0mset[0m[0;34m([0m[0mmap_data[0m[0;34m[[0m[0;34m"mapped_sixcam_frame_indices"[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m
          id    frame       AnimalID   Sex  Condition        date   time  \
312600  99.0  3126000  20241002PMCr2  male  miniscope  2024_10_25  17:05   
312601  99.0  3126010  20241002PMCr2  male  