# Load libraries

In [None]:
library(ggplot2)
library(tidyverse)
library(gfile)

# Load data

In [None]:
# set an experiment name
experiment_name <- "gpt-j-6B_counterfact_k0_sd1_tracing_sweep_n2000"
DATA_DIR = "" # your data directory here
data_path <- sprintf("%s/%s.csv", DATA_DIR, experiment_name)

orig_data <- read_csv(data_path)
head(orig_data)
names(orig_data)

# Globals

In [None]:
theme = theme(axis.ticks = element_blank(),
        axis.text = element_text(size=15, color='black'),
        axis.line.x = element_line(colour = 'black', size = .6),
        axis.line.y = element_line(colour = 'black', size = .6),
        panel.background = element_blank(),
        panel.border = element_blank(),
        panel.grid = element_line(colour = '#DFDFDF', size = .2),
        text = element_text(size=18, family="serif"),
        axis.title.x = element_text(size = 18),
        axis.title.y = element_text(size = 18),
        plot.title = element_text(size = 20, hjust=0.5),
        legend.text = element_text(size=16),
        legend.box.background = element_blank(),
        legend.position = "right",
        panel.spacing=unit(1.5,"lines"),
        )

cbp1 <- c("#E69F00", "#56B4E9", "#009E73",
          "#0072B2", "#D55E00", "#999999", "#F0E442",  "#CC79A7")


# Postprocess tracing data

In [None]:
options(repr.plot.width=8, repr.plot.height=6)
# add a few variable to data
orig_data <- orig_data %>%
  mutate(
    last_subj_token = (subj_end_idx-1 == token_idx),
    restore_effect = restore_prob - corrupted_pred_prob,
    module = ifelse(module=='None', 'all', module),
    corruption_effect = corrupted_pred_prob - orig_pred_prob,
    fraction_restored = pmax(0, restore_effect / abs(corruption_effect)) # bound this below by 0
  )
# get per-data-point statistics
per_data_point <- orig_data %>%
  group_by(input_id, module, trace_window_size) %>% 
  summarize(
    experiment_name=min(experiment_name),
    task=min(task),
    split=min(split),
    orig_pred_prob=min(orig_pred_prob), # this min does not do anything
    corrupted_pred_prob=min(corrupted_pred_prob), # this min does not do anything
    corruption_effect = min(corruption_effect), # this min does not do anything
    max_effect = max(restore_effect),
    mean_effect = mean(restore_effect),
    seq_len = max(token_idx)+1,
    max_fraction_restored = max(fraction_restored),
    ) 
# add extra variables that rely on per-data-point statistics
data <- left_join(
  orig_data,
  per_data_point %>% 
    select(input_id, module, trace_window_size, max_effect, max_fraction_restored),
  by=c('input_id', 'module', 'trace_window_size')
)
data <- data %>%
  mutate(
    last_seq_token = (token_idx == seq_len-1),
    is_max_effect = (restore_effect >= max_effect),
  )

# turn trace_window into leveled factor
data <- data %>%
  mutate(trace_window_size_str = sprintf("Tracing Window Size: %s", trace_window_size),
         trace_window_size_str = factor(trace_window_size_str, 
         levels = c("Tracing Window Size: 1", "Tracing Window Size: 3", "Tracing Window Size: 5", "Tracing Window Size: 10"))
  )


# Final plots 1: tracing effects by window size

In [None]:
options(repr.plot.width=13, repr.plot.height=6)

xticks <- seq(0, 28, by=4)
xticks[1] <- 1
grid_ymax = .3
is_correct_filter <- 1
min_orig_prob <- 0
min_restoration_effect <- 0
MODULE = 'mlp'

TITLE = expression("Causal Tracing shows larger effects when multiple layers are denoised")
# TITLE = expression(paste("Earlier layers show the strongest tracing effects ", italic("on average")))
(avg_plot <- data %>% 
  filter(module==MODULE) %>%
  filter(orig_pred_prob > min_orig_prob) %>%
  filter(restore_effect > min_restoration_effect) %>%
  filter(is_correct >= is_correct_filter) %>% 
  group_by(input_id, layer_idx, trace_window_size_str) %>%
  summarise(max_effect = max(fraction_restored)) %>%
  group_by(layer_idx, trace_window_size_str) %>%
  summarise(mean_effect = mean(max_effect)) %>%
  ggplot(aes(layer_idx, mean_effect)) + 
  geom_point() + 
  ggtitle(TITLE) + 
  xlab("Layer in GPT-J") + 
  ylab("Denoising Effect") + 
  theme + 
  annotate("rect", xmin = 4.9, xmax = 5.1, ymin = 0, ymax = grid_ymax,
         alpha = .7, fill = "#FF0909") + 
  annotate("rect", xmin = 3, xmax = 8, ymin = 0, ymax = grid_ymax,
         alpha = .2, fill = "#2009FF") + 
  annotation_custom(grid.text("ROME Edit Layer ", x=0.694,  y=0.89, gp=gpar(col = "#FF0909", fontsize=14, fontfamily='serif'))) +   
  annotation_custom(grid.text(label="MEMIT Edit Layers", check.overlap = TRUE, x=0.71,  y=0.815, gp=gpar(col = "#2009FF", fontsize=14, fontfamily='serif'))) + 
  scale_x_continuous(labels = xticks, breaks=xticks-1) + 
  theme(axis.title.x = element_text(size=18, color='black', angle=0, vjust=-.5, hjust=.5),
        axis.title.y = element_text(size=18, color='black', angle=90, vjust=2, hjust=.5),
        plot.title = element_text(size=20, color='black', angle=0, vjust=.5, hjust=.5)
  ) + 
  facet_wrap(~trace_window_size_str,  nrow=1)
)

