# Generates and Checks the Trial Info Databases for Each Mouse

In [6]:
from scipy.io import loadmat
import pandas as pd
import numpy as np
import h5py
import matplotlib.pyplot as plt
import os

In [7]:
def behavior_path(mouse):
    """Returns the file path for behavioral data of the given mouse."""
    return '//tsclient/T7 Shield2/BehaviorDataBackup/VoltageMice/' + mouse + "/PerceptualData_" + mouse +"_all.mat"

def movie_path(mouse, date, file):
    """Returns the file path for the corresponding movie data of a given mouse, date, and file."""
    return "N:/GEVI_Wave/Analysis/Visual/" + mouse + "/20" + str(date) + "/" + file + '/cG_unmixed_dFF.h5'

## Generating Dataframe

In [8]:
def flatten_cell(cell):
    """Extracts the first element from a nested list or array, or returns None if empty."""
    # Handle nested lists or arrays
    if isinstance(cell, (list, np.ndarray)):
        return cell[0] if len(cell) > 0 else None  # Extract the first element if not empty
    return cell

def perceptual_category(df, data_behavior):
    """Categorizes trials based on perceptual contrast levels using binning."""
    percept_cats = pd.DataFrame(data_behavior["ConCrit"])

    # Extract min and max contrast values from percept_cats
    min_values = percept_cats.iloc[1:, 0].to_list()  # First column (min values)
    max_values = percept_cats.iloc[1:, 1].to_list()  # Second column (max values)
    bins = [min_values[0]] + max_values  # Combine min start and max values to define bin edges

    # Use pd.cut to categorize contrasts
    categories = pd.cut(
        df["Contrast"],
        bins=bins,  # Bin edges
        labels=[1, 2, 3],  # Assign category labels
        right=True,  # Include the right edge of bins
        include_lowest=True # Include 0 for category 1
    )

    # Add the new column to the DataFrame
    df.insert(5, "PerceptualCat", categories)
    return df

def categorize_trial(row):
    """Assigns each trial to a specific category based on behavioral conditions."""
    if (row["ReactTime"] < 0.25) or (row["Enticed?"] == 1) or ((row["Rewarded?"] == 1) and (row["Consume?"] == 0)):
        return "Error (0)"

    # Assign valid trials to categories
    if row["Rewarded?"] == 1 and row["Attrition?"] == 0:
        if row["PerceptualCat"] == 1:
            return "False Alarm (1)"
        elif row["PerceptualCat"] == 2:
            return "MC Hit (2)"
        elif row["PerceptualCat"] == 3:
            return "HC Hit (3)"
    
    if row["Rewarded?"] == 0 and row["Attrition?"] == 0:
        if row["PerceptualCat"] == 1:
            return "Correct Rejection (4)"
        elif row["PerceptualCat"] == 2:
            return "MC Miss (5)"
        elif row["PerceptualCat"] == 3:
            return "Incorrect Reject (6)"

    if row["Rewarded?"] == 0 and row["Attrition?"] == 1:
        if row["PerceptualCat"] == 1:
            return "LC No Report (7)"
        elif row["PerceptualCat"] == 2:
            return "MC No Report (8)"
        elif row["PerceptualCat"] == 3:
            return "HC No Report (9)"
    
    return "Uncategorized"

def trial_category(df, print_results=False):
    """Classifies trials into behavioral categories and optionally prints a summary."""
    # Apply function to each row

    df.loc[(df['Attrition?'] == 1) & (df['Rewarded?'] == 1), 'Attrition?'] = 0

    df["TrialType"] = df.apply(categorize_trial, axis=1)

    if print_results:
        # Print summary
        print("\nTrial Type Summary:")
        summary = df["TrialType"].value_counts()
        for category, count in summary.items():
            print(f"{category}: {count} trials")

    return df

