In [None]:
import <- function(pkg) { library(pkg, warn.conflicts=F, quietly=T, character.only=T) }
import("repr")
import("stringr")
import("tidyr")
import("dplyr")
import("ggplot2")
import("lme4")
import("emmeans")

In [None]:
options(repr.matrix.max.cols=15, repr.matrix.max.rows=20)
options(repr.plot.width=2, repr.plot.height=1.25, repr.plot.res = 300)

my.theme <- theme(legend.text=element_text(size=6),
                  legend.title=element_text(size=6),
                  plot.title = element_text(size=7, hjust=0.5),
                  axis.line=element_line(linewidth=0.25),
                  axis.ticks=element_line(linewidth=0.25),
                  axis.ticks.length=unit(0.05, "cm"),
                  axis.title=element_text(size=7),
                  axis.text=element_text(size=6),
                  strip.placement="outside",
                  strip.text=element_text(size=7),
                  strip.background=element_blank())
no.legend <- theme(legend.position="none")
update_geom_defaults("point", list(fill="white", shape=21, size=0.8))
update_geom_defaults("line", list(linewidth=0.4))

## Load trial data

In [None]:
## metadata
birds <- data.table::fread("../inputs/bird_metadata.csv") |> filter(behavior=="yes")
probe_birds <- birds |> filter(probe=="yes")

In [None]:
## trials - retrieved with batch/retrieve_trials
header <- data.table::fread(cmd='find ../build/ -name "*pretrain*_trials.csv" | head -n1 | xargs head -n1', header=T)
all_trials <- tibble(data.table::fread(cmd='find ../build/ -name "*pretrain*_trials.csv" | xargs tail -q -n+2', header=F))
names(all_trials) <- names(header)

In [None]:
# sanity check - each stimulus/response should only have one consequence: result should be empty
(
    all_trials
    |> group_by(subject, stimulus, response, correct)
    |> tally()
    |> tally()
    |> filter(n > 1)
)

In [None]:
# sanity check: no big skips in the dates, which could indicate trials recorded with the wrong subject
# NB the gap with C197 is due to a clock error. The missing days are put in the right place using trial id.
options(repr.plot.width=10, repr.plot.height=5, repr.plot.res = 300)
(
    all_trials
    |> mutate(date=lubridate::date(time))
    |> group_by(subject, date)
    |> tally()
    |> ggplot(aes(date, n))
    + facet_wrap(~ subject, scale="free")
    + geom_point()
)

In [None]:
## just the pretraining
block_size <- 100
trials <- (
    all_trials 
    |> group_by(subject)
    |> arrange(id)
    ## remove all trials cued with a light
    #|> filter(str_length(lights)==0) # filter(is.na(lights))
    ## trials are considered to be corrections if the stimulus was repeated and the previous trial was incorrect
    |> mutate(noresp=(response == "timeout") * 1,
              stim_left=1 - xor(response=="peck_left", correct),
              peck_left=ifelse(noresp, NA, (response == "peck_left") * 1),
              correct=ifelse(noresp, NA, correct * 1),
              trial=row_number(),
              tot_rewarded=cumsum(result=="feed"),
              tot_noresp=cumsum(response=="timeout"),
              ## trials are considered to be corrections if the stimulus was repeated and the previous trial was incorrect
              ## this is based on the assumption that correction trials are on but may not be logged correctly
              inferred_correction=(lag(stimulus)==stimulus & lag(!correct)))

)
valid_trials <- (
    trials
    |> filter(str_length(lights)==0, !inferred_correction, response != "peck_center") 
    |> mutate(time=row_number())
)

valid_trials |> inner_join(birds, by=c(subject="bird")) |> group_by(group, subject) |> tally()

## Example learning curves

In [None]:
example_trials <- filter(valid_trials, subject=="C291")
tail(example_trials)

In [None]:
blocked_timeouts <- (
    example_trials
    |> filter(str_length(lights)==0)
    |> group_by(subject)
    |> mutate(block=factor(floor(row_number() / block_size)), y=response=="timeout")
)

