In [None]:
suppressMessages(library(cowplot))
library(jsonlite)
library(tools)
suppressMessages(library(pROC))
library(devtools)
suppressMessages(devtools::load_all('../R/sumrep'))
source('../R/plot.R')

theme_set(theme_minimal())

In [None]:
data_dir = '../_ignore/plotting/2019-01-01-deneuter/'
data_path = function(path) paste0(data_dir, path)

The following command should run if the data is in the right place.

In [None]:
summarized = data_path('summarized.agg.csv')
system(paste('ls', summarized), intern=TRUE)

We can't compare likelihoods between OLGA and the VAE, because OLGA happily assigns zero probability to quite a few of the observed sequences in the test set. For example, here it assigns zero probability to 13 of 500 sequences from the first test set.

If we take this literally, this means that OLGA has zero out of sample likelihoods for this test set. 

In [None]:
system(paste("grep '\t0.0'", data_path('H10_B0/test-head.pgen.tsv'), '| wc -l'), intern=TRUE)

First we need to pick beta, which we can do on the validation set.

In [None]:
df = read.csv(summarized)

df$beta = as.factor(df$beta)
df = df[df$model != 'olga', ]

id_vars = c('model', 'beta')
measure_vars = c('validation_mean_log_p', 'validation_median_log_p')
ggplot(
    melt(df, id_vars, measure_vars), 
    aes(beta, value, color=model)
) + geom_point() + 
    ylab('log likelihood') +
    facet_wrap(vars(variable))

0.875 looks like a good choice, so we'll use that for the rest of this analysis.

In [None]:
our_beta = 0.875
fit_dir = data_path('deneuter-2018-12-31.train/0.875/')
fit_path = function(path) paste0(fit_dir, path)

Let's compare the CDR3 length distribution between the various programs.

In [None]:
prep_sumrep = function(path) {
    df = read.csv(path, stringsAsFactors=FALSE)
    colnames(df)[colnames(df) == 'amino_acid'] = 'junction_aa'
    data.table(df)
}
named_summary = function(summary_fun, summary_name, path, data_source, data_group) {
    df = data.frame(summary_fun(path))
    colnames(df) = c(summary_name)
    df$source = data_source
    df$group = data_group
    df
}
prep_summaries_general = function(data_glob_str, olga_str, basic_str, count_match_str, summary_fun, summary_name) {
    aux = function(path, data_source, data_group) {
        named_summary(summary_fun, summary_name, path, data_source, data_group)
    }
    data_df = do.call(rbind, 
                 lapply(
                     Sys.glob(data_path(data_glob_str)),
                     function(path) aux(path, 'data', path)
                 ))
    df = rbind(
        aux(fit_path(olga_str), 'olga', 'olga'),
        aux(fit_path(basic_str), 'basic', 'basic'),
        aux(fit_path(count_match_str), 'count_match', 'count_match') 
    )
    df = rbind(df, data_df)
    df$size = 1-as.numeric(df$source == 'data')
    df
}
prep_summaries = function(summary_fun, summary_name)
    prep_summaries_general('*/*for-test.csv', 'olga-generated.csv', 'basic/vae-generated.csv', 
                           'count_match/vae-generated.csv', summary_fun, summary_name)
plot_summaries = function(df, summary_name) {
    ggplot(df,
        aes_string(summary_name, color='source', group='group', size='size')) + 
        geom_density(bw=1) + scale_size(range=c(0.4, 1))
}

get_cdr3_lengths = function(path) getCDR3LengthDistribution(prep_sumrep(path), by_amino_acid = TRUE)
plot_summaries(prep_summaries(get_cdr3_lengths, 'CDR3_length'), 'CDR3_length')

In [None]:
get_pwdd = function(path) getPairwiseDistanceDistribution(prep_sumrep(path), column='junction_aa')
plot_summaries(prep_summaries(get_pwdd, 'pairwise_distance'), 'pairwise_distance')

In [None]:
df = read.csv(summarized)

trim_sumdiv = function(s) sub("sumdiv_","",s)

compare_model_divergences = function(df, beta) {
    df = df[df$beta == beta,]
    id_vars = c('test_set', 'model')
    measure_vars = grep('sumdiv_', colnames(df), value=TRUE)
    df = df[c(id_vars, measure_vars)]
    ggplot(
        melt(df, id_vars, measure_vars, variable.name='divergence_name', value.name='divergence'),
        aes_string('model', 'divergence', color='model')
    ) + geom_boxplot() +
        facet_wrap(vars(divergence_name), scales='free', labeller=as_labeller(trim_sumdiv)) +
        scale_y_log10() +
        theme(axis.text.x=element_blank()) +
        ggtitle(paste('beta =', beta))
}

compare_model_divergences(df, our_beta)

In [None]:
get_ppost = function(path) read.csv(path)$Ppost

summaries = prep_summaries_general(
    'deneuter-2018-12-31.train/*/ppost.csv', 'olga-generated.ppost.csv', 'basic/vae-generated.ppost.csv', 
    'count_match/vae-generated.ppost.csv', get_ppost, 'Ppost')
summaries$log_Ppost = log(summaries$Ppost)

plot_summaries(summaries, 'log_Ppost')

In [None]:
get_pvae = function(path) read.csv(path)$log_p_x

summaries = prep_summaries_general(
    'deneuter-2018-12-31.train/0.875/basic/*-test.head/test.pvae.csv', 'basic/olga-generated.pvae.csv', 'basic/vae-generated.pvae.csv', 
    'count_match/vae-generated.pvae.csv', get_pvae, 'log_Pvae')

plot_summaries(summaries, 'log_Pvae')