In [1]:
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from datasets import Dataset
import torch
import pickle
import transformer_lens
from torch.optim import AdamW
from os.path import join
from tqdm.auto import tqdm
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_from_disk
from sklearn.model_selection import train_test_split
from dataclasses import dataclass
import pandas as pd
import json
import numpy as np

import sys
sys.path.append('/home/leiyu/projects/def-yangxu/leiyu/circuit-discovery')
from dmc.circuit_gpt import *
from oqa_dataset import *
from oqa_utils import *
from tqdm.auto import tqdm

In [2]:
@dataclass
class DiffMaskArgs:
    model_dir: str = '/home/leiyu/projects/def-yangxu/leiyu/LMs/'
    data_dir: str = '/home/leiyu/projects/def-yangxu/leiyu/circuit-discovery/data/'
    results_dir: str = '/home/leiyu/scratch/circuit-discovery/mask_logits/'
    data_name: str = 'pararel_data_all.json'
    pararel_rel_ids: str = 'P36 P1376'
    model_name: str = 'gpt2-small'
    gs_temp_weight: float = 0.01
    gs_temp_edge: float = 1.0
    logits_w_init: float = 0.0
    logits_e_init: float = 0.0
    test_ratio: float = 0.2
    batch_size: int = 16
    train_epochs_weight: int = 1000
    train_epochs_edge: int = 50
    lr_weight: float = 0.1
    lr_edge: float = 0.1
    lambda_sparse_weight_init: float = 1.
    lambda_sparse_edge_init: float = 1.
    lambda_complete_weight_init: float = 1.
    lambda_complete_edge_init: float = 1.
    save_every: int = 5
    resume_epoch_w: int = 0
    resume_epoch_e: int = 0
    use_weight_masks: bool = False
    n_epoch_warmup_lambda_sparse: int = 20
    n_epoch_cooldown_lambda_sparse: int = 20
    max_times_lambda_sparse: float = 100.
    min_times_lambda_sparse: float = 0.01
    random_seed: int = 0

args = DiffMaskArgs()

In [4]:
# path that stores gpt-small weights and gpt tokenizer
model_path = join(args.model_dir, args.model_name)
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

In [6]:
# load OQA data
ds_dict = pickle.load(open(join(args.data_dir, 'pararel_capital_ds_dict.p'), 'rb'))
full_model_target_log_probs = torch.load(f'full_model_results/target_log_probs.pt')
full_model_pred_labels = torch.load(f'full_model_results/pred_labels.pt')
capital_vocab_idx = torch.load(f'full_model_results/capital_vocab_idx.pt')
ds_dict['full_model_target_log_probs'] = full_model_target_log_probs
ds_dict['full_model_pred_labels'] = full_model_pred_labels

In [7]:
ds = OQACircuitDataset(ds_dict)
n_train = int(0.8*len(ds))
n_test = len(ds) - n_train
ds_train, ds_test = torch.utils.data.random_split(ds, [n_train, n_test])

train_dl = DataLoader(
    ds_train,
    batch_size=args.batch_size
)
eval_dl = DataLoader(
    ds_test,
    batch_size=args.batch_size,
    shuffle=False
)

In [8]:
ds_train[0]