p_timeout <- (
    glm(y ~ block, data=blocked_timeouts, family=binomial)
    |> emmeans(~ block) 
    |> confint(level=0.90, type="response")
    |> inner_join(blocked_timeouts |> group_by(subject, block) |> summarize(index_trial=median(trial), n_trials=n()), by="block")
    |> filter(n_trials > 20)
)

In [None]:
block_size <- 50
blocked_pecks <- (
    example_trials
    |> filter(response != "timeout")
    |> filter(str_length(lights)==0)
    |> group_by(subject)
    |> mutate(block=factor(floor(row_number() / block_size)), y=response=="peck_left")
)

p_left <- (
    glm(y ~ block*stim_left, data=blocked_pecks, family=binomial)
    |> emmeans(~ block/stim_left) 
    |> confint(level=0.90, type="response")
    |> inner_join(blocked_pecks |> group_by(subject, block) |> summarize(index_trial=median(trial), n_trials=n()), by="block")
    |> filter(n_trials > 20)
    |> mutate(stim_left=factor(stim_left))
)

In [None]:
options(repr.plot.width=2, repr.plot.height=1.25, repr.plot.res = 300)
p <- (
    ggplot(mapping=aes(index_trial, prob))
    + geom_line(data=p_timeout)
    + geom_line(data=p_left, aes(color=stim_left, group=stim_left))
    + scale_x_continuous("Trial")
    + scale_y_continuous("p(left)")
)
p + theme_classic() + my.theme + no.legend

In [None]:
p_correct <- (
    glm(correct ~ factor(block)*stim_left, data=blocked_pecks, family=binomial)
    |> emmeans(~ block/stim_left) 
    |> confint(level=0.90, type="response")
    |> inner_join(blocked_pecks |> group_by(subject, block) |> summarize(index_trial=median(trial), n_trials=n()), by="block")
    |> filter(n_trials > 20)
    |> mutate(stim_left=factor(stim_left))
)

In [None]:
options(repr.plot.width=2, repr.plot.height=1.25, repr.plot.res = 300)
p <- (
    ggplot(mapping=aes(index_trial, prob))
    + geom_line(data=p_timeout)
    + geom_line(data=p_correct, aes(color=stim_left, group=stim_left))
    + scale_x_continuous("Trial")
    + scale_y_continuous("p(correct)")
)
p + theme_classic() + my.theme + no.legend

In [None]:
(
    glm(correct ~ factor(block)*stim_left, data=blocked_pecks, family=binomial)
    |> ref_grid()
)

## Trials to criterion

We want to estimate the number of trials that it takes for each subject to reach a criterion level of performance. Performance needs to be assessed as an estimated marginal mean - i.e., what is the probability of a correct response to all stimuli, assuming any stimulus is equally likely to be presented? We also need to impose a prior so that we can estimate probabilities in blocks where there are no correct or no incorrect trials for a given stimulus. The poor man's way of doing this is to add 1 success and 1 failure to each condition, which is equivalent to a $\mathrm{Beta}(1,1)$ prior.

Analysis needs to be done with the full set of data for each animal so that xtabs returns a full table.

In [None]:
blocked_pecks <- (
    valid_trials
    |> filter(response != "timeout")
    |> filter(str_length(lights)==0)
    |> filter(!inferred_correction)
    |> group_by(subject)
    |> mutate(block=floor(row_number() / block_size))
)

In [None]:
p_corr_emm <- function(df) {
    xt <- xtabs(~ block + stimulus + correct, data=df) + 1 # add one to every cell for the prior
    emms <- (
        prop.table(xt, margin=c(1,2))[,,2] 
        |> as.data.frame() 
        |> group_by(block) 
        |> summarize(p_corr_mean=mean(Freq), p_corr_min=min(Freq))
        |> mutate(block=as.numeric(block))
    )
    group_by(df, block) |> summarize(index_trial=first(trial), last_trial=last(trial), n_trials=n()) |> inner_join(emms, by="block")
}

