In [1]:
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
import torch
import pickle
import transformer_lens
from torch.optim import AdamW
from os.path import join
from tqdm.auto import tqdm
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_from_disk
from sklearn.model_selection import train_test_split
from dataclasses import dataclass

from circuit_gpt import CircuitGPT, CircuitGPTConfig
from ioi_dataset import IOIDataset, CircuitIOIDataset, prepare_batch_inputs, prepare_ioi_data_for_clm

In [2]:
@dataclass
class DiffMaskArgs:
    model_dir: str = '/home/leiyu/projects/def-yangxu/leiyu/LMs/'
    data_dir: str = '/home/leiyu/projects/def-yangxu/leiyu/circuit-discovery/data/'
    results_dir: str = '/home/leiyu/scratch/circuit-discovery/mask_logits/'
    model_name: str = 'gpt2-small'
    gs_temp_weight: float = 0.01
    gs_temp_edge: float = 1.0
    logits_w_init: float = 0.0
    logits_e_init: float = 0.0
    batch_size: int = 32
    train_epochs_weight: int = 1000
    train_epochs_edge: int = 50
    lr_weight: float = 0.1
    lr_edge: float = 0.1
    lambda_sparse_weight_init: float = 1.
    lambda_sparse_edge_init: float = 1.
    lambda_complete_weight_init: float = 1.
    lambda_complete_edge_init: float = 1.
    save_every: int = 5
    resume_epoch_w: int = 0
    resume_epoch_e: int = 0
    use_weight_masks: bool = False
    n_epoch_warmup_lambda_sparse: int = 20
    n_epoch_cooldown_lambda_sparse: int = 20
    max_times_lambda_sparse: float = 100.
    min_times_lambda_sparse: float = 0.01
    n_ioi_data: int = 640


def compute_faith_loss(batch_logits, batch_inputs):
    # batch_logits: (B, seq_len, vocab_size)
    batch_seq_lens = batch_inputs['seq_lens']
    batch_size = batch_logits.shape[0]

    logits_target_good = batch_logits[torch.arange(batch_size), batch_seq_lens - 1, batch_inputs['target good']]
    logits_target_bad = batch_logits[torch.arange(batch_size), batch_seq_lens - 1, batch_inputs['target bad']]
    logits_gb = torch.stack([logits_target_good, logits_target_bad], -1)  # (B,2)

    batch_labels = torch.zeros(batch_size).long().to(logits_gb.device)
    batch_faith_loss = nn.functional.cross_entropy(logits_gb, batch_labels)

    return batch_faith_loss, logits_gb


def compute_complete_loss(batch_logits, batch_inputs):
    # batch_logits: (B, seq_len, vocab_size)
    batch_seq_lens = batch_inputs['seq_lens']
    batch_size = batch_logits.shape[0]

    logits_target_good = batch_logits[torch.arange(batch_size), batch_seq_lens - 1, batch_inputs['target good']]
    logits_target_bad = batch_logits[torch.arange(batch_size), batch_seq_lens - 1, batch_inputs['target bad']]
    logits_gb = torch.stack([logits_target_good, logits_target_bad], -1)  # (B,2)

    batch_probs_uniform = torch.ones(logits_gb.shape).to(logits_gb.device) * 0.5
    batch_complete_loss = nn.functional.cross_entropy(logits_gb, batch_probs_uniform)

    return batch_complete_loss, logits_gb


@torch.no_grad()
def eval_model(model, eval_dl, tokenizer, device, use_weight_mask=True, use_edge_mask=True, reverse=False):
    model.eval()

    # get weight and edge density     
    model.turn_on_weight_masks(deterministic=True, reverse=reverse)  
    _, _, weight_density = model.get_weight_density()       
    model.turn_on_edge_masks(deterministic=True, reverse=reverse)  
    _, _, edge_density = model.get_edge_density()

    if not use_weight_mask:
        model.turn_off_weight_masks()
    if not use_edge_mask:     
        model.turn_off_edge_masks()

    total = len(eval_dl.dataset)
    correct = 0

    for batch in eval_dl:
        batch_inputs = prepare_batch_inputs(batch, tokenizer)
        batch_logits = model(batch_inputs['input_ids'].to(device))[0]  # (B, seq_len, vocab_size)
        _, batch_logits_gb = compute_faith_loss(batch_logits, batch_inputs)
        # print(batch_logits_gb)
        correct += (batch_logits_gb[:, 0] > batch_logits_gb[:, 1]).sum().cpu().item()

        torch.cuda.empty_cache()

    model.turn_off_weight_masks()
    model.turn_off_edge_masks()

    acc = correct / total

    return acc, weight_density, edge_density
     

