In [None]:
# This file calculates ablation targets using MMLU data. Run after running `run-base-mmlu.ipynb`.

In [None]:
# ---- Load libs ----
library(tidyverse)
library(slider)
library(IRdisplay, include.only = 'display')
library(IRkernel)

model_prefix = 'qwen3moe'

## Load data

In [None]:
# ---- Load data ----
local({

    raw_sample_df =
        data.table::fread(str_glue('data/{model_prefix}/train_samples.csv'), strip.white = F) %>%
        arrange(q_ix, token_ix) %>%
        mutate(., shifted_token_ix = 1:n(), .by = c('q_ix'))
    
    raw_topk_df = data.table::fread(str_glue('data/{model_prefix}/train_topks.csv'))

    last_shared_tok =
        raw_sample_df %>%
        group_by(shifted_token_ix, token) %>%
        summarize(., n_questions_with_token_ix_tok = n(), .groups = 'drop') %>% 
        filter(., n_questions_with_token_ix_tok == length(unique(raw_sample_df$q_ix))) %>%
        .$shifted_token_ix %>%
        max

    sample_df =
        raw_sample_df %>%
        filter(., shifted_token_ix > last_shared_tok)

    topk_df = raw_topk_df %>% inner_join(select(sample_df, q_ix, token_ix), by = c('q_ix', 'token_ix'))

    topk1_df = topk_df %>% filter(., topk_ix == 1) %>% select(-topk_ix)

    sample_df <<- sample_df
    topk_df <<- topk_df
    topk1_df <<- topk1_df
})

In [None]:
# ---- Diagnostic checks ----
domain_lang_map =
    sample_df %>%
    group_by(domain, lang) %>% 
    summarize(., n_tokens = n(), .groups = 'drop')

print(domain_lang_map)

print(sort(unique(topk1_df$expert)))
print(sort(unique(topk1_df$layer_ix)))

In [None]:
# ---- Check accuracy baselines ----
base_accs =
    sample_df %>%
    group_by(domain, lang, q_ix, question_output_token, answer_char) %>% 
    summarize(tokens_per_question = n(), .groups = 'drop') %>%
    mutate(., is_correct = ifelse(str_squish(question_output_token) == answer_char, 1, 0)) %>%
    group_by(domain, lang) %>%
    summarize(base_acc = sum(is_correct)/n(), questions = n(), total_tokens = sum(tokens_per_question), .groups = 'drop')

print(base_accs)

In [None]:
# ---- Check accuracy by question ----
question_accs =
    sample_df %>%
    group_by(domain, lang, q_ix, question_output_token, answer_char) %>% 
    summarize(tokens_per_question = n(), .groups = 'drop') %>%
    mutate(., is_correct = ifelse(str_squish(question_output_token) == answer_char, 1, 0)) %>%
    select(q_ix, is_correct) 

q_ix_correct = question_accs %>% filter(is_correct == 1) %>% .$q_ix
# write_csv(question_accs, str_glue('data/{model_prefix}/question_accs.csv'))

## Ablation method 1: most common transitions

In [None]:
# ---- Get a (token, path) level dataframe ----
toks_with_paths =
    topk1_df %>%
    select(., layer_ix, q_ix, token_ix, expert) %>%
    group_by(., q_ix, token_ix) %>%
    arrange(., layer_ix, .by_group = T) %>%
    mutate(
        path = slide(expert, .f = \(x) x, .before = 1, .after = 0),
        layers = slide(layer_ix, .f = \(x) x, .before = 1, .after = 0)
    ) %>%
    ungroup() %>%
    filter(., layer_ix > 0) %>%
    left_join(., select(sample_df, q_ix, token_ix, domain, lang), by = c('q_ix', 'token_ix'))

head(toks_with_paths, 5)

In [None]:
# ---- Check diagnostics ----
cat('# Samples: ', nrow(sample_df))
cat('\n# Topk1 Experts (n_layers x samples): ', nrow(topk1_df))
cat('\n# Paths ((n_layers - 1) x samples): ', nrow(toks_with_paths))