In [None]:
blocked_emms <- (
    blocked_pecks
    |> group_by(subject)
    |> nest()
    |> mutate(emms=purrr::map(data, p_corr_emm))
    |> select(subject, emms)
    |> unnest(cols=c(emms))
)

In [None]:
(
    blocked_emms 
    |> ggplot(aes(index_trial, p_corr_mean, group=subject, color=subject))
    + geom_line()
    + my.theme + no.legend
)

In [None]:
options(repr.plot.width=3, repr.plot.height=2, repr.plot.res = 300)
p <- (
    blocked_emms
    |> filter(block <= 20)
    |> inner_join(birds, by=c(subject="bird"))
    |> group_by(group, block)
    |> summarize(n_birds=n(), p_corr_se=sd(p_corr_mean)/sqrt(n_birds), p_corr_mean=mean(p_corr_mean))
    |> ggplot(aes(block, p_corr_mean, color=group, group=group))
    + geom_point()
    + geom_linerange(aes(ymin=p_corr_mean - p_corr_se, ymax=p_corr_mean + p_corr_se))
    + theme_classic() + my.theme + no.legend
)
p 

In [None]:
trials_to_criterion <- (
    blocked_emms
    |> filter(p_corr_mean > 0.8)
    |> summarize(criterion_trial=first(last_trial))
    |> right_join(trials |> group_by(subject) |> summarize(final_trial=last(trial)))
    |> transmute(subject, failed=is.na(criterion_trial), criterion_trial = coalesce(criterion_trial, final_trial), )
    |> inner_join(birds, by=c(subject="bird"))
)
trials_to_criterion |> arrange(group)

In [None]:
(
    trials_to_criterion
    |> mutate(failed = failed | criterion_trial > 10000)
    |> xtabs(~ group + failed, data=_)
    #|> chisq.test()
)

In [None]:
options(repr.plot.width=1.5, repr.plot.height=2, repr.plot.res = 300)
p <- (
    trials_to_criterion
    # |> filter(probe=="yes")
    |> ggplot(aes(group, criterion_trial))
    #+ geom_boxplot(width=0.2, outlier.size=1)
    + geom_jitter(aes(color=sex), width=0.1)
    #+ stat_summary(fun.data="mean_se")
    + scale_y_log10("Trials")
    + scale_x_discrete(NULL)
    + ggtitle("Trials to Criterion\n(80% correct)")
    + theme_classic() + my.theme + no.legend
)
p 

In [None]:
pdf("../figures/discrim_trials_to_criterion.pdf", width=1.5, height=2)
print(p)
dev.off()

In [None]:
options(repr.plot.width=3, repr.plot.height=2.5, repr.plot.res = 300)
p <- (
    ggplot(trials_to_criterion, aes(age, criterion_trial))
    + geom_point(aes(color=group))
)
p + theme_classic() + my.theme

In [None]:
wilcox.test(criterion_trial ~ group, trials_to_criterion, subset=!failed)

In [None]:
fm_learning <- (
    trials_to_criterion
    |> filter(!failed)
    |> lm(log10(criterion_trial) ~ group*sex, data=_)
)
joint_tests(fm_learning)

In [None]:
options(repr.plot.width=1.8, repr.plot.height=1.45, repr.plot.res = 450)
p <- (
    fm_learning
    |> emmeans(~ group*sex, type="response")
    |> confint(level=0.90)
    |> ggplot(aes(group, response, color=sex))
    + geom_point(position=position_dodge(width=1), size=1.5)
    + geom_linerange(aes(ymin=lower.CL, ymax=upper.CL), position=position_dodge(width=1))
    + geom_jitter(data=trials_to_criterion |> filter(!failed), mapping=aes(y=criterion_trial), width=0.05)
    + scale_x_discrete(name=NULL)
    + scale_y_log10(name="trials to criterion")
    + scale_color_manual(values=c("red", "blue"))
    + theme_classic() + my.theme
)
p 

In [None]:
pdf("../figures/pretraining_trials_to_criterion.pdf", width=1.8, height=1.5)
print(p)
dev.off()

## Probability of non-response

