In [None]:
import <- function(pkg) { library(pkg, warn.conflicts=F, quietly=T, character.only=T) }
import("repr")
import("stringr")
import("tidyr")
import("dplyr")
import("ggplot2")
import("lme4")
import("emmeans")

In [None]:
options(repr.matrix.max.cols=15, repr.matrix.max.rows=20)
options(repr.plot.width=2, repr.plot.height=1.25, repr.plot.res = 300)

my.theme <- theme(legend.text=element_text(size=6),
                  legend.title=element_text(size=6),
                  plot.title = element_text(size=7, hjust=0.5),
                  axis.line=element_line(linewidth=0.25),
                  axis.ticks=element_line(linewidth=0.25),
                  axis.ticks.length=unit(0.05, "cm"),
                  axis.title=element_text(size=7),
                  axis.text=element_text(size=6),
                  strip.placement="outside",
                  strip.text=element_text(size=7),
                  strip.background=element_blank())
no.legend <- theme(legend.position="none")
update_geom_defaults("point", list(fill="white", shape=21, size=0.8))
update_geom_defaults("line", list(linewidth=0.4))

## Load trial data

In [None]:
## metadata
birds <- data.table::fread("../inputs/bird_metadata.csv") |> filter(behavior=="yes")
probe_birds <- birds |> filter(probe=="yes")

### Probe

In [None]:
## trials - retrieved with batch/retrieve_trials
header <- data.table::fread(cmd='find ../build/ -name "*probe*_trials.csv" | head -n1 | xargs head -n1', header=T)
all_trials <- tibble(data.table::fread(cmd='find ../build/ -name "*probe*_trials.csv" | xargs tail -q -n+2', header=F))
names(all_trials) <- names(header)

In [None]:
# sanity check - each stimulus/response should only have one consequence
(
    all_trials
    |> group_by(subject, stimulus, response, correct)
    |> tally()
    |> tally()#
    |> filter(n > 1)
)

In [None]:
# sanity check: no big skips in the dates, which could indicate trials recorded with the wrong subject
# NB the trials with dates back in 2017 are probably clock errors. They should be in the right place using trial id,
# but it's safer to just discard them.
options(repr.plot.width=10, repr.plot.height=5, repr.plot.res = 300)
(
    all_trials
    |> mutate(date=lubridate::date(time))
    |> group_by(subject, date)
    |> tally()
    |> ggplot(aes(date, n))
    + facet_wrap(~ subject, scale="free")
    + geom_point()
)

In [None]:
trials <- (
    all_trials 
    # correction trials were inadvertently left on for Rb284 in the 0 dB session
    |> filter(correction==0)
    # remove some trials with the wrong date 
    |> filter(time > lubridate::date("2022-01-01"))
    |> group_by(subject)
    |> arrange(id)
    ## recode stim and response so that we can get bias and LOR
    |> mutate(peck_any=(response != "timeout") * 1,
              peck_left=ifelse(peck_any, (response == "peck_left") * 1, NA),
              correct=ifelse(peck_any, correct * 1, NA),
              trial=row_number(),
              rtime=rtime / 1e6  # convert to s
              )
)
stims <- (
    unique(trials$stimulus)
    |> str_match("(?<foreground>[:alnum:]+)-(?<foregroundlvl>[:digit:]+)_(?<background>[:alnum:]+)-(?<backgroundlvl>[:digit:]+)")
    |> as.data.frame()
    |> mutate(stimulus=V1, foreground, background, snr=as.numeric(backgroundlvl) - as.numeric(foregroundlvl), .keep="none")
    |> mutate(snr=forcats::fct_rev(factor(snr)))
    |> drop_na()
)
trials <- (
    trials
    |> inner_join(stims, by="stimulus")
)
# generate a lookup table to determine which stimuli are associated with left key for which birds
stimclasses <- (
    trials
    |> xtabs(~ subject + foreground + peck_left + correct, data=_) 
    |> as.data.frame() 
    |> filter(correct==1, Freq > 0)
    |> select(subject, foreground, stim_left=peck_left)
)
sessions <- (
    unique(trials$experiment)
    |> str_match("2ac-(?<type>[:alnum:]+)-snr.*_(?<snr>[0-9-]+)-.*")
    |> as.data.frame()
    |> mutate(experiment=V1, session_type=type, session_snr=as.numeric(snr), .keep="none")
    |> mutate(session_snr=forcats::fct_rev(factor(session_snr)))
)
trials <- (
    trials
    |> inner_join(stimclasses, by=c("subject", "foreground"))
    |> inner_join(sessions, by="experiment")
    |> inner_join(birds, by=c(subject="bird"))
    |> select(subject, sex, group, age, sibling, 
              trial, session_type, session_snr, foreground, background, stim_left, snr, 
              peck_any, peck_left, rtime, correct, result)
    # only keep first 10 trials for each stimulus in each experiment
    # |> group_by(subject, experiment, stimulus)
    # |> slice_head(n=10)
)


