# 230623_sub_causal_tracing

submodule 단위로 가능하도록 변경

# Import libraries

In [2]:
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Tokenizer, set_seed
from datasets import load_dataset
from tqdm import tqdm
import json
import torch
import argparse
import datasets
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pylab as plt
from datetime import date
import sys
import os


  from .autonotebook import tqdm as notebook_tqdm


# Config

In [21]:
torch.manual_seed(718)
set_seed(718)

In [22]:
job_cd = "omcd_econ_l1"
# job_cd = "tmod_Religion_and_belief_systems"
# job_cd = sys.argv[2] # TODO
list_job_cd = job_cd.split("_")
job_gubun = list_job_cd[0]
model_id = "_".join(list_job_cd[1:])

In [23]:
assert job_gubun in ["omcd", "tmod"]

cache_dir = "/rds/general/user/jj1122/ephemeral/.cache/huggingface"
device_id = "cpu"

n_layers = 12
list_modules = ['attn', 'mlp']
# trace_module_id = "transformer.h.{l}.{m}"

tuned_model_path = f"/rds/general/user/jj1122/home/projects/m2d2/dataset/{model_id}/models"

today_dt = date.today().strftime("%y%m%d")
output_file = f"/rds/general/user/jj1122/home/projects/m2d2/utils/output_logs/{today_dt}_sub_{job_cd}.json"


# module params

In [4]:
cache_dir = "/rds/general/user/jj1122/ephemeral/.cache/huggingface"

def parse_code_submodule(row):
    list_code = row.code.split(".")
    row["trace_id"] = ".".join(row["code"].split(".")[:-1])
#     row["component_id"] = ".".join(row["code"].split(".")[1:-1])
    is_in_layer = row["code"].startswith("transformer.h")
    

    if is_in_layer:        
        row["layer"] = int(list_code[2])
        row["module"] = list_code[3]
    else:
        row["module"] = list_code[1]
        if row["module"] == "ln_f": 
            row["layer"] = int(99)
        elif row["module"] in ["wte", "wpe"]:
            row['layer'] = int(-1)
    
    if row["module"] in ["attn", "mlp"]:
        row["submodule"] = list_code[-2]
    else:
        row["submodule"] = row["module"]

    row["w_or_b"] = list_code[-1]
    row["is_in_layer"] = is_in_layer
    
    is_investigated = (not row["code"].startswith("transformer.w")) and (not row["submodule"].startswith("ln"))
    row["is_investigated"] = is_investigated

    return row


dict_n_parmas = {tup[0]: tup[1].numel() for tup in GPT2LMHeadModel.from_pretrained("gpt2", cache_dir=cache_dir).to("cpu").named_parameters()}
df_submodule = pd.DataFrame.from_dict(dict_n_parmas, orient="index").reset_index()
df_submodule.columns = ["code", 'params']
df_submodule = df_submodule.apply(lambda row: parse_code_submodule(row), axis=1)
df_temp = df_submodule.groupby("trace_id")["params"].sum().to_frame().reset_index()
df_temp.columns = ["trace_id", "total"]

df_submodule["display_id"] = df_submodule["layer"].astype(str) + "." + df_submodule["module"]

df_submodule.loc[df_submodule.module != df_submodule.submodule, "display_id"] \
    = df_submodule.loc[df_submodule.module != df_submodule.submodule, "display_id"] \
        + "." + df_submodule.loc[df_submodule.module != df_submodule.submodule, "submodule"]

df_submodule = df_submodule.merge(df_temp, on=["trace_id"])
df_submodule.head()

Unnamed: 0,code,is_in_layer,is_investigated,layer,module,params,submodule,trace_id,w_or_b,display_id,total
0,transformer.wte.weight,False,False,-1,wte,38597376,wte,transformer.wte,weight,-1.wte,38597376
1,transformer.wpe.weight,False,False,-1,wpe,786432,wpe,transformer.wpe,weight,-1.wpe,786432
2,transformer.h.0.ln_1.weight,True,False,0,ln_1,768,ln_1,transformer.h.0.ln_1,weight,0.ln_1,1536
3,transformer.h.0.ln_1.bias,True,False,0,ln_1,768,ln_1,transformer.h.0.ln_1,bias,0.ln_1,1536
4,transformer.h.0.attn.c_attn.weight,True,True,0,attn,1769472,c_attn,transformer.h.0.attn.c_attn,weight,0.attn.c_attn,1771776


In [6]:
list_trace_ids = df_submodule.loc[df_submodule.is_investigated].trace_id.unique()

# Data

In [71]:
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2", cache_dir=cache_dir)
dataset = load_dataset("machelreid/m2d2", model_id, cache_dir=cache_dir)
ds_test = dataset["test"].filter(lambda x: x['text'] != '')


