In [1]:
import sys

sys.path.append("../..")
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from transformers import AutoTokenizer, LlamaForCausalLM

from src.hyperdas.data_utils import (
    filter_dataset,
    generate_ravel_dataset,
)

In [3]:
model_name_or_path = "meta-llama/Meta-Llama-3-8B"

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = LlamaForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16)
model = model.cuda()

tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [5]:
all_attributes = [
    "Country",
    "Continent",
    "Language",
    "Latitude",
    "Longitude",
    "Timezone",
]

for split in ["train", "test"]:
    dataset = generate_ravel_dataset(
        10000,
        root_path="/workspace/HyperDAS/assets/data/ravel",
        target_attributes=["Country", "Continent"],
        isolate_attributes=list(set(all_attributes) - set(["Country", "Continent"])),
        template_split=split,
        entity_split=split,
    )

    dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)
    dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)
    dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)

    dataset.save_to_disk(
        f"/workspace/HyperDAS/experiments/RAVEL/data/city_country_{split}"
    )

625it [00:50, 12.30it/s]


Accuracy: 0.5883; filtered out 4117 examples


368it [00:27, 13.20it/s]


Accuracy: 0.9959204487506375; filtered out 24 examples


367it [00:25, 14.58it/s]

Accuracy: 0.9989759344598055; filtered out 6 examples





Saving the dataset (0/1 shards):   0%|          | 0/5853 [00:00<?, ? examples/s]

625it [00:45, 13.84it/s]


Accuracy: 0.5804; filtered out 4196 examples


363it [00:25, 14.45it/s]


Accuracy: 0.9943142660234321; filtered out 33 examples


361it [00:24, 14.66it/s]


Accuracy: 0.9975740772829665; filtered out 14 examples


Saving the dataset (0/1 shards):   0%|          | 0/5757 [00:00<?, ? examples/s]

In [5]:
# CHANGE: {base_entity} → {source_entity} | ATTR: {target_attribute}
all_attributes = [
    "Country",
    "Continent",
    "Language",
    "Latitude",
    "Longitude",
    "Timezone",
]

for split in ["train", "test"]:
    dataset = generate_ravel_dataset(
        10000,
        root_path="/workspace/HyperDAS/assets/data/ravel",
        target_attributes=["Country", "Continent"],
        isolate_attributes=list(set(all_attributes) - set(["Country", "Continent"])),
        template_split=split,
        entity_split=split,
        edit_instruction_template="CHANGE: {base_entity} -> {source_entity} | ATTR: {random_target_attribute}",
    )

    dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)
    dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)
    dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)

    dataset.save_to_disk(
        f"/workspace/HyperDAS/experiments/RAVEL/data/city_country_{split}_v0.1"
    )

625it [00:49, 12.59it/s]


Accuracy: 0.5956; filtered out 4044 examples


373it [00:27, 13.75it/s]


Accuracy: 0.9971457353928811; filtered out 17 examples


372it [00:25, 14.64it/s]

Accuracy: 0.9988213503956895; filtered out 7 examples





Saving the dataset (0/1 shards):   0%|          | 0/5932 [00:00<?, ? examples/s]

625it [00:44, 13.97it/s]


Accuracy: 0.5884; filtered out 4116 examples


368it [00:25, 14.31it/s]


Accuracy: 0.9954112848402448; filtered out 27 examples


367it [00:24, 14.76it/s]


Accuracy: 0.9994877923851802; filtered out 3 examples


Saving the dataset (0/1 shards):   0%|          | 0/5854 [00:00<?, ? examples/s]

In [None]:
all_attributes = [
    "Country",
    "Continent",
    "Language",
    "Latitude",
    "Longitude",
    "Timezone",
]

for attribute in all_attributes:
    for split, size in [("train", 20000), ("test", 4000)]:
        print(f"Generating data for {attribute} in split {split}")
        all_other_attributes = [a for a in all_attributes if a != attribute]

        dataset = generate_ravel_dataset(
            size,
            root_path="/home/ubuntu/HyperDAS/data/ravel",
            target_attributes=[attribute],
            isolate_attributes=all_other_attributes,
            template_split=split,
            entity_split="both",
        )

        dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)
        dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)
        dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)

        dataset.save_to_disk(
            f"/home/ubuntu/HyperDAS/experiments/ravel/data/city_{attribute.lower()}_{split}"
        )

In [13]:
def generate_ravel_causal_dataset(split):
    dataset = generate_ravel_dataset(
        10000,
        root_path="/workspace/HyperDAS/assets/data/ravel",
        target_attributes=["Country"],
        isolate_attributes=[],
        template_split=split,
        entity_split=split,
    )

    dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)
    dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)
    dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)

    dataset.save_to_disk(
        f"/workspace/HyperDAS/experiments/RAVEL/data/ravel_country_causal_only_{split}"
    )

    return dataset

In [None]:
train_dataset, test_dataset = (
    generate_ravel_causal_dataset("train"),
    generate_ravel_causal_dataset("test"),
)

In [None]:
train_dataset[0]

In [None]:
all_country = [d for d in test_dataset if d["attribute"] == "Country"]
all_country_with_isolation = [
    d
    for d in test_dataset
    if d["attribute"] == "Country" and d["attribute_type"] == "isolate"
]
len(test_dataset), len(all_country), len(all_country_with_isolation)

In [None]:
attr_dict = {}

for d in test_dataset:
    if d["attribute"] not in attr_dict:
        attr_dict[d["attribute"]] = 0

    attr_dict[d["attribute"]] += 1

attr_dict