In [None]:
import <- function(pkg) { library(pkg, warn.conflicts=F, quietly=T, character.only=T) }
import("repr")
import("stringr")
import("tidyr")
import("dplyr")
import("ggplot2")
import("lme4")
import("emmeans")

In [None]:
options(repr.matrix.max.cols=15, repr.matrix.max.rows=20)
options(repr.plot.width=2, repr.plot.height=1.25, repr.plot.res = 300)

my.theme <- theme(legend.text=element_text(size=5),
                  legend.title=element_text(size=6),
                  plot.title = element_text(size=8, hjust=0.5),
                  axis.line=element_line(linewidth=0.25),
                  axis.ticks=element_line(linewidth=0.25),
                  axis.title=element_text(size=8),
                  axis.text=element_text(size=6),
                  strip.placement="outside",
                  strip.text=element_text(size=8),
                  strip.background=element_blank())
no.legend <- theme(legend.position="none")
update_geom_defaults("point", list(fill="white", shape=21, size=1.1))
update_geom_defaults("line", list(linewidth=0.25))

## Load trial data

In [None]:
## metadata
birds <- data.table::fread("../inputs/bird_metadata.csv") |> filter(behavior=="yes")

In [None]:
## trials - retrieved with batch/retrieve_trials
header <- data.table::fread(cmd='find ../build/ -name "*pretrain*_trials.csv" | head -n1 | xargs head -n1', header=T)
all_trials <- tibble(data.table::fread(cmd='find ../build/ -name "*pretrain*_trials.csv" | xargs tail -q -n+2', header=F))
names(all_trials) <- names(header)

In [None]:
# sanity check - each stimulus/response should only have one consequence
(
    all_trials
    |> group_by(subject, stimulus, response, correct)
    |> tally()
    |> tally()#
    |> filter(n > 1)
)

In [None]:
# sanity check: no big skips in the dates, which could indicate trials recorded with the wrong subject
# NB the gap with C197 is due to a clock error. The missing days are put in the right place using trial id.
options(repr.plot.width=10, repr.plot.height=5, repr.plot.res = 300)
(
    all_trials
    |> mutate(date=lubridate::date(time))
    |> group_by(subject, date)
    |> tally()
    |> ggplot(aes(date, n))
    + facet_wrap(~ subject, scale="free")
    + geom_point()
)

In [None]:
## just the pretraining
block_size <- 100
trials <- (
    all_trials 
    |> group_by(subject)
    |> arrange(id)
    ## remove all trials cued with a light
    #|> filter(str_length(lights)==0) # filter(is.na(lights))
    ## trials are considered to be corrections if the stimulus was repeated and the previous trial was incorrect
    |> mutate(noresp=(response == "timeout") * 1,
              stim_left=1 - xor(response=="peck_left", correct),
              peck_left=ifelse(noresp, NA, (response == "peck_left") * 1),
              correct=ifelse(noresp, NA, correct * 1),
              trial=row_number(),
              tot_rewarded=cumsum(result=="feed"),
              tot_noresp=cumsum(response=="timeout"),
              ## trials are considered to be corrections if the stimulus was repeated and the previous trial was incorrect
              inferred_correction=(lag(stimulus)==stimulus & lag(!correct)))

)
valid_trials <- (
    trials
    |> filter(str_length(lights)==0, !inferred_correction, response != "peck_center") 
    |> mutate(time=row_number())
)

valid_trials |> inner_join(birds, by=c(subject="bird")) |> group_by(group, subject) |> tally()

## Example learning curves

In [None]:
example_trials <- filter(valid_trials, subject=="C313")
head(example_trials)

In [None]:
blocked_timeouts <- (
    example_trials
    |> filter(str_length(lights)==0)
    |> group_by(subject)
    |> mutate(block=factor(floor(row_number() / block_size)), y=response=="timeout")
)

p_timeout <- (
    glm(y ~ block, data=blocked_timeouts, family=binomial)
    |> emmeans(~ block) 
    |> confint(level=0.90, type="response")
    |> inner_join(blocked_timeouts |> group_by(subject, block) |> summarize(index_trial=median(trial), n_trials=n()), by="block")
    |> filter(n_trials > 20)
)

