In [None]:
library(tidyverse)
library(fs)
library(ggtext)
library(systemfonts)

ws = '/workspace/deliberative-alignment-jailbreaks'

source(paste0(ws, '/r-utils/plots.r'))

In [None]:
model_prefixes = c('gptoss-20b', 'gptoss-120b', 'nemotron-3-nano', 'qwen3-30b-a3b')

# Probe training data

In [None]:
acc_by_role = list_rbind(map(model_prefixes, function(model_prefix)
  read_csv(file.path(ws, str_glue('experiments/role-analysis/outputs/probe-training/acc_by_role_{model_prefix}.csv')))
))

median_layers =
  acc_by_role %>%
  distinct(., model, layer_ix) %>%
  group_by(., model) %>%
  summarise(layer_ix = sort(layer_ix)[floor(n()/2) + 2])

acc_by_role

In [None]:
median_layers

In [None]:
median_layers

In [None]:
median_layers =
  acc_by_role %>%
  distinct(., model, layer_ix) %>%
  group_by(., model) %>%
  summarise(median_layer = sort(layer_ix)[floor(n()/2) + 1])


In [None]:
# Get highest accuracy layer
acc_by_role %>%
  filter(role_space == 'user,assistant,tool') %>%
  mutate(., is_correct = ifelse(role == pred, 1, 0)) %>%
  group_by(model, role, layer_ix, is_correct) %>%
  summarize(., count = sum(count), .groups = 'drop') %>%
  group_by(., model, layer_ix) %>%
  summarize(., accuracy = sum(is_correct * count)/sum(count), .groups = 'drop') %>%
  group_by(., model) %>%
  slice_max(accuracy, n = 1)

In [None]:
acc_by_role %>%
  filter(role_space == 'user,assistant,tool') %>%
  mutate(., is_correct = ifelse(role == pred, 1, 0)) %>%
  group_by(model, role, layer_ix) %>%
  summarize(., acc = sum(is_correct * count)/sum(count), .groups = 'drop') %>%
  group_by(model, layer_ix) %>%
  summarize(., mean_acc = medan(acc), .groups = 'drop') %>%
  group_by(., model) %>%
  slice_max(mean_acc, n = 1)

In [None]:
acc_by_role %>%
  filter(role_space == 'user,assistant,tool') %>%
  mutate(., is_correct = ifelse(role == pred, 1, 0)) %>%
  group_by(model, layer_ix) %>%
  summarize(., acc = sum(is_correct * count)/sum(count), .groups = 'drop') %>%
  group_by(., model) %>%
  slice_max(acc, n = 1)

In [None]:
acc_by_role %>%
  filter(., model == 'qwen3-30b-a3b') %>%
  filter(., layer_ix %in% c(4, 8)) %>%
  group_by(role, model, role_space, layer_ix) %>%
  summarize(., acc = sum(ifelse(role == pred, 1, 0) * count)/sum(count), .groups = 'drop') %>%
  pivot_wider(., id_cols = c(layer_ix, role_space), names_from = role, values_from = acc)

In [None]:
## JUST DO MID LAYER

In [None]:
projections   %>%
  filter(role_space == 'uat') %>%
  filter(model == 'nemotron-3-nano') %>%
  filter(., layer_ix == 20)


In [None]:
# Get highest accuracy layer
 
acc_by_role %>%
  filter(role_space == 'user,assistant,tool') %>%
  mutate(., is_correct = ifelse(role == pred, 1, 0)) %>%
  group_by(model, layer_ix, is_correct) %>%
  summarize(., count = sum(count), .groups = 'drop') %>%
  pivot_wider(., id_cols = c(model, layer_ix), names_from = is_correct, names_prefix = 'is_correct_', values_from = count) %>%
  mutate(., accuracy = is_correct_1/(is_correct_0 + is_correct_1)) %>%
  filter(., layer_ix %% 4 == 0) %>%
  pivot_wider(., id_cols = c(layer_ix), names_from = model, values_from = accuracy)

# Conversational validation

In [None]:
model_prefixes = c('gptoss-20b', 'gptoss-120b', 'qwen3-30b-a3b', 'nemotron-3-nano')
include_roleness = F

base_acc = list_rbind(map(model_prefixes, function(model_prefix)
  read_csv(file.path(ws, str_glue('experiments/role-analysis/outputs/probe-projections/all_conv_acc_{model_prefix}.csv'))) %>%
    mutate(., model = model_prefix)
))

