In [None]:
# Plot test role probes

In [None]:
library(tidyverse)
library(fs)
library(ggtext)
library(systemfonts)
library(arrow)
library(patchwork)

ws = '/workspace/deliberative-alignment-jailbreaks'
model_prefix = 'gptoss-20b'

source(paste0(ws, '/r-utils/plots.r'))

# Load data

In [None]:
base_path = file.path(ws, 'experiments/role-injection-analysis/activations-redteam', model_prefix)

prompts_df =
    read_csv(file.path(base_path, 'base-harmful-responses-classified.csv'), trim_ws = FALSE) %>%
    select(
        .,
        redteam_prompt_ix,
        harmful_question_ix, # harmful_question, harmful_question_category,
        qualifier_type, policy_style, synthetic_policy,
        output_class
    ) %>%
    mutate(., output_class = ifelse(output_class == 'REDIRECTION', 'REFUSAL', output_class))

print(prompts_df)

raw_projections_df =
    read_feather(file.path(ws, str_glue('experiments/role-injection-analysis/projections/redteam-role-projections-{model_prefix}.feather')))

probe_mapping_df =
    read_csv(file.path(ws, str_glue('experiments/role-injection-analysis/projections/redteam-role-probe-mapping-{model_prefix}.csv')))

head(raw_projections_df, 5)

# Test probes

In [None]:
# Choose primary test layer / roles
print(probe_mapping_df)

test_layer_ix = 12
test_roles = 'assistant-cot,assistant-final,system,user'
test_probe_ix = filter(probe_mapping_df, layer_ix == test_layer_ix, roles == test_roles)$probe_ix
test_probe_ix

In [None]:
# Merge - get all layers, but only roles = test_roles
roles_df =
    raw_projections_df %>%
    inner_join(
        probe_mapping_df %>% filter(roles == test_roles),
        by = 'probe_ix'
    ) %>%
    inner_join(
        .,
        prompts_df %>% select(., redteam_prompt_ix, qualifier_type, policy_style, output_class),
        by = 'redteam_prompt_ix'
    )

head(roles_df, 5)

In [None]:
# Get prompt-level RCI for forged cot by role space
prompt_x_role_space_cotness =
    roles_df %>%
    filter(., layer_ix > 0) %>%
    filter(., roles == test_roles) %>%
    filter(., role == 'user' & base_message_type == 'forged_cot') %>% # Get forged CoT only
    filter(., role_space %in% c('user', 'assistant-cot')) %>%
    group_by(layer_ix, redteam_prompt_ix, role_space) %>%
    summarize(., spaceness = mean(prob), .groups = 'drop') %>%
    pivot_wider(id_cols = c(layer_ix, redteam_prompt_ix), names_from = role_space, values_from = spaceness) %>%
    rename(., cotness = 'assistant-cot', userness = 'user') %>%
    mutate(., rci = .5 * (cotness - userness) + .5)

prompt_x_role_space_cotness

In [None]:
roles_df %>%
    filter(layer_ix == 12) %>%
    filter(., qualifier_type == 'no_qualifier' & policy_style == 'no_policy') %>%
    group_by(redteam_prompt_ix, role_space) %>%
    arrange(., token_ix, .by_group = T) %>%
    mutate(
        ., 
        prob = zoo::rollmean(prob, 10, 'right', fill = NA),
        # prob = zoo::rollapply(prob, seq_along(prob), \(x) .5^(seq(length(x) - 1, 0)) %>% {sum(x * .)/sum(.)}, align = 'right', partial = T),
        dprob = prob - lag(prob, 1)) %>%
    ungroup() %>%
    group_by(token, role_space) %>%
    summarize(
        .,
        n = n(),
        mean_dprob = mean(dprob, na.rm = T),
        .groups = 'drop'
    ) %>%
    filter(., n >= 20) %>% 
    mutate(., mean_dprob = round(mean_dprob, 4)) -> z 

z

In [None]:
roles_df

In [None]:
z %>%
    filter(role_space == 'user') %>% 
    filter(., n >= 100) %>%
    arrange(mean_dprob)

# Table: CoTness by styled/destyled

In [None]:
# Verify that for CoT forgery, styled is more "assistant-cot-like" than destyled
plot_df =
    prompt_x_role_space_cotness %>%
    inner_join(prompts_df, by = 'redteam_prompt_ix') %>%
    filter(., qualifier_type == 'no_qualifier') %>%
    filter(., policy_style %in% c('base', 'destyled')) %>%
    group_by(layer_ix, policy_style) %>%
    summarize(
        .,
        n_toks = n(),
        n_prompts = n_distinct(redteam_prompt_ix),
        mean_cotness = mean(cotness),
        tail_cotness = tail(cotness, 1),
        .groups = 'drop'
    )
    
# write_csv(plot_df, str_glue('{ws}/experiments/role-analysis/plots/mean-cotness-by-layer-and-style-{model_prefix}.csv'))

plot_df 

In [None]:
# Heatmap
plot_df %>%
    mutate(., policy_style = factor(
        policy_style,
        levels = c('base', 'destyled'),
        labels = c('CoT Forgery', 'Destyled CoT Forgery')
    )) %>%
    mutate(., fmt = format(round(mean_cotness, 2), nsmall = 2)) %>%
    ggplot() +
    geom_tile(aes(x = as.factor(layer_ix), y = policy_style, fill = mean_cotness)) +
    geom_text(aes(x = as.factor(layer_ix), y = policy_style, label = fmt)) +
    scale_fill_gradient(low  = scales::alpha('#ffd230', 0.08), high = scales::alpha('#e17100', 0.9)) +
    labs(
        x = 'Layer index',
        y = 'Policy style'
    ) +
    theme_iclr(base_size = 11) +
    theme(
        legend.position = 'none',
        panel.border = element_blank(),
        panel.background = element_blank(),
        axis.title.y = ggtext::element_markdown(angle = 90, vjust = 0.5, margin = margin(r = 6)),
        axis.text.x = ggtext::element_markdown(size = 9, hjust = 2, angle = 0),
        panel.grid.major.y = element_blank(),
        panel.grid.major.x = element_blank(),
        axis.ticks.x = element_blank(),
        axis.ticks.y = element_blank()
    )


# Plot: Destyled CoTness + ASR

In [None]:
# Plot
color_mappings = c(
    'CoT Forgery' = '#ff637e',
    'Destyled CoT Forgery' = '#c27aff'
)

