In [1]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from torch.utils.data import Dataset, DataLoader
import torch
import pickle
import transformer_lens
from torch.optim import AdamW
from os.path import join
from tqdm.auto import tqdm
import torch.nn as nn
import torch.nn.functional as F
import yaml
from datasets import load_from_disk
from sklearn.model_selection import train_test_split 

from masked_model import MaskedModel
from circuit_gpt import CircuitGPT, CircuitGPTConfig
from utils import get_target_module_keys
from ioi_dataset import *

In [2]:
with open('diff_mask_ioi.yml') as f:
    args = yaml.safe_load(f)

tokenizer = GPT2Tokenizer.from_pretrained(join(args['model_dir'], args['model_name']))
tokenizer.pad_token = tokenizer.eos_token

ds = IOIDataset(prompt_type="ABBA", N=640, tokenizer=tokenizer)
ds_train, ds_test = train_test_split(ds.ioi_prompts, test_size=0.2, random_state=0)
# Note that there are overlaps between train and test sets, due to the way IOIDataset is constructed (randomly sample N items)

ioi_ds_train = CircuitIOIDataset(prepare_ioi_data_for_clm(ds_train))
ioi_ds_test = CircuitIOIDataset(prepare_ioi_data_for_clm(ds_test))

train_dl = DataLoader(
    ioi_ds_train,
    batch_size=args['batch_size']
)
eval_dl = DataLoader(
    ioi_ds_train,
    batch_size=args['batch_size'],
    shuffle=False
)


In [3]:
def compute_faith_loss(batch_logits, batch_inputs):
    # batch_logits: (B, seq_len, vocab_size)
    batch_seq_lens = batch_inputs['seq_lens']
    batch_size = batch_logits.shape[0]

    logits_target_good = batch_logits[torch.arange(batch_size), batch_seq_lens - 1, batch_inputs['target good']]
    logits_target_bad = batch_logits[torch.arange(batch_size), batch_seq_lens - 1, batch_inputs['target bad']]
    logits_gb = torch.stack([logits_target_good, logits_target_bad], -1)  # (B,2)

    batch_labels = torch.zeros(batch_size).long().to(logits_gb.device)
    batch_faith_loss = nn.functional.cross_entropy(logits_gb, batch_labels)

    return batch_faith_loss, logits_gb


@torch.no_grad()
def eval_mask_model(masked_model, eval_dl, tokenizer, device, use_mask=True):
    masked_model.model.eval()
    masked_model.apply_masks(deterministic_masks=True)  # weight masks
    masked_model.model.turn_on_edge_masks(deterministic_masks=True)  # edge masks

    _, _, pruned_model_density_weight = masked_model.get_pruned_model_density()
    _, _, pruned_model_density_edge = masked_model.model.get_pruned_model_density()

    if not use_mask:
        masked_model.remove_masks()
        masked_model.model.turn_off_edge_masks()

    total = len(eval_dl.dataset)
    correct = 0

    for batch in eval_dl:
        batch_inputs = prepare_batch_inputs(batch, tokenizer)
        batch_logits = masked_model(batch_inputs['input_ids'].to(device))[0]  # (B, seq_len, vocab_size)
        _, batch_logits_gb = compute_faith_loss(batch_logits, batch_inputs)
        correct += (batch_logits_gb[:, 0] > batch_logits_gb[:, 1]).sum()

        torch.cuda.empty_cache()

    if use_mask:
        masked_model.remove_masks()
        masked_model.model.turn_off_edge_masks()

    acc = correct / total

    return acc.item(), pruned_model_density_weight, pruned_model_density_edge


In [4]:
device = torch.device('cuda')

# # download gpt2-small weights from EasyTransformer and save it
# reference_gpt2 = EasyTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)
# torch.save(reference_gpt2.state_dict(), join(args['model_dir'], 'gpt2-small/gpt2_small_weights.pt'))

gpt_weights = torch.load(join(args['model_dir'], 'gpt2-small/gpt2_small_weights.pt')) 
tokenizer = GPT2Tokenizer.from_pretrained(join(args['model_dir'], args['model_name']))
tokenizer.pad_token = tokenizer.eos_token

