In [None]:
# Plot test role probes

In [None]:
library(tidyverse)
library(fs)
library(ggtext)
library(systemfonts)
library(arrow)
library(patchwork)

ws = '/workspace/deliberative-alignment-jailbreaks'
model_prefix = 'gptoss-20b'

source(paste0(ws, '/r-utils/plots.r'))

# Load data

In [None]:
raw_df_inject = read_csv(str_glue("{ws}/experiments/position-analysis/outputs/rolespace.csv"))
raw_df_early = read_csv(str_glue("{ws}/experiments/position-analysis/outputs/rolespace-early.csv"))

raw_df = bind_rows(
    raw_df_inject %>% mutate(condition = 't100'),
    raw_df_early %>% mutate(condition = 't1')
)

In [None]:
prompt_x_token_ix_x_role =
    raw_df %>%
    group_by(prompt_ix, condition, target_role) %>%
    arrange(., token_in_prompt_ix, .by_group = T) %>%
    mutate(prob = zoo::rollapply(prob, seq_along(prob), \(x) .75^(seq(length(x) - 1, 0)) %>% {sum(x * .)/sum(.)}, align = 'right', partial = T)) %>%
    ungroup() %>%
    mutate(
        .,
        target_role = factor(
            target_role,
            levels = c('system', 'user', 'cot', 'assistant'),
            labels = c('Systemness', 'Userness', 'CoTness', 'Assistantness')
        )
    )

token_ix_x_role =
    prompt_x_token_ix_x_role %>%
    group_by(condition, token_in_prompt_ix, target_role) %>%
    summarize(prob = mean(prob), .groups = 'drop')

token_ix_x_role

# Mark

In [None]:
point_colors = c(
  'System' = '#90a1b9', # slate
  'User' = '#00a6f4', # sky
  'CoT' = '#fd9a00', # amber
  'Assistant' = '#00d492'  # emerald
)

text_colors = c(
  'System' = '#62748e', # slate
  'User' = '#0084d1', # sky
  'CoT' = '#e17100', # amber
  'Assistant' = '#009966'  # emerald
)

plot =
    token_ix_x_role %>%
    filter(., condition == 't100') %>%
    ggplot() +
    geom_line(aes(x = token_in_prompt_ix, y = prob, color = target_role)) +
    annotate('rect', xmin = 100, xmax = 161, ymin = -Inf, ymax = Inf, fill = 'red', alpha = 0.2) +
    facet_wrap(vars(target_role), ncol = 1) +
    scale_x_continuous(
        expand = expansion(mult = c(0, 0.01)),
    ) +
    scale_y_continuous(
        labels = scales::percent_format(accuracy = 1),
        limits = c(0, 1),
        expand = expansion(mult = c(0.02, 0.02)),
        breaks = c(0, .25, .5, .75, 1)
    ) +
    labs(x = 'Token Index', y = NA) +
    theme_iclr(base_size = 9) +
    theme(
        plot.title = ggtext::element_markdown(size = 10.5),
        legend.position = 'none',
        axis.title.y = element_blank(),
        axis.text.x = ggtext::element_markdown(size = 8.5, hjust = 0, angle = 0, margin = margin(t = 2)),
        strip.text = element_text(face = 'bold', size = 7.5, margin = margin(b = 2)),
        panel.spacing.y = unit(0.5, 'lines'),
    )


ggsave(
    str_glue('{ws}/experiments/position-analysis/outputs/plots/position-plots.pdf'),
    plot = plot, width = 6.75, height = 2.8, units = 'in', dpi = 300, device = cairo_pdf
)

ggsave(
    str_glue('{ws}/experiments/position-analysis/outputs/plots/position-plots.png'),
    plot = plot,  width = 6.75, height = 2.8, units = 'in', dpi = 300
)

plot

