In [1]:
import sys

sys.path.append('../..')
%load_ext autoreload
%autoreload 2

In [3]:
from torch.utils.data import DataLoader
from datasets import load_from_disk
import torch
from src.hyperdas.data_utils import generate_ravel_dataset, get_ravel_collate_fn, filter_dataset

from transformers import AutoTokenizer, Gemma2ForCausalLM

In [4]:
tokenizer = AutoTokenizer.from_pretrained("/nlp/scr/sjd24/cache/hub/models--google--gemma-2-9b-it/snapshots/11c9b309abf73637e4b6f9a3fa1e92e615547819")
model = Gemma2ForCausalLM.from_pretrained("/nlp/scr/sjd24/cache/hub/models--google--gemma-2-9b-it/snapshots/11c9b309abf73637e4b6f9a3fa1e92e615547819", torch_dtype=torch.bfloat16)
model = model.cuda()

tokenizer.padding_side = "right"
# tokenizer.pad_token = tokenizer.eos_token
# tokenizer.pad_token_id = tokenizer.eos_token_id

Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00,  7.61it/s]


In [10]:
all_attributes = ["Country", "Continent", "Language", "Latitude", "Longitude", "Timezone"]

for attribute in all_attributes:
    for split, size in [("train", 20000), ("test", 4000)]:
        
        print(f"Generating data for {attribute} in split {split}")
        all_other_attributes = [a for a in all_attributes if a != attribute]
        
        dataset = generate_ravel_dataset(
            size,
            root_path="../../data/RAVEL",
            target_attributes=[attribute],
            isolate_attributes=all_other_attributes,
            template_split=split,
            entity_split="both",
        )
        
        dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)
        dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)
        dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)
        
        dataset.save_to_disk(f"./gemma2_data/city_{attribute.lower()}_{split}")

Generating data for Country in split train


1250it [01:46, 11.68it/s]


Accuracy: 0.6368; filtered out 7264 examples


796it [01:07, 11.76it/s]


Accuracy: 0.9997644472361809; filtered out 3 examples


796it [01:07, 11.76it/s]


Accuracy: 0.9989004947773502; filtered out 14 examples


Saving the dataset (1/1 shards): 100%|██████████| 12719/12719 [00:00<00:00, 66820.76 examples/s]


Generating data for Country in split test


250it [00:21, 11.68it/s]


Accuracy: 0.6105; filtered out 1558 examples


153it [00:13, 11.66it/s]


Accuracy: 0.9946764946764947; filtered out 13 examples


152it [00:13, 11.69it/s]


Accuracy: 0.9942363112391931; filtered out 14 examples


Saving the dataset (1/1 shards): 100%|██████████| 2415/2415 [00:00<00:00, 51325.77 examples/s]


Generating data for Continent in split train


1250it [01:47, 11.66it/s]


Accuracy: 0.66245; filtered out 6751 examples


829it [01:10, 11.71it/s]


Accuracy: 0.9993207034493169; filtered out 9 examples


828it [01:10, 11.73it/s]


Accuracy: 0.9993202416918429; filtered out 9 examples


Saving the dataset (1/1 shards): 100%|██████████| 13231/13231 [00:00<00:00, 72537.24 examples/s]

Generating data for Continent in split test



250it [00:21, 11.67it/s]


Accuracy: 0.635; filtered out 1460 examples


159it [00:13, 11.72it/s]


Accuracy: 0.9952755905511811; filtered out 12 examples


158it [00:13, 11.69it/s]


Accuracy: 1.0; filtered out 0 examples


Saving the dataset (1/1 shards): 100%|██████████| 2528/2528 [00:00<00:00, 38347.23 examples/s]


Generating data for Language in split train


1250it [01:48, 11.47it/s]


Accuracy: 0.5672; filtered out 8656 examples


709it [01:00, 11.71it/s]


Accuracy: 0.9992947813822285; filtered out 8 examples


709it [01:00, 11.73it/s]


