In [None]:
suppressMessages(library(cowplot))
library(jsonlite)
library(tools)
suppressMessages(library(pROC))
library(devtools)
suppressMessages(devtools::load_all('../R/sumrep'))
source('../R/plot.R')

theme_set(theme_minimal())
source_colors = c(basic = "#fc8d62", count_match = "#66c2a5", olga ="#8da0cb", data = "#D3D3D3")

In [None]:
output_dir = '_output_deneuter/'
system(paste('mkdir -p ', output_dir))
output_path = function(path) paste0(output_dir, path)

data_dir = paste0(normalizePath('../_ignore/plotting/2019-01-01-deneuter'), '/')
data_path = function(path) paste0(data_dir, path)

The following command should run if the data is in the right place.

In [None]:
summarized = data_path('summarized.agg.csv')
system(paste('ls', summarized), intern=TRUE)

We can't compare likelihoods between OLGA and the VAE, because OLGA happily assigns zero probability to quite a few of the observed sequences in the test set. For example, here it assigns zero probability to 13 of 500 sequences from the first test set.

If we take this literally, this means that OLGA has zero out of sample likelihoods for this test set. 

In [None]:
system(paste("grep '\t0.0'", data_path('H10_B0/test-head.pgen.tsv'), '| wc -l'), intern=TRUE)

First we need to pick beta, which we can do on the validation set.

In [None]:
df = read.csv(summarized)

df$beta = as.factor(df$beta)
df = df[df$model != 'olga', ]

id_vars = c('model', 'beta')
measure_vars = c('validation_mean_log_p', 'validation_median_log_p')
ggplot(
    melt(df, id_vars, measure_vars), 
    aes(beta, value, color=model)
) + geom_point() + 
    ylab('log likelihood') +
    facet_wrap(vars(variable)) +
    scale_color_manual(values=source_colors)

0.875 looks like a good choice, so we'll use that for the rest of this analysis.

In [None]:
our_beta = 0.875
fit_dir = data_path('deneuter-2018-12-31.train/0.875/')
fit_path = function(path) paste0(fit_dir, path)

Let's compare the CDR3 length distribution between the various programs.

In [None]:
prep_sumrep = function(path) {
    df = read.csv(path, stringsAsFactors=FALSE)
    colnames(df)[colnames(df) == 'amino_acid'] = 'junction_aa'
    data.table(df)
}
named_summary = function(summary_fun, summary_name, path, data_source, data_group) {
    df = data.frame(summary_fun(path))
    colnames(df) = c(summary_name)
    df$source = data_source
    df$group = data_group
    df
}
prep_summaries_general = function(data_glob_str, olga_str, basic_str, count_match_str, summary_fun, summary_name) {
    aux = function(path, data_source, data_group) {
        named_summary(summary_fun, summary_name, path, data_source, data_group)
    }
    data_df = do.call(rbind, 
                 lapply(
                     Sys.glob(data_path(data_glob_str)),
                     function(path) aux(path, 'data', path)
                 ))
    df = rbind(
        aux(fit_path(olga_str), 'olga', 'olga'),
        aux(fit_path(basic_str), 'basic', 'basic'),
        aux(fit_path(count_match_str), 'count_match', 'count_match') 
    )
    df = rbind(df, data_df)
    df$size = 1-as.numeric(df$source == 'data')
    df
}
prep_summaries = function(summary_fun, summary_name)
    prep_summaries_general('*/*for-test.csv', 'olga-generated.csv', 'basic/vae-generated.csv', 
                           'count_match/vae-generated.csv', summary_fun, summary_name)

plot_summaries = function(df, summary_name) {
    theme_set(theme_minimal(base_size=18))
    p = ggplot(df,
        aes_string(summary_name, color='source', group='group', size='size')) + 
        geom_density(bw=1) + scale_size(range=c(0.4, 1.2), guide='none') +
        theme(legend.justification=c(0,1), legend.position=c(0,1)) +
         scale_color_manual(values=source_colors)

    ggsave(output_path(paste0(summary_name, '.png')), width=8, height=4.5)
    p
}

plot_summaries(prep_summaries(
    function(path) getCDR3LengthDistribution(prep_sumrep(path), by_amino_acid = TRUE), 
    'CDR3_length'), 'CDR3_length')

