In [None]:
library(tidyverse)
library(fs)
library(extrafont)
library(ggtext)

base_font_family = 'CMU Serif'
base_font_size = 22
# font_import(prompt = F)

## Transitions

In [None]:
transition_df =
    dir_ls('exports/', regexp = '^exports/svd-transition-stability-.*\\.csv$') %>%
    set_names(., str_extract(path_file(.), '(?<=^svd-transition-stability-).*(?=\\.csv$)')) %>%
    imap(function(x, i) {
        read_csv(x) %>%
            mutate(model = i)
    }) %>%
    list_rbind()

head(transition_df)

In [None]:
plots = map(group_split(transition_df, model), function(model_transition_df) {
    model_prefix = model_transition_df$model[[1]]

    plot = 
        bind_rows(
            model_transition_df %>%
                transmute(model, layer_ix_1, component = 'h_para', mean = para_mean_across_layers, upper = mean + para_cis, lower = mean - para_cis),
            model_transition_df %>%
                transmute(model, layer_ix_1, component = 'h_orth', mean = orth_mean_across_layers, upper = mean + orth_cis, lower = mean - orth_cis)
        ) %>%
        ggplot() + 
        labs(
            title = NULL,
            color = NULL, fill = NULL,
            x = 'Layer index',
            y = 'Cosine similarity to previous layer'
        ) +
        geom_line(aes(x = layer_ix_1, y = mean, group = component, color = component), linewidth = 0.8) +
        geom_point(aes(x = layer_ix_1, y = mean, group = component, color = component), shape = 16, size = 2) +
        geom_ribbon(aes(x = layer_ix_1, ymin = lower, ymax = upper, fill = component), alpha = 0.25, linetype = 'blank', color = NA) +
        scale_color_manual(
            labels = c(h_para = 'Router-visible channel (<em>h<sup>vis</sup></em>)', h_orth = 'Router-blind channel (<em>h<sup>blind</sup></em>)'),
            values = c(h_para = '#E69F00', h_orth = '#56B4E9')
        ) +
        scale_fill_manual(
            labels = c(h_para = 'Router-visible channel (<em>h<sup>vis</sup></em>)', h_orth = 'Router-blind channel (<em>h<sup>blind</sup></em>)'),
            values = c(h_para = '#E69F00', h_orth = '#56B4E9')
        ) +
        theme_bw(base_size = base_font_size, base_family = base_font_family) + 
        theme(
            plot.title = element_blank(), 
            plot.subtitle = element_blank(),

            axis.title = element_text(face = 'plain', size = rel(1.0)), 
            axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1, size = rel(0.9)),
            axis.text.y = element_text(size = rel(0.9)),
            
            legend.position = "bottom", # Or "bottom"
            legend.title = element_text(face = "bold", size = rel(0.9)),
            legend.text = ggtext::element_markdown(size = rel(0.9)),
            legend.background = element_rect(fill="white", color = "grey90", linewidth=0.2), # Subtle legend box
            legend.key.size = unit(0.8, "lines"),

            panel.grid.major = element_line(colour = "grey85", linewidth = 0.3), # Lighter, thinner major grid
            panel.grid.minor = element_blank(),
            panel.border = element_rect(colour = "grey70", fill=NA, linewidth=0.5), # Add back a border
            
            strip.background = element_blank(),
            strip.text = element_text(face = "bold", size = rel(1.0))
        )

    ggsave(
        str_glue('exports/svd-transition-stability-{model_prefix}-md.pdf'),
        plot = plot, 
        width = 800/100, height = 500/100,
        units = "in", dpi = 300,
        device = cairo_pdf
    )
    ggsave(
        str_glue('exports/svd-transition-stability-{model_prefix}-md.png'),
        plot = plot, 
        width = 800/100, height = 500/100,
        units = "in", dpi = 300
    )
    print(plot)

    return(plot)
    })


