In [None]:
%load_ext autoreload
%autoreload 1

import glob
import numpy as np
import pandas as pd
from collections import OrderedDict

import matplotlib.pyplot as plt
%matplotlib widget

# processed_directory = './data/processed_mats/*.mat'
processed_directory = '/Volumes/AnxietyBU/callbacks/processed/*.mat'

acceptable_call_labels = ['Call', 'Stimulus']  # any stimulus_trials containing call types NOT in this list are excluded (this includes unlabeled, which are stored as 'USV'!!)

files = [f for f in glob.glob(processed_directory)]

# or only specific files. Note: may mess up histograms, which may require data from >1 file
# files = [
#     './data/processed_mats/or60rd49-d1-20240425115050-Block1-PROCESSED.mat',
#     './data/processed_mats/or60rd49-d2-20240426114923-Block1-PROCESSED.mat'
# ]
files


In [None]:
%aimport utils.deepsqueak
from utils.deepsqueak import call_mat_stim_trial_loader, multi_index_from_dict

df = pd.DataFrame()

call_types_all = pd.DataFrame()
rejected_trials_all = pd.DataFrame()

for file in files:
    calls_df, stim_trials, rejected_trials, file_info, call_types = call_mat_stim_trial_loader(file, acceptable_call_labels=['Call', 'Stimulus'], verbose=False)

    # TODO: make this a nicely editable parameter
    multi_index_info = OrderedDict()
    multi_index_info['birdname'] = file_info['birdname']
    multi_index_info['day'] = int(file_info['d'])
    multi_index_info['block'] = int(file_info['block'])

    # create multiindex: birdname, stim_trial_index, call_index
    stim_trials = multi_index_from_dict(
        stim_trials, 
        multi_index_info, 
        keep_current_index=True,
    )
    df = pd.concat((df, stim_trials), axis='rows')
    
    rejected_trials = multi_index_from_dict(
        rejected_trials, 
        multi_index_info, 
        keep_current_index=True
    )
    rejected_trials_all = pd.concat((rejected_trials_all, rejected_trials), axis='rows')

    call_types = multi_index_from_dict(
        call_types, 
        multi_index_info, 
        keep_current_index=True
    )
    call_types_all = pd.concat((call_types_all, call_types), axis='rows')

print('Rejected trials:')
rejected_trials_all

In [None]:
print(
    "Call types in rejected trials."
    + "\nLabel `USV` means an accepted call was not given a label."
    + "\nGo back to DeepSqueak & fix ths."
)

rejected_trial_call_types = call_types_all.loc[rejected_trials_all.index]
rejected_trial_call_types
# TODO: add stim index to rej trial type df (is this the first stim?)

# # see only blocks with a specific call type
#
# label = 'USV'
# label = 'Noise'
# call_types_all.loc[~np.isnan(call_types_all.loc[:, label])]

In [None]:
df

In [None]:
all_birds = list(set(df.index.get_level_values(0)))
all_birds

## Plot Rasters

In [None]:
day_colors = {1: "#a2cffe", 2: "#840000"}
day_labels = {1: "baseline", 2: "loom"}

stim_kwargs = dict(alpha=0.5)
call_kwargs = dict(color="black", alpha=0.5)

### Raster by block

In [None]:
%%capture  
# %%capture prevents plot output

%aimport utils.plot
from utils.plot import plot_callback_raster

save_folder =  None
save_folder = './data/figures/callback_rasters_by_block'

# every bird/day/block
unique_conditions = list(set([a[0:3] for a in df.index]))

## or select a subset
# unique_conditions = [
#     ('or14pu27', 1, 1),
#     ('or14pu27', 2, 1),
#     ('or54rd45', 1, 1),
#     ('or54rd45', 2, 1),
# ]

# figs = {}

for bird, day, block in unique_conditions:

    fig = plt.figure()
    ax = fig.subplots()

    data = df.loc[(bird, day, block)]
    
    title_str = f'{bird}-d{day}-b{block}'

    stim_kwargs['color'] = day_colors[day]

    plot_callback_raster(
        data,
        ax=ax,
        title = title_str,
        plot_stim_blocks = False,
        show_legend = True,
        call_kwargs = call_kwargs,
        stim_kwargs = stim_kwargs,
    )

    ax.set_xlim([-0.1, 3])

    # figs[title_str] = fig

    if save_folder is not None:
        fig.savefig(f'{save_folder}/{title_str}.png')

### Raster by day

In [None]:
%%capture  
# %%capture prevents plot output

%aimport utils.plot
from utils.plot import plot_callback_raster_multiblock

save_folder =  None
save_folder = './data/figures/callback_rasters_multiblock'

