In [None]:
# Plot test role probes

In [None]:
library(tidyverse)
library(fs)
library(ggtext)
library(systemfonts)
library(arrow)
library(patchwork)

ws = '/workspace/deliberative-alignment-jailbreaks'
model_prefix = 'gptoss20'

source(paste0(ws, '/r-utils/plots.r'))

# Load data

In [None]:
base_path = file.path(ws, 'experiments/role-analysis/activations-agent', model_prefix)

prompts_df =
    read_csv(file.path(base_path, 'agent-outputs-classified.csv'), trim_ws = FALSE) %>%
    select(
        .,
        redteam_prompt_ix,
        variant,
        output_class
    ) %>%
    mutate(., output_class = ifelse(output_class == 'REDIRECTION', 'REFUSAL', output_class))

print(prompts_df)

raw_projections_df =
    read_feather(file.path(ws, str_glue('experiments/role-analysis/projections/agent-role-projections-{model_prefix}.feather')))

probe_mapping_df =
    read_csv(file.path(ws, str_glue('experiments/role-analysis/projections/agent-role-probe-mapping-{model_prefix}.csv')))

head(raw_projections_df, 5)

# Test probes

In [None]:
# Choose primary test layer / roles
print(probe_mapping_df)

test_layer_ix = 12
test_roles = 'assistant-cot,assistant-final,system,tool,user'
test_probe_ix = filter(probe_mapping_df, layer_ix == test_layer_ix, roles == test_roles)$probe_ix
test_probe_ix

In [None]:
# Merge - get all layers, but only roles = test_roles
roles_df =
    raw_projections_df %>%
    inner_join(
        probe_mapping_df %>% filter(roles == test_roles),
        by = 'probe_ix'
    ) %>%
    inner_join(
        .,
        prompts_df %>% select(., redteam_prompt_ix, output_class),
        by = 'redteam_prompt_ix'
    )

head(roles_df, 5)

In [None]:
# Get prompt-level RCI for forged cot by role space
prompt_x_role_space_cotness =
    roles_df %>%
    filter(., layer_ix > 0) %>%
    filter(., roles == test_roles) %>%
    filter(., role == 'tool' & base_message_type == 'forged-cot') %>% # Get forged CoT only
    filter(., role_space %in% c('tool', 'user', 'assistant-cot')) %>%
    group_by(layer_ix, redteam_prompt_ix, role_space) %>%
    summarize(., spaceness = mean(prob), .groups = 'drop') %>%
    pivot_wider(id_cols = c(layer_ix, redteam_prompt_ix), names_from = role_space, values_from = spaceness) %>%
    rename(., cotness = 'assistant-cot', userness = 'user', toolness = 'tool')

prompt_x_role_space_cotness

# Phase portraits

In [None]:
escape_md_html = function(x) {
  #' Escape markdown + HTML so they don't render in plots
  #' 
  #' @param x The string to escape
  #' @return character vector safe for ggtext::element_markdown()

  # Normalize platform newlines
  x = gsub("\r\n|\r", "\n", x, perl = TRUE)
  
  # Treat any literal <br> in the token text as a newline character; keeps own <br/> to add later separate
  x =  gsub("(?i)<br\\s*/?>", "\n", x, perl = TRUE)
  # Escape backslashes in the original text (so they render literally)
  x = gsub("\\\\", "\\\\\\\\", x, perl = TRUE)
  # Escape CommonMark punctuation so it renders literally
  x = gsub("([\\*_|`\\[\\]\\(\\)\\#\\~\\|!\\+\\-=])", "\\\\\\1", x, perl = TRUE)
  # Prevent HTML tag parsing by making < and > literal
  x = gsub("([<>])", "\\\\\\1", x, perl = TRUE)
  # Finally, convert real newline characters to the visible two chars "\n"
  #  (uses \\\\n in the replacement to emit a single backslash + n)
  x = gsub("\n", "\\\\n", x, perl = TRUE)

  return(x)
}

wrap_tokens = function(tokens, line_width, max_lines = 3) {
  #' Wrap tokens by # of characters
  #'
  #' @param tokens A vector of tokens to insert linebreaks into
  #' @param line_width The number of characters per line
  #' @param max_lines The maximum number of lines to retain

  lines = character(0)
  cur = ''
  used = 0L

  for (t in tokens) {
    cand = paste0(cur, t) # append next token as-is
    if (nchar(cand, type = "width") <= line_width || cur == "") {
      cur  <- cand
      used <- used + 1L
    } else {
      # push current line
      lines = c(lines, cur)
      if (length(lines) >= max_lines) {
        # stop collecting; we still count the remaining tokens as unused
        break
      }
      cur = t
      used = used + 1L
    }
  }
  if (length(lines) < max_lines && nzchar(cur)) {
    lines = c(lines, cur)
  }
  truncated = used < length(tokens) || length(lines) > max_lines
  res = list(lines = lines, truncated = truncated)
  return(res)
}

