In [None]:
library(tidyverse)
library(fs)
library(ggtext)
library(systemfonts)

ws = '/workspace/deliberative-alignment-jailbreaks'

source(paste0(ws, '/r-utils/plots.r'))

In [None]:
model_prefixes = c('gptoss-20b', 'gptoss-120b', 'nemotron-3-nano', 'qwen3-30b-a3b')

# Probe training accuracy

In [None]:
acc_by_role = list_rbind(map(model_prefixes, function(model_prefix)
  read_csv(file.path(ws, str_glue('experiments/role-analysis/outputs/probe-training/acc_by_role_{model_prefix}.csv')))
))

median_layers =
  acc_by_role %>%
  distinct(., model, layer_ix) %>%
  group_by(., model) %>%
  summarise(layer_ix = sort(layer_ix)[floor(n()/2) + 2])

acc_by_role

In [None]:
# Highest accuracy training layer
acc_by_role %>%
  filter(role_space == 'user,assistant,tool') %>%
  mutate(., is_correct = ifelse(role == pred, 1, 0)) %>%
  group_by(model, role, layer_ix, is_correct) %>%
  summarize(., count = sum(count), .groups = 'drop') %>%
  group_by(., model, layer_ix) %>%
  summarize(., accuracy = sum(is_correct * count)/sum(count), .groups = 'drop') %>%
  group_by(., model) %>%
  slice_max(accuracy, n = 1)

# Accuracy on conversational data

In [None]:
model_prefixes = c('gptoss-20b', 'gptoss-120b', 'qwen3-30b-a3b', 'nemotron-3-nano')
include_roleness = F

base_acc = list_rbind(map(model_prefixes, function(model_prefix)
  read_csv(file.path(ws, str_glue('experiments/role-analysis/outputs/probe-projections/all_conv_acc_{model_prefix}.csv'))) %>%
    mutate(., model = model_prefix)
))

alt_acc = list_rbind(map(model_prefixes, function(model_prefix)
  read_csv(file.path(ws, str_glue('experiments/role-analysis/outputs/probe-projections/alt_conv_acc_{model_prefix}.csv'))) %>%
    mutate(., model = model_prefix)
))

if (include_roleness) {
  base_roleness = list_rbind(map(model_prefixes, function(model_prefix)
    read_csv(file.path(ws, str_glue('experiments/role-analysis/outputs/probe-projections/all_conv_projs_{model_prefix}.csv'))) %>%
      mutate(., model = model_prefix) %>%
      rename(., mean_roleness = mean_acc)
  ))

  alt_roleness = list_rbind(map(model_prefixes, function(model_prefix)
    read_csv(file.path(ws, str_glue('experiments/role-analysis/outputs/probe-projections/alt_conv_projs_{model_prefix}.csv'))) %>%
      mutate(., model = model_prefix) %>%
      rename(., mean_roleness = mean_acc)
  ))

  projections = full_join(
    bind_rows(base_acc, alt_acc),
    bind_rows(base_roleness, alt_roleness),
    join_by(model, conv_type, layer_ix, role_space, role)
  )
} else {
  projections = bind_rows(base_acc, alt_acc)
}

projections

In [None]:
# Get mid-layer results only

convs_df = 
  projections %>%
  filter(., role_space == 'uat') %>%
  inner_join(median_layers, join_by(model, layer_ix)) %>%
  select(., -role_space, -layer_ix) %>%
  mutate(., conv_type_2 = case_when(
    role == 'user' & conv_type == 'tagged' ~ 'user_in_user_role_tags',
    role == 'assistant' & conv_type == 'tagged' ~ 'asst_in_asst_role_tags',

    role == 'user' & conv_type == 'untagged' ~ 'user_no_role_tags',
    role == 'assistant' & conv_type == 'untagged' ~ 'assistant_no_role_tags',

    role == 'user' & conv_type == 'tool_tagged' ~ 'user_tool_tagged',
    role == 'assistant' & conv_type == 'tool_tagged' ~ 'assistant_tool_tagged',

    role == 'assistant' & conv_type == 'user_tagged' ~ 'assistant_user_tagged'
  ))
  
