In [2]:
source("utils/plot.R")

In [3]:
options(repr.plot.width=4, repr.plot.height=3, jupyter.plot_scale=1)

In [4]:
col_spec = cols(
    swept_mutations=col_character(),
    adaptive_mutation_rate=col_number(),
    selection_region_size=col_number(),
    swept_frequencies=col_character(),
    actual_frequency_at_selection=col_number(),
    num_starting_lineages=col_number(),
    num_surviving_lineages=col_number(),
    frequency_at_selection=col_number(),
    num_restarts=col_number(),
    actual_frequency_at_sampling=col_number(),
    dominance_coefficient=col_number(),
    frequency_at_sampling=col_number(),
    selection_coefficient=col_number(),
    selection_coordinate=col_number(),
    selection_generation=col_number(),
    log_selection_coefficient=col_number(),
    selection_coordinate=col_number(),
    selection_generation=col_number()
)

In [5]:
parameters <- read_tsv(snakemake@input[[1]], col_types=col_spec) %>%
    filter(sweep_mode %in% c('hard', 'rnm (true)', 'sgv (true)')) %>%
    select(uuid, actual_frequency_at_sampling)

### Selection strength regression

In [6]:
selstrength <- read_tsv(snakemake@input$selstrength, col_types=cols()) %>%
    right_join(parameters, by='uuid')

In [7]:
selstrength_fig <- ggplot(selstrength) +
    geom_point(aes(
        x=true_log_selection_coefficient,
        y=predicted_log_selection_coefficient,
        colour=actual_frequency_at_sampling
    ), size=0.5) +
    scale_colour_gradient(low='darkred', high='grey', name='Freq.') +
    geom_abline(linetype='dashed') +
    scale_x_continuous(labels=function(x){10**x}) +
    scale_y_continuous(labels=function(x){10**x}) +
    labs(
        x = "True s",
        y = "Predicted s"
    ) +
    sweeps_theme +
    theme(
        legend.position='top'
    )

In [8]:
selstrength_fig

### Sweep mode confusion matrix

In [9]:
sweepmode_raw <- read_tsv(snakemake@input$sweepmode, col_types=cols()) %>%
    select(uuid, true_labels, predicted_labels) %>%
    right_join(parameters, by='uuid') %>%
    select(actual_frequency_at_sampling, true_labels, predicted_labels)

In [10]:
sweepmode_confmat <- sweepmode_raw %>%
    select(true_labels, predicted_labels) %>%
    table %>%
    as_tibble %>%
    mutate(
        true_labels=sweepmode_factor_short(true_labels),
        predicted_labels=sweepmode_factor_short(predicted_labels)
    ) %>%
    group_by(true_labels) %>%
    mutate(
        percent=n/sum(n),
        percent_label=paste0(round(percent*100, 1), '%')
    )

In [11]:
confmat_fig <- ggplot(sweepmode_confmat) +
    geom_tile(aes(x=true_labels, y=predicted_labels, fill=percent)) +
    geom_text(aes(x=true_labels, y=predicted_labels, label=percent_label, colour=percent<0.5)) +
    scale_colour_manual(values=c('white', 'black')) +
    scale_y_discrete(limits=rev) +
    scale_fill_distiller(palette=3, direction=1) +
    sweeps_theme +
    labs(x='True', y='Predicted') +
    theme(
        legend.position='none',
        panel.grid=element_blank(),
        panel.spacing=unit(0.3, "in")
    )

In [12]:
confmat_fig

### Selection strength by frequency bracket

In [13]:
freq_breaks <- c(0, 0.2, 0.4, 0.6, 0.8, 1.0)

In [14]:
selstrength_freq <- selstrength %>%
    mutate(freq_bracket=cut(actual_frequency_at_sampling, breaks=freq_breaks, dig.lab=3)) %>%
    group_by(freq_bracket) %>%
    summarize(
        rmse=rmse(true_log_selection_coefficient, predicted_log_selection_coefficient),
        mean_relative_error=mean_relative_error(true_log_selection_coefficient, predicted_log_selection_coefficient),
    )
freq_bracket_levels <- levels(selstrength_freq$freq_bracket)
baseline <- tibble(
    rmse=rmse(selstrength$true_log_selection_coefficient, selstrength$predicted_log_selection_coefficient),
    mean_relative_error=mean_relative_error(selstrength$true_log_selection_coefficient, selstrength$predicted_log_selection_coefficient),
    freq_bracket='All'
)
selstrength_freq <- bind_rows(selstrength_freq, baseline) %>%
    mutate(
        freq_bracket=factor(freq_bracket, levels=c(freq_bracket_levels, 'All')),
        freq_bracket_label=str_replace_all(freq_bracket, c(
            "\\("="",
            "\\]"="",
            ","="-"
        )),
        freq_bracket_label=fct_reorder(freq_bracket_label, as.integer(freq_bracket))
    )

In [15]:
selstrength_freq

In [16]:
selstren_freq_fig <- ggplot(selstrength_freq) +
    geom_col(aes(x=freq_bracket_label, y=mean_relative_error)) +
    labs(
        x='Sweep frequency',
        y='Mean relative error'
    ) +
    sweeps_theme +
    theme(
        axis.text.x=element_text(angle=45, hjust=1)
    )

In [17]:
selstren_freq_fig

### Sweep mode by frequency bracket

In [18]:
sweepmode_freq <- sweepmode_raw %>%
    mutate(freq_bracket=cut(actual_frequency_at_sampling, breaks=freq_breaks, dig.lab=3)) %>%
    group_by(freq_bracket) %>%
    summarize(
        accuracy=accuracy(true_labels, predicted_labels)
    )
baseline <- tibble(
    accuracy=accuracy(sweepmode_raw$true_labels, sweepmode_raw$predicted_labels),
    freq_bracket='All'
)
sweepmode_freq <- bind_rows(sweepmode_freq, baseline) %>%
    mutate(
        freq_bracket=factor(freq_bracket, levels=c(freq_bracket_levels, 'All')),
        freq_bracket_label=str_replace_all(freq_bracket, c(
            "\\("="",
            "\\]"="",
            ","="-"
        )),
        freq_bracket_label=fct_reorder(freq_bracket_label, as.integer(freq_bracket))
    )

In [19]:
sweepmode_freq

In [20]:
sweepmode_freq_fig <- ggplot(sweepmode_freq) +
    geom_col(aes(x=freq_bracket_label, y=accuracy)) +
    ylim(0, 1) +
    labs(
        x='Sweep frequency',
        y='Accuracy'
    ) +
    sweeps_theme +
    theme(
        axis.text.x=element_text(angle=45, hjust=1)
    )

In [21]:
sweepmode_freq_fig

### Plot it all together

In [22]:
top_part <- plot_grid(selstrength_fig, confmat_fig, nrow=1, labels=c('A', 'B'))
freq_brackets <- plot_grid(selstren_freq_fig, sweepmode_freq_fig, axis='tb', align='t', nrow=1, labels=c('C', 'D'), label_y=1.1)
all_fig <- plot_grid(
    top_part,
    freq_brackets,
    nrow=2,
    axis='lr', align='l')

In [23]:
sweeps_save(snakemake@output$figure, all_fig, width=6, asp=4/3)

## Get metrics

In [24]:
metrics <- inner_join(selstrength_freq, select(sweepmode_freq, freq_bracket, accuracy), by='freq_bracket') %>%
    rename(selstrength_rmse=rmse, selstrength_mre=mean_relative_error, sweepmode_accuracy=accuracy)

In [25]:
metrics

In [26]:
write_tsv(metrics, snakemake@output$metrics)