In [None]:
library(tidyverse)
library(fs)
library(extrafont)
#font_import() Only run once. May need to run apt install fonts-cmu 

## Transitions

In [None]:
transition_df =
    dir_ls('exports/', regexp = '^exports/transition-stability-.*\\.csv$') %>%
    set_names(., str_extract(path_file(.), '(?<=^transition-stability-).*(?=\\.csv$)')) %>%
    imap(function(x, i) {
        read_csv(x) %>%
            mutate(model = i)
    }) %>%
    list_rbind()

In [None]:
base_font_family = 'CMU Serif'
base_font_size = 11

plots = map(group_split(transition_df, model), function(model_transition_df) {
    model_prefix = model_transition_df$model[[1]]

    plot = 
        bind_rows(
            model_transition_df %>%
                transmute(model, layer_ix, component = 'h_para', mean = para_mean_across_layers, upper = mean + para_cis, lower = mean - para_cis),
            model_transition_df %>%
                transmute(model, layer_ix, component = 'h_orth', mean = orth_mean_across_layers, upper = mean + orth_cis, lower = mean - orth_cis)
        ) %>%
        ggplot() + 
        labs(
            title = NULL,
            color = NULL, fill = NULL,
            x = 'Layer index',
            y = 'Cosine similarity to previous layer'
        ) +
        geom_line(aes(x = layer_ix, y = mean, group = component, color = component), linewidth = 0.8) +
        geom_point(aes(x = layer_ix, y = mean, group = component, color = component), shape = 16, size = 2) +
        geom_ribbon(aes(x = layer_ix, ymin = lower, ymax = upper, fill = component), alpha = 0.25, linetype = 'blank', color = NA) +
        scale_color_manual(
            labels = c(h_para = 'Parallel component', h_orth = 'Orthogonal component'),
            values = c(h_para = '#E69F00', h_orth = '#56B4E9')
        ) +
        scale_fill_manual(
            labels = c(h_para = 'Parallel component', h_orth = 'Orthogonal component'),
            values = c(h_para = '#E69F00', h_orth = '#56B4E9')
        ) +
        theme_bw(base_size = base_font_size, base_family = base_font_family) + 
        theme(
            plot.title = element_blank(), 
            plot.subtitle = element_blank(),
            
            axis.title = element_text(face = 'plain', size = rel(1.0)), 
            axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1, size = rel(0.9)),
            axis.text.y = element_text(size = rel(0.9)),
            
            legend.position = "bottom", # Or "bottom"
            legend.title = element_text(face = "bold", size = rel(0.9)),
            legend.text = element_text(size = rel(0.85)),
            legend.background = element_rect(fill="white", color = "grey90", linewidth=0.2), # Subtle legend box
            legend.key.size = unit(0.8, "lines"),

            panel.grid.major = element_line(colour = "grey85", linewidth = 0.3), # Lighter, thinner major grid
            panel.grid.minor = element_blank(),
            panel.border = element_rect(colour = "grey70", fill=NA, linewidth=0.5), # Add back a border
            
            strip.background = element_blank(),
            strip.text = element_text(face = "bold", size = rel(1.0))
        )

    ggsave(
        str_glue('exports/transition-{model_prefix}-md.pdf'),
        plot = plot, 
        width = 800/100, height = 500/100,
        units = "in", dpi = 300,
        device = cairo_pdf
    )
    ggsave(
        str_glue('exports/transition-{model_prefix}-md.png'),
        plot = plot, 
        width = 800/100, height = 500/100,
        units = "in", dpi = 300
    )
    print(plot)

    return(plot)
    })
