In [8]:
import torch

torch.manual_seed(42)

from transformers import LlamaForCausalLM, PreTrainedTokenizerFast, LlamaConfig, LlamaModel
import torch
from torch import compile
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset, DataLoader
import transformers
from tqdm import tqdm
import pandas as pd
import yaml
import contextlib
import os
import time
import sys

sys.path.append("..")

%load_ext autoreload
%autoreload 2

from transformers import PreTrainedTokenizerFast, AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v0.1")
tokenizer = AutoTokenizer.from_pretrained("/work/frink/models/llama3-8B-HF")
tokenizer.pad_token = tokenizer.eos_token

tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id


import random
import numpy as np
import json
from tqdm import tqdm
from datasets import Dataset, load_from_disk

def generate_ravel_dataset(n_samples, split="train", domains=["city"], domain_excluded_attributes=[["Latitude", "Longitude", "Timezone"]], filtering_dict_paths=[None],  seed=42):
            
    # Seed
    random.seed(seed)
    np.random.seed(seed)
    dataset = []
    
    sample_per_domain = n_samples // len(domains)
    
    for domain, excluded_attributes in zip(domains, domain_excluded_attributes):
                        
        templates = {
            "Language": ["People in %s usually speak"],
            "Country": ["%s is in the country of"],
            "Continent": ["%s is in the continent of"]
        }
        
        prompt_idxs_dict = {
            "Language": [3],
            "Country": [5],
            "Continent": [3]
        }
        
        entities = json.load(open(os.path.join("./data/ravel/", f"ravel_{domain}_entity_attributes.json"), "r"))
        entities_split = json.load(open(os.path.join("./data/ravel/", f"ravel_{domain}_entity_to_split.json"), "r"))
        
        all_attributes = [a for a in list(templates.keys()) if a not in excluded_attributes]

        entities_train = {k: v for k, v in entities.items() if entities_split[k] == "train"}
        name_train = list(entities_train.keys())

        entities_test = {k: v for k, v in entities.items() if entities_split[k] != "train"}
        name_test = list(entities_test.keys())
        
        if split == "train":
            entity_dict, entity_name = entities_train, name_train
        elif split == "test":
            entity_dict, entity_name = entities_test, name_test
        else:
            raise ValueError("split must be 'train' or 'test'")
        
        
        filtered_key = []
        for entity in entity_dict.keys():
            model_knows = True
            
            entity_token_len = len(tokenizer(entity)["input_ids"])
            
            if entity_token_len > 1:
                filtered_key.append(entity)
            
            if not model_knows:
                filtered_key.append(entity)
        
        print(f"Filtering out {len(filtered_key)} out of {len(entity_dict)} entities that the model does not know!")
        filtering_dict = {k: v for k, v in entity_dict.items() if k not in filtered_key}
        entity_name = [n for n in entity_name if n not in filtered_key]
        
        for _ in tqdm(range(sample_per_domain)):
            
            data = {}
            
            source_entity, base_entity = random.sample(entity_name, 2)
            attribute = random.choice(all_attributes)
            another_attribute = random.choice(all_attributes)
            source_entity_dict, base_entity_dict = entity_dict[source_entity], entity_dict[base_entity]
            source_template, base_template = random.choice(templates[another_attribute]), random.choice(templates[attribute])
        
                
            data["input_text"] = base_template % base_entity
            data["counterfactual_input_text"] = source_template % source_entity
            data["entity"] = base_entity
            data["edit_instruction"] = f"{base_entity} ; {source_entity} - {attribute}"
            data["counterfactual_entity"] = source_entity
            data["target"] = base_entity_dict[attribute]
            data["counterfactual_target"] = source_entity_dict[attribute]  
            data["counterfactual_input_text_with_same_template"] = base_template % source_entity
                
            dataset.append(data)
                
    dataset = Dataset.from_list(dataset)
    return dataset

city_train_set = generate_ravel_dataset(10000, split="train", filtering_dict_paths=None)
city_test_set =  generate_ravel_dataset(1000, split="test", filtering_dict_paths=None)


# city_train_set = generate_ravel_dataset(1000, split="train", filtering_dict_paths=None)
# city_test_set =  generate_ravel_dataset(100, split="test", filtering_dict_paths=None)
# city_train_set = load_from_disk("./experiment_models_new/city_train_set")
# city_test_set = load_from_disk("./experiment_models_new/city_test_set")