In [None]:
resp_probs <- (
    trials
    |> group_by(group, subject, snr, stim_left)
    |> summarize(
        n_trials=n(), 
        n_peck=sum(peck_any),
        n_correct=sum(correct, na.rm=T),
        n_left=sum(peck_left, na.rm=T)
    )
)

In [None]:
options(repr.plot.width=3, repr.plot.height=3, repr.plot.res = 300)
(
    resp_probs
    |> summarize(n_correct=sum(n_correct), n_peck=sum(n_peck))
    |> ggplot(aes(snr, n_correct / n_peck, group=subject, color=group))
    + geom_line()
    #+ facet_wrap(~ subject)
    + scale_y_continuous("p(correct)", limits=c(0, 1))
    + theme_classic() + my.theme
)

In [None]:
options(repr.plot.width=3, repr.plot.height=3, repr.plot.res = 300)
(
    resp_probs
    |> summarize(n_trials=sum(n_trials), n_correct=sum(n_correct), n_peck=sum(n_peck))
    |> ggplot(aes(snr, n_peck / n_trials, group=subject, color=group))
    + geom_line()
    #+ facet_wrap(~ subject)
    + scale_y_continuous("p(peck)", limits=c(0, 1))
    + theme_classic() + my.theme
)

In [None]:
options(repr.plot.width=3, repr.plot.height=3, repr.plot.res = 300)
(
    trials
    |> filter(peck_any==1, rtime > 0)
    |> group_by(group, subject, snr)
    |> summarize(rtime=median(rtime))
    |> ggplot(aes(snr, rtime, group=subject, color=group))
    + geom_line()
    #+ facet_wrap(~ subject)
    + scale_y_continuous("rtime")
    + theme_classic() + my.theme
)

In [None]:
options(repr.plot.width=3, repr.plot.height=3, repr.plot.res = 300)
(
    resp_probs
    |> ggplot(aes(snr, n_left / n_peck, group=stim_left, color=group))
    + geom_line()
    + geom_line(aes(y= 1 - n_peck / n_trials), color="black")
    + facet_wrap(~ subject)
    + scale_y_continuous("p(left)", limits=c(0, 1))
    + theme_classic() + my.theme
)

### Model for individual bird


In [None]:
example_trials <- trials |> filter(subject=="C280")
example_probs <- resp_probs |> filter(subject=="C280")

In [None]:
(
    example_trials
    |> colnames()
)

In [None]:
block_size <- 100
p_timeout <- (
    example_trials
    |> mutate(block=factor(floor(row_number() / block_size)))
    |> group_by(block)
    |> summarize(p_timeout=sum(1 - peck_any)/n(), n_trials=n(), index_trial=median(trial))
)
p_left <- (
    example_trials
    |> filter(peck_any==1)
    |> mutate(block=factor(floor(row_number() / block_size)))
    |> group_by(block, stim_left)
    |> summarize(n_trials=n(), p_left=sum(peck_left)/n_trials, index_trial=median(trial))
)

In [None]:
options(repr.plot.width=2, repr.plot.height=1.25, repr.plot.res = 300)
p <- (
    ggplot(mapping=aes(index_trial))
    + geom_line(data=p_timeout, aes(y=p_timeout))
    + geom_line(data=p_left, aes(y=p_left, color=factor(stim_left), group=stim_left))
    + scale_color_manual(values=c("#F68626", "#2677B4"))
    + scale_x_continuous("Trial")
    + scale_y_continuous("p(left)")
)
p + theme_classic() + my.theme + no.legend