point_colors = c(
  'System' = '#90a1b9', # slate
  'User' = '#00a6f4', # sky
  'CoT' = '#fd9a00', # amber
  'Assistant' = '#00d492',  # emerald,
  'User (CoT Forgery)' = '#ff637e'
)

text_colors = c(
  'System' = '#62748e', # slate
  'User' = '#0084d1', # sky
  'CoT' = '#e17100', # amber
  'Assistant' = '#009966',  # emerald
  'User (CoT Forgery)' = '#ff637e'

)

In [None]:
prompts_df %>%
    filter(variant == 'cot-forgery-injection') %>%
    filter(., str_detect(output_class, 'ATTEMPTED'))


In [None]:
# Use 1 for cotness plot (plot 2), 32 for plot 1
test_redteam_prompt_ix = 1 #1 is interesting, lots of rotation - many tool usages; 11 = clean example of user injectio  only
test_layer_ix = 16
test_roles = 'assistant-cot,assistant-final,system,tool,user'

proj_df =
    raw_projections_df %>%
    inner_join(
        probe_mapping_df %>% filter(roles == test_roles & layer_ix == test_layer_ix),
        by = 'probe_ix'
    ) %>%
    inner_join(
        .,
        prompts_df %>%
            filter(redteam_prompt_ix == test_redteam_prompt_ix) %>%
            select(., redteam_prompt_ix, output_class),
        by = 'redteam_prompt_ix'
    ) %>%
    # Add in bmt/seg id to allow for filtering later if needed
    group_by(redteam_prompt_ix, role_space) %>%
    arrange(., sample_ix, .by_group = T) %>%
    mutate(seg_ix = consecutive_id(base_message_type)) %>%
    group_by(redteam_prompt_ix, role_space, seg_ix) %>%
    mutate(token_in_seg_ix = row_number()) %>%
    ungroup()

head(proj_df, 5)

In [None]:
# Line plot version
point_colors = c(
  'System' = '#90a1b9', # slate
  'User' = '#00a6f4', # sky
  'CoT' = '#fd9a00', # amber
  'Asst' = '#00d492',  # emerald,
  'Tool' = '#7e6cff', 
  'Tool<br>(User Injection)' = '#2bdcff',
  'Tool<br>(CoT Forgery)' = '#ff637e'
)

text_colors = c(
  'System' = '#62748e', # slate
  'User' = '#0084d1', # sky
  'CoT' = '#e17100', # amber
  'Asst' = '#009966',  # emerald
  'Tool' = '#7e6cff', 
  'Tool<br>(User Injection)' = '#2bdcff',
  'Tool<br>(CoT Forgery)' = '#ff637e'
)

plot_df =
    proj_df %>%
    filter(., role_space != 'system' & base_message_type != 'system') %>%
    filter(., token_in_seg_ix <= 100) %>%
    group_by(redteam_prompt_ix, seg_ix, role_space) %>%
    mutate(prob_ewma = zoo::rollapply(prob, seq_along(prob), \(x) .5^(seq(length(x) - 1, 0)) %>% {sum(x * .)/sum(.)}, align = 'right', partial = T)) %>%
    ungroup() %>%    
    mutate(
        .,
        role_space = factor(
            role_space,
            levels = c('user', 'assistant-cot', 'assistant-final', 'tool'),
            labels = c('Userness', 'CoTness', 'Assistantness', 'Toolness')
        ),
        base_message_type = factor(
            base_message_type,
            levels = c('user', 'assistant-cot', 'assistant-final', 'tool', 'forged-cot', 'user-injection'),
            labels = c('User', 'CoT', 'Asst', 'Tool', 'Tool<br>(CoT Forgery)', 'Tool<br>(User Injection)')
        ),
    ) %>%
    filter(., !is.na(base_message_type)) %>%
    group_by(., seg_ix) %>%
    mutate(., tok_length = n()) %>%
    ungroup() %>%
    filter(., tok_length >= 10) %>%
    group_by(redteam_prompt_ix, role_space) %>%
    arrange(., seg_ix, token_in_seg_ix, .by_group = T) %>%
    mutate(., token_in_prompt_ix = 1:n()) %>%
    ungroup()

