In [1]:
import pandas as pd
import numpy as np
from fibermagic.IO.NeurophotometricsIO import extract_leds
import os
import pathlib
import copy
from detrend import detrend
from perievents import perievents



In [2]:
# Replace with your own base path for the experiment data
basepath = '/home/james/Massive/PROJECTDATA/NAcC gDA3m + rAdo1.3 FR20-PR/DATA/'

# Replace with the path where you want results saved to.
outpath = "/home/james/Massive/PROJECTDATA/NAcC gDA3m + rAdo1.3 FR20-PR/RESULTS/"

In [3]:
def read_keyfile(keyfile):
    """Opens a csv file containing the metadata about the experiment and returns it as a data frame.
    The columns are expected to be: region, mouse, wave_len, protocol, condition, and include.

    Args:
        keyfile (pathlib.Path): Pathlib path to the keyfile.

    Returns:
        pandas.DataFrame: A dataframe of the experiment metadata.
    """
    abspath = keyfile.absolute()
    keytable = pd.read_csv(abspath)
    return keytable

def read_mouse_log(logfile, id):
    abspath = logfile.absolute()
    logs = pd.read_csv(abspath)
    logs = pd.read_csv(abspath)
    logs.columns = ['ComputerTimestamp', 'SystemTimestamp', 'animal.ID', 'Event', 'pi.time', 'pc.time', 'datetimestamp']
    logs = logs[logs['animal.ID']==id]
    return logs




def get_photometry(filepath, start_timestamp, logs):
    from fibermagic.IO.NeurophotometricsIO import extract_leds
    df = pd.read_csv(filepath)
    # The column renaming assumes that all columns to the right of "Timestamp" are photometry columns and that the recording regions are letter-number combinations like "X0" or "R1"
    df = df.rename(columns={'R0':'Region0R','G1':'Region1G'})
    df = df.rename(columns={'Timestamp':'SystemTimestamp'})
    df = df[df.SystemTimestamp>=logs.iloc[0]['SystemTimestamp']]
    if 'Flags' in df.columns:  # legacy fix: Flags were renamed to LedState
        df = df.rename(columns={'Flags': 'LedState'})
    df = extract_leds(df).dropna()
    return df

# Convert to long format
def convert_to_long(df):
    NPM_RED = 560
    NPM_GREEN = 470
    NPM_ISO = 410
    # dirty hack to come around dropped frames until we find better solution -
    # it makes about 0.16 s difference
    df.FrameCounter = np.arange(0, len(df)) // len(df.wave_len.unique())
    df = df.set_index('FrameCounter')
    regions = [column for column in df.columns if 'Region' in column]
    dfs = list()
    for region in regions:
        channel = NPM_GREEN if 'G' in region else NPM_RED
        sdf = pd.DataFrame(data={
            'Region': region,
            'Channel': channel,
            'Timestamp': df.SystemTimestamp[df.wave_len == channel],
            'Signal': df[region][df.wave_len == channel],
            'Reference': df[region][df.wave_len == NPM_ISO]
        }
        )
        dfs.append(sdf)
    dfs = pd.concat(dfs).reset_index().set_index('Region').dropna()
    return dfs


def detrend_data(df):
    df = convert_to_long(df)
    df["zdFF"] = detrend(df, "Timestamp", "Signal", "Reference", "Channel", steps=False, method="airPLS", smooth=10, standardize=True)
    return df

def sync_behavior(logs, detrended):
    import copy
    logs = logs.rename(columns={'SystemTimestamp':'Timestamp'})
    dfsx = copy.deepcopy(detrended)
    dfsx = dfsx.reset_index()
    logsG = pd.merge_asof(logs, dfsx[dfsx.Channel == 470], on="Timestamp", direction = "nearest")
    logsG = logsG[['Region', 'Channel', 'FrameCounter', 'Event', 'Timestamp', 'animal.ID']]
    logsR = pd.merge_asof(logs, dfsx[dfsx.Channel == 560], on="Timestamp", direction = "nearest")
    logsR = logsR[['Region', 'Channel', 'FrameCounter', 'Event', 'Timestamp', 'animal.ID']]
    slogs = pd.concat([logsR, logsG], axis=0)
    slogs = slogs.reset_index(drop=True).set_index(['Region', 'Channel', 'FrameCounter'])
    dfsx = dfsx.reset_index().set_index(['Region', 'Channel', 'FrameCounter'])
    return dfsx, slogs


In [4]:
def get_experiment_dirs(start_path, required_files):
    """Gets the bottom level folders in the starting folder. 
    The assumption going forward will be that each folder contains the files `logs.csv`, `photometry.csv` and `mouse_to_region.csv)

    Args:
        start_path (str): Path to the top-level folder containing the experiment data

    Returns:
        _type_: _description_
    """
    bottom_level_dirs = []
    
    for dirpath, dirnames, filenames in os.walk(start_path):
        # If the current directory has no subdirectories
        if not dirnames:
            # Check if all required files are present in the current directory
            if all(req_file in filenames for req_file in required_files):
                bottom_level_dirs.append(dirpath)
    
    return bottom_level_dirs


