In [2]:
source('utils/plot.R')

In [3]:
read_learning_curves <- function(filename) {
    result <- read_tsv(filename, col_types=cols()) %>%
    mutate(filename=basename(filename)) %>%
    separate('filename', sep='_', into=c('target', NA, NA))
    return(result)
}

In [4]:
data <- bind_rows(lapply(snakemake@input, read_learning_curves)) %>%
    pivot_longer(c('train_loss', 'valid_loss'), names_to='loss_name', values_to='loss') %>%
    mutate(
        target=target_factor(target),
        loss_name=str_replace_all(loss_name, c('valid_loss'='Validation loss', 'train_loss'='Training loss'))
    ) %>%
    filter(epoch >= 10)

In [5]:
fig <- ggplot(data) +
    geom_line(aes(x=epoch, y=loss, colour=loss_name)) +
    facet_wrap(vars(target), scales='free_y') +
    sweeps_colour +
    labs(
        x='Epoch',
        y='Loss'
    ) +
    sweeps_theme +
    theme(
        legend.position='top',
        legend.title=element_blank(),
        legend.justification=c(0, 0),
        legend.box.spacing=unit(0.1, "cm"),
        panel.spacing=unit(0.2, "cm"),
        axis.title.y=element_text(hjust=1)
    )

In [6]:
# Add labels to plot
labeled_fig <- ggdraw(fig) +
    draw_label('A', x=0.07, y=0.8, hjust=0, vjust=0, fontface='bold', size=11) +
    draw_label('B', x=0.55, y=0.8, hjust=0, vjust=0, fontface='bold', size=11) +
    draw_label('C', x=0.07, y=0.43, hjust=0, vjust=0, fontface='bold', size=11) +
    draw_label('D', x=0.55, y=0.43, hjust=0, vjust=0, fontface='bold', size=11)

In [7]:
sweeps_save(snakemake@output[[1]], labeled_fig)