In [1]:
import os
import re
import datasets
import pandas as pd
from transformers import PreTrainedTokenizerFast
from typing import List
import numpy as np

SYSTEM_PROMPT = "You are a helpful assistant that answers questions about the table. You only answer the question right after 'Answer: '"
ASSISTANT_PROMPT = "Answer: "
SHUFFLE_SEED = 42


In [4]:
# Create dummy data directory
os.makedirs("../datasets/self_generated/data", exist_ok=True)

# Create dummy train, validation, and test CSV files
dummy_data = pd.DataFrame({
    "context": ["table1", "table2"],
    "question": ["What is the value in row 1?", "What is the value in row 2?"],
    "answer": ["Value 1", "Value 2"]
})
dummy_data.to_csv("../datasets/self_generated/data/train.csv", index=False)
dummy_data.to_csv("../datasets/self_generated/data/val.csv", index=False)
dummy_data.to_csv("../datasets/self_generated/data/test.csv", index=False)

# Create dummy table files
with open("../datasets/self_generated/table1.csv", "w") as f:
    f.write("Column1,Column2\nValue 1,Value 2\n")
with open("../datasets/self_generated/table2.csv", "w") as f:
    f.write("Column1,Column2\nValue 3,Value 4\n")

In [5]:
tokenizer = PreTrainedTokenizerFast.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", token='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')
string = "What is the capital of France?"
# for i in string:
#     token = tokenizer.encode(i, add_special_tokens=False)
#     print(i, token)