In [None]:
# ---- Get a (domain, path) level dataframe with counts ----
dom_x_path =
    toks_with_paths %>%
    group_by(., domain, lang, path, layers) %>%
    summarize(., n_samples = n(), .groups = 'drop')

head(dom_x_path, 5)

In [None]:
# ---- Get a (domain) level dataframe with token counts ----
dom_tok_counts =
    dom_x_path %>%
    group_by(., domain, lang) %>%
    summarize(., n_tok_samples = sum(n_samples), .groups = 'drop') %>%
    mutate(., n_tok_prop = n_tok_samples/sum(n_tok_samples))

dom_tok_counts

In [None]:
# ---- For each target_domain x target_lang x target_k, get the specialized paths ----
targets =
    distinct(sample_df, domain, lang) %>%
    expand_grid(k = c(2, 4)) %>%
    rename(target_domain = domain, target_lang = lang, target_k = k)

dom_x_path %>%
    pivot_wider(., id_cols = c(layers, path), names_from = c(domain, lang), values_from = n_samples, values_fill = 0) %>% 
    print()

spec_paths = 
    dom_x_path %>%
    group_by(layers, path) %>%
    mutate(., prop_of_samples = n_samples/sum(n_samples)) %>%
    ungroup() %>%
    expand_grid(targets, .) %>%
    filter(domain == target_domain & lang == target_lang) %>%
    left_join(dom_tok_counts, by = c('domain', 'lang')) %>%
    filter(., prop_of_samples >= n_tok_prop * target_k)

spec_paths_bio_en = spec_paths %>% filter(target_domain == 'biology' & target_lang == 'en' & target_k == 4)

cat('Unique path counts: ', nrow(spec_paths_bio_en))
cat('\nPaths taken counts: ', sum(spec_paths_bio_en$n_samples), ' of ', sum(filter(dom_x_path, domain == 'biology' & lang == 'en')$n_samples))
head(spec_paths_bio_en, 5)

In [None]:
# ---- For each target_domain x target_lang x target_k, get various summary stats on the ablation %s ----

ablation_props = map(group_split(spec_paths, target_domain, target_lang, target_k), .progress = T, function(df_for_target) {
    
    props = 
        toks_with_paths %>%
        left_join(transmute(df_for_target, layers, path, is_spec = 1), by = c('path', 'layers')) %>%
        mutate(., is_spec = ifelse(!is.na(is_spec), 1, 0)) %>%
        group_by(q_ix, token_ix, domain, lang) %>%
        summarize(., n_spec_paths = sum(is_spec), n_possible_paths = n(), .groups = 'drop') %>%
        group_by(domain, lang) %>%
        summarize(
            n_questions = n_distinct(q_ix),
            n_toks = n(),
            n_toks_with_any_spec_path = sum(ifelse(n_spec_paths > 0, 1, 0)),
            n_toks_with_half_spec_path = sum(ifelse(n_spec_paths >= n_possible_paths * .5, 1, 0)),
            n_spec_paths = sum(n_spec_paths),
            n_possible_paths = sum(n_possible_paths),
            .groups = 'drop'
        ) %>%
        mutate(
            .,
            prop_spec_paths = n_spec_paths/n_possible_paths,
            prop_toks_with_any_spec_path = n_toks_with_any_spec_path/n_toks,
            prop_toks_with_half_spec_path = n_toks_with_half_spec_path/n_toks
        )
    
    list(
        target_domain = df_for_target$target_domain[[1]],
        target_lang = df_for_target$target_lang[[1]],
        target_k = df_for_target$target_k[[1]],
        props = props
        )
    })

cat('Bio_en, k=2')
ablation_props %>% keep(., \(x) x$target_domain == 'biology' && x$target_lang == 'en' && x$target_k == 2) %>% .[[1]] %>% .$props %>% display()

