# Behavior statistics

This notebook generates panels for Figure 2 comparing the behavioral performance of birds raised in the CR and PR conditions.

In [None]:
import <- function(pkg) { library(pkg, warn.conflicts=F, quietly=T, character.only=T) }
import("repr")
import("stringr")
import("tidyr")
import("dplyr")
import("ggplot2")
import("lme4")
import("emmeans")
import("diagis")
import("bssm")

In [None]:
options(repr.matrix.max.cols=15, repr.matrix.max.rows=20)
options(repr.plot.width=2, repr.plot.height=1.25, repr.plot.res = 300)

my.theme <- theme(legend.text=element_text(size=6),
                  legend.title=element_text(size=6),
                  plot.title = element_text(size=7, hjust=0.5),
                  axis.line=element_line(linewidth=0.25),
                  axis.ticks=element_line(linewidth=0.25),
                  axis.ticks.length=unit(0.05, "cm"),
                  axis.title=element_text(size=7),
                  axis.text=element_text(size=6),
                  strip.placement="outside",
                  strip.text=element_text(size=7),
                  strip.background=element_blank())
no.legend <- theme(legend.position="none")
update_geom_defaults("point", list(fill="white", shape=21, size=0.8))
update_geom_defaults("line", list(linewidth=0.4))

In [None]:
logit <- function(p) log(p / (1 - p))
invlogit <- function(x) exp(x) / (1 + exp(x))

## Training

Before plotting the training data, you need to fit the state-space model by running `Rscript scripts/ssm-training.R datasets/zebf-discrim-noise/trials/C280_train_trials.csv`. The following is a brief explanation of the model.

### Non-response model

The training trials are analyzed with a state-space model in which each trial is modeled as a Bernoulli random variable that depends on one or more latent state variables. For non-response probability, the outcomes are coded as {Peck, Timeout}, and there is a single latent state variable $x_t$ that represents the log odds of pecking and that changes by a by a random amount in each trial, subject to a distribution with unknown variance. 

\begin{align}
y_t & \sim \mathrm{Bin}(\frac{\exp x_t}{1 + \exp x_t}) \\
x_{t+1} & = x_t + \eta_t \\
\eta_t & \sim N(0, \sigma_\nu^2) \\
\end{align}

To fit the model, we need priors for the initial state, $x_0 \sim N(0, \sigma_x^2)$ and for the hyperparameter $\sigma_\nu$. This model can be fit in `bssm` with the `bsm_ng` function.

### Discrimination model

To quantify discrimination, we exclude the non-response trials and code the remaining trials as {Right, Left}. The probability of pecking left is conditioned on the reinforcement contingency
for the stimulus, $Z_t$, which is coded as $(1, -0.5)$ for trials rewarded on the right key and $(1, 0.5)$ for trials rewarded on the left key. The latent variable $x_t$ is now a vector of two components, one for bias and the other for discrimination. Our random-walk model is as follows:

\begin{align}
y_t|x_t,Z_t,\theta & \sim \mathrm{Bernoulli}(p_t) \\
\mathrm{logit}(p_t) & = Z_t x_t \\
x_{t+1} & = x_t + \eta_t \\
\eta_t & \sim N(0, \Sigma_\eta) \\
\end{align}

We now have to use the `ssm_ung` function to construct the SSM model. The documentation in `bssm` is not great, so a few notes. The model is formalized as

\begin{align}
y_t & \sim p(y_t|D_t + Z_tx_t) \\
x_{t+1} & = C_t + T_t x_t + R_t \nu_t,
\end{align}

with $\nu_t \sim N(0, I_k)$. For our model, $D_t$ and $C_t$ are zero, and $T_t$ is a $2 \times 2$ identity matrix, all constant over time. $Z_t$ does depend on time; it's $(1, -0.5)$ for trials reinforced on the right and $(1, 0.5)$ for trials reinforced on the left. $x_1$ has a multivariate normal prior that's specified with the `a1` and `P1` arguments. 

