# State-space learning model

This notebook summarizes results from the SSM model for pretraining and training (2AC)

In [None]:
import <- function(pkg) { library(pkg, warn.conflicts=F, quietly=T, character.only=T) }
import("repr")
import("stringr")
import("tidyr")
import("dplyr")
import("ggplot2")
import("diagis")
import("bssm")

In [None]:
options(repr.matrix.max.cols=15, repr.matrix.max.rows=20)
options(repr.plot.width=4, repr.plot.height=2.5, repr.plot.res = 300)

my.theme <- theme(legend.text=element_text(size=5),
                  legend.title=element_text(size=6),
                  plot.title = element_text(size=8, hjust=0.5),
                  axis.line=element_line(linewidth=0.25),
                  axis.ticks=element_line(linewidth=0.25),
                  axis.title=element_text(size=8),
                  axis.text=element_text(size=6),
                  strip.placement="outside",
                  strip.text=element_text(size=8),
                  strip.background=element_blank())
no.legend <- theme(legend.position="none")
update_geom_defaults("point", list(fill="white", shape=21, size=1.1))
update_geom_defaults("line", list(linewidth=0.25))

In [None]:
logit <- function(p) log(p / (1 - p))
invlogit <- function(x) exp(x) / (1 + exp(x))
last_true <- function(x) { 
    lt <- tail(which(x), 1)
    ifelse(identical(lt, integer(0)), 1, lt)
}
summarize_samples <- function(x) { c(mean=mean(x), lwr=quantile(x, 0.05), upr=quantile(x, 0.95))}
crit_name <- "io_1"

## Pretraining

### Example bird

In [None]:
results <- readRDS("../build/C269_pretrain_ssm_summary.rds")

In [None]:
summary_bird <- (
    results
    |> purrr::discard_at(c("data", "corr"))
    |> purrr::map(~ select(., time, variable, mean, lwr, upr))
    |> purrr::list_rbind()
)

In [None]:
options(repr.plot.width=4, repr.plot.height=2.5, repr.plot.res = 300)
p_nr <- (
    summary_bird
    |> filter(variable=="level")
    |> inner_join(results$data, by="time")
    |> ggplot(aes(trial))
    + geom_point(aes(y=ifelse(noresp, 1.05, NA)), shape="|")
    + geom_line(mapping=aes(y=mean))
    + geom_ribbon(mapping=aes(ymin=lwr, ymax=upr), alpha=0.25)
    + scale_y_continuous("prob", breaks=c(0, 0.2, 0.4, 0.6, 0.8, 1.0))
    + theme_classic() + no.legend
)
p_nr

In [None]:
options(repr.plot.width=4, repr.plot.height=2.5, repr.plot.res = 300)
p_left <- (
    summary_bird
    |> filter(variable %in% c("level", "p_left", "p_right"))
    |> pivot_wider(names_from="variable", values_from=c("mean", "upr", "lwr"))
    |> inner_join(results$data, by="time")
    |> mutate(rpos=(peck_left - 0.5)* 1.1 + 0.5, spos=stim_left * 0.05 - 0.025)
    |> ggplot(aes(trial))
    + geom_point(aes(y=ifelse(noresp, 1.125, NA)), shape="|")
    + geom_point(aes(y=rpos + spos, color=stimulus), shape="|")
    + geom_line(mapping=aes(y=mean_level), color="black")
    + geom_line(mapping=aes(y=mean_p_left), color="red")
    + geom_ribbon(mapping=aes(ymin=lwr_p_left, ymax=upr_p_left), fill="red", alpha=0.25)
    + geom_line(mapping=aes(y=mean_p_right), color="blue")
    + geom_ribbon(mapping=aes(ymin=lwr_p_right, ymax=upr_p_right), fill="blue", alpha=0.25)
    + scale_y_continuous("prob", breaks=c(0, 0.2, 0.4, 0.6, 0.8, 1.0))
    + theme_classic() + no.legend
)
p_left

In [None]:
p_discrim <- (
    summary_bird
    |> filter(variable %in% c("discrim", "bias"))
    |> inner_join(results$data, by="time")
    |> ggplot(aes(trial))
    + facet_grid(vars(variable))
    + geom_line(mapping=aes(y=mean))
    + geom_ribbon(mapping=aes(ymin=lwr, ymax=upr), alpha=0.25)
    + geom_hline(yintercept=1, linetype="dotted")
    + scale_y_continuous("LOR")
    + theme_classic() + no.legend
)
p_discrim

### Group data

In [None]:
## metadata
birds <- data.table::fread("../inputs/bird_metadata.csv")