In [None]:
plot = 
    bind_rows(
        transition_df %>%
            transmute(model, layer_ix_1, component = 'h_para', mean = para_mean_across_layers, upper = mean + para_cis, lower = mean - para_cis),
        transition_df %>%
            transmute(model, layer_ix_1, component = 'h_orth', mean = orth_mean_across_layers, upper = mean + orth_cis, lower = mean - orth_cis)
    ) %>%
    mutate(
    model = model |>
      fct_relevel("olmoe", "qwen1.5moe", "dsv2", "qwen3moe") |>   # order
      fct_recode(                                                 # labels
        "OlMoE"       = "olmoe",
        "Qwen1.5MoE"  = "qwen1.5moe",
        "DSv2-Lite"   = "dsv2",
        "Qwen3MoE"    = "qwen3moe"
        )
    ) %>% 
    ggplot() + 
    labs(
        title = NULL,
        color = NULL, fill = NULL,
        x = 'Layer index <em>l</em>',
        y = 'Transition similarity <em>C<sub>l</sub></em>'
    ) +
    geom_line(aes(x = layer_ix_1, y = mean, group = component, color = component), linewidth = 0.8) +
    geom_point(aes(x = layer_ix_1, y = mean, group = component, color = component), shape = 16, size = 2) +
    geom_ribbon(aes(x = layer_ix_1, ymin = lower, ymax = upper, fill = component), alpha = 0.25, linetype = 'blank', color = NA) +
    scale_color_manual(
        labels = c(h_para = 'Router-visible channel (<em>h<sup>vis</sup></em>)', h_orth = 'Router-blind channel (<em>h<sup>blind</sup></em>)'),
        values = c(h_para = '#E69F00', h_orth = '#56B4E9')
    ) +
    scale_fill_manual(
        labels = c(h_para = 'Router-visible channel (<em>h<sup>vis</sup></em>)', h_orth = 'Router-blind channel (<em>h<sup>blind</sup></em>)'),
        values = c(h_para = '#E69F00', h_orth = '#56B4E9')
    ) +
    facet_wrap(vars(model), nrow = 2, scales = 'free') +
    theme_bw(base_size = base_font_size, base_family = base_font_family) + 
    theme(
        plot.title = element_blank(), 
        plot.subtitle = element_blank(),

        axis.title = ggtext::element_markdown(face = 'plain', size = rel(1.0)), 
        axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1, size = rel(0.9)),
        axis.text.y = element_text(size = rel(0.9)),
        
        legend.position = "bottom", # Or "bottom"
        legend.title = element_text(face = "bold", size = rel(0.9)),
        legend.text = ggtext::element_markdown(size = rel(0.9)),
        legend.background = element_rect(fill="white", color = "grey90", linewidth=0.2), # Subtle legend box
        legend.key.size = unit(0.8, "lines"),

        panel.grid.major = element_line(colour = "grey85", linewidth = 0.3), # Lighter, thinner major grid
        panel.grid.minor = element_blank(),
        panel.border = element_rect(colour = "grey70", fill=NA, linewidth=0.5), # Add back a border
        
        strip.background = element_blank(),
        strip.text = element_text(face = "bold", size = rel(1.0))
    )

ggsave(
    str_glue('exports/svd-transition-stability-all-md.pdf'),
    plot = plot, 
    width = 1600/100, height = 800/100,
    units = "in", dpi = 300,
    device = cairo_pdf
)
ggsave(
    str_glue('exports/svd-transition-stability-all-md.png'),
    plot = plot, 
    width = 1600/100, height = 800/100,
    units = "in", dpi = 300
)

plot

## Expert ID probes

In [None]:
layer_pred =
    dir_ls('exports/', regexp = '^exports/svd-probe-expert-id-.*\\.csv$') %>%
    set_names(., str_extract(path_file(.), '(?<=^svd-probe-expert-id-).*(?=\\.csv$)')) %>%
    imap(function(x, i) {
        read_csv(x) %>%
            mutate(model = i)
    }) %>%
    list_rbind()

head(layer_pred)

In [None]:
layer_pred_df =
    layer_pred %>%
    filter(target == 'current_layer') %>%
    pivot_longer(cols = -c(test_layer_1, target, model),names_to = c('tstr'), values_to = 'value') %>%
    mutate(., channel = str_sub(tstr, 1, 4), metric = str_sub(tstr, 6)) %>%
    select(-tstr) %>%
    pivot_wider(id_cols = c(test_layer_1, target, model, channel), names_from = metric, values_from = value)

print(head(layer_pred_df))

base_font_family = 'CMU Serif'
base_font_size = 11

