In [None]:
import numpy as np
import glob
import os
import pandas as pd
from skimage import io
import matplotlib.pyplot as plt
import mplcursors
%matplotlib notebook

In [None]:
imaging_path = "/home1/bhalla/hrishikeshn/Imaging_Sorted_for_Analysis/"
csv_path = "/home1/bhalla/hrishikeshn/expt_csv/"
output_path = ""
animals = ["G394"] # Change this to the animals that need to be analysed
upi_list = range(1,3) # Change this value to the upi of the sessions that need to be analysed. For all sessions use []
if output_path != "":
    if not (os.path.isdir(output_path)):
        os.mkdir(output_path)
else:
    output_path = "."

water_blackout_thresh = 0.2
flash_thresh = 0.2

In [None]:
def get_min_max_intensity(trial_file):
    frame_stack = io.imread(trial_file)
    avg_intensities = []
    for frame in frame_stack:
        avg_intensities.append(np.mean(frame))
    return np.array([np.min(avg_intensities), np.std(frame_stack[np.argmin(avg_intensities)]), np.max(avg_intensities), np.std(frame_stack[np.argmax(avg_intensities)])])

In [None]:
def process_imaging_data(
    csv_path,
    imaging_path,
    animals,
    water_blackout_thresh=0.2,
    flash_thresh = 0.2,
):

    for animal_name in animals:
        csv_data = pd.read_csv(
            csv_path + "/" + animal_name + ".csv",
            dtype={
                "upi": np.int64,
                "date": str,
                "experiment_number": str,
                "missing_imaging_trials": str,
                "skip_imaging_trials": str,
            },
        )
        print("--------------------------------------------------")
        print(animal_name)
        print("--------------------------------------------------")

        img_sess_paths = glob.glob(imaging_path + "/" + animal_name + "/20*/[0-9]")
        img_dates = [x.split("/")[-2] for x in img_sess_paths]

        for _, session in csv_data.iterrows():
            
            if len(upi_list)>0 and session['upi'] not in upi_list:
                continue

            print("**************************************************")
            print(
                f"{session['date']}/{session['experiment_number']}\t\
                {session['behaviour_code']}_{session['behaviour_session_number']}"
            )
            print("**************************************************")

            if session["num_imaging_trials"] > 0:
                if session["date"] in img_dates:
                    img_expt_nums = [
                        x.split("/")[-1]
                        for x in glob.glob(
                            imaging_path
                            + "/"
                            + animal_name
                            + "/"
                            + session["date"]
                            + "/[0-9]"
                        )
                    ]
                    if session["experiment_number"] in img_expt_nums:
                        trial_files = sorted(
                            glob.glob(
                                imaging_path
                                + "/"
                                + animal_name
                                + "/"
                                + session["date"]
                                + "/"
                                + session["experiment_number"]
                                + "/*.tif*"
                            )
                        )
                        if len(trial_files) != session["num_imaging_trials"]:
                            print("ERROR: Mismatch in number of imaging tiff files")
                        else:
                            csv_error_trials = set()
                            if pd.notna(session["skip_imaging_trials"]):
                                csv_error_trials.update(
                                    [
                                        int(x)
                                        for x in session["skip_imaging_trials"].split(
                                            ";"
                                        )
                                    ]
                                )
                            if pd.notna(session["missing_imaging_trials"]):
                                csv_error_trials.update(
                                    int(x)
                                    for x in session["missing_imaging_trials"].split(
                                        ";"
                                    )
                                )

                            min_max_intensities = {}
                            for t, trial_file in enumerate(trial_files):
                                t_num = int(trial_files[t].split("-")[-3])
                                if t_num not in csv_error_trials:
                                    min_max_intensities[t_num] = get_min_max_intensity(trial_file)

                            min_intensities = np.array(list(min_max_intensities.values()))[:,0]
                            mean_min_intensity = np.mean(min_intensities)
                            error_t_blackout = np.where(
                                (
                                    (mean_min_intensity - min_intensities)
                                    / mean_min_intensity
                                )
                                > water_blackout_thresh
                            )[0]
                            for et in error_t_blackout:
                                et_t_num = int(trial_files[et].split("-")[-3])
                                if et_t_num not in csv_error_trials:
                                    print(
                                        "WARNING: "
                                        + trial_files[et].split("/")[-1]
                                        + " is a potential water blackout trial"
                                    )

                            max_intensities = np.array(list(min_max_intensities.values()))[:,2]
                            mean_max_intensity = np.mean(max_intensities)
                            error_t_flash = np.where(
                                (
                                    (max_intensities - mean_max_intensity)
                                    / mean_max_intensity
                                )
                                > flash_thresh
                            )[0]
                            for et in error_t_flash:
                                et_t_num = int(trial_files[et].split("-")[-3])
                                if et_t_num not in csv_error_trials:
                                    print(
                                        "WARNING: "
                                        + trial_files[et].split("/")[-1]
                                        + " is a potential flash trial"
                                    )


                            fig, ax = plt.subplots(2,1)
                            x = []
                            y = []
                            for t_num, values in min_max_intensities.items():
                                x.append(values[0])
                                y.append(values[1])
                            min_scatter = ax[0].scatter(x, y, c=np.random.rand(len(x),3), alpha=0.5)
                            ax[0].set_ylabel("std of pixel intensities")
                            ax[0].set_xlabel("min intensity during trial")
                            min_cursor = mplcursors.cursor(min_scatter, hover=True)
                            @min_cursor.connect("add")
                            def on_add(sel):
                                sel.annotation.set(text = list(min_max_intensities.keys())[sel.target.index])

                            y = []
                            x = []
                            for t_num, values in min_max_intensities.items():
                                x.append(values[2])
                                y.append(values[3])
                            max_scatter = ax[1].scatter(x, y, c=np.random.rand(len(x),3), alpha=0.5)
                            ax[1].set_ylabel("std of pixel intensities")
                            ax[1].set_xlabel("max intensity during trial")
                            max_cursor = mplcursors.cursor(max_scatter, hover=True)
                            @max_cursor.connect("add")
                            def on_add(sel2):
                                sel2.annotation.set(text = list(min_max_intensities.keys())[sel2.target.index])
                            plt.suptitle(f"{animal_name}/{session['date']}/{session['experiment_number']}")
                            plt.tight_layout()

                    else:
                        print("ERROR: Imaging session not found")
                else:
                    print("ERROR: Imaging session not found")

            print()
        print()
        print()

In [None]:
process_imaging_data(
    csv_path,
    imaging_path,
    animals,
    water_blackout_thresh,
    flash_thresh,
)