# TITLE = sprintf("Causal Tracing often shows information is localized in mid-to-late layers")
TITLE = sprintf("Causal Tracing peak distribution shifts outward with lower window size")
(distr_plot <- data %>% 
  filter(module==MODULE) %>%
  filter(orig_pred_prob > min_orig_prob) %>%
  filter(restore_effect > min_restoration_effect) %>%
  filter(is_correct >= is_correct_filter) %>%
  filter(is_max_effect == 1) %>%
  ggplot(aes(layer_idx)) + 
  geom_histogram(binwidth = 1, 
                 size=.1) +  
  ggtitle(TITLE) + 
  xlab("Layer in GPT-J where Causal Tracing effects peak") + 
  ylab("Count") + 
  theme + 
  annotate("rect", xmin = 4.9, xmax = 5.1, ymin = 0, ymax = 200,
         alpha = .7, fill = "#FF0909") + 
  annotate("rect", xmin = 3, xmax = 8, ymin = 0, ymax = 200,
         alpha = .2, fill = "#2009FF") + 
  annotation_custom(grid.text("ROME Edit Layer ", x=0.694,  y=0.89, gp=gpar(col = "#FF0909", fontsize=14, fontfamily='serif'))) +   
  annotation_custom(grid.text(label="MEMIT Edit Layers", check.overlap = TRUE, x=0.71,  y=0.815, gp=gpar(col = "#2009FF", fontsize=14, fontfamily='serif'))) + 
  scale_x_continuous(labels = xticks, breaks=xticks-1) + 
  theme(axis.title.x = element_text(size=18, color='black', angle=0, vjust=-.5, hjust=.5),
        plot.title = element_text(size=20, color='black', angle=0, vjust=.5, hjust=.5)
  ) + 
  facet_wrap(~trace_window_size_str,  nrow=1)
)

options(repr.plot.width=8, repr.plot.height=6)
ggsave('avg_plot.pdf', avg_plot, width=16, height=4, device=cairo_pdf)
colab::download_file('avg_plot.pdf') 
ggsave('distr_plot.pdf', distr_plot, width=16, height=4, device=cairo_pdf)
colab::download_file('distr_plot.pdf') 

# Final plots 2: Tracing effects with overlay for ROME/MEMIT choices

In [None]:
options(repr.plot.width=8, repr.plot.height=6)

TRACE_WINDOW_SIZE = 5
xticks <- seq(0, 28, by=4)
xticks[1] <- 1

grid_ymax = .06
TITLE = expression(paste("Causal Tracing effects by layer ", italic("on average across data")))
# TITLE = expression(paste("Causal Tracing effects are largest in earlier layers ", italic("on average across data")))
(avg_plot <- data %>% 
  filter(module==MODULE) %>%
  filter(orig_pred_prob > min_orig_prob) %>%
  filter(restore_effect > min_restoration_effect) %>%
  filter(is_correct >= is_correct_filter) %>% 
  filter(trace_window_size == TRACE_WINDOW_SIZE) %>%
  # filter to last subj token positions for comparison with ROME design choices
  select(input_id, last_subj_token, layer_idx, trace_window_size_str, restore_effect) %>%
  filter(last_subj_token == TRUE) %>%
  # group_by(input_id, layer_idx, trace_window_size_str) %>%
  # summarise(max_effect = max(restore_effect)) %>%
  group_by(layer_idx, trace_window_size_str) %>%
  # summarise(mean_effect = mean(max_effect)) %>%
  summarise(mean_effect = mean(restore_effect)) %>%
  ggplot(aes(layer_idx, mean_effect)) + 
  geom_point() + 
  ggtitle(TITLE) + 
  xlab("Layer in GPT-J") + 
  ylab("Denoising Effect") + 
  annotate("rect", xmin = 4.9, xmax = 5.1, ymin = 0, ymax = grid_ymax,
         alpha = .7, fill = "#FF0909") + 
  annotate("rect", xmin = 3, xmax = 8, ymin = 0, ymax = grid_ymax,
         alpha = .2, fill = "#2009FF") + 
  annotation_custom(grid.text("ROME Edit Layer ", x=0.802,  y=0.88, gp=gpar(col = "#FF0909", fontsize=16, fontfamily='serif'))) +   
  annotation_custom(grid.text(label="MEMIT Edit Layers", check.overlap = TRUE, x=0.81,  y=0.82, gp=gpar(col = "#2009FF", fontsize=16, fontfamily='serif'))) + 
  theme + 
  scale_x_continuous(labels = xticks, breaks=xticks-1) + 
  theme(axis.title.x = element_text(size=18, color='black', angle=0, vjust=-.5, hjust=.5),
        plot.title = element_text(size=20, color='black', angle=0, vjust=.5, hjust=.5)
  )
)

TITLE = "How often does Causal Tracing peak in each layer?"
# TITLE = sprintf("Peak Causal Tracing effects often lie outside layers chosen for editing")
YMAX = 200
(distr_plot <- data %>% 
  filter(module==MODULE) %>%
  filter(orig_pred_prob > min_orig_prob) %>%
  filter(restore_effect > min_restoration_effect) %>%
  filter(is_correct >= is_correct_filter) %>%
  filter(is_max_effect) %>%
  filter(trace_window_size == TRACE_WINDOW_SIZE) %>%
  ggplot(aes(layer_idx)) + 
  geom_histogram(binwidth = 1, 
                 size=.1) +  
  ggtitle(TITLE) + 
  xlab("Layer in GPT-J where Causal Tracing effects peak") + 
  ylab("Num. Points") + 
  annotate("rect", xmin = 4.9, xmax = 5.1, ymin = 0, ymax = YMAX,
         alpha = 1, fill = "#FF0909") + 
  annotate("rect", xmin = 3, xmax = 8, ymin = 0, ymax = YMAX,
         alpha = .2, fill = "#2009FF") + 
  annotation_custom(grid.text("ROME Edit Layer ", x=0.788,  y=0.876, gp=gpar(col = "#FF0909", fontsize=20, fontfamily='serif'))) +   
  annotation_custom(grid.text(label="MEMIT Edit Layers", check.overlap = TRUE, x=0.801,  y=0.810, gp=gpar(col = "#2009FF", fontsize=20, fontfamily='serif'))) + 
  theme + 
  scale_x_continuous(labels = xticks, breaks=xticks-1) + 
  theme(axis.text = element_text(size=20, color='black'),
        axis.title.x = element_text(size=24, color='black', angle=0, vjust=-.5, hjust=.5),
        axis.title.y = element_text(size=24, color='black', angle=90, vjust=.5, hjust=.5),
        plot.title = element_text(size=25, color='black', angle=0, vjust=.5, hjust=.5)
  )
)