# every bird/day
unique_conditions = list(set([a[0:2] for a in df.index]))

for bird, day in unique_conditions:
    data = df.loc[(bird, day)]

    title_str = f"{bird}-d{day}"

    stim_kwargs = dict(color=day_colors[day], alpha=0.5, edgecolor=None)
    call_kwargs = dict(color="black", alpha=0.5, edgecolor=None)

    fig = plt.figure()
    ax = fig.subplots()

    plot_callback_raster_multiblock(
        data,
        ax=ax,
        plot_hlines=True,
        show_block_axis=True,
        show_legend=False,
        xlim=[-0.1, 3],
        stim_kwargs = stim_kwargs,
        call_kwargs = call_kwargs,
        title = title_str,
    )

    if save_folder is not None:
        fig.savefig(f'{save_folder}/{title_str}.png')

## Violin plots

In [None]:
days = [1, 2]
width=0.75

In [None]:
%%capture  
# %%capture prevents plot output

%aimport utils.plot
from utils.plot import plot_violins_by_block

save_folder =  None
save_folder = './data/figures/n_calls'
# save_folder = './data/figures/n_calls-norm'

for bird in all_birds:
    fig, ax = plt.subplots()
    title_str = bird

    ax = plot_violins_by_block(
            df.loc[bird],
            field="n_calls",
            ax=ax,
            days=days,
            day_colors=day_colors,
            width=width,
            dropna=False,
    )

    ax.set(
        xlim=[-0.5,9.5],
        xticks= np.arange(0,10),
        # ylim=[-.5, 8],
        xlabel='Block',
        ylabel='Calls per stimulus',
        title=title_str,
    )
    
    if save_folder is not None:
        fig.savefig(f'{save_folder}/{title_str}.png')

In [None]:
%%capture  
# %%capture prevents plot output

%aimport utils.plot
from utils.plot import plot_violins_by_block

save_folder =  None
save_folder = './data/figures/latency'
# save_folder = './data/figures/latency-norm'

for bird in all_birds:
    fig, ax = plt.subplots()
    title_str = bird

    ax = plot_violins_by_block(
            df.loc[bird],
            field="latency_s",
            ax=ax,
            days=days,
            day_colors=day_colors,
            width=width,
            dropna=True,
    )

    ax.set(
        xlim=[-0.5,9.5],
        xticks= np.arange(0,10),
        # ylim=[0, 2.5],
        xlabel='Block',
        ylabel='Latency to first call (s)',
        title=title_str,
    )
    
    if save_folder is not None:
        fig.savefig(f'{save_folder}/{title_str}.png')

## Histograms

All blocks merged

### Latency

In [None]:
%%capture

# index levels: 'birdname', 'day', 'block', 'stims_index'
# idx = pd.IndexSlice
# this_bird = df.loc[idx[birdname, :, :, :]]

save_folder = None
save_folder = './data/figures/histograms/latency'

%aimport utils.plot
from utils.plot import plot_group_hist

for bird in all_birds:

    fig, ax = plt.subplots()

    plot_group_hist(
        df.loc[bird],
        field="latency_s",
        grouping_level="day",
        group_colors=day_colors,
        alt_labels={1: "baseline", 2: "loom"},
        ax=ax,
        density=True,
        ignore_nan=True,
        histogram_kwargs={
            "range": (0, 1.5),
            "bins": 40,
        },
        stair_kwargs={
            1: {"hatch": "/"},
            2: {"hatch": "\\"},
        },
    )

    ax.set(
        title=f"{bird}: latency to first call",
        xlabel="Latency (s)",
    )

    if save_folder is not None:
        fig.savefig(f'{save_folder}/{bird}-latency.png')

In [None]:
%%capture

save_folder = None
save_folder = "./data/figures/histograms/n_calls"
# save_folder = './data/figures'

%aimport utils.plot
from utils.plot import plot_group_hist

for bird in all_birds:

    fig, ax = plt.subplots()

    plot_group_hist(
        df.loc[bird],
        field="n_calls",
        grouping_level="day",
        group_colors=day_colors,
        alt_labels={1: "baseline", 2: "loom"},
        ax=ax,
        density=True,
        ignore_nan=False,
        histogram_kwargs={
            "range": (-0.5, 9.5),
            "bins": 10,
        },
        stair_kwargs={
            1: {"hatch": "/"},
            2: {"hatch": "\\"},
        },
    )

    ax.set(
        title=f"{bird}: number of calls per trial",
        xlabel="# of calls",
        xticks=list(range(0, 10)),
    )

    if save_folder is not None:
        fig.savefig(f"{save_folder}/{bird}-ncalls.png")