def map_recording_files(df):
    """Maps recording file numbers to each trial based on mouse and date.""" 
    base_path = "//tsclient/T7 Shield2/BehaviorDataBackup/VoltageMice/"
    recording_file_map = {}

    # Identify unique (mouse, date) pairs
    unique_pairs = df[["AnimalCode", "Date"]].drop_duplicates()

    for _, row in unique_pairs.iterrows():
        mouse = row["AnimalCode"]
        date = row["Date"]
        mat_path = f"{base_path}{mouse}/Volt_{date}.mat"

        try:
            # Load the MATLAB file
            mat_data = loadmat(mat_path, struct_as_record=False, squeeze_me=True)
            
            # Extract the 'RESULTS' table
            results = mat_data["RESULTS"]
            
            # Extract the file numbers column
            file_numbers = results[1:, 4]  # Column 5 (index 4) contains file numbers
            
            # Create a mapping of (mouse, date) -> ordered file numbers
            recording_file_map[(mouse, date)] = list(file_numbers)
        
        except Exception as e:
            print(f"Error loading {mat_path}: {e}")
            recording_file_map[(mouse, date)] = None  # Store None if loading fails

    # Map the correct file number to each row in df
    def get_recording_file_num(row):
        file_list = recording_file_map.get((row["AnimalCode"], row["Date"]), [])
        if file_list and row["Recording"] <= len(file_list):
            return file_list[row["Recording"] - 1]  # Adjust for zero-indexing
        return None  # Return None if something is missing

    df["File"] = df.apply(get_recording_file_num, axis=1) - 1
    df["File"] = 'meas0' + df["File"].astype(str)
    
    return df

# Add column for trialID
def trial_ID(df, mouse):
    """Generates a unique TrialID for each trial in the DataFrame."""
    df['TrialID'] = 'Visual/' + str(mouse) + '/20' + df['Date'].astype(str) + '/' + df['File'] + '/trial' + df['Trial'].astype(str).str.zfill(3)
    return df

def extract_specs(filepath):
    """Extracts FPS, time origin, and movie length from an HDF5 file."""
    with h5py.File(filepath, 'r') as mov_file:
        specs = mov_file["specs"]
        fps = specs["fps"][()][0][0]
        timeorigin = specs["timeorigin"][()][0][0]
        movie_length = mov_file["mov"].shape[0]
    return fps, timeorigin, movie_length

def compute_delays_and_bfm_times(df):
    """ Computes delays and BFM times for each unique (Date, AnimalCode, Recording) combination, adds columns do df """

    basePath = '//tsclient/T7 Shield2/BehaviorDataBackup/VoltageMice/'
    # Initialize dictionaries to cache delays and specs
    delay_cache = {}
    specs_cache = {}

    # Find unique (Date, AnimalCode, Recording) combinations
    unique_pairs = df[["Date", "AnimalCode", "Recording", "File"]].drop_duplicates()

    for _, row in unique_pairs.iterrows():
        date = row["Date"]
        mouse = row["AnimalCode"]
        recording = row["Recording"]
        file = row["File"]

        # Compute delay
        path2 = f"{basePath}{mouse}/Volt_{date}_processed.mat"
        data2 = loadmat(path2)
        df2 = pd.DataFrame(data2["MasterN"]).map(lambda x: x[0] if isinstance(x, np.ndarray) and x.size > 0 else x)
        df2.columns = df2.iloc[0]  
        df2 = df2[1:]  
        delayEstimates = df2['FrameAlignmentInfo'][recording][0][0]
        delay = np.mean(delayEstimates)
        delay_cache[(date, mouse, recording)] = delay

        # Extract movie specs
        filepath = movie_path(mouse, date, file)
        fps, timeorigin, movie_length = extract_specs(filepath)
        specs_cache[(date, mouse, recording)] = (fps, timeorigin, movie_length)

    # Map delays and specs back to the original DataFrame
    df["Delay"] = df.apply(lambda row: delay_cache.get((row["Date"], row["AnimalCode"], row["Recording"])), axis=1)
    df["BFMTime"] = df.apply(
        lambda row: row["Time"] - row["Delay"] - int(specs_cache[(row["Date"], row["AnimalCode"], row["Recording"])] [1].item()) / specs_cache[(row["Date"], row["AnimalCode"], row["Recording"])][0].item(),
        axis=1
    )
    
    # Determine the validity of each trial
    df["ValidTrial?"] = df.apply(
        lambda row: 0 <= row["BFMTime"] <= specs_cache.get((row["Date"], row["AnimalCode"], row["Recording"]), (0, 0, 0))[2] / specs_cache.get((row["Date"], row["AnimalCode"], row["Recording"]), (1, 1, 1))[0].item(),
        axis=1
    )

    return df