In [None]:
blocked_pecks <- (
    example_trials
    |> filter(response != "timeout")
    |> filter(str_length(lights)==0)
    |> group_by(subject)
    |> mutate(block=factor(floor(row_number() / block_size)), y=response=="peck_left")
)

p_left <- (
    glm(y ~ block*stim_left, data=blocked_pecks, family=binomial)
    |> emmeans(~ block/stim_left) 
    |> confint(level=0.90, type="response")
    |> inner_join(blocked_pecks |> group_by(subject, block) |> summarize(index_trial=median(trial), n_trials=n()), by="block")
    |> filter(n_trials > 20)
    |> mutate(stim_left=factor(stim_left))
)

In [None]:
options(repr.plot.width=2, repr.plot.height=1.25, repr.plot.res = 300)
p <- (
    ggplot(mapping=aes(index_trial, prob))
    + geom_line(data=p_timeout)
    + geom_line(data=p_left, aes(color=stim_left, group=stim_left))
    + scale_x_continuous("Trial")
    + scale_y_continuous("Prob")
)
p + theme_classic() + my.theme + no.legend

In [None]:
pdf("../figures/2ac_example_C294.pdf", width=2, height=1.25)
print(p + theme_classic() + my.theme + no.legend)
dev.off()

In [None]:
p_correct <- (
    glm(correct ~ factor(block)*stim_left, data=blocked_pecks, family=binomial)
    |> emmeans(~ block/stim_left) 
    |> confint(level=0.90, type="response")
    |> inner_join(blocked_pecks |> group_by(subject, block) |> summarize(index_trial=median(trial), n_trials=n()), by="block")
    |> filter(n_trials > 20)
    |> mutate(stim_left=factor(stim_left))
)

In [None]:
options(repr.plot.width=2, repr.plot.height=1.25, repr.plot.res = 300)
p <- (
    ggplot(mapping=aes(index_trial, prob))
    + geom_line(data=p_timeout)
    + geom_line(data=p_correct, aes(color=stim_left, group=stim_left))
    + scale_x_continuous("Trial")
    + scale_y_continuous("Prob")
)
p + theme_classic() + my.theme + no.legend

## Trials to criterion

In [None]:
blocked_pecks <- (
    valid_trials
    |> filter(response != "timeout")
    |> filter(str_length(lights)==0)
    |> filter(!inferred_correction)
    |> group_by(subject)
    |> mutate(block=floor(row_number() / block_size))
)
trials_to_criterion <- (
    blocked_pecks
    |> group_by(subject, block)
    |> summarize(index_trial=first(trial), n_trials=n(), n_correct=sum(correct), p_correct=n_correct/n_trials)
    |> filter(p_correct > 0.8)
    |> summarize(criterion_trial=first(index_trial))
    |> inner_join(birds, by=c(subject="bird"))
)
trials_to_criterion |> arrange(group)

In [None]:
# exclude C197 for now - accuracy was above 80% but had massive key bias
trials_to_criterion <- filter(trials_to_criterion, subject !="C197")

In [None]:
options(repr.plot.width=1.5, repr.plot.height=2, repr.plot.res = 300)
p <- (
    ggplot(trials_to_criterion, aes(group, criterion_trial))
    + geom_boxplot(width=0.2)
    # + geom_point(aes(color=sex))
    # + stat_summary(fun.data="mean_se")
    + scale_y_continuous("Trials")
    + scale_x_discrete(NULL)
    + ggtitle("Trials to Criterion\n(80% correct)")
)
p + theme_classic() + my.theme

In [None]:
options(repr.plot.width=3, repr.plot.height=2.5, repr.plot.res = 300)
p <- (
    ggplot(trials_to_criterion, aes(age, criterion_trial))
    + geom_point(aes(color=group))
)
p + theme_classic() + my.theme

In [None]:
t.test(criterion_trial ~ group, trials_to_criterion)

