In [None]:
### Calculate most common MMLU domains per pair

In [None]:
# ---- Load libs ----
library(tidyverse)
library(slider)

model_prefix = 'qwen3moe'

## Load data

In [None]:
# ---- Load data ----
local({

    raw_sample_df =
        data.table::fread(str_glue('data/{model_prefix}/samples.csv'), strip.white = F) %>%
        arrange(q_ix, token_ix) %>%
        mutate(., shifted_token_ix = 1:n(), .by = c('q_ix'))
    
    raw_topk_df = data.table::fread(str_glue('data/{model_prefix}/topks.csv'))

    last_shared_tok =
        raw_sample_df %>%
        group_by(shifted_token_ix, token) %>%
        summarize(., n_questions_with_token_ix_tok = n(), .groups = 'drop') %>% 
        filter(., n_questions_with_token_ix_tok == length(unique(raw_sample_df$q_ix))) %>%
        .$shifted_token_ix %>%
        max

    sample_df =
        raw_sample_df %>%
        filter(., shifted_token_ix > last_shared_tok)

    topk_df = raw_topk_df %>% inner_join(select(sample_df, q_ix, token_ix), by = c('q_ix', 'token_ix'))

    topk1_df = topk_df %>% filter(., topk_ix == 1) %>% select(-topk_ix)

    sample_df <<- sample_df
    topk_df <<- topk_df
    topk1_df <<- topk1_df
})

In [None]:
domain_lang_map =
    sample_df %>%
    group_by(domain, lang) %>% 
    summarize(., n_tokens = n(), .groups = 'drop')

domain_lang_map

In [None]:
# Check experts, layers
print(sort(unique(topk1_df$expert)))
print(sort(unique(topk1_df$layer_ix)))

In [None]:
base_accs =
    sample_df %>%
    group_by(domain, lang, q_ix, question_output_token, answer_char) %>% 
    summarize(tokens_per_question = n(), .groups = 'drop') %>%
    mutate(., is_correct = ifelse(str_squish(question_output_token) == answer_char, 1, 0)) %>%
    group_by(domain, lang) %>%
    summarize(base_acc = sum(is_correct)/n(), questions = n(), total_tokens = sum(tokens_per_question), .groups = 'drop')

print(base_accs)
write_csv(base_accs, str_glue('data/{model_prefix}/base_accs.csv'))

## Ablation method 1: most common transitions

In [None]:
# ---- Get transition counts by q_ix ----
toks_with_paths =
    topk1_df %>%
    select(., layer_ix, q_ix, token_ix, expert) %>%
    group_by(., q_ix, token_ix) %>%
    arrange(., layer_ix, .by_group = T) %>%
    mutate(
        path = slide(expert, .f = \(x) x, .before = 1, .after = 0),
        layers = slide(layer_ix, .f = \(x) x, .before = 1, .after = 0)
    ) %>%
    ungroup() %>%
    filter(., layer_ix > 0) %>%
    left_join(., select(sample_df, q_ix, token_ix, domain, lang), by = c('q_ix', 'token_ix'))

head(toks_with_paths, 5)

In [None]:
cat('# Samples: ', nrow(sample_df))
cat('\n# Topk1 Experts (n_layers x samples): ', nrow(topk1_df))
cat('\n# Paths ((n_layers - 1) x samples): ', nrow(toks_with_paths))

In [None]:
dom_x_path =
    toks_with_paths %>%
    group_by(., domain, lang, path, layers) %>%
    summarize(., n_samples = n(), .groups = 'drop')

head(dom_x_path, 5)

In [None]:
dom_tok_counts =
    dom_x_path %>%
    group_by(., domain, lang) %>%
    summarize(., n_tok_samples = sum(n_samples), .groups = 'drop') %>%
    mutate(., n_tok_prop = n_tok_samples/sum(n_tok_samples))

dom_tok_counts

In [None]:
# Path counts, pivot domains out
test_domain = 'biology'
test_lang = 'en'

dom_x_path %>%
    pivot_wider(., id_cols = c(layers, path), names_from = c(domain, lang), values_from = n_samples, values_fill = 0) %>% 
    print()

spec_paths =
    dom_x_path %>%
    group_by(layers, path) %>%
    mutate(., prop_of_samples = n_samples/sum(n_samples)) %>%
    ungroup() %>%
    filter(., domain == test_domain & lang == test_lang) %>%
    left_join(dom_tok_counts, by = c('domain', 'lang')) %>%
    filter(., prop_of_samples >= n_tok_prop * 1)

