In [2]:
import torch
import numpy as np
import random
import json
import os
from tqdm import tqdm

In [3]:
MODEL_DIR = "/mnt/nas/jinho/GrokkedTransformer/trained_checkpoints/composition.2000.200.inf-controlled_wd-0.1_layer-8_head-12_seed-42"
STEP_LIST=[250, 3500, 5000, 300000]
BASE_DIR = "/home/jinho/repos/GrokkedTransformer"

if MODEL_DIR.split("/")[-1] == "":
    dataset = MODEL_DIR.split("/")[-2].split("_")[0]
else:
    dataset = MODEL_DIR.split("/")[-1].split("_")[0]

In [4]:
import logging
def setup_logging(debug_mode):
    level = logging.DEBUG if debug_mode else logging.INFO
    logging.basicConfig(level=level, format='%(asctime)s - %(levelname)s - %(message)s')

setup_logging(False)

In [5]:
all_checkpoints = [checkpoint for checkpoint in os.listdir(MODEL_DIR) if checkpoint.startswith("checkpoint") and int(checkpoint.split("-")[-1]) in STEP_LIST]
assert all(os.path.isdir(os.path.join(MODEL_DIR, checkpoint)) for checkpoint in all_checkpoints)
all_checkpoints.sort(key=lambda var: int(var.split("-")[1]))

all_checkpoints

['checkpoint-250', 'checkpoint-3500', 'checkpoint-5000', 'checkpoint-300000']

In [6]:
import torch.nn.functional as F
def return_rank(hd, word_embedding_, token_ids_list, metric='dot', token_list=None):
    if metric == 'dot':
        word_embedding = word_embedding_
    elif metric == 'cos':
        word_embedding = F.normalize(word_embedding_, p=2, dim=1)
    else:
        assert False

    logits_ = torch.matmul(hd, word_embedding.T)
    batch_size, seq_len, vocab_size = logits_.shape[0], logits_.shape[1], logits_.shape[2]
    
    token_ids_list = torch.tensor(token_ids_list).view(batch_size, 1, 1).expand(batch_size, seq_len, vocab_size).to(logits_.device) 

    _, sorted_indices = logits_.sort(dim=-1, descending=True)
    rank = (sorted_indices == token_ids_list).nonzero(as_tuple=True)[-1].view(batch_size, seq_len).cpu()

    return rank