In [None]:
blocked_timeouts <- (
    trials
    |> filter(str_length(lights)==0)
    |> group_by(subject)
    |> mutate(block=floor(row_number() / block_size))
)
p_timeout <- (
    blocked_timeouts
    |> group_by(subject, block)
    |> summarize(index_trial=first(trial), n_trials=n(), n_timeout=sum(response=="timeout"), p_timeout=n_timeout/n_trials)
    |> arrange(block)
    |> summarize(p_timeout=last(p_timeout))
    |> inner_join(birds, by=c(subject="bird"))
    |> arrange(group)
)

In [None]:
p <- (
    ggplot(p_timeout, aes(group, p_timeout))
    + geom_boxplot(width=0.2, outlier.size=1)
    + scale_y_continuous("Prob", limits=c(0,1))
    + scale_x_discrete(NULL)
    + ggtitle("p(no resp)\n[last block]")
)
p + theme_classic() + my.theme

In [None]:
wilcox.test(p_timeout ~ group, p_timeout)

## Training

In [None]:
## trials - retrieved with batch/retrieve_trials
header <- data.table::fread(cmd='find ../build/ -name "*_train_trials.csv" | head -n1 | xargs head -n1', header=T)
all_trials <- tibble(data.table::fread(cmd='find ../build/ -name "*_train_trials.csv" | xargs tail -q -n+2', header=F))
names(all_trials) <- names(header)

In [None]:
## just the training
block_size <- 100
trials <- (
    all_trials 
    |> group_by(subject)
    |> arrange(time)
    ## remove all trials cued with a light
    #|> filter(str_length(lights)==0) # filter(is.na(lights))
    ## trials are considered to be corrections if the stimulus was repeated and the previous trial was incorrect
    |> mutate(noresp=(response == "timeout") * 1,
              stim_left=1 - xor(response=="peck_left", correct),
              peck_left=ifelse(noresp, NA, (response == "peck_left") * 1),
              correct=ifelse(noresp, NA, correct * 1),
              trial=row_number(),
              tot_rewarded=cumsum(result=="feed"),
              tot_noresp=cumsum(response=="timeout"),
              ## trials are considered to be corrections if the stimulus was repeated and the previous trial was incorrect
              inferred_correction=(lag(stimulus)==stimulus & lag(!correct)))

)

valid_trials <- (
    trials
    |> filter(str_length(lights)==0, !inferred_correction, response != "peck_center") 
    |> mutate(time=row_number())
)

valid_trials |> inner_join(birds, by=c(subject="bird")) |> group_by(group, subject) |> tally()

In [None]:
example_trials <- filter(valid_trials, subject=="C235")
head(example_trials)

In [None]:
blocked_timeouts <- (
    example_trials
    |> filter(str_length(lights)==0)
    |> group_by(subject)
    |> mutate(block=factor(floor(row_number() / block_size)), y=response=="timeout")
)

p_timeout <- (
    glm(y ~ block, data=blocked_timeouts, family=binomial)
    |> emmeans(~ block) 
    |> confint(level=0.90, type="response")
    |> inner_join(blocked_timeouts |> group_by(subject, block) |> summarize(index_trial=median(trial), n_trials=n()), by="block")
    |> filter(n_trials > 20)
)

In [None]:
blocked_pecks <- (
    example_trials
    |> filter(response != "timeout")
    |> filter(str_length(lights)==0)
    |> group_by(subject)
    |> mutate(block=factor(floor(row_number() / block_size)), y=response=="peck_left")
)

p_left <- (
    glm(y ~ block*stim_left, data=blocked_pecks, family=binomial)
    |> emmeans(~ block/stim_left) 
    |> confint(level=0.90, type="response")
    |> inner_join(blocked_pecks |> group_by(subject, block) |> summarize(index_trial=median(trial), n_trials=n()), by="block")
    |> filter(n_trials > 20)
    |> mutate(stim_left=factor(stim_left))
)

In [None]:
options(repr.plot.width=2, repr.plot.height=1.25, repr.plot.res = 300)
p <- (
    ggplot(mapping=aes(index_trial, prob))
    + geom_line(data=p_timeout)
    + geom_line(data=p_left, aes(color=stim_left, group=stim_left))
    + scale_x_continuous("Trial")
    + scale_y_continuous("Prob")
)
p + theme_classic() + my.theme + no.legend