In [1]:
import os
import sys
# Add ../src to the system path
src_path = os.path.abspath(os.path.join(os.path.dirname("__file__"), "../src"))
datasets_path = os.path.abspath(os.path.join(os.path.dirname("__file__"), "../datasets"))
if src_path not in sys.path:
    sys.path.append(src_path)

if datasets_path not in sys.path:
    sys.path.append(datasets_path)
from dotenv import load_dotenv
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaConfig, BitsAndBytesConfig
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding

load_dotenv()
# With grid
from models.modeling_table_llama import TableLlamaConfig

In [2]:
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" 

# Testing TableLlamaForCausalLM
* With entire channel set to zero 
* With gaussian noise added

In [3]:
table_llama_config = TableLlamaConfig.from_pretrained(MODEL_NAME)

table_llama_config.rope_table_llama = {
    "x_channels_start": None,
    "x_channels_end": None,
    "x_channels_step": None,
    "y_channels_start": None,
    "y_channels_end": None,
    "y_channels_step": None,
    "line_length": None
}

bnb_config = BitsAndBytesConfig(load_in_8bit=True)
torch.cuda.empty_cache()

## Entire channel set to zero

In [4]:
"""
This file contains the implementation of the TableLlama model.
"""

import torch
from transformers import LlamaConfig, LlamaModel, LlamaForCausalLM
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS

import transformers.utils.logging as logging
logger = logging.get_logger(__name__)

"""
TableLlamaRotaryEmbedding
"""


class TableLlamaRotaryEmbeddingZeroChannel(torch.nn.Module):
    def __init__(
        self,
        config: TableLlamaConfig,
        device=None,
    ):
        super().__init__()

        self.config = config
        self.rope_type = getattr(config.rope_scaling, "rope_type", "llama3")
    
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
        
        # Table Llama specific initialization
        self.rope_table_llama = getattr(self.config, "rope_table_llama")

        x_channels_start = self.rope_table_llama["x_channels_start"]
        x_channels_end = self.rope_table_llama["x_channels_end"]
        x_channels_step = self.rope_table_llama["x_channels_step"]
        y_channels_start = self.rope_table_llama["y_channels_start"]
        y_channels_end = self.rope_table_llama["y_channels_end"]
        y_channels_step = self.rope_table_llama["y_channels_step"]
        line_length = self.rope_table_llama["line_length"]
        
        if line_length is None:
            # Set a large number to avoid the RoPE effect
            line_length = 10**8

            
        if x_channels_end is None:
            x_channels_start = 10**8
            x_channels_end = 10**8
            x_channels_step = 10**8
        else:
            if x_channels_step is None or x_channels_start is None:
                raise ValueError("You have set x_channels_end but not x_channels_step or x_channels_start")
          
        if y_channels_end is None:
            y_channels_start = 10**8
            y_channels_end = 10**8
            y_channels_step = 10**8
        else:
            if y_channels_step is None or y_channels_start is None:
                raise ValueError("You have set y_channels_end but not y_channels_step or y_channels_start")

        self.register_buffer("inv_freq", inv_freq, persistent=False)
        
        self.num_channels = inv_freq.shape[0]
        self.line_length = line_length
        self.x_channels_start = x_channels_start
        self.x_channels_end = x_channels_end
        self.x_channels_step = x_channels_step
        self.y_channels_start = y_channels_start
        self.y_channels_end = y_channels_end
        self.y_channels_step = y_channels_step
        
        
        
  
    @torch.no_grad()
    def forward(self, x, position_ids):
        if "dynamic" in self.rope_type:
            self._dynamic_frequency_update(position_ids, device=x.device)

        # Core RoPE block
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) # (1, num_channels, 1)
        position_ids_expanded = position_ids[:, None, :].float() # (batch_size, 1, seq_len)
        position_ids_expanded = position_ids_expanded.repeat(1, self.num_channels, 1) # (batch_size, num_channels, seq_len)
        
        x_position_ids = position_ids_expanded % self.line_length
        y_position_ids = position_ids_expanded // self.line_length
        
        # Replace the position_ids_expanded with x_position_ids and y_position_ids
        x_start = self.x_channels_start
        x_end = self.x_channels_end
        x_step = self.x_channels_step
        y_start = self.y_channels_start
        y_end = self.y_channels_end
        y_step = self.y_channels_step
        
        # Specify the channel index to set to zero (e.g., channel 1)
        # Set the entire channel to zero
        position_ids_expanded[:, :, :] = 0

        # position_ids_expanded[:, x_start:x_end:x_step, :] = x_position_ids[:, x_start:x_end:x_step, :]
        # position_ids_expanded[:, y_start:y_end:y_step, :] = y_position_ids[:, y_start:y_end:y_step, :]
        
        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
                        
            freqs = position_ids_expanded * inv_freq_expanded # (batch_size, num_channels, seq_len)
            freqs = freqs.transpose(1, 2) # (batch_size, seq_len, num_channels)
            
            emb = torch.cat((freqs, freqs), dim=-1)

            cos = emb.cos()
            sin = emb.sin()
        

        # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
        cos = cos * self.attention_scaling # (batch_size, seq_len, dim)
        sin = sin * self.attention_scaling # (batch_size, seq_len, dim)

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

