In [1]:
import torch

torch.manual_seed(42)
torch.set_default_device("cuda")
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast, LlamaConfig, LlamaModel
from torch import compile
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset, DataLoader
import transformers
from tqdm import tqdm
import pandas as pd
import yaml
import contextlib
import os
import time
import sys

sys.path.append("..")

%load_ext autoreload
%autoreload 2

In [2]:
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
tokenizer = PreTrainedTokenizerFast.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v0.1")
tokenizer.pad_token = tokenizer.eos_token
model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v0.1")
for param in model.parameters():
    param.requires_grad = False
    
model.config.pad_token_id = model.config.eos_token_id
model.generation_config.pad_token_id = tokenizer.pad_token_id

tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LlamaTokenizer'. 
The class this function is called from is 'PreTrainedTokenizerFast'.


In [None]:
import random
import numpy as np
import json
import os
from tqdm import tqdm
from datasets import Dataset



def generate_ravel_dataset(n_samples, split="train", domains=["city"], domain_excluded_attributes=[["Latitude", "Longitude", "Timezone"]], filtering_dict_paths=[None],  seed=42):
            
    # Seed
    random.seed(seed)
    np.random.seed(seed)
    dataset = []
    
    sample_per_domain = n_samples // len(domains)
    
    for domain, excluded_attributes, filtering_dict_path in zip(domains, domain_excluded_attributes, filtering_dict_paths):
                        
        templates = json.load(open(os.path.join("./data/ravel/", f"ravel_{domain}_attribute_to_prompts.json"), "r"))
        entities = json.load(open(os.path.join("./data/ravel/", f"ravel_{domain}_entity_attributes.json"), "r"))
        entities_split = json.load(open(os.path.join("./data/ravel/", f"ravel_{domain}_entity_to_split.json"), "r"))
        templates_split = json.load(open(os.path.join("./data/ravel/", f"ravel_{domain}_prompt_to_split.json"), "r"))
        
        all_attributes = [a for a in list(templates.keys()) if a not in excluded_attributes]

        templates_train = {k: [v for v in vs if templates_split[v] == "train"] for k, vs in templates.items()}
        templates_train_idxs = {k: [templates[k].index(v) for v in vs if templates_split[v] == "train"] for k, vs in templates.items()}
        
        templates_test = {k: [v for v in vs if templates_split[v] != "train"] for k, vs in templates.items()}
        templates_test_idxs = {k: [templates[k].index(v) for v in vs if templates_split[v] != "train"] for k, vs in templates.items()}

        entities_train = {k: v for k, v in entities.items() if entities_split[k] == "train"}
        name_train = list(entities_train.keys())

        entities_test = {k: v for k, v in entities.items() if entities_split[k] != "train"}
        name_test = list(entities_test.keys())
        
        if split == "train":
            entity_dict, entity_name, prompt_dict, prompt_idxs_dict = entities_train, name_train, templates_train, templates_train_idxs
        elif split == "test":
            entity_dict, entity_name, prompt_dict, prompt_idxs_dict = entities_test, name_test, templates_test, templates_test_idxs
        else:
            raise ValueError("split must be 'train' or 'test'")
        
        if filtering_dict_path is not None:
            filtering_dict = json.load(open(filtering_dict_path, "r"))
            filtered_key = []
            
            for entity in filtering_dict.keys():
                model_knows = True
                for attribute in all_attributes:                    
                    split_template_idx = [list(filtering_dict[entity][attribute].values())[i] for i in prompt_idxs_dict[attribute]]
                    if True not in split_template_idx:
                        model_knows = False
                        break
                if not model_knows:
                    filtered_key.append(entity)
            print(f"Filtering out {len(filtered_key)} out of {len(filtering_dict)} entities that the model does not know!")
            filtering_dict = {k: v for k, v in filtering_dict.items() if k not in filtered_key}
        else:
            filtering_dict = None
        
        for _ in tqdm(range(sample_per_domain)):
            
            data = {}
            
            if filtering_dict is None:
                source_entity, base_entity = random.sample(entity_name, 2)
                attribute = random.choice(all_attributes)
                frozen_attributes = [k for k in all_attributes if k != attribute]
                source_entity_dict, base_entity_dict = entity_dict[source_entity], entity_dict[base_entity]
                source_template, base_template = random.choice(prompt_dict[attribute]), random.choice(prompt_dict[attribute])
            else:
                source_entity, base_entity = random.sample([k for k in entity_name if k in filtering_dict.keys()], 2)
                attribute = random.choice(all_attributes)
                frozen_attributes = [k for k in all_attributes if k != attribute]
                source_entity_dict, base_entity_dict = entity_dict[source_entity], entity_dict[base_entity]
                source_template_idxs = [i for i in range(len(filtering_dict[source_entity][attribute])) if filtering_dict[source_entity][attribute][str(i)] == True]
                source_template_idxs = [prompt_idxs_dict[attribute].index(i) for i in source_template_idxs if i in prompt_idxs_dict[attribute]]
                source_template = random.choice([prompt_dict[attribute][i] for i in source_template_idxs])
                base_template_idxs = [i for i in range(len(filtering_dict[base_entity][attribute])) if filtering_dict[base_entity][attribute][str(i)] == True]
                base_template_idxs = [prompt_idxs_dict[attribute].index(i) for i in base_template_idxs if i in prompt_idxs_dict[attribute]]
                base_template = random.choice([prompt_dict[attribute][i] for i in base_template_idxs])
                
            data["input_text"] = base_template % base_entity
            data["counterfactual_input_text"] = source_template % source_entity
            data["edit_instruction"] = f"{base_entity} ; {source_entity} - {attribute}"
            data["target"] = base_entity_dict[attribute]
            data["counterfactual_target"] = source_entity_dict[attribute]
            
            data["unaffected_attributes"] = []
            
            for frozen_attribute in frozen_attributes:
                
                if filtering_dict is None:
                    source_attribute_template, base_attribute_template = random.choice(prompt_dict[frozen_attribute]), random.choice(prompt_dict[frozen_attribute])
                else:
                    source_attribute_idxs = [i for i in range(len(filtering_dict[source_entity][frozen_attribute])) if filtering_dict[source_entity][frozen_attribute][str(i)] == True]
                    source_attribute_idxs = [prompt_idxs_dict[frozen_attribute].index(i) for i in source_attribute_idxs if i in prompt_idxs_dict[frozen_attribute]]
                    source_attribute_template = random.choice([prompt_dict[frozen_attribute][i] for i in source_attribute_idxs])
                    base_attribute_idxs = [i for i in range(len(filtering_dict[base_entity][frozen_attribute])) if filtering_dict[base_entity][frozen_attribute][str(i)] == True]
                    base_attribute_idxs = [prompt_idxs_dict[frozen_attribute].index(i) for i in base_attribute_idxs if i in prompt_idxs_dict[frozen_attribute]]
                    base_attribute_template = random.choice([prompt_dict[frozen_attribute][i] for i in base_attribute_idxs])
                    
                base_prompt = base_attribute_template % base_entity
                counterfactual_prompt = source_attribute_template % source_entity
                
                target = base_entity_dict[frozen_attribute]
                counterfactual_target = source_entity_dict[frozen_attribute]
                
                data["unaffected_attributes"].append(
                    {
                        "input_text": base_prompt,
                        "counterfactual_input_text": counterfactual_prompt,
                        "edit_instruction": f"{base_entity} ; {source_entity} - {attribute}",
                        "target": target,
                        "counterfactual_target": counterfactual_target,
                    }
                )
                
            dataset.append(data)
                
    dataset = Dataset.from_list(dataset)
    return dataset

