In [1]:
stopping_index = 8  # 4090 or A6000
num_editing_heads = 16 
max_grad_clip = 4.0
chop_layer = 6
lr = 3e-5

import torch
torch.manual_seed(42)

from torch import compile
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset, DataLoader
import transformers
from tqdm import tqdm
import pandas as pd
import yaml
import contextlib
import os
import time
import sys

sys.path.append("..")

%load_ext autoreload
%autoreload 2

from transformers import LlamaForCausalLM, PreTrainedTokenizerFast, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("/work/frink/models/llama3-8B-HF")
tokenizer.pad_token = tokenizer.eos_token

tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id


import random
import numpy as np
import json
from tqdm import tqdm
from datasets import Dataset

def generate_ravel_dataset(n_samples, split="train", domains=["city"], domain_excluded_attributes=[["Latitude", "Longitude", "Timezone"]], filtering_dict_paths=[None],  seed=42):
            
    # Seed
    random.seed(seed)
    np.random.seed(seed)
    dataset = []
    
    sample_per_domain = n_samples // len(domains)
    
    for domain, excluded_attributes, filtering_dict_path in zip(domains, domain_excluded_attributes, filtering_dict_paths):
                        
        templates = json.load(open(os.path.join("./data/ravel/", f"ravel_{domain}_attribute_to_prompts.json"), "r"))
        entities = json.load(open(os.path.join("./data/ravel/", f"ravel_{domain}_entity_attributes.json"), "r"))
        entities_split = json.load(open(os.path.join("./data/ravel/", f"ravel_{domain}_entity_to_split.json"), "r"))
        templates_split = json.load(open(os.path.join("./data/ravel/", f"ravel_{domain}_prompt_to_split.json"), "r"))
        
        all_attributes = [a for a in list(templates.keys()) if a not in excluded_attributes]

        templates_train = {k: [v for v in vs if templates_split[v] == "train"] for k, vs in templates.items()}
        templates_train_idxs = {k: [templates[k].index(v) for v in vs if templates_split[v] == "train"] for k, vs in templates.items()}
        
        templates_test = {k: [v for v in vs if templates_split[v] != "train"] for k, vs in templates.items()}
        templates_test_idxs = {k: [templates[k].index(v) for v in vs if templates_split[v] != "train"] for k, vs in templates.items()}

        entities_train = {k: v for k, v in entities.items() if entities_split[k] == "train"}
        name_train = list(entities_train.keys())

        entities_test = {k: v for k, v in entities.items() if entities_split[k] != "train"}
        name_test = list(entities_test.keys())
        
        if split == "train":
            entity_dict, entity_name, prompt_dict, prompt_idxs_dict = entities_train, name_train, templates_train, templates_train_idxs
        elif split == "test":
            entity_dict, entity_name, prompt_dict, prompt_idxs_dict = entities_test, name_test, templates_test, templates_test_idxs
        else:
            raise ValueError("split must be 'train' or 'test'")
        
        if filtering_dict_path is not None:
            filtering_dict = json.load(open(filtering_dict_path, "r"))
            filtered_key = []
            
            for entity in filtering_dict.keys():
                model_knows = True
                for attribute in all_attributes:                    
                    split_template_idx = [list(filtering_dict[entity][attribute].values())[i] for i in prompt_idxs_dict[attribute]]
                    if True not in split_template_idx:
                        model_knows = False
                        break
                if not model_knows:
                    filtered_key.append(entity)
            print(f"Filtering out {len(filtered_key)} out of {len(filtering_dict)} entities that the model does not know!")
            filtering_dict = {k: v for k, v in filtering_dict.items() if k not in filtered_key}
        else:
            filtering_dict = None
        
        for _ in tqdm(range(sample_per_domain)):
            
            data = {}
            
            if filtering_dict is None:
                source_entity, base_entity = random.sample(entity_name, 2)
                attribute = random.choice(all_attributes)
                frozen_attributes = [k for k in all_attributes if k != attribute]
                source_entity_dict, base_entity_dict = entity_dict[source_entity], entity_dict[base_entity]
                source_template, base_template = random.choice(prompt_dict[attribute]), random.choice(prompt_dict[attribute])
            else:
                source_entity, base_entity = random.sample([k for k in entity_name if k in filtering_dict.keys()], 2)
                attribute = random.choice(all_attributes)
                frozen_attributes = [k for k in all_attributes if k != attribute]
                source_entity_dict, base_entity_dict = entity_dict[source_entity], entity_dict[base_entity]
                source_template_idxs = [i for i in range(len(filtering_dict[source_entity][attribute])) if filtering_dict[source_entity][attribute][str(i)] == True]
                source_template_idxs = [prompt_idxs_dict[attribute].index(i) for i in source_template_idxs if i in prompt_idxs_dict[attribute]]
                source_template = random.choice([prompt_dict[attribute][i] for i in source_template_idxs])
                base_template_idxs = [i for i in range(len(filtering_dict[base_entity][attribute])) if filtering_dict[base_entity][attribute][str(i)] == True]
                base_template_idxs = [prompt_idxs_dict[attribute].index(i) for i in base_template_idxs if i in prompt_idxs_dict[attribute]]
                base_template = random.choice([prompt_dict[attribute][i] for i in base_template_idxs])
                
            data["input_text"] = base_template % base_entity
            data["counterfactual_input_text"] = source_template % source_entity
            data["edit_instruction"] = f"{base_entity} ; {source_entity} - {attribute}"
            data["target"] = base_entity_dict[attribute]
            data["counterfactual_target"] = source_entity_dict[attribute]
            
            data["unaffected_attributes"] = []
            
            for frozen_attribute in frozen_attributes:
                
                if filtering_dict is None:
                    source_attribute_template, base_attribute_template = random.choice(prompt_dict[frozen_attribute]), random.choice(prompt_dict[frozen_attribute])
                else:
                    source_attribute_idxs = [i for i in range(len(filtering_dict[source_entity][frozen_attribute])) if filtering_dict[source_entity][frozen_attribute][str(i)] == True]
                    source_attribute_idxs = [prompt_idxs_dict[frozen_attribute].index(i) for i in source_attribute_idxs if i in prompt_idxs_dict[frozen_attribute]]
                    source_attribute_template = random.choice([prompt_dict[frozen_attribute][i] for i in source_attribute_idxs])
                    base_attribute_idxs = [i for i in range(len(filtering_dict[base_entity][frozen_attribute])) if filtering_dict[base_entity][frozen_attribute][str(i)] == True]
                    base_attribute_idxs = [prompt_idxs_dict[frozen_attribute].index(i) for i in base_attribute_idxs if i in prompt_idxs_dict[frozen_attribute]]
                    base_attribute_template = random.choice([prompt_dict[frozen_attribute][i] for i in base_attribute_idxs])
                    
                base_prompt = base_attribute_template % base_entity
                counterfactual_prompt = source_attribute_template % source_entity
                
                target = base_entity_dict[frozen_attribute]
                counterfactual_target = source_entity_dict[frozen_attribute]
                
                data["unaffected_attributes"].append(
                    {
                        "input_text": base_prompt,
                        "counterfactual_input_text": counterfactual_prompt,
                        "edit_instruction": f"{base_entity} ; {source_entity} - {attribute}",
                        "target": target,
                        "counterfactual_target": counterfactual_target,
                    }
                )
                
            dataset.append(data)
                
    dataset = Dataset.from_list(dataset)
    return dataset