# print fraction of points in the 4-9 (3-8 starting at 0) range for ws 1
data %>% 
  filter(module==MODULE) %>%
  filter(trace_window_size == TRACE_WINDOW_SIZE) %>%
  filter(orig_pred_prob > min_orig_prob) %>%
  filter(restore_effect > min_restoration_effect) %>%
  filter(is_correct >= is_correct_filter) %>%
  filter(is_max_effect == 1) %>%
  summarise(inside = layer_idx %in% c(3,4,5,6,7,8)) %>%
  pull(inside) %>% table

ggsave('avg_plot.pdf', avg_plot, width=8, height=6, device=cairo_pdf)
colab::download_file('avg_plot.pdf') 
ggsave('distr_plot.pdf', distr_plot, width=8, height=6, device=cairo_pdf)
colab::download_file('distr_plot.pdf') 

## Join tracing and editing results

In [None]:
tracing_data <- data %>%
  mutate(case_id = input_id,
         module == 'mlp')

# set an individual experiment to load

# FT fact forcing
# experiment_name <- "gpt-j-6B_FT_outputs_cf_editing_sweep_ws-_5__layer-all_fact-forcing_n2000"

# ROME error injection
experiment_name <- "gpt-j-6B_ROME_outputs_cf_editing_sweep_ws-_1__layer-all_n2000"

editing_path <- sprintf("%s/%s.csv", DATA_DIR, experiment_name)
editing_data <- read_csv(editing_path)
editing_data %>% select(case_id) %>% unique() %>% nrow()

TRACE_WINDOW_SIZE = 5
data <- data %>%
  mutate(case_id = input_id)

# add essence_ppl_diff transformations to editing_data
MAX_ppl_diff <- 5
safe_inverse <- function(x){
  if (x==0){
    return(2e16)
  } else{
    return(1/x)
  }
}
editing_data <- editing_data %>%
  mutate(essence_ppl_diff_bounded = pmax(0, pmin(essence_ppl_diff, MAX_ppl_diff)),
         essence_diff_normed = 1 - essence_ppl_diff_bounded / MAX_ppl_diff,
         target_score_v2 = 4/(safe_inverse(rewrite_score) + safe_inverse(paraphrase_score) + safe_inverse(neighborhood_score) + safe_inverse(essence_diff_normed)),
         target_score_mean = (rewrite_score + paraphrase_score + neighborhood_score) / 3,
         target_score_mean_v2 = (rewrite_score + paraphrase_score + neighborhood_score + essence_diff_normed) / 4,
         )

# MAKE FIRST STYLE OF JOINED DATA. 
# one record per edit per datapoint. includes all tracing variables pertaining to the MAX effect per point
combined_data <- left_join(tracing_data, editing_data, join_by='case_id')
combined_data_max_MLP <- combined_data %>%
  filter(is_max_effect == TRUE, module=='mlp') %>%
  mutate(layer_discrepancy = edit_central_layer - layer_idx,
         max_tracing_layer = layer_idx)
combined_data_max_MLP <- combined_data_max_MLP %>% select(-layer_idx)
combined_data_max_MLP %>% select(case_id) %>% unique() %>% nrow()

# SECOND STYLE OF JOINED DATA.
# one record per edit per datapoint.

# add max-token and subj-token effects for each point+layer to editing data
max_per_layer_per_record <- tracing_data %>% 
  filter(module=='mlp', trace_window_size==TRACE_WINDOW_SIZE) %>%
  group_by(case_id, layer_idx) %>%
  summarise(
    max_token_effect=max(restore_effect),
    max_fraction_restored=max(fraction_restored),
    # variables below are constant per point, so mean does nothing. need them for later filtering 
    is_correct=mean(is_correct),
    corruption_effect=mean(corruption_effect),
    orig_pred_prob=mean(orig_pred_prob),
    ) 
subj_effect_per_layer_per_record <- tracing_data %>% 
  filter(module=='mlp', trace_window_size==TRACE_WINDOW_SIZE) %>%
  group_by(case_id, layer_idx) %>%
  filter(last_subj_token==TRUE) %>%
  summarise(subj_effect=restore_effect,
            subj_end_idx=subj_end_idx,
            seq_len=seq_len,
            subj_effect_fraction = restore_effect / abs(corruption_effect))
per_edit_data <- editing_data %>%
  mutate(layer_idx = edit_central_layer) %>% # need to add layer_idx for matching tracing by layer_idx
  left_join(
    max_per_layer_per_record,
    join_by=c("case_id", "layer_idx")
  ) %>%
  left_join(
    subj_effect_per_layer_per_record,
    join_by=c("case_id", "layer_idx")
  )
# add indicator for if tracing effect is in the 99th percentile of tracing effects
effect_cutoff <- quantile(data %>% filter(module == 'mlp') %>% pull(restore_effect), .95)
print(effect_cutoff)
# add some more variables
# - discrete restoration effect variable
# - subj position
per_edit_data <- per_edit_data %>%
  mutate(large_tracing_effect = max_token_effect > effect_cutoff,
         orig_pred_prob_disc = ifelse(orig_pred_prob < .05, "<.05", 
                    ifelse(orig_pred_prob < .1, ".05-.1", 
                    ifelse(orig_pred_prob < .15, ".1-.15", 
                    ifelse(orig_pred_prob < .2, ".15-.2", 
                    ifelse(orig_pred_prob < .25, ".2-.25", '>.25'))))),
        subj_position = subj_end_idx / seq_len
  )


# FINAL PLOT: score vs. restoration effect by layer

In [None]:
is_correct_filter = 1
# edit_layers = c(0,12,16,20)
# edit_layers = c(0,4,8,12)
# edit_layers = c(0, 4, 5, 8)
# edit_layers = c(0, 12, 16, 20)
# edit_layers = c(8, 12, 16, 20)
# edit_layers = c(0, 8, 16, 24)
# edit_layers = c(0, 4, 8, 12)
# options(repr.plot.width=13, repr.plot.height=4)
# NROW=1