In [3]:
# with open('diff_mask_ioi.yml') as f:
#     args = yaml.safe_load(f)
args = DiffMaskArgs()
model_path = join(args.model_dir, args.model_name)
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

ds = IOIDataset(prompt_type="ABBA", N=args.n_ioi_data, tokenizer=tokenizer)
ds_train, ds_test = train_test_split(ds.ioi_prompts, test_size=0.2, random_state=0)
# Note that there are overlaps between train and test sets, due to the way IOIDataset is constructed (randomly sample N items)

ioi_ds_train = CircuitIOIDataset(prepare_ioi_data_for_clm(ds_train))
ioi_ds_test = CircuitIOIDataset(prepare_ioi_data_for_clm(ds_test))

train_dl = DataLoader(
    ioi_ds_train,
    batch_size=args.batch_size
)
eval_dl = DataLoader(
    ioi_ds_train,
    batch_size=args.batch_size,
    shuffle=False
)

In [4]:
device = torch.device('cuda')

# # download gpt2-small weights from EasyTransformer and save it
# reference_gpt2 = EasyTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)
# torch.save(reference_gpt2.state_dict(), join(args['model_dir'], 'gpt2-small/gpt2_small_weights.pt'))

gpt_weights = torch.load(join(model_path, 'model_weights.pt')) 
circuit_gpt_config = CircuitGPTConfig(
    debug=False,
    gs_temp_weight=args.gs_temp_weight,
    gs_temp_edge=args.gs_temp_edge,
    use_weight_masks=False
)
circuit_gpt = CircuitGPT(circuit_gpt_config)
circuit_gpt.load_pretrained_weight(gpt_weights)

# load pretrained mask logits if necessary
if args.resume_epoch_w > 0:
    weight_mask_logits = torch.load(join(model_path), f'weight_mask_logits_{resume_epoch_w}.pt')
    circuit_gpt.load_pretrained_weight_mask(weight_mask_logits)
if args.resume_epoch_e > 0:
    edge_mask_logits = torch.load(join(model_path), f'edge_mask_logits_{resume_epoch_e}.pt')
    circuit_gpt.load_pretrained_edge_mask(edge_mask_logits)

circuit_gpt.to(device);

In [5]:
eval_acc, weight_density, edge_density = eval_model(
    circuit_gpt, eval_dl, tokenizer, device, 
    use_weight_mask=False, use_edge_mask=False
)

print(
    f"Epoch 0. mean pruned model accuracy: {eval_acc:.4f}," + 
    f"weight density: {weight_density:.4f}," + 
    f"edge density: {edge_density:.4f}"
)

Epoch 0. mean pruned model accuracy: 1.00,weight density: 1.0000,edge density: 1.0000


In [6]:
# weight_logits = [mask for _, mask in circuit_gpt.mask_logits_dict_weight.items()]
edge_logits = [mask for _, mask in circuit_gpt.mask_logits_dict_edge.items()]

# optim_weight = torch.optim.AdamW(weight_logits, lr=args.lr_weight)
optim_edge = torch.optim.AdamW(edge_logits, lr=args.lr_edge)

In [7]:
def get_lambda_sparse(epoch, lambda_0, max_times=100., min_times=0.001, 
                      n_epoch_warmup=10, n_epoch_cooldown=10):

    if epoch < n_epoch_warmup:
        return lambda_0  + lambda_0 * (max_times - 1) * epoch / n_epoch_warmup
        
    elif epoch < n_epoch_warmup + n_epoch_cooldown:
        return lambda_0 * max_times - lambda_0 * (max_times - min_times) * (epoch - n_epoch_warmup) / n_epoch_cooldown
        
    else:
        return lambda_0 * min_times
        