city_train_set = generate_ravel_dataset(10000, split="train", filtering_dict_paths=["./notebooks/ravel_llama-3-8b_city_prompt_to_output_statistics.json"])
city_test_set =  generate_ravel_dataset(1000, split="test", filtering_dict_paths=["./notebooks/ravel_llama-3-8b_city_prompt_to_output_statistics.json"])

def ravel_collate_fn(batch):
    
    def tokenize_text_inputs(texts, counterfactual_texts, target_texts):
        
        input_texts = [text + " " + target for text, target in zip(texts, target_texts)]
        input_texts = [text.replace(" \" ", " \" ") for text in input_texts]
        
        tokenized = tokenizer(input_texts, return_tensors="pt", padding=True, max_length=50, truncation=True)
        tokenized_counterfactual = tokenizer(counterfactual_texts, return_tensors="pt", padding=True, max_length=50, truncation=True)
        tokenized_labels = []
        
        for input_ids, input_text in zip(tokenized["input_ids"], texts):
            input_length = tokenizer(input_text, return_tensors="pt", padding=False)["input_ids"].shape[-1]
            label = torch.full_like(input_ids, -100)
            label[input_length:] = input_ids[input_length:]
            label[input_ids == tokenizer.pad_token_id] = -100
            tokenized_labels.append(label)
        
        tokenized_labels = torch.stack(tokenized_labels)
        return {
            "base_input_ids": tokenized["input_ids"],
            "base_attention_mask": tokenized["attention_mask"],
            "source_input_ids": tokenized_counterfactual["input_ids"],
            "source_attention_mask": tokenized_counterfactual["attention_mask"],
            "labels": tokenized_labels
        }
        
    prompts, edit_instructions, targets, unaffected_attributes, counterfactual_prompts = [], [], [], [], []
    for b in batch:
        prompts.append(b["input_text"])
        edit_instructions.append(b["edit_instruction"])
        targets.append(b["counterfactual_target"])
        unaffected_attributes.append(b["unaffected_attributes"])
        counterfactual_prompts.append(b["counterfactual_input_text"])
        
        
    editor_input_ids = tokenizer(edit_instructions, return_tensors="pt", padding=True, truncation=True)["input_ids"]
    
    returned_dict = {
        "editor_input_ids": editor_input_ids,
        **tokenize_text_inputs(prompts, counterfactual_prompts, targets),
    }
    
    base_prompts_unaffected, counterfactual_prompts_unaffected, targets_unaffected, edit_instructions_unaffected, instance_indices = [], [], [], [], []
    
    for i, attribute_list in enumerate(unaffected_attributes):
        
        for d in attribute_list:
            base_prompts_unaffected.append(d["input_text"])
            targets_unaffected.append(d["target"])
            counterfactual_prompts_unaffected.append(d["counterfactual_input_text"])
                    
        for _ in range(len(attribute_list)):
            edit_instructions_unaffected.append(editor_input_ids[i])
            instance_indices.append(i)
        
    edit_instructions_unaffected = torch.stack(edit_instructions_unaffected)
    instance_indices = torch.tensor(instance_indices)
    
    assert len(base_prompts_unaffected) == len(targets_unaffected)
    
    tokenized_unaffected = tokenize_text_inputs(base_prompts_unaffected, counterfactual_prompts_unaffected, targets_unaffected)
    
    returned_dict["editor_input_ids_unaffected"] = edit_instructions_unaffected
    returned_dict["base_input_ids_unaffected"] = tokenized_unaffected["base_input_ids"]
    returned_dict["base_attention_mask_unaffected"] = tokenized_unaffected["base_attention_mask"]
    returned_dict["source_input_ids_unaffected"] = tokenized_unaffected["source_input_ids"]
    returned_dict["source_attention_mask_unaffected"] = tokenized_unaffected["source_attention_mask"]
    returned_dict["labels_unaffected"] = tokenized_unaffected["labels"]
    returned_dict["instance_indices"] = instance_indices
    
    return returned_dict