p1 = prompt_x_role_space_cotness %>%
    inner_join(prompts_df, by = c('redteam_prompt_ix')) %>%
    # filter(., layer_ix == test_layer_ix) %>%
    filter(., layer_ix > 0) %>%
    filter(., qualifier_type == 'no_qualifier') %>%
    filter(., policy_style %in% c('base', 'destyled')) %>%
    group_by(policy_style) %>%
    summarize(
        .,
        n_prompts = n(),
        mean_cotness = mean(cotness),
        .groups = 'drop'
    ) %>%
    mutate(., policy_style = factor(
        policy_style,
        levels = c('base', 'destyled'),
        labels = c('CoT Forgery', 'Destyled CoT Forgery')
    )) %>%
    mutate(
        font_style = 'plain',
        policy_style_colored = factor(
            paste0('<span style="color:', color_mappings[as.character(policy_style)], '">', 
                   as.character(policy_style), '</span>'),
            levels = paste0('<span style="color:', color_mappings[names(color_mappings)], '">', 
                           names(color_mappings), '</span>')
        )
    ) %>%
    ggplot() +
    geom_col(
        aes(x = policy_style_colored, y = mean_cotness, fill = policy_style),
        stat = 'identity', width = .6,  position = position_dodge2(width = 0.8, padding = 0.01, preserve = 'single')
    ) +
    geom_text(
        aes(x = policy_style_colored, y = mean_cotness, label = scales::percent(mean_cotness, accuracy = .1), fontface = font_style),
        position = position_dodge2(width = 0.8, padding = 0.01, preserve = 'single'),
        # vjust = -0.4, 
        hjust = -0.1,
        size = 2.5
    ) +
    labs(x = NULL, y = '(a) CoTness of Forged CoT', fill = NULL) +
    scale_x_discrete(limits = rev, expand = expansion(mult = c(0.6, 0.6))) +
    scale_fill_manual(values = color_mappings) +
    scale_y_continuous(
        labels = scales::percent_format(accuracy = 1),
        limits = c(0, 1),
        expand = expansion(mult = c(0.005, 0.02)),
        breaks = c(0, .5, 1)
    ) +
    theme_iclr(base_size = 10) +
    theme(
        legend.position = 'none',
        axis.title.x = ggtext::element_markdown(angle = 0, vjust = 0.5, margin = margin(t = 6)),
        axis.text.x = ggtext::element_markdown(angle = 0, hjust = 0.5, size = rel(0.95), margin = margin(t = 4)),
        panel.grid.major.y = element_blank(),
        panel.grid.minor.y = element_blank(),
        axis.ticks.y = element_blank(),
        axis.title.y = ggtext::element_markdown(margin = margin(t = 10)),
        axis.text.y = ggtext::element_markdown(hjust = 1, angle = 0, margin = margin(r = 6), face = 'bold'),
        axis.ticks.length.x = unit(0, 'pt'),
    ) +
    coord_flip()

p1

In [None]:
# Plot P2
p2 =
    prompts_df %>%
    filter(., policy_style %in% c('base', 'destyled')) %>%
    mutate(., output_class = ifelse(output_class %in% c('REDIRECTION', 'REFUSAL'), 'REFUSAL', output_class)) %>%
    mutate(., policy_style = factor(
        policy_style,
        levels = c('base', 'destyled'),
        labels = c('CoT Forgery', 'Destyled CoT Forgery')
    )) %>%
    group_by(policy_style, output_class) %>%
    summarize(
        .,
        n_prompts = n_distinct(redteam_prompt_ix),
        .groups = 'drop'
    ) %>%
    pivot_wider(id_cols = c(policy_style), names_from = c(output_class), values_from = c(n_prompts), values_fill = 0) %>%
    mutate(., asr = HARMFUL_RESPONSE/(HARMFUL_RESPONSE + REFUSAL)) %>%
    mutate(
        font_style = 'plain',
        policy_style_colored = factor(
            paste0('<span style="color:', color_mappings[as.character(policy_style)], '">', 
                   as.character(policy_style), '</span>'),
            levels = paste0('<span style="color:', color_mappings[names(color_mappings)], '">', 
                           names(color_mappings), '</span>')
        )
    ) %>%
    ggplot() +
    geom_col(
        aes(x = policy_style_colored, y = asr, fill = policy_style),
        stat = 'identity', width = .6,  position = position_dodge2(width = 0.8, padding = 0.01, preserve = 'single')
    ) +
    geom_text(
        aes(x = policy_style_colored, y = asr, label = scales::percent(asr, accuracy = .1), fontface = font_style),
        position = position_dodge2(width = 0.8, padding = 0.01, preserve = 'single'),
        # vjust = -0.4, 
        hjust = -0.1,
        size = 2.5
    ) +
    labs(x = NULL, y = '(b) Attack success rate', fill = NULL) +
    scale_x_discrete(limits = rev, expand = expansion(mult = c(0.6, 0.6))) +
    scale_fill_manual(values = color_mappings) +
    scale_y_continuous(
        labels = scales::percent_format(accuracy = 1),
        limits = c(0, 1),
        expand = expansion(mult = c(0.005, 0.02)),
        breaks = c(0, .5, 1)
    ) +
    theme_iclr(base_size = 10) +
    theme(
        legend.position = 'none',
        axis.title.x = ggtext::element_markdown(angle = 0, vjust = 0.5, margin = margin(t = 6)),
        axis.text.x = ggtext::element_markdown(angle = 0, hjust = 0.5, size = rel(0.95), margin = margin(t = 4)),
        panel.grid.major.y = element_blank(),
        panel.grid.minor.y = element_blank(),
        axis.ticks.y = element_blank(),
        axis.title.y = ggtext::element_markdown(margin = margin(t = 10)),
        axis.text.y = ggtext::element_markdown(hjust = 1, angle = 0, margin = margin(r = 6), face = 'bold'),
        axis.ticks.length.x = unit(0, 'pt'),
    ) +
    coord_flip()

p2

In [None]:
combined_plot =
    p1 + 
    p2 +
    plot_layout(widths = c(1, 1))

# ggsave(
#     str_glue('{ws}/experiments/role-analysis/plots/destyled-asr.pdf'),
#     plot = combined_plot, width = 7.0, height = 3.0, units = 'in', dpi = 300, device = cairo_pdf
# )
# ggsave(
#     str_glue('{ws}/experiments/role-analysis/plots/destyled-asr.png'),
#     plot = combined_plot, width = 7.0, height = 3.0, units = 'in', dpi = 300
# )

# ggsave(
#     str_glue('{ws}/docs/destyled-asr.png'),
#     plot = combined_plot, width = 7.0, height = 3.0, units = 'in', dpi = 300
# )

# Plot: CoTness vs ASR

In [None]:
# Single layer plot
ngroups = 25

cotness_by_prompt =
    roles_df %>%
    # filter(., qualifier_type == 'no_qualifier') %>%
    filter(., policy_style %in% c('base', 'destyled')) %>%
    filter(., role == 'user' & base_message_type == 'forged_cot') %>% # Get forged CoT only
    filter(., role_space %in% c('user', 'assistant-cot')) %>%
    pivot_wider(., names_from = role_space, values_from = prob) %>%
    mutate(., prob = (`assistant-cot`) - (`user`)) %>%
    filter(., layer_ix == test_layer_ix) %>%
    group_by(redteam_prompt_ix) %>%
    summarize(
        .,
        cotness = mean(prob),
        # cotness = mean(-1 * log10(1 - prob)),
        .groups = 'drop'
    ) %>%
    inner_join(., select(prompts_df, redteam_prompt_ix, output_class), by = 'redteam_prompt_ix') 