In [None]:
options(repr.plot.width=2, repr.plot.height=1.25, repr.plot.res = 300)
p <- (
    example_probs
    |> mutate(snr=as.numeric(as.character(snr)))
    |> ggplot(aes(snr, n_left / n_peck))
    + geom_line(aes(group=stim_left, color=stim_left))
    + stat_summary(aes(y= 1 - n_peck / n_trials, group=1), fun="mean", geom="line", color="black")
    + scale_y_continuous("Prob", limits=c(0, 1))
    + scale_x_reverse("SNR (dB)")
    + scale_color_manual(values=c("#F68626", "#2677B4"))
    + theme_classic() + my.theme + no.legend
)
p

In [None]:
pdf("../figures/C280_probe.pdf", width=2, height=1.25)
print(p)
dev.off()

In [None]:
options(repr.plot.width=4, repr.plot.height=3, repr.plot.res = 300)
(
    example_trials
    |> filter(peck_any==1, rtime > 0)
    |> ggplot(aes(x=rtime, group=correct, color=factor(correct)))
    + facet_wrap(~ snr)
    + geom_density()
    + theme_classic() + my.theme
)

In [None]:
options(repr.plot.width=3, repr.plot.height=3, repr.plot.res = 300)
(
    example_trials
    |> filter(session_type=="probe")
    |> group_by(session_snr, foreground, background, stim_left, snr)
    |> summarize(n_trials=n(), n_peck=sum(peck_left))
    |> ggplot(aes(session_snr, n_peck / n_trials, color=factor(stim_left)))
    + facet_wrap(~ snr)
    + geom_point()
    + scale_y_continuous("p(left)", limits=c(0, 1))
    + theme_classic() + my.theme + no.legend
)

In [None]:
(
    example_trials
    |> filter(block_type=="probe")
    |> group_by(block_snr, foreground, background, stim_left, snr)
    |> summarize(n_trials=n(), n_peck=sum(peck_left))
    #|> filter(block_snr==5)
    |> ggplot(aes(snr, n_peck / n_trials, color=factor(stim_left)))
    + facet_grid(foreground ~ background)
    + geom_point()
    #+ geom_line()
    + scale_y_continuous("p(left)", limits=c(0, 1))
    + theme_classic() + my.theme + no.legend
)

In [None]:
# code SNR as a continuous number
fm_example_snr_cov <- (
    example_probs
    |> mutate(snr=as.numeric(as.character(snr)))
    |> filter(snr < 50)
    |> glm(cbind(n_correct, n_peck - n_correct) ~ 1 + snr, data=_, family=binomial)
)
summary(fm_example_snr_cov)

In [None]:
options(repr.plot.width=3, repr.plot.height=3, repr.plot.res = 300)
snr_seq <- seq(-10, 35, length.out=100)
pred <- (
    fm_example_snr_cov
    |> emmeans(~ snr, at=list(snr=snr_seq), type="response")
    |> as.data.frame()
)
(
    example_probs
    |> mutate(snr=as.numeric(as.character(snr)))
    |> ggplot(aes(snr, n_correct / n_peck))
    + geom_point(mapping=aes(color=stim_left))
    + geom_line(data=pred, mapping=aes(snr, prob))
    + scale_x_reverse()
    + theme_classic() + my.theme + no.legend
)

In [None]:
# code SNR as a factor; peck left as dependent variable
fm_example_snr_fac <- (
    example_probs
    |> glm(cbind(n_left, n_peck - n_left) ~ 1 + snr*stim_left, data=_, family=binomial)
)
summary(fm_example_snr_fac)

In [None]:
# code SNR as a factor; nonresponse as dependent variable
fm_example_snr_nr <- (
    example_probs
    |> glm(cbind(n_trials - n_peck, n_peck) ~ 1 + snr, data=_, family=binomial)
)
emm_example_snr_nr <- (
    fm_example_snr_nr
    |> emmeans(~ snr, type="response")
    |> confint(level=0.90)
    |> mutate(snr=as.numeric(as.character(snr)))
)
emm_example_snr_nr