def filter_vdt_trials(df, xlsx_path):
    """
    Filters trials from df to include only those where the corresponding recording has Task = 'VDT'.
    
    Args:
    - df (pd.DataFrame): The trial dataframe containing 'Date', 'File', and 'AnimalCode'.
    - xlsx_path (str): Path to the xlsx file.
    - sheet_name (str): The name of the sheet containing recording metadata.
    
    Returns:
    - pd.DataFrame: The filtered dataframe containing only trials with Task = 'VDT'.
    """

    recording_info = pd.read_excel(xlsx_path, sheet_name="Sheet1")

    df2 = df.copy()

    df2["Date"] = df2["Date"].astype(str)
    recording_info["Date"] = recording_info["Date"].astype(str).str[2:]  # Convert YYYYMMDD → YYMMDD
    recording_info = recording_info.rename(columns={"Animal": "AnimalCode"})
    recording_info["AnimalCode"] = recording_info["AnimalCode"].apply(lambda x: x if x.endswith("mjr") else x + "mjr")

    #Merge df with the xlsx data to bring in the Task column
    merged_df = df2.merge(
        recording_info,
        left_on=["AnimalCode", "Date", "File"],
        right_on=["AnimalCode", "Date", "File"],
        how="left"
    )

    # Identify recordings where Task is not 'VDT'
    flagged_recordings = merged_df.loc[merged_df["Task"] != "VDT", ["AnimalCode", "Date", "File"]].drop_duplicates()

    # Identify recordings missing from the xlsx file
    missing_recordings = merged_df.loc[merged_df["Task"].isna(), ["AnimalCode", "Date", "File"]].drop_duplicates()

    # Print missing recordings
    if not missing_recordings.empty:
        print()
        print("The following recordings were not found in the xlsx file:")
        print(missing_recordings.to_string(index=False))

    # Print flagged recordings
    if not flagged_recordings.empty:
        print()
        print("The following recordings were flagged for removal (Task != 'VDT'):")
        print(flagged_recordings.to_string(index=False))

    # # Keep only trials where Task is 'VDT'
    merged_df = merged_df[merged_df["Task"] == "VDT"].copy()
    merged_df.drop(columns=['Record#', 'identifier', 'Task', 'Active', 'VDT Behavior Quality.1'], inplace=True)

    return merged_df

def gen_trial_info(mouse, days_to_omit=None):
    """
    Generates a DataFrame containing trial information for a given mouse:
    Processes behavioral data, categorizes trials, maps recording files,
    computes timing information, and filters trials based on external logs.
    Args:
    - mouse (str): Identifier for the mouse.
    - days_to_omit (list of ints): Days for which to remove all trials from dataframe
    Returns:
    - pd.DataFrame: Processed trial information with relevant columns.
    """
    path_behavior = behavior_path(mouse)
    data_behavior = loadmat(path_behavior)
    df = pd.DataFrame(data_behavior["TrialInfo"]).map(flatten_cell).map(flatten_cell)

    # Set the second row as the header (column names)
    df.columns = df.iloc[0]                    # Assign the second row as column names
    df = df[1:]                                # Remove the first row (now redundant)

    # Omit days with missing files or invalid data
    if days_to_omit:
        df = df[~df["Date"].isin(days_to_omit)]

    df.reset_index(drop=True, inplace=True)    # Reset the index to start from 0
    df = df.rename(columns={"Recording#": "Recording"})

    df = perceptual_category(df, data_behavior)
    df = trial_category(df, print_results=True)
    df = map_recording_files(df)
    df = trial_ID(df, mouse)
    df = compute_delays_and_bfm_times(df)
    
    xlsx_path = "//tsclient/T7 Shield2/BehaviorDataBackup/VoltageMice/AllVoltage/RecordingLogsVoltage.xlsx"
    df = filter_vdt_trials(df, xlsx_path)

    # Rearrange columns
    order = ['TrialID', 'TrialType', 'AnimalCode', 'Date', 'Recording', 'File', 'Time', 'BFMTime', 'ValidTrial?', 
             'Duration', 'Contrast', 'PerceptualCat', 'ReactTime', 'Rewarded?', 'Enticed?', 'Seen?', 'Consume?', 
             'Attrition?', 'EngagementScore', 'EngagementScore_S20', 'VDT Behavior Quality']
    df = df[[col for col in order if col in df.columns]]    # Reorder the DataFrame columns

    return df