circuit_gpt_config = CircuitGPTConfig(
    debug=False,
    gs_temp=args['temperature_edge'],
    mask_logit_init=args['mask_logit_init']
)
circuit_gpt = CircuitGPT(circuit_gpt_config, means=None)
circuit_gpt.load_state_dict(gpt_weights, strict=False)
circuit_gpt.to(device);


In [6]:
args['param_keys'] = get_target_module_keys(circuit_gpt)
masked_model = MaskedModel(circuit_gpt, args)

# weight_mask_fn = join(args['model_dir'], 'mask_logits', 'gpt2_small_ioi_weight_mask_logits.pt')
# masked_model.load_mask_logits(weight_mask_fn, device)

# optim_weight = torch.optim.AdamW(masked_model.get_trainable_parameters(), lr=args['lr_weight'])

In [7]:
# eval_acc, pruned_model_density_weight, _ = eval_mask_model(masked_model, eval_dl, tokenizer, device, use_mask=False)
# print("Epoch 0. unmasked model accuracy {:.2f}".format(eval_acc))
eval_acc, pruned_model_density_weight, pruned_model_density_edge = eval_mask_model(masked_model, eval_dl, tokenizer, device, use_mask=False)
print(
    "Epoch 0. mean pruned model accuracy: {:.2f}, weight density: {:.4f}, edge density: {:.4f}".format(
        eval_acc, pruned_model_density_weight, pruned_model_density_edge)
)

Epoch 0. mean pruned model accuracy: 1.00, weight density: 1.0000, edge density: 1.0000


In [7]:
# def get_lambda_sparse_edge(epoch):
#     if epoch < 20:
#         return 0.01 * epoch / 20
#     elif epoch < 100:
#         return 0.01 - 0.009 * epoch / 100
#     else:
#         return 0.001


# def get_lambda_sparse_weight(epoch, args):
#     if epoch <= 20:
#         return args['lambda_sparse_weight']
#     elif epoch < 1020:
#         return args['lambda_sparse_weight'] + epoch - 20
#     else:
#         return 1000


# def weight_pruning_epoch(args, epoch, train_dl, eval_dl, masked_model, optim_weight):
#     # masked_model.model.turn_off_edge_masks()

#     lambda_sparse_w = get_lambda_sparse_weight(epoch, args)

#     for batch in train_dl:
#         batch_inputs = prepare_batch_inputs(batch, tokenizer)

#         # weight pruning
#         masked_model.apply_masks()
#         batch_logits = masked_model(batch_inputs['input_ids'].to(device))[0] # (B, seq_len, vocab_size)
#         faith_loss_weight, _ = compute_faith_loss(batch_logits, batch_inputs)
#         # losses['faithfulness'].append(faith_loss_weight.detach().cpu().item())
#         sparse_loss_weight = masked_model.get_sparseness_loss()
#         # losses['sparseness'].append(sparse_loss_weight.detach().cpu().item())
#         loss_weight = sparse_loss_weight * lambda_sparse_w + faith_loss_weight
#         loss_weight.backward()
#         optim_weight.step()
#         optim_weight.zero_grad()
#         masked_model.remove_masks()
#         # torch.cuda.empty_cache()

#     eval_acc, pruned_model_density_weight, pruned_model_density_edge = eval_mask_model(masked_model, eval_dl, tokenizer, device)
#     print(
#         "Epoch {}. mean pruned model accuracy {:.2f}, weight density: {:.4f}".format(
#             epoch + 1, eval_acc, pruned_model_density_weight))
#     # print('\n')


# def edge_pruning_epoch(args, epoch, train_dl, eval_dl, masked_model, circuit_gpt, optim_edge):

#     circuit_gpt.turn_on_edge_masks(deterministic_masks=False)
#     masked_model.apply_masks(deterministic_masks=True)

#     for batch in train_dl:
#         batch_inputs = prepare_batch_inputs(batch, tokenizer)

#         lambda_sparse_edge = get_lambda_sparse_edge(epoch)
#         circuit_gpt.train() 
#         circuit_gpt.turn_on_edge_masks(deterministic_masks=False)
#         sparse_loss_edge = 0
#         for p in circuit_gpt.mask_params_edge:
#             sparse_loss_edge += nn.functional.sigmoid(p).sum()
    