In [7]:
def intervene_and_measure(original_data, intervene_data, model, tokenizer, device, random=False):
    # Todo : Data Batch processing
    results = []
    value = 0
    skipped_data = 0
    
    word_embedding = model.lm_head.weight.data
    
    for bridge_entity, entries in tqdm(original_data.items()):
        # Find matching hidden representation from reference_data
        if not bridge_entity in intervene_data:
            continue # skip because there is no same bridge entity in intervene_data
        for original_data in entries:
            assert original_data['identified_target'] == bridge_entity
            original_input_list = [original_data['input_text']]
            target_text_list = [original_data['target_text']]
            # print()
            # print(f"original_input: {original_input_list}")
            # print(f"target_text: {target_text_list}")
            temp_dict = dict()
            
            real_h_r1_r2_t_list = [target.strip("><").split("><") for target in target_text_list]
            real_b_list, real_t_list, real_r2_list = [], [], []
            for tokens in real_h_r1_r2_t_list:
                assert len(tokens) == 5
                real_b_list.append(bridge_entity)
                real_r2_list.append(tokens[2])
                real_t_list.append(tokens[3])
            
            tokenizer_output = tokenizer(original_input_list, return_tensors="pt", padding=True)
            input_ids, attention_mask = tokenizer_output["input_ids"], tokenizer_output["attention_mask"]
            input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
            
            with torch.no_grad():
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    output_hidden_states=True
                )
            all_hidden_states = outputs['hidden_states']
            rank_before = return_rank(all_hidden_states[8], word_embedding, tokenizer([f"<{target}>" for target in real_t_list])["input_ids"])[:, -1].tolist()
            temp_dict['rank_before'] = rank_before        
            # print(temp_dict)

            # perturb the 1st relation
            query_list = []
            for i, real_tokens in enumerate(real_h_r1_r2_t_list):
                intervene_entries = [entry for entry in intervene_data[bridge_entity] if entry["input_text"].strip("><").split("><")[:2] != real_tokens[:2]]
                # print(f"intervene_entries: {[(entry['input_text'], entry['identified_target']) for entry in intervene_entries]}")
                # assert len(reference_entries) > 0
                if len(intervene_entries) == 0:
                    continue
                selected_intervene_data = intervene_entries[np.random.randint(0, len(intervene_entries))]
                query_list.append(''.join([f"<{token}>" for token in selected_intervene_data["input_text"].strip("><").split("><")[:2]]))
            if len(query_list) == 0:
                skipped_data += 1
                continue
            
            tokenizer_output = tokenizer(query_list, return_tensors="pt", padding=True)
            input_ids_, attention_mask = tokenizer_output["input_ids"], tokenizer_output["attention_mask"]
            input_ids_, attention_mask = input_ids_.to(device), attention_mask.to(device)
            with torch.no_grad():
                outputs_ctft = model(
                    input_ids=input_ids_,
                    attention_mask=attention_mask,
                    output_hidden_states=True
                )
            all_hidden_states_ctft = outputs_ctft['hidden_states']

            for layer_to_intervene in range(1, 8):
                hidden_states = all_hidden_states[layer_to_intervene].clone()
                hidden_states_ctft = all_hidden_states_ctft[layer_to_intervene]
                # intervene
                hidden_states[:, 1, :] = hidden_states_ctft[:, 1, :]

                with torch.no_grad():
                    for i in range(layer_to_intervene, 8):
                        f_layer = model.transformer.h[i]
                        # attn
                        residual = hidden_states
                        hidden_states = f_layer.ln_1(hidden_states)
                        attn_output = f_layer.attn(hidden_states)[0] 
                        hidden_states = attn_output + residual
                        # mlp
                        residual = hidden_states
                        hidden_states = f_layer.ln_2(hidden_states)
                        feed_forward_hidden_states = f_layer.mlp.c_proj(f_layer.mlp.act(f_layer.mlp.c_fc(hidden_states)))
                        hidden_states = residual + feed_forward_hidden_states
                    # final ln
                    hidden_states = model.transformer.ln_f(hidden_states)
                # print("--------")
                rank_after = return_rank(hidden_states, word_embedding, tokenizer([f"<{target}>" for target in real_t_list])["input_ids"])[:, -1].tolist()
                temp_dict['r1_'+str(layer_to_intervene)] = rank_after
                
            if temp_dict["r1_5"] != temp_dict['rank_before']:
                value += 1
            # print(temp_dict)
            result_dict_list = [dict()]
            for key, value_list in temp_dict.items():
                result_dict_list[0][key] = value_list[0]
                    
            results = results + result_dict_list
    return results, value

In [8]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

np.random.seed(0)
device = "cuda:0"

def get_total_list_length(json_data):
    total_length = 0
    for key, value in json_data.items():
        if isinstance(value, list):
            total_length += len(value)
    return total_length

results = dict()
for checkpoint in all_checkpoints:
    result_ckpt = {}
    print("\nnow checkpoint", checkpoint)
    step = checkpoint.split("-")[-1]
    model_path = os.path.join(MODEL_DIR, checkpoint)
    model = GPT2LMHeadModel.from_pretrained(model_path).to(device)
    word_embedding = model.lm_head.weight.data
    tokenizer = GPT2Tokenizer.from_pretrained(model_path)
    tokenizer.padding_side = "left"
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model.config.pad_token_id = model.config.eos_token_id
    model.eval()
    
    # Load already deduplicated hidden representation results file
    with open(f"/mnt/nas/jinho/GrokkedTransformer/collapse_analysis/composition.2000.200.inf-controlled/(5,1)/{step}/id_train_dedup.json") as f:
        id_train_dedup = json.load(f)
    print(f"Total data number of id_train_dedup.json : {get_total_list_length(id_train_dedup)}")
    with open(f"/mnt/nas/jinho/GrokkedTransformer/collapse_analysis/composition.2000.200.inf-controlled/(5,1)/{step}/id_test_dedup.json") as f:
        id_test_dedup = json.load(f)
    print(f"Total data number of id_test_dedup.json : {get_total_list_length(id_test_dedup)}")
    with open(f"/mnt/nas/jinho/GrokkedTransformer/collapse_analysis/composition.2000.200.inf-controlled/(5,1)/{step}/ood_dedup.json") as f:
        ood_dedup = json.load(f)
    print(f"Total data number of ood_dedup.json : {get_total_list_length(ood_dedup)}")
    with open(f"/mnt/nas/jinho/GrokkedTransformer/collapse_analysis/composition.2000.200.inf-controlled/(5,1)/{step}/nonsense_dedup.json") as f:
        nonsense_dedup = json.load(f)
    print(f"Total data number of nonsense_dedup.json : {get_total_list_length(nonsense_dedup)}")

    # id_test_results = intervene_and_measure(id_train_dedup, id_test_dedup, model, tokenizer)
    id_train_test_results, value = intervene_and_measure(id_train_dedup, id_test_dedup, model, tokenizer, device)
    print(f"train_inferred <- test_inferred_id : {value}")
    result_ckpt["train_inferred-test_inferred_id"] = id_train_test_results
    id_train_ood_results, value = intervene_and_measure(id_train_dedup, ood_dedup, model, tokenizer, device)
    print(f"train_inferred <- test_inferred_ood : {value}")
    result_ckpt["train_inferred-test_inferred_ood"] = id_train_ood_results
    id_test_ood_results, value = intervene_and_measure(id_test_dedup, ood_dedup, model, tokenizer, device)
    print(f"test_inferred_id <- test_inferred_ood : {value}")
    result_ckpt["test_inferred_id-test_inferred_ood"] = id_test_ood_results
    id_train_results, value = intervene_and_measure(id_train_dedup, id_train_dedup, model, tokenizer, device)
    print(f"train_inferred <- train_inferred : {value}")
    result_ckpt["train_inferred"] = id_train_results
    id_test_results, value = intervene_and_measure(id_test_dedup, id_test_dedup, model, tokenizer, device)
    print(f"test_inferred_id <- test_inferred_id : {value}")
    result_ckpt["test_inferred_id"] = id_test_results
    ood_results, value = intervene_and_measure(ood_dedup, ood_dedup, model, tokenizer, device)
    print(f"test_inferred_ood <- test_inferred_ood : {value}")
    result_ckpt["test_inferred_ood"] = ood_results
    results[checkpoint] = result_ckpt