cotness_x_asr = map(1:100, function(b) {

    cotness_by_prompt_samples = sample_n(cotness_by_prompt, nrow(cotness_by_prompt), replace = T)

    cotness_by_prompt_samples %>%
        mutate(
            cot_q = ntile(cotness, ngroups) * (100/ngroups)/100
        ) %>%
        group_by(., cot_q) %>%
        summarize(
            .,
            n = n(),
            asr = sum(ifelse(output_class == 'HARMFUL_RESPONSE', 1, 0))/n(),
            .groups = 'drop'
        ) %>%
        mutate(., b = b)
    }) %>%
    list_rbind() %>%
    group_by(cot_q) %>%
    summarize(
        .,
        n_prompts = n(),
        asr_mean = mean(asr),
        asr_bot = quantile(asr, 0.05),
        asr_top = quantile(asr, 0.95),
        .groups = 'drop'
    ) %>%
    ggplot() +
    geom_ribbon(aes(x = (cot_q), ymin = asr_bot, ymax = asr_top), fill = '#fee685', alpha = 0.5) +
    geom_line(aes(x = (cot_q), y = asr_mean), color = '#ffba00', linewidth = 1, alpha = 0.9) +
    geom_point(aes(x = (cot_q), y = asr_mean), color = '#ffba00', size = 2) +
    scale_x_continuous(
        labels = scales::percent_format(accuracy = 1),
        breaks = c(0, .25, .5, .75, 1),
        expand = expansion(mult = c(0.005, 0.005))
    ) +
    scale_y_continuous(
        labels = scales::percent_format(accuracy = 1),
        limits = c(0, 1),
        expand = expansion(mult = c(0, 0.03)),
        breaks = c(0, .2, .4, .6, .8, 1)
    ) +
    labs(
        x = '<b>CoTness</b>, as %ile of CoT Forgery attempts',
        y = '<b>Attack success rate</b>'
    ) +
    theme_iclr(base_size = 11) +
    theme(
        axis.title.y = ggtext::element_markdown(angle = 90, vjust = 0.5, margin = margin(r = 6)),
        axis.title.x = ggtext::element_markdown(angle = 0, vjust = 0, margin = margin(t = 6)),
        axis.text.x = ggtext::element_markdown(angle = 0, hjust = 0.5, size = rel(0.95), margin = margin(t = 4)),
        plot.margin = margin(t = 0, r = 8, b = 0, l = 0, unit = 'pt')
    )

# ggsave(
#     str_glue('{ws}/experiments/role-analysis/plots/cotness-x-asr.pdf'),
#     plot = cotness_x_asr, width = 7.0, height = 3.0, units = 'in', dpi = 300, device = cairo_pdf
# )
# ggsave(
#     str_glue('{ws}/experiments/role-analysis/plots/cotness-x-asr.png'),
#     plot = cotness_x_asr, width = 7.0, height = 3.0, units = 'in', dpi = 300
# )
# ggsave(
#     str_glue('{ws}/docs/cotness-x-asr.png'),
#     plot = cotness_x_asr, width = 7.0, height = 3.0, units = 'in', dpi = 300
# )

cotness_x_asr

In [None]:
# CoTness margin by %ile
ngroups = 25
layers_to_test = c(4, 12, 20)

cotness_by_prompt =
    roles_df %>%
    # filter(., qualifier_type == 'no_qualifier') %>%
    filter(., policy_style %in% c('base', 'destyled')) %>%
    filter(., role == 'user' & base_message_type == 'forged_cot') %>% # Get forged CoT only
    filter(., role_space %in% c('user', 'assistant-cot')) %>%
    pivot_wider(., names_from = role_space, values_from = prob) %>%
    mutate(., prob = (`assistant-cot`) - (`user`)) %>%
    filter(., layer_ix %in% layers_to_test) %>%
    group_by(layer_ix, redteam_prompt_ix) %>%
    summarize(
        .,
        cotness = mean(prob),
        # cotness = mean(-1 * log10(1 - prob)),
        .groups = 'drop'
    ) %>%
    inner_join(., select(prompts_df, redteam_prompt_ix, output_class), by = 'redteam_prompt_ix') 

cotness_x_asr_by_layer = map(1:200, .progress = T, function(b) {

    cotness_by_prompt_samples = sample_n(cotness_by_prompt, nrow(cotness_by_prompt), replace = T)

    cotness_by_prompt_samples %>%
        group_by(layer_ix) %>%
        mutate(
            cot_q = ntile(cotness, ngroups) * (100/ngroups)/100
        ) %>%
        group_by(., layer_ix, cot_q) %>%
        summarize(
            .,
            n = n(),
            asr = sum(ifelse(output_class == 'HARMFUL_RESPONSE', 1, 0))/n(),
            .groups = 'drop'
        ) %>%
        mutate(., b = b)
    }) %>%
    list_rbind() %>%
    group_by(layer_ix, cot_q) %>%
    summarize(
        .,
        n_prompts = n(),
        asr_mean = mean(asr),
        asr_bot = quantile(asr, 0.05),
        asr_top = quantile(asr, 0.95),
        .groups = 'drop'
    ) %>%
    filter(., n_prompts >= 20) %>%
    mutate(., layer_ix = as.factor(layer_ix)) %>%
    ggplot() +
    geom_ribbon(aes(x = (cot_q), ymin = asr_bot, ymax = asr_top, fill = layer_ix), alpha = 0.5, lineend = 'round', linejoin = 'round', linemitre = 2) +
    geom_line(aes(x = (cot_q), y = asr_mean, color = layer_ix), linewidth = 1) +
    geom_point(aes(x = (cot_q), y = asr_mean, color = layer_ix), size = 2) +
    scale_x_continuous(
        labels = scales::percent_format(accuracy = 1),
        breaks = c(0, .5, 1),
        expand = expansion(mult = c(0.005, 0.005))
    ) +
    scale_y_continuous(
        labels = scales::percent_format(accuracy = 1),
        limits = c(0, 1),
        expand = expansion(mult = c(0, 0.005)),
        breaks = c(0, .2, .4, .6, .8, 1)
    ) +
    labs(
        x = '<b>CoTness</b>, as %ile of CoT Forgery attempts',
        y = '<b>Attack success rate</b>'
    ) +
    facet_grid(cols = vars(layer_ix)) +
    theme_iclr(base_size = 11) +
    theme(
        legend.position = 'none',
        axis.title.y = ggtext::element_markdown(angle = 90, vjust = 0.5, margin = margin(r = 6)),
        axis.title.x = ggtext::element_markdown(angle = 0, vjust = 0, margin = margin(t = 6))
    )