edit_layers = c(0, 4, 8, 12, 16, 20, 24, 27)
options(repr.plot.width=13, repr.plot.height=8)
NROW=2

x_ub = 1
min_orig_prob = 0
layer_levels <- c()
for (i in edit_layers){
  layer_levels <- c(layer_levels, sprintf("Layer %s", i+1))
}

x = "max_fraction_restored"
# x = "max_token_effect"
# x = "subj_effect_fraction"
# x = "subj_effect"
point_alpha = .1
CI_alpha = .15
CI_fill = 'orange'
show_se = TRUE
smooth_method = 'lm'
# smooth_method = 'loess'
outcome = "score"
ovr_name = "target_score_mean"
line_size=1.3

n_unique_points <- per_edit_data %>%
  filter(is_correct >= is_correct_filter) %>%
  filter(orig_pred_prob >= min_orig_prob) %>%
  filter(get(x) < x_ub) %>% 
  pull(case_id) %>% 
  unique() %>% 
  length()
sprintf("Plotting with %s points", n_unique_points)

qs <- quantile(per_edit_data$rewrite_score, c(.4, 1))
TITLE="Rewrite Score by Tracing Effect (Grouped by Edit Layer)"
# TITLE="Fact Forcing Rewrite Score by Tracing Effect (Grouped by Edit Layer)"
# TITLE="ROME Rewrite Score by Tracing Effect (Error Injection)"
# TITLE="ROME Rewrite Score by Last Subject Token Tracing Effect (Error Injection)"
(rewrite_plot <- per_edit_data %>%
  filter(is_correct >= is_correct_filter) %>%
  filter(orig_pred_prob >= min_orig_prob) %>%
  filter(edit_central_layer %in% edit_layers) %>%
  mutate(edit_central_layer = sprintf("Layer %s", edit_central_layer+1),
           edit_central_layer = factor(edit_central_layer, levels=layer_levels)
         ) %>%
  ggplot(aes_string(x, sprintf("rewrite_%s", outcome))) + 
  geom_point(alpha=point_alpha) +
  geom_smooth(method=smooth_method, se=show_se, color='orange', alpha=CI_alpha, fill=CI_fill, size=line_size) + 
  geom_abline(slope=1, intercept=0, color='#F94343', linetype=2, size=.8, inherit.aes=FALSE) + 
  # geom_segment(aes(x=0, xend=1, y=0, yend=1), color='Red', linetype=1, size=.25, inherit.aes = FALSE) + 
  ggtitle(TITLE) + 
  xlab("Tracing Effect") + 
  # xlab("Tracing Effect at Last Subject Token") + 
  ylab("Rewrite Score") + 
  theme + 
  coord_cartesian(xlim=c(0,x_ub), ylim=c(0, 1)) + 
  theme(axis.title.y = element_text(size=20, color='black', angle=90, vjust=1.5, hjust=.5),
        strip.text.x = element_text(size=16, color='black', angle=0, vjust=.5, hjust=.5),
        axis.text = element_text(size=16, color='black'),
        axis.title.x = element_text(size=20, color='black', angle=0, vjust=0, hjust=.5),
        plot.title = element_text(size=22, color='black', angle=0, vjust=.5, hjust=.5)      
  ) + 
  facet_wrap(~edit_central_layer, nrow=NROW))
  
ggsave('rewrite_plot.pdf', rewrite_plot, width=13, height=4*NROW, dpi=600, device=cairo_pdf)
colab::download_file('rewrite_plot.pdf')

options(repr.plot.width=8, repr.plot.height=6)
TITLE="ROME Rewrite Score by Tracing Effect at Layer 6"
show_se=TRUE
line_size=1.4
(rewrite_plot <- per_edit_data %>%
  filter(is_correct >= is_correct_filter) %>%
  filter(orig_pred_prob >= min_orig_prob) %>%
  filter(edit_central_layer == 5) %>%
  ggplot(aes_string(x, sprintf("rewrite_%s", outcome))) + 
  geom_point(alpha=point_alpha) +
  geom_abline(slope=1, intercept=0, color='#F94343', linetype=2, size=.8) + 
  geom_smooth(method=smooth_method, se=show_se, color='orange', alpha=CI_alpha, fill=CI_fill, size=line_size) + 
  ggtitle(TITLE) + 
  xlab("Tracing Effect (Fraction Restored)") + 
  ylab("Rewrite Score") + 
  theme + 
  coord_cartesian(xlim=c(0,x_ub), ylim=c(0, 1)) +
  theme(axis.title.y = element_text(size=21, color='black', angle=90, vjust=1.5, hjust=.5),
        axis.text = element_text(size=18, color='black'),
        axis.title.x = element_text(size=21, color='black', angle=0, vjust=0, hjust=.5),
        plot.title = element_text(size=25, color='black', angle=0, vjust=.5, hjust=.5)      
  )
)
ggsave('rewrite_plot.pdf', rewrite_plot, width=8, height=6, device=cairo_pdf)
colab::download_file('rewrite_plot.pdf')


x <- per_edit_data %>%
  filter(is_correct >= is_correct_filter) %>%
  filter(orig_pred_prob >= min_orig_prob) %>%
  filter(edit_central_layer == 5) %>%
  select(rewrite_score, max_fraction_restored, subj_effect_fraction)

cor.test(x$rewrite_score, x$max_fraction_restored)
cor.test(x$rewrite_score, x$subj_effect_fraction)

# ALL experiments: Tracing vs. editing

In [None]:
conditions <- list()
conditions[[1]] <- c("FT", "1")
conditions[[2]] <- c("FT", "5")
conditions[[3]] <- c("ROME", "1")
conditions[[4]] <- c("MEMIT", "5")
objectives <- c("Falsehood Injection", "Tracing Reversal", "Fact Erasure", "Fact Forcing", "Fact Amplification")

safe_load_data <- function(experiment_name){  
  possibleError <- tryCatch(
    read_csv(gfile::GFile(sprintf("/cns/mf-d/home/brain-frameworks/cloudsync/belief-localization-xgcp/output/%s.csv", experiment_name))),
    error=function(e) e
  )
  if(!inherits(possibleError, "error")){
    return(possibleError)
  } else{
    return(data.frame())
  }
}