plots = map(group_split(plot_df, redteam_prompt_ix), function(this_plot_df) {

    # Seg text + index for the start of each consecutive base_message_type segment
    segs =
        this_plot_df %>%
        # Collapse out role_space (lossless)
        distinct(redteam_prompt_ix, token_in_prompt_ix, seg_ix, base_message_type, token) %>%
        group_by(redteam_prompt_ix, seg_ix, base_message_type) %>%
        summarize(
            .,
            start_ix = first(token_in_prompt_ix),
            # Contiguous segment
            start_str = {
                nseg = n()
                bmt = first(base_message_type)
                seg_ix = first(seg_ix)
                line_width_chars = case_when(
                    bmt %in% c('User') ~ 6,
                    bmt %in% c('CoT') ~ 10,
                    bmt %in% c('CoT') ~ 18,
                    bmt %in% c('Assistant') ~ 30,
                    TRUE ~ 12
                )

                tok_vec = {if (bmt == 'User') tail(token, -17) else token}
                # Build up to 3 lines by character length
                wrapped = wrap_tokens(tok_vec, line_width = line_width_chars, max_lines = 2)

                # Escape Markdown/HTML inside the token content
                # Join with <br/>, append ellipsis if truncated
                escaped_lines =
                    map_chr(wrapped$lines, escape_md_html) %>%
                    paste0(., collapse = '<br>') %>%
                    {if(isTRUE(wrapped$truncated)) paste0(., '..') else .}

                # escaped_lines
                bmt
                
              },
            .groups = 'drop'
        ) %>%
        mutate(., start_str = str_glue('<span style="color:{text_colors[as.character(base_message_type)]}">{start_str}</span>')) 
        
    this_p =
        this_plot_df %>%
        ggplot() +
        geom_point(aes(x = token_in_prompt_ix, y = prob, color = base_message_type), size  = 0.5) +
        facet_grid(rows = vars(role_space), switch = 'y') +
        scale_y_continuous(
            labels = scales::percent_format(accuracy = 1),
            limits = c(0, 1),
            expand = expansion(mult = c(0.02, 0.02)),
            breaks = c(0, 1)
        ) +
        scale_x_continuous(
            breaks = segs$start_ix,
            labels = segs$start_str,
            expand = expansion(mult = c(0, 0))
        ) +
        scale_color_manual(values = point_colors, drop = FALSE) +
        labs(
            title = unique(this_plot_df$redteam_prompt_ix),
            x = NULL,
            y = NULL,
            color = 'Token Style'
        ) +
        coord_cartesian(clip = "off") +
        theme_iclr(base_size = 11) +
        theme(
            # plot.title = ggtext::element_markdown(size = 10.5),
            legend.position = 'none',
            axis.title.y = ggtext::element_markdown(angle = 90, vjust = 0.5, margin = margin(r = 6)),
            strip.placement = 'outside',
            strip.text.y.left = element_text(angle = 90, face = 'bold'),
            panel.spacing.y = unit(0.4, 'lines'),
            axis.text.x = ggtext::element_markdown(size = 7.5, hjust = 0, angle = 0)
        )

    print(this_p)
})


ggsave(
    str_glue('{ws}/experiments/role-analysis/plots/roleness-agent.pdf'),
    plot = plots[[1]], width = 7.0, height = 4.0, units = 'in', dpi = 300, device = cairo_pdf
)
ggsave(
    str_glue('{ws}/experiments/role-analysis/plots/roleness-agent.png'),
    plot = plots[[1]], width = 7.0, height = 4.0, units = 'in', dpi = 300
)

ggsave(
    str_glue('{ws}/docs/rolness-agent.png'),
    plot = plots[[1]], width = 7.0, height = 4.0, units = 'in', dpi = 300
)

In [None]:
# Cotness only plot
point_colors = c(
  'System' = '#90a1b9', # slate
  'User' = '#00a6f4', # sky
  'CoT' = '#fd9a00', # amber
  'Asst' = '#00d492',  # emerald,
  'Tool' = '#7e6cff',   
  'Tool<br>(User Injection)' = '#2bdcff',
  'Tool<br>(CoT Forgery)' = '#ff637e'
)

text_colors = c(
  'System' = '#62748e', # slatse
  'User' = '#0084d1', # sky
  'CoT' = '#e17100', # amber
  'Asst' = '#009966',  # emerald
  'Tool' = '#7e6cff', 
  'Tool<br>(User Injection)' = '#2bdcff',
  'Tool<br>(CoT Forgery)' = '#ff637e'
)