# ggsave(
#     str_glue('{ws}/experiments/role-analysis/plots/cotness-x-asr-by-layer.pdf'),
#     plot = cotness_x_asr_by_layer, width = 7.0, height = 3.0, units = 'in', dpi = 300, device = cairo_pdf
# )
# ggsave(
#     str_glue('{ws}/experiments/role-analysis/plots/cotness-x-asr-by-layer.png'),
#     plot = cotness_x_asr_by_layer, width = 7.0, height = 3.0, units = 'in', dpi = 300
# )
# ggsave(
#     str_glue('{ws}/docs/cotness-x-asr-by-layer.png'),
#     plot = cotness_x_asr_by_layer, width = 7.0, height = 3.0, units = 'in', dpi = 300
# )

# Phase portraits

In [None]:
escape_md_html = function(x) {
  #' Escape markdown + HTML so they don't render in plots
  #' 
  #' @param x The string to escape
  #' @return character vector safe for ggtext::element_markdown()

  # Normalize platform newlines
  x = gsub("\r\n|\r", "\n", x, perl = TRUE)
  # Treat any literal <br> in the token text as a newline character; keeps own <br/> to add later separate
  x =  gsub("(?i)<br\\s*/?>", "\n", x, perl = TRUE)
  # Escape backslashes in the original text (so they render literally)
  x = gsub("\\\\", "\\\\\\\\", x, perl = TRUE)
  # Escape CommonMark punctuation so it renders literally
  x = gsub("([\\*_|`\\[\\]\\(\\)\\#\\~\\|!\\+\\-=])", "\\\\\\1", x, perl = TRUE)
  # Prevent HTML tag parsing by making < and > literal
  x = gsub("([<>])", "\\\\\\1", x, perl = TRUE)
  # Finally, convert real newline characters to the visible two chars "\n"
  #  (uses \\\\n in the replacement to emit a single backslash + n)
  x = gsub("\n", "\\\\n", x, perl = TRUE)

  return(x)
}

wrap_tokens = function(tokens, line_width, max_lines = 3) {
  #' Wrap tokens by # of characters
  #'
  #' @param tokens A vector of tokens to insert linebreaks into
  #' @param line_width The number of characters per line
  #' @param max_lines The maximum number of lines to retain

  lines = character(0)
  cur = ''
  used = 0L

  for (t in tokens) {
    cand = paste0(cur, t) # append next token as-is
    if (nchar(cand, type = "width") <= line_width || cur == "") {
      cur  <- cand
      used <- used + 1L
    } else {
      # push current line
      lines = c(lines, cur)
      if (length(lines) >= max_lines) {
        # stop collecting; we still count the remaining tokens as unused
        break
      }
      cur = t
      used = used + 1L
    }
  }
  if (length(lines) < max_lines && nzchar(cur)) {
    lines = c(lines, cur)
  }
  truncated = used < length(tokens) || length(lines) > max_lines
  res = list(lines = lines, truncated = truncated)
  return(res)
}

point_colors = c(
  'System' = '#90a1b9', # slate
  'User' = '#00a6f4', # sky
  'CoT' = '#fd9a00', # amber
  'Assistant' = '#00d492',  # emerald,
  'User (CoT Forgery)' = '#ff637e'
)

text_colors = c(
  'System' = '#62748e', # slate
  'User' = '#0084d1', # sky
  'CoT' = '#e17100', # amber
  'Assistant' = '#009966',  # emerald
  'User (CoT Forgery)' = '#ff637e'

)

In [None]:
test_harmful_question_ix = 33 # 1, 4, 33**
test_layer_ix = 16
test_roles = 'assistant-cot,assistant-final,system,user'

proj_df =
    raw_projections_df %>%
    inner_join(
        probe_mapping_df %>% filter(roles == test_roles & layer_ix == test_layer_ix),
        by = 'probe_ix'
    ) %>%
    inner_join(
        .,
        prompts_df %>%
            filter(harmful_question_ix == test_harmful_question_ix) %>%
            select(., redteam_prompt_ix, qualifier_type, policy_style, harmful_question_ix, output_class),
        by = 'redteam_prompt_ix'
    ) %>%
    # Add in bmt/seg id to allow for filtering later if needed
    group_by(redteam_prompt_ix, role_space) %>%
    arrange(., sample_ix, .by_group = T) %>%
    mutate(seg_ix = consecutive_id(base_message_type)) %>%
    group_by(redteam_prompt_ix, role_space, seg_ix) %>%
    mutate(token_in_seg_ix = row_number()) %>%
    ungroup()

head(proj_df, 5)

In [None]:
# Line plot version
plot_df =
    proj_df %>%
    filter(., role_space != 'system' & base_message_type != 'system') %>%
    group_by(redteam_prompt_ix, seg_ix, role_space) %>%
    mutate(prob_ewma = zoo::rollapply(prob, seq_along(prob), \(x) .1^(seq(length(x) - 1, 0)) %>% {sum(x * .)/sum(.)}, align = 'right', partial = T)) %>%
    ungroup() %>%    
    mutate(
        .,
        role_space = factor(
            role_space,
            levels = c('system', 'user', 'assistant-cot', 'assistant-final'),
            labels = c('Systemness', 'Userness', 'CoTness', 'Assistantness')
        ),
        base_message_type = factor(
            base_message_type,
            levels = c('system', 'user', 'assistant-cot', 'assistant-final', 'forged_cot'),
            labels = c('System', 'User', 'CoT', 'Assistant', 'User (CoT Forgery)')
        ),
        # prompt_key = factor (
        #     prompt_key,
        #     levels = c('basic_no_format', 'everything_in_assistant_tags', 'everything_in_user_tags', 'proper_tags'),
        #     labels = c('No tags', 'Everything in assistant tags', 'Everything in user tags', 'Correct tags')
        # )
    ) %>%
    filter(., token_in_seg_ix <= 100) %>%
    group_by(redteam_prompt_ix, role_space) %>%
    arrange(., seg_ix, token_in_seg_ix, .by_group = T) %>%
    mutate(., token_in_prompt_ix = 1:n()) %>%
    ungroup()