made_first_df <- FALSE
for (condition in conditions){
  method <- condition[1]
  window_size <- condition[2]
  for (objective in objectives){
    if (objective == "Falsehood Injection"){
      obj_tag <- ""
    }
    if (objective == "Tracing Reversal"){
      obj_tag <- "_trace-reverse"
    }
    if (objective == "Fact Erasure"){
      obj_tag <- "_fact-erasure"
    }
    if (objective == "Fact Forcing"){
      obj_tag <- "_fact-forcing"
    }
    if (objective == "Fact Amplification"){
      obj_tag <- "_fact-amplification"
    }
    experiment_name <- sprintf("gpt-j-6B_%s_outputs_cf_editing_sweep_ws-_%s__layer-all%s_n2000", method, window_size, obj_tag)
    print(sprintf("Trying to load %s", experiment_name))
    if (made_first_df){
      editing_data <- safe_load_data(experiment_name)
      if (nrow(editing_data) == 0) next
      editing_data$objective = objective
      print("Loaded!")
      running_editing_data <- bind_rows(running_editing_data, editing_data)
    } else {
      running_editing_data <- safe_load_data(experiment_name)
      if (nrow(running_editing_data) == 0) next
      running_editing_data$objective = objective
      print("Loaded!")
      made_first_df <- TRUE
    }
    }
}
editing_data <- running_editing_data %>%
  filter(edit_central_layer >= 0) %>%
  select(case_id, rewrite_score, paraphrase_score, neighborhood_score, essence_ppl_diff, target_score, edit_method, edit_central_layer, objective, edit_window_size) %>%
  mutate(objective = as.factor(objective, levels=objectives))

TRACE_WINDOW_SIZE = 5
tracing_data <- data %>%
  filter(trace_window_size==TRACE_WINDOW_SIZE) %>%
  mutate(case_id = input_id) %>%
  select(case_id, trace_window_size, layer_idx, token_idx, module, orig_pred_prob, corrupted_pred_prob, corruption_effect, seq_len, is_correct, last_seq_token, is_subj_token, last_subj_token, restore_effect, corruption_effect, fraction_restored, max_effect, max_fraction_restored, is_max_effect, trace_window_size_str)
print("Tracing # per trace window size")
table(tracing_data$trace_window_size)

data <- data %>%
  mutate(case_id = input_id)

# add essence_ppl_diff transformations to editing_data
MAX_ppl_diff <- 5
safe_inverse <- function(x){
  if (x==0){
    return(2e16)
  } else{
    return(1/x)
  }
}
# add cols and is_correct variable
editing_data <- editing_data %>%
  mutate(essence_ppl_diff_bounded = pmax(0, pmin(essence_ppl_diff, MAX_ppl_diff)),
         essence_diff_normed = 1 - essence_ppl_diff_bounded / MAX_ppl_diff,
         target_score_v2 = 4/(safe_inverse(rewrite_score) + safe_inverse(paraphrase_score) + safe_inverse(neighborhood_score) + safe_inverse(essence_diff_normed)),
         target_score_mean = (rewrite_score + paraphrase_score + neighborhood_score) / 3,
         target_score_mean_v2 = (rewrite_score + paraphrase_score + neighborhood_score + essence_diff_normed) / 4,
         ) %>%
  left_join(tracing_data %>% select(case_id, is_correct) %>% unique())

# MAKE FIRST STYLE OF JOINED DATA. 
# one record per edit per datapoint. includes all tracing variables pertaining to the MAX effect per point
combined_data <- left_join(tracing_data, editing_data, join_by='case_id')
combined_data_max_MLP <- combined_data %>%
  filter(is_max_effect == TRUE, module=='mlp') %>%
  mutate(layer_discrepancy = edit_central_layer - layer_idx,
         max_tracing_layer = layer_idx)
combined_data_max_MLP <- combined_data_max_MLP %>% select(-layer_idx)

# SECOND STYLE OF JOINED DATA.
# one record per edit per datapoint.

# add max-token and subj-token effects for each point+layer to editing data
max_per_layer_per_record <- tracing_data %>% 
  filter(module=='mlp', trace_window_size==TRACE_WINDOW_SIZE) %>%
  group_by(case_id, layer_idx) %>%
  summarise(
    max_token_effect=max(restore_effect),
    max_fraction_restored=max(fraction_restored),
    # variables below are constant per point, so mean does nothing. need them for later filtering 
    corruption_effect=mean(corruption_effect),
    orig_pred_prob=mean(orig_pred_prob),
    ) 
subj_effect_per_layer_per_record <- tracing_data %>% 
  filter(module=='mlp', trace_window_size==TRACE_WINDOW_SIZE) %>%
  group_by(case_id, layer_idx) %>%
  filter(last_subj_token==TRUE) %>%
  summarise(subj_effect=restore_effect,
            subj_effect_fraction = restore_effect / abs(corruption_effect))
per_edit_data <- editing_data %>%
  mutate(layer_idx = edit_central_layer) %>% # need to add layer_idx for matching tracing by layer_idx
  left_join(
    max_per_layer_per_record,
    join_by=c("case_id", "layer_idx")
  ) %>%
  left_join(
    subj_effect_per_layer_per_record,
    join_by=c("case_id", "layer_idx")
  )
# add indicator for if tracing effect is in the 99th percentile of tracing effects
effect_cutoff <- quantile(data %>% filter(module == 'mlp') %>% pull(restore_effect), .95)
# add discrete restoration effect variable
per_edit_data <- per_edit_data %>%
  mutate(large_tracing_effect = max_token_effect > effect_cutoff,
         orig_pred_prob_disc = ifelse(orig_pred_prob < .05, "<.05", 
                    ifelse(orig_pred_prob < .1, ".05-.1", 
                    ifelse(orig_pred_prob < .15, ".1-.15", 
                    ifelse(orig_pred_prob < .2, ".15-.2", 
                    ifelse(orig_pred_prob < .25, ".2-.25", '>.25')))))
  )

editing_data %>%
  select(edit_method, edit_window_size, objective, case_id) %>%
  unique() %>%
  group_by(edit_method, edit_window_size, objective) %>%
  summarise(n = n())