city_train_set = generate_ravel_dataset(10000, split="train", filtering_dict_paths=["./notebooks/ravel_llama-3-8b_city_prompt_to_output_statistics.json"])
city_test_set =  generate_ravel_dataset(1000, split="test", filtering_dict_paths=["./notebooks/ravel_llama-3-8b_city_prompt_to_output_statistics.json"])

def ravel_collate_fn(batch):
    
    def tokenize_text_inputs(texts, counterfactual_texts, target_texts):
        
        input_texts = [text + " " + target for text, target in zip(texts, target_texts)]
        input_texts = [text.replace(" \" ", " \" ") for text in input_texts]
        
        tokenized = tokenizer(input_texts, return_tensors="pt", padding=True, max_length=50, truncation=True)
        tokenized_counterfactual = tokenizer(counterfactual_texts, return_tensors="pt", padding=True, max_length=50, truncation=True)
        tokenized_labels = []
        
        for input_ids, input_text in zip(tokenized["input_ids"], texts):
            input_length = tokenizer(input_text, return_tensors="pt", padding=False)["input_ids"].shape[-1]
            label = torch.full_like(input_ids, -100)
            label[input_length:] = input_ids[input_length:]
            label[input_ids == tokenizer.pad_token_id] = -100
            tokenized_labels.append(label)
        
        tokenized_labels = torch.stack(tokenized_labels)
        return {
            "base_input_ids": tokenized["input_ids"],
            "base_attention_mask": tokenized["attention_mask"],
            "source_input_ids": tokenized_counterfactual["input_ids"],
            "source_attention_mask": tokenized_counterfactual["attention_mask"],
            "labels": tokenized_labels
        }
        
    prompts, edit_instructions, targets, unaffected_attributes, counterfactual_prompts = [], [], [], [], []
    for b in batch:
        prompts.append(b["input_text"])
        edit_instructions.append(b["edit_instruction"])
        targets.append(b["counterfactual_target"])
        unaffected_attributes.append(b["unaffected_attributes"])
        counterfactual_prompts.append(b["counterfactual_input_text"])
        
        
    editor_input_ids = tokenizer(edit_instructions, return_tensors="pt", padding=True, truncation=True)["input_ids"]
    
    returned_dict = {
        "editor_input_ids": editor_input_ids,
        **tokenize_text_inputs(prompts, counterfactual_prompts, targets),
    }
    
    base_prompts_unaffected, counterfactual_prompts_unaffected, targets_unaffected, edit_instructions_unaffected, instance_indices = [], [], [], [], []
    
    for i, attribute_list in enumerate(unaffected_attributes):
        
        for d in attribute_list:
            base_prompts_unaffected.append(d["input_text"])
            targets_unaffected.append(d["target"])
            counterfactual_prompts_unaffected.append(d["counterfactual_input_text"])
                    
        for _ in range(len(attribute_list)):
            edit_instructions_unaffected.append(editor_input_ids[i])
            instance_indices.append(i)
        
    edit_instructions_unaffected = torch.stack(edit_instructions_unaffected)
    instance_indices = torch.tensor(instance_indices)
    
    assert len(base_prompts_unaffected) == len(targets_unaffected)
    
    tokenized_unaffected = tokenize_text_inputs(base_prompts_unaffected, counterfactual_prompts_unaffected, targets_unaffected)
    
    returned_dict["editor_input_ids_unaffected"] = edit_instructions_unaffected
    returned_dict["base_input_ids_unaffected"] = tokenized_unaffected["base_input_ids"]
    returned_dict["base_attention_mask_unaffected"] = tokenized_unaffected["base_attention_mask"]
    returned_dict["source_input_ids_unaffected"] = tokenized_unaffected["source_input_ids"]
    returned_dict["source_attention_mask_unaffected"] = tokenized_unaffected["source_attention_mask"]
    returned_dict["labels_unaffected"] = tokenized_unaffected["labels"]
    returned_dict["instance_indices"] = instance_indices
    
    return returned_dict

