# Setup

In [1]:
import sys
sys.path.append("../src")

In [2]:
# Load tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")

In [3]:
# Load collator
from collators.data_collator_for_cell_tokenization import DataCollatorForCellTokenizer

# Load dataset arguments
collator = DataCollatorForCellTokenizer(
    tokenizer=tokenizer, 
    max_seq_length=1024, 
    is_train=True, 
    table_cell_separator="|",
    row_offset=1,
    column_offset=1,
)

In [4]:
# Load dataset arguments
from parsers.argument_classes import DatasetArguments
from utils.datasets_loader import load_datasets

dataset_args = DatasetArguments(
    dataset_root_dir="../datasets",
    dataset_names=["wtq"],
    table_extension="csv",
    train_max_samples_for_each_dataset=14149,
    val_max_samples_for_each_dataset=100,
    test_max_samples_for_each_dataset=100,
)

datasets = load_datasets(dataset_args)

In [5]:
datasets["train"][0]

{'question': 'who was the top ranked swimmer in the semifinals?',
 'answer': 'Dyana Calub',
 'context': 'csv/204-csv/544.csv',
 'id': 'nt-9309',
 'task': 'wtq',
 'direction': 'none',
 'size': -1,
 'table_row_num': 17,
 'table_width': 20,
 'table': '"Rank","Name","Nationality","Time","Notes"\n"1","Dyana Calub","Australia","1:01.77","Q"\n"2","Natalie Coughlin","United States","1:01.99","Q"\n"3","Noriko Inada","Japan","1:02.00","Q"\n"4","Haley Cope","United States","1:02.09","Q"\n"5","Diana MacManus","United States","1:02.10","Q"\n"6","Courtney Shealy","United States","1:02.28","Q"\n"7","Aya Terakawa","Japan","1:02.39","Q"\n"8","Giaan Rooney","Australia","1:02.53","Q"\n"9","Erin Gammel","Canada","1:02.63",""\n"10","Hannah McLean","New Zealand","1:02.82",""\n"11","Melissa Morgan","Australia","1:02.86",""\n"12","Reiko Nakamura","Japan","1:02.91",""\n"13","Michelle Lischinsky","Canada","1:03.22",""\n"14","Jennifer Fratesi","Canada","1:03.42",""\n"15","Kelly Stefanyshyn","Canada","1:03.44",""

In [6]:
def pretty_print(tokens):
    line = ""
    for i in tokens:
        decoded_string = collator.tokenizer.decode(i)
        # Append the current token and its decoded representation to the line
        line += f"{repr(decoded_string)} "
        # line += f"{i} {repr(decoded_string)} "
        # Check if the decoded string contains a newline character
        if '\n' in decoded_string:
            # Print the accumulated line and reset it
            print(line.strip())
            line = ""
    # Print any remaining content in the line
    if line:
        print(line.strip())

# Test _tokenize_string

In [7]:
input_text = ""
token_ids = collator._tokenize_string(
    input_text,
    include_eot=True,
    include_header=True,
    header_content="system"
)
print(token_ids)
print(collator.tokenizer.decode(token_ids[0]))

(array([128006,   9125, 128007,    271, 128009]), array([0, 0, 0, 0, 0]), array([0, 0, 0, 0, 0]))
<|start_header_id|>system<|end_header_id|>

<|eot_id|>


In [8]:
input_text = "Hello, world! This is a very long string that should be padded to the nearest multiple of 32. We will also include the end of text token at the end."
token_ids = collator._tokenize_string(
    input_text,
    include_eot=True,
    include_header=True,
    header_content="system"
)
print(token_ids)
print(collator.tokenizer.decode(token_ids[0]))

(array([128006,   9125, 128007,    271,   9906,     11,   1917,      0,
         1115,    374,    264,   1633,   1317,    925,    430,   1288,
          387,  44968,    311,    279,  24379,   5361,    315,    220,
          843,     13,   1226,    690,   1101,   2997,    279,    842,
          315,   1495,   4037,    520,    279,    842,     13, 128009]), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
<|start_header_id|>system<|end_header_id|>

Hello, world! This is a very long string that should be padded to the nearest multiple of 32. We will also include the end of text token at the end.<|eot_id|>


# Test _cell_tokenize_table

In [9]:
table = datasets["train"][0]["table"]
print(table)

"Rank","Name","Nationality","Time","Notes"
"1","Dyana Calub","Australia","1:01.77","Q"
"2","Natalie Coughlin","United States","1:01.99","Q"
"3","Noriko Inada","Japan","1:02.00","Q"
"4","Haley Cope","United States","1:02.09","Q"
"5","Diana MacManus","United States","1:02.10","Q"
"6","Courtney Shealy","United States","1:02.28","Q"
"7","Aya Terakawa","Japan","1:02.39","Q"
"8","Giaan Rooney","Australia","1:02.53","Q"
"9","Erin Gammel","Canada","1:02.63",""
"10","Hannah McLean","New Zealand","1:02.82",""
"11","Melissa Morgan","Australia","1:02.86",""
"12","Reiko Nakamura","Japan","1:02.91",""
"13","Michelle Lischinsky","Canada","1:03.22",""
"14","Jennifer Fratesi","Canada","1:03.42",""
"15","Kelly Stefanyshyn","Canada","1:03.44",""
"16","Clementine Stoney","Australia","1:03.52",""



In [10]:
table = collator._convert_csv_string_to_table(datasets["train"][0]["table"])
result = collator._cell_tokenize_table(table)
pretty_print(result[0])

'Rank' '|' 'Name' '|' 'National' 'ity' '|' 'Time' '|' 'Notes' '|\n'
'1' '|' 'D' 'y' 'ana' ' Cal' 'ub' '|' 'Australia' '|' '1' ':' '01' '.' '77' '|' 'Q' '|\n'
'2' '|' 'N' 'atal' 'ie' ' C' 'ough' 'lin' '|' 'United' ' States' '|' '1' ':' '01' '.' '99' '|' 'Q' '|\n'
'3' '|' 'Nor' 'iko' ' In' 'ada' '|' 'Japan' '|' '1' ':' '02' '.' '00' '|' 'Q' '|\n'
'4' '|' 'H' 'aley' ' C' 'ope' '|' 'United' ' States' '|' '1' ':' '02' '.' '09' '|' 'Q' '|\n'
'5' '|' 'D' 'iana' ' Mac' 'Man' 'us' '|' 'United' ' States' '|' '1' ':' '02' '.' '10' '|' 'Q' '|\n'
'6' '|' 'Court' 'ney' ' She' 'aly' '|' 'United' ' States' '|' '1' ':' '02' '.' '28' '|' 'Q' '|\n'
'7' '|' 'A' 'ya' ' Ter' 'ak' 'awa' '|' 'Japan' '|' '1' ':' '02' '.' '39' '|' 'Q' '|\n'
'8' '|' 'G' 'ia' 'an' ' Rooney' '|' 'Australia' '|' '1' ':' '02' '.' '53' '|' 'Q' '|\n'
'9' '|' 'Er' 'in' ' G' 'amm' 'el' '|' 'Canada' '|' '1' ':' '02' '.' '63' '|' 'nan' '|\n'
'10' '|' 'H' 'annah' ' Mc' 'Lean' '|' 'New' ' Zealand' '|' '1' ':' '02' '.' '82' '|' 'nan' '|\n'
'

# E2E Test

In [11]:
result = collator([datasets["train"][0], datasets["train"][1]])
print(result["input_ids"].shape)
print(result["row_ids"].shape)
print(result["column_ids"].shape)
print(result["attention_mask"].shape)
print(result["labels"].shape)


torch.Size([2, 471])
torch.Size([2, 471])
torch.Size([2, 471])
torch.Size([2, 471])
torch.Size([2, 471])


In [12]:
print("labels", result["labels"])
print("-"*100)
print("attention_mask", result["attention_mask"])
print("-"*100)
print("input_ids", result["input_ids"])
print("-"*100)
print("row_ids", result["row_ids"])
print("-"*100)
print("column_ids", result["column_ids"])


labels tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  

In [13]:
# Test segment ids
table_input_segment = result["input_ids"][0][result["segment_ids"][0] == 1]
text_input_segment = result["input_ids"][0][result["segment_ids"][0] == 0]
print(collator.tokenizer.decode(table_input_segment))
print(datasets["train"][0]["table"])
print("-"*100)
print(collator.tokenizer.decode(text_input_segment))
print(datasets["train"][0]["question"])


Rank|Name|Nationality|Time|Notes|
1|Dyana Calub|Australia|1:01.77|Q|
2|Natalie Coughlin|United States|1:01.99|Q|
3|Noriko Inada|Japan|1:02.00|Q|
4|Haley Cope|United States|1:02.09|Q|
5|Diana MacManus|United States|1:02.10|Q|
6|Courtney Shealy|United States|1:02.28|Q|
7|Aya Terakawa|Japan|1:02.39|Q|
8|Giaan Rooney|Australia|1:02.53|Q|
9|Erin Gammel|Canada|1:02.63|nan|
10|Hannah McLean|New Zealand|1:02.82|nan|
11|Melissa Morgan|Australia|1:02.86|nan|
12|Reiko Nakamura|Japan|1:02.91|nan|
13|Michelle Lischinsky|Canada|1:03.22|nan|
14|Jennifer Fratesi|Canada|1:03.42|nan|
15|Kelly Stefanyshyn|Canada|1:03.44|nan|
16|Clementine Stoney|Australia|1:03.52|nan|

"Rank","Name","Nationality","Time","Notes"
"1","Dyana Calub","Australia","1:01.77","Q"
"2","Natalie Coughlin","United States","1:01.99","Q"
"3","Noriko Inada","Japan","1:02.00","Q"
"4","Haley Cope","United States","1:02.09","Q"
"5","Diana MacManus","United States","1:02.10","Q"
"6","Courtney Shealy","United States","1:02.28","Q"
"7","Aya T

In [14]:
# Test row ids and column ids
random_row_index = 1
random_column_index = 2
input_segment = result["input_ids"][0][(result["row_ids"][0] == random_row_index) & (result["column_ids"][0] == random_column_index)]
print(collator.tokenizer.decode(input_segment))
# 1 is the row offset and column offset
row_offset, column_offset = 1, 1
print(collator._convert_csv_string_to_table(datasets["train"][0]["table"])[random_row_index - row_offset][random_column_index - column_offset])

Name|
Name


In [15]:

# Test labels
print(result["labels"][0][result["labels"][0] != -100])
print(collator.tokenizer.decode(result["labels"][0][result["labels"][0] != -100]))

tensor([16533,    25, 43048,  3444,  3400,   392])
Answer: Dyana Calub


In [16]:
# Test attention mask
print(result["attention_mask"][0])
print(collator.tokenizer.decode(result["input_ids"][0][result["attention_mask"][0] == 0]))

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

# Test loading the entire dataset


In [17]:
# 14149 data points
batch_result = collator(datasets["train"])
print(batch_result["input_ids"].shape)
print(batch_result["row_ids"].shape)
print(batch_result["column_ids"].shape)
print(batch_result["attention_mask"].shape)
print(batch_result["labels"].shape)

ParserError: Error tokenizing data. C error: EOF inside string starting at row 17
ParserError: Error tokenizing data. C error: EOF inside string starting at row 17
ParserError: Error tokenizing data. C error: EOF inside string starting at row 17
ParserError: Error tokenizing data. C error: EOF inside string starting at row 17
ParserError: Error tokenizing data. C error: EOF inside string starting at row 17
ParserError: Error tokenizing data. C error: EOF inside string starting at row 17
ParserError: Error tokenizing data. C error: EOF inside string starting at row 17
ParserError: Error tokenizing data. C error: EOF inside string starting at row 17
ParserError: Error tokenizing data. C error: EOF inside string starting at row 17
torch.Size([14140, 17797])
torch.Size([14140, 17797])
torch.Size([14140, 17797])
torch.Size([14140, 17797])
torch.Size([14140, 17797])