# editing_data %>%
#   group_by(edit_method, edit_window_size, objective, case_id, edit_central_layer) %>%
#   unique() %>%
#   group_by(edit_method, edit_window_size, objective, edit_central_layer) %>%
#   summarise(n = n())

## Performance stats

In [None]:
correctness_filter = 1

editing_data_table <- per_edit_data %>%
  filter(is_correct >= correctness_filter) %>%
  group_by(edit_method, edit_central_layer, edit_window_size, objective) %>%
  summarise(n=n(),
            rewrite_score_sd = sd(rewrite_score),
            rewrite_score = mean(rewrite_score),
            paraphrase_score = mean(paraphrase_score),
            neighborhood_score = mean(neighborhood_score),
            target_score_mean = mean(target_score_mean),
            target_score_mean_v2 = mean(target_score_mean_v2),
            essence_ppl_diff = mean(essence_ppl_diff),
            essence_score = mean(essence_diff_normed),
            ) %>%
  arrange(edit_method) %>%
  arrange(edit_central_layer) %>%
  arrange(edit_window_size) %>%
  mutate_if(is.double, ~round(., 3))

editing_data_table %>%
  select(-rewrite_score_sd) %>%
  # filter(is_correct >= correctness_filter) %>%
  # filter(edit_method == "FT") %>%
  # filter(objective == "Fact Forcing") %>%
  filter(objective == "Tracing Reversal") %>%
  # filter(objective == "Falsehood Injection") %>%
  # filter(edit_window_size == 5) %>%
  arrange(edit_method, objective, edit_window_size, edit_central_layer)

# Performance plots

In [None]:
options(repr.plot.width=11, repr.plot.height=6)
editing_data_table <- editing_data_table %>%
  mutate(Method = sprintf("%s (ws=%s)", edit_method, edit_window_size),
         Method = factor(Method, levels=c("FT (ws=1)", "FT (ws=5)", "ROME (ws=1)", "MEMIT (ws=5)")))
xticks = seq(0, 28, by=4)
xticks[1] <- 1

OBJECTIVE = "Falsehood Injection"
# OBJECTIVE = "Tracing Reversal"
# OBJECTIVE = "Fact Erasure"
# OBJECTIVE = "Fact Amplification"
# OBJECTIVE = "Fact Forcing"
line_size = 1

TITLE <- sprintf("%s Rewrite Score by Edit Layer", OBJECTIVE)
if (OBJECTIVE == "Falsehood Injection"){
  TITLE <- "Error Injection Rewrite Score by Edit Layer"
}
(rewrite_plot <- editing_data_table %>%
  filter(objective == OBJECTIVE) %>%
  ggplot(aes(edit_central_layer, rewrite_score, color=Method)) + 
  geom_line(size=line_size) +
  ggtitle(TITLE) + 
  xlab("(Central) Edit Layer") + 
  ylab("Rewrite Score") + 
  scale_x_continuous(labels = xticks, breaks=xticks-1) + 
  theme + 
  theme(axis.title.x = element_text(size=18, color='black', angle=0, vjust=-.5, hjust=.5),
        axis.title.y = element_text(size=18, color='black', angle=90, vjust=2, hjust=.5),
        plot.title = element_text(size=20, color='black', angle=0, vjust=.5, hjust=.5),
        legend.title = element_text(size=16),
        legend.position=c(0.2, 0.2),  
        legend.background = element_rect(fill = "white", color = "#555555"),
        legend.key = element_blank())
)

TITLE <- sprintf("%s Paraphrase score by edit layer", OBJECTIVE)
if (OBJECTIVE == "Falsehood Injection"){
  TITLE <- "Error Injection Paraphrase Score by Edit Layer"
}
(paraphrase_plot <- editing_data_table %>%
  filter(objective == OBJECTIVE) %>%
  ggplot(aes(edit_central_layer, paraphrase_score, color=Method)) + 
  geom_line(size=line_size) +
  ylim(c(0, 1)) + 
  ggtitle(TITLE) + 
  xlab("(Central) Edit Layer") + 
  ylab("Paraphrase Score") + 
  scale_x_continuous(labels = xticks, breaks=xticks-1) + 
  theme + 
  theme(axis.title.x = element_text(size=18, color='black', angle=0, vjust=-.5, hjust=.5),
        axis.title.y = element_text(size=18, color='black', angle=90, vjust=2, hjust=.5),
        plot.title = element_text(size=20, color='black', angle=0, vjust=.5, hjust=.5),
        legend.title = element_text(size=16),
        legend.position=c(0.8, 0.8),  
        legend.background = element_rect(fill = "white", color = "#555555"),
        legend.key = element_blank())
)

TITLE <- sprintf("%s Neighbor score by edit layer", OBJECTIVE)
if (OBJECTIVE == "Falsehood Injection"){
  TITLE <- "Error Injection Neighborhood Score by Edit Layer"
}
(neighborhood_plot <- editing_data_table %>%
  filter(objective == OBJECTIVE) %>%
  ggplot(aes(edit_central_layer, neighborhood_score, color=Method)) + 
  geom_line(size=line_size) +
  ggtitle(TITLE) + 
  xlab("(Central) Edit Layer") + 
  ylab("Neighborhood Score") + 
  scale_x_continuous(labels = xticks, breaks=xticks-1) + 
  ylim(c(.9, 1)) + 
  theme + 
  theme(axis.title.x = element_text(size=18, color='black', angle=0, vjust=-.5, hjust=.5),
        axis.title.y = element_text(size=18, color='black', angle=90, vjust=2, hjust=.5),
        plot.title = element_text(size=20, color='black', angle=0, vjust=.5, hjust=.5),
        legend.title = element_text(size=16),
        legend.position=c(0.2, 0.2),  
        legend.background = element_rect(fill = "white", color = "#555555"),
        legend.key = element_blank())
)