In [8]:
class TableDatasetLoader:
    def __init__(
        self,
        dataset_root: str,
        dataset_name: str,
        tokenizer: PreTrainedTokenizerFast,
        batch_size: int = 4,
        table_extension: str = "html",
        test_max_samples: int = None,
        val_max_samples: int = None,
        train_max_samples: int = None,
        system_prompt: str = SYSTEM_PROMPT,
        assistant_prompt: str = ASSISTANT_PROMPT,
        user_prompt_order: List[str] = ["question", "table"],
        grid_it: bool = False,
        line_length: int = 10,
        skip_validation: bool = False
    ):
        # Initialize instance variables with given parameters
        self.dataset_root = dataset_root
        self.dataset_name = dataset_name
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.table_extension = table_extension
        self.test_max_samples = test_max_samples
        self.val_max_samples = val_max_samples
        self.train_max_samples = train_max_samples
        self.system_prompt = system_prompt
        self.assistant_prompt = assistant_prompt
        self.user_prompt_order = user_prompt_order
        self.grid_it = grid_it
        self.line_length = line_length
        self.table_pad_token = tokenizer.eos_token
        self.start_of_line_token = "[START_OF_LINE]"
        self.end_of_line_token = "[END_OF_LINE]"
        self.table_cell_separator_token = "[TABLE_CELL_SEPARATOR]"
        self.extension_separator_map = {
            "csv": ",",
            "html": " ",
            "tsv": "\t"
        }
        
        # Validate the input parameters
        if not skip_validation:
            self._validate_inputs()
        if self.grid_it:
            new_tokens = [self.start_of_line_token, self.end_of_line_token, self.table_cell_separator_token]
            self.tokenizer.add_tokens(new_tokens)
            # TODO: model.resize_token_embeddings(len(tokenizer))
        # Set the path to the dataset directory
        self.dataset_path = os.path.join(self.dataset_root, self.dataset_name)

    def _validate_inputs(self):
        # Validate dataset name
        if self.dataset_name not in ["self_generated", "wtq"]:
            raise ValueError(f"Invalid dataset name: {self.dataset_name}")
        # Validate table file extension
        if self.table_extension not in ["csv", "html", "tsv"]:
            raise ValueError(f"Invalid table extension: {self.table_extension}")
        # For self_generated datasets, ensure sample sizes are multiples of 80
        if self.dataset_name == "self_generated":
            if self.test_max_samples is not None and self.test_max_samples % 80 != 0:
                raise ValueError("The number of samples for self-generated dataset must be a multiple of 80")
            if self.val_max_samples is not None and self.val_max_samples % 80 != 0:
                raise ValueError("The number of samples for self-generated dataset must be a multiple of 80")
            if self.train_max_samples is not None and self.train_max_samples % 80 != 0:
                raise ValueError("The number of samples for self-generated dataset must be a multiple of 80")

    def _get_table(self, context: str):
        # Remove .csv extension from the context if present
        context = re.sub(r"\.csv$", "", context)
        separator = self.extension_separator_map[self.table_extension]
        modified_lines = []

        with open(os.path.join(self.dataset_path, context + "." + self.table_extension), "r", encoding="utf-8") as f:
            for line in f:
                cells = line.strip().split(separator)
                modified_line = self.start_of_line_token + self.table_cell_separator_token.join(cells) + self.end_of_line_token
                modified_lines.append(modified_line.strip())
        # Join the modified lines into a single string
        return "\n".join(modified_lines)

    def _preprocess_single_example_to_string(self, example):
        # Retrieve the table content based on the context provided in the example
        table = self._get_table(example["context"])
        example["table"] = table

        # Construct the user prompt using the specified order of fields
        user_prompt = "\n".join([example[col_name] for col_name in self.user_prompt_order if col_name in example])
        # Create the system and user messages for the chat template
        messages = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": user_prompt}
        ]

        # Apply the chat template to create the input string
        text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        text = text + self.assistant_prompt
        if self.grid_it:
            text = self._grid_it(text)
        example["input_string"] = text
        return example

    def _grid_it(self, text):
        # Seperate the text into before_table, table, and after_table
        table_pattern = r"(\[START_OF_LINE\].*?\[END_OF_LINE\](?:\n\[START_OF_LINE\].*?\[END_OF_LINE\])*)"
        parts = re.split(table_pattern, text, maxsplit=1)
        before_table = parts[0].strip()
        table = parts[1].strip() if len(parts) > 1 else ""
        after_table = parts[2].strip() if len(parts) > 2 else ""

        # Special token ids
        start_of_line_token = self.tokenizer.encode(self.start_of_line_token, add_special_tokens=False)[0]
        end_of_line_token = self.tokenizer.encode(self.end_of_line_token, add_special_tokens=False)[0]
        table_cell_separator_token = self.tokenizer.encode(self.table_cell_separator_token, add_special_tokens=False)[0]
        pad_token_id = self.tokenizer.encode(self.table_pad_token, add_special_tokens=False)[0]

        # before_table
        before_table_tokens = self.tokenizer.encode(before_table, add_special_tokens=False)
        before_table_pad_count = self.line_length - len(before_table_tokens) % self.line_length
        before_table_tokens.extend([pad_token_id] * before_table_pad_count)

        # table
        table_tokens = self.tokenizer.encode(table, add_special_tokens=False)
        rows = table.strip().split("\n")
        col_count = len(rows[0].split(self.table_cell_separator_token))
        row_count = len(rows)
        table_grid = np.full((row_count, self.line_length), pad_token_id, dtype=object)

        # Get the max token number per column
        token_num_per_cell = []
        token_row = []
        token_counter_in_row = 0
        token_counter_in_cell = 0
        for id in table_tokens:
            token_counter_in_row += 1
            
            if id == start_of_line_token:
                token_row = []
                token_counter_in_cell = 0
                token_counter_in_row = 1
            elif id == end_of_line_token:
                token_row.append(token_counter_in_cell)
                token_num_per_cell.append(token_row)
                token_row = []
                token_counter_in_cell = 0
            elif id == table_cell_separator_token:
                token_row.append(token_counter_in_cell)
                token_counter_in_cell = 0
            else:
                token_counter_in_cell += 1
        token_num_per_cell = np.array(token_num_per_cell)
        max_token_num_per_cell = np.max(token_num_per_cell, axis = 0)
        pad_token_count = self.line_length - sum(max_token_num_per_cell)
        if pad_token_count < 0:
            print("The token number in the row exceeds the line length")
            # TODO: discard this data
        

        # Fill the table grid
        token_col_cursor = self.line_length - 1
        token_row_cursor = row_count - 1
        cell_col_cursor = col_count - 1
        cell_inner_token_counter = 0

        for id in reversed(table_tokens):
            current_cell_token_num = max_token_num_per_cell[cell_col_cursor]
            if id == start_of_line_token:
                token_row_cursor -= 1
            elif id == end_of_line_token:
                cell_col_cursor = col_count - 1
                token_col_cursor = self.line_length - 1
                table_grid[token_row_cursor, token_col_cursor - pad_token_count + 1:] = pad_token_id
                token_col_cursor -= pad_token_count
                cell_inner_token_counter = 0
            elif id == table_cell_separator_token:
                need_to_pad_token_count = current_cell_token_num - cell_inner_token_counter
                table_grid[token_row_cursor, token_col_cursor - need_to_pad_token_count + 1 : token_col_cursor + 1] = pad_token_id
                token_col_cursor -= need_to_pad_token_count
                cell_inner_token_counter = 0
            else:
                table_grid[token_row_cursor, token_col_cursor] = id
                token_col_cursor -= 1
                cell_inner_token_counter += 1
        
        # after_table
        after_table_tokens = self.tokenizer.encode(after_table, add_special_tokens=False)
        after_table_pad_count = self.line_length - len(after_table_tokens) % self.line_length
        after_table_tokens.extend([pad_token_id] * after_table_pad_count)

        # Concatenate the three grids and flatten the result
        result = np.concatenate((before_table_tokens, table_grid.flatten(), after_table_tokens), axis=0)
        for i in result:
            print(i, self.tokenizer.decode(i))
        # print(result)
        return self.tokenizer.decode(result)

    def _tokenize_function(self, examples):
        # Tokenize the input strings with padding and truncation
        return self.tokenizer(examples["input_string"], padding=True, truncation=True)

    def load(self):
        # Load the dataset from CSV files
        dataset = datasets.load_dataset("csv", data_files={
            "train": os.path.join(self.dataset_path, "data", "train.csv"),
            "test": os.path.join(self.dataset_path, "data", "test.csv"),
            "validation": os.path.join(self.dataset_path, "data", "val.csv")
        })

        # Shuffle the dataset based on the dataset name
        if self.dataset_name == "wtq":
            dataset = dataset.shuffle(seed=SHUFFLE_SEED)
        elif self.dataset_name == "self_generated":
            dataset["train"] = dataset["train"].shuffle(seed=SHUFFLE_SEED)

        # Select a subset of samples if max sample limits are specified
        if self.test_max_samples is not None:
            dataset["test"] = dataset["test"].select(range(self.test_max_samples))
        if self.val_max_samples is not None:
            dataset["validation"] = dataset["validation"].select(range(self.val_max_samples))
        if self.train_max_samples is not None:
            dataset["train"] = dataset["train"].select(range(self.train_max_samples))
        
        # print(dataset)
        # Preprocess each example to generate input strings
        dataset = dataset.map(self._preprocess_single_example_to_string, batched=False)

        # Set tokenizer padding and padding side
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = "left"

        # Tokenize the dataset
        dataset = dataset.map(self._tokenize_function, batched=True, batch_size=self.batch_size)
        return dataset