In [None]:
load_discrim_summary <- function(subject) {
    summary_file <- str_c("../build/", subject, "_pretrain_ssm_summary.rds")
    if (file.exists(summary_file)) {
        summaries <- readRDS(summary_file)
        ( 
            summaries
            |> purrr::discard_at(c("data", "corr"))
            |> purrr::map(~ select(., time, variable, mean, lwr, upr))
            |> purrr::list_rbind()
            |> inner_join(summaries$data, by="time")
        )
    }
}

In [None]:
summary_all <- (
    birds$bird
    |> purrr::map(load_discrim_summary) 
    |> purrr::list_rbind()
    |> inner_join(birds, by=c(subject="bird"))
)

#### Plots of individual learning curves

Mostly useful for inspection.

In [None]:
options(repr.plot.width=4, repr.plot.height=2.5, repr.plot.res = 300)
p_noresp <- (
    summary_all
    |> filter(variable=="level")
    |> ggplot(aes(trial))
    + geom_line(mapping=aes(y=mean, group=subject, color=subject))
    + facet_grid(vars(group))
    #+ geom_ribbon(mapping=aes(ymin=lwr, ymax=upr, fill=subject), alpha=0.25)
    #+ geom_hline(yintercept=0, linetype="dotted")
    + scale_y_continuous("p(no_resp)")
    + theme_classic() + no.legend
)
p_noresp

In [None]:
options(repr.plot.width=4, repr.plot.height=2.5, repr.plot.res = 300)
p_discrim <- (
    summary_all
    |> filter(variable=="discrim")
    |> ggplot(aes(trial))
    + geom_line(mapping=aes(y=mean, group=subject, color=subject))
    + facet_grid(vars(group))
    #+ geom_ribbon(mapping=aes(ymin=lwr, ymax=upr, fill=subject), alpha=0.25)
    + geom_hline(yintercept=0, linetype="dotted")
    + scale_y_continuous("discrim (LOR)")
    + theme_classic() + no.legend
)
p_discrim

In [None]:
options(repr.plot.width=4, repr.plot.height=2.5, repr.plot.res = 300)
p_discrim <- (
    summary_all
    |> filter(variable=="bias")
    |> ggplot(aes(trial))
    + geom_line(mapping=aes(y=mean, group=subject, color=subject))
    + facet_grid(vars(group))
    #+ geom_ribbon(mapping=aes(ymin=lwr, ymax=upr, fill=subject), alpha=0.25)
    + geom_hline(yintercept=0, linetype="dotted")
    + scale_y_continuous("Bias (log odds)")
    + theme_classic() + no.legend
)
p_discrim

#### Trials to criterion

In [None]:
summary_discrim <- filter(summary_all, variable=="discrim")
summary_criterion <- (
    summary_discrim
    |> mutate(io_chance=lwr < 0, io_1=lwr < 1.0)
    |> group_by(subject)
    |> summarize(across(starts_with("io_"), last_true), success=any(lwr >= 1.0))
    |> pivot_longer(starts_with("io_"), values_to="time")
    |> inner_join(summary_discrim, by=c("subject", "time"))
    |> select(subject, name, time, trial, tot_rewarded, tot_noresp, success, uuid, sex, group, age, sibling)
)

In [None]:
summary_criterion |> filter(name==crit_name) |> arrange(group) #|> xtabs(~ group, data=_)

In [None]:
options(repr.plot.width=1.5, repr.plot.height=2, repr.plot.res = 300)
(
    summary_criterion
    |> filter(name==crit_name)
    |> ggplot(aes(group, trial))
    + geom_point(aes(shape=success))
    + scale_shape_manual(values=c(4, 16))
    + theme_classic() + my.theme + no.legend
)

In [None]:
options(repr.plot.width=1.5, repr.plot.height=2, repr.plot.res = 300)
(
    summary_criterion
    |> filter(name==crit_name, success==TRUE)
    |> ggplot(aes(group, trial))
    + geom_boxplot(width=0.2)
    + theme_classic() + my.theme + no.legend
)

In [None]:
options(repr.plot.width=2.5, repr.plot.height=2, repr.plot.res = 300)
(
    summary_criterion
    |> filter(name==crit_name)
    |> ggplot(aes(sibling, trial))
    + geom_point(aes(color=group, shape=success))
    + scale_shape_manual(values=c(4, 16))
    + theme_classic() + my.theme
)

In [None]:
options(repr.plot.width=5, repr.plot.height=2, repr.plot.res = 300)
(
    summary_criterion
    |> filter(name==crit_name)
    |> pivot_longer(cols=3:6, names_to="count", values_to="trials")
    |> ggplot(aes(sibling, trials))
    + facet_grid(~ count)
    + geom_point(aes(color=group))
    + theme_classic() + my.theme
)