## Checking

In [9]:
def generate_ttl_trace(df, fps, total_frames, date, file):
    """
    Generates a TTL trace (1 for stimulus on, 0 for stimulus off) for each frame in the movie.

    Args:
    - df (pd.DataFrame): DataFrame containing 'BFMTime', 'Duration', and 'Date'.
    - fps (float): Frames per second of the movie.
    - total_frames (int): Total number of frames in the movie.
    - date (int or str): Date to filter trials (e.g., 240506).
    - file (str): file name - formatted as meas0#

    Returns:
    - np.ndarray: TTL trace with 1 (stimulus on) and 0 (stimulus off), one value per frame.
    """

    # Filter the dataframe to include only trials for the given date
    filtered_df = df[df["Date"].astype(str) == str(date)]
    filtered_df = filtered_df[filtered_df["File"] == file]
    filtered_df = filtered_df[filtered_df["ValidTrial?"] == 1]

    # Initialize the TTL trace with zeros (one entry per frame)
    ttl_trace = np.zeros(total_frames, dtype=int)

    # Loop through each trial in the filtered dataframe
    for _, trial in filtered_df.iterrows():
        # Compute the start and end frame indices for this trial
        start_frame = int(trial["BFMTime"] * fps)   # Convert BFMTime to frame index
        end_frame = int((trial["BFMTime"] + trial["Duration"]) * fps)  # Duration in frames

        # Set the frames for this trial to 1 (stimulus on)
        ttl_trace[start_frame:end_frame] = 1

    return ttl_trace

def get_ttl_trace(timestamps_table, timestamps_table_names, timeorigin, timebinning=1):
    """
    Extracts the TTL trace from the movie specs.

    Args:
    - timestamps_table (np.ndarray): Timestamps table containing TTL data.
    - timestamps_table_names (np.ndarray): Column names for the timestamps table.
    - timeorigin (int): Starting index for the TTL signal.
    - timebinning (int, optional): Time binning factor for downsampling (default: 1).

    Returns:
    - np.ndarray: Extracted TTL signal.
    """
  
    # Get the column index for 'behavior_ttl'
    ttl_column = timestamps_table_names.index("behavior_ttl")

    # Extract the raw TTL signal starting from `timeorigin`
    ttl_signal_raw = timestamps_table[ttl_column, int(timeorigin):]

    # Apply binning if `timebinning` is greater than 1
    if timebinning > 1:
        # Ensure the length of ttl_signal_raw is divisible by timebinning
        trimmed_length = len(ttl_signal_raw) - (len(ttl_signal_raw) % timebinning)
        ttl_signal_raw = ttl_signal_raw[:trimmed_length]
        
        # Reshape, average over bins, and round the result
        ttl_signal = np.round(np.mean(ttl_signal_raw.reshape(-1, timebinning), axis=1))
    else:
        ttl_signal = ttl_signal_raw

    return ttl_signal

