# 230623_causal_tracing

Hooking 방식 좀 더 안전하게 변경  
tokenisation 할 때 max_sequence 부여

# Import libraries

In [1]:
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Tokenizer, set_seed
from datasets import load_dataset
from tqdm import tqdm
import json
import torch
import argparse
import datasets
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pylab as plt
from datetime import date
import sys
import os


  from .autonotebook import tqdm as notebook_tqdm


# # module params

In [3]:
import json
with open("./data/dict_n_module_params_230623.json", "r") as json_file:
    dict_temp = json.load(json_file)

In [27]:
def parse_code(row):
    list_code = row.code.split(".")
    is_in_layer = row["code"].startswith("transformer.h")
    if is_in_layer:        
        row["l"] = int(list_code[2])
        row["m"] = list_code[3]
    else:
        row["m"] = list_code[1]
        if row["m"] == "ln_f": 
            row["l"] = int(99)
        elif row["m"] in ["wte", "wpe"]: 
            row['l'] = int(-1)
    row["w_or_b"] = list_code[-1]
    row["is_in_layer"] = is_in_layer
    return row

df_temp = pd.DataFrame.from_dict(dict_temp, orient="index").reset_index()
df_temp.columns = ['code', 'n_params']
df_temp = df_temp.apply(lambda row: parse_code(row), axis=1)
df_temp

Unnamed: 0,code,is_in_layer,l,m,n_params,w_or_b
0,transformer.wte.weight,False,-1,wte,38597376,weight
1,transformer.wpe.weight,False,-1,wpe,786432,weight
2,transformer.h.0.ln_1.weight,True,0,ln_1,768,weight
3,transformer.h.0.ln_1.bias,True,0,ln_1,768,bias
4,transformer.h.0.attn.c_attn.weight,True,0,attn,1769472,weight
...,...,...,...,...,...,...
143,transformer.h.11.mlp.c_fc.bias,True,11,mlp,3072,bias
144,transformer.h.11.mlp.c_proj.weight,True,11,mlp,2359296,weight
145,transformer.h.11.mlp.c_proj.bias,True,11,mlp,768,bias
146,transformer.ln_f.weight,False,99,ln_f,768,weight


In [33]:
df_temp.loc[df_temp.w_or_b == "weight", ["m", "n_params", "is_in_layer"]]\
    .drop_duplicates()\
    .groupby(["is_in_layer", "m"], as_index=False)[["n_params"]].sum().drop_duplicates()

Unnamed: 0,is_in_layer,m,n_params
0,False,ln_f,768
1,False,wpe,786432
2,False,wte,38597376
3,True,attn,2359296
4,True,ln_1,768
5,True,ln_2,768
6,True,mlp,2359296


In [None]:
def parse_code(row):
    list_code = row.code.split(".")
    row['l'] = int(list_code[2])
    row['m'] = list_code[3]
#     row['t'] = int(list_code[4])
    return row

list_df_ide_temp = []
for i, d in enumerate(list_results_temp):
    df_t = pd.DataFrame.from_dict(d['restored_loss'], orient='index')
    TE = (corrupted_loss - clean_loss) / clean_loss
    IDE = {}
    for m_id in list_trace_module_ids:
        IDE[m_id] = (restored_loss[m_id] - clean_loss) / clean_loss
    
    df_ide = pd.DataFrame.from_dict(IDE, orient='index').reset_index()
    df_ide.columns = ['code', 'ide']
    df_ide = df_ide.apply(lambda row: parse_code(row), axis=1)
    df_ide["seq"] = i
    list_df_ide_temp.append(df_ide)
    break

# Config

In [2]:
torch.manual_seed(718)
set_seed(718)

In [3]:
job_cd = "omcd_econ_l1"
# job_cd = "tmod_Religion_and_belief_systems"
# job_cd = sys.argv[2] # TODO

list_job_cd = job_cd.split("_")
job_gubun = list_job_cd[0]
model_id = "_".join(list_job_cd[1:])

In [4]:
assert job_gubun in ["omcd", "tmod"]

cache_dir = "/rds/general/user/jj1122/ephemeral/.cache/huggingface"
device_id = "cpu"

n_layers = 12
list_modules = ['attn', 'mlp']
trace_module_id = "transformer.h.{l}.{m}"

tuned_model_path = f"/rds/general/user/jj1122/ephemeral/projects/m2d2/dataset/{model_id}/models"

today_dt = date.today().strftime("%y%m%d")
output_file = f"/rds/general/user/jj1122/home/projects/m2d2/utils/output_logs/{today_dt}_{job_gubun}_{model_id}.json"


In [5]:
list_trace_module_ids = []
for l in range(n_layers):
    for m in list_modules:
        list_trace_module_ids.append(trace_module_id.format(l=l, m=m))


# Data

In [6]:
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

dataset = load_dataset("machelreid/m2d2", model_id, cache_dir=cache_dir)
ds_test = dataset["test"].filter(lambda x: x['text'] != '')