TITLE <- sprintf("%s Essence Score by Edit Layer", OBJECTIVE)
if (OBJECTIVE == "Falsehood Injection"){
  TITLE <- "Error Injection Essence Score by Edit Layer"
}
(essence_plot <- editing_data_table %>%
  filter(objective == OBJECTIVE) %>%
  ggplot(aes(edit_central_layer, essence_score, color=Method)) + 
  geom_line(size=line_size) +
  ggtitle(TITLE) + 
  scale_x_continuous(labels = xticks, breaks=xticks-1) + 
  ylim(c(.4, 1)) + 
  xlab("(Central) Edit Layer") + 
  ylab("Essence Score") + 
  theme + 
  theme(axis.title.x = element_text(size=18, color='black', angle=0, vjust=-.5, hjust=.5),
        axis.title.y = element_text(size=18, color='black', angle=90, vjust=2, hjust=.5),
        plot.title = element_text(size=20, color='black', angle=0, vjust=.5, hjust=.5),
        legend.title = element_text(size=16),
        legend.position=c(0.8, 0.2),  
        legend.background = element_rect(fill = "white", color = "#555555"),
        legend.key = element_blank())
)

TITLE <- sprintf("%s Overall Score by Edit Layer", OBJECTIVE)
if (OBJECTIVE == "Falsehood Injection"){
  TITLE <- "Error Injection Overall Score by Edit Layer"
}
(ovr_plot <- editing_data_table %>%
  filter(objective == OBJECTIVE) %>%
  ggplot(aes(edit_central_layer, target_score_mean, color=Method)) + 
  geom_line(size=line_size) + 
  ggtitle(TITLE) +  
  xlab("(Central) Edit Layer") + 
  ylab("Overall Score") + 
  scale_x_continuous(labels = xticks, breaks=xticks-1) + 
  ylim(c(.3, 1)) + 
  theme + 
  theme(axis.title.x = element_text(size=18, color='black', angle=0, vjust=-.5, hjust=.5),
        axis.title.y = element_text(size=18, color='black', angle=90, vjust=2, hjust=.5),
        plot.title = element_text(size=20, color='black', angle=0, vjust=.5, hjust=.5),
        legend.title = element_text(size=16),
        legend.position=c(0.2, 0.2),  
        legend.background = element_rect(fill = "white", color = "#555555"),
        legend.key = element_blank())
)

TITLE <- sprintf("%s Overall Score by Edit Layer", OBJECTIVE)
if (OBJECTIVE == "Falsehood Injection"){
  TITLE <- "Error Injection Overall Score (+Essence) by Edit Layer"
}
(ovr_v2_plot <- editing_data_table %>%
  filter(objective == OBJECTIVE) %>%
  ggplot(aes(edit_central_layer, target_score_mean_v2, color=Method)) + 
  geom_line(size=line_size) +
  ggtitle(TITLE) + 
  xlab("(Central) Edit Layer") + 
  ylab("Overall Score") + 
  scale_x_continuous(labels = xticks, breaks=xticks-1) + 
  ylim(c(.3, 1)) + 
  theme + 
  theme(axis.title.x = element_text(size=18, color='black', angle=0, vjust=-.5, hjust=.5),
        axis.title.y = element_text(size=18, color='black', angle=90, vjust=2, hjust=.5),
        plot.title = element_text(size=20, color='black', angle=0, vjust=.5, hjust=.5),
        legend.title = element_text(size=16),
        legend.position=c(0.2, 0.2),  
        legend.background = element_rect(fill = "white", color = "#555555"),
        legend.key = element_blank())
)

ggsave('rewrite_plot.pdf', rewrite_plot, width=8, height=6, device=cairo_pdf)
colab::download_file('rewrite_plot.pdf')
ggsave('paraphrase_plot.pdf', paraphrase_plot, width=8, height=6, device=cairo_pdf)
colab::download_file('paraphrase_plot.pdf')
ggsave('neighborhood_plot.pdf', neighborhood_plot, width=8, height=6, device=cairo_pdf)
colab::download_file('neighborhood_plot.pdf')
ggsave('essence_plot.pdf', essence_plot, width=8, height=6, device=cairo_pdf)
colab::download_file('essence_plot.pdf')
ggsave('ovr_plot.pdf', ovr_plot, width=8, height=6, device=cairo_pdf)
colab::download_file('ovr_plot.pdf')
ggsave('ovr_v2_plot.pdf', ovr_v2_plot, width=8, height=6, device=cairo_pdf)
colab::download_file('ovr_v2_plot.pdf')

## R2 values for each method

In [None]:
conditions <- list()
conditions[[1]] <- c("FT", "1")
conditions[[2]] <- c("FT", "5")
conditions[[3]] <- c("ROME", "1")
conditions[[4]] <- c("MEMIT", "5")
objectives <- c("Falsehood Injection", "Tracing Reversal", "Fact Erasure", "Fact Forcing", "Fact Amplification")
correctness_filter = 1

results_df <- data.frame(
  edit_method = character(),
  objective = character(),
  edit_window_size = character(),
  rewrite_R2_lyr = double(), 
  rewrite_R2_trace = double(), 
  rewrite_R2_both = double(), 
  paraphrase_R2_lyr = double(), 
  paraphrase_R2_trace = double(), 
  paraphrase_R2_both = double(), 
  neighborhood_R2_lyr = double(), 
  neighborhood_R2_trace = double(), 
  neighborhood_R2_both = double(), 
  essence_R2_lyr = double(), 
  essence_R2_trace = double(), 
  essence_R2_both = double()
)

