In [2]:
%load_ext autoreload
%autoreload 1

from collections import OrderedDict

import glob
import numpy as np
import pandas as pd
import scipy

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.collections import PatchCollection
from matplotlib.patches import Rectangle
%matplotlib widget

import pickle
import os

In [3]:
%aimport utils.plot
from utils.plot import (
    plot_callback_raster_multiblock,
    plot_callback_raster, 
    plot_group_hist, 
    plot_pre_post, 
    plot_violins_by_block,
)

In [None]:
# Folders to save figures. set values none to prevent saving figures

save_folders = {
    "response_rate_by_block": "./data/figures/response_rates",
    "summary": "./data/figures/summary",
    "correlation": "./data/figures/correlation",
    "callback_rasters_by_block": "./data/figures/callback_rasters_by_block",
    "callback_rasters_multiblock": "./data/figures/callback_rasters_multiblock",
    "violin-n_calls": "./data/figures/violins/n_calls",
    "violin-latency": "./data/figures/violins/latency",
    "histogram-latency": "./data/figures/histograms/latency",
    "histogram-n_calls": "./data/figures/histograms/n_calls",
}

for k, v in save_folders.items():
    if v is not None:
        os.makedirs(v, exist_ok=True)

In [None]:
file = "./data/data.pickle"

with open(file, "rb") as f:
    loaded_data = pickle.load(f)

df = loaded_data["df"]
all_birds = loaded_data["all_birds"]

del loaded_data

df

In [None]:
### Plotting Parameters
day_colors = {1: "#a2cffe", 2: "#840000", 3: "#4e77a3"}
day_labels = {1: "baseline"}  # note: removed D2 label, since not all birds get D2 loom.

stim_kwargs = dict(alpha=0.5)
call_kwargs = dict(color="black", alpha=0.5)

plt.rcParams["figure.dpi"] = 600  # 600 for high-quality sharing

### Construct block-by-block response rates

In [None]:
df_block_response_rate = df.groupby(level=["birdname", "day", "block"]).agg(
    pct_trials_responded=("n_calls", lambda x: np.count_nonzero(x) / len(x)),
    n_trials=("n_calls", lambda x: len(x)),
)
df_block_response_rate

### Plot response rate by block

In [None]:
%%capture  
# %%capture prevents plot output

save_folder = save_folders['response_rate_by_block']

for birdname in all_birds:
    fig, ax = plt.subplots()

    bird_data = df_block_response_rate.loc[birdname]

    days = np.unique(bird_data.index.get_level_values("day"))

    for day in days:
        day_data = bird_data.loc[day]
        dc = day_colors.get(day, f'C{day}')
        dl = day_labels.get(day, f'day{day}')

        ax.plot(
            day_data.index,
            day_data["pct_trials_responded"],
            marker="o",
            color=dc,
            label=dl,
        )

    ax.set(
        xlabel="Block",
        ylabel="Response rate\n(% blocks with ≥ 1 response)",
        title=birdname,
        ylim=[-0.05, 1.05],
    )

    ax.legend()
    
    if save_folder is not None:
        fig.savefig(f'{save_folder}/{birdname}-response_rate.png')

### Plot Response Rates Across Days

In [None]:
# load tempo data
tempo_colors = {
    "fast": "#EE4E4E",
    "slow": "#FFC700",
    "no loom": "#0C1844",
}

tempo_data = pd.read_csv(
    "/Users/cirorandazzo/code/callback-analysis/data/loom_tempos.csv", index_col="bird"
)

# add control birds
control_birds = ["or91rd13", "gr44bu34"]

for bird in control_birds:
    control_bird = pd.DataFrame.from_records(
        [
            dict(
                bird=bird,
                family_category="no loom",
            )
        ],
        index="bird",
    )

    tempo_data = pd.concat([tempo_data, control_bird])

tempo_data["color"] = tempo_data["family_category"].map(tempo_colors)


tempo_data