Found cached dataset m2d2 (/rds/general/user/jj1122/ephemeral/.cache/huggingface/machelreid___m2d2/econ_l1/0.0.0/eb235f33a5de3163c10549b7f63c906910539c8a8c0ec5ade1285ccbf5067d00)
100%|██████████| 3/3 [00:00<00:00, 828.26it/s]
Loading cached processed dataset at /rds/general/user/jj1122/ephemeral/.cache/huggingface/machelreid___m2d2/econ_l1/0.0.0/eb235f33a5de3163c10549b7f63c906910539c8a8c0ec5ade1285ccbf5067d00/cache-1eec3cf03a698048.arrow


In [72]:
set_seed(718)

def tokenize_function(examples):
    output = gpt2_tokenizer(examples['text'], max_length=1024, truncation=True)
    return output

tokenized_datasets = ds_test.map(
    tokenize_function,
    batched=True,
    num_proc=8,
    remove_columns='text', #TODO
    load_from_cache_file=True,
)

len_sentences = len(tokenized_datasets)

Loading cached processed dataset at /rds/general/user/jj1122/ephemeral/.cache/huggingface/machelreid___m2d2/econ_l1/0.0.0/eb235f33a5de3163c10549b7f63c906910539c8a8c0ec5ade1285ccbf5067d00/cache-12826ea30f123b74_*_of_00008.arrow


# MODEL & HOOK

In [77]:
# MODEL & HOOK
def load_model(model="tuned"):
    if model == "tuned":
        return GPT2LMHeadModel.from_pretrained(tuned_model_path).to(device_id)
    else:
        return GPT2LMHeadModel.from_pretrained(model, cache_dir=cache_dir).to(device_id)


def save_clean_activation(m_id):
    def save_clean_activation_hook(module, _input, _output):
        clean_activations[m_id] = _output
    return save_clean_activation_hook


def corrupt_input_vector_v2(module, _input, _output):
    torch.manual_seed(718)
    global before
    global after
    
    std = torch.std(_output)
    output = _output + (std*1.5) * torch.randn(_output[0].shape).to(device_id)
    before = _output
    after = output
    return output


def restore_activation(m_id):
    def restore_activation_hook(module, _input, _output):
        clean_activation = clean_activations[m_id]#[:, t]
        return clean_activation
    return restore_activation_hook

## Register Hook

In [95]:
if job_gubun == "omcd":
    clean_model = load_model()
    clean_model.eval()
    for m_id in list_trace_ids:
        clean_model.get_submodule(m_id).register_forward_hook(save_clean_activation(m_id))    

    # Second run: corrupted run def
    corrupted_model = load_model()
    corrupted_model.eval()
    corrupted_model.get_submodule("transformer.wte").register_forward_hook(corrupt_input_vector_v2)

    # Third run: restored run def    
    restored_model = load_model()
    restored_model.get_submodule("transformer.wte").register_forward_hook(corrupt_input_vector_v2)
else:
    clean_model = load_model()
    clean_model.eval()
    for m_id in list_trace_ids:
        clean_model.get_submodule(m_id).register_forward_hook(save_clean_activation(m_id))

    # Second run: base run def
    corrupted_model = load_model("gpt2")
    corrupted_model.eval()

    # Third run: restored run def
    restored_model = load_model("gpt2")
    restored_model.eval()


# Causal Tracing

In [96]:
try:
    with open(output_file, 'r') as json_file:
        list_results = json.load(json_file)
    for i, d in enumerate(list_results):
        if len(d) == 0:
            done_idx = i - 1
            break
except:
    list_results = [{} for x in range(len_sentences)]
    done_idx = 0
print(done_idx)

-1


In [97]:
done_idx = 0

In [98]:
torch.manual_seed(718)
set_seed(718)

for sentence_idx, data in enumerate(tokenized_datasets):
    if done_idx > sentence_idx: continue
#         print(sentence_idx, data)
#     print(sentence_idx)

    if sentence_idx % 1000 == 0:
        print(f"sentence_idx: {sentence_idx}")
        with open(output_file, 'w') as json_file:
            json.dump(list_results, json_file)

    inputs = torch.tensor(data['input_ids']).to(device_id)

    # First run: clean run
    clean_activations = {}
    with torch.no_grad():
        clean_outputs = clean_model(inputs, labels=inputs.clone())
        clean_loss = np.exp(clean_outputs.loss.item())

    # Second run: corrupted run
    with torch.no_grad():
        corrupted_outputs = corrupted_model(inputs, labels=inputs.clone())
        corrupted_loss = np.exp(corrupted_outputs.loss.item())

    # Third run: corrupted-with-restoration run
    restored_loss = {}
    with torch.no_grad():
        for m_id in list_trace_ids:
            hook = restored_model.get_submodule(m_id).register_forward_hook(restore_activation(m_id))
            restored_outputs = restored_model(inputs, labels=inputs.clone())
            restored_loss[m_id] = np.exp(restored_outputs.loss.item())
            hook.remove()

    list_results[sentence_idx]['clean_loss'] = clean_loss
    list_results[sentence_idx]['corrupted_loss'] = corrupted_loss
    list_results[sentence_idx]['restored_loss'] = restored_loss

sentence_idx: 0


KeyboardInterrupt: 