for (condition in conditions){
  METHOD <- condition[1]
  EDIT_WINDOW_SIZE <- condition[2]
  for (OBJECTIVE in objectives){
    model_data <- per_edit_data %>%
      filter(is_correct >= correctness_filter) %>%
      filter(edit_method == METHOD, edit_window_size == EDIT_WINDOW_SIZE, objective == OBJECTIVE) %>%
      mutate(essence_score = essence_diff_normed)
    if (nrow(model_data) == 0) next
    row <- data.frame(
        edit_method=METHOD,
        objective=OBJECTIVE,
        edit_window_size=EDIT_WINDOW_SIZE
      )
    for (metric in c("rewrite", "paraphrase", "neighborhood", "essence")){
      condition <- sprintf("%s | %s | %s | %s", METHOD, OBJECTIVE, EDIT_WINDOW_SIZE, metric)
      # print(condition)
      score_name <- sprintf("%s_score", metric)
      model_edit_layer_only <- lm(get(score_name) ~ as.factor(edit_central_layer), data = model_data)
      r2_lyr <- summary(model_edit_layer_only)$r.squared
      model_fraction_restored_only <- lm(get(score_name) ~ max_fraction_restored, data = model_data)
      r2_trace <- summary(model_fraction_restored_only)$r.squared
      model_both <- lm(get(score_name) ~ as.factor(edit_central_layer) * max_fraction_restored, data = model_data)
      r2_both <- summary(model_both)$r.squared
      f_test <- anova(model_both, model_edit_layer_only)
      p_value <- f_test$`Pr(>F)`[2]
      row[[sprintf("%s_R2_lyr", metric)]] = r2_lyr
      row[[sprintf("%s_R2_trace", metric)]] = r2_trace
      row[[sprintf("%s_R2_both", metric)]] = r2_both
      row[[sprintf("%s_R2_diff", metric)]] = r2_both - r2_lyr
      row[[sprintf("%s_F_pvalue", metric)]] <- p_value
    }
    results_df <- bind_rows(results_df, row)
    }
}

results_df <- results_df %>%
  mutate_if(is.double, ~round(.,3))

In [None]:
results_df %>%
  mutate_if(is.double, ~round(.,4)) %>%
  # filter(objective == "Falsehood Injection" | objective == "Fact Forcing") %>%
  # filter(objective == "Fact Forcing") %>%
  select(objective, edit_method, edit_window_size, rewrite_R2_diff, rewrite_F_pvalue, paraphrase_R2_diff, paraphrase_F_pvalue, neighborhood_R2_diff, neighborhood_F_pvalue, essence_R2_diff, essence_F_pvalue) %>%
  # select(objective, edit_method, edit_window_size, rewrite_R2_lyr, rewrite_R2_trace, rewrite_R2_both, rewrite_R2_diff, rewrite_F_pvalue) %>%
  # select(objective, edit_method, edit_window_size, rewrite_F_pvalue) %>%
  # arrange(rewrite_R2_diff) %>%
  # arrange(paraphrase_R2_diff) %>%
  # arrange(neighborhood_R2_diff) %>%
  # arrange(essence_R2_diff) %>%
  # arrange(rewrite_R2_lyr) %>%
  arrange(objective, edit_method, edit_window_size) %>%
  write_csv('tmp.csv') %>%
  filter()

colab::download_file('tmp.csv')

## Final R2 plot

In [None]:
options(repr.plot.width=13, repr.plot.height=4)

obj = 'rewrite'
R2_lyr_name = sprintf("%s_R2_lyr", obj)
R2_diff_name = sprintf("%s_R2_diff", obj)
plot_data <- results_df %>%
  mutate(method=sprintf('%s-%s', edit_method, edit_window_size)) %>%
  select(objective, method, all_of(R2_lyr_name), all_of(R2_diff_name)) %>%
  pivot_longer(cols=c(R2_lyr_name, R2_diff_name), names_to = "model", values_to = "R2") %>%
  mutate(model = ifelse(model == R2_lyr_name, "Layer", "Layer + Tracing Effect"),
         objective = ifelse(objective == "Falsehood Injection", "Error Injection", objective),
         objective = factor(objective, levels=c("Error Injection","Tracing Reversal","Fact Amplification", "Fact Erasure","Fact Forcing")),
         model = factor(model, levels = c("Layer + Tracing Effect", "Layer")),
         method = ifelse(method == "ROME-1", "ROME", ifelse(method=="MEMIT-5", "MEMIT", method)),
         method = factor(method, levels = c("FT-1", "FT-5", "ROME", "MEMIT"))
  )
# this is an overly complicated way to get the 'p<1e-4' text to appear on only one facet of the plot, a single time rather than once per bar
text_data <- plot_data %>%
  #filter(objective == "Fact Forcing", grepl("FT", method, fixed=TRUE)) %>%
  filter(objective == "Fact Forcing", method=="FT-1", model=="Layer") %>%
  mutate(ann_text = factor("Fact Forcing", levels = c("Error Injection","Tracing Reversal","Fact Amplification", "Fact Erasure","Fact Forcing")))
(bar_plot <- plot_data %>%
  filter(objective != "Error Injection") %>%
  ggplot(aes(x=method, y=R2, fill=model)) +
    geom_col(position='stack', width = 0.5) + 
    geom_text(data=text_data, 
              # mapping = aes(x=-Inf, y=-Inf, label=label),
              label = "p < 1e-4",
              size=5,
              # size=16,
              hjust = 0, vjust=-1.5) + 
    xlab("") + 
    ylab(bquote(R^2)) + 
    ylim(c(0,1)) + 
    scale_fill_manual(name = "Explanatory Variable(s):", values=c(cbp1[2], cbp1[1]), limits = c("Layer", "Layer + Tracing Effect")) +
    theme + 
    ggtitle(paste("Tracing effects are very weakly predictive of edit success")) + 
    # ggtitle(paste("Tracing effects are very weakly predictive of edit success (measured by rewrite score")) + 
    theme(axis.title.y = element_text(size=23, angle=0, vjust=0.45, hjust=-2),
          # axis.text.x = element_text(angle=40, vjust=.65, hjust=.5),
          axis.text.y = element_text(size=18),
          axis.text.x = element_text(size=16, angle=0, vjust=0, hjust=.5),
          axis.title.x = element_text(size=23, angle=0, vjust=-1, hjust=.5),
          plot.title = element_text(size=24, angle=0, vjust=1, hjust=.5),
          strip.text.x = element_text(size=16, color='black', angle=0, vjust=.5, hjust=.5),
          legend.position=c(.5, -0.25),  
          legend.direction='horizontal',
          legend.title = element_text(size=22, vjust=.55),
          legend.text = element_text(size=22),
          legend.box.margin = margin(100,100,100,100),
        ) + 
    facet_wrap(~objective, nrow=1)
)

# ggsave('R2-bar-plot.pdf', bar_plot, width=17, height=4, device=cairo_pdf)
ggsave('R2-bar-plot.pdf', bar_plot, width=14.5, height=3.6, device=cairo_pdf)
colab::download_file('R2-bar-plot.pdf') 