In [None]:
# Plot test role probes

In [None]:
library(tidyverse)
library(fs)
# library(ggtext)
library(systemfonts)

ws = '/workspace/deliberative-alignment-jailbreaks'
model_prefix = 'gptoss20'

source(paste0(ws, '/r-utils/plots.r'))

# Load data

In [None]:
raw_df =
    read_csv(file.path(ws, str_glue('experiments/da-role-analysis/projections/test-role-projections-{model_prefix}.csv')), trim_ws = FALSE) %>%
    select(., -prompt) %>%
    # Add in bmt/seg id to allow for filtering later if needed
    arrange(prompt_ix, token_in_prompt_ix) %>%
    group_by(prompt_ix) %>%
    mutate(bmt_run = consecutive_id(base_message_type)) %>%
    group_by(prompt_ix, bmt_run) %>%
    mutate(token_in_bmt_ix = row_number()) %>%
    ungroup()

head(raw_df, 1)

# Plots

In [None]:
# Helpers
options(repr.plot.width = 10, repr.plot.height = 6)

escape_md_html = function(x) {
  #' Escape markdown + HTML so they don't render in plots
  #' 
  #' @param x The string to escape
  #' @return character vector safe for ggtext::element_markdown()

  # Normalize platform newlines
  x = gsub("\r\n|\r", "\n", x, perl = TRUE)
  # Treat any literal <br> in the token text as a newline character; keeps own <br/> to add later separate
  x =  gsub("(?i)<br\\s*/?>", "\n", x, perl = TRUE)
  # Escape backslashes in the original text (so they render literally)
  x = gsub("\\\\", "\\\\\\\\", x, perl = TRUE)
  # Escape CommonMark punctuation so it renders literally
  x = gsub("([\\*_|`\\[\\]\\(\\)\\#\\~\\|!\\+\\-=])", "\\\\\\1", x, perl = TRUE)
  # Prevent HTML tag parsing by making < and > literal
  x = gsub("([<>])", "\\\\\\1", x, perl = TRUE)
  # Finally, convert real newline characters to the visible two chars "\n"
  #  (uses \\\\n in the replacement to emit a single backslash + n)
  x = gsub("\n", "\\\\n", x, perl = TRUE)


  return(x)
}

wrap_tokens = function(tokens, line_width, max_lines = 3) {
  #' Wrap tokens by # of characters
  #'
  #' @param tokens A vector of tokens to insert linebreaks into
  #' @param line_width The number of characters per line
  #' @param max_lines The maximum number of lines to retain

  lines = character(0)
  cur = ''
  used = 0L

  for (t in tokens) {
    cand = paste0(cur, t) # append next token as-is
    if (nchar(cand, type = "width") <= line_width || cur == "") {
      cur  <- cand
      used <- used + 1L
    } else {
      # push current line
      lines = c(lines, cur)
      if (length(lines) >= max_lines) {
        # stop collecting; we still count the remaining tokens as unused
        break
      }
      cur = t
      used = used + 1L
    }
  }
  if (length(lines) < max_lines && nzchar(cur)) {
    lines = c(lines, cur)
  }
  truncated = used < length(tokens) || length(lines) > max_lines
  res = list(lines = lines, truncated = truncated)
  return(res)
}

line_width_chars = c(
  "User" = 6,
  "CoT" = 12,
  "Assistant" = 16
)

In [None]:
# Plot full analysis results by segment, with systemness included
plot_df =
    raw_df %>%
    group_by(prompt_key, prompt_ix, role_space) %>%
    mutate(
        prob_sma = zoo::rollmean(prob, k = 2, fill = NA, align = 'right', partial = T),
        prob_ewma = zoo::rollapply(
            prob, width = seq_along(prob),
            FUN = function(x) {
                weights = .5^(seq(length(x) - 1, 0))
                sum(x * weights) / sum(weights)
            },
            align = 'right'
        )
    ) %>%
    ungroup() %>%    
    mutate(., role_space = fct_relevel(role_space, 'system', 'user', 'assistant-cot', 'assistant-final')) %>%
    mutate(., role_space = recode(role_space,
        'system' = 'Systemness',
        'user' = 'Userness',
        'assistant-cot' = 'CoTness',
        'assistant-final' = 'Assistantness'
    )) %>%
    mutate(., base_message_type = fct_relevel(base_message_type, 'system', 'user', 'cot', 'assistant-final')) %>%
    mutate(., base_message_type = recode(base_message_type,
        'system' = 'System',
        'user' = 'User',
        'cot' = 'CoT',
        'assistant-final' = 'Assistant',
    )) %>%
    filter(., token_in_bmt_ix <= 200) %>%
    group_by(prompt_key, role_space) %>%
    arrange(., token_in_prompt_ix, .by_group = T) %>%
    mutate(., token_in_prompt_ix = 1:n()) %>%
    ungroup()

