In [12]:
import torch

torch.manual_seed(42)

from transformers import LlamaForCausalLM, PreTrainedTokenizerFast, LlamaConfig, LlamaModel
import torch
from torch import compile
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset, DataLoader
import transformers
from tqdm import tqdm
import pandas as pd
import yaml
import contextlib
import os
import time
import sys

sys.path.append("..")

%load_ext autoreload
%autoreload 2

from transformers import PreTrainedTokenizerFast, AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v0.1")
tokenizer = AutoTokenizer.from_pretrained("/work/frink/models/llama3-8B-HF")
tokenizer.pad_token = tokenizer.eos_token

tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id


import random
import numpy as np
import json
from tqdm import tqdm
from datasets import Dataset, load_from_disk

def generate_ravel_dataset(n_samples, split="train", domains=["city"], domain_excluded_attributes=[["Latitude", "Longitude", "Timezone"]], filtering_dict_paths=[None],  seed=42):
            
    # Seed
    random.seed(seed)
    np.random.seed(seed)
    dataset = []
    
    sample_per_domain = n_samples // len(domains)
    
    for domain, excluded_attributes, filtering_dict_path in zip(domains, domain_excluded_attributes, filtering_dict_paths):
                        
        templates = {
            "Language": ["People in %s usually speak"],
            "Country": ["%s is in the country of"],
            "Continent": ["%s is in the continent of"]
        }
        
        prompt_idxs_dict = {
            "Language": [3],
            "Country": [5],
            "Continent": [3]
        }
        
        entities = json.load(open(os.path.join("./data/ravel/", f"ravel_{domain}_entity_attributes.json"), "r"))
        entities_split = json.load(open(os.path.join("./data/ravel/", f"ravel_{domain}_entity_to_split.json"), "r"))
        
        all_attributes = [a for a in list(templates.keys()) if a not in excluded_attributes]

        entities_train = {k: v for k, v in entities.items() if entities_split[k] == "train"}
        name_train = list(entities_train.keys())

        entities_test = {k: v for k, v in entities.items() if entities_split[k] != "train"}
        name_test = list(entities_test.keys())
        
        if split == "train":
            entity_dict, entity_name = entities_train, name_train
        elif split == "test":
            entity_dict, entity_name = entities_test, name_test
        else:
            raise ValueError("split must be 'train' or 'test'")
        
        if filtering_dict_path is not None:
            filtering_dict = json.load(open(filtering_dict_path, "r"))
            filtered_key = []
            
            for entity in filtering_dict.keys():
                model_knows = True
                
                entity_token_len = len(tokenizer(entity)["input_ids"])
                
                if entity_token_len > 1:
                    model_knows = False
                
                for attribute in all_attributes:                    
                    split_template_idx = [list(filtering_dict[entity][attribute].values())[i] for i in prompt_idxs_dict[attribute]]
                    if True not in split_template_idx:
                        model_knows = False
                        break
                if not model_knows:
                    filtered_key.append(entity)
            print(f"Filtering out {len(filtered_key)} out of {len(filtering_dict)} entities that the model does not know!")
            filtering_dict = {k: v for k, v in filtering_dict.items() if k not in filtered_key}
        else:
            filtering_dict = None
        
        for _ in tqdm(range(sample_per_domain)):
            
            data = {}
            
            if filtering_dict is None:
                source_entity, base_entity = random.sample(entity_name, 2)
                attribute = random.choice(all_attributes)
                another_attribute = random.choice(all_attributes)
                source_entity_dict, base_entity_dict = entity_dict[source_entity], entity_dict[base_entity]
                source_template, base_template = random.choice(templates[another_attribute]), random.choice(templates[attribute])
            else:
                source_entity, base_entity = random.sample([k for k in entity_name if k in filtering_dict.keys()], 2)
                attribute = random.choice(all_attributes)
                another_attribute = random.choice(all_attributes)
                source_entity_dict, base_entity_dict = entity_dict[source_entity], entity_dict[base_entity]
                source_template, base_template = random.choice(templates[another_attribute]), random.choice(templates[attribute])
                
            data["input_text"] = base_template % base_entity
            data["counterfactual_input_text"] = source_template % source_entity
            data["edit_instruction"] = f"{base_entity} ; {source_entity} - {attribute}"
            data["target"] = base_entity_dict[attribute]
            data["counterfactual_target"] = source_entity_dict[attribute]            
                
            dataset.append(data)
                
    dataset = Dataset.from_list(dataset)
    return dataset

# city_train_set = generate_ravel_dataset(1000, split="train", filtering_dict_paths=["./notebooks/ravel_llama-3-8b_city_prompt_to_output_statistics.json"])
# city_test_set =  generate_ravel_dataset(100, split="test", filtering_dict_paths=["./notebooks/ravel_llama-3-8b_city_prompt_to_output_statistics.json"])


