# 230611_causal_tracing

batch 처리 가능하게  - loss 땜시 불가능  
tokenisation 방식 변경 여러개

# Import libraries

In [1]:
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Tokenizer, set_seed
from datasets import load_dataset
from tqdm import tqdm
import json
import torch
import argparse
import datasets
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pylab as plt
from datetime import date
import sys

  from .autonotebook import tqdm as notebook_tqdm


# Config

In [2]:
torch.manual_seed(718)
set_seed(718)

In [3]:
job_cd = "omcd_Religion_and_belief_systems"
# job_cd = "tmod_Religion_and_belief_systems"
# job_cd = sys.argv[2] # TODO

list_job_cd = job_cd.split("_")
job_gubun = list_job_cd[0]
model_id = "_".join(list_job_cd[1:])

cache_dir = "/rds/general/user/jj1122/ephemeral/.cache/huggingface"
device_id = "cpu"

n_layers = 12
list_modules = ['attn', 'mlp']
trace_module_id = "transformer.h.{l}.{m}"

tuned_model_path = f"/rds/general/user/jj1122/home/projects/m2d2/dataset/{model_id}/models"

today_dt = date.today().strftime("%y%m%d")
output_file = f"/rds/general/user/jj1122/home/projects/m2d2/utils/output_logs/{today_dt}_{job_gubun}_{model_id}.json"

In [4]:
list_trace_module_ids = []

for l in range(n_layers):
    for m in list_modules:
        list_trace_module_ids.append(trace_module_id.format(l=l, m=m))

# Data

In [5]:
dataset = load_dataset("machelreid/m2d2", model_id, cache_dir=cache_dir)
ds_test = dataset["test"].filter(lambda x: x['text'] != '')
# gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token

Found cached dataset m2d2 (/rds/general/user/jj1122/ephemeral/.cache/huggingface/machelreid___m2d2/Religion_and_belief_systems/0.0.0/eb235f33a5de3163c10549b7f63c906910539c8a8c0ec5ade1285ccbf5067d00)
100%|██████████| 3/3 [00:00<00:00, 276.90it/s]
Loading cached processed dataset at /rds/general/user/jj1122/ephemeral/.cache/huggingface/machelreid___m2d2/Religion_and_belief_systems/0.0.0/eb235f33a5de3163c10549b7f63c906910539c8a8c0ec5ade1285ccbf5067d00/cache-b257c9c3236091fa.arrow


In [6]:
set_seed(718)

gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

def tokenize_function(examples):
    output = gpt2_tokenizer(examples['text'])
    return output

tokenized_datasets = ds_test.map(
    tokenize_function,
    batched=True,
    num_proc=8,
    remove_columns='text', #TODO
    load_from_cache_file=True,
)

len_sentences = len(tokenized_datasets)

Loading cached processed dataset at /rds/general/user/jj1122/ephemeral/.cache/huggingface/machelreid___m2d2/Religion_and_belief_systems/0.0.0/eb235f33a5de3163c10549b7f63c906910539c8a8c0ec5ade1285ccbf5067d00/cache-3708937b1d6e8848_*_of_00008.arrow


In [7]:
def load_model():
    return GPT2LMHeadModel.from_pretrained(tuned_model_path).to(device_id)


# OMCD

## Hook

In [8]:
# First run: clean run def
def save_clean_activation(m_id):
    def save_clean_activation_hook(module, _input, _output):
#         print(m_id, _output.shape)
        if m_id.endswith('attn'):
            clean_activations[m_id] = _output[0].detach()
        elif m_id.endswith('mlp'):
#         else:
            clean_activations[m_id] = _output.detach()
    return save_clean_activation_hook

tuned_model = load_model()
tuned_model.eval()
for m_id in list_trace_module_ids:
    tuned_model.get_submodule(m_id).register_forward_hook(save_clean_activation(m_id))    

    
# Second run: corruped run def    
def corrupt_input_vector(module, _input):#, _output):
    torch.manual_seed(718)
    std = torch.std(_input[0])
    return tuple([_input[0] + (std*1.5) * torch.randn(_input[0].shape).to(device_id), ])

corrupted_model = load_model()
corrupted_model.eval()
corrupted_model.get_submodule("transformer.h.0.attn").register_forward_pre_hook(corrupt_input_vector)    


# Third run: restored run def    
def restore_activation(m_id):
    def restore_activation_hook(module, _input, _output):
        clean_activation = clean_activations[m_id]#[:, t]
        if m_id.endswith('attn'):
            return tuple([clean_activation, tuple([_output[1][0], _output[1][1]])])
        elif m_id.endswith('mlp'):           
            base_output = _output.detach()
            base_output = clean_activation
            return base_output
    return restore_activation_hook

restored_model = load_model()
restored_model.eval()
restored_model.get_submodule("transformer.h.0.attn").register_forward_pre_hook(corrupt_input_vector)

<torch.utils.hooks.RemovableHandle at 0x14ead6342590>