plots = map(group_split(plot_df, redteam_prompt_ix), function(this_plot_df) {

    # Seg text + index for the start of each consecutive base_message_type segment
    segs =
        this_plot_df %>%
        # Collapse out role_space (lossless)
        distinct(redteam_prompt_ix, token_in_prompt_ix, seg_ix, base_message_type, token) %>%
        group_by(redteam_prompt_ix, seg_ix, base_message_type) %>%
        summarize(
            .,
            start_ix = first(token_in_prompt_ix),
            # Contiguous segment
            start_str = {
                nseg = n()
                bmt = first(base_message_type)
                seg_ix = first(seg_ix)
                line_width_chars = case_when(
                    bmt %in% c('User') ~ 6,
                    bmt %in% c('CoT') ~ 10,
                    bmt %in% c('CoT') ~ 18,
                    bmt %in% c('Assistant') ~ 30,
                    TRUE ~ 12
                )

                tok_vec = {if (bmt == 'User') tail(token, -17) else token}
                # Build up to 3 lines by character length
                wrapped = wrap_tokens(tok_vec, line_width = line_width_chars, max_lines = 2)

                # Escape Markdown/HTML inside the token content
                # Join with <br/>, append ellipsis if truncated
                escaped_lines =
                    map_chr(wrapped$lines, escape_md_html) %>%
                    paste0(., collapse = '<br>') %>%
                    {if(isTRUE(wrapped$truncated)) paste0(., '..') else .}

                # escaped_lines
                bmt
                
              },
            .groups = 'drop'
        ) %>%
        mutate(., start_str = str_glue('<span style="color:{text_colors[as.character(base_message_type)]}">{start_str}</span>')) 
        
    this_p =
        this_plot_df %>%
        ggplot() +
        geom_point(aes(x = token_in_prompt_ix, y = prob, color = base_message_type), size  = 0.5) +
        facet_grid(rows = vars(role_space), switch = 'y') +
        scale_y_continuous(
            labels = scales::percent_format(accuracy = 1),
            limits = c(0, 1),
            expand = expansion(mult = c(0.02, 0.02)),
            breaks = c(0, 1)
        ) +
        scale_x_continuous(
            breaks = segs$start_ix,
            labels = segs$start_str,
            expand = expansion(mult = c(0, 0))
        ) +
        scale_color_manual(values = point_colors, drop = FALSE) +
        labs(
            title = unique(this_plot_df$redteam_prompt_ix),
            x = NULL,
            y = NULL,
            color = 'Token Style'
        ) +
        coord_cartesian(clip = "off") +
        theme_iclr(base_size = 11) +
        theme(
            # plot.title = ggtext::element_markdown(size = 10.5),
            legend.position = 'none',
            axis.title.y = ggtext::element_markdown(angle = 90, vjust = 0.5, margin = margin(r = 6)),
            strip.placement = 'outside',
            strip.text.y.left = element_text(angle = 90, face = 'bold'),
            panel.spacing.y = unit(0.4, 'lines'),
            axis.text.x = ggtext::element_markdown(size = 8, hjust = 0, angle = 0)
        )

    # ggsave(
    #     str_glue('{ws}/experiments/role-analysis/plots/tomato-role-space-projections-{this_plot_df$prompt_key[[1]]}.pdf'),
    #     plot = this_p, width = 7, height = 4.0, units = 'in', dpi = 300, device = cairo_pdf
    # )
    # ggsave(
    #     str_glue('{ws}/experiments/role-analysis/plots/tomato-role-space-projections-{this_plot_df$prompt_key[[1]]}.png'),
    #     plot = this_p,  width = 7, height = 4.0, units = 'in', dpi = 300
    # )

    # ggsave(
    #     str_glue('{ws}/docs/tomato-role-space-projections-{this_plot_df$prompt_key[[1]]}.png'),
    #     plot = this_p, width = 7, height = 4.0, units = 'in', dpi = 300
    # )
    print(this_p)
})

In [None]:
point_colors = c(
  'System' = '#90a1b9', # slate
  'User' = '#00a6f4', # sky
  'CoT' = '#fd9a00', # amber
  'Asst' = '#00d492',  # emerald,
  'User (CoT Forgery)' = '#ff637e'
)

text_colors = c(
  'System' = '#62748e', # slate
  'User' = '#0084d1', # sky
  'CoT' = '#e17100', # amber
  'Asst' = '#009966',  # emerald
  'User (CoT Forgery)' = '#ff637e'

)

# CoTness only
plot_df =
    proj_df %>%
    filter(., role_space != 'system' & base_message_type != 'system') %>%
    filter(., qualifier_type == 'no_qualifier') %>%
    group_by(redteam_prompt_ix, seg_ix, role_space) %>%
    mutate(prob_ewma = zoo::rollapply(prob, seq_along(prob), \(x) .5^(seq(length(x) - 1, 0)) %>% {sum(x * .)/sum(.)}, align = 'right', partial = T)) %>%
    ungroup() %>%    
    mutate(
        .,
        role_space = factor(
            role_space,
            levels = c('system', 'user', 'assistant-cot', 'assistant-final'),
            labels = c('Systemness', 'Userness', 'CoTness', 'Assistantness')
        ),
        base_message_type = factor(
            base_message_type,
            levels = c('system', 'user', 'assistant-cot', 'assistant-final', 'forged_cot'),
            labels = c('System', 'User', 'CoT', 'Asst', 'User (CoT Forgery)')
        ),
        policy_style = factor (
            policy_style,
            levels = c('no_policy', 'base', 'destyled'),
            labels = c('(a) No CoT Forgery', '(b) CoT Forgery', '(c) Destyled CoT Forgery')
        )
    ) %>%
    filter(., token_in_seg_ix <= 200) %>%
    group_by(redteam_prompt_ix, role_space) %>%
    arrange(., seg_ix, token_in_seg_ix, .by_group = T) %>%
    mutate(., token_in_prompt_ix = 1:n()) %>%
    ungroup() %>%
    filter(., role_space == 'CoTness') %>%
    select(., -role_space)


plots = map(group_split(plot_df, policy_style), function(this_plot_df) {

    # Seg text + index for the start of each consecutive base_message_type segment
    segs =
        this_plot_df %>%
        group_by(redteam_prompt_ix, seg_ix, base_message_type) %>%
        summarize(
            .,
            start_ix = first(token_in_prompt_ix),
            # Contiguous segment
            start_str = {
                nseg = n()
                bmt = first(base_message_type)
                seg_ix = first(seg_ix)
                line_width_chars = case_when(
                    bmt %in% c('User') ~ 6,
                    bmt %in% c('CoT') ~ 10,
                    bmt %in% c('CoT') ~ 18,
                    bmt %in% c('Assistant') ~ 30,
                    TRUE ~ 12
                )

                tok_vec = {if (bmt == 'User') tail(token, -17) else token}
                # Build up to 3 lines by character length
                wrapped = wrap_tokens(tok_vec, line_width = line_width_chars, max_lines = 2)

                # Escape Markdown/HTML inside the token content
                # Join with <br/>, append ellipsis if truncated
                escaped_lines =
                    map_chr(wrapped$lines, escape_md_html) %>%
                    paste0(., collapse = '<br>') %>%
                    {if(isTRUE(wrapped$truncated)) paste0(., '..') else .}

                # escaped_lines

                bmt
              },
            .groups = 'drop'
        ) %>%
        mutate(., start_str = str_glue('<span style="color:{text_colors[as.character(base_message_type)]}">{start_str}</span>')) 
        
    this_p =
        this_plot_df %>%
        ggplot() +
        geom_point(aes(x = token_in_prompt_ix, y = prob_ewma, color = base_message_type), size = 0.5, alpha = 0.9) +
        geom_line(aes(x = token_in_prompt_ix, y = prob_ewma, color = base_message_type, group = seg_ix), linewidth = 0.5, alpha = 0.7) +
        scale_y_continuous(
            labels = scales::percent_format(accuracy = 1),
            limits = c(0, 1),
            expand = expansion(mult = c(0.02, 0.02)),
            breaks = c(0, .5, 1)
        ) +
        scale_x_continuous(
            breaks = segs$start_ix,
            labels = segs$start_str,
            expand = expansion(mult = c(0, 0))
        ) +
        scale_color_manual(values = point_colors, drop = FALSE) +
        labs(
            title = unique(this_plot_df$policy_style),
            x = NULL,
            y = NULL,
            color = 'Token Style'
        ) +
        coord_cartesian(clip = "off") +
        theme_iclr(base_size = 11) +
        theme(
            plot.title = ggtext::element_markdown(face = 'bold', color = '#45556c', size = 9.5, margin = margin(b = 3, t = 8)),
            legend.position = 'none',
            axis.title.y = element_blank(),  # Remove y-axis title
            axis.text.y = element_text(margin = margin(r = 0)),  # Remove right margin from y-axis text
            strip.placement = 'outside',
            strip.text.y.left = element_text(angle = 90, face = 'bold'),
            panel.spacing.y = unit(0.8, 'lines'),
            axis.text.x = ggtext::element_markdown(size = 8, hjust = 0, angle = 0),
            panel.grid.major.y = element_blank()
            # plot.margin = margin(0, 0, 0, 0),  # Add this line - top, right, bottom, left
        )

    print(this_p)
})