{'prompt': 'The capital of Haakon County is',
 'label': 50,
 'full_model_target_log_probs': tensor([ -7.3468,  -8.0343,  -8.5534,  -7.9324,  -8.0947,  -5.9757,  -7.2039,
          -7.4835, -11.2138,  -7.4869,  -8.5819, -11.5449, -10.1456,  -6.6759,
         -10.7692,  -7.7960,  -6.8112,  -4.9924,  -7.5453,  -6.9843, -10.4148,
          -6.4187,  -8.2486,  -8.9324,  -8.5128,  -9.5808,  -7.5216,  -4.2317,
          -5.4716,  -8.9545,  -8.1433,  -6.6596,  -6.7450,  -6.4809,  -8.7846,
          -8.3562,  -5.6886,  -6.8378,  -9.1335,  -6.4357,  -6.2878,  -7.6560,
         -10.1810, -10.6336, -10.5368,  -4.8770,  -6.7504,  -2.9278,  -9.5715,
          -9.0878,  -9.5661,  -8.2746,  -4.8770,  -9.5494,  -7.3597,  -9.1926,
          -4.2638, -11.3434,  -8.7225,  -4.2638,  -6.4417,  -8.4388, -10.9134,
         -10.2426,  -4.2317,  -7.2723,  -7.9848, -11.0620,  -8.7507,  -8.4591,
          -9.7536,  -8.8494,  -7.1412,  -8.3157,  -8.1154, -11.4140,  -8.4836,
         -13.1069,  -7.5587,  -6.8248,  

In [27]:

model_path = join(args.model_dir, args.model_name)
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

train_ds, test_ds = prepare_pararel_data(args)

train_dl = DataLoader(
    train_ds,
    batch_size=args.batch_size
)

eval_dl = DataLoader(
    test_ds,
    batch_size=args.batch_size,
    shuffle=False
)


In [34]:
len(train_ds)

702

In [5]:
device = torch.device('cuda')

# # download gpt2-small weights from EasyTransformer and save it
# reference_gpt2 = EasyTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)
# torch.save(reference_gpt2.state_dict(), join(args['model_dir'], 'gpt2-small/gpt2_small_weights.pt'))

gpt_weights = torch.load(join(model_path, 'model_weights.pt')) 
circuit_gpt_config = CircuitGPTConfig(
    debug=False,
    gs_temp_weight=args.gs_temp_weight,
    gs_temp_edge=args.gs_temp_edge,
    use_weight_masks=False
)
circuit_gpt = CircuitGPT(circuit_gpt_config)
circuit_gpt.load_pretrained_weight(gpt_weights)

# load pretrained mask logits if necessary
if args.resume_epoch_w > 0:
    weight_mask_logits = torch.load(join(model_path), f'weight_mask_logits_{resume_epoch_w}.pt')
    circuit_gpt.load_pretrained_weight_mask(weight_mask_logits)
if args.resume_epoch_e > 0:
    edge_mask_logits = torch.load(join(model_path), f'edge_mask_logits_{resume_epoch_e}.pt')
    circuit_gpt.load_pretrained_edge_mask(edge_mask_logits)

circuit_gpt.to(device);

In [6]:
eval_acc, weight_density, edge_density = eval_model(
    circuit_gpt, eval_dl, tokenizer, device, 
    use_weight_mask=False, use_edge_mask=False
)

print(
    f"Epoch 0. mean pruned model accuracy: {eval_acc:.4f}," + 
    f"weight density: {weight_density:.4f}," + 
    f"edge density: {edge_density:.4f}"
)

  0%|          | 0/158 [00:00<?, ?it/s]

Epoch 0. mean pruned model accuracy: 0.0135,weight density: 1.0000,edge density: 1.0000


In [7]:
eval_acc, weight_density, edge_density = eval_model(
    circuit_gpt, train_dl, tokenizer, device, 
    use_weight_mask=False, use_edge_mask=False
)

print(
    f"Epoch 0. mean pruned model accuracy: {eval_acc:.4f}," + 
    f"weight density: {weight_density:.4f}," + 
    f"edge density: {edge_density:.4f}"
)

  0%|          | 0/629 [00:00<?, ?it/s]

Epoch 0. mean pruned model accuracy: 0.0174,weight density: 1.0000,edge density: 1.0000


In [8]:
0.0135 * len(ds_test) + 0.0174 * len(ds_train)

208.8438

In [7]:
# model = circuit_gpt
# model.eval()

# # get weight and edge density     
# model.turn_on_weight_masks(deterministic=True, reverse=False)  
# _, _, weight_density = model.get_weight_density()       
# model.turn_on_edge_masks(deterministic=True, reverse=False)  
# _, _, edge_density = model.get_edge_density()


# model.turn_off_weight_masks()    
# model.turn_off_edge_masks()

# total = len(eval_dl.dataset)
# correct = 0

# _, batch = next(enumerate(eval_dl))

# batch_inputs = prepare_batch_inputs(batch, tokenizer)
# batch_logits = model(batch_inputs['input_ids'].to(device))[0]  # (B, seq_len, vocab_size)
# _, batch_preds = compute_faith_loss(batch_logits, batch_inputs)
# # print(batch_logits_gb)
# correct += (batch_preds == batch_inputs['labels']).sum().cpu().item()

# torch.cuda.empty_cache()

# model.turn_off_weight_masks()
# model.turn_off_edge_masks()

In [8]:
# weight_logits = [mask for _, mask in circuit_gpt.mask_logits_dict_weight.items()]
edge_logits = [mask for _, mask in circuit_gpt.mask_logits_dict_edge.items()]

# optim_weight = torch.optim.AdamW(weight_logits, lr=args.lr_weight)
optim_edge = torch.optim.AdamW(edge_logits, lr=args.lr_edge)

In [9]:
def get_lambda_sparse(epoch, lambda_0, max_times=100., min_times=0.001, 
                      n_epoch_warmup=10, n_epoch_cooldown=10):

    if epoch < n_epoch_warmup:
        return lambda_0  + lambda_0 * (max_times - 1) * epoch / n_epoch_warmup
        
    elif epoch < n_epoch_warmup + n_epoch_cooldown:
        return lambda_0 * max_times - lambda_0 * (max_times - min_times) * (epoch - n_epoch_warmup) / n_epoch_cooldown
        
    else:
        return lambda_0 * min_times
        
# it takes about 35 mins to run 100 epochs of edge mask training on one A100 GPU with batch_size=32
for epoch in tqdm(range(args.train_epochs_edge)):
    lambda_sparse_edge = get_lambda_sparse(
        epoch, 
        lambda_0=args.lambda_sparse_edge_init,
        max_times=args.max_times_lambda_sparse, 
        min_times=args.min_times_lambda_sparse, 
        n_epoch_warmup=args.n_epoch_warmup_lambda_sparse,
        n_epoch_cooldown=args.n_epoch_cooldown_lambda_sparse,
    )
    lambda_complete_edge = args.lambda_complete_edge_init
    
    for batch in tqdm(train_dl):
        batch_inputs = prepare_batch_inputs(batch, tokenizer)
        
        circuit_gpt.turn_on_edge_masks(deterministic=False)
        sparse_loss_edge = circuit_gpt.edge_sparseness_loss()
    
        batch_logits = circuit_gpt(batch_inputs['input_ids'].to(device))[0] 
        faith_loss_edge, _ = compute_faith_loss(batch_logits, batch_inputs) 
        
        circuit_gpt.turn_on_edge_masks(deterministic=False, reverse=True)
        batch_logits = circuit_gpt(batch_inputs['input_ids'].to(device))[0] 
        complete_loss_edge, _ = compute_complete_loss(batch_logits, batch_inputs) 
        
        loss_edge = faith_loss_edge + sparse_loss_edge *  lambda_sparse_edge + complete_loss_edge * lambda_complete_edge
        loss_edge.backward()
        optim_edge.step()
        optim_edge.zero_grad()
        torch.cuda.empty_cache()

    eval_acc_pruned, weight_density, edge_density = eval_model(
        circuit_gpt, eval_dl, tokenizer, device, use_weight_mask=False, reverse=False
    )
    eval_acc_complement, _, _ = eval_model(
        circuit_gpt, eval_dl, tokenizer, device, use_weight_mask=False, reverse=True
    )
    print(
        "Epoch {}. discovered circuit accuracy: {:.4f}, complementary circuit accuracy: {:.4f}, weight density: {:.4f}, edge density: {:.4f}".format(
            epoch + 1, eval_acc_pruned, eval_acc_complement, weight_density, edge_density)
    )

    # save good edge masks
    if eval_acc_pruned > 0.95 and edge_density < 0.05:
        torch.save(
            circuit_gpt.mask_logits_dict_edge,
            join(args.results_dir, f'mask_logits_dict_edge_oqa_edge_only_{epoch}.pt')
        )

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/158 [00:00<?, ?it/s]

  0%|          | 0/158 [00:00<?, ?it/s]

Epoch 1. discovered circuit accuracy: 0.1046, complementary circuit accuracy: 0.0000, weight density: 1.0000, edge density: 0.5049


KeyboardInterrupt: 

In [12]:
# circuit_gpt initialization
device = torch.device('cuda')    
gpt_weights = torch.load(join(model_path, 'model_weights.pt')) 
circuit_gpt_config = CircuitGPTConfig(
    debug=False,
    gs_temp_weight=args.gs_temp_weight,
    gs_temp_edge=args.gs_temp_edge,
    use_weight_masks=True,
    logits_w_init=args.logits_w_init
    
)
circuit_gpt = CircuitGPT(circuit_gpt_config)
circuit_gpt.load_pretrained_weight(gpt_weights)

In [9]:
circuit_gpt

CircuitGPT(
  (embed): Embed()
  (pos_embed): PosEmbed()
  (ln_final): LayerNorm()
  (unembed): Unembed()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNorm()
      (attn): Attention()
      (ln2): LayerNorm()
      (mlp): MLP()
    )
  )
)

In [14]:
circuit_gpt.mask_logits_dict_weight.keys()

dict_keys(['blocks.0.attn.W_Q', 'blocks.0.attn.b_Q', 'blocks.0.attn.W_K', 'blocks.0.attn.b_K', 'blocks.0.attn.W_V', 'blocks.0.attn.b_V', 'blocks.0.attn.W_O', 'blocks.0.attn.b_O', 'blocks.0.mlp.W_in', 'blocks.0.mlp.b_in', 'blocks.0.mlp.W_out', 'blocks.0.mlp.b_out', 'blocks.1.attn.W_Q', 'blocks.1.attn.b_Q', 'blocks.1.attn.W_K', 'blocks.1.attn.b_K', 'blocks.1.attn.W_V', 'blocks.1.attn.b_V', 'blocks.1.attn.W_O', 'blocks.1.attn.b_O', 'blocks.1.mlp.W_in', 'blocks.1.mlp.b_in', 'blocks.1.mlp.W_out', 'blocks.1.mlp.b_out', 'blocks.2.attn.W_Q', 'blocks.2.attn.b_Q', 'blocks.2.attn.W_K', 'blocks.2.attn.b_K', 'blocks.2.attn.W_V', 'blocks.2.attn.b_V', 'blocks.2.attn.W_O', 'blocks.2.attn.b_O', 'blocks.2.mlp.W_in', 'blocks.2.mlp.b_in', 'blocks.2.mlp.W_out', 'blocks.2.mlp.b_out', 'blocks.3.attn.W_Q', 'blocks.3.attn.b_Q', 'blocks.3.attn.W_K', 'blocks.3.attn.b_K', 'blocks.3.attn.W_V', 'blocks.3.attn.b_V', 'blocks.3.attn.W_O', 'blocks.3.attn.b_O', 'blocks.3.mlp.W_in', 'blocks.3.mlp.b_in', 'blocks.3.mlp.W_o

In [15]:
n_attn_weight = 0
n_weight = 0

with torch.no_grad():
    for key, mask in circuit_gpt.mask_logits_dict_weight.items():
        n_weight += torch.ones_like(mask).sum()
        if 'attn' in key:
            n_attn_weight += torch.ones_like(mask).sum()

In [16]:
n_attn_weight

tensor(28348416., device='cuda:0')

In [18]:
n_attn_weight / n_weight

tensor(0.3334, device='cuda:0')