cat('Unique path counts: ', nrow(spec_paths), ' of ', nrow(spec_paths))
cat('\nPaths taken counts: ', sum(spec_paths$n_samples), ' of ', sum(filter(dom_x_path, domain == test_domain & lang == test_lang)$n_samples))

In [None]:
# Analyze proportions to be ablated
toks_with_paths %>%
    left_join(transmute(spec_paths, layers, path, is_spec = 1), by = c('path', 'layers')) %>%
    mutate(., is_spec = ifelse(!is.na(is_spec), 1, 0)) %>%
    group_by(q_ix, token_ix, domain, lang) %>%
    summarize(., n_spec_paths = sum(is_spec), n_possible_paths = n(), .groups = 'drop') %>%
    group_by(domain, lang) %>%
    summarize(
        n_questions = n_distinct(q_ix),
        n_toks = n(),
        n_toks_with_any_spec_path = sum(ifelse(n_spec_paths > 0, 1, 0)),
        n_toks_with_half_spec_path = sum(ifelse(n_spec_paths >= n_possible_paths * .5, 1, 0)),
        n_spec_paths = sum(n_spec_paths),
        n_possible_paths = sum(n_possible_paths),
        .groups = 'drop'
    ) %>%
    mutate(
        .,
        prop_spec_paths = n_spec_paths/n_possible_paths,
        prop_toks_with_any_spec_path = n_toks_with_any_spec_path/n_toks,
        prop_toks_with_half_spec_path = n_toks_with_half_spec_path/n_toks
    )

In [None]:
# We save it in format
# {
#     layer: [
#         [prefix, target_e]
#         ...
#     ]
# }
exportable_format = 
    spec_paths %>%
    mutate(
        target_layer = map_int(layers, \(x) tail(x, 1)),
        target_expert = map_int(path, \(x) tail(x, 1)),
        expert_prefix = map(path, \(x) head(x, -1)),
        rule_pair = map2(expert_prefix, target_expert, \(x, y) list(x, y))
    ) %>%
    select(target_layer, rule_pair) %>%
    group_by(., target_layer) %>%
    summarize(
        rules = list(rule_pair),
        .groups = 'drop'
    ) %>%
    {setNames(.$rules, .$target_layer)}

In [None]:
length(exportable_format$`1`)

In [None]:
json_output = jsonlite::toJSON(exportable_format, simplifyVector = F, auto_unbox = T, pretty = F)
writeLines(json_output, str_glue('data/{model_prefix}/path_ablation_targets.json'))

In [None]:
sample_df %>%
    distinct(q_ix, domain, lang) %>%
    mutate(., row_ix = 1:n()) %>% 
    group_by(domain, lang) %>%
    summarize(., start = min(row_ix), end = max(row_ix), .groups = 'drop')

## Ablation method 2: within-layer

In [None]:
dom_expert_layer_counts =
    topk1_df %>%
    select(., layer_ix, q_ix, token_ix, expert) %>%
    group_by(., q_ix, token_ix) %>%
    filter(., layer_ix > 0) %>%
    left_join(., select(sample_df, q_ix, token_ix, domain, lang), by = c('q_ix', 'token_ix')) %>%
    group_by(., domain, lang, layer_ix, expert) %>%
    summarize(., n_samples = n(), .groups = 'drop')

head(dom_expert_layer_counts, 5)

In [None]:
dom_full_tok_counts =
    dom_expert_layer_counts %>%
    group_by(., domain, lang) %>%
    summarize(., n_tok_samples = sum(n_samples), .groups = 'drop') %>%
    mutate(., n_tok_prop = n_tok_samples/sum(n_tok_samples))

dom_full_tok_counts

In [None]:
dom_expert_layer_counts %>%
    mutate(domain_lang = paste0(domain, '_', lang)) %>%
    pivot_wider(., id_cols = c(layer_ix, expert), names_from = domain_lang, values_from = n_samples, values_fill = 0) %>% 
    print()

spec_experts =
    dom_expert_layer_counts %>%
    group_by(layer_ix, expert) %>%
    mutate(
        .,
        prop_of_samples = n_samples/sum(n_samples)
        ) %>%
    ungroup() %>%
    filter(., domain == test_domain & lang == test_lang) %>%
    left_join(dom_full_tok_counts, by = c('domain', 'lang')) %>%
    filter(., prop_of_samples >= n_tok_prop * 5)

cat('Specialized experts: ', sum(spec_experts$n_samples), ' of ', sum(filter(dom_expert_layer_counts, domain == test_domain & lang == test_lang)$n_samples))