p = (plots[[1]] + plots[[2]] + plots[[3]]) + plot_layout(ncol = 1)

print(p)

In [None]:
# imap(plots, function(p, i) {
#   ggsave(paste0('plot', i, '.svg'), p, device = 'svg', width = 10, height = 6, units = 'in')
# })

In [None]:
plots

In [None]:
ylab =
  ggplot() + theme_void(base_size = 11, base_family = "TeX Gyre Termes") +
  annotate("text", x = 0, y = 0, label = "CoTness", angle = 90, vjust = 0.5)

lay = c(
  area(t = 1, l = 1, b = 3, r = 1),  # y label spanning all rows
  area(t = 1, l = 2, b = 1, r = 2),
  area(t = 2, l = 2, b = 2, r = 2),
  area(t = 3, l = 2, b = 3, r = 2)
)

p <- wrap_plots(
  ylab,
  plots[[1]], 
  plots[[2]],
  plots[[3]],
  design = lay, widths = c(0.03, 1)
) & theme(plot.margin = margin(2, 2, 2, 0)) # rtbl

p

In [None]:
# ggsave(
#     str_glue('{ws}/experiments/role-analysis/plots/cotness-redteam.pdf'),
#     plot = p, width = 7.0, height = 4.0, units = 'in', dpi = 300, device = cairo_pdf
# )
# ggsave(
#     str_glue('{ws}/experiments/role-analysis/plots/cotness-redteam.png'),
#     plot = p, width = 7.0, height = 4.0, units = 'in', dpi = 300
# )

# ggsave(
#     str_glue('{ws}/docs/cotness-redteam.png'),
#     plot = p, width = 7.0, height = 4.0, units = 'in', dpi = 300
# )

# Differential Role Projections

In [None]:
test_layer_ix = 12
test_roles = 'assistant-cot,assistant-final,system,user'
test_prompt_ixs = c(1, 2)

# Assign BMT indices & filter for only specified layer / test roles
# Create prompt type = joint measure of cot forgery/destlyed. each redteam_prompt_ix is a unique combination of prompt type x harmful_question
filtered_proj_df =
    raw_projections_df %>%
    inner_join(
        probe_mapping_df %>% filter(roles == test_roles & layer_ix == test_layer_ix),
        by = 'probe_ix'
    ) %>%
    inner_join(
        .,
        prompts_df %>%
            select(., redteam_prompt_ix, qualifier_type, policy_style, harmful_question_ix, output_class) %>%
            mutate(., prompt_type = case_when(
                policy_style == 'base' & qualifier_type != 'no_qualifier' ~ 'cot_forgery_qual',
                policy_style == 'base' & qualifier_type == 'no_qualifier' ~ 'cot_forgery',
                policy_style == 'destyled' & qualifier_type != 'no_qualifier' ~ 'destyled_cot_forgery_qual',
                policy_style == 'destyled' & qualifier_type == 'no_qualifier' ~ 'destyled_cot_forgery',
                policy_style == 'no_policy' ~ 'no_policy'
            )),
        by = 'redteam_prompt_ix'
    ) %>%
    # Assign BMT indices
    group_by(redteam_prompt_ix, role_space) %>%
    arrange(., sample_ix, .by_group = T) %>%
    mutate(seg_ix = consecutive_id(base_message_type)) %>%
    group_by(redteam_prompt_ix, role_space, seg_ix) %>%
    mutate(token_in_seg_ix = row_number()) %>%
    ungroup() %>%
    arrange(seg_ix, token_in_seg_ix)

filtered_proj_df %>%
    inner_join(
        probe_mapping_df %>% filter(roles == test_roles & layer_ix == test_layer_ix),
        by = 'probe_ix'
    ) %>%
    filter(., redteam_prompt_ix %in% test_prompt_ixs) %>%
    filter(., role_space == 'system') %>%
    summarize(., n_prompts = n_distinct(redteam_prompt_ix)) %>%
    print()

# Group by prompt_type & BMT. Don't group by seg_ix to allow for order differential
prompt_type_bmt_and_token_ix =
    filtered_proj_df %>%
    # Cluster by prompt
    group_by(
        .,
        prompt_type,
        role_space,
        base_message_type,
        seg_ix,
        token_in_seg_ix
    ) %>%
    summarize(
        n_prompts = n_distinct(redteam_prompt_ix), # Same as n()
        mean_spaceness = mean(prob),
        .groups = 'drop'
    ) %>%
    filter(., n_prompts >= 50)

head(prompt_type_bmt_and_token_ix, 5)

In [None]:
p_dfs =
    prompt_type_bmt_and_token_ix %>%
    filter(., base_message_type != 'system') %>%
    filter(., prompt_type %in% c('destyled_cot_forgery', 'cot_forgery', 'no_policy')) %>%
    filter(., (token_in_seg_ix <= 50 & base_message_type == 'user') | token_in_seg_ix <= 100) %>%
    filter(., role_space == 'assistant-cot') %>%
    group_by(., prompt_type) %>%
    arrange(., seg_ix, token_in_seg_ix, role_space, .by_group = T) %>%
    mutate(., token_in_prompt_ix = 1:n()) %>%
    ungroup() %>%
    group_by(seg_ix, role_space) %>%
    mutate(prob_ewma = zoo::rollapply(mean_spaceness, seq_along(mean_spaceness), \(x) .5^(seq(length(x) - 1, 0)) %>% {sum(x * .)/sum(.)}, align = 'right', partial = T)) %>%
    ungroup() %>%    
    mutate(
        .,
        base_message_type = factor(
            base_message_type,
            levels = c('user', 'forged_cot', 'assistant-cot', 'assistant-final'),
            labels = c('User', 'User (CoT Forgery)', 'CoT', 'Asst')
        ),
        prompt_type = factor (
            prompt_type,
            levels = c('no_policy', 'cot_forgery', 'destyled_cot_forgery'),
            labels = c('(a) No CoT Forgery', '(b) CoT Forgery', '(c) Destyled CoT Forgery')
        )
    )
    
