In [None]:
# Plot test role probes

In [None]:
library(tidyverse)
library(fs)
# library(ggtext)
library(systemfonts)

ws = '/workspace/deliberative-alignment-jailbreaks'
model_prefix = 'gptoss20'

source(paste0(ws, '/r-utils/plots.r'))

In [None]:
install.packages("arrow")

Installing package into ‘/usr/local/lib/R/site-library’
(as ‘lib’ is unspecified)



# Load data

In [None]:
raw_df =
    read_csv(file.path(ws, str_glue('experiments/da-role-analysis/exports/test-role-projections-{model_prefix}.csv')), trim_ws = FALSE) %>%
    select(., -prompt)

head(raw_df, 1)

# Plots

In [None]:
# Plot
options(repr.plot.width = 10, repr.plot.height = 6)

plot_df =
    raw_df %>%
    group_by(prompt_key, prompt_ix, role_space) %>%
    mutate(
        prob_sma = zoo::rollmean(prob, k = 2, fill = NA, align = 'right', partial = T),
        prob_ewma = zoo::rollapply(
            prob, width = seq_along(prob),
            FUN = function(x) {
                weights = .5^(seq(length(x) - 1, 0))
                sum(x * weights) / sum(weights)
            },
            align = 'right'
        )
    ) %>%
    ungroup() %>%
    mutate(., role_space = fct_relevel(role_space, 'user', 'assistant-cot', 'assistant-final')) %>%
    mutate(., role_space = recode(role_space,
        'system' = 'Systemness',
        'user' = 'Userness',
        'assistant-cot' = 'CoTness',
        'assistant-final' = 'Assistantness'
    )) %>%
    mutate(., base_message_type = fct_relevel(base_message_type, 'user', 'cot', 'assistant-final')) %>%
    mutate(., base_message_type = recode(base_message_type,
        'system' = 'System',
        'user' = 'User',
        'cot' = 'CoT',
        'assistant-final' = 'Assistant',
    )) %>%
    filter(base_message_type != 'System') %>%
    filter(role_space != 'Systemness') %>%
    group_by(prompt_key, role_space) %>%
    arrange(., token_in_prompt_ix, .by_group = T) %>%
    mutate(., token_in_prompt_ix = 1:n()) %>%
    ungroup() 

point_colors = c(
#   "System" = "#90a1b9", # slate
  "User" = "#74d4ff", # sky
  "CoT" = "#f4a8ff", # fuschia
  "Assistant" = "#a4f4cf"  # emerald
)

text_colors = c(
#   "System" = "#62748e", # slate
  "User" = "#00bcff", # sky
  "CoT" = "#ed6bff", # fuschia
  "Assistant" = "#00d492"  # emerald
)


plots = map(group_split(plot_df, prompt_key), function(this_plot_df) {

    # Seg text + index for the start of each consecutive base_message_type segment
    segs =
        this_plot_df %>%
        arrange(prompt_key, prompt_ix, token_in_prompt_ix) %>%
        # Collapse out role_space
        group_by(prompt_ix, token_in_prompt_ix, base_message_type) %>%
        summarize(token = unique(token), .groups = 'drop') %>%
        group_by(prompt_ix) %>%
        mutate(., rleid = consecutive_id(base_message_type)) %>%
        group_by(prompt_ix, rleid, base_message_type) %>%
        arrange(token_in_prompt_ix, .by_group = TRUE) %>%
        summarize(
            .,
            start_ix = first(token_in_prompt_ix),
            # Length of contiguous segment
            start_str = {
                nseg <- dplyr::n()
                k  <- {
                    if (first(base_message_type) == "User") min(1, nseg) 
                    else if (first(base_message_type) == "CoT") min(3, nseg)
                    else min(6, nseg)
                }
                paste0(head(token, k), collapse = "")
            },
            .groups = 'drop'
        ) %>%
        mutate(., start_str = paste0(
            "<span style='color:", text_colors[base_message_type], "'>",
            start_str, "...</span>"
        ))
        
    this_p =
        this_plot_df %>%
        ggplot() +
        geom_point(aes(x = token_in_prompt_ix, y = prob, color = base_message_type)) +
        facet_grid(rows = vars(role_space), switch = 'y') +
        scale_y_continuous(
            labels = scales::percent_format(accuracy = 1),
            limits = c(0, 1),
            expand = expansion(mult = c(0.02, 0.02)),
            breaks = c(0, .5, 1)
        ) +
        scale_x_continuous(
            breaks = segs$start_ix,
            labels = segs$start_str,
            expand = expansion(mult = c(0, 0))
        ) +
        scale_color_manual(values = point_colors, drop = FALSE) +
        labs(
            title = unique(this_plot_df$prompt_key),
            x = NULL,
            y = NULL,
            color = 'Token Style'
        ) +
        coord_cartesian(clip = "off") +
        theme_iclr(base_size = 11) +
        theme(
            legend.position = 'top',
            axis.title.y = ggtext::element_markdown(angle = 90, vjust = 0.5, margin = margin(r = 6)),
            strip.placement = 'outside',
            strip.text.y.left = element_text(angle = 90, face = 'bold'),
            panel.spacing.y = unit(2.0, "lines"),
            axis.text.x = ggtext::element_markdown(size = 10, hjust = 0, angle = 0)
        )

    print(this_p)
})