In [None]:
# Analyze proportions to be ablated
topk1_df %>%
    left_join(transmute(spec_experts, layer_ix, expert, is_spec = 1), by = c('layer_ix', 'expert')) %>%
    left_join(select(sample_df, domain, lang, q_ix, token_ix), by = c('q_ix', 'token_ix')) %>%
    mutate(., is_spec = ifelse(!is.na(is_spec), 1, 0)) %>%
    group_by(q_ix, token_ix, domain, lang) %>%
    summarize(., n_spec_exps = sum(is_spec), n_possible_exps = n(), .groups = 'drop') %>%
    group_by(domain, lang) %>%
    summarize(
        n_questions = n_distinct(q_ix),
        n_toks = n(),
        n_toks_with_any_spec_exp = sum(ifelse(n_spec_exps > 0, 1, 0)),
        n_toks_with_half_spec_exp = sum(ifelse(n_spec_exps >= n_possible_exps * .5, 1, 0)),
        n_spec_exps = sum(n_spec_exps),
        n_possible_exps = sum(n_possible_exps),
        .groups = 'drop'
    ) %>%
    mutate(
        .,
        prop_spec_exps = n_spec_exps/n_possible_exps,
        prop_toks_with_any_spec_exp = n_toks_with_any_spec_exp/n_toks,
        prop_toks_with_half_spec_exp = n_toks_with_half_spec_exp/n_toks
    )

## Ablation method 3: within-layer, multiple experts

In [None]:
# Get multi-topk [order matters!]
toks_with_multi_topk =
    topk_df %>%
    filter(topk_ix %in% 1:2) %>%
    arrange(q_ix, token_ix, layer_ix, topk_ix, expert) %>%
    group_by(q_ix, token_ix, layer_ix) %>%
    arrange(topk_ix, .by_group = T) %>% # Order matters, switch topk_ix for expert otherwise
    summarize(., experts = list(expert), .groups = 'drop') %>%
    left_join(., select(sample_df, q_ix, token_ix, domain, lang), by = c('q_ix', 'token_ix'))
    
dom_x_experts = 
    toks_with_multi_topk %>%
    group_by(., domain, lang, layer_ix, experts) %>%
    summarize(., n_samples = n(), .groups = 'drop')

print(head(topks_by_multi_topk, 5))

dom_counts =
    dom_x_experts %>%
    group_by(., domain, lang) %>%
    summarize(., n_tok_samples = sum(n_samples), .groups = 'drop') %>%
    mutate(., n_tok_prop = n_tok_samples/sum(n_tok_samples))

print(head(dom_counts))

In [None]:
dom_x_experts %>%
    pivot_wider(., id_cols = c(layer_ix, experts), names_from = c(domain, lang), values_from = n_samples, values_fill = 0) %>% 
    print()

spec_multi_topk =
    dom_x_experts %>%
    group_by(layer_ix, experts) %>%
    mutate(., prop_of_samples = n_samples/sum(n_samples)) %>%
    ungroup() %>%
    filter(., domain == test_domain & lang == test_lang) %>%
    left_join(dom_counts, by = c('domain', 'lang')) %>%
    filter(., prop_of_samples >= n_tok_prop * 5)

cat('Specialized [exp1, exp2]: ', sum(spec_multi_topk$n_samples), ' of ', sum(filter(dom_x_experts, domain == test_domain & lang == test_lang)$n_samples))

In [None]:
# Analyze proportions to be ablated
toks_with_multi_topk %>%
    left_join(transmute(spec_multi_topk, layer_ix, experts, is_spec = 1), by = c('layer_ix', 'experts')) %>%
    mutate(., is_spec = ifelse(!is.na(is_spec), 1, 0)) %>%
    group_by(q_ix, token_ix, domain, lang) %>%
    summarize(., n_spec_exp_pairs = sum(is_spec), n_possible_exp_pairs = n(), .groups = 'drop') %>%
    group_by(domain, lang) %>%
    summarize(
        n_questions = n_distinct(q_ix),
        n_toks = n(),
        n_toks_with_any_spec_exp_pairs = sum(ifelse(n_spec_exp_pairs > 0, 1, 0)),
        n_toks_with_half_spec_exp_pairs = sum(ifelse(n_spec_exp_pairs >= n_possible_exp_pairs * .5, 1, 0)),
        n_spec_exp_pairs = sum(n_spec_exp_pairs),
        n_possible_exp_pairs = sum(n_possible_exp_pairs),
        .groups = 'drop'
    ) %>%
    mutate(
        .,
        prop_spec_exp_pairs = n_spec_exp_pairs/n_possible_exp_pairs,
        prop_toks_with_any_spec_exp_pairs = n_toks_with_any_spec_exp_pairs/n_toks,
        prop_toks_with_half_spec_exp_pairs = n_toks_with_half_spec_exp_pairs/n_toks
    )