In [None]:
df_day = df.groupby(level=["birdname", "day"]).agg(
    mean_n_calls=("n_calls", "mean"),
    mean_n_calls_excl_zero=("n_calls", lambda x: np.sum(x) / np.count_nonzero(x)),
    median_latency_s=("latency_s", lambda x: np.nanmedian(x)),  # median ignoring nans
    pct_trials_responded=("n_calls", lambda x: np.count_nonzero(x) / len(x)),
    n_trials_no_response=("n_calls", lambda x: len(x) - np.count_nonzero(x)),
    n_trials=("n_calls", lambda x: len(x)),
)
df_day

### Pre/post summary plots

- n_calls
- latency
- response rate

In [None]:
%%capture

#TODO: add option to specify days (eg, exclude d3)

save_folder = save_folders['summary']
highlight_bird = None
# highlight_bird = 'pk81rd39'

to_plot_pre_post = {
    "n_calls" : dict(
        fieldname="mean_n_calls",
        ax_kwargs = dict(
            title="Mean Calls/Trial Across Days",
            ylabel="Mean # Calls per Trial",
            ylim=[-0.1,3]
        )
    ),
    
    "n_calls_excl_zero" : dict(
        fieldname="mean_n_calls_excl_zero",
        ax_kwargs = dict(
            title="Mean Calls/Trial Across Days (Excl. 0)",
            ylabel="Mean # Calls per Trial",
            ylim=[0.9,3]
        )
    ),
    
    "latency" : dict(
        fieldname="median_latency_s", 
        ax_kwargs = dict(
            title="Median Latency Across Days",
            ylabel="Median Latency (s)",
        )
    ),
    
    "response_rate" : dict(
        fieldname="pct_trials_responded", 
        ax_kwargs = dict(
            yticks=np.arange(0, 1.2, 0.2),
            title="Response Rate Across Days",
            ylabel="% Trials with ≥1 Response",
        )
    ),
}

# proxy artists for legend
handles = [
    mpatches.Patch(color=color, label=speed)
    for speed, color in tempo_colors.items()
]

for fname, parameters in to_plot_pre_post.items():

    fig, ax = plt.subplots(figsize=[4, 6])

    ax_kwargs = parameters.get('ax_kwargs', {})

    ax = plot_pre_post(
        df_day, 
        fieldname=parameters['fieldname'], 
        ax=ax,
        color=tempo_data,
        plot_kwargs={'marker': 'o'},
        add_bird_label=True,  # default: false. adds bird ids on plot
    )

    if highlight_bird is not None:
        fname += f"-{highlight_bird}"

        ax = plot_pre_post(
            df_day.xs(highlight_bird, level="birdname", drop_level=False),
            fieldname=parameters['fieldname'], 
            ax=ax,
            color='green',
            plot_kwargs={'marker': 'o', 'linestyle':'--'}
        )

    ax.set(
        xticks=[1, 2, 3],
        # xlim=[0.75, 2.25],  # if no bird labels
        xlim=[0.75, 3.5],  # if bird labels
        # xticklabels=["baseline", "loom"],
        xlabel='Day',
        # yticks=np.arange(0, 1.2, 0.2),
        **ax_kwargs,
    )

    if save_folder is not None:
        fig.savefig(f"{save_folder}/{fname}.png", bbox_inches="tight")

        ax.legend(handles=handles)

        fig.savefig(f"./data/figures/summary/{fname}-legend.png", bbox_inches="tight")

### Deltas & Correlation

In [None]:
# scatter plot kwargs for correlation plot
corr_kwargs = dict(
    marker="o",
    color="k",
)

# regression line kwargs for correlation plot
corr_reg_kwargs = dict(
    color="k",
    linestyle="--",
)

In [None]:
df_day

In [None]:
# takes d2 - d1
d2minusD1 = lambda x: x.loc[:, 2] - x.loc[:, 1]

delta_df = df_day.groupby("birdname").agg(
    d_pct_trials_responded=("pct_trials_responded", d2minusD1),
    d_mean_n_calls=("mean_n_calls", d2minusD1),
    d_median_latency_s=("median_latency_s", d2minusD1),
)

# get song tempo & callback latency
delta_df["song_tempo"] = delta_df.index.map(tempo_data["median song tempo"])
# delta_df["callback_latency"] = delta_df.index.map(tempo_data["callback latency"])

delta_df