def check_ttl_alignment_file(df, mouse, date, file, tolerance=5, plot_all=False, plot_if_misaligned=True):
    """
    Checks if the TTL traces from the movie and behavioral data are aligned for one file.

    Args:
    - ttl_movie (np.ndarray): TTL trace extracted from the movie.
    - ttl_df (np.ndarray): TTL trace generated from behavioral data.
    - tolerance (int, optional): Allowed frame difference for alignment. Default is 5.
    - plot (bool, optional): Wether to plot both traces for comparison (default is False)

    Returns:
    - str: "Aligned" if all onset/offset differences are within tolerance, else "Misaligned".
    """
    result = 'Aligned'

    path_movie = movie_path(mouse, date, file)

    with h5py.File(path_movie, 'r') as mov_file:
        print(f"Loading {path_movie}")
        specs = mov_file["specs"]
        fps = specs["fps"][()][0][0][0]
        timeorigin = specs["timeorigin"][()][0][0][0]
        timebinning = specs["timebinning"][()][0][0]
        timestamps_table = specs["extra_specs"]["timestamps_table"][()].squeeze()
        timestamps_table_names = specs["extra_specs"]["timestamps_table_names"][()]
        timestamps_table_names = b''.join(timestamps_table_names.flatten()).decode("utf-8").split(';')

    ttl_movie = get_ttl_trace(timestamps_table, timestamps_table_names, timeorigin, timebinning)
    ttl_df = generate_ttl_trace(df, fps, len(ttl_movie), date, file)

    # Find stimulus ONSET (0 -> 1) and OFFSET (1 -> 0) frames for both traces
    onsets_movie = np.where(np.diff(ttl_movie) == 1)[0] 
    offsets_movie = np.where(np.diff(ttl_movie) == -1)[0] 
    onsets_df = np.where(np.diff(ttl_df) == 1)[0] 
    offsets_df = np.where(np.diff(ttl_df) == -1)[0] 

    # Removie movie stimulus from before or after behavior ttl trace
    onsets_movie = onsets_movie[
        (onsets_movie >= onsets_df[0] - 10) & (onsets_movie <= onsets_df[-1] + 10)]
    offsets_movie = offsets_movie[
        (offsets_movie >= offsets_df[0] - 10) & (offsets_movie <= offsets_df[-1] + 10)]

    # Ensure equal number of onsets and offsets in both traces
    if len(onsets_movie) != len(onsets_df) or len(offsets_movie) != len(offsets_df):
        print(f"ERROR for {mouse} {date} {file}: Mismatch in number of stimuli: {len(onsets_movie)} in movie, {len(onsets_df)} in df")
        result = 'Mismatch Number of Stimuli'
        return result

    # Compute frame differences for onsets and offsets
    onset_diff = np.abs(onsets_movie - onsets_df)
    offset_diff = np.abs(offsets_movie - offsets_df)

    # Check if all differences are within tolerance
    if (onset_diff > tolerance).any() or (offset_diff > tolerance).any():
        print(f"ERROR for {mouse} {date} {file}: Some frame differences exceed tolerance of {tolerance} frames.")
        print(f"Onset differences: {onset_diff}")
        print(f"Offset differences: {offset_diff}")
        result = 'Misaligned'

    if plot_all or (plot_if_misaligned and result == 'Misaligned'):
        if len(ttl_movie) > 10000:
            plt.plot(ttl_movie[0:10000])
            plt.plot(ttl_df[0:10000], ls='--')
        else:
            plt.plot(ttl_movie[0:min(len(ttl_movie), len(ttl_df))])
            plt.plot(ttl_df[0:min(len(ttl_movie), len(ttl_df))], ls='--')
        plt.legend(["Movie", "Behavioral Data"], loc=1)
        plt.xlabel("Frame")
        plt.show()
        plt.plot(ttl_movie[offsets_movie[0]-10:offsets_movie[0]+10])
        plt.plot(ttl_df[offsets_movie[0]-10:offsets_movie[0]+10], ls='--')
        plt.legend(["Movie", "Behavioral Data"], loc=1)
        plt.xlabel("Frame (centered on first onset in movie)")
        plt.show()

    return result