cat('Bio_en, k=4')
ablation_props %>% keep(., \(x) x$target_domain == 'biology' && x$target_lang == 'en' && x$target_k == 4) %>% .[[1]] %>% .$props %>% display()

cat('CS, k=2')
ablation_props %>% keep(., \(x) x$target_domain == 'compsci' && x$target_lang == 'en' && x$target_k == 2) %>% .[[1]] %>% .$props %>% display()

cat('CS, k=4')
ablation_props %>% keep(., \(x) x$target_domain == 'compsci' && x$target_lang == 'en' && x$target_k == 4) %>% .[[1]] %>% .$props %>% display()

In [None]:
# ---- For each target_domain x target_lang x target_k, save JSON str ----

# We save it in format
# {
#     layer: [
#         [prefix, target_e]
#         ...
#     ]
# }
spec_paths_exportable = map(group_split(spec_paths, target_domain, target_lang, target_k), .progress = T, function(spec_df_for_target) {

    target_domain = spec_df_for_target$target_domain[[1]]
    target_lang = spec_df_for_target$target_lang[[1]]
    target_k = spec_df_for_target$target_k[[1]]

    spec_paths_str =
        spec_df_for_target %>%
        mutate(
            target_layer = map_int(layers, \(x) tail(x, 1)),
            target_expert = map_int(path, \(x) tail(x, 1)),
            expert_prefix = map(path, \(x) head(x, -1)),
            rule_pair = map2(expert_prefix, target_expert, \(x, y) list(x, y))
        ) %>%
        select(target_layer, rule_pair) %>%
        group_by(., target_layer) %>%
        summarize(
            rules = list(rule_pair),
            .groups = 'drop'
        ) %>%
        {setNames(.$rules, .$target_layer)}

    jsonlite::toJSON(spec_paths_str, simplifyVector = F, auto_unbox = T, pretty = F) %>%
        writeLines(., str_glue('data/{model_prefix}/path_ablation_targets_{target_domain}_{target_lang}_{target_k}.json'))

    list(
        target_domain = target_domain,
        target_lang = target_lang,
        target_k = target_k,
        spec_paths_str = spec_paths_str
    )
})

spec_paths_exportable %>% keep(., \(x) x$target_domain == 'biology' && x$target_lang == 'en' && x$target_k == 2) %>% .[[1]] %>% .$spec_paths_str %>% .[[1]] %>% length()

## Ablation method 2: within-layer

In [None]:
# ---- Get a (token, expert) level dataframe ----
toks_with_expert =
    topk1_df %>%
    group_by(., q_ix, token_ix, layer_ix, expert) %>%
    summarize(., n_samples = n(), .groups = 'drop') %>%
    left_join(., select(sample_df, q_ix, token_ix, domain, lang), by = c('q_ix', 'token_ix'))

head(toks_with_expert, 5)

In [None]:
# ---- Get (domain, expert) level dataframe ----
dom_x_expert =
    toks_with_expert %>%
    group_by(., domain, lang, layer_ix, expert) %>%
    summarize(., n_samples = n(), .groups = 'drop')

head(dom_x_expert, 5)

In [None]:
# ---- Get (domain) level dataframe with counts of token usage ----
dom_full_tok_counts =
    dom_x_expert %>%
    group_by(., domain, lang) %>%
    summarize(., n_tok_samples = sum(n_samples), .groups = 'drop') %>%
    mutate(., n_tok_prop = n_tok_samples/sum(n_tok_samples))

dom_full_tok_counts

In [None]:
# ---- Get specialized experts for each (target) ----
dom_x_expert %>%
    pivot_wider(., id_cols = c(layer_ix, expert), names_from = c(domain, lang), values_from = n_samples, values_fill = 0) %>% 
    print()