with open(os.path.join(BASE_DIR, "collapse_analysis", "tracing_results", f"{MODEL_DIR.split('/')[-1]}.json"), "w", encoding='utf-8') as f:
        json.dump(results, f, indent=4)


now checkpoint checkpoint-250


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Total data number of id_train_dedup.json : 2895
Total data number of id_test_dedup.json : 2889
Total data number of ood_dedup.json : 1322
Total data number of nonsense_dedup.json : 3000


100%|██████████| 1536/1536 [05:42<00:00,  4.49it/s]


train_inferred <- test_inferred_id : 1449


100%|██████████| 1536/1536 [02:56<00:00,  8.68it/s]


train_inferred <- test_inferred_ood : 770


100%|██████████| 1510/1510 [03:03<00:00,  8.24it/s]


test_inferred_id <- test_inferred_ood : 806


100%|██████████| 1536/1536 [06:01<00:00,  4.25it/s]


train_inferred <- train_inferred : 1430


100%|██████████| 1510/1510 [06:09<00:00,  4.09it/s]


test_inferred_id <- test_inferred_id : 1548


100%|██████████| 824/824 [02:27<00:00,  5.58it/s]


test_inferred_ood <- test_inferred_ood : 779

now checkpoint checkpoint-3500


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Total data number of id_train_dedup.json : 2895
Total data number of id_test_dedup.json : 2889
Total data number of ood_dedup.json : 1322
Total data number of nonsense_dedup.json : 3000


100%|██████████| 1536/1536 [05:40<00:00,  4.50it/s]


train_inferred <- test_inferred_id : 0


100%|██████████| 1536/1536 [03:00<00:00,  8.49it/s]


train_inferred <- test_inferred_ood : 1


100%|██████████| 1510/1510 [03:01<00:00,  8.34it/s]


test_inferred_id <- test_inferred_ood : 69


100%|██████████| 1536/1536 [05:40<00:00,  4.51it/s]


train_inferred <- train_inferred : 0


100%|██████████| 1510/1510 [06:04<00:00,  4.14it/s]


test_inferred_id <- test_inferred_id : 54


100%|██████████| 824/824 [02:24<00:00,  5.70it/s]


test_inferred_ood <- test_inferred_ood : 851

now checkpoint checkpoint-5000


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Total data number of id_train_dedup.json : 2895
Total data number of id_test_dedup.json : 2889
Total data number of ood_dedup.json : 1322
Total data number of nonsense_dedup.json : 3000


100%|██████████| 1536/1536 [05:38<00:00,  4.53it/s]


train_inferred <- test_inferred_id : 1


100%|██████████| 1536/1536 [03:02<00:00,  8.42it/s]


train_inferred <- test_inferred_ood : 3


100%|██████████| 1510/1510 [03:00<00:00,  8.38it/s]


test_inferred_id <- test_inferred_ood : 43


100%|██████████| 1536/1536 [05:56<00:00,  4.31it/s]


train_inferred <- train_inferred : 0