point_colors = c(
  'System' = '#90a1b9', # slate
  'User' = '#00a6f4', # sky
  'CoT' = '#fd9a00', # amber
  'Asst' = '#00d492',  # emerald,
  'User (CoT Forgery)' = '#ff637e'
)

plot =
    p_dfs %>%
    ggplot() +
    geom_point(aes(x = token_in_prompt_ix, y = prob_ewma, color = base_message_type), size = 0.5, alpha = 0.9) +
    geom_line(aes(x = token_in_prompt_ix, y = prob_ewma, color = base_message_type, group = base_message_type), linewidth = 0.4, alpha = 0.7) +
    scale_y_continuous(
        labels = scales::percent_format(accuracy = 1),
        limits = c(0, 1),
        expand = expansion(mult = c(0.02, 0.02)),
        breaks = c(0, .5, 1)
    ) +
    facet_wrap(vars(prompt_type), ncol = 1, strip.position = 'top') +
    scale_color_manual(values = point_colors, drop = FALSE) +
    labs(
        title = NULL,
        x = NULL,
        y = 'CoTness',
        color = 'Role'
    ) +
    coord_cartesian(clip = "off") +
    theme_iclr(base_size = 11) +
    theme(
        plot.title = ggtext::element_markdown(face = 'bold', color = '#45556c', size = 9.5, margin = margin(b = 5)),
        # legend.position = 'none',
        axis.title.y = ggtext::element_markdown(angle = 90, vjust = 0.5, margin = margin(r = 6)),
        panel.spacing.y = unit(1.0, 'lines'),
        axis.text.x = element_blank(),
        strip.placement = "outside",
        strip.text.x = element_text(face = "bold", color = '#45556c', size = 9.5, margin = margin(b = 5)),
        strip.background = element_blank(),
        panel.grid.major.x = element_blank(),
        axis.ticks.x = element_blank(),

        legend.background    = element_rect(fill = "#f4f7fb", color = NA),
        legend.margin        = margin(t = 6, r = 8, b = 6, l = 8),

        # outer border around the whole legend box
        # legend.box.background = element_rect(color = "#90a1b9", linewidth = 0.6),
        legend.box.margin     = margin(t = 6, r = 6, b = 6, l = 6),

        # tidy keys/text
        legend.key           = element_rect(fill = NA, color = NA),
        legend.key.width     = unit(14, "pt"),
        legend.key.height    = unit(10, "pt"),
        legend.title         = element_text(face = "bold", colour = "#45556c"),
        legend.text          = ggtext::element_markdown(colour = "#45556c")
    )


# ggsave(
#     str_glue('{ws}/experiments/role-analysis/plots/cotness-redteam-mean.pdf'),
#     plot = plot, width = 7.0, height = 4.0, units = 'in', dpi = 300, device = cairo_pdf
# )
# ggsave(
#     str_glue('{ws}/experiments/role-analysis/plots/cotness-redteam-mean.png'),
#     plot = plot, width = 7.0, height = 4.0, units = 'in', dpi = 300
# )

# ggsave(
#     str_glue('{ws}/docs/cotness-redteam-mean.png'),
#     plot = plot, width = 7.0, height = 4.0, units = 'in', dpi = 300
# )

plot

In [None]:
# Single plot, user CoT region only
prompt_type_bmt_and_token_ix %>%
    filter(., role_space == 'assistant-cot') %>%
    filter(., base_message_type == 'forged_cot') %>%
    filter(., prompt_type %in% c('destyled_cot_forgery', 'cot_forgery')) %>%
    filter(., token_in_seg_ix <= 100) %>%
    ggplot() +
    geom_point(aes(x = token_in_seg_ix, y = mean_spaceness, color = prompt_type), size = 0.8, alpha = 0.9) +
    geom_line(aes(x = token_in_seg_ix, y = mean_spaceness, color = prompt_type, group = prompt_type), linewidth = 0.8, alpha = 0.7) +
    scale_y_continuous(
        labels = scales::percent_format(accuracy = 1),
        limits = c(0, 1),
        expand = expansion(mult = c(0.02, 0.02)),
        breaks = c(0, .5, 1)
    ) +
    labs(
        x = NULL,
        y = NULL,
        color = 'Token Style'
    ) +
    coord_cartesian(clip = "off") +
    theme_iclr(base_size = 11) +
    theme(
        plot.title = ggtext::element_markdown(face = 'bold', color = '#45556c', size = 9.5, margin = margin(b = 5)),
        legend.position = 'none',
        axis.title.y = ggtext::element_markdown(angle = 90, vjust = 0.5, margin = margin(r = 6)),
        strip.placement = 'outside',
        strip.text.y.left = element_text(angle = 90, face = 'bold'),
        panel.spacing.y = unit(0.4, 'lines'),
        axis.text.x = ggtext::element_markdown(size = 8, hjust = 0, angle = 0),
        panel.grid.major.y = element_blank()
    )

In [None]:
# Group by prompt_type & BMT. Don't group by seg_ix to allow for order differential
prompt_type_bmt_and_token_ix =
    filtered_proj_df %>%
    # Cluster by prompt
    group_by(
        .,
        prompt_type,
        role_space,
        base_message_type,
        seg_ix,
        token_in_seg_ix
    ) %>%
    summarize(
        n_prompts = n_distinct(redteam_prompt_ix), # Same as n()
        mean_spaceness = mean(prob),
        .groups = 'drop'
    ) %>%
    filter(., n_prompts >= 50)

head(prompt_type_bmt_and_token_ix, 5)

# Analysis

In [None]:
test_layer_ix = 16