batch_size = 32  # 50 or so
data_loader = DataLoader(
    city_train_set, batch_size=batch_size, collate_fn=ravel_collate_fn, shuffle=True
)  # batch_size, collate_fn=collate_fn)
test_data_loader = DataLoader(
    city_test_set, batch_size=batch_size, collate_fn=ravel_collate_fn, shuffle=True
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Filtering out 3 out of 3552 entities that the model does not know!


100%|██████████| 10000/10000 [00:03<00:00, 2781.54it/s]


Filtering out 3 out of 3552 entities that the model does not know!


100%|██████████| 1000/1000 [00:00<00:00, 3163.57it/s]


In [2]:
from models.llama3.model import LlamaInterpretor, LlamaInterpretorConfig
from models.utils import EditorModelOutput

editor_config = LlamaInterpretorConfig.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v0.1")

editor_config.name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v0.1"
editor_config.torch_dtype = torch.float16
editor_config.num_editing_heads = 16
editor_config.chop_editor_at_layer = 10
editor_config.default_intervention_layer = 10
editor_config._attn_implementation = 'eager'

editor_inner = LlamaInterpretor(editor_config)
tokenizer = PreTrainedTokenizerFast.from_pretrained(editor_config.name_or_path)
editor_inner = editor_inner.to("cuda")

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LlamaTokenizer'. 
The class this function is called from is 'PreTrainedTokenizerFast'.


In [3]:
trainable_parameters = []
for name, param in editor_inner.named_parameters():
    if "target_model" not in name:
        trainable_parameters.append(param)
        
optimzer = optim.SGD(trainable_parameters, lr=lr, weight_decay=0.01)  # usually: lr = 5e-5. 1e-3 worked well!

In [4]:
for batch in data_loader:
    break

editor_input_ids = batch["editor_input_ids"].cuda()
base_input_ids = batch["base_input_ids"].cuda()
base_attention_mask = batch["base_attention_mask"].cuda()
source_input_ids = batch["source_input_ids"].cuda()
source_attention_mask = batch["source_attention_mask"].cuda()
labels = batch["labels"].cuda()
editor_attention_mask=editor_input_ids != editor_config.eos_token_id

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [5]:
_pred: EditorModelOutput = editor_inner(
    editor_input_ids=editor_input_ids,
    editor_attention_mask=editor_attention_mask,
    base_input_ids=base_input_ids,
    base_attention_mask=base_attention_mask,
    source_input_ids=source_input_ids,
    source_attention_mask=source_attention_mask,
)

In [9]:
log_prob_predictions = torch.nn.functional.log_softmax(
    _pred.logits.reshape(-1, _pred.logits.shape[-1]),
    dim=1,
)

labels = labels.reshape(-1)

In [10]:
label_indices = labels != -100
log_prob_predictions = log_prob_predictions[label_indices, :]
labels = labels[label_indices]

# Compute the cross-entropy loss with masking
criterion = torch.nn.CrossEntropyLoss(reduction="mean")
loss = criterion(log_prob_predictions, labels.long())
loss.backward()

In [12]:
# loss.backward()
"""nn.utils.clip_grad_norm_(
    editor_inner.parameters(), max_grad_clip
)"""
optimzer.step()

In [13]:
editor_inner.hypernetwork.model.layers[0].self_attn.q_proj.weight[0, 0], editor_inner.hypernetwork.model.layers[0].self_attn.q_proj.weight.grad[0, 0]

(tensor(-0.0012, device='cuda:0', dtype=torch.float16,
        grad_fn=<SelectBackward0>),
 tensor(0., device='cuda:0', dtype=torch.float16))