alt_acc = list_rbind(map(model_prefixes, function(model_prefix)
  read_csv(file.path(ws, str_glue('experiments/role-analysis/outputs/probe-projections/alt_conv_acc_{model_prefix}.csv'))) %>%
    mutate(., model = model_prefix)
))

if (include_roleness) {
  base_roleness = list_rbind(map(model_prefixes, function(model_prefix)
    read_csv(file.path(ws, str_glue('experiments/role-analysis/outputs/probe-projections/all_conv_projs_{model_prefix}.csv'))) %>%
      mutate(., model = model_prefix) %>%
      rename(., mean_roleness = mean_acc)
  ))

  alt_roleness = list_rbind(map(model_prefixes, function(model_prefix)
    read_csv(file.path(ws, str_glue('experiments/role-analysis/outputs/probe-projections/alt_conv_projs_{model_prefix}.csv'))) %>%
      mutate(., model = model_prefix) %>%
      rename(., mean_roleness = mean_acc)
  ))

  projections = full_join(
    bind_rows(base_acc, alt_acc),
    bind_rows(base_roleness, alt_roleness),
    join_by(model, conv_type, layer_ix, role_space, role)
  )
} else {
  projections = bind_rows(base_acc, alt_acc)
}

projections

## Mid-layer tagged -> untagged

In [None]:
# Mid layer plot
convs_df = 
  projections %>%
  filter(., role_space == 'uat') %>%
  inner_join(median_layers, join_by(model, layer_ix)) %>%
  select(., -role_space, -layer_ix) %>%
  mutate(., conv_type_2 = case_when(
    role == 'user' & conv_type == 'tagged' ~ 'user_in_user_role_tags',
    role == 'assistant' & conv_type == 'tagged' ~ 'asst_in_asst_role_tags',

    role == 'user' & conv_type == 'untagged' ~ 'user_no_role_tags',
    role == 'assistant' & conv_type == 'untagged' ~ 'assistant_no_role_tags',

    role == 'user' & conv_type == 'tool_tagged' ~ 'user_tool_tagged',
    role == 'assistant' & conv_type == 'tool_tagged' ~ 'assistant_tool_tagged',

    role == 'assistant' & conv_type == 'user_tagged' ~ 'assistant_user_tagged'
  ))
  
convs_df %>%
  filter(., !is.na(conv_type)) %>%
  mutate(., y = 0) %>%
  ggplot() +
  geom_point(aes(x = mean_acc, y = y, color = conv_type), size = 4) +
  facet_grid(rows = vars(model), cols = vars(role)) +
  theme_iclr(base_size = 12) +
  theme(
    axis.text.y = element_blank(),
    axis.ticks.y = element_blank(),
    axis.title.y = element_blank(),
    panel.grid.major.y = element_blank(),
    panel.grid.minor.y = element_blank()
  ) 

convs_df %>%
  write_csv(file.path(ws, 'experiments/role-analysis/outputs/summary-stats/convs_results.csv'))

In [None]:
convs_df_2 =
  projections %>%
  filter(., role_space == 'uat') %>%
  mutate(., conv_type_2 = case_when(
    role == 'user' & conv_type == 'tagged' ~ 'user_in_user_role_tags',
    role == 'assistant' & conv_type == 'tagged' ~ 'asst_in_asst_role_tags',

    role == 'user' & conv_type == 'untagged' ~ 'user_no_role_tags',
    role == 'assistant' & conv_type == 'untagged' ~ 'assistant_no_role_tags',

    role == 'user' & conv_type == 'tool_tagged' ~ 'user_tool_tagged',
    role == 'assistant' & conv_type == 'tool_tagged' ~ 'assistant_tool_tagged',

    role == 'assistant' & conv_type == 'user_tagged' ~ 'assistant_user_tagged'
  ))

convs_df_2  %>%
  filter(., model == 'qwen3-30b-a3b') %>%
  filter(., conv_type_2 == 'user_in_user_role_tags') 

In [None]:
projections %>%
  filter(., role_space == 'uat') %>%
  inner_join(median_layers, join_by(model, layer_ix)) %>%
  filter(., conv_type == 'alt_tagged')


In [None]:
projections %>%
  filter(., role_space == 'uat') %>%
  inner_join(median_layers, join_by(model, layer_ix)) %>%
  filter(., conv_type == 'tagged')



In [None]:
alt_projections %>%
  filter(., role == 'user' & role_space == 'uat') %>%
  filter(., conv_type != 'user_tagged') %>%
  group_by(model) %>%
  filter(layer_ix %in% {
    xs = sort(unique(layer_ix))
    xs[ceiling((length(xs) + 1) / 2)]
  }) %>%
  ungroup()

In [None]:
alt_projections