batch_size = 32  # 50 or so
data_loader = DataLoader(
    city_train_set, batch_size=batch_size, collate_fn=ravel_collate_fn, shuffle=True
)  # batch_size, collate_fn=collate_fn)
test_data_loader = DataLoader(
    city_test_set, batch_size=batch_size, collate_fn=ravel_collate_fn, shuffle=True
)

Filtering out 3 out of 3552 entities that the model does not know!


100%|██████████| 10000/10000 [00:03<00:00, 2875.14it/s]


AttributeError: type object 'Dataset' has no attribute 'from_list'

In [14]:
def filtering_ravel_dataset(dataset, model, tokenizer, filtering_unaffected_features=True):
    
    def tokenize_text_inputs(texts, target_texts):
        input_texts = [text + " " for text in texts]
        input_texts = [text.replace(" \" ", " \" ") for text in input_texts]
        
        tokenized = tokenizer(input_texts, return_tensors="pt", padding=True)
        tokenized_labels = tokenizer(target_texts, return_tensors="pt", padding=True)["input_ids"]
        tokenized_labels[tokenized_labels == tokenizer.pad_token_id] = -100
        tokenized["labels"] = tokenized_labels
        return tokenized
    
    def filtering_collate_fn(batch):
        prompts, targets, unaffected_attributes = [], [], []
        for b in batch:
            prompts.append(b["counterfactual_input_text"])
            targets.append(b["counterfactual_target"])
            unaffected_attributes.append(b["unaffected_attributes"])
        
        returned_dict = {
            **tokenize_text_inputs(prompts, targets),
        }
        
        prompts_unaffected, targets_unaffected, instance_indices = [], [], []
        
        for i, attribute_list in enumerate(unaffected_attributes):
    
            for d in attribute_list:
                prompts_unaffected.append(d["input_text"])
                targets_unaffected.append(d["target"])
                        
            for _ in range(len(attribute_list)):
                instance_indices.append(i)
            
        instance_indices = torch.tensor(instance_indices)
        
        assert len(prompts_unaffected) == len(targets_unaffected)
        
        tokenized_unaffected = tokenize_text_inputs(prompts_unaffected, targets_unaffected)
        
        returned_dict["input_ids_unaffected"] = tokenized_unaffected["input_ids"]
        returned_dict["attention_mask_unaffected"] = tokenized_unaffected["attention_mask"]
        returned_dict["labels_unaffected"] = tokenized_unaffected["labels"]
        returned_dict["instance_indices"] = instance_indices
        
        return returned_dict
        
    filtered_idx = []
    
    batch_size = 16  # 50 or so
    data_loader = DataLoader(
        dataset, batch_size=batch_size, collate_fn=filtering_collate_fn, shuffle=False, generator=torch.Generator(device='cuda')
    )  # batch_size, collate_fn=collate_fn)
    
    model = model.to("cuda")
    model.eval()
  
    for step, batch in tqdm(enumerate(data_loader), total=len(data_loader), desc=f"Currently have {len(filtered_idx)} samples"):
            
        batch_pred_ids = model.generate(
            input_ids=batch["input_ids"].to("cuda"),
            attention_mask=batch["attention_mask"].to("cuda"),
            max_new_tokens=10
        )
        
        if filtering_unaffected_features:
            batch_unaffected_pred_ids = model.generate(
                input_ids=batch["input_ids_unaffected"].to("cuda"),
                attention_mask=batch["attention_mask_unaffected"].to("cuda"),
                max_new_tokens=10
            )
            
            batch_instance_indices = batch["instance_indices"]
        else:
            batch_unaffected_pred_ids = None
            batch_instance_indices = None
                        
        for i, (label, pred_ids) in enumerate(zip(batch["labels"], batch_pred_ids)):
            
            
            output_idx = label != -100
            label = label[output_idx]
            pred_ids = pred_ids[len(batch["input_ids"][i]):]
            pred_ids = pred_ids[:len(label)]
            print(tokenizer.decode(label, skip_special_tokens=True).strip(), tokenizer.decode(pred_ids, skip_special_tokens=True).strip())
            is_correct = tokenizer.decode(label, skip_special_tokens=True).strip() in tokenizer.decode(pred_ids, skip_special_tokens=True).strip()
            print(is_correct)
            print("---------------")
            
            if filtering_unaffected_features:
                instance_unaffected_pred_ids = batch_unaffected_pred_ids[batch_instance_indices == i]
                instance_unaffected_labels = batch["labels_unaffected"][batch_instance_indices == i]
                
                for (unaffected_label, unaffected_pred_ids) in zip(instance_unaffected_labels, instance_unaffected_pred_ids):
                    
                    unaffected_output_idx = unaffected_label != -100
                    unaffected_label = unaffected_label[unaffected_output_idx]
                    unaffected_pred_ids = unaffected_pred_ids[len(batch["input_ids_unaffected"][i]):]
                    unaffected_pred_ids = unaffected_pred_ids[:len(unaffected_label)]
                    
                    unaffected_is_correct = tokenizer.decode(unaffected_label, skip_special_tokens=True).strip() in tokenizer.decode(unaffected_pred_ids, skip_special_tokens=True).strip()
                    if not unaffected_is_correct:
                        is_correct = False
                        break
        
            if is_correct:
                filtered_idx.append(i + step * batch_size)
                
        print(filtered_idx)
        raise
    
    print(f"Filtered {len(filtered_idx)} out of {len(dataset)} samples")
    return dataset.select(filtered_idx)

filtered_dataset = filtering_ravel_dataset(city_train_set, model, tokenizer)
filtered_test_dataset = filtering_ravel_dataset(city_test_set, model, tokenizer)

Currently have 0 samples:   0%|          | 0/625 [00:06<?, ?it/s]

Africa/Ndjamena 1.5 hours.
False
---------------
Asia/Kolkata 1.
False
---------------
Oceania Australia"}, {"
False
---------------
Europe 
False
---------------
31 31
True
---------------
31 31
True
---------------
Bengali Bengali"},
True
---------------
-33 33.
False
---------------
42 41
False
---------------
Asia/Kuala_Lumpur 2019-12-31
False
---------------
Africa 
False
---------------
India India
True
---------------
-12 12"}
False
---------------
Russian 2
False
---------------
Turkmenistan Turkey"}, {"city
False
---------------
-57 60.
False
---------------
[]





RuntimeError: No active exception to reraise