point_colors = c(
  'System' = '#90a1b9', # slate
  'User' = '#74d4ff', # sky
  'CoT' = '#f4a8ff', # fuschia
  'Assistant' = '#a4f4cf'  # emerald
)

text_colors = c(
  'System' = '#62748e', # slate
  'User' = '#00bcff', # sky
  'CoT' = '#ed6bff', # fuschia
  'Assistant' = '#00d492'  # emerald
)

plots = map(group_split(plot_df, prompt_key), function(this_plot_df) {

    # Seg text + index for the start of each consecutive base_message_type segment
    segs =
        this_plot_df %>%
        arrange(prompt_key, prompt_ix, token_in_prompt_ix) %>%
        # Collapse out role_space
        group_by(prompt_ix, token_in_prompt_ix, base_message_type) %>%
        summarize(token = first(token), .groups = 'drop') %>%
        group_by(prompt_ix) %>%
        mutate(., rleid = consecutive_id(base_message_type)) %>%
        group_by(prompt_ix, rleid, base_message_type) %>%
        arrange(token_in_prompt_ix, .by_group = TRUE) %>%
        summarize(
            .,
            start_ix = first(token_in_prompt_ix),
            # Contiguous segment
            start_str = {
                nseg = n()
                bmt = first(base_message_type)
                line_width_chars = case_when(
                    bmt %in% c('User') ~ 8,
                    bmt %in% c('CoT') ~ 12,
                    bmt %in% c('Assistant') ~ 16,
                    TRUE ~ 16
                )

                # Build up to 3 lines by character length
                wrapped = wrap_tokens(token, line_width = line_width_chars, max_lines = 3)

                # Escape Markdown/HTML inside the token content
                # Join with <br/>, append ellipsis if truncated
                escaped_lines =
                    map_chr(wrapped$lines, escape_md_html) %>%
                    paste0(., collapse = '<br>') %>%
                    {if(isTRUE(wrapped$truncated)) paste0(., '…') else .}

                escaped_lines
              },
            .groups = 'drop'
        ) %>%
        mutate(., start_str = str_glue('<span style="color:{text_colors[base_message_type]}">{start_str}</span>'))
        
    this_p =
        this_plot_df %>%
        ggplot() +
        geom_point(aes(x = token_in_prompt_ix, y = prob, color = base_message_type)) +
        facet_grid(rows = vars(role_space), switch = 'y') +
        scale_y_continuous(
            labels = scales::percent_format(accuracy = 1),
            limits = c(0, 1),
            expand = expansion(mult = c(0.02, 0.02)),
            breaks = c(0, .5, 1)
        ) +
        scale_x_continuous(
            breaks = segs$start_ix,
            labels = segs$start_str,
            expand = expansion(mult = c(0, 0))
        ) +
        scale_color_manual(values = point_colors, drop = FALSE) +
        labs(
            title = unique(this_plot_df$prompt_key),
            x = NULL,
            y = NULL,
            color = 'Token Style'
        ) +
        coord_cartesian(clip = "off") +
        theme_iclr(base_size = 11) +
        theme(
            legend.position = 'top',
            axis.title.y = ggtext::element_markdown(angle = 90, vjust = 0.5, margin = margin(r = 6)),
            strip.placement = 'outside',
            strip.text.y.left = element_text(angle = 90, face = 'bold'),
            panel.spacing.y = unit(2.0, "lines"),
            axis.text.x = ggtext::element_markdown(size = 10, hjust = 0, angle = 0)
        )

    print(this_p)
})


In [None]:
raw_df %>% colnames()