plots = map(group_split(layer_pred_df, model), function(model_layer_df) {
    model_prefix = model_layer_df$model[[1]]

    plot = 
        model_layer_df %>%
        ggplot() + 
        labs(
            title = NULL,
            color = NULL, fill = NULL,
            x = 'Layer index',
            y = 'Probe accuracy for current-layer expert ID (%)'
        ) +
        geom_line(aes(x = test_layer_1, y = acc, group = channel, color = channel), linewidth = 0.8) +
        geom_point(aes(x = test_layer_1, y = acc, group = channel, color = channel), shape = 16, size = 2) +
        scale_color_manual(
            labels = c(para = 'Router-visible channel (<em>h<sub>vis</sub></em>)', orth = 'Router-blind channel (<em>h<sub>blind</sub></em>)'),
            values = c(para = '#E69F00', orth = '#56B4E9')
        ) +
        scale_fill_manual(
            labels = c(para = 'Router-visible channel (<em>h<sub>vis</sub></em>)', orth = 'Router-blind channel (<em>h<sub>blind</sub></em>)'),
            values = c(para = '#E69F00', orth = '#56B4E9')
        ) +
        theme_bw(base_size = base_font_size, base_family = base_font_family) + 
        theme(
            plot.title = element_blank(), 
            plot.subtitle = element_blank(),

            axis.title = element_text(face = 'plain', size = rel(1.0)), 
            axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1, size = rel(0.9)),
            axis.text.y = element_text(size = rel(0.9)),
            
            legend.position = "bottom", # Or "bottom"
            legend.title = element_text(face = "bold", size = rel(0.9)),
            legend.text = ggtext::element_markdown(size = rel(0.85)),
            legend.background = element_rect(fill="white", color = "grey90", linewidth=0.2), # Subtle legend box
            legend.key.size = unit(0.8, "lines"),

            panel.grid.major = element_line(colour = "grey85", linewidth = 0.3), # Lighter, thinner major grid
            panel.grid.minor = element_blank(),
            panel.border = element_rect(colour = "grey70", fill=NA, linewidth=0.5), # Add back a border
            
            strip.background = element_blank(),
            strip.text = element_text(face = "bold", size = rel(1.0))
        )

    ggsave(
        str_glue('exports/svd-probe-same-expert-id-{model_prefix}-md.pdf'),
        plot = plot, 
        width = 800/100, height = 500/100,
        units = "in", dpi = 300,
        device = cairo_pdf
    )
    ggsave(
        str_glue('exports/svd-probe-same-expert-id-{model_prefix}-md.png'),
        plot = plot, 
        width = 800/100, height = 500/100,
        units = "in", dpi = 300
    )
    print(plot)

    return(plot)
    })


In [None]:
layer_pred_df =
    layer_pred %>%
    filter(target == 'next_layer') %>%
    pivot_longer(cols = -c(test_layer_1, target, model),names_to = c('tstr'), values_to = 'value') %>%
    mutate(., channel = str_sub(tstr, 1, 4), metric = str_sub(tstr, 6)) %>%
    select(-tstr) %>%
    pivot_wider(id_cols = c(test_layer_1, target, model, channel), names_from = metric, values_from = value)

base_font_family = 'CMU Serif'
base_font_size = 11

plots = map(group_split(layer_pred_df, model), function(model_layer_df) {
    model_prefix = model_layer_df$model[[1]]

    plot = 
        model_layer_df %>%
        ggplot() + 
        labs(
            title = NULL,
            color = NULL, fill = NULL,
            x = 'Layer index',
            y = 'Probe accuracy for next-layer expert ID (%)'
        ) +
        geom_line(aes(x = test_layer_1, y = acc, group = channel, color = channel), linewidth = 0.8) +
        geom_point(aes(x = test_layer_1, y = acc, group = channel, color = channel), shape = 16, size = 2) +
        scale_color_manual(
            labels = c(para = 'Router-visible channel (<em>h<sub>vis</sub></em>)', orth = 'Router-blind channel (<em>h<sub>blind</sub></em>)'),
            values = c(para = '#E69F00', orth = '#56B4E9')
        ) +
        scale_fill_manual(
            labels = c(para = 'Router-visible channel (<em>h<sub>vis</sub></em>)', orth = 'Router-blind channel (<em>h<sub>blind</sub></em>)'),
            values = c(para = '#E69F00', orth = '#56B4E9')
        ) +
        theme_bw(base_size = base_font_size, base_family = base_font_family) + 
        theme(
            plot.title = element_blank(), 
            plot.subtitle = element_blank(),

            axis.title = element_text(face = 'plain', size = rel(1.0)), 
            axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1, size = rel(0.9)),
            axis.text.y = element_text(size = rel(0.9)),
            
            legend.position = "bottom", # Or "bottom"
            legend.title = element_text(face = "bold", size = rel(0.9)),
            legend.text = ggtext::element_markdown(size = rel(0.85)),
            legend.background = element_rect(fill="white", color = "grey90", linewidth=0.2), # Subtle legend box
            legend.key.size = unit(0.8, "lines"),

            panel.grid.major = element_line(colour = "grey85", linewidth = 0.3), # Lighter, thinner major grid
            panel.grid.minor = element_blank(),
            panel.border = element_rect(colour = "grey70", fill=NA, linewidth=0.5), # Add back a border
            
            strip.background = element_blank(),
            strip.text = element_text(face = "bold", size = rel(1.0))
        )

    ggsave(
        str_glue('exports/svd-probe-next-expert-id-{model_prefix}-md.pdf'),
        plot = plot, 
        width = 800/100, height = 500/100,
        units = "in", dpi = 300,
        device = cairo_pdf
    )
    ggsave(
        str_glue('exports/svd-probe-next-expert-id-{model_prefix}-md.png'),
        plot = plot, 
        width = 800/100, height = 500/100,
        units = "in", dpi = 300
    )
    print(plot)

    return(plot)
    })