spec_experts =
    dom_x_expert %>%
    group_by(layer_ix, expert) %>%
    mutate(., prop_of_samples = n_samples/sum(n_samples)) %>%
    ungroup() %>%
    expand_grid(targets, .) %>%
    filter(domain == target_domain & lang == target_lang) %>%
    left_join(dom_full_tok_counts, by = c('domain', 'lang')) %>%
    filter(., prop_of_samples >= n_tok_prop * target_k)

spec_experts_bio_en = spec_experts %>% filter(target_domain == 'biology' & target_lang == 'en' & target_k == 4)

cat('Unique expert counts: ', nrow(spec_experts_bio_en))
cat('\nExperts taken counts: ', sum(spec_experts_bio_en$n_samples), ' of ', sum(filter(dom_x_path, domain == 'biology' & lang == 'en')$n_samples))
head(spec_experts_bio_en, 5)

In [None]:
# ---- For each target_domain x target_lang x target_k, get various summary stats on the ablation %s ----
expert_ablation_props = map(group_split(spec_experts, target_domain, target_lang, target_k), .progress = T, function(df_for_target) {
    
    props = 
        toks_with_expert %>%
        left_join(transmute(df_for_target, layer_ix, expert, is_spec = 1), by = c('layer_ix', 'expert')) %>%
        mutate(., is_spec = ifelse(!is.na(is_spec), 1, 0)) %>%
        group_by(q_ix, token_ix, domain, lang) %>%
        summarize(., n_spec_exps = sum(is_spec), n_possible_exps = n(), .groups = 'drop') %>%
        group_by(domain, lang) %>%
        summarize(
            n_questions = n_distinct(q_ix),
            n_toks = n(),
            n_toks_with_any_spec_exp = sum(ifelse(n_spec_exps > 0, 1, 0)),
            n_toks_with_half_spec_exp = sum(ifelse(n_spec_exps >= n_possible_exps * .5, 1, 0)),
            n_spec_exps = sum(n_spec_exps),
            n_possible_exps = sum(n_possible_exps),
            .groups = 'drop'
        ) %>%
        mutate(
            .,
            prop_spec_exps = n_spec_exps/n_possible_exps,
            prop_toks_with_any_spec_exp = n_toks_with_any_spec_exp/n_toks,
            prop_toks_with_half_spec_exp = n_toks_with_half_spec_exp/n_toks
        )
    
    list(
        target_domain = df_for_target$target_domain[[1]],
        target_lang = df_for_target$target_lang[[1]],
        target_k = df_for_target$target_k[[1]],
        props = props
        )
    })

cat('Bio_en, k=2')
expert_ablation_props %>% keep(., \(x) x$target_domain == 'biology' && x$target_lang == 'en' && x$target_k == 2) %>% .[[1]] %>% .$props %>% display()

cat('Bio_en, k=4')
expert_ablation_props %>% keep(., \(x) x$target_domain == 'biology' && x$target_lang == 'en' && x$target_k == 4) %>% .[[1]] %>% .$props %>% display()

cat('CS, k=2')
expert_ablation_props %>% keep(., \(x) x$target_domain == 'compsci' && x$target_lang == 'en' && x$target_k == 2) %>% .[[1]] %>% .$props %>% display()

cat('CS, k=4')
expert_ablation_props %>% keep(., \(x) x$target_domain == 'compsci' && x$target_lang == 'en' && x$target_k == 4) %>% .[[1]] %>% .$props %>% display()

In [None]:
# ---- For each target_domain x target_lang x target_k, save JSON str ----