Accuracy: 0.9992942836979535; filtered out 8 examples


Saving the dataset (1/1 shards): 100%|██████████| 11328/11328 [00:00<00:00, 70122.55 examples/s]


Generating data for Language in split test


250it [00:22, 11.28it/s]


Accuracy: 0.5545; filtered out 1782 examples


139it [00:12, 11.52it/s]


Accuracy: 0.9950405770964833; filtered out 11 examples


138it [00:11, 11.53it/s]


Accuracy: 0.9932034435885818; filtered out 15 examples


Saving the dataset (1/1 shards): 100%|██████████| 2192/2192 [00:00<00:00, 53346.61 examples/s]


Generating data for Latitude in split train


1250it [01:49, 11.44it/s]


Accuracy: 0.4803; filtered out 10394 examples


601it [00:52, 11.56it/s]


Accuracy: 0.9980220695398709; filtered out 19 examples


600it [00:51, 11.58it/s]


Accuracy: 0.9988526129133202; filtered out 11 examples


Saving the dataset (1/1 shards): 100%|██████████| 9576/9576 [00:00<00:00, 60633.12 examples/s]


Generating data for Latitude in split test


250it [00:21, 11.41it/s]


Accuracy: 0.46225; filtered out 2151 examples


116it [00:10, 11.49it/s]


Accuracy: 0.9951325040562466; filtered out 9 examples


115it [00:10, 11.42it/s]


Accuracy: 1.0; filtered out 0 examples


Saving the dataset (1/1 shards): 100%|██████████| 1840/1840 [00:00<00:00, 44516.27 examples/s]


Generating data for Longitude in split train


1250it [01:49, 11.37it/s]


Accuracy: 0.4675; filtered out 10650 examples


585it [00:50, 11.48it/s]


Accuracy: 0.9987165775401069; filtered out 12 examples


584it [00:50, 11.45it/s]


Accuracy: 0.9985007496251874; filtered out 14 examples


Saving the dataset (1/1 shards): 100%|██████████| 9324/9324 [00:00<00:00, 48984.73 examples/s]


Generating data for Longitude in split test


250it [00:22, 11.35it/s]


Accuracy: 0.4485; filtered out 2206 examples


113it [00:09, 11.46it/s]


Accuracy: 0.9977703455964325; filtered out 4 examples


112it [00:09, 11.45it/s]


Accuracy: 0.9916201117318436; filtered out 15 examples


Saving the dataset (1/1 shards): 100%|██████████| 1775/1775 [00:00<00:00, 42369.12 examples/s]


Generating data for Timezone in split train


1250it [01:44, 11.99it/s]


Accuracy: 0.53155; filtered out 9369 examples


665it [00:55, 11.95it/s]


Accuracy: 0.9983068384912049; filtered out 18 examples


664it [00:55, 12.01it/s]


Accuracy: 0.9993404315462169; filtered out 7 examples


Saving the dataset (1/1 shards): 100%|██████████| 10606/10606 [00:00<00:00, 73362.81 examples/s]


Generating data for Timezone in split test


250it [00:21, 11.78it/s]


Accuracy: 0.49125; filtered out 2035 examples


123it [00:10, 11.85it/s]


Accuracy: 0.9918575063613232; filtered out 16 examples


122it [00:10, 11.69it/s]


Accuracy: 0.9933299127757824; filtered out 13 examples


Saving the dataset (1/1 shards): 100%|██████████| 1936/1936 [00:00<00:00, 55597.44 examples/s]


AttributeError: 'Gemma2ForCausalLM' object has no attribute 'padding_side'

In [None]:
data = [
    "ravel_occupation_attribute_to_prompts.json",
    "what is the pattern of the occupation attribute in the ravel dataset?",
]

tokenizer = AutoTokenizer.from_pretrained("/nlp/scr/sjd24/cache/hub/models--google--gemma-2-9b-it/snapshots/11c9b309abf73637e4b6f9a3fa1e92e615547819")
tokenizer.padding_side = "right"