# CoTness only
plot_df =
    proj_df %>%
    filter(., role_space != 'system' & base_message_type != 'system') %>%
    filter(., token_in_seg_ix <= 220) %>%
    group_by(redteam_prompt_ix, seg_ix, role_space) %>%
    mutate(prob_ewma = zoo::rollapply(prob, seq_along(prob), \(x) .5^(seq(length(x) - 1, 0)) %>% {sum(x * .)/sum(.)}, align = 'right', partial = T)) %>%
    ungroup() %>%
    filter(., base_message_type %in% c('user', 'assistant-cot', 'assistant-final', 'tool', 'forged-cot', 'user-injection')) %>%
    mutate(
        .,
        base_message_type = factor(
            base_message_type,
            levels = c('user', 'assistant-cot', 'assistant-final', 'tool', 'forged-cot', 'user-injection'),
            labels = c('User', 'CoT', 'Asst', 'Tool', 'Tool<br>(CoT Forgery)', 'Tool (User Injection)')
        )
    ) %>%
    filter(., base_message_type != 'Tool (User Injection)') %>%
    group_by(., seg_ix) %>%
    mutate(., tok_length = n()) %>%
    ungroup() %>%
    filter(., tok_length >= 30) %>%
    filter(., base_message_type != 'Tool<br>(CoT Forgery)' | token_in_seg_ix <= 70) %>%
    group_by(redteam_prompt_ix) %>%
    arrange(., seg_ix, token_in_seg_ix, .by_group = T) %>%
    mutate(., token_in_prompt_ix = 1:n()) %>%
    ungroup() %>%
    filter(., role_space == 'assistant-cot') %>%
    select(., -role_space) %>%
    filter(., seg_ix <= 15)

# Seg text + index for the start of each consecutive base_message_type segment
segs =
    plot_df %>%
    group_by(redteam_prompt_ix, seg_ix, base_message_type) %>%
    summarize(
        .,
        start_ix = first(token_in_prompt_ix),
        tok_len = n(),
        # Contiguous segment
        start_str = {
            nseg = n()
            bmt = first(base_message_type)
            seg_ix = first(seg_ix)
            line_width_chars = case_when(
                bmt %in% c('User') ~ 6,
                bmt %in% c('CoT') ~ 10,
                bmt %in% c('Assistant') ~ 30,
                TRUE ~ 12
            )

            tok_vec = {if (bmt == 'User') tail(token, -17) else token}
            # Build up to 3 lines by character length
            wrapped = wrap_tokens(tok_vec, line_width = line_width_chars, max_lines = 2)

            # Escape Markdown/HTML inside the token content
            # Join with <br/>, append ellipsis if truncated
            escaped_lines =
                map_chr(wrapped$lines, escape_md_html) %>%
                paste0(., collapse = '<br>') %>%
                {if(isTRUE(wrapped$truncated)) paste0(., '..') else .}

            # escaped_lines

            bmt
            },
        .groups = 'drop'
    ) %>%
    mutate(., start_str = str_glue('<span style="color:{text_colors[as.character(base_message_type)]}">{start_str}</span>')) 
        
this_p =
    plot_df %>%
    ggplot() +
    geom_point(aes(x = token_in_prompt_ix, y = prob_ewma, color = base_message_type), size = 0.5, alpha = 0.9) +
    geom_line(aes(x = token_in_prompt_ix, y = prob_ewma, color = base_message_type, group = seg_ix), linewidth = 0.5, alpha = 0.7) +
    scale_y_continuous(
        labels = scales::percent_format(accuracy = 1),
        limits = c(0, 1),
        expand = expansion(mult = c(0.02, 0.02)),
        breaks = c(0, .5, 1)
    ) +
    scale_x_continuous(
        breaks = segs$start_ix,
        labels = segs$start_str,
        expand = expansion(mult = c(0, 0))
    ) +
    scale_color_manual(values = point_colors, drop = FALSE) +
    labs(
        title = NULL,
        x = NULL,
        y = NULL,
        color = 'Token Style'
    ) +
    coord_cartesian(clip = "off") +
    theme_iclr(base_size = 11) +
    theme(
        plot.title = ggtext::element_markdown(face = 'bold', color = '#45556c', size = 9.5, margin = margin(b = 5)),
        legend.position = 'none',
        axis.title.y = ggtext::element_markdown(angle = 90, vjust = 0.5, margin = margin(r = 6)),
        strip.placement = 'outside',
        strip.text.y.left = element_text(angle = 90, face = 'bold'),
        panel.spacing.y = unit(0.4, 'lines'),
        axis.text.x = ggtext::element_markdown(size = 8, hjust = 0, angle = 0),
        panel.grid.major.y = element_blank()
    )

this_p

In [None]:
ggsave(
    str_glue('{ws}/experiments/role-analysis/plots/cotness-agent.pdf'),
    plot = this_p, width = 7.0, height = 4.0, units = 'in', dpi = 300, device = cairo_pdf
)
ggsave(
    str_glue('{ws}/experiments/role-analysis/plots/cotness-agent.png'),
    plot = this_p, width = 7.0, height = 4.0, units = 'in', dpi = 300
)

ggsave(
    str_glue('{ws}/docs/cotness-agent.png'),
    plot = this_p, width = 7.0, height = 4.0, units = 'in', dpi = 300
)