# We save it in format
# {
#     layer: [expert1, expert2],
# }
spec_experts_exportable = map(group_split(spec_experts, target_domain, target_lang, target_k), .progress = T, function(spec_df_for_target) {

    target_domain = spec_df_for_target$target_domain[[1]]
    target_lang = spec_df_for_target$target_lang[[1]]
    target_k = spec_df_for_target$target_k[[1]]

    spec_experts_str =
        spec_df_for_target %>%
        group_by(., layer_ix) %>%
        summarize(
            expert = list(expert),
            .groups = 'drop'
        ) %>% 
        filter(layer_ix > 0) %>%
        {setNames(.$expert, .$layer_ix)}

    jsonlite::toJSON(spec_experts_str, simplifyVector = F, auto_unbox = T, pretty = F) %>%
        writeLines(., str_glue('data/{model_prefix}/expert_ablation_targets_{target_domain}_{target_lang}_{target_k}.json'))

    list(
        target_domain = target_domain,
        target_lang = target_lang,
        target_k = target_k,
        spec_experts_str = spec_experts_str
    )
})

spec_experts_exportable %>% keep(., \(x) x$target_domain == 'biology' && x$target_lang == 'en' && x$target_k == 2) %>% .[[1]] %>% .$spec_experts_str %>% .[[1]] %>% length()

## Ablation method 3: 4-layer paths

In [None]:
# ---- Get a (token, path) level dataframe ----
multipath_length = 2

toks_with_multipaths =
    topk1_df %>%
    select(., layer_ix, q_ix, token_ix, expert) %>%
    group_by(., q_ix, token_ix) %>%
    arrange(., layer_ix, .by_group = T) %>%
    mutate(
        path = slide(expert, .f = \(x) x, .before = multipath_length, .after = 0),
        layers = slide(layer_ix, .f = \(x) x, .before = multipath_length, .after = 0)
    ) %>%
    ungroup() %>%
    filter(., layer_ix > multipath_length - 1) %>%
    left_join(., select(sample_df, q_ix, token_ix, domain, lang), by = c('q_ix', 'token_ix'))

head(toks_with_multipaths, 5)

In [None]:
# ---- Get a (domain, path) level dataframe with counts ----
dom_x_multipath =
    toks_with_multipaths %>%
    group_by(., domain, lang, path, layers) %>%
    summarize(., n_samples = n(), .groups = 'drop')

head(dom_x_multipath, 5)

In [None]:
# ---- Get a (domain) level dataframe with token counts ----
dom_multipath_tok_counts =
    dom_x_multipath %>%
    group_by(., domain, lang) %>%
    summarize(., n_tok_samples = sum(n_samples), .groups = 'drop') %>%
    mutate(., n_tok_prop = n_tok_samples/sum(n_tok_samples))

dom_multipath_tok_counts

In [None]:
# ---- For each target_domain x target_lang x target_k, get the specialized paths ----
targets =
    distinct(sample_df, domain, lang) %>%
    expand_grid(k = c(2, 4)) %>%
    rename(target_domain = domain, target_lang = lang, target_k = k)

dom_x_multipath %>%
    pivot_wider(., id_cols = c(layers, path), names_from = c(domain, lang), values_from = n_samples, values_fill = 0) %>% 
    print()

spec_multipaths = 
    dom_x_multipath %>%
    group_by(layers, path) %>%
    mutate(., prop_of_samples = n_samples/sum(n_samples)) %>%
    ungroup() %>%
    expand_grid(targets, .) %>%
    filter(domain == target_domain & lang == target_lang) %>%
    left_join(dom_multipath_tok_counts, by = c('domain', 'lang')) %>%
    filter(., prop_of_samples >= n_tok_prop * target_k)

spec_multipaths_bio_en = spec_multipaths %>% filter(target_domain == 'biology' & target_lang == 'en' & target_k == 4)

cat('Unique path counts: ', nrow(spec_multipaths_bio_en))
cat('\nPaths taken counts: ', sum(spec_multipaths_bio_en$n_samples), ' of ', sum(filter(dom_x_multipath, domain == 'biology' & lang == 'en')$n_samples))
head(spec_multipaths_bio_en, 5)

In [None]:
# ---- For each target_domain x target_lang x target_k, get various summary stats on the ablation %s ----