#         batch_logits = circuit_gpt(batch_inputs['input_ids'].to(device))[0] 
#         faith_loss_edge, _ = compute_faith_loss(batch_logits, batch_inputs)        
#         loss_edge = faith_loss_edge + sparse_loss_edge *  lambda_sparse_edge
#         loss_edge.backward()
#         optim_edge.step()
#         optim_edge.zero_grad()
#         # torch.cuda.empty_cache()

#     eval_acc, pruned_model_density_weight, pruned_model_density_edge = eval_mask_model(masked_model, eval_dl, tokenizer, device)
#     print(
#         "Epoch {}. mean pruned model accuracy: {:.2f}, weight density: {:.4f}, edge density: {:.4f}".format(
#             epoch + 1, eval_acc, pruned_model_density_weight, pruned_model_density_edge))


# def combined_pruning(args, masked_model, circuit_gpt, train_dl, eval_dl, n_epoch_edge=100, n_epoch_weight_per_edge_epoch=10):
#     optim_weight = torch.optim.AdamW(masked_model.get_trainable_parameters(), lr=args['lr_weight'])
#     optim_edge = torch.optim.AdamW(circuit_gpt.mask_params_edge, lr=args['lr_edge'])

#     epoch = 0
#     for i in tqdm(range(n_epoch_edge), position=1):
#         for j in tqdm(range(n_epoch_weight_per_edge_epoch), position=0):
#             weight_pruning_epoch(args, epoch, train_dl, eval_dl, masked_model, optim_weight)
#             epoch += 1
#         edge_pruning_epoch(args, epoch, train_dl, eval_dl, masked_model, circuit_gpt, optim_edge)
#         epoch += 1


In [8]:
# combined_pruning(
#     args, masked_model, circuit_gpt, train_dl, eval_dl, n_epoch_edge=100, n_epoch_weight_per_edge_epoch=10
# )

In [9]:
# weight pruning
# torch.cuda.empty_cache()
# masked_model.model.turn_off_edge_masks()


# def get_lambda_sparse_weight_ioi(epoch, args):
#     if epoch <= 20:
#         return args['lambda_sparse_weight']
#     else:
#         return args['lambda_sparse_weight'] + (epoch - 40)* 1.0


# for epoch in tqdm(range(args['train_epochs_weight'])):

#     # if epoch > 20 and args['lambda_sparse_weight'] < 1000:
#     #     args['lambda_sparse_weight'] += 1
#     lambda_sparse_w = get_lambda_sparse_weight_ioi(epoch, args)

#     for batch in train_dl:
#         batch_inputs = prepare_batch_inputs(batch, tokenizer)

#         # weight pruning
#         masked_model.apply_masks()
#         batch_logits = masked_model(batch_inputs['input_ids'].to(device))[0] # (B, seq_len, vocab_size)
        
#         faith_loss_weight, _ = compute_faith_loss(batch_logits, batch_inputs)

#         sparse_loss_weight = masked_model.get_sparseness_loss()
#         loss_weight = sparse_loss_weight * lambda_sparse_w + faith_loss_weight
#         loss_weight.backward()
#         optim_weight.step()
#         optim_weight.zero_grad()
#         masked_model.remove_masks()
        
#         torch.cuda.empty_cache()

#     eval_acc, pruned_model_density_weight, pruned_model_density_edge = eval_mask_model(masked_model, eval_dl, tokenizer, device)
#     print(
#         "Epoch {}. mean pruned model accuracy {:.2f}, weight density: {:.4f}".format(
#             epoch + 1, eval_acc, pruned_model_density_weight))
#     print('\n')

In [10]:
# save weight masks
# weight_mask_fn = join(args['model_dir'], 'mask_logits', 'gpt2_small_ioi_weight_mask_logits.pt')
# masked_model.save_mask_logits(weight_mask_fn)


In [8]:
# edge pruning