Here we can take advantage of mixed-effects modeling by using multiple trials from each animal.

In [None]:
last_block <- (
    trials
    |> filter(str_length(lights)==0)
    |> group_by(subject)
    |> slice_tail(n = 200)
    |> ungroup()
    |> select(subject, stimulus, noresp, stim_left)
    |> inner_join(birds, by=c(subject="bird"))
    |> anti_join(filter(trials_to_criterion, failed==TRUE), by="subject")
)

In [None]:
fm_noresp <- (
    last_block
    |> glmer(noresp ~ group*sex + (1|subject) + (1|stimulus), data=_, family=binomial)
)
joint_tests(fm_noresp)

In [None]:
options(repr.plot.width=1.2, repr.plot.height=1.45, repr.plot.res = 450)
p <- (
    fm_noresp
    |> emmeans(~ group*sex, type="response")
    |> confint(level=0.90)
    |> ggplot(aes(sex, prob, color=group))
    # + facet_wrap(~ spike)
    + geom_point(position=position_dodge(width=0.5), size=1.5)
    + geom_linerange(aes(ymin=asymp.LCL, ymax=asymp.UCL), position=position_dodge(width=0.5))
    + scale_x_discrete(name=NULL)
    + scale_y_continuous(limits=c(0, 1), name="p(noresp) [last block]")
    + theme_classic() + my.theme + no.legend
)
p 

In [None]:
options(repr.plot.width=1.5, repr.plot.height=1.5, repr.plot.res = 300)
p_timeout <- (
    last_block
    |> group_by(subject)
    |> summarize(n_trials=n(), n_timeout=sum(noresp), p_timeout=n_timeout/n_trials)
    |> inner_join(birds, by=c(subject="bird"))
)
p <- (
    p_timeout 
    |> ggplot(aes(group, p_timeout, color=sex))
    + geom_jitter(width=0.1)
    #+ geom_boxplot(width=0.2, outlier.size=1)
    + scale_y_continuous("p(no resp) [last block]", limits=c(0,1))
    + scale_x_discrete(NULL)
    + theme_classic() + my.theme + no.legend
)
p 

In [None]:
pdf("../figures/discrim_pnoresp.pdf", width=1.5, height=1.5)
print(p)
dev.off()

## Training

In [None]:
## trials - retrieved with batch/retrieve_trials
header <- data.table::fread(cmd='find ../build/ -name "*_train_trials.csv" | head -n1 | xargs head -n1', header=T)
all_trials <- tibble(data.table::fread(cmd='find ../build/ -name "*_train_trials.csv" | xargs tail -q -n+2', header=F))
names(all_trials) <- names(header)

In [None]:
## just the training
block_size <- 100
trials <- (
    all_trials 
    |> group_by(subject)
    |> arrange(time)
    ## remove all trials cued with a light
    #|> filter(str_length(lights)==0) # filter(is.na(lights))
    ## trials are considered to be corrections if the stimulus was repeated and the previous trial was incorrect
    |> mutate(noresp=(response == "timeout") * 1,
              stim_left=1 - xor(response=="peck_left", correct),
              peck_left=ifelse(noresp, NA, (response == "peck_left") * 1),
              correct=ifelse(noresp, NA, correct * 1),
              trial=row_number(),
              tot_rewarded=cumsum(result=="feed"),
              tot_noresp=cumsum(response=="timeout"),
              ## trials are considered to be corrections if the stimulus was repeated and the previous trial was incorrect
              inferred_correction=(lag(stimulus)==stimulus & lag(!correct)))

)

valid_trials <- (
    trials
    |> filter(str_length(lights)==0, !inferred_correction, response != "peck_center") 
    |> mutate(time=row_number())
)

valid_trials |> inner_join(birds, by=c(subject="bird")) |> group_by(group, subject) |> tally()

In [None]:
trials |> xtabs(~ lights, data=_)

In [None]:
(
    valid_trials
    |> inner_join(probe_birds, by=c(subject="bird"))
    |> group_by(group, subject)
    |> summarize(experiment=first(experiment))
)

### Example training