In [None]:
axis_labels_by_field = {
    "song_tempo": "Median song tempo\n(syl/s)",
    "d_pct_trials_responded": "$\Delta$ % trials with ≥1 response",
    "d_mean_n_calls": "$\Delta$ mean call count/trial",
    "d_median_latency_s": "$\Delta$ median latency (s)",
}


correlations_to_plot = {
    "C--song_tempo-response_rate": dict(
        x_field="song_tempo",
        y_field="d_pct_trials_responded",
        ax_kwargs=dict(
            title="$\Delta$ response rate vs. song tempo",
        ),
    ),
    "C--song_tempo-n_calls": dict(
        x_field="song_tempo",
        y_field="d_mean_n_calls",
        ax_kwargs=dict(
            title="$\Delta$ mean calls/trial vs. Song tempo",
        ),
    ),
    "C--song_tempo-latency": dict(
        x_field="song_tempo",
        y_field="d_median_latency_s",
        ax_kwargs=dict(
            title="$\Delta$ median latency vs. Song tempo",
        ),
    ),
}

In [None]:
%%capture

save_folder = save_folders['correlation']

for fname, parameters in correlations_to_plot.items():
    x_field = parameters["x_field"]
    y_field = parameters["y_field"]
    x = delta_df[x_field]
    y = delta_df[y_field]

    ii = ~np.isnan(x) & ~np.isnan(y)
    x, y = x[ii], y[ii]
    

    fig, ax = plt.subplots()
    ax.scatter(x, y, **corr_kwargs, label=None)

    slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(x=x, y=y)
    p = lambda x: slope*x + intercept
    eq_string = "$%.3ex + %.6f$\n$r=%.3f$"%(slope, intercept, r_value)

    ends = np.array([np.min(x), np.max(x)])

    ax.plot(ends, p(ends), label=eq_string, **corr_reg_kwargs)

    ax.set(
        xlabel=axis_labels_by_field[x_field],
        ylabel=axis_labels_by_field[y_field],
        **parameters["ax_kwargs"],
    )

    ax.legend(loc='best')

    if save_folder is not None:
        fig.savefig(f"{save_folder}/{fname}.png", bbox_inches="tight")

In [None]:
# savefig = False

# fig, ax = plt.subplots()

# ax.scatter(delta_df["song_tempo"], delta_df[""], **corr_kwargs)

# ax.set(
#     # xlim=[4.5, 8],
# )

# if savefig:
#     fig.savefig(f"./data/figures/correlation/{fname}.png", bbox_inches="tight")

In [None]:
plt.close("all")

## Plot Rasters

### Raster by block

In [None]:
%%capture  
# %%capture prevents plot output

save_folder = save_folders['callback_rasters_by_block']

# every bird/day/block
unique_conditions = list(set([a[0:3] for a in df.index]))

## or select a subset
# unique_conditions = [
#     ('or14pu27', 1, 1),
#     ('or14pu27', 2, 1),
#     ('or54rd45', 1, 1),
#     ('or54rd45', 2, 1),
# ]

# figs = {}

for bird, day, block in unique_conditions:

    fig = plt.figure()
    ax = fig.subplots()

    data = df.loc[(bird, day, block)]
    
    title_str = f'{bird}-d{day}-b{block}'

    stim_kwargs['color'] = day_colors[day]

    plot_callback_raster(
        data,
        ax=ax,
        title = title_str,
        plot_stim_blocks = False,
        show_legend = True,
        call_kwargs = call_kwargs,
        stim_kwargs = stim_kwargs,
    )

    ax.set_xlim([-0.1, 3])

    # figs[title_str] = fig

    if save_folder is not None:
        fig.savefig(f'{save_folder}/{title_str}.png')

### Raster by day

In [None]:
%%capture  
# %%capture prevents plot output

save_folder = save_folders['callback_rasters_multiblock']

# every bird/day
unique_conditions = list(set([a[0:2] for a in df.index]))

for bird, day in unique_conditions:
    data = df.loc[(bird, day)]

    title_str = f"{bird}-d{day}"

    stim_kwargs = dict(color=day_colors[day], alpha=0.5, edgecolor=None)
    call_kwargs = dict(color="black", alpha=0.5, edgecolor=None)

    fig = plt.figure()
    ax = fig.subplots()

    plot_callback_raster_multiblock(
        data,
        ax=ax,
        plot_hlines=True,
        show_block_axis=True,
        show_legend=False,
        xlim=[-0.1, 3],
        stim_kwargs = stim_kwargs,
        call_kwargs = call_kwargs,
        title = title_str,
    )

    if save_folder is not None:
        fig.savefig(f'{save_folder}/{title_str}.png')