def ravel_collate_fn(batch):
    
    def tokenize_text_inputs(texts, counterfactual_texts, target_texts):
        
        input_texts = [text + " " + target for text, target in zip(texts, target_texts)]
        input_texts = [text.replace(" \" ", " \" ") for text in input_texts]
        
        tokenized = tokenizer(input_texts, return_tensors="pt", padding=True, max_length=50, truncation=True)
        tokenized_counterfactual = tokenizer(counterfactual_texts, return_tensors="pt", padding=True, max_length=50, truncation=True)
        tokenized_labels = []
        
        for input_ids, input_text in zip(tokenized["input_ids"], texts):
            input_length = tokenizer(input_text, return_tensors="pt", padding=False)["input_ids"].shape[-1]
            input_length += torch.sum(input_ids == tokenizer.pad_token_id)
            
            label = torch.full_like(input_ids, -100)
            label[input_length:] = input_ids[input_length:]
            tokenized_labels.append(label)
        
        tokenized_labels = torch.stack(tokenized_labels)
        return {
            "base_input_ids": tokenized["input_ids"],
            "base_attention_mask": tokenized["attention_mask"],
            "source_input_ids": tokenized_counterfactual["input_ids"],
            "source_attention_mask": tokenized_counterfactual["attention_mask"],
            "labels": tokenized_labels
        }
        
    prompts, edit_instructions, targets, counterfactual_prompts = [], [], [], []
    for b in batch:
        prompts.append(b["counterfactual_input_text_with_same_template"])
        edit_instructions.append(b["edit_instruction"])
        targets.append(b["counterfactual_target"])
        counterfactual_prompts.append(b["counterfactual_input_text"])
        
    editor_input_ids = tokenizer(edit_instructions, return_tensors="pt", padding=True, truncation=True)["input_ids"]
    
    returned_dict = {
        "editor_input_ids": editor_input_ids,
        **tokenize_text_inputs(prompts, counterfactual_prompts, targets),
    }    
    
    return returned_dict

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Filtering out 1751 out of 1787 entities that the model does not know!


100%|██████████| 10000/10000 [00:00<00:00, 171916.73it/s]


Filtering out 1724 out of 1765 entities that the model does not know!


100%|██████████| 1000/1000 [00:00<00:00, 197872.53it/s]
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [2]:
# @torch.compile #Apparently this fails when used inside jupyter notebooks but is fine if i make dedicated scripts
from models.llama3.model import LlamaInterpretor, LlamaInterpretorConfig, RavelInterpretorHypernetwork
from models.utils import InterpretorModelOutput
from transformers import LlamaForCausalLM

llama = LlamaForCausalLM.from_pretrained("/work/frink/models/llama3-8B-HF")
llama = llama.cuda()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [18]:
def filter_dataset(model, dataset):

    batch_size = 16  # 50 or so
    data_loader = DataLoader(
        dataset, batch_size=batch_size, collate_fn=ravel_collate_fn, shuffle=False
    )  # batch_size, collate_fn=collate_fn)

    model.eval()
    correct_idxs = set()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_id, batch in enumerate(data_loader):
            
            predictions = model.forward(
                input_ids=batch["base_input_ids"].to("cuda"),
                attention_mask=batch["base_attention_mask"].to("cuda"),
                labels=batch["labels"].to("cuda"),
            )
            
            batch_pred_ids = torch.argmax(predictions["logits"], dim=-1)
            
            for i, (label, pred_ids) in enumerate(zip(batch["labels"].to("cuda"), batch_pred_ids)):
                            
                label_idx = label != -100
                output_idx = torch.zeros_like(label_idx)
                output_idx[:-1] = label_idx[1:]
                
                label = label[label_idx]
                pred_ids = pred_ids[output_idx]
                is_correct = (torch.sum(label == pred_ids) == torch.numel(label)).item()
                    
                correct += is_correct
                if is_correct:
                    correct_idxs.add(batch_id * len(batch["labels"]) + i)
                total += 1
                
    print(f"Accuracy: {correct / total}")
    new_dataset = dataset.select(correct_idxs)
    return new_dataset

In [20]:
city_train_set = filter_dataset(llama, city_train_set)
city_test_set = filter_dataset(llama, city_test_set)

Accuracy: 0.3742
Accuracy: 0.331


In [22]:
len(city_test_set), len(city_train_set)

(330, 3742)

In [24]:
city_train_set.save_to_disk("./data/city_train_set")
city_test_set.save_to_disk("./data/city_test_set")

Saving the dataset (0/1 shards):   0%|          | 0/3742 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/330 [00:00<?, ? examples/s]