In [None]:
options(repr.plot.width=1.7, repr.plot.height=1.2, repr.plot.res = 450)
p <- (
    fm_example_snr_fac
    |> emmeans(~ snr*stim_left, type="response")
    |> confint(level=0.90)
    |> mutate(snr=as.numeric(as.character(snr)))
    |> ggplot(aes(snr, prob, ymin=asymp.LCL, ymax=asymp.UCL))
    + geom_line(mapping=aes(group=stim_left, color=stim_left))
    + geom_ribbon(mapping=aes(group=stim_left, fill=stim_left), alpha=0.25)
    + geom_line(data=emm_example_snr_nr, color="black")
    + geom_ribbon(data=emm_example_snr_nr, fill="black", alpha=0.25)
    + scale_y_continuous("Prob", limits=c(0, 1))
    + scale_x_reverse("SNR (dB)")
    + scale_color_manual(values=c("#F68626", "#2677B4"))
    + scale_fill_manual(values=c("#F68626", "#2677B4"))
    + theme_classic() + my.theme + no.legend
)
p

In [None]:
pdf("../figures/2ac_probe_C280.pdf", width=1.7, height=1.2)
print(p)
dev.off()

#### Response time

Model response time as a function of SNR and response correctness.

In [None]:
fm_example_rtime <- (
    example_trials
    |> filter(peck_any==1, rtime > 0)
    |> lm(log10(rtime) ~ 1 + snr*correct, data=_)
)
joint_tests(fm_example_rtime)

In [None]:
(
    fm_example_rtime
    |> emmeans(~ snr:correct, type="response")
    |> as.data.frame()
    |> ggplot(aes(snr, response, color=factor(correct), group=correct))
    + geom_line()
    + geom_point()
    + geom_linerange(mapping=aes(ymin=lower.CL, ymax=upper.CL))
    + scale_y_log10()
    + theme_classic() + my.theme + no.legend
)

### Model for all birds

Modeling SNR as a linear covariate means adding a random intercept and slope for subject.

In [None]:
fm_snr_cov <- (
    resp_probs
    |> filter(snr != 70)
    |> mutate(snr=as.numeric(as.character(snr)))
    |> glmer(cbind(n_correct, n_peck - n_correct) ~ 1 + snr*group + (1+snr|subject), family=binomial, data=_)
)
summary(fm_snr_cov)

In [None]:
snr_seq <- seq(-10, 35, length.out=100)
pred <- (
    fm_snr_cov
    |> emmeans(~ snr*group, at=list(snr=snr_seq), type="response")
    |> as.data.frame()
)
(
    pred
    |> ggplot(aes(snr, prob, ymin=prob - SE, ymax=prob + SE, color=group, fill=group))
    + geom_line()
    + geom_ribbon(alpha=0.2)
    + scale_x_reverse()
)

### SNR as categorical variable

Switching to SNR as a categorical variable makes it easier to model discrimination rather than performance, because the bias can change by SNR. We still need random intercept and slopes because the curves vary a lot by subject, but that introduces a huge number of variables.

In [None]:
fm_snr_cat <- (
    resp_probs
    # |> filter(snr != 70)
    |> glmer(cbind(n_left, n_peck - n_left) ~ stim_left*snr*group + (1+stim_left|subject), family=binomial, data=_)
)
joint_tests(fm_snr_cat)

In [None]:
noisy_levels <- levels(resp_probs$snr)[-(1:2)]
(
    fm_snr_cat
    |> emmeans(~ stim_left*group, at=list(snr=noisy_levels))
    |> contrast(interaction=c("revpairwise", "pairwise"))
)

In [None]:
levels(resp_probs$snr)[-(1:2)]

In [None]:
(
    fm_snr_cat
    |> emmeans(~ stim_left*group*snr)
    |> contrast(interaction=c("revpairwise", "pairwise"), by="snr")
)

In [None]:
options(repr.plot.width=1.6, repr.plot.height=1.2, repr.plot.res = 450)
p <- (
    emmeans(fm_snr_cat, ~ stim_left | snr/group)
    |> contrast("revpairwise")
    |> as.data.frame()
    |> mutate(snr=as.numeric(as.character(snr)))
    |> ggplot(aes(snr, estimate, group=group, color=group))
    + geom_line()
    + geom_point(size=1.5)
    + geom_linerange(mapping=aes(ymin=estimate - SE, ymax=estimate + SE))
    + scale_y_continuous("Discrimination (LOR)")
    + scale_x_reverse("SNR (dB)")
    + theme_classic() + my.theme + no.legend
)
p

In [None]:
pdf("../figures/probe_discrim.pdf", width=1.6, height=1.2)
print(p)
dev.off()

### Non-response probability

Looks like PR birds are more likely to timeout than CR birds