## Violin plots

In [None]:
# NOTE: violin plot only works for days = [1,2] right now

days = [1, 2]

width = 0.75

In [None]:
%%capture  
# %%capture prevents plot output

save_folder =  save_folders['violin-n_calls']

for bird in all_birds:
    fig, ax = plt.subplots()
    title_str = bird

    ax = plot_violins_by_block(
            df.loc[bird],
            field="n_calls",
            ax=ax,
            days=days,
            day_colors=day_colors,
            width=width,
            dropna=False,
    )

    ax.set(
        xlim=[-0.5,9.5],
        xticks= np.arange(0,10),
        # ylim=[-.5, 8],
        xlabel='Block',
        ylabel='Calls per stimulus',
        title=title_str,
    )
    
    if save_folder is not None:
        fig.savefig(f'{save_folder}/{title_str}.png')

In [None]:
%%capture  
# %%capture prevents plot output

save_folder =  save_folders['violin-latency']

for bird in all_birds:
    fig, ax = plt.subplots()
    title_str = bird

    ax = plot_violins_by_block(
            df.loc[bird],
            field="latency_s",
            ax=ax,
            days=days,
            day_colors=day_colors,
            width=width,
            dropna=True,
    )

    ax.set(
        xlim=[-0.5,9.5],
        xticks= np.arange(0,10),
        # ylim=[0, 2.5],
        xlabel='Block',
        ylabel='Latency to first call (s)',
        title=title_str,
    )
    
    if save_folder is not None:
        fig.savefig(f'{save_folder}/{title_str}.png')

## Histograms

All blocks merged

### Latency

In [None]:
%%capture

# index levels: 'birdname', 'day', 'block', 'stims_index'
# idx = pd.IndexSlice
# this_bird = df.loc[idx[birdname, :, :, :]]

save_folder = save_folders['histogram-latency']

for bird in all_birds:

    fig, ax = plt.subplots()

    plot_group_hist(
        df.loc[bird],
        field="latency_s",
        grouping_level="day",
        group_colors=day_colors,
        alt_labels=day_labels,
        ax=ax,
        density=True,
        ignore_nan=True,
        histogram_kwargs={
            "range": (0, 1.5),
            "bins": 40,
        },
        stair_kwargs={
            1: {"hatch": "/"},
            2: {"hatch": "\\"},
        },
    )

    ax.set(
        title=f"{bird}: latency to first call",
        xlabel="Latency (s)",
    )

    if save_folder is not None:
        fig.savefig(f'{save_folder}/{bird}-latency.png')

In [None]:
%%capture

save_folder = save_folders['histogram-n_calls']

for bird in all_birds:

    fig, ax = plt.subplots()

    plot_group_hist(
        df.loc[bird],
        field="n_calls",
        grouping_level="day",
        group_colors=day_colors,
        alt_labels=day_labels,
        ax=ax,
        density=True,
        ignore_nan=False,
        histogram_kwargs={
            "range": (-0.5, 9.5),
            "bins": 10,
        },
        stair_kwargs={
            1: {"hatch": "/"},
            2: {"hatch": "\\"},
        },
    )

    ax.set(
        title=f"{bird}: number of calls per trial",
        xlabel="# of calls",
        xticks=list(range(0, 10)),
    )

    if save_folder is not None:
        fig.savefig(f"{save_folder}/{bird}-ncalls.png")

In [None]:
# to generate legend

alpha = 0.5

fig, ax = plt.subplots()

handles = []
for k, v in day_colors.items():
    handles.append(
        Rectangle(
            [0, 0],
            0,
            0,
            color=day_colors[k],
            alpha=alpha,
            label=f"Day {k}",
        )
    )

ax.legend(handles=handles)

plt.show()

TODO

In [None]:
# TODO: maybe replace some hacky indexing stuff with `groupby`. eg:
df["n_calls"].groupby(level=["birdname", "day"]).mean()