Found cached dataset m2d2 (/rds/general/user/jj1122/ephemeral/.cache/huggingface/machelreid___m2d2/econ_l1/0.0.0/eb235f33a5de3163c10549b7f63c906910539c8a8c0ec5ade1285ccbf5067d00)
100%|██████████| 3/3 [00:00<00:00, 677.27it/s]
Loading cached processed dataset at /rds/general/user/jj1122/ephemeral/.cache/huggingface/machelreid___m2d2/econ_l1/0.0.0/eb235f33a5de3163c10549b7f63c906910539c8a8c0ec5ade1285ccbf5067d00/cache-1eec3cf03a698048.arrow


In [102]:
set_seed(718)

def tokenize_function(examples):
    output = gpt2_tokenizer(examples['text'], max_length=1024, truncation=True)
    return output

tokenized_datasets = ds_test.map(
    tokenize_function,
    batched=True,
    num_proc=8,
    remove_columns='text', #TODO
    load_from_cache_file=True,
)

len_sentences = len(tokenized_datasets)



# MODEL & HOOK

In [90]:
# MODEL & HOOK
def load_model(model="tuned"):
    if model == "tuned":
        return GPT2LMHeadModel.from_pretrained(tuned_model_path).to(device_id)
    else:
        return GPT2LMHeadModel.from_pretrained(model, cache_dir=cache_dir).to(device_id)


def save_clean_activation(m_id):
    def save_clean_activation_hook(module, _input, _output):
        clean_activations[m_id] = _output
    return save_clean_activation_hook


def corrupt_input_vector_v2(module, _input, _output):
    torch.manual_seed(718)
    global before
    global after
    
    std = torch.std(_output)
    output = _output + (std*1.5) * torch.randn(_output[0].shape).to(device_id)
    before = _output
    after = output
    return output


def restore_activation(m_id):
    def restore_activation_hook(module, _input, _output):
        clean_activation = clean_activations[m_id]#[:, t]
        return clean_activation
    return restore_activation_hook

## Register Hook

In [91]:
if job_gubun == "omcd":
    clean_model = load_model()
    clean_model.eval()
    for m_id in list_trace_module_ids:
        clean_model.get_submodule(m_id).register_forward_hook(save_clean_activation(m_id))    

    # Second run: corrupted run def
    corrupted_model = load_model()
    corrupted_model.eval()
    #         corrupted_model.get_submodule("transformer.h.0.attn").register_forward_pre_hook(corrupt_input_vector)
    corrupted_model.get_submodule("transformer.wte").register_forward_hook(corrupt_input_vector_v2)

    # Third run: restored run def    
    restored_model = load_model()
    #         restored_model.get_submodule("transformer.h.0.attn").register_forward_pre_hook(corrupt_input_vector)
    restored_model.get_submodule("transformer.wte").register_forward_hook(corrupt_input_vector_v2)
else:
    clean_model = load_model()
    clean_model.eval()
    for m_id in list_trace_module_ids:
        clean_model.get_submodule(m_id).register_forward_hook(save_clean_activation(m_id))

    # Second run: base run def
    corrupted_model = load_model("gpt2")
    corrupted_model.eval()

    # Third run: restored run def

    restored_model = load_model("gpt2")
    restored_model.eval()


# Causal Tracing

In [92]:
try:
    with open(output_file, 'r') as json_file:
        list_results = json.load(json_file)
    for i, d in enumerate(list_results):
        if len(d) == 0:
            done_idx = i - 1
            break
except:
    list_results = [{} for x in range(len_sentences)]
    done_idx = 0
print(done_idx)

58


In [93]:
done_idx = 0

In [101]:
torch.manual_seed(718)
set_seed(718)

for sentence_idx, data in enumerate(tokenized_datasets):
    if done_idx > sentence_idx: continue
#         print(sentence_idx, data)
#     print(sentence_idx)

    if sentence_idx % 1000 == 0:
        print(f"sentence_idx: {sentence_idx}")
        with open(output_file, 'w') as json_file:
            json.dump(list_results, json_file)

    inputs = torch.tensor(data['input_ids']).to(device_id)

    # First run: clean run
    clean_activations = {}
    with torch.no_grad():
        clean_outputs = clean_model(inputs, labels=inputs.clone())
        clean_loss = np.exp(clean_outputs.loss.item())

    # Second run: corrupted run
    with torch.no_grad():
        corrupted_outputs = corrupted_model(inputs, labels=inputs.clone())
        corrupted_loss = np.exp(corrupted_outputs.loss.item())

    # Third run: corrupted-with-restoration run
    restored_loss = {}
    with torch.no_grad():
        for m_id in list_trace_module_ids:
            hook = restored_model.get_submodule(m_id).register_forward_hook(restore_activation(m_id))
            restored_outputs = restored_model(inputs, labels=inputs.clone())
            restored_loss[m_id] = np.exp(restored_outputs.loss.item())
            hook.remove()

    list_results[sentence_idx]['clean_loss'] = clean_loss
    list_results[sentence_idx]['corrupted_loss'] = corrupted_loss
    list_results[sentence_idx]['restored_loss'] = restored_loss

sentence_idx: 0