Now let's explore the distribution of pairwise distances between the CDR3 sequences.

In [None]:
plot_summaries(prep_summaries(
    function(path) getPairwiseDistanceDistribution(prep_sumrep(path), column='junction_aa'),
    'pairwise_distance'), 
'pairwise_distance')

Let's look at divergences from the test sets for TCR sequences generated by the various programs.

In [None]:
df = read.csv(summarized)

trim_sumdiv = function(s) sub("sumdiv_","",s)

compare_model_divergences = function(df, beta) {
    df = df[df$beta == beta,]
    id_vars = c('test_set', 'model')
    measure_vars = grep('sumdiv_', colnames(df), value=TRUE)
    df = df[c(id_vars, measure_vars)]
    theme_set(theme_minimal())
    ggplot(
        melt(df, id_vars, measure_vars, variable.name='divergence_name', value.name='divergence'),
        aes_string('model', 'divergence', color='model')
    ) + geom_boxplot() +
        facet_wrap(vars(divergence_name), scales='free', labeller=as_labeller(trim_sumdiv)) +
        scale_y_log10() +
        theme(axis.text.x=element_blank(), legend.justification=c(1,0), legend.position=c(1,0)) +
        scale_color_manual(values=source_colors) +
        ggtitle(paste('beta =', beta))
}

compare_model_divergences(df, our_beta)

Let's look at the statistics where the VAEs are performing worst, such as bulkiness.

In [None]:
plot_summaries(prep_summaries(
    function(path) getBulkinessDistribution(prep_sumrep(path), column='junction_aa'),
    'bulkiness'), 
'bulkiness')

And polarity.

In [None]:
plot_summaries(prep_summaries(
    function(path) getPolarityDistribution(prep_sumrep(path), column='junction_aa'),
    'polarity'), 
'polarity')

Now let's look at a more sophisticated way of evaluating sequences, namely Ppost. 
If a synthetically generated sequence doesn't look like a real VDJ recombination, then Ppost will be low.

In [None]:
get_ppost = function(path) read.csv(path)$Ppost

summaries = prep_summaries_general(
    'deneuter-2018-12-31.train/*/ppost.csv', 'olga-generated.ppost.csv', 'basic/vae-generated.ppost.csv', 
    'count_match/vae-generated.ppost.csv', get_ppost, 'Ppost')
summaries$log_Ppost = log(summaries$Ppost)

plot_summaries(summaries, 'log_Ppost')

In [None]:
get_pvae = function(path) read.csv(path)$log_p_x

summaries = prep_summaries_general(
    'deneuter-2018-12-31.train/0.875/basic/*-test.head/test.pvae.csv', 'basic/olga-generated.pvae.csv', 'basic/vae-generated.pvae.csv', 
    'count_match/vae-generated.pvae.csv', get_pvae, 'log_Pvae')

plot_summaries(summaries, 'log_Pvae')

It's possible that the VAE is just memorizing input sequences and spitting them back out. 
We can exclude that possibility by looking at out-of-sample likelihoods.

In [None]:
df = read.csv(data_path('../2019-01-01-deneuter-extras.csv'))
colnames(df)[colnames(df) == 'test_mean_log_p'] = 'mean log Pvae'
colnames(df)[colnames(df) == 'test_median_log_p'] = 'median log Pvae'

train_set = lapply(
    fromJSON(data_path('../deneuter-2018-12-31.json'))$train_paths, 
    function(path) paste0(tools::file_path_sans_ext(basename(path)),'.for-test.head'))
    
df$in_train = df$test_set %in% train_set
df$beta = as.factor(df$beta)

df = df[df$beta == 0.875,]
id_vars = c('test_set', 'model', 'in_train')
measure_vars = c('mean log Pvae', 'median log Pvae')
df = df[df$model != 'olga', c(id_vars, measure_vars)]
                   
ggplot(
    melt(df, id_vars, measure_vars),
    aes(in_train, value, color=model)) + 
    geom_boxplot() + 
    ylab('log likelihood') + 
    scale_color_manual(values=source_colors) +
    facet_wrap(vars(variable)) +
    theme(axis.title.x=element_blank(), legend.justification=c(0.9,0.2), legend.position=c(0.9,0.2)) +
    scale_x_discrete(labels=c("train","test"))