In [2]:
source('utils/plot.R')

In [3]:
tables <- lapply(snakemake@input, read_tsv, col_types=cols())
data <- bind_rows(tables) %>%
    separate('replicate', sep='_', into=c('target', 'training', 'replicate')) %>%
    separate('replicate', sep='-', into=c(NA, 'replicate'), convert=TRUE) %>%
    pivot_longer(c('train_loss', 'valid_loss'), names_to='loss_name', values_to='loss') %>%
    group_by(target, epoch, loss_name) %>%
    summarize(loss=mean(loss)) %>%
    mutate(
        target=target_factor(target),
        loss_name=str_replace_all(loss_name, c('valid_loss'='Validation loss', 'train_loss'='Training loss'))
    ) %>%
    filter(epoch >= 3) # helps visualize better

In [4]:
fig <- ggplot(data) +
    geom_vline(xintercept=snakemake@config$epochs_for_model_training, linetype='dashed', size=0.3) +
    geom_line(aes(x=epoch, y=loss, colour=loss_name)) +
    facet_wrap(vars(target), scales='free_y') +
    sweeps_colour +
    labs(
        x='Epoch',
        y='Loss value'
    ) +
    sweeps_theme +
    theme(
        axis.text.y=element_blank(),
        axis.title.y=element_blank(),
        legend.position='top',
        legend.title=element_blank(),
        legend.justification=c(0, 0),
        legend.box.spacing=unit(0.1, "cm"),
    )

In [5]:
sweeps_save(snakemake@output[[1]], fig)