batch = tokenizer(data, return_tensors="pt", padding=True, truncation=True)

batch = {k: v.cuda() for k, v in batch.items()}

logits = model(**batch).logits
pred_ids = torch.argmax(logits, dim=-1)
pred_ids

pred = model.generate(**batch, max_new_tokens=30)
pred[0], pred_ids[0], batch["input_ids"][0],

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


(tensor([     2,  62467, 235298,  88461, 235298,  14831, 235298,    511, 235298,
          12159,  23602, 235265,   3327,      0,      0,    108,   1917,    109,
            688,  32729,  66058,    109,   1596,  12041,   5548,    476,  11384,
           2482,  10751,  29037,  16258,    578,  57878,   1174,   1280,  73815,
          10383,    604,  31928,  12354,   2793,  32808,   1154,  29632, 235269],
        device='cuda:0'),
 tensor([   185, 235256,   2195, 235298,   1259, 235298,  31332, 235298,   1973,
          23602, 235278,   2158,    108,    108,    108], device='cuda:0'),
 tensor([     2,  62467, 235298,  88461, 235298,  14831, 235298,    511, 235298,
          12159,  23602, 235265,   3327,      0,      0], device='cuda:0'))

In [93]:
tokenizer.bos_token_id

2

In [78]:
tokenizer.decode(pred[1], skip_special_tokens=True), tokenizer.decode(pred_ids[1], skip_special_tokens=True)

('what is the pattern of the occupation attribute in the ravel dataset?\n\nThe RAVEL dataset doesn\'t have an "occupation" attribute.',
 '<h1>s the difference in the sequence of in the datasetchelry?\n\n')

In [75]:
pred_ids = torch.argmax(logits, dim=-1)
pred_ids

tensor([[     0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0],
        [   185, 235256,    573,   5830,    575,    573,  10629,    576,    575,
            573,  20884,  27272,    831, 235336,    109]], device='cuda:0')

In [67]:
tokenizer.decode(pred_ids[1], skip_special_tokens=True)

'<h1>s the difference in the sequence of in the datasetchelry?\n\n'

In [11]:
test_dataset = generate_ravel_dataset(
    10000,
    root_path="/home/ubuntu/HyperDAS/data/ravel",
    target_attributes=["Country"],
    isolate_attributes=["Continent", "Language", "Latitude", "Longitude", "Timezone"],
    template_split="test",
    entity_split="both",
)

test_dataset = filter_dataset(model, tokenizer, test_dataset, batch_size=16)
test_dataset = filter_dataset(model, tokenizer, test_dataset, batch_size=16)
test_dataset = filter_dataset(model, tokenizer, test_dataset, batch_size=16)

# test_dataset.save_to_disk("./data/ravel_sanity_check_test")

0it [00:00, ?it/s]

625it [00:26, 23.98it/s]


Accuracy: 0.6655; filtered out 3345 examples


416it [00:17, 23.65it/s]


Accuracy: 0.9959429000751315; filtered out 27 examples


415it [00:17, 23.78it/s]

Accuracy: 0.998038624019312; filtered out 13 examples





In [5]:
len(test_dataset)

NameError: name 'test_dataset' is not defined

In [13]:
all_country = [d for d in test_dataset if d["attribute"] == "Country"]
all_country_with_isolation = [d for d in test_dataset if d["attribute"] == "Country" and d["attribute_type"] == "isolate"]
len(test_dataset), len(all_country), len(all_country_with_isolation)

(6615, 3923, 0)

In [14]:
attr_dict = {}

for d in test_dataset:
    if d["attribute"] not in attr_dict:
        attr_dict[d["attribute"]] = 0
    
    attr_dict[d["attribute"]] += 1
    
attr_dict

{'Country': 3923,
 'Continent': 829,
 'Language': 629,
 'Latitude': 339,
 'Longitude': 319,
 'Timezone': 576}