# it takes about 35 mins to run 100 epochs of edge mask training on one A100 GPU with batch_size=32
for epoch in tqdm(range(args.train_epochs_edge)):
    lambda_sparse_edge = get_lambda_sparse(
        epoch, 
        lambda_0=args.lambda_sparse_edge_init,
        max_times=args.max_times_lambda_sparse, 
        min_times=args.min_times_lambda_sparse, 
        n_epoch_warmup=args.n_epoch_warmup_lambda_sparse,
        n_epoch_cooldown=args.n_epoch_cooldown_lambda_sparse,
    )
    lambda_complete_edge = args.lambda_complete_edge_init
    
    for batch in train_dl:
        batch_inputs = prepare_batch_inputs(batch, tokenizer)
        
        circuit_gpt.turn_on_edge_masks(deterministic=False)
        sparse_loss_edge = circuit_gpt.edge_sparseness_loss()
    
        batch_logits = circuit_gpt(batch_inputs['input_ids'].to(device))[0] 
        faith_loss_edge, _ = compute_faith_loss(batch_logits, batch_inputs) 
        
        circuit_gpt.turn_on_edge_masks(deterministic=False, reverse=True)
        batch_logits = circuit_gpt(batch_inputs['input_ids'].to(device))[0] 
        complete_loss_edge, _ = compute_complete_loss(batch_logits, batch_inputs) 
        
        loss_edge = faith_loss_edge + sparse_loss_edge *  lambda_sparse_edge + complete_loss_edge * lambda_complete_edge
        loss_edge.backward()
        optim_edge.step()
        optim_edge.zero_grad()
        torch.cuda.empty_cache()

    eval_acc_pruned, weight_density, edge_density = eval_model(
        circuit_gpt, eval_dl, tokenizer, device, use_weight_mask=False, reverse=False
    )
    eval_acc_complement, _, _ = eval_model(
        circuit_gpt, eval_dl, tokenizer, device, use_weight_mask=False, reverse=True
    )
    print(
        "Epoch {}. discovered circuit accuracy: {:.4f}, complementary circuit accuracy: {:.4f}, weight density: {:.4f}, edge density: {:.4f}".format(
            epoch + 1, eval_acc_pruned, eval_acc_complement, weight_density, edge_density)
    )

    # save good edge masks
    if eval_acc_pruned > 0.95 and edge_density < 0.05:
        torch.save(
            circuit_gpt.mask_logits_dict_edge,
            join(args.results_dir, f'mask_logits_dict_edge_ioi_edge_only_{epoch}.pt')
        )

  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 1. discovered circuit accuracy: 0.9746, complementary circuit accuracy: 0.4785, weight density: 1.0000, edge density: 0.9922
Epoch 2. discovered circuit accuracy: 0.5312, complementary circuit accuracy: 0.4785, weight density: 1.0000, edge density: 0.8904
Epoch 3. discovered circuit accuracy: 0.7070, complementary circuit accuracy: 0.5000, weight density: 1.0000, edge density: 0.7580
Epoch 4. discovered circuit accuracy: 0.7129, complementary circuit accuracy: 0.5176, weight density: 1.0000, edge density: 0.6174
Epoch 5. discovered circuit accuracy: 0.7812, complementary circuit accuracy: 0.5078, weight density: 1.0000, edge density: 0.4917
Epoch 6. discovered circuit accuracy: 0.6270, complementary circuit accuracy: 0.5000, weight density: 1.0000, edge density: 0.3864
Epoch 7. discovered circuit accuracy: 0.5469, complementary circuit accuracy: 0.4844, weight density: 1.0000, edge density: 0.3001
Epoch 8. discovered circuit accuracy: 0.5684, complementary circuit accuracy: 0.478

KeyboardInterrupt: 

In [1]:
import torch

x = torch.rand(2,5)

In [3]:
(x[0] < x[1]).long()

tensor([1, 0, 1, 0, 1])