In [None]:
token_x_systemness_t100 = prompt_x_token_ix_x_role %>% filter(., target_role == 'Systemness') %>% filter(., condition == 't100')
token_x_systemness_t1 = prompt_x_token_ix_x_role %>% filter(., target_role == 'Systemness') %>% filter(., condition == 't1')

plot_df_t100 = map(1:200, function(b) {

    unique_prompts = distinct(token_x_systemness_t100, prompt_ix)
    sample_prompt_ixs =
        slice_sample(unique_prompts, n = nrow(unique_prompts), replace = TRUE) %>%
        mutate(sample_id = row_number())

        token_x_systemness_samples =
            sample_prompt_ixs %>%
            left_join(token_x_systemness_t100, by = "prompt_ix", relationship = 'many-to-many')
        
        token_x_systemness_samples %>%
            group_by(token_in_prompt_ix) %>%
            summarize(systemness = mean(prob), .groups = 'drop') %>%
            mutate(b = b)
    }) %>%
    list_rbind() %>%
    group_by(token_in_prompt_ix) %>%
    summarize(
        .,
        n_prompts = n(),
        systemness_mean = mean(systemness),
        systemness_bot = quantile(systemness, 0.025),
        systemness_top = quantile(systemness, 0.975),
        .groups = 'drop'
    ) %>%
    mutate(region = ifelse(token_in_prompt_ix >= 100 & token_in_prompt_ix <= 161, "system_tagged", "other"))

plot_df_t1 = map(1:200, function(b) {

    unique_prompts = distinct(token_x_systemness_t1, prompt_ix)
    sample_prompt_ixs =
        slice_sample(unique_prompts, n = nrow(unique_prompts), replace = TRUE) %>%
        mutate(sample_id = row_number())

        token_x_systemness_samples =
            sample_prompt_ixs %>%
            left_join(token_x_systemness_t1, by = "prompt_ix", relationship = 'many-to-many')
        
        token_x_systemness_samples %>%
            group_by(token_in_prompt_ix) %>%
            summarize(systemness = mean(prob), .groups = 'drop') %>%
            mutate(b = b)
    }) %>%
    list_rbind() %>%
    group_by(token_in_prompt_ix) %>%
    summarize(
        .,
        n_prompts = n(),
        systemness_mean = mean(systemness),
        systemness_bot = quantile(systemness, 0.025),
        systemness_top = quantile(systemness, 0.975),
        .groups = 'drop'
    ) %>%
    mutate(region = ifelse(token_in_prompt_ix >= 1 & token_in_prompt_ix <= 61, "system_tagged", "other"))

plot_df = bind_rows(
    plot_df_t100 %>% mutate(condition = 't100'),
    plot_df_t1 %>% mutate(condition = 't1')
)

plot_df

In [None]:
plot_df_t100 %>%
    ggplot(aes(x = token_in_prompt_ix)) +
    geom_ribbon(aes(ymin = systemness_bot, ymax = systemness_top), fill = 'gray70', alpha = 0.3) +
    geom_ribbon(aes(ymin = systemness_bot, ymax = systemness_top), data = subset(plot_df_t100, region == "system_tagged"), fill = 'mediumpurple', alpha = 0.3) +
    geom_line(aes(y = systemness_mean, color = region, group = 1)) +
    scale_color_manual(values = c("other" = "gray40", "system_tagged" = "mediumpurple")) +
    geom_vline(xintercept = c(100, 161), linetype = 'dashed', color = 'mediumpurple', alpha = 0.5) +
    annotate("text", x = 130, y = 0.5,  label = "System-tagged\nregion",  size = 2.5, color = "mediumpurple", fontface = "bold") +
    scale_x_continuous(
        expand = expansion(mult = c(0, 0))
    ) +
    scale_y_continuous(
        labels = scales::percent_format(accuracy = 1),
        limits = c(0, 1),
        expand = expansion(mult = c(0.02, 0.02)),
        breaks = c(0, .5, 1)
    ) +
    labs(x = 'Token Index', y = 'Systemness') +
    theme_iclr(base_size = 9) +
    theme(
        plot.title = element_blank(),
        legend.position = 'none',
        axis.text.x = ggtext::element_markdown(size = 6.5, hjust = 0.5, angle = 0, margin = margin(t = 4)),
        axis.title.x = ggtext::element_markdown(size = 7.5, angle = 0, margin = margin(t = 4)),
        axis.title.y = ggtext::element_markdown(size = 7.5, angle = 90, margin = margin(r = 4)),
    )