- $n$ is the number of time points, $m$ is the dimension of the state vector, $k$ is the dimension of the process noise
- $Z$ is supplied as $m \times n$ array (if it varies over time)
- $T$ is supplied as $m \times m$
- $R$ is $m \times k$. We probably just want to use $k = m$ here.

The prior for $R_t$ is given by supplying functions to the `update_fn` and `prior_fn` arguments. All the parameters that take a prior get mushed together into a vector $\theta$ that gets passed to these functions. `prior_fn` evaluates the (joint) log density. `update_fn` un-mushes $\theta$ into a list with named elements.

The simplest prior to implement for $R$ is to have independent normal distributions for bias and discrimination, which means their random walks will be uncorrelated.

In [None]:
# results of the script are output to this file
results <- readRDS("../build/C280_train_ssm_summary.rds")

In [None]:
summary_bird <- (
    results
    |> purrr::discard_at(c("data", "corr"))
    |> purrr::map(~ select(., time, variable, mean, lwr, upr))
    |> purrr::list_rbind()
)

In [None]:
options(repr.plot.width=1.9, repr.plot.height=1.2, repr.plot.res = 450)
p_left <- (
    summary_bird
    |> filter(variable %in% c("level", "p_left", "p_right"))
    |> pivot_wider(names_from="variable", values_from=c("mean", "upr", "lwr"))
    |> inner_join(results$data, by="time")
    |> mutate(rpos=(peck_left - 0.5)* 1.1 + 0.5, spos=stim_left * 0.05 - 0.025)
    |> ggplot(aes(trial))
    + geom_point(aes(y=ifelse(noresp, 1.125, NA)), size=0.3, shape="|")
    + geom_point(aes(y=rpos + spos, color=factor(stim_left)), size=0.3, shape="|")
    + geom_line(mapping=aes(y=mean_level), color="black")
    + geom_line(mapping=aes(y=mean_p_left), color="#2677B4")
    + geom_ribbon(mapping=aes(ymin=lwr_p_left, ymax=upr_p_left), fill="#2677B4", alpha=0.25)
    + geom_line(mapping=aes(y=mean_p_right), color="#F68626")
    + geom_ribbon(mapping=aes(ymin=lwr_p_right, ymax=upr_p_right), fill="#F68626", alpha=0.25)
    + scale_color_manual(values=c("#F68626", "#2677B4"))
    + scale_y_continuous("prob", breaks=c(0, 0.2, 0.4, 0.6, 0.8, 1.0))
    + theme_classic() + my.theme + no.legend
)
p_left

In [None]:
pdf("../figures/2ac_ssm_train_C280.pdf", width=1.9, height=1.2)
print(p_left)
dev.off()

In [None]:
options(repr.plot.width=1.9, repr.plot.height=1.2, repr.plot.res = 450)
p_discrim <- (
    summary_bird
    |> filter(variable=="discrim")
    |> inner_join(results$data, by="time")
    |> ggplot(aes(trial))
    + geom_line(mapping=aes(y=mean))
    + geom_ribbon(mapping=aes(ymin=lwr, ymax=upr), alpha=0.25)
    + geom_hline(yintercept=0, linetype="dotted")
    + scale_y_continuous("LOR")
    + theme_classic() + my.theme + no.legend
)
p_discrim

In [None]:
pdf("../figures/2ac_ssm_train_discrim_C280.pdf", width=1.9, height=1.2)
print(p_discrim)
dev.off()

## Generalizing to noisy stimuli

For the test trials, we don't need to use a state space model because we assume that the trials for a given bird to a given stimulus are exchangeable. However, because we have multiple animals, we need to use a generalized linear mixed effects model to deal with the hierarchical nature of the data.

In [None]:
## load metadata
birds <- data.table::fread("../datasets/zebf-discrim-noise/birds.csv")