def get_experiment_metadata(path):
    experiment = []
    path = pathlib.Path(path)
    keyfile = path / 'region_to_mouse.csv'
    metadata = read_keyfile(keyfile)
    keep = metadata[metadata['include'] == "yes"]
    mice = list(keep['mouse'][keep['include'] == "yes"].unique())
    for mouse_id in mice:
        mouse = {}
        mouse['id'] = mouse_id
        mouse['path'] = path
        mouse['wavelengths'] = list(keep['wave_len'][keep['mouse'] == mouse_id].unique())
        mouse['regions'] = list(keep['region'][keep['mouse'] == mouse_id].unique())
        protocol = list(keep['protocol'][keep['mouse'] == mouse_id].unique())
        if len(protocol) == 1:
            mouse['protocol'] = protocol[0]
        elif len(protocol) > 1:
            raise Exception(f"There are multiple protocols for animal {mouse_id}. That should not be the case. \nPlease check that 'region_to_mouse.csv' is correct")
        else:
            raise Exception(f"There was no protocol found for amimal {mouse_id}. \nPlease check that 'region_to_mouse.csv' is correct")
        condition = list(keep['condition'][keep['mouse'] == mouse_id].unique())
        if len(condition) == 1:
            mouse['condition'] = condition[0]
        elif len(condition) > 1:
            raise Exception(f"There are multiple conditions for animal {mouse_id}. That should not be the case. \nPlease check that 'region_to_mouse.csv' is correct")
        else:
            raise Exception(f"There was no condition found for amimal {mouse_id}. \nPlease check that 'region_to_mouse.csv' is correct. If there was no treatment, enter 'NaN' or None for treatment")
        experiment.append(mouse)
    return experiment




In [5]:
required_files = ['logs.csv', 'photometry.csv', 'region_to_mouse.csv']

experiment_paths = get_experiment_dirs(basepath, required_files)

In [6]:
all_experiments = []

for path in experiment_paths:
    path = pathlib.Path(path)
    experiment = get_experiment_metadata(path)
    all_experiments = all_experiments + experiment

In [7]:
for experiment in all_experiments:
    mouse_id = experiment['id']
    logfile = experiment['path'] / 'logs.csv'
    experiment['logs'] = read_mouse_log(logfile, mouse_id)
    start_timestamp = experiment['logs'].iloc[0]['SystemTimestamp']
    photometry_file = experiment['path'] / 'photometry.csv'
    experiment['photometry'] = get_photometry(photometry_file, start_timestamp, experiment['logs'])
    experiment['detrended'] = detrend_data(experiment['photometry'])


In [None]:
# from fibermagic.core.perievents import perievents
from perievents import perievents

for experiment in all_experiments:
    print(experiment['id'])
    print(experiment['path'])
    dfsx,slogs = sync_behavior(experiment['logs'], experiment['detrended'])
    experiment['dfsx'] = dfsx
    experiment['slogs'] = slogs
    try:
        experiment['perievents'] = perievents(dfsx, slogs[slogs.Event=='FD'], window=20, frequency=10)
    except Exception as e:
        print(f"***ERROR!!!: {str(e)}")
        experiment['perievents'] = f"***ERROR!!!: {str(e)}"
        continue


In [9]:
from pathlib import Path
output_path = Path(outpath)

def save_perievents(experiment, outpath):
    if type(experiment['perievents']) == pd.core.frame.DataFrame:
        experiment['perievents'].to_csv(f"{outpath}{experiment['id']}-{experiment['protocol']}-{experiment['condition']}.csv")

for experiment in all_experiments:
    save_perievents(experiment, outpath)

In [None]:
import plotly.express as px
import kaleido


def plot_perievents(experiment, outpath):
    if type(experiment['perievents']) == pd.core.frame.DataFrame:
        figR = px.scatter(experiment['perievents'].loc['Region0R'].reset_index(), x='Timestamp', y='Trial', color='zdFF', range_color=(-5,5),
        color_continuous_scale=['blue', 'grey', 'red'], height=300).update_yaxes(autorange="reversed", title_text='Reward #',
        title_font={'size': 20}, tickfont={'size': 18}).update_xaxes(title_text=None, showticklabels=False).update_layout(title={'text':f"{experiment['condition']}: {experiment['condition']}", 'x':0.5})
        for scatter in figR.data:
            scatter.marker.symbol = 'square'
        figR.show()
        figG = px.scatter(experiment['perievents'].loc['Region1G'].reset_index(), x='Timestamp', y='Trial', color='zdFF', range_color=(-5,5),
                        color_continuous_scale=['blue', 'grey', 'red'], height=300).update_yaxes(autorange="reversed", title_font={'size': 20}, tickfont={'size': 18}).update_xaxes(title_text='Time (s)', title_font={'size': 20}, tickfont={'size': 18}).update_layout(title={'text':"iSPN activity", 'x':0.5})
        for scatter in figG.data:
            scatter.marker.symbol = 'square'
        figG.show()
        from plotly.subplots import make_subplots
        fig = make_subplots(rows=2, cols=1, subplot_titles=(experiment['condition'], "iSPN activity"), vertical_spacing = 0.1, shared_xaxes=True, shared_yaxes=True)
        for trace in figR.data:
            fig.add_trace(trace, row=1, col=1)
        for trace in figG.data:
            fig.add_trace(trace, row=2, col=1)
        common_colorscale = [[0, 'blue'], [0.5, 'grey'], [1, 'red']]
        coloraxis_range=[-4,4]
        fig.update_layout(coloraxis=dict(colorscale=common_colorscale, colorbar_title='Z dF/F', cmin=coloraxis_range[0], cmax=coloraxis_range[1]), height=500)
        fig.update_yaxes(autorange='reversed', row=1, col=1, title_text='Reward #', title_font={'size': 16}, tickfont={'size': 14}).update_xaxes(showticklabels=False, row=1, col=1)
        fig.update_yaxes(autorange='reversed', row=2, col=1, title_text='Reward #', title_font={'size': 16}, tickfont={'size': 14}).update_xaxes(title="Time (s) from reward delivery", showticklabels=True, title_font={'size': 16}, tickfont={'size': 14}, row=2, col=1)
        fig.show()
        fig.write_image(f"{outpath}{experiment['id']}-{experiment['protocol']}-{experiment['condition']}_heatmap.jpg")

for experiment in all_experiments:
    plot_perievents(experiment, outpath)