ablation_multipath_props = map(group_split(spec_multipaths, target_domain, target_lang, target_k), .progress = T, function(df_for_target) {
    
    props = 
        toks_with_multipaths %>%
        left_join(transmute(df_for_target, layers, path, is_spec = 1), by = c('path', 'layers')) %>%
        mutate(., is_spec = ifelse(!is.na(is_spec), 1, 0)) %>%
        group_by(q_ix, token_ix, domain, lang) %>%
        summarize(., n_spec_paths = sum(is_spec), n_possible_paths = n(), .groups = 'drop') %>%
        group_by(domain, lang) %>%
        summarize(
            n_questions = n_distinct(q_ix),
            n_toks = n(),
            n_toks_with_any_spec_path = sum(ifelse(n_spec_paths > 0, 1, 0)),
            n_toks_with_half_spec_path = sum(ifelse(n_spec_paths >= n_possible_paths * .5, 1, 0)),
            n_spec_paths = sum(n_spec_paths),
            n_possible_paths = sum(n_possible_paths),
            .groups = 'drop'
        ) %>%
        mutate(
            .,
            prop_spec_paths = n_spec_paths/n_possible_paths,
            prop_toks_with_any_spec_path = n_toks_with_any_spec_path/n_toks,
            prop_toks_with_half_spec_path = n_toks_with_half_spec_path/n_toks
        )
    
    list(
        target_domain = df_for_target$target_domain[[1]],
        target_lang = df_for_target$target_lang[[1]],
        target_k = df_for_target$target_k[[1]],
        props = props
        )
    })

cat('Bio_en, k=2')
ablation_multipath_props %>% keep(., \(x) x$target_domain == 'biology' && x$target_lang == 'en' && x$target_k == 2) %>% .[[1]] %>% .$props %>% display()

cat('Bio_en, k=4')
ablation_multipath_props %>% keep(., \(x) x$target_domain == 'biology' && x$target_lang == 'en' && x$target_k == 4) %>% .[[1]] %>% .$props %>% display()

cat('CS, k=2')
ablation_multipath_props %>% keep(., \(x) x$target_domain == 'compsci' && x$target_lang == 'en' && x$target_k == 2) %>% .[[1]] %>% .$props %>% display()

cat('CS, k=4')
ablation_multipath_props %>% keep(., \(x) x$target_domain == 'compsci' && x$target_lang == 'en' && x$target_k == 4) %>% .[[1]] %>% .$props %>% display()

In [None]:
# ---- For each target_domain x target_lang x target_k, save JSON str ----

# We save it in format
# {
#     layer: [
#         [prefix, target_e]
#         ...
#     ]
# }
spec_multipaths_exportable = map(group_split(spec_multipaths, target_domain, target_lang, target_k), .progress = T, function(spec_df_for_target) {

    target_domain = spec_df_for_target$target_domain[[1]]
    target_lang = spec_df_for_target$target_lang[[1]]
    target_k = spec_df_for_target$target_k[[1]]

    spec_multipaths_str =
        spec_df_for_target %>%
        # filter(., n_samples >= 2) %>%
        mutate(
            target_layer = map_int(layers, \(x) tail(x, 1)),
            target_expert = map_int(path, \(x) tail(x, 1)),
            expert_prefix = map(path, \(x) head(x, -1)),
            rule_pair = map2(expert_prefix, target_expert, \(x, y) list(x, y))
        ) %>%
        select(target_layer, rule_pair) %>%
        group_by(., target_layer) %>%
        summarize(
            rules = list(rule_pair),
            .groups = 'drop'
        ) %>%
        {setNames(.$rules, .$target_layer)}

    jsonlite::toJSON(spec_multipaths_str, simplifyVector = F, auto_unbox = T, pretty = F) %>%
        writeLines(., str_glue('data/{model_prefix}/multipath_ablation_targets_{target_domain}_{target_lang}_{target_k}.json'))

    list(
        target_domain = target_domain,
        target_lang = target_lang,
        target_k = target_k,
        spec_multipaths_str = spec_multipaths_str
    )
})