In [None]:
## load trials
header <- data.table::fread(cmd='find ../datasets/zebf-discrim-noise/trials/ -name "*probe*_trials.csv" | head -n1 | xargs head -n1', header=T)
all_trials <- tibble(data.table::fread(cmd='find ../datasets/zebf-discrim-noise/trials/ -name "*probe*_trials.csv" | xargs tail -q -n+2', header=F))
names(all_trials) <- names(header)

In [None]:
## data cleaning
trials <- (
    all_trials 
    # correction trials were inadvertently left on for Rb284 in the 0 dB session
    |> filter(correction==0)
    # remove some trials with the wrong date 
    |> filter(time > lubridate::date("2022-01-01"))
    |> group_by(subject)
    |> arrange(id)
    ## recode stim and response so that we can get bias and LOR
    |> mutate(peck_any=(response != "timeout") * 1,
              peck_left=ifelse(peck_any, (response == "peck_left") * 1, NA),
              correct=ifelse(peck_any, correct * 1, NA),
              trial=row_number(),
              rtime=rtime / 1e6  # convert to s
              )
)
stims <- (
    unique(trials$stimulus)
    |> str_match("(?<foreground>[:alnum:]+)-(?<foregroundlvl>[:digit:]+)_(?<background>[:alnum:]+)-(?<backgroundlvl>[:digit:]+)")
    |> as.data.frame()
    |> mutate(stimulus=V1, foreground, background, snr=as.numeric(backgroundlvl) - as.numeric(foregroundlvl), .keep="none")
    |> mutate(snr=forcats::fct_rev(factor(snr)))
    |> drop_na()
)
trials <- (
    trials
    |> inner_join(stims, by="stimulus")
)
# generate a lookup table to determine which stimuli are associated with left key for which birds
stimclasses <- (
    trials
    |> xtabs(~ subject + foreground + peck_left + correct, data=_) 
    |> as.data.frame() 
    |> filter(correct==1, Freq > 0)
    |> select(subject, foreground, stim_left=peck_left)
)
sessions <- (
    unique(trials$experiment)
    |> str_match("2ac-(?<type>[:alnum:]+)-snr.*_(?<snr>[0-9-]+)-.*")
    |> as.data.frame()
    |> mutate(experiment=V1, session_type=type, session_snr=as.numeric(snr), .keep="none")
    |> mutate(session_snr=forcats::fct_rev(factor(session_snr)))
)
trials <- (
    trials
    |> inner_join(stimclasses, by=c("subject", "foreground"))
    |> inner_join(sessions, by="experiment")
    |> inner_join(birds, by=c(subject="bird"))
    |> select(subject, sex, group, age, siblings, 
              trial, session_type, session_snr, foreground, background, stim_left, snr, 
              peck_any, peck_left, rtime, correct, result)
)


In [None]:
## tabulate the number of responses for each subject, snr, and stimulus type. We can use these
## pooled counts as binomial random variables - much faster than trying to fit individual trials, and same results
resp_probs <- (
    trials
    |> group_by(group, subject, snr, stim_left)
    |> summarize(
        n_trials=n(), 
        n_peck=sum(peck_any),
        n_correct=sum(correct, na.rm=T),
        n_left=sum(peck_left, na.rm=T)
    )
)

### Example bird


In [None]:
example_trials <- trials |> filter(subject=="C280")
example_probs <- resp_probs |> filter(subject=="C280")

#### Check that performance on the baseline stimuli (70 dB SNR) remains stable.

In [None]:
block_size <- 100
p_timeout <- (
    example_trials
    |> filter(snr==70)
    |> mutate(block=factor(floor(row_number() / block_size)))
    |> group_by(block)
    |> summarize(p_timeout=sum(1 - peck_any)/n(), n_trials=n(), index_trial=median(trial))
)
p_left <- (
    example_trials
    |> filter(snr==70)
    |> filter(peck_any==1)
    |> mutate(block=factor(floor(row_number() / block_size)))
    |> group_by(block, stim_left)
    |> summarize(n_trials=n(), p_left=sum(peck_left)/n_trials, index_trial=median(trial))
)