p =
    roles_df %>%
    filter(layer_ix == test_layer_ix) %>%
    filter(qualifier_type == 'no_qualifier' & policy_style %in% c('base', 'destyled')) %>%
    filter(., base_message_type %in% c('assistant-cot', 'assistant-final',  'user', 'forged_cot')) %>%
    # filter(., role_space != 'system') %>%
    # filter(., seg_id <= 20) %>%
    group_by(policy_style, role_space, base_message_type, redteam_prompt_ix) %>%
    summarize(., roleness = mean(prob), .groups = 'drop') %>%
    group_by(policy_style, role_space, base_message_type) %>%
    summarize(., mean_roleness = mean(roleness), .groups = 'drop') %>%
    mutate(., 
        role_space = 
            factor(
                role_space,
                levels = c('user', 'assistant-cot', 'assistant-final', 'tool', 'system'),
                labels = c('Userness', 'CoTness',  'Assistantness', 'Toolness', 'Systemness')
        ),
        base_message_type = 
            factor(
                base_message_type,
                levels = c('user', 'assistant-cot', 'assistant-final', 'forged_cot'),
                labels = c('User<br>(Excl. CoT Forgery)', 'CoT', 'Assistant', 'User<br>(CoT Forgery)')
            ),
        policy_style = factor(
            policy_style,
            levels = c('base', 'destyled'),
            labels = c('CoT Forgery', 'Destyled CoT Forgery')
        )
    ) %>%
    ggplot() +
    geom_tile(aes(x = base_message_type, y = role_space, fill = mean_roleness), color = 'white') +
    geom_text(aes(x = base_message_type, y = role_space, label = scales::percent(mean_roleness, accuracy = 0.1)), size = 3.0) +
    annotate('rect', xmin = 3.5, xmax = 4.49, ymin = 0.51, ymax = 4.49, color = "red", fill = NA, linewidth = 0.5) +
    scale_fill_gradient(low = 'white', high = 'lawngreen', limits = c(0, 1), labels = scales::percent_format(accuracy = 1)) +
    scale_y_discrete(limits = rev, expand = c(0.0, 0.0)) +
    scale_x_discrete(expand = c(0.0, 0.0)) +
    labs(x = NULL, y = 'Role Space') +
    facet_grid(cols = vars(policy_style)) +
    guides(fill = 'none') +
    theme_iclr(base_size = 11) +
    theme(
        plot.title = ggtext::element_markdown(face = 'bold', color = '#45556c', size = 9.5, margin = margin(b = 5)),
        legend.position = 'none',
        axis.title.y = ggtext::element_markdown(angle = 90, vjust = 0.5, margin = margin(r = 6)),
        axis.title.x = ggtext::element_markdown(angle = 0, margin = margin(t = 10)),
        strip.placement = 'outside',
        strip.text = element_text(face = 'bold'),
        panel.spacing.x = unit(3.0, 'lines'),
        axis.text.x = ggtext::element_markdown(size = 8.5, hjust = .5, angle = 0, margin = margin(t = 2)),
        axis.text.y = ggtext::element_markdown(size = 8.5, hjust = .5, angle = 0, margin = margin(r = 2)),
        panel.grid.major.y = element_blank(),
        axis.ticks.y = element_blank(),
        axis.ticks.x = element_blank()
    )

# ggsave(
#     str_glue('{ws}/experiments/role-injection-analysis/plots/cotness-redteam-rolespace.pdf'),
#     plot = p, width = 7.0, height = 3.0, units = 'in', dpi = 300, device = cairo_pdf
# )
# ggsave(
#     str_glue('{ws}/experiments/role-injection-analysis/plots/cotness-redteam-rolespace.png'),
#     plot = p, width = 7.0, height = 3.0, units = 'in', dpi = 300
# )

# ggsave(
#     str_glue('{ws}/docs/cotness-redteam-rolespace.png'),
#     plot = p, width = 7.0, height = 3.0, units = 'in', dpi = 300
# )

p

# Ternary plot

In [None]:
# Prep ternary inputs - standard triangle shape
ternary_input_df =
    roles_df %>%
    filter(layer_ix == test_layer_ix) %>%
    filter(qualifier_type == 'no_qualifier' & policy_style %in% c('base', 'destyled')) %>%
    filter(., base_message_type == 'forged_cot') %>%
    group_by(policy_style, role_space, redteam_prompt_ix, output_class) %>%
    summarize(., roleness = mean(prob), .groups = 'drop') %>%
    arrange(., redteam_prompt_ix, policy_style, role_space, output_class, roleness) %>%
    group_by(., redteam_prompt_ix, policy_style, output_class) %>%
    mutate(., roleness = roleness/sum(roleness)) %>%
    pivot_wider(
        .,
        id_cols = c(redteam_prompt_ix, policy_style, output_class),
        names_from = role_space,
        values_from = roleness,
        values_fill = 0
    ) %>%
    rename(cotness = 'assistant-cot', assistantness = 'assistant-final', userness = 'user', systemness = 'system') %>%
    mutate(., other_roleness = systemness + assistantness) %>%
    mutate(
        x = cotness - 0.5 * (userness + other_roleness),
        y = (sqrt(3) / 2) * (userness - other_roleness)
    )

triangle_df = tibble(x = c(1, cos(2 * pi/3), cos(4 * pi/3), 1), y = c(0, sin(2 * pi/3), sin(4 * pi/3), 0))

In [None]:
# Draw the plot
ternary_input_df %>%
    mutate(., output_class = factor(
        output_class,
        levels = c('HARMFUL_RESPONSE', 'REFUSAL'),
        labels = c('Attack Success', 'Attack Failure')
    )) %>%
    ggplot(., aes(x = x, y = y, colour = output_class)) +
    geom_polygon(data = triangle_df, aes(x, y), inherit.aes = FALSE, linewidth = 0.5, fill = '#f8fafc', color = '#f1f5f9') +
    # geom_path(aes(group = prompt_ix), alpha = 0.5, linewidth = 0.3) +
    geom_point(size = 0.8, alpha = 0.5) +
    scale_color_manual(
        values = c(
            'Attack Success' = '#fb2c36',
            'Attack Failure' = '#009966'
        )
    ) +
    # Label the three role directions just outside the circle
    coord_equal(xlim = c(-1, 1), ylim = c(-1, 1), expand = 0, clip = 'off') +
    theme_void(base_size = 11) +
    theme(legend.position = 'bottom', plot.margin = margin(5, 5, 5, 5)) +
    annotate('text', x = 1.10 - 0.22, y = 0 + 0.10, label = 'CoTness', hjust = 0, vjust = 0.5) +
    annotate('text', x = cos(2 * pi/3) * 1.05, y = sin(2*pi/3) * 1.05, label = 'Userness', hjust = 0.5, vjust = -0.1) +
    annotate('text', x = cos(4 * pi/3) * 1.05 + 0.10, y = sin(4*pi/3) * 1.05 + 0.10, label = 'All Other Roles', hjust = 0.5, vjust = 1.1) +
    labs(color = NULL) +
    theme(
        legend.position = 'top'
    ) +
    guides(
        fill = guide_legend(
            nrow = 1,
            keyheight = unit(12, "pt"), keywidth = unit(18, "pt"),
            label.theme = ggtext::element_markdown(margin = margin(l = 4, r = 12, t = 2))
        )
    )





In [None]:
ternary_input_df %>%
  mutate(
    edge_pos = userness / (userness + cotness),
  ) %>%
  ggplot(aes(edge_pos, dist_from_edge, colour = output_class)) +
  geom_point(alpha = 0.4, size = 0.8) +
  scale_y_continuous(trans = "sqrt") +
  theme_minimal() +
  labs(x = "Position along CoT â†” User edge", y = "Assistant+System (distance from edge)", colour = NULL)