## Language probes

In [None]:
lang_pred =
    dir_ls('exports/', regexp = '^exports/svd-probe-lang-.*\\.csv$') %>%
    set_names(., str_extract(path_file(.), '(?<=^svd-probe-lang-).*(?=\\.csv$)')) %>%
    imap(function(x, i) {
        read_csv(x) %>%
            mutate(model = i)
    }) %>%
    list_rbind()

head(lang_pred)

In [None]:
base_font_family = 'CMU Serif'
base_font_size = 11

layer_pred_df =
    lang_pred %>%
    pivot_longer(cols = -c(test_layer_1, model),names_to = c('tstr'), values_to = 'value') %>%
    mutate(., channel = str_sub(tstr, 1, 4), metric = str_sub(tstr, 6)) %>%
    select(-tstr) %>%
    pivot_wider(id_cols = c(test_layer_1, model, channel), names_from = metric, values_from = value)

head(layer_pred_df)

plots = map(group_split(layer_pred_df, model), function(model_layer_df) {
    model_prefix = model_layer_df$model[[1]]

    plot = 
        model_layer_df %>%
        ggplot() + 
        labs(
            title = NULL,
            color = NULL, fill = NULL,
            x = 'Layer index',
            y = 'Probe accuracy for language (%)'
        ) +
        geom_line(aes(x = test_layer_1, y = acc, group = channel, color = channel), linewidth = 0.8) +
        geom_point(aes(x = test_layer_1, y = acc, group = channel, color = channel), shape = 16, size = 2) +
        scale_color_manual(
            labels = c(para = 'Router-visible channel (<em>h<sub>vis</sub></em>)', orth = 'Router-blind channel (<em>h<sub>blind</sub></em>)'),
            values = c(para = '#E69F00', orth = '#56B4E9')
        ) +
        scale_fill_manual(
            labels = c(para = 'Router-visible channel (<em>h<sub>vis</sub></em>)', orth = 'Router-blind channel (<em>h<sub>blind</sub></em>)'),
            values = c(para = '#E69F00', orth = '#56B4E9')
        ) +
        theme_bw(base_size = base_font_size, base_family = base_font_family) + 
        theme(
            plot.title = element_blank(), 
            plot.subtitle = element_blank(),

            axis.title = element_text(face = 'plain', size = rel(1.0)), 
            axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1, size = rel(0.9)),
            axis.text.y = element_text(size = rel(0.9)),
            
            legend.position = "bottom", # Or "bottom"
            legend.title = element_text(face = "bold", size = rel(0.9)),
            legend.text = ggtext::element_markdown(size = rel(0.85)),
            legend.background = element_rect(fill="white", color = "grey90", linewidth=0.2), # Subtle legend box
            legend.key.size = unit(0.8, "lines"),

            panel.grid.major = element_line(colour = "grey85", linewidth = 0.3), # Lighter, thinner major grid
            panel.grid.minor = element_blank(),
            panel.border = element_rect(colour = "grey70", fill=NA, linewidth=0.5), # Add back a border
            
            strip.background = element_blank(),
            strip.text = element_text(face = "bold", size = rel(1.0))
        )

    ggsave(
        str_glue('exports/svd-probe-lang-{model_prefix}-md.pdf'),
        plot = plot, 
        width = 800/100, height = 500/100,
        units = "in", dpi = 300,
        device = cairo_pdf
    )
    ggsave(
        str_glue('exports/svd-probe-lang-{model_prefix}-md.png'),
        plot = plot, 
        width = 800/100, height = 500/100,
        units = "in", dpi = 300
    )
    print(plot)

    return(plot)
    })

## TID probe

