In [9]:
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from datasets import Dataset
import torch
import pickle
import transformer_lens
from torch.optim import AdamW
from os.path import join
from tqdm.auto import tqdm
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_from_disk
from sklearn.model_selection import train_test_split
from dataclasses import dataclass
import pandas as pd
import json
import numpy as np

from circuit_gpt import CircuitGPT, CircuitGPTConfig

In [80]:
def prepare_pararel_data(args):
    
    with open(join(args.data_dir, args.data_name)) as open_file:
        pararel_rel_data = json.load(open_file)
    rel_ids = args.pararel_rel_ids.split(' ')
    data = []
    for rel_id in rel_ids:
        data += pararel_rel_data[rel_id]

    ds_dict = {
        'prompt': [],
        'answer': [],
    }
    for entry in data:
        prompt = entry[0][0].replace(' [MASK] .', '')
        prompt = prompt.replace(' [MASK].', '')
        assert '[MASK]' not in prompt
        target = entry[0][1]
        ds_dict['prompt'].append(prompt)
        ds_dict['answer'].append(' ' + target)

    train_idx, test_idx = train_test_split(np.arange(len(data)), test_size=args.test_ratio)
    train_ds_dict = {
        'prompt': [ds_dict['prompt'][i] for i in train_idx], 
        'answer': [ds_dict['answer'][i] for i in train_idx]
    }
    test_ds_dict = {
        'prompt': [ds_dict['prompt'][i] for i in test_idx], 
        'answer': [ds_dict['answer'][i] for i in test_idx]
    }
    
    train_ds = Dataset.from_dict(train_ds_dict)
    test_ds = Dataset.from_dict(test_ds_dict)
    
    return train_ds, test_ds
    

def prepare_batch_inputs(batch, tokenizer):
    batch_inputs = tokenizer(
        batch['prompt'], return_tensors='pt', padding=True
    )
    batch_seq_lens = batch_inputs.attention_mask.sum(-1)

    return {
        'input_ids': batch_inputs.input_ids,
        'seq_lens': batch_seq_lens,
        'label': batch['label']
    }
    

In [20]:
@dataclass
class DiffMaskArgs:
    model_dir: str = '/home/leiyu/projects/def-yangxu/leiyu/LMs/'
    data_dir: str = '/home/leiyu/projects/def-yangxu/leiyu/circuit-discovery/data/'
    results_dir: str = '/home/leiyu/scratch/circuit-discovery/mask_logits/'
    data_name: str = 'pararel_data_all.json'
    pararel_rel_ids: str = 'P36 P1376'
    model_name: str = 'gpt2-small'
    gs_temp_weight: float = 0.01
    gs_temp_edge: float = 1.0
    logits_w_init: float = 0.0
    logits_e_init: float = 0.0
    test_ratio: float = 0.0
    batch_size: int = 16
    train_epochs_weight: int = 1000
    train_epochs_edge: int = 50
    lr_weight: float = 0.1
    lr_edge: float = 0.1
    lambda_sparse_weight_init: float = 1.
    lambda_sparse_edge_init: float = 1.
    lambda_complete_weight_init: float = 1.
    lambda_complete_edge_init: float = 1.
    save_every: int = 5
    resume_epoch_w: int = 0
    resume_epoch_e: int = 0
    use_weight_masks: bool = False
    n_epoch_warmup_lambda_sparse: int = 20
    n_epoch_cooldown_lambda_sparse: int = 20
    max_times_lambda_sparse: float = 100.
    min_times_lambda_sparse: float = 0.01
    random_seed: int = 0
    

def compute_faith_loss(batch_logits, batch_inputs):
    # batch_logits: (B, seq_len, vocab_size)
    batch_seq_lens = batch_inputs['seq_lens']
    batch_size = batch_logits.shape[0]

    batch_logits_next_tok = batch_logits[torch.arange(batch_size), batch_seq_lens - 1]  # (B, vocab_size)
    batch_labels = batch_inputs['labels'].long().to(batch_logits_next_tok.device)
    batch_faith_loss = nn.functional.cross_entropy(batch_logits_next_tok, batch_labels)

    with torch.no_grad():
        batch_preds = torch.argsort(batch_logits_next_tok, -1)[:, -1].cpu()

    return batch_faith_loss, batch_preds


def compute_complete_loss(batch_logits, batch_inputs):
    # batch_logits: (B, seq_len, vocab_size)
    batch_seq_lens = batch_inputs['seq_lens']
    batch_size = batch_logits.shape[0]

    batch_logits_next_tok = batch_logits[torch.arange(batch_size), batch_seq_lens - 1]  # (B, vocab_size)

    batch_probs_uniform = torch.ones(batch_logits_next_tok.shape).to(batch_logits_next_tok.device) / batch_logits_next_tok.shape[-1]
    batch_complete_loss = nn.functional.cross_entropy(batch_logits_next_tok, batch_probs_uniform)

    with torch.no_grad():
        batch_preds = torch.argsort(batch_logits_next_tok, -1)[:, -1].cpu()

    return batch_complete_loss, batch_preds