In [8]:
# First run: clean run def
def save_clean_activation_v2(m_id):
    def save_clean_activation_hook_v2(module, _input, _output):
#         print(m_id, _output.shape)
        if m_id.endswith('attn'):
            clean_activations[m_id] = _output
        elif m_id.endswith('mlp'):
#         else:
            clean_activations[m_id] = _output
    return save_clean_activation_hook_v2

tuned_model = load_model()
tuned_model.eval()
for m_id in list_trace_module_ids:
    tuned_model.get_submodule(m_id).register_forward_hook(save_clean_activation_v2(m_id))

    
# Second run: corruped run def    
def corrupt_input_vector_v2(module, _input):#, _output):
    torch.manual_seed(718)
    std = torch.std(_input[0])
    return tuple([_input[0] + (std*1.5) * torch.randn(_input[0].shape).to(device_id), ])

corrupted_model = load_model()
corrupted_model.eval()
corrupted_model.get_submodule("transformer.h.0.attn").register_forward_pre_hook(corrupt_input_vector_v2)


# Third run: restored run def    
def restore_activation_v2(m_id):
    def restore_activation_hook_v2(module, _input, _output):
        clean_activation = clean_activations[m_id]#[:, t]
        if m_id.endswith('attn'):
            return clean_activation
        elif m_id.endswith('mlp'):           
#             base_output = _output.detach()
#             base_output = clean_activation
            return clean_activation
    return restore_activation_hook_v2

restored_model = load_model()
restored_model.eval()
restored_model.get_submodule("transformer.h.0.attn").register_forward_pre_hook(corrupt_input_vector_v2)

<torch.utils.hooks.RemovableHandle at 0x14876b9255a0>

## Causal Tracing

In [9]:
try:
    with open(output_file, 'r') as json_file:
        list_results = json.load(json_file)
    for i, d in enumerate(list_results):
        if len(d) == 0:
            done_idx = i - 1
            break
except:
    list_results = [{} for x in range(len_sentences)]
    done_idx = 0
print(done_idx)

2


In [10]:
done_idx = 0

In [11]:
tokenized_datasets.select([0, 1, 2])

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 3
})

In [12]:
torch.manual_seed(718)
set_seed(718)

for sentence_idx, data in enumerate(tokenized_datasets.select([0, 1, 2])):
    if done_idx > sentence_idx: continue
    if sentence_idx % 1000 == 0:
        print(f"sentence_idx: {sentence_idx}")
        with open(output_file, 'w') as json_file:
            json.dump(list_results, json_file)
    
    inputs = torch.tensor(data['input_ids']).to(device_id)
    
    # First run: clean run        
    clean_activations = {}
    with torch.no_grad():
        clean_outputs = tuned_model(inputs, labels=inputs.clone())
        clean_loss = np.exp(clean_outputs.loss.item())

    # Second run: corrupted run    
    with torch.no_grad():
        corrupted_outputs = corrupted_model(inputs, labels=inputs.clone())
        corrupted_loss = np.exp(corrupted_outputs.loss.item())
        
    # Third run: corrupted-with-restoration run    
    restored_loss = {}
    with torch.no_grad():
        for m_id in list_trace_module_ids:
            hook = restored_model.get_submodule(m_id).register_forward_hook(restore_activation(m_id))
            restored_outputs = restored_model(inputs, labels=inputs.clone())
            restored_loss[m_id] = np.exp(restored_outputs.loss.item())
            hook.remove()
            
    list_results[sentence_idx]['clean_loss'] = clean_loss
    list_results[sentence_idx]['corrupted_loss'] = corrupted_loss
    list_results[sentence_idx]['restored_loss'] = restored_loss

sentence_idx: 0


In [13]:
list_results[:3]

