In [1]:
import sys 
sys.path.append("../src")

In [2]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")

In [3]:
from collators.data_collator_for_grid_tokenization import DataCollatorForGridTokenization

In [4]:
LINE_LENGTH = 64

In [5]:
collator = DataCollatorForGridTokenization(
    tokenizer, 
    1024,
    is_train=False,
    is_grid_tokenization=True,
    line_length=LINE_LENGTH,
)

In [6]:
# input_text = "Hello, world! This is a very long string that should be padded to the nearest multiple of 32. We will also include the end of text token at the end."
input_text = ""
token_ids = collator._grid_tokenize_string(
    input_text,
    include_eot=True,
    include_header=True,
    header_content="system"
)


In [7]:
class PrettyPrintTokens:
    def __init__(self, len_for_each_token=4, line_length=32):
        self.len_for_each_token = len_for_each_token
        self.line_length = line_length
    
        
    def __call__(self, tokens, len_for_each_token=None, line_length=None):
        len_for_each_token = len_for_each_token or self.len_for_each_token
        line_length = line_length or self.line_length
        
        tkns = [tkn.replace("Ġ", "_") for tkn in tokens ]
        tkns = [str(i) for i in range(line_length)] + tkns
        res = []
        for tkn in tkns:
            if len(tkn) <= len_for_each_token:
                # pad the token with spaces
                res.append(" " * (len_for_each_token - len(tkn)) + tkn)
            else:
                res.append(tkn[:len_for_each_token])
        
        # Group the tokens into lines of length line_length
        for i in range(0, len(res), line_length):
            row_idx = i // line_length
            row_idx_str = str(row_idx).rjust(3, " ")
            print(row_idx_str + ": " + " ".join(res[i:i+line_length]))
        
pretty_printer = PrettyPrintTokens(
    len_for_each_token=4,
    line_length=LINE_LENGTH,
)

In [None]:

print(input_text)
print("-"*100)
tkns = tokenizer.convert_ids_to_tokens(token_ids)
pretty_printer(tkns)



In [None]:
table = [
    ["", "Name", "Age", "Email", "Phone"],
    ["Row 1", "John", "25", "john@example.com", "123-456-7890"],
    ["Row 2", "Jane", "30", "jane@example.com", "098-765-4321"],
    ["Row 3", "Christopher", "36", "chris.testing@verylongcompany.com", "123-123-1234"],
]
token_ids = collator._grid_tokenize_table(table)
tkns = tokenizer.convert_ids_to_tokens(token_ids)
print("-"*100)
pretty_printer(tkns)


In [None]:
csv_string = """
,Name,Age,Email,Phone
Row 1,John,25,john@example.com,123-456-7890
Row 2,Jane,30,jane@example.com,098-765-4321
Row 3,Christopher,36,chris.testing@verylongcompany.com,123-123-1234
"""
table = collator._convert_csv_string_to_table(csv_string)
print(table)
token_ids = collator._grid_tokenize_table(table)
tkns = tokenizer.convert_ids_to_tokens(token_ids)
print("-"*100)
pretty_printer(tkns)


In [None]:
from parsers.argument_classes import DatasetArguments
from utils.datasets_loader import load_datasets

dataset_args = DatasetArguments(
    dataset_root_dir="../datasets",
    dataset_names=["self_generated"],
    table_extension="csv",
    train_max_samples_for_each_dataset=100,
    val_max_samples_for_each_dataset=100,
    test_max_samples_for_each_dataset=100,
)

datasets = load_datasets(dataset_args)


In [None]:
idx = 0
print(datasets["test"][idx]["table"])
print("-"*100)
print(datasets["test"][idx]["question"])
print("-"*100)
print(datasets["test"][idx]["answer"])
print("-"*100)
collator.is_train = False
input_ids = collator._grid_tokenize_example(datasets["test"][idx])
tkns = tokenizer.convert_ids_to_tokens(input_ids)
pretty_printer(tkns, len_for_each_token=30, line_length=32)

In [None]:

collator.is_train = True
batch = collator([datasets["train"][1], datasets["train"][2], datasets["train"][3]])
print(batch)
print(batch["input_ids"][2])
print(batch["attention_mask"][2])
if "labels" in batch:
    print(batch["labels"][2])
    

In [None]:
tkns = tokenizer.convert_ids_to_tokens(batch["input_ids"][0])
attn_mask = [str(i) for i in batch["attention_mask"][0].tolist()]
labels = [str(i) for i in batch["labels"][1].tolist()]



pretty_printer(tkns, len_for_each_token=5, line_length=64)

pretty_printer(labels, len_for_each_token=5, line_length=64)
pretty_printer(attn_mask, len_for_each_token=5, line_length=64)



In [None]:
tokenizer.convert_ids_to_tokens([128009])

In [None]:
datasets["train"]["size"]

In [17]:
collator.is_train = True
collator.is_grid_tokenization = False
batch = collator([datasets["train"][1], datasets["train"][2], datasets["train"][3]])

In [None]:
print(batch["labels"][1])
print(batch["attention_mask"][1])
print(batch["input_ids"][1])



In [None]:
# Stress test the collator
collator = DataCollatorForGridTokenization(
    tokenizer, 
    1024,
    is_train=True,
    is_grid_tokenization=True,
    line_length=64,
)


dataset_args = DatasetArguments(
    dataset_root_dir="../datasets",
    dataset_names=["wtq"],
    table_extension="csv",
    train_max_samples_for_each_dataset=100,
    val_max_samples_for_each_dataset=100,
    test_max_samples_for_each_dataset=100,
)

datasets = load_datasets(dataset_args)

from torch.utils.data import DataLoader
dataloader = DataLoader(datasets["train"], batch_size=8, collate_fn=collator)
for batch in dataloader:
    print(batch["input_ids"].shape)


In [None]:
batch