In [None]:
(
    summary_criterion
    |> filter(name==crit_name, success=TRUE)
    |> t.test(trial ~ group, data=_)
)

#### Average discrimination

Excluding birds that failed (C250), what is average performance at trial 2799 (when fastest bird succeeded)?

In [None]:
max_trials <- (
    summary_discrim
    |> filter(!subject %in% c("C197", "C250"))
    |> group_by(subject)
    |> summarize(trial=max(trial))
)
discrim_by_first_removal <- (
    summary_discrim
    |> filter(!subject %in% c("C197", "C250"))
    |> group_by(subject)
    |> filter(trial >= min(max_trials$trial))
    |> filter(row_number() == 1)
)

In [None]:
options(repr.plot.width=1.5, repr.plot.height=2, repr.plot.res = 300)
(
    discrim_by_first_removal
    |> ggplot(aes(group, mean))
    + geom_point()
    + ylab("Discrim (LOR)")
    + theme_classic() + my.theme + no.legend
)

In [None]:
(
    discrim_by_first_removal
    |> t.test(mean ~ group, data=_)
)

#### Non-response probability

Average the log odds of non-response in the last 100 trials before the birds reach criterion.

In [None]:
last_trial_stats <- (
    filter(summary_all, variable=="level")
    |> inner_join(summary_criterion |> filter(name==crit_name) |> select(subject, crit_time=time), by="subject")
    |> group_by(subject)
    |> filter(crit_time - time < 100 & crit_time - time > 0)
    |> summarize(p_noresp=mean(mean))
    |> inner_join(birds, by=c(subject="bird"))
    |> ggplot(aes(group, p_noresp))
    + geom_boxplot(width=0.2)
)
last_trial_stats

## Training

In [None]:
options(repr.plot.width=1.5, repr.plot.height=2, repr.plot.res = 300)
(
    last_trial_stats
    |> ggplot(aes(group, p_noresp))
    + geom_point()
    + ylab("p(no resp)")
    + theme_classic() + my.theme + no.legend
)

In [None]:
(
    last_trial_stats
    |> t.test(p_noresp ~ group, data=_)
)

### Example bird

In [None]:
results <- readRDS("../build/Rb279_train_ssm_summary.rds")

In [None]:
summary_bird <- (
    results
    |> purrr::discard_at(c("data", "corr"))
    |> purrr::map(~ select(., time, variable, mean, lwr, upr))
    |> purrr::list_rbind()
)

In [None]:
options(repr.plot.width=4, repr.plot.height=2.5, repr.plot.res = 300)
p_nr <- (
    summary_bird
    |> filter(variable=="level")
    |> inner_join(results$data, by="time")
    |> ggplot(aes(trial))
    + geom_point(aes(y=ifelse(noresp, 1.05, NA)), shape="|")
    + geom_line(mapping=aes(y=mean))
    + geom_ribbon(mapping=aes(ymin=lwr, ymax=upr), alpha=0.25)
    + scale_y_continuous("prob", breaks=c(0, 0.2, 0.4, 0.6, 0.8, 1.0))
    + theme_classic() + no.legend
)
p_nr

In [None]:
p_left <- (
    summary_bird
    |> filter(variable %in% c("level", "p_left", "p_right"))
    |> pivot_wider(names_from="variable", values_from=c("mean", "upr", "lwr"))
    |> inner_join(results$data, by="time")
    |> mutate(rpos=(peck_left - 0.5)* 1.1 + 0.5, spos=stim_left * 0.05 - 0.025)
    |> ggplot(aes(trial))
    + geom_point(aes(y=ifelse(noresp, 1.125, NA)), shape="|")
    + geom_point(aes(y=rpos + spos, color=factor(stim_left)), shape="|")
    + geom_line(mapping=aes(y=mean_level), color="black")
    + geom_line(mapping=aes(y=mean_p_left), color="blue")
    + geom_ribbon(mapping=aes(ymin=lwr_p_left, ymax=upr_p_left), fill="blue", alpha=0.25)
    + geom_line(mapping=aes(y=mean_p_right), color="red")
    + geom_ribbon(mapping=aes(ymin=lwr_p_right, ymax=upr_p_right), fill="red", alpha=0.25)
    + scale_y_continuous("prob", breaks=c(0, 0.2, 0.4, 0.6, 0.8, 1.0))
    + theme_classic() + no.legend
)
p_left