"""
TableLlamaModel
"""


class TableLlamaModelZeroChannel(LlamaModel):
    def __init__(self, config):
        super().__init__(config)
        self.rotary_emb = TableLlamaRotaryEmbeddingZeroChannel(config)


class TableLlamaForCausalLMZeroChannel(LlamaForCausalLM):
    def __init__(self, config: TableLlamaConfig):
        # Change the config class to TableLlamaConfig
        super().__init__(config)
        # if getattr(config, "rope_table_llama", None) is None:
        #     logger.warning("[TableLlamaForCausalLM] `rope_table_llama` is None. Using default values.")
        #     config.rope_table_llama = DEFAULT_ROPE_TABLE_LLAMA
        
        self.model = TableLlamaModelZeroChannel(config)

In [5]:
table_llama_model = TableLlamaForCausalLMZeroChannel.from_pretrained(
    MODEL_NAME, 
    quantization_config=bnb_config, 
    device_map="auto",
    config=table_llama_config
)

table_llama_model

TableLlamaForCausalLMZeroChannel(
  (model): TableLlamaModelZeroChannel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear8bitLt(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear8bitLt(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear8bitLt(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear8bitLt(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear8bitLt(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear8bitLt(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear8bitLt(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attent

## With Gaussian noise

In [6]:
"""
TableLlamaRotaryEmbedding
"""

class TableLlamaRotaryEmbeddingNoise(torch.nn.Module):
    def __init__(
        self,
        config: TableLlamaConfig,
        device=None,
    ):
        super().__init__()

        self.config = config
        self.rope_type = getattr(config.rope_scaling, "rope_type", "llama3")
    
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
        
        # Table Llama specific initialization
        self.rope_table_llama = getattr(self.config, "rope_table_llama")

        x_channels_start = self.rope_table_llama["x_channels_start"]
        x_channels_end = self.rope_table_llama["x_channels_end"]
        x_channels_step = self.rope_table_llama["x_channels_step"]
        y_channels_start = self.rope_table_llama["y_channels_start"]
        y_channels_end = self.rope_table_llama["y_channels_end"]
        y_channels_step = self.rope_table_llama["y_channels_step"]
        line_length = self.rope_table_llama["line_length"]
        
        if line_length is None:
            # Set a large number to avoid the RoPE effect
            line_length = 10**8

            
        if x_channels_end is None:
            x_channels_start = 10**8
            x_channels_end = 10**8
            x_channels_step = 10**8
        else:
            if x_channels_step is None or x_channels_start is None:
                raise ValueError("You have set x_channels_end but not x_channels_step or x_channels_start")
          
        if y_channels_end is None:
            y_channels_start = 10**8
            y_channels_end = 10**8
            y_channels_step = 10**8
        else:
            if y_channels_step is None or y_channels_start is None:
                raise ValueError("You have set y_channels_end but not y_channels_step or y_channels_start")

        self.register_buffer("inv_freq", inv_freq, persistent=False)
        
        self.num_channels = inv_freq.shape[0]
        self.line_length = line_length
        self.x_channels_start = x_channels_start
        self.x_channels_end = x_channels_end
        self.x_channels_step = x_channels_step
        self.y_channels_start = y_channels_start
        self.y_channels_end = y_channels_end
        self.y_channels_step = y_channels_step
        
        
        
  
    @torch.no_grad()
    def forward(self, x, position_ids):
        if "dynamic" in self.rope_type:
            self._dynamic_frequency_update(position_ids, device=x.device)

        # Core RoPE block
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) # (1, num_channels, 1)
        position_ids_expanded = position_ids[:, None, :].float() # (batch_size, 1, seq_len)
        position_ids_expanded = position_ids_expanded.repeat(1, self.num_channels, 1) # (batch_size, num_channels, seq_len)
        
        x_position_ids = position_ids_expanded % self.line_length
        y_position_ids = position_ids_expanded // self.line_length
        
        # Replace the position_ids_expanded with x_position_ids and y_position_ids
        x_start = self.x_channels_start
        x_end = self.x_channels_end
        x_step = self.x_channels_step
        y_start = self.y_channels_start
        y_end = self.y_channels_end
        y_step = self.y_channels_step
        
        # Specify the channel index to set to zero (e.g., channel 1)
        # channel_index = 0
        noise = np.random.normal(0, 0.1, 1)
        position_ids_expanded[:, :, :] += noise

        # position_ids_expanded[:, x_start:x_end:x_step, :] = x_position_ids[:, x_start:x_end:x_step, :]
        # position_ids_expanded[:, y_start:y_end:y_step, :] = y_position_ids[:, y_start:y_end:y_step, :]
        
        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
                        
            freqs = position_ids_expanded * inv_freq_expanded # (batch_size, num_channels, seq_len)
            freqs = freqs.transpose(1, 2) # (batch_size, seq_len, num_channels)
            
            emb = torch.cat((freqs, freqs), dim=-1)

            cos = emb.cos()
            sin = emb.sin()
        

        # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
        cos = cos * self.attention_scaling # (batch_size, seq_len, dim)
        sin = sin * self.attention_scaling # (batch_size, seq_len, dim)

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

"""
TableLlamaModel
"""


class TableLlamaModelNoise(LlamaModel):
    def __init__(self, config):
        super().__init__(config)
        self.rotary_emb = TableLlamaRotaryEmbeddingZeroChannel(config)


class TableLlamaForCausalNoise(LlamaForCausalLM):
    def __init__(self, config: TableLlamaConfig):
        # Change the config class to TableLlamaConfig
        super().__init__(config)
        # if getattr(config, "rope_table_llama", None) is None:
        #     logger.warning("[TableLlamaForCausalLM] `rope_table_llama` is None. Using default values.")
        #     config.rope_table_llama = DEFAULT_ROPE_TABLE_LLAMA
        
        self.model = TableLlamaModelZeroChannel(config)

# Eval

In [7]:
import sys
import os
from datetime import datetime
from tqdm import tqdm
from transformers import (
    HfArgumentParser, 
    PreTrainedTokenizerFast, 
    BitsAndBytesConfig
)
import torch
from torch.utils.data.dataloader import DataLoader

from parsers.argument_classes import DatasetArguments, ModelArguments, TrainingArguments, GenerationArguments
from utils.datasets_loader import load_datasets
from collators.data_collator_for_grid_tokenization import DataCollatorForGridTokenization

def main():
    parser = HfArgumentParser((DatasetArguments, ModelArguments, TrainingArguments, GenerationArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
        dataset_args, model_args, training_args, generation_args = parser.parse_yaml_file(sys.argv[1])
    else:
        args_list = ["--dataset_root_dir", "../datasets", "--dataset_names", "wtq" ,"--test_max_samples_for_each_dataset", "480", "--load_in_4bit", "True",
                    "--batch_size", "4", "--set_channel_zero", "False", "--add_channel_noise", "True"]
        dataset_args, model_args, training_args, generation_args = parser.parse_args_into_dataclasses(args_list)
    print(model_args)
    print(training_args)
            
    # Tokenizer
    tokenizer = PreTrainedTokenizerFast.from_pretrained(model_args.model_name)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    
    # Model
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=model_args.load_in_4bit,
        load_in_8bit=model_args.load_in_8bit,
        bnb_4bit_compute_dtype=torch.bfloat16 if model_args.load_in_4bit else None,
        bnb_4bit_use_double_quant=model_args.load_in_4bit,
    )
    # TableLlama
    table_llama_config = TableLlamaConfig.from_pretrained(model_args.model_name)
    table_llama_config.rope_table_llama = {
        "line_length": model_args.line_length,
        "x_channels_start": model_args.x_channels_start,
        "x_channels_end": model_args.x_channels_end,
        "x_channels_step": model_args.x_channels_step,
        "y_channels_start": model_args.y_channels_start,
        "y_channels_end": model_args.y_channels_end,
        "y_channels_step": model_args.y_channels_step,
    }
    # Replace with Model experimenting
    if (model_args.set_channel_zero): 
        model = TableLlamaForCausalLMZeroChannel.from_pretrained(
            model_args.model_name, 
            quantization_config=bnb_config if model_args.load_in_4bit or model_args.load_in_8bit else None,
            device_map="auto",
            config=table_llama_config
        )
    elif (model_args.add_channel_noise): 
        model = TableLlamaForCausalNoise.from_pretrained(
            model_args.model_name, 
            quantization_config=bnb_config if model_args.load_in_4bit or model_args.load_in_8bit else None,
            device_map="auto",
            config=table_llama_config
        )
    model.generation_config.pad_token_id = tokenizer.pad_token_id
    model.generation_config.eos_token_id = tokenizer.eos_token_id
    # Load adapter
    if model_args.adapter_path:
        model.load_adapter(model_args.adapter_path)
    
    # Load datasets
    def filter_function(example):
        if dataset_args.max_table_row_num is not None and example["table_row_num"] > dataset_args.max_table_row_num:
            return False
        if dataset_args.max_table_width is not None and example["table_width"] > dataset_args.max_table_width:
            return False
        return True
    datasets = load_datasets(dataset_args, filter_function=filter_function)
    
    # Data collator
    data_collator = DataCollatorForGridTokenization(
        tokenizer=tokenizer,
        max_seq_length=training_args.max_seq_length,
        is_train=False,
        is_grid_tokenization=model_args.line_length is not None,
        line_length=model_args.line_length if model_args.line_length is not None else 64,
    )
    
    # Inference loop
    pred_dataloader = DataLoader(
        datasets["test"],
        collate_fn=data_collator,
        batch_size=training_args.batch_size,
    )
    
    # Predict
    
    predictions = []
    for idx, batch in enumerate(tqdm(pred_dataloader)):
        with torch.no_grad():
            input_length = batch["input_ids"].size(1)
            # Move the batch to the device
            batch = {k: v.to(model.device) for k, v in batch.items()}
            outputs = model.generate(
                **batch, 
                max_new_tokens=generation_args.max_new_tokens,
                do_sample=generation_args.do_sample,
                top_k=generation_args.top_k,
                top_p=generation_args.top_p,
                temperature=generation_args.temperature,
                pad_token_id=tokenizer.eos_token_id
            )
            output_strings = tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=False)
            predictions.extend(output_strings)
    
    # Create a new column for predictions
    df = datasets["test"].to_pandas()
    df["raw_predictions"] = predictions
    
    def clean_predictions(predictions):
        return [pred.replace("<|eot_id|>", "").replace("Answer:", "").strip() for pred in predictions]
    
    df["predictions"] = clean_predictions(df["raw_predictions"])
    df["correct"] = df["answer"] == df["predictions"]
    
    print(f"Base model: {model_args.model_name}")
    print(f"Adapter: {model_args.adapter_path}")
    print(f"Total samples: {df.shape[0]}")

    if "self_generated" in dataset_args.dataset_names:
        
        # Count accuracy for each task and direction
        list_item_row_total = df[(df["task"] == "list_items") & (df["direction"] == "row")].shape[0]
        list_item_col_total = df[(df["task"] == "list_items") & (df["direction"] == "column")].shape[0]
        arithmetic_row_total = df[(df["task"] == "arithmetic") & (df["direction"] == "row")].shape[0]
        arithmetic_col_total = df[(df["task"] == "arithmetic") & (df["direction"] == "column")].shape[0]
        
        list_item_row_correct = df[(df["task"] == "list_items") & (df["direction"] == "row") & (df["correct"])].shape[0] 
        list_item_col_correct = df[(df["task"] == "list_items") & (df["direction"] == "column") & (df["correct"])].shape[0] 
        arithmetic_row_correct = df[(df["task"] == "arithmetic") & (df["direction"] == "row") & (df["correct"])].shape[0] 
        arithmetic_col_correct = df[(df["task"] == "arithmetic") & (df["direction"] == "column") & (df["correct"])].shape[0] 
        
        self_generated_total = list_item_row_total + list_item_col_total + arithmetic_row_total + arithmetic_col_total
        self_generated_correct = list_item_row_correct + list_item_col_correct + arithmetic_row_correct + arithmetic_col_correct
        
        print(f"List item row correct: {list_item_row_correct} / {list_item_row_total} = {list_item_row_correct / list_item_row_total * 100:.2f}%")
        print(f"List item column correct: {list_item_col_correct} / {list_item_col_total} = {list_item_col_correct / list_item_col_total * 100:.2f}%")
        print(f"Arithmetic row correct: {arithmetic_row_correct} / {arithmetic_row_total} = {arithmetic_row_correct / arithmetic_row_total * 100:.2f}%")
        print(f"Arithmetic column correct: {arithmetic_col_correct} / {arithmetic_col_total} = {arithmetic_col_correct / arithmetic_col_total * 100:.2f}%")
        print(f"Self-generated correct: {self_generated_correct} / {self_generated_total} = {self_generated_correct / self_generated_total * 100:.2f}%")
    
    if "wtq" in dataset_args.dataset_names:
        wtq_total = df[df["task"] == "wtq"].shape[0]
        wtq_correct = df[(df["task"] == "wtq") & (df["correct"])].shape[0]
            
        print(f"WTQ correct: {wtq_correct} / {wtq_total} = {wtq_correct / wtq_total * 100:.2f}%")
    
    total_correct = df["correct"].sum()
    total_total = df.shape[0]
    
    print(f"Total correct: {total_correct} / {total_total} = {total_correct / total_total * 100:.2f}%")
    
    
    
    # Save it to the adapter path
    if model_args.adapter_path:
        output_path = os.path.join(model_args.adapter_path, f"predictions.csv")
    else:
        # Create the output directory if not exists
        os.makedirs(training_args.output_dir, exist_ok=True)
        output_path = os.path.join(training_args.output_dir, f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_predictions.csv")
        
    df.to_csv(output_path, index=False)


In [8]:
# With zero channel
main()

ModelArguments(model_name='meta-llama/Llama-3.2-1B-Instruct', adapter_path='', load_in_8bit=False, load_in_4bit=True, line_length=None, x_channels_start=None, x_channels_end=None, x_channels_step=None, y_channels_start=None, y_channels_end=None, y_channels_step=None, set_channel_zero=False, add_channel_noise=True)
TrainingArguments(output_dir='./outputs', gradient_accumulation_steps=1, batch_size=4, num_train_epochs=1, save_total_limit=3, save_steps=100, logging_steps=10, eval_steps=200, max_seq_length=1024, dry_run=False, run_id_prefix='run', wandb_entity=None, wandb_project=None, hf_organization='cs230-table-llama', push_to_hub=False)


  2%|█▍                                                                                  | 2/120 [00:03<03:13,  1.64s/it]


KeyboardInterrupt: 

In [None]:
# With gaussian noise
main()

ModelArguments(model_name='meta-llama/Llama-3.2-1B-Instruct', adapter_path='', load_in_8bit=False, load_in_4bit=True, line_length=None, x_channels_start=None, x_channels_end=None, x_channels_step=None, y_channels_start=None, y_channels_end=None, y_channels_step=None, set_channel_zero=False, add_channel_noise=True)
TrainingArguments(output_dir='./outputs', gradient_accumulation_steps=1, batch_size=4, num_train_epochs=1, save_total_limit=3, save_steps=100, logging_steps=10, eval_steps=200, max_seq_length=1024, dry_run=False, run_id_prefix='run', wandb_entity=None, wandb_project=None, hf_organization='cs230-table-llama', push_to_hub=False)


 88%|████████████████████████████████████████████████████████████████████████▍         | 106/120 [03:27<00:45,  3.27s/it]