def get_lambda_sparse_edge(epoch, lambda_0=1e-3, lambda_1=2e-3, lambda_2=1e-4, n_epoch_warmup=20, n_epoch_cooldown=30):
    if epoch < n_epoch_warmup:
        return lambda_0 + (lambda_1 - lambda_0) * (epoch / n_epoch_warmup)    
    elif epoch < n_epoch_warmup + n_epoch_cooldown:
        return lambda_1 - (lambda_1 - lambda_2) * (epoch - n_epoch_warmup) / n_epoch_cooldown
    else:
        return lambda_2
    
optim_edge = AdamW(circuit_gpt.mask_params_edge, lr=args['lr_edge'])

torch.cuda.empty_cache()
circuit_gpt.turn_on_edge_masks(deterministic_masks=False)
masked_model.apply_masks(deterministic_masks=True)

for epoch in tqdm(range(args['train_epochs_edge'])):

    lambda_sparse_edge = get_lambda_sparse_edge(
        epoch, 
        lambda_0=args['lambda_sparse_edge'], 
        lambda_1=2*args['lambda_sparse_edge'],
        lambda_2=0.1*args['lambda_sparse_edge'],
        n_epoch_warmup=50,
        n_epoch_cooldown=450
    )
    lambda_sparse_edge = args['lambda_sparse_edge']

    for batch in train_dl:
        batch_inputs = prepare_batch_inputs(batch, tokenizer)

        # lambda_sparse_edge = args['lambda_sparse_edge']
        circuit_gpt.train() 
        circuit_gpt.turn_on_edge_masks(deterministic_masks=False)
        sparse_loss_edge = 0
        for p in circuit_gpt.mask_params_edge:
            sparse_loss_edge += nn.functional.sigmoid(p).sum()
    
        batch_logits = circuit_gpt(batch_inputs['input_ids'].to(device))[0] 
        faith_loss_edge, _ = compute_faith_loss(batch_logits, batch_inputs)        
        loss_edge = faith_loss_edge + sparse_loss_edge *  lambda_sparse_edge
        loss_edge.backward()
        optim_edge.step()
        optim_edge.zero_grad()
        torch.cuda.empty_cache()

    eval_acc, pruned_model_density_weight, pruned_model_density_edge = eval_mask_model(masked_model, eval_dl, tokenizer, device)
    print(
        "Epoch {}. mean pruned model accuracy: {:.2f}, weight density: {:.4f}, edge density: {:.4f}".format(
            epoch + 1, eval_acc, pruned_model_density_weight, pruned_model_density_edge))
    print('\n')

  0%|          | 1/500 [00:05<49:43,  5.98s/it]

Epoch 1. mean pruned model accuracy: 1.00, weight density: 1.0000, edge density: 1.0000




  0%|          | 2/500 [00:12<51:50,  6.25s/it]

Epoch 2. mean pruned model accuracy: 1.00, weight density: 1.0000, edge density: 1.0000




  1%|          | 3/500 [00:18<52:18,  6.32s/it]

Epoch 3. mean pruned model accuracy: 1.00, weight density: 1.0000, edge density: 1.0000




  1%|          | 4/500 [00:25<52:27,  6.35s/it]

Epoch 4. mean pruned model accuracy: 1.00, weight density: 1.0000, edge density: 1.0000




  1%|          | 5/500 [00:31<52:51,  6.41s/it]

Epoch 5. mean pruned model accuracy: 1.00, weight density: 1.0000, edge density: 1.0000




  1%|          | 6/500 [00:38<52:37,  6.39s/it]

Epoch 6. mean pruned model accuracy: 1.00, weight density: 1.0000, edge density: 1.0000




  1%|▏         | 7/500 [00:44<52:55,  6.44s/it]

Epoch 7. mean pruned model accuracy: 1.00, weight density: 1.0000, edge density: 1.0000




  2%|▏         | 8/500 [00:50<52:34,  6.41s/it]

Epoch 8. mean pruned model accuracy: 1.00, weight density: 1.0000, edge density: 1.0000




  2%|▏         | 9/500 [00:57<52:36,  6.43s/it]

Epoch 9. mean pruned model accuracy: 1.00, weight density: 1.0000, edge density: 1.0000




  2%|▏         | 9/500 [00:57<52:43,  6.44s/it]


KeyboardInterrupt: 