In [None]:
p_discrim <- (
    summary_bird
    |> filter(variable %in% c("discrim", "bias"))
    |> inner_join(results$data, by="time")
    |> ggplot(aes(trial))
    + facet_grid(vars(variable))
    + geom_line(mapping=aes(y=mean))
    + geom_ribbon(mapping=aes(ymin=lwr, ymax=upr), alpha=0.25)
    + geom_hline(yintercept=0, linetype="dotted")
    + scale_y_continuous("LOR")
    + theme_classic() + no.legend
)
p_discrim

### Group data

In [None]:
## metadata
birds <- data.table::fread("../inputs/bird_metadata.csv")

In [None]:
load_discrim_summary <- function(subject) {
    summary_file <- str_c("../build/", subject, "_train_ssm_summary.rds")
    if (file.exists(summary_file)) {
        summaries <- readRDS(summary_file)
        ( 
            summaries
            |> purrr::discard_at(c("data", "corr"))
            |> purrr::map(~ select(., time, variable, mean, lwr, upr))
            |> purrr::list_rbind()
            |> inner_join(summaries$data, by="time")
        )
    }
}

In [None]:
summary_all <- (
    birds$bird
    |> purrr::map(load_discrim_summary) 
    |> purrr::list_rbind()
    |> inner_join(birds, by=c(subject="bird"))
)

In [None]:
options(repr.plot.width=4, repr.plot.height=2.5, repr.plot.res = 300)
p_noresp <- (
    summary_all
    |> filter(variable=="level")
    |> ggplot(aes(trial))
    + geom_line(mapping=aes(y=mean, group=subject, color=subject))
    + facet_grid(vars(group))
    #+ geom_ribbon(mapping=aes(ymin=lwr, ymax=upr, fill=subject), alpha=0.25)
    #+ geom_hline(yintercept=0, linetype="dotted")
    + scale_y_continuous("LOR")
    + theme_classic() + no.legend
)
p_noresp

In [None]:
options(repr.plot.width=4, repr.plot.height=2.5, repr.plot.res = 300)
p_discrim <- (
    summary_all
    |> filter(variable=="discrim")
    |> ggplot(aes(trial))
    + geom_line(mapping=aes(y=mean, group=subject, color=subject))
    + facet_grid(vars(group))
    #+ geom_ribbon(mapping=aes(ymin=lwr, ymax=upr, fill=subject), alpha=0.25)
    + geom_hline(yintercept=0, linetype="dotted")
    + scale_y_continuous("discrim (LOR)")
    + theme_classic()
)
p_discrim

In [None]:
options(repr.plot.width=4, repr.plot.height=2.5, repr.plot.res = 300)
p_discrim <- (
    summary_all
    |> filter(variable=="bias")
    |> ggplot(aes(trial))
    + geom_line(mapping=aes(y=mean, group=subject, color=subject))
    + facet_grid(vars(group))
    #+ geom_ribbon(mapping=aes(ymin=lwr, ymax=upr, fill=subject), alpha=0.25)
    + geom_hline(yintercept=0, linetype="dotted")
    + scale_y_continuous("bias (LOR)")
    + theme_classic() + no.legend
)
p_discrim

In [None]:
summary_discrim <- filter(summary_all, variable=="discrim")
df <- (
    summary_discrim
    |> filter(subject=="C235")
    |> mutate(io_chance=lwr < 0, io_1=lwr < 1.0)
)

In [None]:
summary_discrim <- filter(summary_all, variable=="discrim")
summary_criterion <- (
    summary_discrim
    |> mutate(io_chance=lwr < 0, io_1=lwr < 1.0)
    |> group_by(subject)
    |> summarize(across(starts_with("io_"), last_true), success=any(lwr >= 1.0))
    |> pivot_longer(starts_with("io_"), values_to="time")
    |> inner_join(summary_discrim, by=c("subject", "time"))
    |> select(subject, name, time, trial, tot_rewarded, tot_noresp, success, uuid, sex, group, age, sibling)
)

In [None]:
summary_criterion |> filter(name=="io_1") |> arrange(subject)

In [None]:
options(repr.plot.width=1.5, repr.plot.height=2, repr.plot.res = 300)
(
    summary_criterion
    |> filter(name=="io_1")
    |> ggplot(aes(group, trial))
    + geom_point(aes(shape=success))
    + scale_shape_manual(values=c(4, 16))
    + theme_classic() + my.theme + no.legend
)

In [None]:
options(repr.plot.width=2.5, repr.plot.height=2, repr.plot.res = 300)
(
    summary_criterion
    |> filter(name=="io_chance")
    |> ggplot(aes(sibling, trial))
    + geom_point(aes(color=group))
    + theme_classic() + my.theme
)