In [None]:
fm_noresp <- (
    resp_probs
    |> glmer(cbind(n_trials - n_peck, n_trials) ~ snr*group + (1|subject), data=_, family=binomial)
)
summary(fm_noresp)

In [None]:
(
    fm_noresp
    |> emmeans(~ group | snr)
    |> contrast("revpairwise", type="response")
    |> as.data.frame()
)

In [None]:
(
    fm_noresp
    |> emmeans(~ group)
    |> contrast("revpairwise")
)

In [None]:
options(repr.plot.width=1.65, repr.plot.height=1.2, repr.plot.res = 450)
p <- (
    emmeans(fm_noresp, ~ group:snr)
    |> as.data.frame()
    |> mutate(snr=as.numeric(as.character(snr)))
    |> ggplot(aes(snr, emmean, group=group, color=group))
    + geom_line()
    + geom_point(size=1.5)
    + geom_linerange(mapping=aes(ymin=emmean - SE, ymax=emmean + SE))
    + scale_y_continuous("p(no resp) [log odds]")
    + scale_x_reverse("SNR (dB)")
    + theme_classic() + my.theme + no.legend
)
p

In [None]:
pdf("../figures/probe_respond.pdf", width=1.65, height=1.2)
print(p)
dev.off()

#### p(correct)

Considering non-responses as errors

In [None]:
fm_correct <- (
    resp_probs
    |> glmer(cbind(n_correct, n_trials - n_correct) ~ snr*group + (1|subject), data=_, family=binomial)
)

In [None]:
options(repr.plot.width=1.65, repr.plot.height=1.2, repr.plot.res = 450)
p <- (
    emmeans(fm_correct, ~ group:snr, type="response")
    |> as.data.frame()
    |> mutate(snr=as.numeric(as.character(snr)))
    |> ggplot(aes(snr, prob, group=group, color=group))
    + geom_line()
    + geom_point(size=1.5)
    + geom_linerange(mapping=aes(ymin=prob - SE, ymax=prob + SE))
    + scale_y_continuous("p(correct)", limits=c(0,1))
    + scale_x_reverse("SNR (dB)")
    + theme_classic() + my.theme + no.legend
)
p

In [None]:
pdf("../figures/probe_correct.pdf", width=1.7, height=1.2)
print(p)
dev.off()

In [None]:
(
    fm_correct
    |> emmeans(~ group | snr)
    |> contrast("revpairwise", type="response")
)

In [None]:
(
    fm_correct
    |> emmeans(~ group)
    |> contrast("pairwise")
)

#### Response time

In [None]:
fm_rtime <- (
    trials
    |> filter(peck_any == 1, rtime > 0)
    |> lmer(log10(rtime) ~ snr*group + (1|subject), data=_)
)
summary(fm_rtime)

In [None]:
(
    emmeans(fm_rtime, ~ group:snr, type="response")
    |> as.data.frame()
    |> mutate(snr=as.numeric(as.character(snr)))
    |> ggplot(aes(snr, response, group=group, color=group))
    + geom_line()
    + geom_point()
    + geom_linerange(mapping=aes(ymin=response - SE, ymax=response + SE))
    + scale_y_log10("Response Time (s)")
    + scale_x_reverse()
    + theme_classic() + my.theme + no.legend
)

### Population model averaging

There's also the poor man's approach of fitting models to individual subjects and then doing statistics on the point estimates.

In [None]:
fit_glm <- function(df) {
    glm(cbind(n_peck, n_trials - n_peck) ~ 1 + snr*stim_left, data=df, family=binomial)
}

emm_glm <- function(fm) {
    fm |> emmeans(~ snr:stim_left) |> contrast(interaction="revpairwise", by="snr") |> as.data.frame() |> select(snr, estimate)
}

subj_estimates <- (
    trials_pooled
    |> group_by(group, subject)
    |> nest()
    |> mutate(model=purrr::map(data, fit_glm), emms=purrr::map(model, emm_glm))
    |> select(group, subject, emms) 
    |> unnest(cols=emms)
)
                

In [None]:
(
    subj_estimates
    |> ggplot(aes(snr, estimate, group=subject, color=subject))
    + geom_line()
    + facet_grid(~ group)
    + theme_classic() + my.theme
)

In [None]:
# quick inspection of where a bird is
tibble(data.table::fread("../C342_trials.csv")) |> xtabs(~ experiment, data=_)