In [None]:
example_trials <- filter(valid_trials, subject=="C313")
head(example_trials)

In [None]:
xtabs(~ stimulus, example_trials)

In [None]:
blocked_timeouts <- (
    example_trials
    |> filter(str_length(lights)==0)
    |> group_by(subject)
    |> mutate(block=factor(floor(row_number() / block_size)), y=response=="timeout")
)

p_timeout <- (
    glm(y ~ block, data=blocked_timeouts, family=binomial)
    |> emmeans(~ block) 
    |> confint(level=0.90, type="response")
    |> inner_join(blocked_timeouts |> group_by(subject, block) |> summarize(index_trial=median(trial), n_trials=n()), by="block")
    |> filter(n_trials > 20)
)

In [None]:
blocked_pecks <- (
    example_trials
    |> filter(response != "timeout")
    |> filter(str_length(lights)==0)
    |> group_by(subject)
    |> mutate(block=factor(floor(row_number() / block_size)), y=response=="peck_left")
)

p_left <- (
    glm(y ~ block*stim_left, data=blocked_pecks, family=binomial)
    |> emmeans(~ block/stim_left) 
    |> confint(level=0.90, type="response")
    |> inner_join(blocked_pecks |> group_by(subject, block) |> summarize(index_trial=median(trial), n_trials=n()), by="block")
    |> filter(n_trials > 20)
    |> mutate(stim_left=factor(stim_left))
)

In [None]:
options(repr.plot.width=2, repr.plot.height=1.25, repr.plot.res = 300)
p <- (
    ggplot(mapping=aes(index_trial, prob))
    + geom_line(data=p_timeout)
    + geom_line(data=p_left, aes(color=stim_left, group=stim_left))
    + scale_x_continuous("Trial")
    + scale_y_continuous("Prob")
)
p + theme_classic() + my.theme + no.legend

In [None]:
pdf("../figures/2ac_example_training_C294.pdf", width=2, height=1.25)
print(p + theme_classic() + my.theme + no.legend)
dev.off()

### Performance in final block

In [None]:
last_block <- (
    valid_trials
    |> filter(response != "timeout")
    |> group_by(subject)
    |> slice_tail(n = block_size)
    |> ungroup()
    |> select(time, trial, subject, stimulus, correct, peck_left, stim_left)
    |> inner_join(probe_birds, by=c(subject="bird"))
)

In [None]:
(
    last_block
    |> group_by(group, subject)
    |> summarize(p_corr=mean(correct))
    |> ggplot(aes(group, p_corr))
    + geom_boxplot()
)

In [None]:
fm_train_performance <- (
    last_block
    |> glmer(correct ~ group + (1|subject) + (1|stimulus), data=_, family=binomial, control=glmerControl(optimizer="bobyqa"))
)
joint_tests(fm_train_performance)

In [None]:
(
    fm_train_performance
    |> emmeans(~ group, type="response")
)

### Trials to criterion

For training, because we're only using birds that were included in the probe experiments, we just use the number of trials before the experiment was stopped.


In [None]:
training_stats <- (
    last_block
    |> group_by(subject)
    |> summarize(n_trials=last(trial), n_valid_trials=last(time), p_corr=mean(correct))
    |> inner_join(probe_birds, by=c(subject="bird"))
)

In [None]:
options(repr.plot.width=1, repr.plot.height=1.5, repr.plot.res = 450)
p <- (
    training_stats
    |> ggplot(aes(group, n_trials, color=group))
    #+ geom_boxplot(width=0.2, outlier.size=1)
    + geom_point()
    + stat_summary(fun=median, geom="crossbar", width=0.2, linewidth=0.2)
    + scale_y_log10("Trials to Criterion", breaks=c(5000,7000,10000), labels=c("5k", "7k", "10k"))
    + scale_x_discrete(NULL)
    + theme_classic() + my.theme + no.legend
)
p 

In [None]:
pdf("../figures/train_ntrials.pdf", width=1, height=1.5)
print(p)
dev.off()

In [None]:
wilcox.test(n_trials ~ group, training_stats)