@torch.no_grad()
def eval_model(model, eval_dl, tokenizer, device, use_weight_mask=True, use_edge_mask=True, reverse=False):
    model.eval()

    # get weight and edge density     
    model.turn_on_weight_masks(deterministic=True, reverse=reverse)  
    _, _, weight_density = model.get_weight_density()       
    model.turn_on_edge_masks(deterministic=True, reverse=reverse)  
    _, _, edge_density = model.get_edge_density()

    if not use_weight_mask:
        model.turn_off_weight_masks()
    if not use_edge_mask:     
        model.turn_off_edge_masks()

    total = len(eval_dl.dataset)
    correct = 0

    n_batch = int(len(eval_dl.dataset) / eval_dl.batch_size) + 1
    for batch in tqdm(eval_dl, total=n_batch):
        batch_inputs = prepare_batch_inputs(batch, tokenizer)
        batch_logits = model(batch_inputs['input_ids'].to(device))[0]  # (B, seq_len, vocab_size)
        _, batch_preds = compute_faith_loss(batch_logits, batch_inputs)
        # print(batch_logits_gb)
        correct += (batch_preds == batch_inputs['labels']).sum().cpu().item()

        torch.cuda.empty_cache()

    model.turn_off_weight_masks()
    model.turn_off_edge_masks()

    acc = correct / total

    return acc, weight_density, edge_density


In [22]:
args = DiffMaskArgs()
model_path = join(args.model_dir, args.model_name)
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

In [81]:
with open(join(args.data_dir, args.data_name)) as open_file:
        pararel_rel_data = json.load(open_file)  
    
rel_ids = args.pararel_rel_ids.split(' ')
data = []
for rel_id in rel_ids:
    data += pararel_rel_data[rel_id]

capital_vocab = []
ds_dict = {
    'prompt': [],
    'answer': [],
}

for entry in data:
    prompt = entry[0][0].replace(' [MASK] .', '')
    prompt = prompt.replace(' [MASK].', '')
    assert '[MASK]' not in prompt
    target = entry[0][1]
    ds_dict['prompt'].append(prompt)
    ds_dict['answer'].append(' ' + target)
    capital_vocab.append(' ' + target)

capital_vocab = list(set([data['answer'] for data in ds_dict]))
capital_vocab_idx = torch.tensor([
    input_ids[0] for input_ids in tokenizer(capital_vocab).input_ids
])

capital_vocab2class_id = {
    capital_vocab_idx[i].item():i for i in range(len(capital_vocab_idx)) 
}

capital_name2vocab_id = {capital: capital_vocab_id.item() for capital, capital_vocab_id in zip(capital_vocab, capital_vocab_idx)}

ds_dict['label'] = [
    capital_vocab2class_id[capital_name2vocab_id[answer]] for answer in ds_dict['answer']
]


ds = Dataset.from_dict(ds_dict)
dl = DataLoader(ds, batch_size=args.batch_size, shuffle=False)

In [82]:
ds[10]

{'prompt': 'The capital of Alexandria Governorate is',
 'answer': ' Alexandria',
 'label': 27872}

In [34]:
device = torch.device('cuda')

# # download gpt2-small weights from EasyTransformer and save it
# reference_gpt2 = EasyTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)
# torch.save(reference_gpt2.state_dict(), join(args['model_dir'], 'gpt2-small/gpt2_small_weights.pt'))

gpt_weights = torch.load(join(model_path, 'model_weights.pt')) 
circuit_gpt_config = CircuitGPTConfig(
    debug=False,
    gs_temp_weight=args.gs_temp_weight,
    gs_temp_edge=args.gs_temp_edge,
    use_weight_masks=False
)
circuit_gpt = CircuitGPT(circuit_gpt_config)
circuit_gpt.load_pretrained_weight(gpt_weights)

# load pretrained mask logits if necessary
if args.resume_epoch_w > 0:
    weight_mask_logits = torch.load(join(model_path), f'weight_mask_logits_{resume_epoch_w}.pt')
    circuit_gpt.load_pretrained_weight_mask(weight_mask_logits)
if args.resume_epoch_e > 0:
    edge_mask_logits = torch.load(join(model_path), f'edge_mask_logits_{resume_epoch_e}.pt')
    circuit_gpt.load_pretrained_edge_mask(edge_mask_logits)

circuit_gpt.to(device);

In [92]:
model = circuit_gpt
model.eval()

model.turn_off_weight_masks()
model.turn_off_edge_masks()

total = len(dl.dataset)
correct = 0

model_predictions = []
n_batch = int(len(dl.dataset) / dl.batch_size) + 1

for batch in tqdm(dl, total=n_batch):
    batch_inputs = prepare_batch_inputs(batch, tokenizer)
    with torch.no_grad():
        batch_logits = model(batch_inputs['input_ids'].to(device))[0]  # (B, seq_len, vocab_size)
        batch_seq_lens = batch_inputs['seq_lens']
        batch_size = batch_logits.shape[0]
        batch_logits_next_tok = batch_logits[torch.arange(batch_size), batch_seq_lens - 1][:, capital_vocab_idx]  # (B, capital_vocab_size)
        batch_pred_cap_ids = torch.argsort(batch_logits_next_tok, -1)[:, -1].cpu()
        for pred_cap_id, label in zip(batch_pred_cap_ids, batch['label']):
            if capital_vocab_idx[pred_cap_id] == label:
                correct += 1
        # correct += (batch_pred_cap_ids == torch.tensor(batch_inputs['label'])).sum()
     
    torch.cuda.empty_cache()

acc = correct / total

  0%|          | 0/59 [00:00<?, ?it/s]

In [93]:
acc

0.3479188900747065

In [94]:
correct

326

In [91]:
len(dl.dataset)

937