[{'clean_loss': 63.91732102791386,
  'corrupted_loss': 132.8242889027916,
  'restored_loss': {'transformer.h.0.attn': 63.91732102791386,
   'transformer.h.0.mlp': 72.45838449838195,
   'transformer.h.1.attn': 108.48160917896826,
   'transformer.h.1.mlp': 98.25952101895331,
   'transformer.h.2.attn': 105.73393221716806,
   'transformer.h.2.mlp': 157.6101965283722,
   'transformer.h.3.attn': 114.146045922409,
   'transformer.h.3.mlp': 115.46617860310457,
   'transformer.h.4.attn': 100.84806083556465,
   'transformer.h.4.mlp': 95.89193990231547,
   'transformer.h.5.attn': 131.30281967082848,
   'transformer.h.5.mlp': 97.5344761810575,
   'transformer.h.6.attn': 102.20784520792095,
   'transformer.h.6.mlp': 109.59660862325092,
   'transformer.h.7.attn': 106.16173384104698,
   'transformer.h.7.mlp': 95.2186388402459,
   'transformer.h.8.attn': 112.70963296738027,
   'transformer.h.8.mlp': 108.70634142756694,
   'transformer.h.9.attn': 127.1640592432581,
   'transformer.h.9.mlp': 84.25118249

In [15]:
list_results[:3]

[{'clean_loss': 63.91732102791386,
  'corrupted_loss': 132.8242889027916,
  'restored_loss': {'transformer.h.0.attn': 63.91732102791386,
   'transformer.h.0.mlp': 72.45838449838195,
   'transformer.h.1.attn': 108.48160917896826,
   'transformer.h.1.mlp': 98.25952101895331,
   'transformer.h.2.attn': 105.73393221716806,
   'transformer.h.2.mlp': 157.6101965283722,
   'transformer.h.3.attn': 114.146045922409,
   'transformer.h.3.mlp': 115.46617860310457,
   'transformer.h.4.attn': 100.84806083556465,
   'transformer.h.4.mlp': 95.89193990231547,
   'transformer.h.5.attn': 131.30281967082848,
   'transformer.h.5.mlp': 97.5344761810575,
   'transformer.h.6.attn': 102.20784520792095,
   'transformer.h.6.mlp': 109.59660862325092,
   'transformer.h.7.attn': 106.16173384104698,
   'transformer.h.7.mlp': 95.2186388402459,
   'transformer.h.8.attn': 112.70963296738027,
   'transformer.h.8.mlp': 108.70634142756694,
   'transformer.h.9.attn': 127.1640592432581,
   'transformer.h.9.mlp': 84.25118249

# TMOD Two models with one data

## Hook

## load models

In [19]:
# First run: tuned run def
def save_tuned_activation(m_id):
    def save_tuned_activation_hook(module, _input, _output):
#         print(m_id, _output.shape)
        if m_id.endswith('attn'):
            tuned_activations[m_id] = _output[0].detach()
        elif m_id.endswith('mlp'):
#         else:
            tuned_activations[m_id] = _output.detach()
    return save_tuned_activation_hook

tuned_model = load_model()
tuned_model.eval()
for m_id in list_trace_module_ids:
    tuned_model.get_submodule(m_id).register_forward_hook(save_tuned_activation(m_id))    

    
# Second run: base run def    
# def corrupt_input_vector(module, _input):#, _output):
#     torch.manual_seed(718)
#     std = torch.std(_input[0])
#     return tuple([_input[0] + (std*1.5) * torch.randn(_input[0].shape).to(device_id), ])

base_model = GPT2LMHeadModel.from_pretrained("gpt2", cache_dir=cache_dir).to(device_id)
base_model.eval()
# base_model.get_submodule("transformer.h.0.attn").register_forward_pre_hook(corrupt_input_vector)    


# Third run: restored run def    
def restore_activation(m_id):
    def restore_activation_hook(module, _input, _output):
        tuned_activation = tuned_activations[m_id]#[:, t]
        if m_id.endswith('attn'):
            return tuple([tuned_activation, tuple([_output[1][0], _output[1][1]])])
        elif m_id.endswith('mlp'):           
            base_output = _output.detach()
            base_output = tuned_activation
            return base_output
    return restore_activation_hook

restored_model = GPT2LMHeadModel.from_pretrained("gpt2", cache_dir=cache_dir).to(device_id)
restored_model.eval()

# restored_model.get_submodule("transformer.h.0.attn").register_forward_pre_hook(corrupt_input_vector)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

## Causal Tracing

In [28]:
try:
    with open(output_file, 'r') as json_file:
        list_results = json.load(json_file)
    for i, d in enumerate(list_results):
        if len(d) == 0:
            done_idx = i - 1
            print(done_idx)
            break
except:
    list_results = [{} for x in range(len_sentences)]
    done_idx = 0

In [29]:
torch.manual_seed(718)
set_seed(718)

for sentence_idx, data in enumerate(tokenized_datasets):
    if done_idx > sentence_idx: continue
    if sentence_idx % 1000 == 0:
        print(f"sentence_idx: {sentence_idx}")
        with open(output_file, 'w') as json_file:
            json.dump(list_results, json_file)
    
    inputs = torch.tensor(data['input_ids']).to(device_id)
    
    # First run: tuned run        
    tuned_activations = {}
    with torch.no_grad():
        tuned_outputs = tuned_model(inputs, labels=inputs.clone())
        tuned_loss = np.exp(tuned_outputs.loss.item())

    # Second run: base run    
    with torch.no_grad():
        base_outputs = base_model(inputs, labels=inputs.clone())
        base_loss = np.exp(base_outputs.loss.item())
        
    # Third run: base-with-restoration run    
    restored_loss = {}
    with torch.no_grad():
        for m_id in list_trace_module_ids:
            hook = restored_model.get_submodule(m_id).register_forward_hook(restore_activation(m_id))
            restored_outputs = restored_model(inputs, labels=inputs.clone())
            restored_loss[m_id] = np.exp(restored_outputs.loss.item())
            hook.remove()
            
    list_results[sentence_idx]['tuned_loss'] = tuned_loss
    list_results[sentence_idx]['base_loss'] = base_loss
    list_results[sentence_idx]['restored_loss'] = restored_loss

sentence_idx: 0


KeyboardInterrupt: 