spec_multipaths_exportable %>% keep(., \(x) x$target_domain == 'biology' && x$target_lang == 'en' && x$target_k == 2) %>% .[[1]] %>% .$spec_multipaths_str %>% .[[4]] %>% length()

## UNUSED: Ablation method 3: within-layer, multiple experts

In [None]:
# # Get multi-topk [order matters!]
# toks_with_multi_topk =
#     topk_df %>%
#     filter(topk_ix %in% 1:2) %>%
#     arrange(q_ix, token_ix, layer_ix, topk_ix, expert) %>%
#     group_by(q_ix, token_ix, layer_ix) %>%
#     arrange(topk_ix, .by_group = T) %>% # Order matters, switch topk_ix for expert otherwise
#     summarize(., experts = list(expert), .groups = 'drop') %>%
#     left_join(., select(sample_df, q_ix, token_ix, domain, lang), by = c('q_ix', 'token_ix'))
    
# dom_x_experts = 
#     toks_with_multi_topk %>%
#     group_by(., domain, lang, layer_ix, experts) %>%
#     summarize(., n_samples = n(), .groups = 'drop')

# print(head(topks_by_multi_topk, 5))

# dom_counts =
#     dom_x_experts %>%
#     group_by(., domain, lang) %>%
#     summarize(., n_tok_samples = sum(n_samples), .groups = 'drop') %>%
#     mutate(., n_tok_prop = n_tok_samples/sum(n_tok_samples))

# print(head(dom_counts))

In [None]:
# dom_x_experts %>%
#     pivot_wider(., id_cols = c(layer_ix, experts), names_from = c(domain, lang), values_from = n_samples, values_fill = 0) %>% 
#     print()

# spec_multi_topk =
#     dom_x_experts %>%
#     group_by(layer_ix, experts) %>%
#     mutate(., prop_of_samples = n_samples/sum(n_samples)) %>%
#     ungroup() %>%
#     filter(., domain == test_domain & lang == test_lang) %>%
#     left_join(dom_counts, by = c('domain', 'lang')) %>%
#     filter(., prop_of_samples >= n_tok_prop * 5)

# cat('Specialized [exp1, exp2]: ', sum(spec_multi_topk$n_samples), ' of ', sum(filter(dom_x_experts, domain == test_domain & lang == test_lang)$n_samples))

In [None]:
# # Analyze proportions to be ablated
# toks_with_multi_topk %>%
#     left_join(transmute(spec_multi_topk, layer_ix, experts, is_spec = 1), by = c('layer_ix', 'experts')) %>%
#     mutate(., is_spec = ifelse(!is.na(is_spec), 1, 0)) %>%
#     group_by(q_ix, token_ix, domain, lang) %>%
#     summarize(., n_spec_exp_pairs = sum(is_spec), n_possible_exp_pairs = n(), .groups = 'drop') %>%
#     group_by(domain, lang) %>%
#     summarize(
#         n_questions = n_distinct(q_ix),
#         n_toks = n(),
#         n_toks_with_any_spec_exp_pairs = sum(ifelse(n_spec_exp_pairs > 0, 1, 0)),
#         n_toks_with_half_spec_exp_pairs = sum(ifelse(n_spec_exp_pairs >= n_possible_exp_pairs * .5, 1, 0)),
#         n_spec_exp_pairs = sum(n_spec_exp_pairs),
#         n_possible_exp_pairs = sum(n_possible_exp_pairs),
#         .groups = 'drop'
#     ) %>%
#     mutate(
#         .,
#         prop_spec_exp_pairs = n_spec_exp_pairs/n_possible_exp_pairs,
#         prop_toks_with_any_spec_exp_pairs = n_toks_with_any_spec_exp_pairs/n_toks,
#         prop_toks_with_half_spec_exp_pairs = n_toks_with_half_spec_exp_pairs/n_toks
#     )