# city_train_set = generate_ravel_dataset(1000, split="train", filtering_dict_paths=None)
# city_test_set =  generate_ravel_dataset(100, split="test", filtering_dict_paths=None)
# city_train_set = load_from_disk("./experiment_models_new/city_train_set")
# city_test_set = load_from_disk("./experiment_models_new/city_test_set")

def ravel_collate_fn(batch):
    
    def tokenize_text_inputs(texts, counterfactual_texts, target_texts):
        
        input_texts = [text + " " + target for text, target in zip(texts, target_texts)]
        input_texts = [text.replace(" \" ", " \" ") for text in input_texts]
        
        tokenized = tokenizer(input_texts, return_tensors="pt", padding=True, max_length=50, truncation=True)
        tokenized_counterfactual = tokenizer(counterfactual_texts, return_tensors="pt", padding=True, max_length=50, truncation=True)
        tokenized_labels = []
        
        for input_ids, input_text in zip(tokenized["input_ids"], texts):
            input_length = tokenizer(input_text, return_tensors="pt", padding=False)["input_ids"].shape[-1]
            input_length += torch.sum(input_ids == tokenizer.pad_token_id)
            
            label = torch.full_like(input_ids, -100)
            label[input_length:] = input_ids[input_length:]
            tokenized_labels.append(label)
        
        tokenized_labels = torch.stack(tokenized_labels)
        return {
            "base_input_ids": tokenized["input_ids"],
            "base_attention_mask": tokenized["attention_mask"],
            "source_input_ids": tokenized_counterfactual["input_ids"],
            "source_attention_mask": tokenized_counterfactual["attention_mask"],
            "labels": tokenized_labels
        }
        
    prompts, edit_instructions, targets, counterfactual_prompts = [], [], [], []
    for b in batch:
        prompts.append(b["input_text"])
        edit_instructions.append(b["edit_instruction"])
        targets.append(b["target"])
        counterfactual_prompts.append(b["counterfactual_input_text"])
        
    editor_input_ids = tokenizer(edit_instructions, return_tensors="pt", padding=True, truncation=True)["input_ids"]
    
    returned_dict = {
        "editor_input_ids": editor_input_ids,
        **tokenize_text_inputs(prompts, counterfactual_prompts, targets),
    }    
    
    return returned_dict

batch_size = 16  # 50 or so
data_loader = DataLoader(
    city_train_set, batch_size=batch_size, collate_fn=ravel_collate_fn, shuffle=False
)  # batch_size, collate_fn=collate_fn)
test_data_loader = DataLoader(
    city_test_set, batch_size=batch_size, collate_fn=ravel_collate_fn, shuffle=False
)

for batch in data_loader:
    break

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


TypeError: 'NoneType' object is not iterable

In [9]:
# @torch.compile #Apparently this fails when used inside jupyter notebooks but is fine if i make dedicated scripts
from models.llama3.model import LlamaInterpretor, LlamaInterpretorConfig, RavelInterpretorHypernetwork
from models.utils import InterpretorModelOutput
from transformers import LlamaForCausalLM
import wandb

llama = LlamaForCausalLM.from_pretrained("/work/frink/models/llama3-8B-HF")
llama = llama.cuda()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [10]:
llama.eval()
test_loss = []
correct_idxs = []
correct = 0
total = 0

with torch.no_grad():
    for batch_id, batch in enumerate(test_data_loader):
        
        predictions = llama.forward(
            input_ids=batch["base_input_ids"].to("cuda"),
            attention_mask=batch["base_attention_mask"].to("cuda"),
            labels=batch["labels"].to("cuda"),
            
        )
        test_loss.append(predictions["loss"].item())
        
        batch_pred_ids = torch.argmax(predictions["logits"], dim=-1)
        
        for i, (label, pred_ids) in enumerate(zip(batch["labels"].to("cuda"), batch_pred_ids)):
            
            print(label)
            label_idx = label != -100
            output_idx = torch.zeros_like(label_idx)
            output_idx[:-1] = label_idx[1:]
            
            label = label[label_idx]
            pred_ids = pred_ids[output_idx]
            print(tokenizer.decode(label))
            print(tokenizer.decode(pred_ids))
            print()
            
            is_correct = (torch.sum(label == pred_ids) == torch.numel(label)).item()
                
            correct += is_correct
            if is_correct:
                correct_idxs.append(batch_id * len(batch["labels"]) + i)
            total += 1
        

tensor([ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100, 15506], device='cuda:0')
 Spanish
 Spanish

tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 4606],
       device='cuda:0')
 Europe
 South

tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 6498],
       device='cuda:0')
 English
 English

tensor([ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100, 10384], device='cuda:0')
 Africa
 Africa

tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 3723, 4273],
       device='cuda:0')
 United States
 his States

tensor([ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
        48475,    64], device='cuda:0')
 Hausa
 Englisha

tensor([ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100, 35217], device='cuda:0')
 Arabic
 the

tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 4606],
 

In [11]:
correct / total

0.3