In [9]:
# Define dataset root and name
dataset_root = "../datasets"
dataset_name = "self_generated"

# Create an instance of TableDatasetLoader
loader = TableDatasetLoader(
    dataset_root=dataset_root,
    dataset_name=dataset_name,
    tokenizer=tokenizer,
    batch_size=2,
    table_extension="csv",
    test_max_samples=2,
    val_max_samples=2,
    train_max_samples=2,
    grid_it=True,
    skip_validation = True
)

# Load the dataset
dataset = loader.load()

# Print a few examples from the training set
# print(dataset["train"][0]["input_string"])
# print(dataset["train"][0]["input_ids"])


Map:   0%|          | 0/2 [00:00<?, ? examples/s]

128000 <|begin_of_text|>
128006 <|start_header_id|>
9125 system
128007 <|end_header_id|>
271 


38766 Cut
1303 ting
33025  Knowledge
2696  Date
25 :
6790  December
220  
2366 202
18 3
198 

15724 Today
2696  Date
25 :
220  
1627 26
10263  Jul
220  
2366 202
19 4
271 


2675 You
527  are
264  a
11190  helpful
18328  assistant
430  that
11503  answers
4860  questions
922  about
279  the
2007  table
13 .
1472  You
1193  only
4320  answer
279  the
3488  question
1314  right
1306  after
364  '
16533 Answer
25 :
364  '
128009 <|eot_id|>
128006 <|start_header_id|>
882 user
128007 <|end_header_id|>
271 


3923 What
374  is
279  the
907  value
304  in
2872  row
220  
17 2
30 ?
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
3006 Column
16 1
128009 <|eot_id|>
3006 Column
17 2
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
1150 Value
220  
18 3
1150 Value
220 

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

128000 <|begin_of_text|>
128006 <|start_header_id|>
9125 system
128007 <|end_header_id|>
271 


38766 Cut
1303 ting
33025  Knowledge
2696  Date
25 :
6790  December
220  
2366 202
18 3
198 

15724 Today
2696  Date
25 :
220  
1627 26
10263  Jul
220  
2366 202
19 4
271 


2675 You
527  are
264  a
11190  helpful
18328  assistant
430  that
11503  answers
4860  questions
922  about
279  the
2007  table
13 .
1472  You
1193  only
4320  answer
279  the
3488  question
1314  right
1306  after
364  '
16533 Answer
25 :
364  '
128009 <|eot_id|>
128006 <|start_header_id|>
882 user
128007 <|end_header_id|>
271 


3923 What
374  is
279  the
907  value
304  in
2872  row
220  
16 1
30 ?
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
3006 Column
16 1
128009 <|eot_id|>
3006 Column
17 2
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
1150 Value
220  
16 1
1150 Value
220 

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

128000 <|begin_of_text|>
128006 <|start_header_id|>
9125 system
128007 <|end_header_id|>
271 


38766 Cut
1303 ting
33025  Knowledge
2696  Date
25 :
6790  December
220  
2366 202
18 3
198 

15724 Today
2696  Date
25 :
220  
1627 26
10263  Jul
220  
2366 202
19 4
271 


2675 You
527  are
264  a
11190  helpful
18328  assistant
430  that
11503  answers
4860  questions
922  about
279  the
2007  table
13 .
1472  You
1193  only
4320  answer
279  the
3488  question
1314  right
1306  after
364  '
16533 Answer
25 :
364  '
128009 <|eot_id|>
128006 <|start_header_id|>
882 user
128007 <|end_header_id|>
271 


3923 What
374  is
279  the
907  value
304  in
2872  row
220  
16 1
30 ?
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
3006 Column
16 1
128009 <|eot_id|>
3006 Column
17 2
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
128009 <|eot_id|>
1150 Value
220  
16 1
1150 Value
220 

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

Map:   0%|          | 0/2 [00:00<?, ? examples/s]