In [1]:
from matplotlib import pyplot as plt
import numpy as np
import yaml
import torch
from os.path import join
from transformers import AutoTokenizer
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
import torch
import pickle
import transformer_lens
from torch.optim import AdamW
from os.path import join
from tqdm.auto import tqdm
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_from_disk
from sklearn.model_selection import train_test_split
from dataclasses import dataclass

import sys
sys.path.append('/home/leiyu/projects/def-yangxu/leiyu/circuit-discovery')
from dmc.circuit_gpt import CircuitGPT, CircuitGPTConfig
from ioi_dataset import *
from ioi_utils import *

data_dir = '/home/leiyu/projects/def-yangxu/leiyu/circuit-discovery/data/'
results_dir = '/home/leiyu/scratch/circuit-discovery/mask_logits/'
mask_dir = '/home/leiyu/scratch/circuit-discovery/mask_logits/'
model_dir = '/home/leiyu/projects/def-yangxu/leiyu/LMs/'
model_name = 'gpt2-small'


In [2]:

ds_idx, n_epoch_w, n_epoch_e = 0, 0, 66

mask_logits_dict_edge = torch.load(
    join(mask_dir, f'mask_logits_dict_edge_ioi_{ds_idx}_weight_{n_epoch_w}_edge_{n_epoch_e}.pt')
)

mask_logits_dict_edge = {k:v.detach().cpu() for k,v in mask_logits_dict_edge.items()}
torch.cuda.empty_cache()

In [3]:
def head_ablation(mask_logits_dict_edge, layer_is, head_js):
    ablated_mask_logits_dict_edge = {k:torch.clone(v) for k,v in mask_logits_dict_edge.items()}
    for layer_i, head_j in zip(layer_is, head_js):
        for k in range(layer_i, 12):
            mlp_masks_logits_k = ablated_mask_logits_dict_edge[f'blocks.{k}.edge_mask_mlp_logits']
            mlp_masks_logits_ki = mlp_masks_logits_k[1:][layer_i*13:(1+layer_i)*13]
            mlp_masks_logits_ki[head_j] = -1
    
            if k > layer_i:
                for l in range(12):
                    attn_q_masks_logits_kli = ablated_mask_logits_dict_edge[f'blocks.{k}.edge_mask_attention_q_logits'][:,l][1:][layer_i*13:(layer_i+1)*13]
                    attn_k_masks_logits_kli = ablated_mask_logits_dict_edge[f'blocks.{k}.edge_mask_attention_k_logits'][:,l][1:][layer_i*13:(layer_i+1)*13]
                    attn_v_masks_logits_kli = ablated_mask_logits_dict_edge[f'blocks.{k}.edge_mask_attention_v_logits'][:,l][1:][layer_i*13:(layer_i+1)*13]
                    
                    # print(attn_q_masks_logits_kli.shape)
                    
                    attn_q_masks_logits_kli[head_j] = -1.
                    attn_k_masks_logits_kli[head_j] = -1.
                    attn_v_masks_logits_kli[head_j] = -1.
    
        ablated_mask_logits_dict_edge['edge_mask_output_logits'][1:][layer_i*13:(1+layer_i)*13][head_j] = -1.
            
    return ablated_mask_logits_dict_edge

In [4]:
# load IOI data
# Note that there are overlaps between train and test sets, due to the way IOIDataset is constructed (randomly sample N items)
ioi_prompts = pickle.load(open(join(data_dir, f'ioi_prompts_{ds_idx}.p'), 'rb'))
full_model_target_log_probs = torch.load(f'full_model_results/target_log_probs_{ds_idx}.pt')
full_model_pred_labels = torch.load(f'full_model_results/pred_labels_{ds_idx}.pt')

data_dict = prepare_ioi_data_for_clm(ioi_prompts, full_model_target_log_probs, full_model_pred_labels)
ds = IOICircuitDataset(data_dict)
dl = DataLoader(
    ds,
    batch_size=32
)

In [10]:
# evaluation before any weight/edge pruning
l = 1
layer_is = torch.repeat_interleave(torch.arange(l),12)
head_js = torch.arange(12).repeat(l)
# layer_is = [0]*12 + [1]*12 
# head_js = list(range(12)) + list(range(12)) 

ds_idx, n_epoch_w, n_epoch_e = 0, 0, 66

mask_logits_dict_edge = torch.load(
    join(mask_dir, f'mask_logits_dict_edge_ioi_{ds_idx}_weight_{n_epoch_w}_edge_{n_epoch_e}.pt')
)
mask_logits_dict_edge = {k:v.detach().cpu() for k,v in mask_logits_dict_edge.items()}
torch.cuda.empty_cache()
mask_logits_dict_edge = head_ablation(mask_logits_dict_edge, layer_is, head_js)



# circuit_gpt initialization
model_path = join(model_dir, model_name)
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

device = torch.device('cuda')    
gpt_weights = torch.load(join(model_path, 'model_weights.pt')) 
circuit_gpt_config = CircuitGPTConfig(
    debug=False,
    gs_temp_weight=1.0,
    gs_temp_edge=1.0,
    use_weight_masks=True
)
circuit_gpt = CircuitGPT(circuit_gpt_config)
circuit_gpt.load_pretrained_weight(gpt_weights)
# weight_mask_logits = torch.load(join(results_dir, f'mask_logits_dict_weight_ioi_0_weight_257_edge_0.pt'))
# weight_mask_logits = {k: v.detach() for k,v in weight_mask_logits.items()}


# circuit_gpt.load_pretrained_weight_mask(weight_mask_logits)
circuit_gpt.load_pretrained_edge_mask(mask_logits_dict_edge)
circuit_gpt.to(device);

#################

eval_results_full_model = eval_model(circuit_gpt, dl, tokenizer, device, 
    use_weight_mask=False, use_edge_mask=True, reverse=False
)
print(
    f"Epoch 0. mean pruned model eval accuracy: {eval_results_full_model['acc']:.2f}," + 
    f"mean eval kl-div: {eval_results_full_model['kl']:.4f}," + 
    f"weight density: {eval_results_full_model['weight_density']:.4f}," + 
    f"edge density: {eval_results_full_model['edge_density']:.4f}"
)

torch.cuda.empty_cache()

Epoch 0. mean pruned model eval accuracy: 0.94,mean eval kl-div: 0.1620,weight density: 1.0000,edge density: 0.0240


In [211]:
torch.repeat_interleave(torch.arange(3),12)

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

In [210]:
torch.arange(12).repeat(3)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,  0,  1,  2,  3,  4,  5,
         6,  7,  8,  9, 10, 11,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])