In [None]:
group_split(plot_df, prompt_key)

In [None]:
plot_df %>%
    group_by(prompt_key, role_space) %>%
    arrange(., token_in_prompt_ix, .by_group = T) %>%
    mutate(., token_in_prompt_ix = 1:n())

In [None]:
plot_df %>%
    group_by(., )

In [None]:
group_split(plot_df, prompt_key)

In [None]:
group_split(plot_df, prompt_key)[[1]] -> this_plot_df

In [None]:
this_plot_df

In [None]:
segs =
    this_plot_df %>%
    arrange(prompt_key, prompt_ix, token_in_prompt_ix) %>%
    # Collapse out role_space
    group_by(prompt_ix, token_in_prompt_ix, base_message_type) %>%
    summarize(token = unique(token), .groups = 'drop') %>%
    group_by(prompt_ix) %>%
    mutate(., rleid = consecutive_id(base_message_type)) %>%
    group_by(prompt_ix, rleid, base_message_type) %>%
    summarize(
        .,
        start_ix = first(token_in_prompt_ix),
        start_str = paste0(head(token, 2), collapse = ''),
        .groups = 'drop'
    ) %>%
    mutate(., start_str = paste0(start_str, '...'))

# Color the tick labels to match base_message_type (uses a discrete hue palette)
bmt_levels = if (is.factor(this_plot_df$base_message_type)) levels(droplevels(this_plot_df$base_message_type)) else unique(this_plot_df$base_message_type)
bmt_cols  = set_names(scales::hue_pal()(length(bmt_levels)), bmt_levels)
segs =
    segs %>%
    mutate(start_str = paste0("<span style='color:", bmt_cols[base_message_type], "'><b>", start_str, "</b></span>"))


segs

In [None]:
this_plot_df %>%
    arrange(prompt_key, prompt_ix, token_in_prompt_ix) %>%
    # Collapse out role_space
    group_by(prompt_ix, token_in_prompt_ix, base_message_type) %>%
    summarize(token = first(token), .groups = 'drop') %>%
    group_by(prompt_ix) %>%
    mutate(., rleid = consecutive_id(base_message_type)) 

In [None]:
this_plot_df %>%
    arrange(prompt_key, prompt_ix, token_in_prompt_ix) %>%
    # Collapse out role_space
    group_by(prompt_ix, token_in_prompt_ix, base_message_type) %>%
    filter(., token_in_prompt_ix == 57)

In [None]:
this_prompt_df %>%
    distinct(token_in_prompt_ix) %>%
    mutate(., token_ix = 1:n())

In [None]:
segs

In [None]:
group_split(plot_df, prompt_key)[[1]] -> this_plot_df

In [None]:
this_plot_df %>%
    select(-prompt) %>%
    arrange(prompt_ix, token_in_prompt_ix) %>%      
    group_by(prompt_ix, token_in_prompt_ix, base_message_type) %>%
    summarize(token = unique(token), .groups = 'drop') %>%
    group_by(prompt_ix, token_in_prompt_ix) %>%
    arrange(token_in_prompt_ix, .by_group = TRUE) %>%
    summarize(
      start_ix  = first(token_in_prompt_ix),
      # concatenate the first TWO tokens in the segment (preserves leading spaces in token text)
      label_raw = paste0(head(token, 2), collapse = ""),
      .groups   = "drop"
    )


In [None]:
this_plot_df %>%
      select(-prompt) %>%
      arrange(prompt_ix, token_in_prompt_ix) %>%
      # Collapse out role_space
      group_by(prompt_ix, token_in_prompt_ix, base_message_type) %>%
      summarize(token = unique(token), .groups = 'drop') %>%
      group_by(prompt_ix) %>%
      mutate(., rleid = consecutive_id(base_message_type)) %>%
      group_by(prompt_ix, rleid, base_message_type) %>%
      summarize(
            .,
            start_ix = first(token_in_prompt_ix),
            start_str = paste0(head(token, 5), collapse = ''),
            .groups = 'drop'
      ) %>%
      mutate(., start_str = paste0(start_str, '...'))


In [None]:
options(repr.plot.width=10, repr.plot.height=8)

In [None]:
plot_df