convs_df %>%
  filter(., !is.na(conv_type)) %>%
  mutate(., y = 0) %>%
  ggplot() +
  geom_point(aes(x = mean_acc, y = y, color = conv_type), size = 4) +
  facet_grid(rows = vars(model), cols = vars(role)) +
  theme_iclr(base_size = 12) +
  theme(
    axis.text.y = element_blank(),
    axis.ticks.y = element_blank(),
    axis.title.y = element_blank(),
    panel.grid.major.y = element_blank(),
    panel.grid.minor.y = element_blank()
  ) 

convs_df %>%
  write_csv(file.path(ws, 'experiments/role-analysis/outputs/summary-stats/convs_results.csv'))

In [None]:
plot =
  projections %>%
  filter(., role_space == 'uat') %>%
  filter(., conv_type %in% c('tagged', 'untagged', 'tool_tagged')) %>%
  mutate(
    .,
    model = factor(
        model,
        levels = c('gptoss-20b', 'gptoss-120b', 'nemotron-3-nano', 'qwen3-30b-a3b'),
        labels = c('gptoss-20b', 'gptoss-120b', 'nemotron-3-nano', 'qwen3-30b-a3b')
    ),
    conv_type = factor(
        conv_type,
        levels = c('tagged', 'untagged', 'tool_tagged'),
        labels = c('Baseline', 'No tags', 'Injection (tool tagged)')
    ),
    role = factor(
        role,
        levels = c('user', 'assistant'),
        labels = c(
          'Userness<br><span style="font-size:8pt; font-weight:normal">(of user-style text)</span>',
          'Assistantness<br><span style="font-size:8pt; font-weight:normal">(of assistant-style text)</span>'
        )
    )
  ) %>%
  ggplot() +
  geom_line(aes(x = layer_ix, y = mean_acc, color = conv_type), linewidth = 0.5, alpha = 0.7) +
  geom_point(aes(x = layer_ix, y = mean_acc, color = conv_type), size = 1.5) +
  scale_color_manual(
    values = c(
      'Baseline' = '#62748e',
      'No tags' = '#00a6f4',
      'Injection (tool tagged)' = '#ff6467'
    ),
    name = NULL
  ) +
  facet_grid(cols = vars(model), rows = vars(role), scales = 'free', switch = 'y') +
  scale_y_continuous(
      labels = scales::percent_format(accuracy = 1),
      limits = c(0, 1),
      expand = expansion(mult = c(0.02, 0.02)),
      breaks = c(0, 0.5, 1)
  ) +
  labs(
    x = 'Layer index',  y = NULL
  ) +
  theme_iclr(base_size = 11) +
  theme(
      plot.title = ggtext::element_markdown(size = 11),
      axis.title.y = ggtext::element_markdown(angle = 90, vjust = 0.5, margin = margin(r = 6)),
      strip.placement = 'outside',
      panel.spacing.y = unit(1.0, "lines"),
      axis.text.x = ggtext::element_markdown(size = 9, hjust = 0.5, angle = 0, margin = margin(t = 2)),
      axis.line.x = element_blank(),
      axis.ticks.x = element_blank(),
      strip.text.x = element_text(margin = margin(b = 4)),
      strip.text.y.left = ggtext::element_markdown(angle = 90, face = 'bold')
      # panel.grid.major.x = element_blank(),
      # panel.grid.minor.x = element_blank(),
  )

ggsave(
    str_glue('{ws}/experiments/role-analysis/plots/cross-model-validation.pdf'),
    plot = plot, width = 7, height = 3.5, units = 'in', dpi = 300, device = cairo_pdf
)
ggsave(
    str_glue('{ws}/experiments/role-analysis/plots/cross-model-validation.png'),
    plot = plot,  width = 7, height = 3.5, units = 'in', dpi = 300
)

ggsave(
    str_glue('{ws}/docs/cross-model-validation.png'),
    plot = plot, width = 7, height = 3.5, units = 'in', dpi = 300
)

plot