def check_ttl_alignment_all(df, tolerance=10, plot_all=False, plot_if_misaligned=False, days_to_skip_testing=None):
    """
    Checks TTL alignment for all unique recordings in the DataFrame.

    Args:
    - df (pd.DataFrame): DataFrame containing the data.
    - tolerance (int, optional): Allowed frame difference for alignment. Default is 5.
    - plot_all (bool, optional): Whether to plot all traces. Default is False.
    - plot_if_misaligned (bool, optional): Whether to plot traces if misaligned. Default is True.

    Returns:
    - None
    """
    misaligned_recordings = []

    # Extract unique combinations of 'AnimalCode', 'Date', and 'File'
    unique_recordings = df[['AnimalCode', 'Date', 'File']].drop_duplicates()

    # Don't test recordings in the days to skip testing
    print(days_to_skip_testing)
    if days_to_skip_testing:
        unique_recordings = unique_recordings[~unique_recordings["Date"].isin(days_to_skip_testing)]
    print(unique_recordings)

    # Iterate over each unique recording
    for _, row in unique_recordings.iterrows():
        mouse = row['AnimalCode']
        date = row['Date']
        file = row['File']

        #print(f"Checking alignment for Mouse: {mouse}, Date: {date}, File: {file}...")
        result  = check_ttl_alignment_file(df, mouse, date, file, tolerance, plot_all=False, plot_if_misaligned=True)

        if result == 'Misaligned' or result == 'Mismatch Number of Stimuli':
            misaligned_recordings.append((mouse, date, file))
            df.loc[(df["AnimalCode"] == mouse) & (df["Date"] == date) & (df["File"] == file), "ValidTrial?"] = False

    # Summary of results
    if not misaligned_recordings:
        print("\nMisaligned Recordings:")
        for mouse, date, file in misaligned_recordings:
            print(f"Mouse: {mouse}, Date: {date}, File: {file}")
    else :
        print("\nAll recordings are aligned.")

In [10]:
def check_files_exist(df):
    """
    Checks if the movie files for all unique (mouse, date, recording) pairs exist.

    Args:
    - df (pd.DataFrame): DataFrame containing trial information.

    Returns:
    - bool: True if all files exist, False if any are missing.
    """
    # Identify unique (mouse, date, recording) pairs
    recordings = df[["AnimalCode", "Date", 'File']].drop_duplicates()
    all_files_exist = True

    for _, row in recordings.iterrows():
        mouse = row["AnimalCode"]
        date = row["Date"]
        file = row['File']

        path_to_check = movie_path(mouse, date, file)

        if not os.path.exists(path_to_check):
            print("Missing File: ", path_to_check)
            all_files_exist = False

    if all_files_exist:
        print("All files exist")
    
    return all_files_exist

## Putting it Together: Generating, Checking and Saving Trial Info File

In [11]:
mouse = 'cmm002mjr'
df = gen_trial_info(mouse)
print('Generated Dataframe, Now Checking')
check_ttl_alignment_all(df)
check_files_exist(df)
df.to_csv(f'trial_info/TrialInfo_{mouse}.csv', index=False)


Trial Type Summary:
HC No Report (9): 510 trials
MC No Report (8): 420 trials
LC No Report (7): 299 trials
HC Hit (3): 223 trials
Error (0): 181 trials
MC Miss (5): 118 trials
MC Hit (2): 99 trials
Correct Rejection (4): 94 trials
Incorrect Reject (6): 55 trials
False Alarm (1): 11 trials

The following recordings were not found in the xlsx file:
AnimalCode   Date   File
 cmm002mjr 231214 meas03

The following recordings were flagged for removal (Task != 'VDT'):
AnimalCode   Date   File
 cmm002mjr 231208 meas00
 cmm002mjr 231214 meas03
 cmm002mjr 231215 meas02
 cmm002mjr 240502 meas01
 cmm002mjr 240502 meas02
Generated Dataframe, Now Checking
None
     AnimalCode    Date    File
249   cmm002mjr  231212  meas00
296   cmm002mjr  231212  meas01
398   cmm002mjr  231212  meas02
605   cmm002mjr  231213  meas01
664   cmm002mjr  231213  meas02
813   cmm002mjr  231213  meas03
1011  cmm002mjr  231214  meas00
1074  cmm002mjr  231214  meas01
1420  cmm002mjr  231215  meas03
Loading N:/GEVI_Wave/An