In [None]:
# Same plot, without systemness
plot_df =
    raw_df %>%
    filter(base_message_type != 'system') %>%
    filter(role_space != 'system') %>%
    mutate(., role_space = fct_relevel(role_space, 'user', 'assistant-cot', 'assistant-final')) %>%
    mutate(., role_space = recode(role_space,
        'user' = 'Userness',
        'assistant-cot' = 'CoTness',
        'assistant-final' = 'Assistantness'
    )) %>%
    mutate(., base_message_type = fct_relevel(base_message_type, 'user', 'cot', 'assistant-final')) %>%
    mutate(., base_message_type = recode(base_message_type,
        'user' = 'User',
        'cot' = 'CoT',
        'assistant-final' = 'Assistant',
    )) %>%
    filter(., token_in_bmt_ix <= 240) %>%
    group_by(prompt_key, role_space) %>%
    arrange(., token_in_prompt_ix, .by_group = T) %>%
    mutate(., token_in_prompt_ix = 1:n()) %>%
    ungroup() 

point_colors = c(
  'User' = '#74d4ff', # sky
  'CoT' = '#f4a8ff', # fuschia
  'Assistant' = '#a4f4cf'  # emerald
)

text_colors = c(    
  'User' = '#00bcff', # sky
  'CoT' = '#ed6bff', # fuschia
  'Assistant' = '#00d492'  # emerald
)

plots = map(group_split(plot_df, prompt_key), function(this_plot_df) {

    # Seg text + index for the start of each consecutive base_message_type segment
    segs =
        this_plot_df %>%
        arrange(prompt_key, prompt_ix, token_in_prompt_ix) %>%
        # Collapse out role_space
        group_by(prompt_ix, token_in_prompt_ix, base_message_type) %>%
        summarize(token = first(token), .groups = 'drop') %>%
        group_by(prompt_ix) %>%
        mutate(., rleid = consecutive_id(base_message_type)) %>%
        group_by(prompt_ix, rleid, base_message_type) %>%
        arrange(token_in_prompt_ix, .by_group = TRUE) %>%
        summarize(
            .,
            start_ix = first(token_in_prompt_ix),
            # Contiguous segment
            start_str = {
                nseg = n()
                bmt = first(base_message_type)
                line_width_chars = case_when(
                    bmt %in% c('User') ~ 14,
                    bmt %in% c('CoT') ~ 20,
                    bmt %in% c('Assistant') ~ 28,
                    TRUE ~ 16
                )

                # Build up to 3 lines by character length
                wrapped = wrap_tokens(token, line_width = line_width_chars, max_lines = 3)

                # Escape Markdown/HTML inside the token content
                # Join with <br/>, append ellipsis if truncated
                escaped_lines =
                    map_chr(wrapped$lines, escape_md_html) %>%
                    paste0(., collapse = '<br>') %>%
                    {if(isTRUE(wrapped$truncated)) paste0(., '…') else .}

                if (length(escaped_lines) > 1) print(escaped_lines)

                escaped_lines
              },
            .groups = 'drop'
        ) %>%
        mutate(., start_str = str_glue('<span style="color:{text_colors[base_message_type]}">{start_str}</span>'))
        
    this_p =
        this_plot_df %>%
        ggplot() +
        geom_point(aes(x = token_in_prompt_ix, y = prob, color = base_message_type)) +
        facet_grid(rows = vars(role_space), switch = 'y') +
        scale_y_continuous(
            labels = scales::percent_format(accuracy = 1),
            limits = c(0, 1),
            expand = expansion(mult = c(0.02, 0.02)),
            breaks = c(0, .5, 1)
        ) +
      scale_x_continuous(
            breaks = segs$start_ix,
            labels = segs$start_str,
            expand = expansion(mult = c(0, 0))
        ) +
        scale_color_manual(values = point_colors, drop = FALSE) +
        labs(
            title = unique(this_plot_df$prompt_key),
            x = NULL,
            y = NULL,
            color = 'Token Style'
        ) +
        coord_cartesian(clip = "off") +
        theme_iclr(base_size = 11) +
        theme(
            legend.position = 'top',
            axis.title.y = ggtext::element_markdown(angle = 90, vjust = 0.5, margin = margin(r = 6)),
            strip.placement = 'outside',
            strip.text.y.left = element_text(angle = 90, face = 'bold'),
            panel.spacing.y = unit(2.0, "lines"),
            axis.text.x = ggtext::element_markdown(size = 10, hjust = 0, angle = 0)
        )

    print(this_p)
})


In [None]:
# CoTness only