# ggsave(
#     str_glue('{ws}/experiments/position-analysis/outputs/plots/systemness.pdf'),
#     plot = systemness_plot, width = 3.75, height = 1.5, units = 'in', dpi = 300, device = cairo_pdf
# )

# ggsave(
#     str_glue('{ws}/experiments/position-analysis/outputs/plots/systemness.png'),
#     plot = systemness_plot, width = 3.75, height = 1.5, units = 'in', dpi = 300
# )


In [None]:
# Condition-specific vlines
vline_df = tibble(
    condition = c('t1', 't1', 't100', 't100'),
    xintercept = c(1, 61, 100, 161)
)

# Condition-specific labels
label_df = tibble(
    condition = c('t1', 't100'),
    x = c(31, 130),
    y = c(0.8, 0.5),
    label = c("<system>-tag\nregion", "<system>-tag\nregion")
)

# Facet labels
condition_labels = c(
    't1' = 'System prompt at start',
    't100' = 'System prompt at position 100'
)

comparison_plot =
    plot_df %>%
    ggplot(aes(x = token_in_prompt_ix)) +
    geom_ribbon(aes(ymin = systemness_bot, ymax = systemness_top), fill = 'gray70', alpha = 0.3) +
    geom_ribbon(
        aes(ymin = systemness_bot, ymax = systemness_top),
        data = subset(plot_df, region == "system_tagged"),
        fill = 'mediumpurple',
        alpha = 0.3
    ) +
    geom_line(aes(y = systemness_mean, color = region, group = 1)) +
    scale_color_manual(values = c("other" = "gray40", "system_tagged" = "mediumpurple")) +
    geom_vline(data = vline_df, aes(xintercept = xintercept), linetype = 'dashed', color = 'mediumpurple', alpha = 0.5) +
    geom_text(data = label_df, aes(x = x, y = y, label = label), size = 2.2, color = "mediumpurple", fontface = "bold") +
    facet_wrap(~condition, ncol = 2, labeller = labeller(condition = condition_labels)) +
    scale_x_continuous(expand = expansion(mult = c(0, 0))) +
    scale_y_continuous(
        labels = scales::percent_format(accuracy = 1),
        limits = c(0, 1),
        expand = expansion(mult = c(0.02, 0.02)),
        breaks = c(0, .5, 1)
    ) +
    labs(x = 'Token Index', y = 'Systemness') +
    theme_iclr(base_size = 9) +
    theme(
        plot.title = element_blank(),
        legend.position = 'none',
        strip.text = element_text(face = 'bold', size = 9),
        axis.text.x = ggtext::element_markdown(size = 7.5, hjust = 0.5, angle = 0, margin = margin(t = 4)),
        axis.title.x = ggtext::element_markdown(size = 8, angle = 0, margin = margin(t = 4)),
        axis.title.y = ggtext::element_markdown(size = 8, angle = 90, margin = margin(r = 4)),
    )

ggsave(
    str_glue('{ws}/experiments/position-analysis/outputs/plots/systemness-comparison.pdf'),
    plot = comparison_plot, width = 5.5, height = 2.0, units = 'in', dpi = 300, device = cairo_pdf
)

ggsave(
    str_glue('{ws}/experiments/position-analysis/outputs/plots/systemness-comparison.png'),
    plot = comparison_plot, width = 5.5, height = 2.0, units = 'in', dpi = 300
)