In [None]:
tid_pred =
    dir_ls('exports/', regexp = '^exports/svd-probe-tid-.*\\.csv$') %>%
    set_names(., str_extract(path_file(.), '(?<=^svd-probe-tid-).*(?=\\.csv$)')) %>%
    imap(function(x, i) {
        read_csv(x) %>%
            mutate(model = i)
    }) %>%
    list_rbind()

head(tid_pred)

In [None]:
base_font_family = 'CMU Serif'
base_font_size = 11

layer_pred_df =
    tid_pred %>%
    pivot_longer(cols = -c(test_layer_1, model),names_to = c('tstr'), values_to = 'value') %>%
    mutate(., channel = str_sub(tstr, 1, 4), metric = str_sub(tstr, 6)) %>%
    select(-tstr) %>%
    pivot_wider(id_cols = c(test_layer_1, model, channel), names_from = metric, values_from = value)

head(layer_pred_df)

plots = map(group_split(layer_pred_df, model), function(model_layer_df) {
    model_prefix = model_layer_df$model[[1]]

    plot = 
        model_layer_df %>%
        ggplot() + 
        labs(
            title = NULL,
            color = NULL, fill = NULL,
            x = 'Layer index',
            y = 'Probe accuracy for token ID (%)'
        ) +
        geom_line(aes(x = test_layer_1, y = acc, group = channel, color = channel), linewidth = 0.8) +
        geom_point(aes(x = test_layer_1, y = acc, group = channel, color = channel), shape = 16, size = 2) +
        scale_color_manual(
            labels = c(para = 'Router-visible channel (<em>h<sub>vis</sub></em>)', orth = 'Router-blind channel (<em>h<sub>blind</sub></em>)'),
            values = c(para = '#E69F00', orth = '#56B4E9')
        ) +
        scale_fill_manual(
            labels = c(para = 'Router-visible channel (<em>h<sub>vis</sub></em>)', orth = 'Router-blind channel (<em>h<sub>blind</sub></em>)'),
            values = c(para = '#E69F00', orth = '#56B4E9')
        ) +
        theme_bw(base_size = base_font_size, base_family = base_font_family) + 
        theme(
            plot.title = element_blank(), 
            plot.subtitle = element_blank(),

            axis.title = element_text(face = 'plain', size = rel(1.0)), 
            axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1, size = rel(0.9)),
            axis.text.y = element_text(size = rel(0.9)),
            
            legend.position = "bottom", # Or "bottom"
            legend.title = element_text(face = "bold", size = rel(0.9)),
            legend.text = ggtext::element_markdown(size = rel(0.85)),
            legend.background = element_rect(fill="white", color = "grey90", linewidth=0.2), # Subtle legend box
            legend.key.size = unit(0.8, "lines"),

            panel.grid.major = element_line(colour = "grey85", linewidth = 0.3), # Lighter, thinner major grid
            panel.grid.minor = element_blank(),
            panel.border = element_rect(colour = "grey70", fill=NA, linewidth=0.5), # Add back a border
            
            strip.background = element_blank(),
            strip.text = element_text(face = "bold", size = rel(1.0))
        )

    ggsave(
        str_glue('exports/svd-probe-tid-{model_prefix}-md.pdf'),
        plot = plot, 
        width = 800/100, height = 500/100,
        units = "in", dpi = 300,
        device = cairo_pdf
    )
    ggsave(
        str_glue('exports/svd-probe-tid-{model_prefix}-md.png'),
        plot = plot, 
        width = 800/100, height = 500/100,
        units = "in", dpi = 300
    )
    print(plot)

    return(plot)
    })

## Make tables

In [None]:

#                  | ---------- model1 ---------  | --------------- model2 ------------ |
# Probe target      Router-hidden   Router-visible 
# current_layer     
# next_layer
# token id
# language
map(unique(layer_pred$model), function(this_model) 
    tribble(
        ~ 'probe_target', ~ 'h-visible', ~ 'h-blind',
        'current_layer', mean(filter(layer_pred, target == 'current_layer' & model == this_model)$para_acc), mean(filter(layer_pred, target == 'current_layer' & model == this_model)$orth_acc),
        'next_layer', mean(filter(layer_pred, target == 'next_layer' & model == this_model)$para_acc), mean(filter(layer_pred, target == 'next_layer' & model == this_model)$orth_acc),
        'language', mean(filter(lang_pred, model == this_model)$para_acc), mean(filter(lang_pred, model == this_model)$orth_acc),
        'token', mean(filter(tid_pred, model == this_model)$para_acc), mean(filter(tid_pred, model == this_model)$orth_acc),
    ) %>%
    mutate(model = this_model)
) %>%
list_rbind() %>%
write_csv('exports/svd-aggregates.csv')