In [None]:
options(repr.plot.width=2, repr.plot.height=1.25, repr.plot.res = 300)
p <- (
    ggplot(mapping=aes(index_trial))
    + geom_line(data=p_timeout, aes(y=p_timeout))
    + geom_line(data=p_left, aes(y=p_left, color=factor(stim_left), group=stim_left))
    + scale_color_manual(values=c("#F68626", "#2677B4"))
    + scale_x_continuous("Trial")
    + scale_y_continuous("p(left)")
)
p + theme_classic() + my.theme + no.legend

#### Generalized linear models

For an individual bird, we can use a standard GLM with SNR as our primary fixed effect. For non-response probability, we're using the proportion of trials where the bird pecks as the outcome. For discrimination, we use the proportion of response trials where the bird pecks left, with the stimulus type as an additional fixed effect.

In [None]:
# code SNR as a factor; nonresponse as dependent variable
fm_example_snr_nr <- (
    example_probs
    |> glm(cbind(n_trials - n_peck, n_peck) ~ 1 + snr, data=_, family=binomial)
)

In [None]:
# code SNR as a factor; peck left as dependent variable
fm_example_snr_pl <- (
    example_probs
    |> glm(cbind(n_left, n_peck - n_left) ~ 1 + snr*stim_left, data=_, family=binomial)
)

In [None]:
## Use emmeans to get marginal response probabilities
emm_example_snr_nr <- (
    fm_example_snr_nr
    |> emmeans(~ snr, type="response")
    |> confint(level=0.90)
    |> mutate(snr=as.numeric(as.character(snr)))
)
emm_example_snr_pl <- (
    fm_example_snr_pl
    |> emmeans(~ snr*stim_left, type="response")
    |> confint(level=0.90)
    |> mutate(snr=as.numeric(as.character(snr)))
)

In [None]:
options(repr.plot.width=1.7, repr.plot.height=1.2, repr.plot.res = 450)
p <- (
    emm_example_snr_pl
    |> ggplot(aes(snr, prob, ymin=asymp.LCL, ymax=asymp.UCL))
    + geom_line(mapping=aes(group=stim_left, color=stim_left))
    + geom_ribbon(mapping=aes(group=stim_left, fill=stim_left), alpha=0.25)
    + geom_line(data=emm_example_snr_nr, color="black")
    + geom_ribbon(data=emm_example_snr_nr, fill="black", alpha=0.25)
    + scale_y_continuous("Prob", limits=c(0, 1))
    + scale_x_reverse("SNR (dB)")
    + scale_color_manual(values=c("#F68626", "#2677B4"))
    + scale_fill_manual(values=c("#F68626", "#2677B4"))
    + theme_classic() + my.theme + no.legend
)
p

In [None]:
pdf("../figures/2ac_probe_C280.pdf", width=1.7, height=1.2)
print(p)
dev.off()

### All birds

First we'll just look at the data for all the birds plotted on the same axes.

In [None]:
options(repr.plot.width=3, repr.plot.height=3, repr.plot.res = 300)
(
    resp_probs
    |> ggplot(aes(snr, n_left / n_peck, group=stim_left, color=group))
    + geom_line()
    + geom_line(aes(y= 1 - n_peck / n_trials), color="black")
    + facet_wrap(~ subject)
    + scale_y_continuous("p(left)", limits=c(0, 1))
    + theme_classic() + my.theme
)

#### Generalized linear mixed effects models

To make the models mixed-effects, we add a random intercept for each subject

In [None]:
fm_snr_nr <- (
    resp_probs
    |> glmer(cbind(n_trials - n_peck, n_trials) ~ snr*group + (1|subject), data=_, family=binomial)
)
joint_tests(fm_snr_nr)