100%|██████████| 1510/1510 [06:02<00:00,  4.16it/s]


test_inferred_id <- test_inferred_id : 28


100%|██████████| 824/824 [02:25<00:00,  5.68it/s]


test_inferred_ood <- test_inferred_ood : 852

now checkpoint checkpoint-300000


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Total data number of id_train_dedup.json : 2895
Total data number of id_test_dedup.json : 2889
Total data number of ood_dedup.json : 1322
Total data number of nonsense_dedup.json : 3000


100%|██████████| 1536/1536 [05:43<00:00,  4.48it/s]


train_inferred <- test_inferred_id : 2


100%|██████████| 1536/1536 [03:01<00:00,  8.47it/s]


train_inferred <- test_inferred_ood : 611


100%|██████████| 1510/1510 [03:00<00:00,  8.36it/s]


test_inferred_id <- test_inferred_ood : 646


100%|██████████| 1536/1536 [06:02<00:00,  4.24it/s]


train_inferred <- train_inferred : 1


100%|██████████| 1510/1510 [06:03<00:00,  4.16it/s]


test_inferred_id <- test_inferred_id : 2


100%|██████████| 824/824 [02:25<00:00,  5.67it/s]


test_inferred_ood <- test_inferred_ood : 851


In [9]:
with open(os.path.join(BASE_DIR, "collapse_analysis",  "tracing_results", f"{MODEL_DIR.split('/')[-1]}.json")) as f:
    results = json.load(f)

refined_results = {}
all_checkpoints = list(results.keys())
all_checkpoints.sort(key=lambda var: int(var.split("-")[1]))
for checkpoints in all_checkpoints:
    print("\nnow checkpoint", checkpoint)
    results_4_ckpt = {}
    for key, entries in tqdm(results[checkpoints].items()):
        result_4_type = {}
        result_4_type["total_num"] = len(entries)
        for i in range(1,8):
            result_4_type[f"r1_{i}"] = 0
        for entry in entries:
            for i in range(1,8):
                if entry["rank_before"] != entry[f"r1_{i}"]:
                    result_4_type[f"r1_{i}"] += 1
        results_4_ckpt[key] = result_4_type
    refined_results[checkpoints] = results_4_ckpt
    
with open(os.path.join(BASE_DIR, "collapse_analysis", "tracing_results", f"{MODEL_DIR.split('/')[-1]}_refined-result.json"), "w", encoding='utf-8') as f:
        json.dump(refined_results, f, indent=4)


now checkpoint checkpoint-300000


100%|██████████| 6/6 [00:00<00:00, 249.27it/s]



now checkpoint checkpoint-300000


100%|██████████| 6/6 [00:00<00:00, 338.81it/s]



now checkpoint checkpoint-300000


100%|██████████| 6/6 [00:00<00:00, 337.72it/s]



now checkpoint checkpoint-300000


100%|██████████| 6/6 [00:00<00:00, 329.39it/s]

{'checkpoint-250': {'train_inferred-test_inferred_id': {'total_num': 2199, 'r1_1': 1773, 'r1_2': 1683, 'r1_3': 1605, 'r1_4': 1515, 'r1_5': 1449, 'r1_6': 1298, 'r1_7': 1120}, 'train_inferred-test_inferred_ood': {'total_num': 1160, 'r1_1': 934, 'r1_2': 878, 'r1_3': 843, 'r1_4': 806, 'r1_5': 770, 'r1_6': 727, 'r1_7': 638}, 'test_inferred_id-test_inferred_ood': {'total_num': 1173, 'r1_1': 952, 'r1_2': 927, 'r1_3': 875, 'r1_4': 831, 'r1_5': 806, 'r1_6': 758, 'r1_7': 661}, 'train_inferred': {'total_num': 2198, 'r1_1': 1736, 'r1_2': 1667, 'r1_3': 1561, 'r1_4': 1501, 'r1_5': 1430, 'r1_6': 1306, 'r1_7': 1125}, 'test_inferred_id': {'total_num': 2223, 'r1_1': 1814, 'r1_2': 1769, 'r1_3': 1685, 'r1_4': 1615, 'r1_5': 1548, 'r1_6': 1425, 'r1_7': 1264}, 'test_inferred_ood': {'total_num': 853, 'r1_1': 827, 'r1_2': 808, 'r1_3': 795, 'r1_4': 786, 'r1_5': 779, 'r1_6': 739, 'r1_7': 724}}, 'checkpoint-3500': {'train_inferred-test_inferred_id': {'total_num': 2199, 'r1_1': 32, 'r1_2': 7, 'r1_3': 1, 'r1_4': 1,