In [None]:
options(repr.plot.width=1.65, repr.plot.height=1.2, repr.plot.res = 450)
p <- (
    emmeans(fm_snr_nr, ~ group:snr)
    |> as.data.frame()
    |> mutate(snr=as.numeric(as.character(snr)))
    |> ggplot(aes(snr, emmean, group=group, color=group))
    + geom_line()
    + geom_point(size=1.5)
    + geom_linerange(mapping=aes(ymin=emmean - SE, ymax=emmean + SE))
    + scale_y_continuous("p(no resp) [log odds]")
    + scale_x_reverse("SNR (dB)")
    + theme_classic() + my.theme + no.legend
)
p

In [None]:
pdf("../figures/probe_respond.pdf", width=1.65, height=1.2)
print(p)
dev.off()

In [None]:
## contrast CR and PR at each SNR
(
    fm_snr_nr
    |> emmeans(~ group | snr)
    |> contrast("revpairwise", type="response")
)

In [None]:
## contrast CR and PR averaged across all SNRS
(
    fm_snr_nr
    |> emmeans(~ group, type="response")
    |> contrast("revpairwise")
)

For discrimination, we need to also add a random slope for stimulus type

In [None]:
fm_snr_pl <- (
    resp_probs
    # |> filter(snr != 70)
    |> glmer(cbind(n_left, n_peck - n_left) ~ stim_left*snr*group + (1+stim_left|subject), family=binomial, data=_)
)
joint_tests(fm_snr_pl)

In [None]:
options(repr.plot.width=1.6, repr.plot.height=1.2, repr.plot.res = 450)
p <- (
    emmeans(fm_snr_pl, ~ stim_left | snr/group)
    |> contrast("revpairwise")
    |> as.data.frame()
    |> mutate(snr=as.numeric(as.character(snr)))
    |> ggplot(aes(snr, estimate, group=group, color=group))
    + geom_line()
    + geom_point(size=1.5)
    + geom_linerange(mapping=aes(ymin=estimate - SE, ymax=estimate + SE))
    + scale_y_continuous("Discrimination (LOR)")
    + scale_x_reverse("SNR (dB)")
    + theme_classic() + my.theme + no.legend
)
p

In [None]:
pdf("../figures/probe_discrim.pdf", width=1.6, height=1.2)
print(p)
dev.off()

In [None]:
# Contrast LOR for CR and PR at each SNR (this is a contrast of contrasts)
(
    fm_snr_pl
    |> emmeans(~ stim_left*group*snr)
    |> contrast(interaction=c("revpairwise", "pairwise"), by="snr")
)

Finally, look at p(correct), counting non-responses as incorrect. This is consistent with the fact that
we punish non-responses.

In [None]:
fm_correct <- (
    resp_probs
    |> glmer(cbind(n_correct, n_trials - n_correct) ~ snr*group + (1|subject), data=_, family=binomial)
)

In [None]:
options(repr.plot.width=1.65, repr.plot.height=1.2, repr.plot.res = 450)
p <- (
    emmeans(fm_correct, ~ group:snr, type="response")
    |> as.data.frame()
    |> mutate(snr=as.numeric(as.character(snr)))
    |> ggplot(aes(snr, prob, group=group, color=group))
    + geom_line()
    + geom_point(size=1.5)
    + geom_linerange(mapping=aes(ymin=prob - SE, ymax=prob + SE))
    + scale_y_continuous("p(correct)", limits=c(0,1))
    + scale_x_reverse("SNR (dB)")
    + theme_classic() + my.theme + no.legend
)
p

In [None]:
pdf("../figures/probe_correct.pdf", width=1.7, height=1.2)
print(p)
dev.off()

In [None]:
(
    fm_correct
    |> emmeans(~ group | snr)
    |> contrast("revpairwise", type="response")
)

In [None]:
(
    fm_correct
    |> emmeans(~ group, type="response")
    |> contrast("pairwise")
)