# Symbolic AI Tests 2025

# Common Task 1.1. Dataset preprocessing 
Dataset:

https://space.mit.edu/home/tegmark/aifeynman.html 
Note: The authors of this dataset are not affiliated with ML4SCI

Description:
Download the Feynman_with_units.tar.gz features and corresponding FeynmanEquations.csv targets. Preprocess and tokenize the target data and document your rationale for choice of tokenization.



In [2]:
# -*- coding: utf-8 -*-
"""
Script to train a Byte-Level BPE tokenizer on LaTeX equations from a CSV file
and prepare file paths for corresponding numeric data.

This script performs the following steps:
1. Defines paths for input CSV, numeric data directory, and tokenizer output.
2. Loads filenames and validates essential columns from the input CSV.
3. Trains a ByteLevelBPETokenizer on the equation strings found in the CSV.
4. Saves the trained tokenizer (vocab, merges, full config).
5. Wraps the trained tokenizer using Hugging Face's PreTrainedTokenizerFast.
6. Tests the wrapped tokenizer on a sample equation.
7. Maps filenames from the CSV to their corresponding numeric data files.
8. Provides usage notes for the tokenizer and data loading.
"""

import os
from pathlib import Path
import gc # For explicit garbage collection, if needed (see notes).
import sys # For exiting script on critical errors.

import pandas as pd
import numpy as np
from tokenizers import ByteLevelBPETokenizer
from transformers import PreTrainedTokenizerFast

# --- Configuration ---

# Input Paths
# Path to the CSV file containing filenames and LaTeX equations.
CSV_PATH = Path(r"/home/nikitas/Desktop/send_Miche/GOOGLE/FeynmanEquations.csv")
# Directory containing the corresponding numeric data files (named according to the 'Filename' column).
NUMERIC_DATA_DIR = Path(r"/home/nikitas/Desktop/send_Miche/GOOGLE/Feynman_with_units")

# Output Path
# Directory where the trained tokenizer files will be saved.
TOKENIZER_OUTPUT_DIR = Path("feynman_tokenizer")

# CSV Column Names
# Ensure these exactly match the headers in your CSV file.
FILENAME_COLUMN = 'Filename' # Column containing the base name of the numeric files.
EQUATION_COLUMN = 'Formula'  # Column containing the LaTeX equation strings.

# Tokenizer Settings
VOCAB_SIZE = 10_000     # Desired vocabulary size for the BPE tokenizer.
MIN_FREQUENCY = 2       # Minimum frequency for a token to be included in the vocab.
SPECIAL_TOKENS = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"] # Standard special tokens.

# --- Helper Functions ---

def equation_iterator(csv_file_path, target_column):
    """
    Creates a generator to efficiently yield equations from a specific column
    in a potentially large CSV file.

    Reads the CSV in chunks to minimize memory usage. Skips NaN or empty values.

    Args:
        csv_file_path (Path): Path object pointing to the CSV file.
        target_column (str): The exact name of the column containing equations.

    Yields:
        str: Non-empty equation strings from the specified column.

    Raises:
        FileNotFoundError: If the csv_file_path does not exist (handled internally).
        KeyError: If the target_column is not found in the CSV (handled internally).
        Exception: For other pandas read errors (handled internally).
    """
    chunk_size = 1000 # Process N rows at a time. Adjust based on memory constraints.
    print(f"   [Iterator] Initializing for column: '{target_column}'...")

    if not csv_file_path.is_file():
        print(f"   [Iterator Error] CSV file not found at: {csv_file_path}")
        return # Stop iteration if file is missing

    try:
        # Use chunking for memory efficiency.
        # skipinitialspace handles potential spaces after delimiters.
        iterator = pd.read_csv(
            csv_file_path,
            usecols=[target_column],
            chunksize=chunk_size,
            skipinitialspace=True,
            low_memory=False # Can sometimes help with mixed types, but monitor memory.
        )

        processed_count = 0
        for chunk in iterator:
            for equation in chunk[target_column]:
                # Convert to string and check for NaN/null/empty values
                eq_str = str(equation).strip()
                if pd.isna(equation) or not eq_str or eq_str.lower() == 'nan':
                    # Optionally log skipped values if debugging is needed
                    # print(f"   [Iterator] Skipping empty/NaN value.")
                    continue
                yield eq_str
                processed_count += 1

            # Optional: Explicit garbage collection after processing a large chunk.
            # This is often not strictly necessary as Python's GC handles objects
            # going out of scope, but can sometimes be helpful in tight memory
            # situations or long-running processes. Profile if considering uncommenting.
            # gc.collect()

        print(f"   [Iterator] Finished yielding {processed_count} equations.")

    except KeyError:
        print(f"   [Iterator Error] Column '{target_column}' not found in CSV: {csv_file_path}")
        # Stop iteration if the required column is missing
        return
    except Exception as e:
        # Catch other potential errors during file reading or processing
        print(f"   [Iterator Error] Failed reading chunks from {csv_file_path} using column '{target_column}': {e}")
        # Stop iteration on unexpected errors
        return


def load_numeric_data(filename, paths_dict):
    """
    Loads numeric data from a file specified by its filename.

    Assumes the file contains numeric data loadable by np.loadtxt.

    Args:
        filename (str): The base filename (e.g., 'data1.txt').
        paths_dict (dict): A dictionary mapping filenames to their full Path objects.

    Returns:
        np.ndarray: The loaded numpy array if successful.
        None: If the file is not found in the dictionary or cannot be loaded.
    """
    file_path = paths_dict.get(filename)
    if file_path is None:
        print(f"   [Data Loader Warning] No path found for numeric file '{filename}'.")
        return None
    if not file_path.is_file():
        print(f"   [Data Loader Warning] File path exists in map but file not found on disk: '{file_path}'.")
        return None

    try:
        # Load the data, assuming standard text format for numpy
        data_array = np.loadtxt(file_path)
        return data_array
    except Exception as e:
        print(f"   [Data Loader Error] Failed to load numeric file {file_path}: {e}")
        return None


# --- Main Script Execution ---

print("--- Script Start ---")

# 1. Setup and Path Validation
print("\n[Step 1] Setting up directories and validating paths...")
TOKENIZER_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) # Ensure output dir exists

if not CSV_PATH.is_file():
    print(f"[Error] Input CSV file not found: {CSV_PATH}")
    sys.exit(1) # Critical error, cannot proceed
if not NUMERIC_DATA_DIR.is_dir():
    print(f"[Warning] Numeric data directory not found: {NUMERIC_DATA_DIR}")
    # Allow script to continue, but numeric file mapping will likely fail.

print(f"  Input CSV: {CSV_PATH}")
print(f"  Numeric Data Dir: {NUMERIC_DATA_DIR}")
print(f"  Tokenizer Output Dir: {TOKENIZER_OUTPUT_DIR}")


# 2. Load Filenames and Validate CSV Structure
print(f"\n[Step 2] Loading filenames from '{FILENAME_COLUMN}' column and validating CSV...")
filenames = []
try:
    # Quickly read header to check for essential columns before loading data
    header = pd.read_csv(CSV_PATH, nrows=0, skipinitialspace=True).columns.tolist()
    print(f"  CSV Header: {header}")
    if FILENAME_COLUMN not in header:
        raise ValueError(f"Required filename column '{FILENAME_COLUMN}' not found in header.")
    if EQUATION_COLUMN not in header:
        raise ValueError(f"Required equation column '{EQUATION_COLUMN}' not found in header.")

    # Load only the necessary filename column
    filenames_series = pd.read_csv(CSV_PATH, usecols=[FILENAME_COLUMN], skipinitialspace=True)[FILENAME_COLUMN]
    # Convert to list of strings, handling potential non-string types gracefully
    filenames = [str(fn).strip() for fn in filenames_series.tolist() if pd.notna(fn)]
    # Filter out any empty strings that might result
    filenames = [fn for fn in filenames if fn]

    print(f"  Successfully loaded {len(filenames)} non-empty filenames.")
    if not filenames:
         print("  [Warning] No valid filenames were loaded from the CSV.")

except FileNotFoundError: # Should be caught by Step 1, but as fallback.
    print(f"[Error] CSV file disappeared after initial check: {CSV_PATH}")
    sys.exit(1)
except ValueError as ve: # Specific error for missing columns
    print(f"[Error] CSV Column Validation Failed: {ve}")
    sys.exit(1) # Critical error if required columns are missing
except Exception as e:
    print(f"[Error] Failed to read CSV header or filenames column '{FILENAME_COLUMN}': {e}")
    # Depending on the error, decide if exit is needed.
    # If filenames are non-critical for a part of the process, could warn instead.
    sys.exit(1)


# 3. Train BPE Tokenizer
print(f"\n[Step 3] Training Byte-Level BPE Tokenizer on '{EQUATION_COLUMN}' column...")

# Create the BPE tokenizer instance
bpe_tokenizer = ByteLevelBPETokenizer()

print(f"  Starting training with vocab_size={VOCAB_SIZE}, min_frequency={MIN_FREQUENCY}")
try:
    # Train using the efficient iterator
    # The iterator handles its own errors internally (e.g., column not found)
    bpe_tokenizer.train_from_iterator(
        equation_iterator(CSV_PATH, EQUATION_COLUMN),
        vocab_size=VOCAB_SIZE,
        min_frequency=MIN_FREQUENCY,
        special_tokens=SPECIAL_TOKENS
    )

    print(f"  Tokenizer training complete. Vocabulary size: {bpe_tokenizer.get_vocab_size()}")

    # Save the tokenizer components (vocabulary and merge rules)
    # Use str() for compatibility with older library versions if needed
    bpe_tokenizer.save_model(str(TOKENIZER_OUTPUT_DIR))
    print(f"  Tokenizer vocabulary and merges saved to: {TOKENIZER_OUTPUT_DIR}")

    # Save the full tokenizer configuration as tokenizer.json for Hugging Face
    tokenizer_json_path = TOKENIZER_OUTPUT_DIR / "tokenizer.json"
    bpe_tokenizer._tokenizer.save(str(tokenizer_json_path))
    print(f"  Full tokenizer config saved to: {tokenizer_json_path}")

except Exception as e:
    print(f"[Error] Tokenizer training failed: {e}")
    # Decide if script should stop; likely yes if tokenizer is core goal.
    tokenizer_json_path = None # Ensure path variable is None if saving failed
    print("  [Warning] Skipping subsequent steps that depend on the tokenizer.")


# 4. Wrap with PreTrainedTokenizerFast (Hugging Face)
print("\n[Step 4] Wrapping tokenizer with Hugging Face PreTrainedTokenizerFast...")

hf_tokenizer = None # Initialize to None
if 'tokenizer_json_path' in locals() and tokenizer_json_path and tokenizer_json_path.exists():
    try:
        hf_tokenizer = PreTrainedTokenizerFast(
            tokenizer_file=str(tokenizer_json_path),
            bos_token="<s>", # Begin-of-sequence
            eos_token="</s>", # End-of-sequence
            unk_token="<unk>", # Unknown token
            pad_token="<pad>", # Padding token
            mask_token="<mask>", # Mask token (if used in models like BERT)
        )
        print("  Hugging Face tokenizer wrapper created successfully.")

        # Perform a quick test on the first available equation
        print("  Testing tokenizer on the first valid equation...")
        try:
            # Get the first valid equation using a fresh iterator instance
            test_iterator = equation_iterator(CSV_PATH, EQUATION_COLUMN)
            first_equation = next(test_iterator, None) # Use default=None to avoid StopIteration
            del test_iterator # Clean up iterator

            if first_equation:
                print(f"    Equation Sample (from '{EQUATION_COLUMN}'): {first_equation}")
                tokens = hf_tokenizer.tokenize(first_equation)
                print(f"    → Tokens ({len(tokens)}): {tokens}")
                encoded_ids = hf_tokenizer.encode(first_equation)
                print(f"    → IDs ({len(encoded_ids)}): {encoded_ids}")
            else:
                print("    [Warning] Could not retrieve the first equation for testing (CSV empty or contains no valid equations?).")

        except Exception as e:
            print(f"    [Warning] Error during tokenizer test on first equation: {e}")

    except Exception as e:
        print(f"[Error] Failed to load tokenizer with PreTrainedTokenizerFast: {e}")
        # hf_tokenizer remains None

else:
     print("  [Skipped] Tokenizer file 'tokenizer.json' not found. Likely due to errors in Step 3.")


# 5. Map Filenames to Numeric Data Files
print(f"\n[Step 5] Mapping filenames to numeric data files in '{NUMERIC_DATA_DIR}'...")

numeric_file_paths = {}
missing_files_count = 0
checked_files_count = 0

if not filenames:
     print("  [Warning] No filenames loaded in Step 2. Cannot map numeric files.")
elif not NUMERIC_DATA_DIR.is_dir():
     print(f"  [Warning] Numeric data directory '{NUMERIC_DATA_DIR}' not found. Cannot map files.")
else:
    print(f"  Checking for {len(filenames)} potential files...")
    for fn in filenames:
        checked_files_count += 1
        # Construct the full expected path
        potential_path = NUMERIC_DATA_DIR / fn
        if potential_path.is_file():
            numeric_file_paths[fn] = potential_path
        else:
            # Only log missing files if verbose logging is needed, otherwise just count.
            # if missing_files_count < 10: # Log first few missing
            #     print(f"    - Missing: {fn} (expected at {potential_path})")
            # elif missing_files_count == 10:
            #      print("    - ... (further missing files not listed)")
            missing_files_count += 1

    found_files_count = len(numeric_file_paths)
    print(f"  Checked {checked_files_count} filenames.")
    print(f"  Found {found_files_count} corresponding numeric files.")
    if missing_files_count > 0:
        print(f"  [Warning] Could not find {missing_files_count} numeric files in '{NUMERIC_DATA_DIR}'.")

    # Show an example path if available
    if numeric_file_paths:
         example_fn = next(iter(numeric_file_paths)) # Get first key
         print(f"  Example mapping: '{example_fn}' -> '{numeric_file_paths[example_fn]}'")


# 6. Usage Information
print("\n[Step 6] Usage Notes")
print("-" * 20)
if hf_tokenizer:
    print(" - Equations Tokenization: Use the 'hf_tokenizer' object.")
    print("   Example: `encoded = hf_tokenizer(list_of_eq_strings, padding=True, truncation=True, max_length=128)`")
else:
    print(" - Equations Tokenization: FAILED (hf_tokenizer was not created).")

if numeric_file_paths:
    print(" - Numeric Data Loading: Use the 'numeric_file_paths' dictionary and the `load_numeric_data` function.")
    print("   Example: `data = load_numeric_data('some_filename.txt', numeric_file_paths)`")
    print("   (Ensure 'some_filename.txt' is a key in the dictionary)")
elif filenames: # Filenames were loaded, but mapping failed
    print(" - Numeric Data Loading: Mapping FAILED (numeric_file_paths dictionary is empty or was not created).")
    print("   Check Step 5 warnings and the NUMERIC_DATA_DIR path.")
else: # No filenames loaded
     print(" - Numeric Data Loading: Cannot load numeric data (no filenames loaded from CSV in Step 2).")
print("-" * 20)


print("\n--- Script Finished ---")

--- Script Start ---

[Step 1] Setting up directories and validating paths...
  Input CSV: /home/nikitas/Desktop/send_Miche/GOOGLE/FeynmanEquations.csv
  Numeric Data Dir: /home/nikitas/Desktop/send_Miche/GOOGLE/Feynman_with_units
  Tokenizer Output Dir: feynman_tokenizer

[Step 2] Loading filenames from 'Filename' column and validating CSV...
  CSV Header: ['Filename', 'Number', 'Output', 'Formula', '# variables', 'v1_name', 'v1_low', 'v1_high', 'v2_name', 'v2_low', 'v2_high', 'v3_name', 'v3_low', 'v3_high', 'v4_name', 'v4_low', 'v4_high', 'v5_name', 'v5_low', 'v5_high', 'v6_name', 'v6_low', 'v6_high', 'v7_name', 'v7_low', 'v7_high', 'v8_name', 'v8_low', 'v8_high', 'v9_name', 'v9_low', 'v9_high', 'v10_name', 'v10_low', 'v10_high']
  Successfully loaded 100 non-empty filenames.

[Step 3] Training Byte-Level BPE Tokenizer on 'Formula' column...
  Starting training with vocab_size=10000, min_frequency=2
   [Iterator] Initializing for column: 'Formula'...
   [Iterator] Finished yielding 1

# Common Task 1.2 Dataset preprocessing
Dataset:

https://alabama.box.com/s/xhgr2onrn503jyse2fs5vxtapg0oifcs 

Dataset:
Download the dataset (split across 10 files) and preprocess and tokenize the target data and document your rationale for choice of tokenization. Data file is formatted with rows like 
“event type : Feynman diagram : amplitude : squared amplitude”
Here the amplitudes are the input sequences and squared amplitudes are the target sequences. Note that indices like _123456 grow over the course of the dataset and should be normalized for each amplitude and squared amplitude. Use an 80-10-10 split of train-val-test across all files.


In [3]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Data Preprocessing and Splitting Pipeline for QED 2-to-2 Tree-Level Expressions.

This script performs the following actions:
1.  Loads raw text data containing QED amplitude and squared-amplitude expressions
    from files matching a specified pattern.
2.  Parses each line, expecting a specific format (event:diagram:amplitude:squared_amplitude).
3.  Applies index normalization to numeric subscripts (e.g., _123 -> _0) within
    each expression separately to reduce vocabulary size and improve generalization.
4.  Tokenizes the normalized expressions using a custom regex designed for
    mathematical and symbolic content (identifiers, numbers, operators).
5.  Structures the tokenized data into input/target pairs.
6.  Randomly shuffles the dataset.
7.  Splits the data into training, validation, and testing sets based on specified fractions.
8.  Saves each dataset split into a separate JSON Lines (.jsonl) file.

Rationale for Design Choices:
-   Custom Regex Tokenization: Provides fine-grained control over token boundaries,
    effectively separating mathematical identifiers (e.g., m_e, gamma_1), numbers,
    operators (+, -, *, /), and punctuation crucial for representing the
    structure of the expressions. Avoids reliance on external subword models,
    leading to a potentially smaller, more interpretable vocabulary specific to
    the domain.
-   Index Normalization: Numeric indices attached to symbols (e.g., `p_1`, `k_23`)
    can grow arbitrarily large across a large dataset. Normalizing these indices
    within each *individual* expression (e.g., `p_1`, `k_23` -> `p_0`, `k_1`)
    prevents the vocabulary from exploding and discourages the model from merely
    memorizing absolute index values, promoting focus on the symbolic structure.
-   Standard Splits: An 80/10/10 split for train/validation/test sets allows for
    standard model training, hyperparameter tuning, and final performance evaluation.
-   JSON Lines Format: A convenient format for storing sequence data, where each
    line is an independent JSON object, facilitating easy reading and processing.
"""

import glob
import re
import random
import json
import sys
from pathlib import Path

# --- Configuration ---

# Input data pattern (use raw string r"..." for Windows paths if needed)
# Update this path to point to your dataset files.
# Example Linux/macOS: "/path/to/data/QED-2-to-2-diag-TreeLevel-*.txt"
# Example Windows: r"C:\path\to\data\QED-2-to-2-diag-TreeLevel-*.txt"
FILES_PATTERN = r"/home/nikitas/Desktop/send_Miche/GOOGLE/SYMBA - Test Data-selected/QED-2-to-2-diag-TreeLevel-*.txt"

# Output file prefix (e.g., 'qed_data' -> 'qed_data_train.jsonl', etc.)
OUTPUT_PREFIX = 'qed_expressions'

# Dataset split ratios
TRAIN_FRAC = 0.8
VAL_FRAC = 0.1
# TEST_FRAC is implicitly calculated as 1.0 - TRAIN_FRAC - VAL_FRAC

# Random seed for reproducibility of shuffling and splitting
RANDOM_SEED = 42

# --- Constants ---

# Regex pattern to capture tokens:
# - Group 1: Identifiers ([A-Za-z_]\w*) - Starts with letter or underscore, followed by word chars.
# - Group 2: Numbers (\d+) - Integer numbers.
# - Group 3: Operators/Special Chars (\*\*|\^|[+\-*/=()_,:]) - Includes power (**), caret (^),
#           basic arithmetic, parentheses, comma, colon. Added '^' explicitly.
TOKEN_PATTERN = re.compile(r"([A-Za-z_]\w*|\d+|\*\*|\^|[+\-*/=()_,:])")

# Regex pattern to find numeric subscripts (e.g., "_123", "_4")
INDEX_PATTERN = re.compile(r'_\d+')

# --- Helper Functions ---

def tokenize(expression_string):
    """
    Tokenizes a mathematical expression string using the pre-defined TOKEN_PATTERN.

    Args:
        expression_string (str): The mathematical expression to tokenize.

    Returns:
        list[str]: A list of tokens found in the expression. Returns an empty
                   list if the input is empty or None.
    """
    if not expression_string:
        return []
    return TOKEN_PATTERN.findall(expression_string)

def normalize_indices(expression_string):
    """
    Normalizes numeric subscripts within an expression string.

    Finds all occurrences of `_ suivi par des chiffres` (e.g., `_1`, `_123`) and replaces
    them sequentially with `_0`, `_1`, `_2`, ... within the context of this single
    expression. This ensures that the same logical index appearing multiple times
    (e.g., `p_1 ... p_1`) gets the same normalized index (e.g., `p_0 ... p_0`).

    Args:
        expression_string (str): The expression containing potential numeric subscripts.

    Returns:
        str: The expression with its numeric subscripts normalized.
    """
    if not expression_string:
        return ""

    # A dictionary to map original indices (e.g., '_123') to normalized ones ('_0').
    index_mapping = {}
    # Using a list for 'counter' allows modification within the nested function (_repl).
    # This is a common technique to emulate mutable integer state in closures.
    counter = [0]

    def _replace_match(match):
        """Nested function to handle the replacement logic for re.sub."""
        original_index = match.group(0) # The matched subscript (e.g., '_123')
        if original_index not in index_mapping:
            # Assign the next available normalized index if this is the first time
            # we encounter this specific original index within this expression.
            normalized_index = f"_{counter[0]}"
            index_mapping[original_index] = normalized_index
            counter[0] += 1
        # Return the mapped normalized index (either newly created or previously stored).
        return index_mapping[original_index]

    # Substitute all occurrences using the _replace_match logic.
    return INDEX_PATTERN.sub(_replace_match, expression_string)

def load_and_preprocess(file_glob_pattern):
    """
    Loads data from files matching the pattern, preprocesses, and tokenizes it.

    Expects each non-empty line in the files to have the format:
    'event_info : diagram_info : amplitude_expression : squared_amplitude_expression'

    Lines not matching this format are skipped, and a warning is printed.

    Args:
        file_glob_pattern (str): A glob pattern matching the input text files.

    Returns:
        list[dict]: A list of dictionaries, where each dictionary represents
                    a preprocessed data point with keys 'input_tokens'
                    (from amplitude) and 'target_tokens' (from squared amplitude).
    """
    dataset = []
    skipped_lines = 0
    processed_files = 0

    # Use sorted() for deterministic file processing order, helpful for debugging.
    file_paths = sorted(glob.glob(file_glob_pattern))

    if not file_paths:
        print(f"[Warning] No files found matching pattern: {file_glob_pattern}")
        return []

    print(f"[INFO] Found {len(file_paths)} files matching pattern. Processing...")

    for file_path in file_paths:
        processed_files += 1
        try:
            with open(file_path, 'r', encoding='utf-8') as infile:
                for line_num, line in enumerate(infile, 1):
                    line = line.strip()
                    if not line:
                        continue # Skip empty lines

                    parts = line.split(' : ')
                    # Expecting 4 parts based on the specified format
                    if len(parts) != 4:
                        # print(f"[Warning] Skipping malformed line {line_num} in {file_path}: Expected 4 parts separated by ' : ', found {len(parts)}. Content: '{line[:100]}...'")
                        skipped_lines += 1
                        continue

                    # Unpack (we only need the expressions for tokenization)
                    _event_info, _diagram_info, amplitude_expr, sq_amplitude_expr = parts

                    # 1. Normalize indices
                    normalized_amplitude = normalize_indices(amplitude_expr)
                    normalized_sq_amplitude = normalize_indices(sq_amplitude_expr)

                    # 2. Tokenize
                    amplitude_tokens = tokenize(normalized_amplitude)
                    sq_amplitude_tokens = tokenize(normalized_sq_amplitude)

                    # Append structured data
                    if amplitude_tokens and sq_amplitude_tokens: # Only add if both are non-empty after processing
                        dataset.append({
                            'input_tokens': amplitude_tokens,   # Typically the 'source' sequence
                            'target_tokens': sq_amplitude_tokens # Typically the 'target' sequence
                        })
                    else:
                        # print(f"[Warning] Skipping line {line_num} in {file_path} due to empty tokens after processing.")
                        skipped_lines += 1

        except Exception as e:
            print(f"[Error] Failed to process file {file_path}: {e}")
            # Decide if one faulty file should stop the whole process or just be skipped
            continue # Skip this file and continue with the next

    print(f"[INFO] Processed {processed_files} files.")
    if skipped_lines > 0:
        print(f"[Warning] Skipped {skipped_lines} lines due to formatting issues or empty results.")

    return dataset

def split_and_save_dataset(data, train_frac, val_frac, output_prefix):
    """
    Shuffles the dataset, splits it into train, validation, and test sets,
    and saves each split to a JSON Lines file.

    Args:
        data (list[dict]): The list of preprocessed data points.
        train_frac (float): Fraction of data for the training set (e.g., 0.8).
        val_frac (float): Fraction of data for the validation set (e.g., 0.1).
        output_prefix (str): The prefix for the output filenames.
                             Files will be named f"{output_prefix}_train.jsonl", etc.
    """
    if not data:
        print("[Warning] No data provided to split and save.")
        return

    # Ensure fractions are valid
    if not (0 < train_frac < 1 and 0 < val_frac < 1 and (train_frac + val_frac) < 1):
        print(f"[Error] Invalid split fractions: train={train_frac}, val={val_frac}. Must be between 0 and 1, and sum < 1.")
        sys.exit(1)

    # Shuffle the data in place for random splitting
    random.shuffle(data)

    n_total = len(data)
    n_train = int(train_frac * n_total)
    n_val = int(val_frac * n_total)
    n_test = n_total - n_train - n_val # Remainder goes to test set

    # Perform the splits
    splits = {
        'train': data[:n_train],
        'val': data[n_train : n_train + n_val],
        'test': data[n_train + n_val :]
    }

    print(f"[INFO] Splitting data: {n_train} train, {n_val} validation, {n_test} test examples.")

    # Save each split
    output_dir = Path('.') # Save in the current directory, or specify another Path object
    output_dir.mkdir(parents=True, exist_ok=True) # Ensure directory exists

    for split_name, subset in splits.items():
        output_filename = output_dir / f"{output_prefix}_{split_name}.jsonl"
        try:
            with open(output_filename, 'w', encoding='utf-8') as outfile:
                for entry in subset:
                    # Write each dictionary as a JSON string on its own line
                    outfile.write(json.dumps(entry, ensure_ascii=False) + '\n')
            print(f"  Successfully saved {split_name} set to: {output_filename}")
        except IOError as e:
            print(f"[Error] Failed to write {split_name} set to {output_filename}: {e}")
        except Exception as e:
             print(f"[Error] An unexpected error occurred while writing {split_name} set: {e}")

# --- Main Execution ---

def main():
    """Main function to orchestrate the preprocessing and splitting."""
    print("[INFO] Starting QED expression preprocessing script...")

    # Set the random seed for reproducibility
    random.seed(RANDOM_SEED)
    print(f"[INFO] Random seed set to: {RANDOM_SEED}")

    # 1. Load and preprocess data
    print(f"[INFO] Loading data from pattern: {FILES_PATTERN}")
    preprocessed_data = load_and_preprocess(FILES_PATTERN)

    if not preprocessed_data:
        print("[Error] No data loaded. Exiting.")
        sys.exit(1)

    print(f"[INFO] Successfully loaded and preprocessed {len(preprocessed_data)} examples.")

    # 2. Split and save the dataset
    print(f"[INFO] Splitting and saving dataset with prefix '{OUTPUT_PREFIX}'...")
    split_and_save_dataset(
        data=preprocessed_data,
        train_frac=TRAIN_FRAC,
        val_frac=VAL_FRAC,
        output_prefix=OUTPUT_PREFIX
    )

    print("[INFO] Script finished successfully.")

if __name__ == "__main__":
    main()

[INFO] Starting QED expression preprocessing script...
[INFO] Random seed set to: 42
[INFO] Loading data from pattern: /home/nikitas/Desktop/send_Miche/GOOGLE/SYMBA - Test Data-selected/QED-2-to-2-diag-TreeLevel-*.txt
[INFO] Found 10 files matching pattern. Processing...
[INFO] Processed 10 files.
[INFO] Successfully loaded and preprocessed 15552 examples.
[INFO] Splitting and saving dataset with prefix 'qed_expressions'...
[INFO] Splitting data: 12441 train, 1555 validation, 1556 test examples.
  Successfully saved train set to: qed_expressions_train.jsonl
  Successfully saved val set to: qed_expressions_val.jsonl
  Successfully saved test set to: qed_expressions_test.jsonl
[INFO] Script finished successfully.


# Common Task 2: Train/Evaluate Transformer model
Train a generic next-token-prediction Transformer model to map the input data to the tokenized output sequences. Evaluate performance on the test set using sequence accuracy as a metric.


In [4]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
End-to-end Training and Evaluation Script for a Sequence-to-Sequence Model
on the Preprocessed QED 2-to-2 Tree-Level Dataset.

This script performs the following steps:
1.  Loads preprocessed training, validation, and test datasets from JSONL files.
2.  Prepares the data by joining token lists into whitespace-separated strings.
3.  Initializes a Hugging Face tokenizer (e.g., BERT's).
4.  Tokenizes the input (amplitude) and target (squared amplitude) sequences,
    padding/truncating them to a maximum length.
5.  Creates custom PyTorch Dataset objects for each split.
6.  Instantiates a Hugging Face Encoder-Decoder model (e.g., BERT-to-BERT).
7.  Configures essential model parameters (special tokens, sequence length, etc.).
8.  Sets up a Data Collator for handling dynamic padding within batches.
9.  Defines a custom metric function to calculate sequence-level accuracy.
10. Configures training arguments using `Seq2SeqTrainingArguments`.
11. Initializes a `Seq2SeqTrainer`.
12. Runs the training process on the training set, evaluating on the validation set.
13. Evaluates the final trained model on the test set.
14. Prints the test set sequence accuracy.
"""

import json
import sys
import numpy as np
import torch
from pathlib import Path
from torch.utils.data import Dataset
from transformers import (
    AutoTokenizer,
    EncoderDecoderModel,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    TrainingArguments # Explicit import can be clearer
)

# --- Configuration ---

# File Paths
# Assumes the script runs in a directory containing these files.
# Consider using absolute paths or command-line arguments for more flexibility.
TRAIN_FILE = Path('qed_expressions_train.jsonl') # Updated from 'data_train.jsonl' based on previous script
VAL_FILE   = Path('qed_expressions_val.jsonl')   # Updated from 'data_val.jsonl'
TEST_FILE  = Path('qed_expressions_test.jsonl')  # Updated from 'data_test.jsonl'
OUTPUT_DIR = Path('qed_model_output') # Directory for checkpoints and logs

# Model Configuration
# Using bert-base-uncased for both encoder and decoder as an example.
# Other combinations (e.g., DistilBERT encoder, BERT decoder) are possible.
ENCODER_MODEL_ID = "bert-base-uncased"
DECODER_MODEL_ID = "bert-base-uncased" # Must match encoder if weights are tied easily

# Tokenizer and Data Processing Configuration
MAX_SEQ_LENGTH = 128 # Maximum sequence length for padding/truncation

# Training Hyperparameters
NUM_TRAIN_EPOCHS = 5
PER_DEVICE_TRAIN_BATCH_SIZE = 16
PER_DEVICE_EVAL_BATCH_SIZE = 16
LEARNING_RATE = 5e-5
WEIGHT_DECAY = 0.01
WARMUP_STEPS = 500
LOGGING_STEPS = 100
EVALUATION_STRATEGY = "epoch" # Evaluate at the end of each epoch
SAVE_STRATEGY = "epoch"       # Save a checkpoint at the end of each epoch
# Set to True if generation metrics (BLEU, ROUGE) are needed during evaluation.
# Set to False to use raw logits for metrics like accuracy, which is faster.
PREDICT_WITH_GENERATE = False

# --- Helper Functions and Classes ---

def load_jsonl(file_path):
    """
    Loads a JSON Lines (.jsonl) file.

    Each line in the file is expected to be a valid JSON object.
    Blank lines or lines containing only whitespace are skipped.

    Args:
        file_path (Path or str): The path to the JSONL file.

    Returns:
        list[dict]: A list of dictionaries loaded from the file.

    Raises:
        FileNotFoundError: If the file_path does not exist.
        json.JSONDecodeError: If a line contains invalid JSON.
        Exception: For other potential I/O errors.
    """
    data = []
    file_path = Path(file_path) # Ensure it's a Path object
    print(f"[INFO] Loading data from: {file_path}")
    try:
        with file_path.open('r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                line = line.strip()
                if line:
                    try:
                        data.append(json.loads(line))
                    except json.JSONDecodeError as e:
                        print(f"[Warning] Skipping invalid JSON on line {i+1} in {file_path}: {e}")
                        continue
        print(f"[INFO] Successfully loaded {len(data)} records.")
        return data
    except FileNotFoundError:
        print(f"[Error] Data file not found: {file_path}")
        raise # Re-raise the exception to be handled by the caller
    except Exception as e:
        print(f"[Error] Failed to load data from {file_path}: {e}")
        raise # Re-raise the exception


def convert_tokens_to_strings(raw_data_list):
    """
    Converts lists of tokens into single whitespace-joined strings.

    Assumes each item in the input list is a dictionary with keys
    'input_tokens' and 'target_tokens', where the values are lists of strings.

    Args:
        raw_data_list (list[dict]): The list of raw data dictionaries.

    Returns:
        tuple[list[str], list[str]]: A tuple containing two lists:
                                     - input strings
                                     - target strings
    """
    input_strings = []
    target_strings = []
    if not raw_data_list:
        return input_strings, target_strings

    for item in raw_data_list:
        # Check if keys exist and values are lists, handle potential errors gracefully
        input_toks = item.get('input_tokens')
        target_toks = item.get('target_tokens')

        if isinstance(input_toks, list) and isinstance(target_toks, list):
            input_strings.append(" ".join(map(str, input_toks)))  # Ensure tokens are strings
            target_strings.append(" ".join(map(str, target_toks)))
        else:
            print(f"[Warning] Skipping item due to missing/invalid keys or non-list values: {item}")

    print(f"[INFO] Converted {len(input_strings)} items to strings.")
    return input_strings, target_strings


def encode_sequence_pairs(tokenizer, input_strings, target_strings, max_len):
    """
    Tokenizes input and target string pairs using the provided tokenizer.

    Pads and truncates sequences to the specified maximum length.
    Sets up the target sequences correctly for decoder training.

    Args:
        tokenizer: An initialized Hugging Face tokenizer instance.
        input_strings (list[str]): List of input sequences.
        target_strings (list[str]): List of target sequences.
        max_len (int): The maximum sequence length for padding/truncation.

    Returns:
        dict: A dictionary containing tokenized 'input_ids', 'attention_mask',
              and 'labels' (tokenized target IDs).
    """
    print(f"[INFO] Encoding sequence pairs with max_length={max_len}...")

    # Tokenize the input sequences (for the encoder)
    encoder_inputs = tokenizer(
        input_strings,
        padding='max_length',   # Pad to max_len
        truncation=True,        # Truncate sequences longer than max_len
        max_length=max_len,
        return_tensors="pt"     # Return PyTorch tensors (though we convert later in Dataset)
    )
    encoder_inputs.pop("token_type_ids", None) # Not needed for BERT encoder in this context

    # Tokenize the target sequences (for the decoder)
    # Use tokenizer in "target mode" to handle special tokens appropriately for decoder inputs
    with tokenizer.as_target_tokenizer():
        decoder_labels = tokenizer(
            target_strings,
            padding='max_length',
            truncation=True,
            max_length=max_len,
            return_tensors="pt"
        )

    # The 'labels' for the model are the input_ids of the *target* sequence.
    # The model will typically shift these internally to create decoder_input_ids.
    encoder_inputs['labels'] = decoder_labels['input_ids']

    # Detach tensors and convert back to lists for easier handling in Dataset __getitem__
    # Note: If memory allows, keeping tensors might be slightly faster, but lists are more flexible.
    encoded_data = {k: v.tolist() for k, v in encoder_inputs.items()}

    print(f"[INFO] Encoding complete.")
    return encoded_data


class SequenceDataset(Dataset):
    """
    Custom PyTorch Dataset to wrap tokenized sequence data.
    """
    def __init__(self, encodings):
        """
        Args:
            encodings (dict): A dictionary where keys are 'input_ids',
                              'attention_mask', 'labels', etc., and values
                              are lists of tokenized sequences.
        """
        if not isinstance(encodings, dict) or not encodings:
             raise ValueError("Encodings must be a non-empty dictionary.")
        # Find a key to determine the length (e.g., 'input_ids')
        self.length = 0
        if 'input_ids' in encodings and isinstance(encodings['input_ids'], list):
             self.length = len(encodings['input_ids'])
        elif encodings:
             # Fallback: Use the length of the first list found
             first_key = next(iter(encodings))
             if isinstance(encodings[first_key], list):
                 self.length = len(encodings[first_key])

        if self.length == 0:
             raise ValueError("Could not determine dataset length from encodings.")

        self.encodings = encodings
        print(f"[INFO] Created Dataset with {self.length} examples.")

    def __len__(self):
        """Returns the total number of samples."""
        return self.length

    def __getitem__(self, idx):
        """
        Retrieves the tokenized data for a single sample and converts it to tensors.

        Args:
            idx (int): The index of the sample to retrieve.

        Returns:
            dict: A dictionary where keys are feature names (e.g., 'input_ids')
                  and values are the corresponding PyTorch tensors for the sample.
        """
        # Fetch the data for the given index and convert each list item to a tensor
        try:
            item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
            return item
        except IndexError:
            print(f"[Error] Index {idx} out of bounds for dataset of length {self.length}.")
            raise
        except Exception as e:
            print(f"[Error] Failed to retrieve or convert item at index {idx}: {e}")
            raise


def compute_sequence_accuracy(predictions, labels, pad_token_id):
    """
    Calculates exact sequence match accuracy.

    Compares predicted sequences to label sequences, ignoring padding tokens.
    A prediction is correct only if the entire non-padded sequence matches the label.

    Args:
        predictions (np.ndarray): Logits output by the model (shape: batch_size, seq_len, vocab_size).
        labels (np.ndarray): Ground truth label IDs (shape: batch_size, seq_len).
        pad_token_id (int): The ID of the padding token to ignore.

    Returns:
        float: The fraction of sequences that match exactly.
    """
    # Get the predicted token IDs by taking the argmax along the vocabulary dimension
    pred_ids = np.argmax(predictions, axis=-1)

    # Create a mask to ignore padding tokens in the labels
    non_padding_mask = (labels != pad_token_id)

    # Check equality element-wise for non-padded positions
    correct_tokens = (pred_ids == labels) & non_padding_mask

    # Check if all non-padded tokens in a sequence are correct
    # We sum the mask and the correct tokens for each sequence.
    # If the sums are equal, it means every non-padded token was predicted correctly.
    correct_sequences = (np.sum(correct_tokens, axis=1) == np.sum(non_padding_mask, axis=1))

    # Calculate the mean accuracy over the batch
    accuracy = np.mean(correct_sequences)
    return float(accuracy)

def compute_metrics(eval_pred):
    """
    Computes metrics for evaluation. Called by the Trainer.

    Args:
        eval_pred (EvalPrediction): A tuple containing predictions (logits) and labels.

    Returns:
        dict: A dictionary mapping metric names to their values.
    """
    predictions, labels = eval_pred
    # predictions are logits, labels are the actual token IDs

    # Ensure pad_token_id is accessible, assuming tokenizer is globally defined or passed
    # If not global, it needs to be passed or accessed differently.
    # For simplicity here, we assume 'tokenizer' is accessible.
    pad_token_id = tokenizer.pad_token_id

    seq_acc = compute_sequence_accuracy(predictions, labels, pad_token_id)

    return {"sequence_accuracy": seq_acc}

# --- Main Execution Logic ---

def main():
    """Orchestrates the entire training and evaluation pipeline."""
    print("[INFO] Starting Sequence-to-Sequence Model Training and Evaluation...")

    # 1. Load Data
    try:
        train_raw_data = load_jsonl(TRAIN_FILE)
        val_raw_data = load_jsonl(VAL_FILE)
        test_raw_data = load_jsonl(TEST_FILE)
    except Exception:
         print("[Error] Failed to load one or more data files. Exiting.")
         sys.exit(1)

    if not train_raw_data or not val_raw_data or not test_raw_data:
        print("[Error] One or more datasets are empty after loading. Exiting.")
        sys.exit(1)

    # 2. Prepare Data (Tokens to Strings)
    train_inputs, train_targets = convert_tokens_to_strings(train_raw_data)
    val_inputs,   val_targets   = convert_tokens_to_strings(val_raw_data)
    test_inputs,  test_targets  = convert_tokens_to_strings(test_raw_data)

    if not train_inputs or not val_inputs or not test_inputs:
         print("[Error] Data conversion resulted in empty lists. Check input data format. Exiting.")
         sys.exit(1)


    # 3. Initialize Tokenizer
    print(f"[INFO] Initializing tokenizer: {ENCODER_MODEL_ID}")
    # Making tokenizer global for access in compute_metrics (simplification)
    # A better approach in larger projects might involve passing it explicitly or using a class.
    global tokenizer
    try:
        tokenizer = AutoTokenizer.from_pretrained(ENCODER_MODEL_ID)
        # Add special tokens if they are not already present? BERT usually has them.
        # tokenizer.add_special_tokens({'bos_token':'<s>', 'eos_token':'</s>'}) # Example if needed
    except Exception as e:
        print(f"[Error] Failed to initialize tokenizer {ENCODER_MODEL_ID}: {e}")
        sys.exit(1)

    # 4. Tokenize Data
    try:
        train_encodings = encode_sequence_pairs(tokenizer, train_inputs, train_targets, MAX_SEQ_LENGTH)
        val_encodings   = encode_sequence_pairs(tokenizer, val_inputs,   val_targets,   MAX_SEQ_LENGTH)
        test_encodings  = encode_sequence_pairs(tokenizer, test_inputs,  test_targets,  MAX_SEQ_LENGTH)
    except Exception as e:
        print(f"[Error] Failed during data tokenization: {e}")
        sys.exit(1)


    # 5. Create Datasets
    try:
        train_dataset = SequenceDataset(train_encodings)
        val_dataset   = SequenceDataset(val_encodings)
        test_dataset  = SequenceDataset(test_encodings)
    except ValueError as e:
         print(f"[Error] Failed to create Dataset objects: {e}")
         sys.exit(1)

    # 6. Initialize Model
    print(f"[INFO] Initializing Encoder-Decoder model ({ENCODER_MODEL_ID} -> {DECODER_MODEL_ID})")
    try:
        model = EncoderDecoderModel.from_encoder_decoder_pretrained(
            ENCODER_MODEL_ID,
            DECODER_MODEL_ID,
        )

        # --- Configure Model ---
        # Tie weights if encoder and decoder are compatible (e.g., same architecture/vocab)
        # This reduces parameters and can improve performance. Check documentation if unsure.
        # model.tie_weights() # Uncomment if appropriate for the chosen models

        # Set special token IDs for generation (though not used if predict_with_generate=False)
        # These ensure the decoder starts and ends sequences correctly if generation is enabled.
        model.config.decoder_start_token_id = tokenizer.cls_token_id # Often [CLS] for BERT-like models
        model.config.eos_token_id           = tokenizer.sep_token_id # Often [SEP]
        model.config.pad_token_id           = tokenizer.pad_token_id # Usually 0

        # Ensure critical parameters are aligned between the main config and encoder/decoder configs
        model.config.vocab_size             = model.config.encoder.vocab_size # Use encoder's vocab size

        # Set parameters relevant for training and generation control
        model.config.max_length             = MAX_SEQ_LENGTH # Max length for generation if enabled
        model.config.early_stopping         = True           # Example: enable early stopping for generation
        model.config.no_repeat_ngram_size   = 3              # Example: prevent trigram repetition during generation
        # model.config.length_penalty         = 2.0            # Example: encourage longer sequences
        # model.config.num_beams              = 4              # Example: use beam search for generation

        # Optional: Set dropout rates explicitly if defaults are not desired
        # model.config.dropout                = 0.1
        # model.config.attention_dropout      = 0.1

        print("[INFO] Model configuration complete.")
        # print(model.config) # Uncomment to inspect the full configuration

    except Exception as e:
        print(f"[Error] Failed to initialize or configure the model: {e}")
        sys.exit(1)

    # 7. Initialize Data Collator
    # Dynamically pads sequences within each batch to the maximum length in that batch.
    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=model,
        label_pad_token_id=tokenizer.pad_token_id # Use pad token ID for labels
    )
    print("[INFO] Data collator initialized.")

    # 8. Define Training Arguments
    training_args = Seq2SeqTrainingArguments(
        output_dir=str(OUTPUT_DIR),              # Directory for checkpoints and logs
        num_train_epochs=NUM_TRAIN_EPOCHS,
        per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
        per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        warmup_steps=WARMUP_STEPS,
        logging_dir=str(OUTPUT_DIR / 'logs'),    # Directory for TensorBoard logs
        logging_steps=LOGGING_STEPS,
        evaluation_strategy=EVALUATION_STRATEGY,
        save_strategy=SAVE_STRATEGY,
        save_total_limit=2,                      # Keep only the last 2 checkpoints
        load_best_model_at_end=True,             # Load the best checkpoint found during training at the end
        metric_for_best_model="sequence_accuracy",# Use sequence accuracy to determine the best model
        greater_is_better=True,                  # Higher accuracy is better
        predict_with_generate=PREDICT_WITH_GENERATE, # Use logits for faster evaluation if True metrics aren't needed
        # report_to="tensorboard",               # Example: enable reporting to TensorBoard
        # fp16=torch.cuda.is_available(),        # Enable mixed-precision training if CUDA is available
    )
    print("[INFO] Training arguments defined.")

    # 9. Initialize Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics, # Pass the custom metric function
    )
    print("[INFO] Seq2SeqTrainer initialized.")

    # 10. Train the Model
    print("[INFO] Starting model training...")
    try:
        train_result = trainer.train()
        # Log training metrics
        metrics = train_result.metrics
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state() # Save optimizer, scheduler state etc.
        # Save the final best model explicitly (though save_strategy='epoch' does it too)
        trainer.save_model(str(OUTPUT_DIR / "best_model"))
        print("[INFO] Training finished successfully.")
    except Exception as e:
        print(f"[Error] Training failed: {e}")
        sys.exit(1)


    # 11. Evaluate on Test Set
    print("[INFO] Evaluating model on the test set...")
    try:
        test_results = trainer.evaluate(eval_dataset=test_dataset, metric_key_prefix="test")
        trainer.log_metrics("test", test_results)
        trainer.save_metrics("test", test_results)

        # Print the key metric
        if 'test_sequence_accuracy' in test_results:
             print(f"\n--- Test Set Results ---")
             print(f"Test Sequence Accuracy: {test_results['test_sequence_accuracy']:.4f}")
             print(f"------------------------")
        else:
             print("[Warning] 'test_sequence_accuracy' not found in evaluation results.")
             print("Test results:", test_results)

    except Exception as e:
        print(f"[Error] Evaluation on test set failed: {e}")
        sys.exit(1)

    print("[INFO] Script finished successfully.")


if __name__ == "__main__":
    main()

[INFO] Starting Sequence-to-Sequence Model Training and Evaluation...
[INFO] Loading data from: qed_expressions_train.jsonl
[INFO] Successfully loaded 12441 records.
[INFO] Loading data from: qed_expressions_val.jsonl
[INFO] Successfully loaded 1555 records.
[INFO] Loading data from: qed_expressions_test.jsonl
[INFO] Successfully loaded 1556 records.
[INFO] Converted 12441 items to strings.
[INFO] Converted 1555 items to strings.
[INFO] Converted 1556 items to strings.
[INFO] Initializing tokenizer: bert-base-uncased
[INFO] Encoding sequence pairs with max_length=128...




[INFO] Encoding complete.
[INFO] Encoding sequence pairs with max_length=128...
[INFO] Encoding complete.
[INFO] Encoding sequence pairs with max_length=128...
[INFO] Encoding complete.
[INFO] Created Dataset with 12441 examples.
[INFO] Created Dataset with 1555 examples.
[INFO] Created Dataset with 1556 examples.
[INFO] Initializing Encoder-Decoder model (bert-base-uncased -> bert-base-uncased)


Some weights of BertLMHeadModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.weight', 'bert.encoder.layer.1.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.1.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.1.crossattention.output.dense.bias', 'bert.encoder.layer.1.crossattention.output.dense.weight', 'bert.encoder.layer.1.crossattention.self.key.bias', 'bert.e

[INFO] Model configuration complete.
[INFO] Data collator initialized.


TypeError: Seq2SeqTrainingArguments.__init__() got an unexpected keyword argument 'evaluation_strategy'

# Specific Test 3: Train/Evaluate advanced model
Repeat task two including checking sequence accuracy but with a model that leverages some slightly more advanced techniques. The model you use should relate to the project you’re applying for.

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
End-to-end Training and Evaluation Script for a T5 Sequence-to-Sequence Model
on the Preprocessed QED 2-to-2 Tree-Level Dataset.

This script utilizes several advanced Hugging Face Transformers features:
- T5 Model: Leverages the t5-base pre-trained model for conditional generation.
- Gradient Checkpointing: Reduces GPU memory usage during training, allowing
  for larger models or batch sizes, at the cost of a small computational overhead.
- Mixed Precision Training (FP16): Speeds up training and further reduces memory
  usage on compatible GPUs (NVIDIA Volta architecture or newer) by performing
  certain operations in half-precision floating-point format.
- Label Smoothing: A regularization technique that prevents the model from
  becoming overconfident in its predictions, potentially improving generalization.
- Beam Search Generation: Used during evaluation (`predict_with_generate=True`)
  to generate more fluent and potentially more accurate output sequences compared
  to greedy decoding.
- Exact Match Accuracy: Evaluates the model based on whether the generated
  sequence exactly matches the target sequence after decoding and stripping whitespace.

Workflow:
1. Load preprocessed data splits (train, validation, test) from JSONL files.
2. Convert token lists in the data back into whitespace-separated strings.
3. Initialize the T5 tokenizer and T5 model (`t5-base`).
4. Enable gradient checkpointing on the model.
5. Define an encoding function to tokenize input/target strings using the T5 tokenizer.
6. Create PyTorch Dataset objects for each data split.
7. Set up a Data Collator for dynamic padding within batches.
8. Define a custom metric function (`compute_metrics`) that uses generated
   predictions (due to `predict_with_generate=True`) and calculates exact
   sequence match accuracy after decoding.
9. Configure `Seq2SeqTrainingArguments`, enabling advanced features like
   FP16, gradient checkpointing, label smoothing, and generation parameters.
10. Initialize the `Seq2SeqTrainer`.
11. Train the model.
12. Evaluate the final model on the test set using the defined metric.
"""

import json
import sys
import numpy as np
import torch
import datetime # <--- ***** FIX: Import the standard datetime module *****
from pathlib import Path
from torch.utils.data import Dataset
from transformers import (
    T5Tokenizer,
    T5ForConditionalGeneration,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    TrainingArguments # Explicit import for clarity
)

# --- Configuration ---

# File Paths
# Assumes the script runs in a directory containing these files.
# Use absolute paths or command-line arguments if needed.
TRAIN_FILE = Path('qed_expressions_train.jsonl') # Path to the training data JSONL file
VAL_FILE   = Path('qed_expressions_val.jsonl')   # Path to the validation data JSONL file
TEST_FILE  = Path('qed_expressions_test.jsonl')  # Path to the test data JSONL file
OUTPUT_DIR = Path('qed_t5_model_output') # Directory for checkpoints, logs, and final model

# Model Configuration
MODEL_ID = "t5-base" # Pre-trained T5 model identifier from Hugging Face Hub

# Tokenizer and Data Processing Configuration
MAX_SEQ_LENGTH = 128 # Maximum sequence length for input and target tokenization

# Training Hyperparameters
NUM_TRAIN_EPOCHS = 5
PER_DEVICE_TRAIN_BATCH_SIZE = 8 # Adjust based on GPU memory
PER_DEVICE_EVAL_BATCH_SIZE = 8  # Adjust based on GPU memory
LEARNING_RATE = 3e-4            # Typical learning rate for T5 fine-tuning
WEIGHT_DECAY = 0.01             # Weight decay for regularization
WARMUP_STEPS = 200              # Number of linear warmup steps for the learning rate scheduler
LOGGING_STEPS = 50              # Log metrics every N steps
EVALUATION_STRATEGY = "epoch"   # Evaluate performance at the end of each epoch
SAVE_STRATEGY = "epoch"         # Save a model checkpoint at the end of each epoch

# Advanced Training Feature Flags/Values
# Enable mixed precision if a CUDA GPU is available and requirements are met
USE_FP16 = torch.cuda.is_available()
USE_GRADIENT_CHECKPOINTING = True # Enable gradient checkpointing to save memory
LABEL_SMOOTHING_FACTOR = 0.1      # Apply label smoothing regularization (0.0 means no smoothing)

# Generation Configuration (used during evaluation with predict_with_generate=True)
GENERATION_NUM_BEAMS = 4          # Number of beams for beam search decoding

# --- Helper Functions and Classes ---

def load_jsonl(file_path):
    """Loads a JSON Lines (.jsonl) file, skipping blank/whitespace-only lines."""
    data = []
    file_path = Path(file_path)
    print(f"[INFO] Loading data from: {file_path}")
    try:
        with file_path.open('r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                line = line.strip()
                if line:
                    try:
                        data.append(json.loads(line))
                    except json.JSONDecodeError as e:
                        print(f"[Warning] Skipping invalid JSON on line {i+1} in {file_path}: {e}")
                        continue
        print(f"[INFO] Successfully loaded {len(data)} records.")
        return data
    except FileNotFoundError:
        print(f"[Error] Data file not found: {file_path}", file=sys.stderr)
        raise
    except Exception as e:
        print(f"[Error] Failed to load data from {file_path}: {e}", file=sys.stderr)
        raise


def convert_tokens_to_strings(raw_data_list):
    """Converts lists of tokens into single whitespace-joined strings."""
    input_strings = []
    target_strings = []
    if not raw_data_list:
        return input_strings, target_strings

    skipped_count = 0
    for i, item in enumerate(raw_data_list):
        input_toks = item.get('input_tokens')
        target_toks = item.get('target_tokens')

        if isinstance(input_toks, list) and isinstance(target_toks, list):
            # Ensure all tokens are strings before joining
            input_strings.append(" ".join(map(str, input_toks)))
            target_strings.append(" ".join(map(str, target_toks)))
        else:
            print(f"[Warning] Skipping item at index {i} due to missing/invalid keys or non-list values: {item}")
            skipped_count += 1

    print(f"[INFO] Converted {len(input_strings)} items to strings (skipped {skipped_count}).")
    return input_strings, target_strings


def encode_sequences(tokenizer, input_strings, target_strings, max_len):
    """
    Tokenizes input and target string pairs using the T5 tokenizer.

    Pads and truncates sequences to the specified maximum length.
    Prepares labels correctly for T5 training.

    Args:
        tokenizer: An initialized T5Tokenizer instance.
        input_strings (list[str]): List of input sequences (add prefix if needed for T5 task).
        target_strings (list[str]): List of target sequences.
        max_len (int): The maximum sequence length for padding/truncation.

    Returns:
        dict: A dictionary containing tokenized 'input_ids', 'attention_mask',
              and 'labels' (tokenized target IDs). T5 does not use token_type_ids.
    """
    print(f"[INFO] Encoding sequence pairs with max_length={max_len}...")

    # Note: T5 often benefits from a task-specific prefix, e.g., "translate English to German: "
    # Add such a prefix to input_strings here if applicable to your task.
    # Example: input_strings = [f"summarize: {s}" for s in input_strings]

    # Tokenize the input sequences (for the encoder)
    encoder_inputs = tokenizer(
        input_strings,
        max_length=max_len,
        padding='max_length',   # Pad to max_len
        truncation=True,        # Truncate sequences longer than max_len
        return_tensors=None     # Return lists, collator will handle tensor conversion
    )

    # Tokenize the target sequences (for the decoder labels)
    # Use tokenizer in target mode context is good practice, although less critical for T5 label encoding.
    with tokenizer.as_target_tokenizer():
        decoder_labels = tokenizer(
            target_strings,
            max_length=max_len,
            padding='max_length',
            truncation=True,
            return_tensors=None
        )

    # The 'labels' for T5 are the input_ids of the target sequence.
    # Padding tokens in labels are typically replaced with -100 by the DataCollator
    # so they are ignored in the loss calculation.
    encoder_inputs['labels'] = decoder_labels['input_ids']

    print(f"[INFO] Encoding complete.")
    return encoder_inputs


class SequenceDataset(Dataset):
    """Custom PyTorch Dataset wrapper for tokenized sequence data."""
    def __init__(self, encodings):
        """
        Args:
            encodings (dict): Dictionary from tokenizer {'input_ids': [...], 'attention_mask': [...], 'labels': [...]}.
        """
        if not isinstance(encodings, dict) or 'input_ids' not in encodings:
            raise ValueError("Encodings must be a dictionary containing at least 'input_ids'.")
        self.encodings = encodings
        self.length = len(encodings['input_ids'])
        if self.length == 0:
            raise ValueError("Input encodings are empty.")
        print(f"[INFO] Created Dataset with {self.length} examples.")

    def __len__(self):
        """Returns the total number of samples."""
        return self.length

    def __getitem__(self, idx):
        """
        Retrieves the tokenized data for a single sample.
        The DataCollator will handle tensor conversion and padding alignment.
        """
        try:
            # Return a dictionary slice for the given index
            return {key: val[idx] for key, val in self.encodings.items()}
        except IndexError:
            print(f"[Error] Index {idx} out of bounds for dataset of length {self.length}.", file=sys.stderr)
            raise
        except Exception as e:
            print(f"[Error] Failed to retrieve item at index {idx}: {e}", file=sys.stderr)
            raise


# --- Main Execution Logic ---

def main():
    """Orchestrates the T5 training and evaluation pipeline."""
    print("[INFO] Starting T5 Sequence-to-Sequence Model Training and Evaluation...")
    # --- ***** FIX: Use standard datetime module ***** ---
    try:
        # Use timezone-aware UTC time (recommended)
        current_time_str = datetime.datetime.now(datetime.timezone.utc).strftime('%Y-%m-%d %H:%M:%S UTC')
    except Exception:
        # Fallback if timezone object isn't available (older Python?) or other issue
        current_time_str = datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC (naive)')
    print(f"[INFO] Current date/time (UTC): {current_time_str}")


    # --- 1. Load Data ---
    try:
        train_raw_data = load_jsonl(TRAIN_FILE)
        val_raw_data = load_jsonl(VAL_FILE)
        test_raw_data = load_jsonl(TEST_FILE)
    except Exception:
        print(f"[Error] Failed to load one or more data files ({TRAIN_FILE}, {VAL_FILE}, {TEST_FILE}). Exiting.", file=sys.stderr)
        sys.exit(1)

    if not train_raw_data or not val_raw_data or not test_raw_data:
        print("[Error] One or more datasets are empty after loading. Exiting.", file=sys.stderr)
        sys.exit(1)

    # --- 2. Prepare Data (Tokens to Strings) ---
    train_inputs, train_targets = convert_tokens_to_strings(train_raw_data)
    val_inputs,   val_targets   = convert_tokens_to_strings(val_raw_data)
    test_inputs,  test_targets  = convert_tokens_to_strings(test_raw_data)

    if not train_inputs or not val_inputs or not test_inputs:
        print("[Error] Data conversion resulted in empty lists. Check input data format. Exiting.", file=sys.stderr)
        sys.exit(1)

    # --- 3. Initialize Tokenizer and Model ---
    print(f"[INFO] Initializing Tokenizer and Model: {MODEL_ID}")
    try:
        # Using legacy=False is recommended for T5Tokenizer for T5v1.1+ behavior, but True is default for t5-base/large
        tokenizer = T5Tokenizer.from_pretrained(MODEL_ID) # legacy=False might be better if using newer T5 variants
        model = T5ForConditionalGeneration.from_pretrained(MODEL_ID)
    except Exception as e:
        print(f"[Error] Failed to initialize tokenizer or model '{MODEL_ID}': {e}", file=sys.stderr)
        sys.exit(1)

    # Enable Gradient Checkpointing *on the model* if specified
    # This needs to be done before training starts.
    if USE_GRADIENT_CHECKPOINTING:
        try:
            model.gradient_checkpointing_enable()
            print("[INFO] Gradient Checkpointing enabled on the model.")
        except Exception as e:
            print(f"[Warning] Failed to enable gradient checkpointing on model: {e}. Training will proceed without it.")
            USE_GRADIENT_CHECKPOINTING = False # Ensure training arg matches reality

    # --- 4. Tokenize Data ---
    try:
        # Making tokenizer global for access in compute_metrics (simplification)
        # Consider passing via functools.partial if this becomes complex.
        global tokenizer_for_metrics
        tokenizer_for_metrics = tokenizer

        train_encodings = encode_sequences(tokenizer, train_inputs, train_targets, MAX_SEQ_LENGTH)
        val_encodings   = encode_sequences(tokenizer, val_inputs,   val_targets,   MAX_SEQ_LENGTH)
        test_encodings  = encode_sequences(tokenizer, test_inputs,  test_targets,  MAX_SEQ_LENGTH)
    except Exception as e:
        print(f"[Error] Failed during data tokenization: {e}", file=sys.stderr)
        sys.exit(1)

    # --- 5. Create Datasets ---
    try:
        train_dataset = SequenceDataset(train_encodings)
        val_dataset   = SequenceDataset(val_encodings)
        test_dataset  = SequenceDataset(test_encodings)
    except ValueError as e:
        print(f"[Error] Failed to create Dataset objects: {e}", file=sys.stderr)
        sys.exit(1)

    # --- 6. Initialize Data Collator ---
    # Handles dynamic padding and prepares decoder input IDs and labels correctly for T5.
    # It automatically creates `decoder_input_ids` by shifting `labels` and replaces padded label tokens with -100.
    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=model,
        label_pad_token_id=-100, # Standard practice for ignoring padding in loss
        pad_to_multiple_of=8 if USE_FP16 else None # Optimize padding for FP16 tensor cores
    )
    print("[INFO] Data collator initialized.")

    # --- 7. Define Metrics Computation ---
    # This function will be called by the Trainer during evaluation.
    # It receives generated token IDs because predict_with_generate=True.
    def compute_metrics_fn(eval_pred):
        """Calculates exact match accuracy between decoded predictions and labels."""
        # eval_pred is a tuple (predictions, labels)
        # For seq2seq with predict_with_generate=True, predictions are generated token IDs (np.ndarray)
        # Labels are also token IDs (np.ndarray), potentially padded with -100.
        predictions, labels = eval_pred

        # Ensure predictions and labels are numpy arrays
        if not isinstance(predictions, np.ndarray): predictions = np.array(predictions)
        if not isinstance(labels, np.ndarray): labels = np.array(labels)

        # Replace -100 in labels with the pad_token_id for decoding purposes.
        labels = np.where(labels != -100, labels, tokenizer_for_metrics.pad_token_id)

        # Decode predicted IDs and label IDs to strings.
        # skip_special_tokens=True removes tokens like <pad>, </s> etc.
        try:
            decoded_preds = tokenizer_for_metrics.batch_decode(predictions, skip_special_tokens=True)
            decoded_labels = tokenizer_for_metrics.batch_decode(labels, skip_special_tokens=True)
        except Exception as e:
             print(f"[Error] Failed to decode predictions/labels in compute_metrics: {e}", file=sys.stderr)
             # Return a default low score or re-raise depending on desired behavior
             return {"sequence_accuracy": 0.0}


        # Post-process: Strip leading/trailing whitespace for robust comparison
        decoded_preds = [pred.strip() for pred in decoded_preds]
        decoded_labels = [label.strip() for label in decoded_labels]

        # Calculate exact matches
        if len(decoded_preds) != len(decoded_labels):
             print(f"[Warning] Mismatch in number of predictions ({len(decoded_preds)}) and labels ({len(decoded_labels)}) after decoding.", file=sys.stderr)
             return {"sequence_accuracy": 0.0}

        matches = [pred == label for pred, label in zip(decoded_preds, decoded_labels)]
        accuracy = np.mean(matches) if matches else 0.0 # Handle empty case

        return {"sequence_accuracy": float(accuracy)}


    # --- 8. Define Training Arguments ---
    # Enables advanced features configured earlier.
    training_args = Seq2SeqTrainingArguments(
        output_dir=str(OUTPUT_DIR),
        # Batch sizes, epochs, learning rate...
        per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
        per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
        num_train_epochs=NUM_TRAIN_EPOCHS,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        warmup_steps=WARMUP_STEPS,
        # Logging, evaluation, saving strategies...
        logging_dir=str(OUTPUT_DIR / 'logs'),
        logging_steps=LOGGING_STEPS,
        evaluation_strategy=EVALUATION_STRATEGY,
        save_strategy=SAVE_STRATEGY,
        save_total_limit=2,             # Keep only the last 2 checkpoints
        load_best_model_at_end=True,    # Load best model checkpoint at the end of training
        metric_for_best_model="sequence_accuracy", # Metric to determine the "best" model
        greater_is_better=True,         # Higher accuracy is better
        # Generation settings (used because predict_with_generate=True)
        predict_with_generate=True,
        generation_max_length=MAX_SEQ_LENGTH, # Should match tokenizer max_length for consistency
        generation_num_beams=GENERATION_NUM_BEAMS,
        # Advanced features
        fp16=USE_FP16, # Enable mixed precision training
        label_smoothing_factor=LABEL_SMOOTHING_FACTOR, # Enable label smoothing
        # Enable gradient checkpointing within the Trainer's control flow
        gradient_checkpointing=USE_GRADIENT_CHECKPOINTING,
        # report_to="tensorboard",      # Example: uncomment to enable TensorBoard logging
    )
    print("[INFO] Training arguments defined.")
    # Print status of key features
    print(f"[INFO] Mixed Precision (FP16): {'Enabled' if USE_FP16 else 'Disabled'}")
    print(f"[INFO] Label Smoothing Factor: {LABEL_SMOOTHING_FACTOR}")
    print(f"[INFO] Gradient Checkpointing: {'Enabled' if USE_GRADIENT_CHECKPOINTING else 'Disabled'}")


    # --- 9. Initialize Trainer ---
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,           # Pass tokenizer for saving/loading and decoding if needed
        compute_metrics=compute_metrics_fn, # Pass the metric calculation function
    )
    print("[INFO] Seq2SeqTrainer initialized.")

    # --- 10. Train the Model ---
    print(f"[INFO] Starting model training for {NUM_TRAIN_EPOCHS} epochs...")
    try:
        train_result = trainer.train()
        # Log and save training metrics and state
        metrics = train_result.metrics
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
        # Save the final best model checkpoint explicitly (already done by load_best_model_at_end + save_strategy)
        # trainer.save_model(str(OUTPUT_DIR / "best_model")) # Can keep if explicit save desired
        print("[INFO] Training finished successfully.")
        print(f"[INFO] Best model saved to: {trainer.state.best_model_checkpoint}")
    except Exception as e:
        print(f"[Error] Training failed: {e}", file=sys.stderr)
        # Consider cleanup or further diagnostics here if needed
        sys.exit(1)

    # --- 11. Evaluate on Test Set ---
    print("[INFO] Evaluating model on the test set...")
    try:
        test_results = trainer.evaluate(
            eval_dataset=test_dataset,
            metric_key_prefix="test" # Prefix metrics with 'test_' (e.g., 'test_sequence_accuracy')
        )
        trainer.log_metrics("test", test_results)
        trainer.save_metrics("test", test_results)

        # Print the key metric clearly
        if 'test_sequence_accuracy' in test_results:
             print(f"\n--- Test Set Results ---")
             print(f"Test Sequence Accuracy: {test_results['test_sequence_accuracy']:.4f}")
             print(f"------------------------")
        else:
             print("[Warning] 'test_sequence_accuracy' not found in evaluation results.")
             print("Full test results:", test_results)

    except Exception as e:
        print(f"[Error] Evaluation on test set failed: {e}", file=sys.stderr)
        sys.exit(1)

    print("[INFO] Script finished successfully.")


if __name__ == "__main__":
    # Optional: Set random seeds for reproducibility
    # SEED = 42
    # torch.manual_seed(SEED)
    # np.random.seed(SEED)
    # # import random # If using random module
    # # random.seed(SEED)
    # if torch.cuda.is_available():
    #     torch.cuda.manual_seed_all(SEED)
    #     # May need deterministic algorithms for full reproducibility, potentially slower
    #     # torch.backends.cudnn.deterministic = True
    #     # torch.backends.cudnn.benchmark = False

    main()

[INFO] Starting T5 Sequence-to-Sequence Model Training and Evaluation...


AttributeError: module 'torch' has no attribute 'datetime'

# 3.1: Next-Generation Transformer Models for Symbolic Calculations of Squared Amplitudes in HEP
Model: Transformer model with a contemporary innovation added such as KAN layers, reinforcement learning, genetic algorithms, specialized long-sequence attention, etc. which improves the performance compared to a basic transformer.


In [6]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
End-to-end Fine-tuning and Evaluation Script for BigBird-Pegasus
on the Preprocessed QED 2-to-2 Tree-Level Dataset.

This script leverages the BigBird-Pegasus model, specifically designed for handling
long sequences efficiently using block-sparse attention mechanisms. It incorporates
several advanced training techniques:

- Gradient Checkpointing: Reduces GPU memory footprint during training, enabling
  fine-tuning of large models like BigBird-Pegasus on systems with limited VRAM,
  at the cost of slightly increased computation time.
- Mixed Precision Training (FP16): Accelerates training and further decreases
  memory usage on compatible hardware (NVIDIA Volta GPUs or newer) by utilizing
  half-precision floating-point numbers for certain computations.
- Label Smoothing: A regularization technique applied to the loss function to
  prevent the model from becoming overconfident, potentially enhancing robustness
  and generalization.
- Beam Search Generation: Employed during evaluation (`predict_with_generate=True`)
  to produce output sequences by exploring multiple hypotheses, often leading to
  more coherent results than simple greedy decoding.
- Exact Match Accuracy: The primary evaluation metric, measuring the percentage
  of generated sequences that exactly match the ground truth target sequences
  after decoding and normalization (stripping whitespace).

Workflow Overview:
1.  Load pre-split training, validation, and test datasets from JSONL files.
2.  Reconstruct source (input) and target (label) strings from token lists.
3.  Initialize the BigBird-Pegasus tokenizer and model.
4.  Enable gradient checkpointing on the loaded model instance.
5.  Define a function to tokenize source/target string pairs, handling truncation
    and padding according to the specified maximum sequence length.
6.  Wrap the tokenized data into PyTorch Dataset objects.
7.  Instantiate a Data Collator suitable for sequence-to-sequence tasks, which
    handles dynamic batch padding and prepares decoder inputs/labels.
8.  Define a function (`compute_metrics`) for calculating sequence accuracy based
    on comparing decoded generated sequences with decoded labels.
9.  Configure `Seq2SeqTrainingArguments`, setting hyperparameters and enabling
    the advanced features (FP16, gradient checkpointing, label smoothing, etc.).
10. Initialize the `Seq2SeqTrainer`.
11. Execute the training loop.
12. Perform final evaluation on the held-out test set.
"""

import json
import sys
import numpy as np
import torch
import datetime
from pathlib import Path
from torch.utils.data import Dataset
from transformers import (
    BigBirdPegasusForConditionalGeneration,
    BigBirdPegasusTokenizer,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    TrainingArguments # Explicit import for clarity
)

# --- Configuration ---

# File Paths (adjust if your preprocessed files have different names/locations)
TRAIN_FILE = Path('qed_expressions_train.jsonl') # Training data
VAL_FILE   = Path('qed_expressions_val.jsonl')   # Validation data
TEST_FILE  = Path('qed_expressions_test.jsonl')  # Test data
OUTPUT_DIR = Path('qed_bigbird_pegasus_output')  # Checkpoints and logs directory

# Model Configuration
MODEL_ID = "google/bigbird-pegasus-large-arxiv" # BigBird specialized for scientific text

# Tokenizer and Data Processing Configuration
# BigBird supports longer sequences compared to standard transformers
MAX_SEQ_LENGTH = 1024 # Max sequence length for inputs and targets

# Training Hyperparameters
# Note: BigBird is large; adjust batch sizes based on available GPU memory.
NUM_TRAIN_EPOCHS = 3 # Adjust as needed
PER_DEVICE_TRAIN_BATCH_SIZE = 2 # Likely needs to be small (e.g., 1, 2, 4)
PER_DEVICE_EVAL_BATCH_SIZE = 4  # Can often be slightly larger than train batch size
LEARNING_RATE = 2e-5            # Common starting point for fine-tuning large models
WEIGHT_DECAY = 0.01             # Regularization parameter
WARMUP_STEPS = 250              # Linear learning rate warmup
LOGGING_STEPS = 50              # Frequency of logging metrics
EVALUATION_STRATEGY = "epoch"   # Evaluate every epoch
SAVE_STRATEGY = "epoch"         # Save checkpoint every epoch

# Advanced Training Feature Flags/Values
USE_FP16 = torch.cuda.is_available() # Enable mixed precision if CUDA available
USE_GRADIENT_CHECKPOINTING = True    # Enable gradient checkpointing for memory saving
LABEL_SMOOTHING_FACTOR = 0.1         # Apply label smoothing (0.0 disables)

# Generation Configuration (for evaluation with predict_with_generate=True)
GENERATION_NUM_BEAMS = 4             # Number of beams for beam search

# --- Helper Functions and Classes ---

def load_jsonl(file_path):
    """Loads data from a JSON Lines file, handling basic errors."""
    data = []
    file_path = Path(file_path)
    print(f"[INFO] Loading data from: {file_path}")
    if not file_path.is_file():
        print(f"[Error] Data file not found: {file_path}", file=sys.stderr)
        raise FileNotFoundError(f"File not found: {file_path}")

    try:
        with file_path.open('r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                line = line.strip()
                if line:
                    try:
                        data.append(json.loads(line))
                    except json.JSONDecodeError as e:
                        print(f"[Warning] Skipping invalid JSON on line {i+1} in {file_path}: {e}")
                        continue
        print(f"[INFO] Successfully loaded {len(data)} records.")
        if not data:
             print(f"[Warning] No valid records loaded from {file_path}.")
        return data
    except Exception as e:
        print(f"[Error] Failed to load data from {file_path}: {e}", file=sys.stderr)
        raise


def convert_tokens_to_strings(raw_data_list):
    """Converts lists of tokens from loaded data into source and target strings."""
    source_strings = []
    target_strings = []
    if not raw_data_list:
        return source_strings, target_strings

    skipped_count = 0
    for i, item in enumerate(raw_data_list):
        input_toks = item.get('input_tokens')
        target_toks = item.get('target_tokens')

        if isinstance(input_toks, list) and isinstance(target_toks, list):
            source_strings.append(" ".join(map(str, input_toks)))
            target_strings.append(" ".join(map(str, target_toks)))
        else:
            print(f"[Warning] Skipping item at index {i} due to missing/invalid keys or non-list values: {item}")
            skipped_count += 1

    print(f"[INFO] Converted {len(source_strings)} items to strings (skipped {skipped_count}).")
    if not source_strings:
        print("[Warning] No items successfully converted to strings.")
    return source_strings, target_strings


def encode_sequences(tokenizer, source_texts, target_texts, max_len):
    """
    Tokenizes source and target text sequences using the provided tokenizer.

    Handles padding to max_len and truncation for sequences exceeding max_len.
    Prepares the 'labels' field required for sequence-to-sequence model training.

    Args:
        tokenizer: Initialized Hugging Face tokenizer instance.
        source_texts (list[str]): List of source sequences for the encoder.
        target_texts (list[str]): List of target sequences for the decoder labels.
        max_len (int): Maximum sequence length.

    Returns:
        dict: Dictionary containing 'input_ids', 'attention_mask', and 'labels'.
    """
    print(f"[INFO] Encoding sequence pairs with max_length={max_len}...")

    # Tokenize source texts
    encoder_inputs = tokenizer(
        source_texts,
        max_length=max_len,
        padding='max_length',   # Pad shorter sequences to max_len
        truncation=True,        # Truncate longer sequences to max_len
        return_tensors=None     # Return lists; collator manages tensor conversion
    )

    # Tokenize target texts to create labels
    # Using the context manager ensures correct handling if tokenizer has specific target modes
    with tokenizer.as_target_tokenizer():
        decoder_labels = tokenizer(
            target_texts,
            max_length=max_len,
            padding='max_length',
            truncation=True,
            return_tensors=None
        )

    # Assign the tokenized target IDs as 'labels' for the model
    encoder_inputs['labels'] = decoder_labels['input_ids']

    print(f"[INFO] Encoding complete.")
    # Validate encoding results basic check
    if not encoder_inputs['input_ids'] or not encoder_inputs['labels']:
        print("[Warning] Encoding resulted in empty lists for input_ids or labels.")
    elif len(encoder_inputs['input_ids']) != len(encoder_inputs['labels']):
         print("[Warning] Mismatch between number of encoded inputs and labels.")

    return encoder_inputs


class SequencePairDataset(Dataset):
    """Simple PyTorch Dataset for holding tokenized sequence pairs."""
    def __init__(self, encodings):
        """
        Args:
            encodings (dict): A dictionary from the tokenizer, containing lists
                              for 'input_ids', 'attention_mask', 'labels'.
        """
        if not isinstance(encodings, dict) or 'input_ids' not in encodings:
            raise ValueError("Encodings must be a dictionary containing at least 'input_ids'.")
        self.encodings = encodings
        try:
            self.length = len(encodings['input_ids'])
            # Basic validation
            for key in encodings:
                if not isinstance(encodings[key], list) or len(encodings[key]) != self.length:
                    raise ValueError(f"Encoding key '{key}' is not a list or has inconsistent length.")
        except Exception as e:
            raise ValueError(f"Failed to validate encodings: {e}")

        if self.length == 0:
            raise ValueError("Input encodings are empty.")
        print(f"[INFO] Created Dataset with {self.length} examples.")

    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return self.length

    def __getitem__(self, idx):
        """
        Retrieves the tokenized data for a single sample index.
        Tensor conversion is typically handled by the DataCollator.
        """
        if not 0 <= idx < self.length:
             raise IndexError(f"Index {idx} out of bounds for dataset of length {self.length}.")
        try:
            # Return a dictionary containing the data for the specified index
            return {key: self.encodings[key][idx] for key in self.encodings}
        except Exception as e:
            print(f"[Error] Failed to retrieve item at index {idx}: {e}", file=sys.stderr)
            # Depending on severity, could return None or re-raise
            raise


# --- Main Execution Logic ---

def main():
    """Orchestrates the BigBird-Pegasus fine-tuning and evaluation pipeline."""
    print("[INFO] Starting BigBird-Pegasus Fine-tuning Script...")
    try:
        # Use timezone-aware UTC time
        current_time_str = datetime.datetime.now(datetime.timezone.utc).strftime('%Y-%m-%d %H:%M:%S UTC')
    except Exception:
        current_time_str = datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC (naive)') # Fallback
    print(f"[INFO] Current date/time (UTC): {current_time_str}")

    # --- 1. Load Data ---
    try:
        train_raw_data = load_jsonl(TRAIN_FILE)
        val_raw_data = load_jsonl(VAL_FILE)
        test_raw_data = load_jsonl(TEST_FILE)
    except Exception as e: # Catch errors from load_jsonl (FileNotFound, JSONDecode, etc.)
        print(f"[FATAL] Critical error during data loading: {e}", file=sys.stderr)
        sys.exit(1)

    if not train_raw_data or not val_raw_data or not test_raw_data:
        print("[FATAL] One or more required datasets (train, val, test) are empty or failed to load. Exiting.", file=sys.stderr)
        sys.exit(1)

    # --- 2. Prepare Data (Tokens to Strings) ---
    train_sources, train_targets = convert_tokens_to_strings(train_raw_data)
    val_sources,   val_targets   = convert_tokens_to_strings(val_raw_data)
    test_sources,  test_targets  = convert_tokens_to_strings(test_raw_data)

    if not train_sources or not val_sources or not test_sources:
        print("[FATAL] Data conversion resulted in one or more empty lists. Check input data format. Exiting.", file=sys.stderr)
        sys.exit(1)

    # --- 3. Initialize Tokenizer and Model ---
    print(f"[INFO] Initializing Tokenizer and Model: {MODEL_ID}")
    try:
        tokenizer = BigBirdPegasusTokenizer.from_pretrained(MODEL_ID)
        model = BigBirdPegasusForConditionalGeneration.from_pretrained(MODEL_ID)
    except Exception as e:
        print(f"[FATAL] Failed to initialize tokenizer or model '{MODEL_ID}': {e}", file=sys.stderr)
        sys.exit(1)

    # Enable Gradient Checkpointing on the model if configured
    if USE_GRADIENT_CHECKPOINTING:
        try:
            model.gradient_checkpointing_enable()
            print("[INFO] Gradient Checkpointing enabled on the model.")
        except Exception as e:
            print(f"[Warning] Failed to enable gradient checkpointing on model: {e}. Training argument will be disabled.", file=sys.stderr)
            # Ensure the training argument reflects this failure
            global USE_GRADIENT_CHECKPOINTING_EFFECTIVE
            USE_GRADIENT_CHECKPOINTING_EFFECTIVE = False
    else:
         USE_GRADIENT_CHECKPOINTING_EFFECTIVE = False

    # --- 4. Tokenize Data ---
    try:
        # Define tokenizer globally for metric computation simplicity
        global tokenizer_for_metrics
        tokenizer_for_metrics = tokenizer

        train_encodings = encode_sequences(tokenizer, train_sources, train_targets, MAX_SEQ_LENGTH)
        val_encodings   = encode_sequences(tokenizer, val_sources,   val_targets,   MAX_SEQ_LENGTH)
        test_encodings  = encode_sequences(tokenizer, test_sources,  test_targets,  MAX_SEQ_LENGTH)
    except Exception as e:
        print(f"[FATAL] Failed during data tokenization: {e}", file=sys.stderr)
        sys.exit(1)

     # Validate that encodings are not empty after tokenization
    if not train_encodings or not train_encodings.get('input_ids') or \
       not val_encodings or not val_encodings.get('input_ids') or \
       not test_encodings or not test_encodings.get('input_ids'):
         print("[FATAL] Tokenization resulted in empty encodings for one or more splits. Exiting.", file=sys.stderr)
         sys.exit(1)

    # --- 5. Create Datasets ---
    try:
        train_dataset = SequencePairDataset(train_encodings)
        val_dataset   = SequencePairDataset(val_encodings)
        test_dataset  = SequencePairDataset(test_encodings)
    except ValueError as e:
        print(f"[FATAL] Failed to create Dataset objects: {e}", file=sys.stderr)
        sys.exit(1)

    # --- 6. Initialize Data Collator ---
    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=model,
        label_pad_token_id=-100, # Ignore padding tokens in loss calculation
        pad_to_multiple_of=8 if USE_FP16 else None # Optimize for FP16 if used
    )
    print("[INFO] Data collator initialized.")

    # --- 7. Define Metrics Computation ---
    def compute_metrics_fn(eval_pred):
        """Calculates exact sequence match accuracy after decoding."""
        predictions, labels = eval_pred
        if not isinstance(predictions, np.ndarray): predictions = np.array(predictions)
        if not isinstance(labels, np.ndarray): labels = np.array(labels)

        # Replace -100 (ignore index) with pad_token_id for decoding
        labels = np.where(labels != -100, labels, tokenizer_for_metrics.pad_token_id)

        try:
            decoded_preds = tokenizer_for_metrics.batch_decode(predictions, skip_special_tokens=True)
            decoded_labels = tokenizer_for_metrics.batch_decode(labels, skip_special_tokens=True)
        except Exception as e:
             print(f"[Error] Decoding failed in compute_metrics: {e}", file=sys.stderr)
             return {"sequence_accuracy": 0.0}

        # Post-processing: strip whitespace for comparison
        decoded_preds = [pred.strip() for pred in decoded_preds]
        decoded_labels = [label.strip() for label in decoded_labels]

        if len(decoded_preds) != len(decoded_labels):
            print(f"[Warning] Mismatch in prediction/label count after decoding: {len(decoded_preds)} vs {len(decoded_labels)}", file=sys.stderr)
            return {"sequence_accuracy": 0.0}

        matches = [pred == label for pred, label in zip(decoded_preds, decoded_labels)]
        accuracy = np.mean(matches) if matches else 0.0
        return {"sequence_accuracy": float(accuracy)}

    # --- 8. Define Training Arguments ---
    # Update effective GC flag if model enabling failed
    effective_gc = USE_GRADIENT_CHECKPOINTING if 'USE_GRADIENT_CHECKPOINTING_EFFECTIVE' not in globals() else USE_GRADIENT_CHECKPOINTING_EFFECTIVE

    training_args = Seq2SeqTrainingArguments(
        output_dir=str(OUTPUT_DIR),
        # Core training parameters
        num_train_epochs=NUM_TRAIN_EPOCHS,
        per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
        per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        warmup_steps=WARMUP_STEPS,
        # Logging, saving, evaluation strategies
        logging_dir=str(OUTPUT_DIR / 'logs'),
        logging_steps=LOGGING_STEPS,
        evaluation_strategy=EVALUATION_STRATEGY,
        save_strategy=SAVE_STRATEGY,
        save_total_limit=2,            # Keep only last 2 checkpoints
        load_best_model_at_end=True,   # Reload best model found during training
        metric_for_best_model="sequence_accuracy", # Metric to define "best"
        greater_is_better=True,        # Higher accuracy is better
        # Generation configuration
        predict_with_generate=True,    # Use model.generate() for evaluation
        generation_max_length=MAX_SEQ_LENGTH, # Max length during generation
        generation_num_beams=GENERATION_NUM_BEAMS, # Beam search width
        # Advanced features
        fp16=USE_FP16,
        label_smoothing_factor=LABEL_SMOOTHING_FACTOR,
        gradient_checkpointing=effective_gc, # Use effective flag based on model setup
        # report_to="tensorboard",     # Optional: Enable TensorBoard logging
    )
    print("[INFO] Training arguments defined.")
    print(f"[INFO] Effective Mixed Precision (FP16): {'Enabled' if USE_FP16 else 'Disabled'}")
    print(f"[INFO] Effective Label Smoothing Factor: {LABEL_SMOOTHING_FACTOR}")
    print(f"[INFO] Effective Gradient Checkpointing: {'Enabled' if effective_gc else 'Disabled'}")

    # --- 9. Initialize Trainer ---
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics_fn,
    )
    print("[INFO] Seq2SeqTrainer initialized.")

    # --- 10. Train the Model ---
    print(f"[INFO] Starting model training for {NUM_TRAIN_EPOCHS} epochs...")
    try:
        train_result = trainer.train()
        # Log and save final training state and metrics
        metrics = train_result.metrics
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
        # Best model is loaded automatically if load_best_model_at_end=True
        print("[INFO] Training finished successfully.")
        if trainer.state.best_model_checkpoint:
             print(f"[INFO] Best model checkpoint saved at: {trainer.state.best_model_checkpoint}")
        else:
             print("[Warning] No best model checkpoint recorded by trainer state.")

    except Exception as e:
        print(f"[FATAL] Training loop encountered an error: {e}", file=sys.stderr)
        # Potentially log more details about the state here
        sys.exit(1)

    # --- 11. Evaluate on Test Set ---
    print("[INFO] Evaluating final model on the test set...")
    try:
        test_results = trainer.evaluate(
            eval_dataset=test_dataset,
            metric_key_prefix="test" # Prefix test metrics (e.g., test_loss, test_sequence_accuracy)
        )
        trainer.log_metrics("test", test_results)
        trainer.save_metrics("test", test_results)

        # Report final test accuracy clearly
        if 'test_sequence_accuracy' in test_results:
             print(f"\n--- Test Set Results ---")
             print(f"Test Sequence Accuracy: {test_results['test_sequence_accuracy']:.4f}")
             print(f"------------------------")
        else:
             print("[Warning] 'test_sequence_accuracy' metric not found in test results.")
             print("Full test results:", test_results)

    except Exception as e:
        print(f"[Error] Evaluation on test set failed: {e}", file=sys.stderr)
        # The model is trained, but test evaluation failed.
        sys.exit(1) # Or choose to exit with a different code

    print("[INFO] Script finished successfully.")


if __name__ == "__main__":
    # Optional: Add argument parsing here (e.g., using argparse) to override
    # configuration variables like paths, batch sizes, epochs, etc.

    # Optional: Set random seeds for reproducibility across libraries
    # SEED = 42
    # torch.manual_seed(SEED)
    # np.random.seed(SEED)
    # import random; random.seed(SEED)
    # if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)
    # Consider also torch.backends.cudnn.deterministic = True / benchmark = False for stricter reproducibility

    main()

ImportError: cannot import name 'BigBirdPegasusTokenizer' from 'transformers' (/home/nikitas/anaconda3/envs/torch_cpu_env/lib/python3.10/site-packages/transformers/__init__.py)

# 3.2: State-space Models for Squared Amplitude Calculation in High-Energy Physics
Model: State-space model such as mamba or other model.

In [7]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
End-to-end Training and Evaluation Script for a Custom Sequence-to-Sequence
Model using State Space Model (SSM) layers (based on Mamba architecture concepts)
on the Preprocessed QED 2-to-2 Tree-Level Dataset.

This script implements a manual PyTorch training loop, including:
- Data loading from preprocessed JSONL files.
- A custom Dataset class with efficient pre-tokenization.
- Definition of an SSM-based Encoder followed by a Linear Decoder head.
- A standard training loop with optimizer steps and loss calculation.
- Validation and testing loops calculating exact sequence match accuracy.
- Device management (CPU/GPU).

Note: This implementation uses a hypothetical `StateSpaceLayer` from a library
      named 'mamba'. Ensure such a library/layer compatible with PyTorch exists
      and is installed in your environment. The 'decoder' here is a simple
      linear projection, predicting the target sequence in parallel based on
      encoder outputs, not an auto-regressive generative decoder.
"""

import json
import sys
import numpy as np
import torch
import datetime
import os
from pathlib import Path
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm # For progress bars
from transformers import AutoTokenizer, get_linear_schedule_with_warmup

# --- Hypothetical SSM Layer Import ---
# Ensure 'mamba' library and 'StateSpaceLayer' are correctly installed/defined
try:
    # Attempt to import the specific layer needed
    from mamba_ssm import Mamba # Example using official 'mamba_ssm' package
    # If using a different library, adjust the import accordingly.
    # For this example, let's assume Mamba serves as the StateSpaceLayer
    StateSpaceLayer = Mamba # Alias for clarity in the model definition below
    print("[INFO] Using Mamba layer from 'mamba_ssm' as StateSpaceLayer.")
except ImportError:
    print("[ERROR] 'mamba_ssm' library not found. Please install it (`pip install mamba_ssm causal-conv1d>=1.1.0`) or replace StateSpaceLayer with your implementation.", file=sys.stderr)
    # Define a dummy layer to allow script structure analysis if library missing
    class StateSpaceLayer(nn.Module):
        def __init__(self, d_model, *args, **kwargs):
             super().__init__()
             print("[Warning] Using Dummy StateSpaceLayer as 'mamba_ssm' not found.")
             self.layer = nn.Linear(d_model, d_model) # Simple placeholder
        def forward(self, x): return self.layer(x)
    # sys.exit(1) # Optionally exit if the real layer is crucial

# --- Configuration ---

# File Paths
TRAIN_FILE = Path('qed_expressions_train.jsonl')
VAL_FILE   = Path('qed_expressions_val.jsonl')
TEST_FILE  = Path('qed_expressions_test.jsonl')
OUTPUT_DIR = Path('qed_ssm_model_output')
CHECKPOINT_NAME = "ssm_seq2seq_best.pt"

# Tokenizer Configuration
TOKENIZER_ID = "bert-base-uncased"
MAX_SEQ_LENGTH = 256 # Max sequence length for tokenization

# Model Hyperparameters
D_MODEL = 512       # Core dimensionality of the model embeddings and layers
N_LAYERS = 6        # Number of stacked SSM layers in the encoder
# SSM_KERNEL_SIZE = 64 # Parameter specific to hypothetical 'StateSpaceLayer' - adjust based on actual layer args
# Mamba specific args (replace StateSpaceLayer args if using the library directly)
MAMBA_D_STATE = 16  # Typical value for Mamba state dimension
MAMBA_D_CONV = 4    # Typical value for Mamba conv dimension
MAMBA_EXPAND = 2    # Typical value for Mamba expansion factor

# Training Hyperparameters
NUM_EPOCHS = 5
BATCH_SIZE = 16     # Adjust based on GPU memory
LEARNING_RATE = 3e-4
OPTIMIZER_EPS = 1e-8
WEIGHT_DECAY = 0.01
LR_SCHEDULER_TYPE = "linear" # Type of learning rate scheduler
WARMUP_RATIO = 0.1           # Percentage of training steps for warmup
LOG_INTERVAL = 50            # Log training loss every N steps

# --- Helper Functions and Classes ---

def load_jsonl(file_path):
    """Loads data from a JSON Lines file, handling basic errors."""
    # (Same implementation as previous script)
    data = []
    file_path = Path(file_path)
    print(f"[INFO] Loading data from: {file_path}")
    if not file_path.is_file():
        print(f"[Error] Data file not found: {file_path}", file=sys.stderr)
        raise FileNotFoundError(f"File not found: {file_path}")

    try:
        with file_path.open('r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                line = line.strip()
                if line:
                    try:
                        data.append(json.loads(line))
                    except json.JSONDecodeError as e:
                        print(f"[Warning] Skipping invalid JSON on line {i+1} in {file_path}: {e}")
                        continue
        print(f"[INFO] Successfully loaded {len(data)} records.")
        if not data:
             print(f"[Warning] No valid records loaded from {file_path}.")
        return data
    except Exception as e:
        print(f"[Error] Failed to load data from {file_path}: {e}", file=sys.stderr)
        raise

def convert_tokens_to_strings(raw_data_list):
    """Converts lists of tokens from loaded data into source and target strings."""
    # (Same implementation as previous script)
    source_strings = []
    target_strings = []
    if not raw_data_list:
        return source_strings, target_strings

    skipped_count = 0
    for i, item in enumerate(raw_data_list):
        input_toks = item.get('input_tokens')
        target_toks = item.get('target_tokens')

        if isinstance(input_toks, list) and isinstance(target_toks, list):
            source_strings.append(" ".join(map(str, input_toks)))
            target_strings.append(" ".join(map(str, target_toks)))
        else:
            print(f"[Warning] Skipping item at index {i} due to missing/invalid keys or non-list values: {item}")
            skipped_count += 1

    print(f"[INFO] Converted {len(source_strings)} items to strings (skipped {skipped_count}).")
    if not source_strings:
        print("[Warning] No items successfully converted to strings.")
    return source_strings, target_strings


class PretokenizedSequenceDataset(Dataset):
    """
    PyTorch Dataset storing pre-tokenized sequences for efficiency.

    Assumes tokenization is done once during initialization.
    """
    def __init__(self, source_strings, target_strings, tokenizer, max_len):
        """
        Initializes the dataset by tokenizing all source and target strings.

        Args:
            source_strings (list[str]): List of source sequences.
            target_strings (list[str]): List of target sequences.
            tokenizer: Initialized Hugging Face tokenizer instance.
            max_len (int): Maximum sequence length for tokenization.
        """
        print(f"[INFO] Pre-tokenizing dataset with max_length={max_len}...")
        if len(source_strings) != len(target_strings):
            raise ValueError("Source and target string lists must have the same length.")

        # Tokenize source sequences
        source_encodings = tokenizer(
            source_strings,
            max_length=max_len,
            padding='max_length',
            truncation=True,
            return_tensors=None # Get lists of IDs first
        )

        # Tokenize target sequences to get labels
        with tokenizer.as_target_tokenizer():
            target_encodings = tokenizer(
                target_strings,
                max_length=max_len,
                padding='max_length',
                truncation=True,
                return_tensors=None
            )

        self.input_ids = source_encodings['input_ids']
        self.attention_mask = source_encodings['attention_mask']
        self.labels = target_encodings['input_ids']

        self.length = len(self.input_ids)
        if self.length == 0:
            raise ValueError("Tokenization resulted in an empty dataset.")
        print(f"[INFO] Pre-tokenization complete. Dataset size: {self.length} examples.")

    def __len__(self):
        """Returns the total number of samples."""
        return self.length

    def __getitem__(self, idx):
        """Retrieves pre-tokenized data for a given index."""
        if not 0 <= idx < self.length:
            raise IndexError(f"Index {idx} out of bounds for dataset of length {self.length}.")
        # Return data as a dictionary. DataLoader will handle batching and tensor conversion.
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "labels": self.labels[idx],
        }


class SSMEncoderDecoder(nn.Module):
    """
    A sequence-to-sequence model using an SSM-based encoder and a linear decoder head.

    Note: The 'decoder' predicts the entire target sequence in parallel based on the
          final hidden states of the encoder, not auto-regressively.
    """
    def __init__(self, vocab_size, d_model=512, n_layers=4, max_len=256, pad_token_id=0, **ssm_kwargs):
        """
        Initializes the SSM Encoder-Decoder model.

        Args:
            vocab_size (int): The size of the vocabulary.
            d_model (int): The dimensionality of embeddings and hidden states.
            n_layers (int): The number of SSM layers in the encoder stack.
            max_len (int): Maximum sequence length for positional embeddings.
            pad_token_id (int): The ID of the padding token for loss calculation.
            **ssm_kwargs: Additional keyword arguments passed to the StateSpaceLayer
                          (e.g., d_state, d_conv, expand for Mamba).
        """
        super().__init__()
        self.d_model = d_model
        self.pad_token_id = pad_token_id
        self.vocab_size = vocab_size

        # Input Embeddings: Token + Absolute Positional
        # Padding_idx prevents the padding token embedding from being updated.
        self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
        # Learnable absolute positional embeddings
        self.pos_emb = nn.Parameter(torch.randn(1, max_len, d_model))
        # Optional: Dropout after embeddings
        self.emb_dropout = nn.Dropout(0.1)

        # Encoder Stack: Consists of multiple State Space Model layers.
        # These layers process the sequence sequentially, updating an internal state.
        self.encoder_ssm_stack = nn.ModuleList([
            StateSpaceLayer(d_model=d_model, **ssm_kwargs)
            # Example Mamba args: d_state=16, d_conv=4, expand=2
            # Ensure ssm_kwargs match the expected arguments of your StateSpaceLayer
            for _ in range(n_layers)
        ])
        # Optional: Layer Normalization after the SSM stack
        self.encoder_norm = nn.LayerNorm(d_model)

        # Decoder Head: Simple linear layer projecting final hidden states to vocabulary logits.
        # Predicts each target token independently based on the corresponding encoder output state.
        self.output_projection = nn.Linear(d_model, vocab_size)

        print(f"[INFO] Initialized SSMEncoderDecoder model:")
        print(f"  - Vocab Size: {vocab_size}")
        print(f"  - Embedding Dim (d_model): {d_model}")
        print(f"  - Max Sequence Length: {max_len}")
        print(f"  - Encoder SSM Layers: {n_layers}")
        print(f"  - SSM Kwargs: {ssm_kwargs}")
        print(f"  - Pad Token ID: {pad_token_id}")


    def forward(self, input_ids, attention_mask=None, labels=None):
        """
        Performs the forward pass of the model.

        Args:
            input_ids (torch.Tensor): Tensor of input token IDs (batch_size, seq_len).
            attention_mask (torch.Tensor, optional): Mask indicating non-padding tokens.
                                                    Currently unused by this simple SSM stack
                                                    but kept for API consistency.
            labels (torch.Tensor, optional): Tensor of target token IDs for loss calculation
                                             (batch_size, seq_len).

        Returns:
            dict: A dictionary containing 'loss' (if labels provided) and 'logits'.
        """
        batch_size, seq_len = input_ids.shape

        # 1. Embeddings
        token_embeddings = self.token_emb(input_ids) # (batch_size, seq_len, d_model)
        # Add positional embeddings (slice to match input seq_len)
        positional_embeddings = self.pos_emb[:, :seq_len, :] # (1, seq_len, d_model)
        x = token_embeddings + positional_embeddings # (batch_size, seq_len, d_model)
        x = self.emb_dropout(x)

        # 2. Encoder SSM Stack
        # Note: Basic SSM layers might not inherently use an attention mask.
        # More advanced implementations might incorporate masking.
        for ssm_layer in self.encoder_ssm_stack:
            x = ssm_layer(x) # (batch_size, seq_len, d_model) - state is internal

        x = self.encoder_norm(x) # Apply layer norm after the stack

        # 3. Output Projection (Decoder Head)
        # Project final hidden states to vocabulary logits for each position
        logits = self.output_projection(x) # (batch_size, seq_len, vocab_size)

        # 4. Loss Calculation (if labels are provided)
        loss = None
        if labels is not None:
            # CrossEntropyLoss expects logits as (N, C) and labels as (N)
            # N = batch_size * seq_len, C = vocab_size
            # Flatten logits and labels, ignoring padding tokens in labels.
            loss = nn.functional.cross_entropy(
                logits.view(-1, self.vocab_size), # (batch_size * seq_len, vocab_size)
                labels.view(-1),                  # (batch_size * seq_len)
                ignore_index=self.pad_token_id    # Ignore padding tokens when calculating loss
            )

        return {"loss": loss, "logits": logits}


def calculate_sequence_accuracy(logits, labels, pad_token_id):
    """
    Calculates exact sequence match accuracy, ignoring padding.

    Args:
        logits (torch.Tensor): Model output logits (batch_size, seq_len, vocab_size).
        labels (torch.Tensor): Ground truth labels (batch_size, seq_len).
        pad_token_id (int): The ID of the padding token.

    Returns:
        float: The fraction of sequences that match the labels exactly (excluding padding).
    """
    if logits.shape[0] == 0: return 0.0 # Handle empty batch case

    predictions = torch.argmax(logits, dim=-1) # Get predicted token IDs (batch_size, seq_len)

    # Create mask for non-padding tokens in labels
    non_pad_mask = (labels != pad_token_id)

    # Check equality only for non-padded positions
    correct_tokens = (predictions == labels) & non_pad_mask # (batch_size, seq_len) boolean

    # Check if all non-padded tokens in each sequence are correct
    # Sum correct tokens and mask per sequence. If sums match, the sequence is correct.
    correct_sequences = (torch.sum(correct_tokens, dim=1) == torch.sum(non_pad_mask, dim=1))

    # Calculate mean accuracy across the batch
    accuracy = torch.mean(correct_sequences.float()) # Convert bool tensor to float for mean

    return accuracy.item() # Return as Python float


# --- Main Execution ---

def main():
    """Orchestrates the SSM model training and evaluation pipeline."""
    print("[INFO] Starting Custom SSM Sequence-to-Sequence Script...")
    try:
        current_time_str = datetime.datetime.now(datetime.timezone.utc).strftime('%Y-%m-%d %H:%M:%S UTC')
    except Exception:
        current_time_str = datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC (naive)')
    print(f"[INFO] Current date/time (UTC): {current_time_str}")

    # --- 1. Setup Device ---
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"[INFO] Using GPU: {torch.cuda.get_device_name(0)}")
    elif torch.backends.mps.is_available(): # Check for Apple Silicon GPU
         device = torch.device("mps")
         print("[INFO] Using MPS device (Apple Silicon GPU)")
    else:
        device = torch.device("cpu")
        print("[INFO] Using CPU")

    # Create output directory if it doesn't exist
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    # --- 2. Initialize Tokenizer ---
    print(f"[INFO] Initializing Tokenizer: {TOKENIZER_ID}")
    try:
        tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID)
        # Check for pad token; add if missing (though BERT usually has one)
        if tokenizer.pad_token is None:
            print("[Warning] Tokenizer does not have a default pad token. Adding '[PAD]'.")
            tokenizer.add_special_tokens({'pad_token': '[PAD]'})
            # Resize model embeddings later if vocab size changes
    except Exception as e:
        print(f"[FATAL] Failed to initialize tokenizer '{TOKENIZER_ID}': {e}", file=sys.stderr)
        sys.exit(1)

    vocab_size = tokenizer.vocab_size
    pad_token_id = tokenizer.pad_token_id
    print(f"[INFO] Tokenizer Vocab Size: {vocab_size}, Pad Token ID: {pad_token_id}")

    # --- 3. Load and Prepare Data ---
    try:
        train_raw = load_jsonl(TRAIN_FILE)
        val_raw = load_jsonl(VAL_FILE)
        test_raw = load_jsonl(TEST_FILE)

        train_src, train_tgt = convert_tokens_to_strings(train_raw)
        val_src,   val_tgt   = convert_tokens_to_strings(val_raw)
        test_src,  test_tgt  = convert_tokens_to_strings(test_raw)

        # Create Datasets (with pre-tokenization)
        train_dataset = PretokenizedSequenceDataset(train_src, train_tgt, tokenizer, MAX_SEQ_LENGTH)
        val_dataset   = PretokenizedSequenceDataset(val_src, val_tgt, tokenizer, MAX_SEQ_LENGTH)
        test_dataset  = PretokenizedSequenceDataset(test_src, test_tgt, tokenizer, MAX_SEQ_LENGTH)

    except Exception as e:
        print(f"[FATAL] Failed during data loading or preprocessing: {e}", file=sys.stderr)
        sys.exit(1)

    if not train_dataset or not val_dataset or not test_dataset:
         print("[FATAL] One or more datasets are empty after processing. Exiting.", file=sys.stderr)
         sys.exit(1)

    # --- 4. Create DataLoaders ---
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
    val_dataloader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
    test_dataloader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
    print("[INFO] DataLoaders created.")

    # --- 5. Initialize Model ---
    print("[INFO] Initializing SSMEncoderDecoder model...")
    # Define SSM specific arguments based on configuration
    ssm_params = {
        "d_state": MAMBA_D_STATE,
        "d_conv": MAMBA_D_CONV,
        "expand": MAMBA_EXPAND
        # Add other args like 'kernel_size' if your StateSpaceLayer uses it
    }
    model = SSMEncoderDecoder(
        vocab_size=vocab_size,
        d_model=D_MODEL,
        n_layers=N_LAYERS,
        max_len=MAX_SEQ_LENGTH,
        pad_token_id=pad_token_id,
        **ssm_params
    )

    # If pad token was added, resize embeddings
    if len(tokenizer) > vocab_size: # Vocab size increased
        print(f"[INFO] Resizing token embeddings to accommodate new vocab size: {len(tokenizer)}")
        model.resize_token_embeddings(len(tokenizer))
        # Ensure new vocab size is used if needed elsewhere
        # vocab_size = len(tokenizer) # Update if needed

    model.to(device) # Move model to the appropriate device

    # --- 6. Initialize Optimizer and Scheduler ---
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=LEARNING_RATE,
        eps=OPTIMIZER_EPS,
        weight_decay=WEIGHT_DECAY
    )

    total_training_steps = len(train_dataloader) * NUM_EPOCHS
    num_warmup_steps = int(total_training_steps * WARMUP_RATIO)

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=total_training_steps
    )
    print("[INFO] Optimizer and Learning Rate Scheduler initialized.")

    # --- 7. Training Loop ---
    print(f"[INFO] Starting training for {NUM_EPOCHS} epochs...")
    best_val_accuracy = -1.0
    best_epoch = -1

    for epoch in range(NUM_EPOCHS):
        print(f"\n--- Epoch {epoch+1}/{NUM_EPOCHS} ---")

        # --- Training Phase ---
        model.train() # Set model to training mode
        total_train_loss = 0.0
        train_pbar = tqdm(train_dataloader, desc=f"Epoch {epoch+1} Training", leave=False)

        for step, batch in enumerate(train_pbar):
            # Move batch to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device) # May be unused by model but good practice
            labels = batch['labels'].to(device)

            # Forward pass
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs['loss']

            if loss is None:
                 print(f"[Warning] Loss is None at step {step}. Skipping backward pass.")
                 continue

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            # Optional: Gradient clipping
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step() # Update learning rate

            total_train_loss += loss.item()

            # Log training loss periodically
            if (step + 1) % LOG_INTERVAL == 0:
                avg_loss = total_train_loss / LOG_INTERVAL
                train_pbar.set_postfix({'Avg Loss': f'{avg_loss:.4f}', 'LR': f'{scheduler.get_last_lr()[0]:.2e}'})
                total_train_loss = 0.0 # Reset loss accumulator

        train_pbar.close()

        # --- Validation Phase ---
        model.eval() # Set model to evaluation mode
        total_val_accuracy = 0.0
        total_val_samples = 0
        val_pbar = tqdm(val_dataloader, desc=f"Epoch {epoch+1} Validation", leave=False)

        with torch.no_grad(): # Disable gradient calculations
            for batch in val_pbar:
                # Move batch to device
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)

                # Forward pass to get logits
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=None) # No labels needed for inference here
                logits = outputs['logits']

                # Calculate accuracy for this batch
                batch_accuracy = calculate_sequence_accuracy(logits, labels, pad_token_id)
                batch_size = input_ids.size(0)
                total_val_accuracy += batch_accuracy * batch_size # Accumulate weighted accuracy
                total_val_samples += batch_size

        val_pbar.close()
        epoch_val_accuracy = total_val_accuracy / total_val_samples if total_val_samples > 0 else 0.0
        print(f"Epoch {epoch+1} Validation Sequence Accuracy: {epoch_val_accuracy:.4f}")

        # --- Save Best Model ---
        if epoch_val_accuracy > best_val_accuracy:
            best_val_accuracy = epoch_val_accuracy
            best_epoch = epoch + 1
            save_path = OUTPUT_DIR / CHECKPOINT_NAME
            try:
                torch.save(model.state_dict(), save_path)
                print(f"[INFO] New best model saved to {save_path} (Epoch {best_epoch}, Val Acc: {best_val_accuracy:.4f})")
            except Exception as e:
                print(f"[Error] Failed to save model checkpoint: {e}", file=sys.stderr)

    print(f"\n[INFO] Training complete. Best validation accuracy: {best_val_accuracy:.4f} at epoch {best_epoch}.")

    # --- 8. Final Test Evaluation ---
    print("\n--- Final Test Set Evaluation ---")
    # Load the best model checkpoint
    best_model_path = OUTPUT_DIR / CHECKPOINT_NAME
    if best_model_path.exists():
        print(f"[INFO] Loading best model from: {best_model_path}")
        try:
            model.load_state_dict(torch.load(best_model_path, map_location=device))
        except Exception as e:
             print(f"[Error] Failed to load best model checkpoint: {e}. Evaluating with the last state.", file=sys.stderr)
    else:
        print("[Warning] Best model checkpoint not found. Evaluating with the final model state.", file=sys.stderr)

    model.eval() # Ensure model is in eval mode
    total_test_accuracy = 0.0
    total_test_samples = 0
    test_pbar = tqdm(test_dataloader, desc="Testing", leave=False)

    with torch.no_grad():
        for batch in test_pbar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=None)
            logits = outputs['logits']

            batch_accuracy = calculate_sequence_accuracy(logits, labels, pad_token_id)
            batch_size = input_ids.size(0)
            total_test_accuracy += batch_accuracy * batch_size
            total_test_samples += batch_size

    test_pbar.close()
    final_test_accuracy = total_test_accuracy / total_test_samples if total_test_samples > 0 else 0.0
    print(f"\nFinal Test Sequence Accuracy: {final_test_accuracy:.4f}")

    print("\n[INFO] Script finished successfully.")


if __name__ == "__main__":
    # Optional: Add argument parsing (argparse) here for flexibility
    # Optional: Set random seeds for reproducibility
    # SEED = 42
    # torch.manual_seed(SEED)
    # np.random.seed(SEED)
    # import random; random.seed(SEED)
    # if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)

    main()

[ERROR] 'mamba_ssm' library not found. Please install it (`pip install mamba_ssm causal-conv1d>=1.1.0`) or replace StateSpaceLayer with your implementation.


[INFO] Starting Custom SSM Sequence-to-Sequence Script...
[INFO] Current date/time (UTC): 2025-04-08 17:34:48 UTC
[INFO] Using CPU
[INFO] Initializing Tokenizer: bert-base-uncased
[INFO] Tokenizer Vocab Size: 30522, Pad Token ID: 0
[INFO] Loading data from: qed_expressions_train.jsonl
[INFO] Successfully loaded 12441 records.
[INFO] Loading data from: qed_expressions_val.jsonl
[INFO] Successfully loaded 1555 records.
[INFO] Loading data from: qed_expressions_test.jsonl
[INFO] Successfully loaded 1556 records.
[INFO] Converted 12441 items to strings (skipped 0).
[INFO] Converted 1555 items to strings (skipped 0).
[INFO] Converted 1556 items to strings (skipped 0).
[INFO] Pre-tokenizing dataset with max_length=256...
[INFO] Pre-tokenization complete. Dataset size: 12441 examples.
[INFO] Pre-tokenizing dataset with max_length=256...
[INFO] Pre-tokenization complete. Dataset size: 1555 examples.
[INFO] Pre-tokenizing dataset with max_length=256...
[INFO] Pre-tokenization complete. Dataset 

Epoch 1 Training:   0%|          | 0/778 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already bee

AttributeError: 'list' object has no attribute 'to'

# 3.3: Transformer Models for Symbolic Regression
Model: Transformer model with a contemporary innovation added such as KAN layers, reinforcement learning, genetic algorithms, specialized long-sequence attention, etc. which improves the performance compared to a basic transformer.


In [8]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Fine-tuning Script for a Reformer Encoder-Decoder Model on Symbolic Regression Data.

This script fine-tunes a Reformer model (`google/reformer-enwik8`) configured
as an encoder-decoder for a sequence-to-sequence task, likely symbolic regression,
using preprocessed data. Reformer employs Locality-Sensitive Hashing (LSH)
attention for computational and memory efficiency, particularly suitable for
longer sequences.

Key features and techniques utilized:
- Reformer Model: Uses LSH attention and other efficiency techniques.
- Encoder-Decoder Architecture: Standard setup for sequence-to-sequence tasks.
- Gradient Checkpointing: Reduces GPU memory usage during training.
- Mixed Precision Training (FP16): Accelerates training and saves memory on
  compatible GPUs.
- Label Smoothing: Regularizes the model to prevent overconfidence.
- Beam Search Generation: Used during evaluation for potentially better outputs.
- Exact Match Accuracy: Evaluation metric comparing decoded generated sequences
  against reference sequences.
- Hugging Face Trainer API: Leverages `Seq2SeqTrainer` for streamlined training
  and evaluation.

Workflow:
1. Load pre-split training, validation, and test datasets from JSONL files.
2. Reconstruct source and target strings from token lists.
3. Initialize the Reformer tokenizer and build an Encoder-Decoder model from
   pre-trained Reformer weights.
4. Configure the model (special tokens, tie weights, gradient checkpointing).
5. Define a function to tokenize source/target string pairs efficiently.
6. Wrap tokenized data into PyTorch Dataset objects.
7. Instantiate `DataCollatorForSeq2Seq` for dynamic padding and label handling.
8. Define the `compute_metrics` function for sequence accuracy calculation.
9. Configure `Seq2SeqTrainingArguments` with hyperparameters and advanced features.
10. Initialize and run the `Seq2SeqTrainer`.
11. Evaluate the final model on the test set.

Note on Tokenizer Choice: The script uses `google/reformer-enwik8` tokenizer,
which is character-level based. Ensure this aligns with the tokenization used
in your preprocessed `data_*.jsonl` files. If your data uses subword tokens
(like from BERT or T5), using this character-level tokenizer might lead to
suboptimal results or errors. Adjust `TOKENIZER_ID` and potentially the
`MODEL_ID` if necessary.
"""

import json
import sys
import numpy as np
import torch
import datetime
from pathlib import Path
from torch.utils.data import Dataset
from transformers import (
    ReformerTokenizer,
    ReformerModel,            # Base Reformer model (used for encoder/decoder components)
    EncoderDecoderModel,      # Wrapper to combine encoder and decoder
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    TrainingArguments       # Explicit import
)
from tqdm.auto import tqdm # Optional, if adding manual loops or progress bars later

# --- Configuration ---

# File Paths
TRAIN_FILE = Path('qed_expressions_train.jsonl') # Assumes qed_expressions prefix from preprocessing
VAL_FILE   = Path('qed_expressions_val.jsonl')
TEST_FILE  = Path('qed_expressions_test.jsonl')
OUTPUT_DIR = Path('reformer_symbolic_regression_output')

# Model Configuration
# Warning: google/reformer-enwik8 is character-level. Verify compatibility with your data tokens.
MODEL_ID = "google/reformer-enwik8" # Pre-trained Reformer model identifier
TOKENIZER_ID = "google/reformer-enwik8" # Typically same as model ID

# Tokenizer and Data Processing Configuration
MAX_SEQ_LENGTH = 256 # Maximum sequence length (Reformer can handle longer, adjust if needed)

# Training Hyperparameters
NUM_TRAIN_EPOCHS = 5
# Reformer can be memory intensive; adjust batch size carefully.
PER_DEVICE_TRAIN_BATCH_SIZE = 8
PER_DEVICE_EVAL_BATCH_SIZE = 8
LEARNING_RATE = 3e-4          # Starting point for fine-tuning
WEIGHT_DECAY = 0.01           # Weight decay regularization
WARMUP_STEPS = 300            # Linear LR warmup steps
LOGGING_STEPS = 50            # Log metrics frequency
EVALUATION_STRATEGY = "epoch"
SAVE_STRATEGY = "epoch"

# Advanced Training Feature Flags/Values
USE_FP16 = torch.cuda.is_available() # Enable mixed precision if CUDA available
USE_GRADIENT_CHECKPOINTING = True    # Enable gradient checkpointing
LABEL_SMOOTHING_FACTOR = 0.1         # Apply label smoothing (0.0 disables)

# Generation Configuration (for evaluation)
GENERATION_NUM_BEAMS = 4             # Beam search width

# --- Helper Functions and Classes ---

def load_jsonl(file_path):
    """Loads data from a JSON Lines file, handling basic errors."""
    data = []
    file_path = Path(file_path)
    print(f"[INFO] Loading data from: {file_path}")
    if not file_path.is_file():
        print(f"[Error] Data file not found: {file_path}", file=sys.stderr)
        raise FileNotFoundError(f"File not found: {file_path}")

    try:
        with file_path.open('r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                line = line.strip()
                if line:
                    try:
                        data.append(json.loads(line))
                    except json.JSONDecodeError as e:
                        print(f"[Warning] Skipping invalid JSON on line {i+1} in {file_path}: {e}")
                        continue
        print(f"[INFO] Successfully loaded {len(data)} records.")
        if not data:
             print(f"[Warning] No valid records loaded from {file_path}.")
        return data
    except Exception as e:
        print(f"[Error] Failed to load data from {file_path}: {e}", file=sys.stderr)
        raise


def convert_tokens_to_strings(raw_data_list):
    """Converts lists of tokens from loaded data into source and target strings."""
    source_strings = []
    target_strings = []
    if not raw_data_list:
        return source_strings, target_strings

    skipped_count = 0
    for i, item in enumerate(raw_data_list):
        input_toks = item.get('input_tokens')
        target_toks = item.get('target_tokens')

        if isinstance(input_toks, list) and isinstance(target_toks, list):
            source_strings.append(" ".join(map(str, input_toks)))
            target_strings.append(" ".join(map(str, target_toks)))
        else:
            print(f"[Warning] Skipping item at index {i} due to missing/invalid keys or non-list values: {item}")
            skipped_count += 1

    print(f"[INFO] Converted {len(source_strings)} items to strings (skipped {skipped_count}).")
    if not source_strings:
        print("[Warning] No items successfully converted to strings.")
    return source_strings, target_strings


def encode_sequences(tokenizer, source_texts, target_texts, max_len):
    """
    Tokenizes source and target text sequences using the provided tokenizer.

    Args:
        tokenizer: Initialized Hugging Face tokenizer instance.
        source_texts (list[str]): List of source sequences for the encoder.
        target_texts (list[str]): List of target sequences for the decoder labels.
        max_len (int): Maximum sequence length for padding and truncation.

    Returns:
        dict: Dictionary containing lists of 'input_ids', 'attention_mask', and 'labels'.
    """
    print(f"[INFO] Encoding sequence pairs with max_length={max_len}...")

    # Tokenize source texts (encoder input)
    encoder_inputs = tokenizer(
        source_texts,
        max_length=max_len,
        padding='max_length',   # Pad to max_len
        truncation=True,        # Truncate sequences longer than max_len
        return_tensors=None     # Return lists
    )

    # Tokenize target texts to create labels
    with tokenizer.as_target_tokenizer():
        decoder_labels = tokenizer(
            target_texts,
            max_length=max_len,
            padding='max_length',
            truncation=True,
            return_tensors=None
        )

    # Assign the tokenized target IDs as 'labels'
    encoder_inputs['labels'] = decoder_labels['input_ids']

    print(f"[INFO] Encoding complete.")
    # Basic validation
    if not encoder_inputs.get('input_ids') or not encoder_inputs.get('labels'):
         print("[Warning] Encoding resulted in empty input_ids or labels.")
    elif len(encoder_inputs['input_ids']) != len(encoder_inputs['labels']):
         print("[Warning] Mismatch in length between encoded inputs and labels.")

    return encoder_inputs


class SequencePairDataset(Dataset):
    """Simple PyTorch Dataset for holding tokenized sequence pair data (as lists)."""
    def __init__(self, encodings):
        """
        Args:
            encodings (dict): Dictionary from tokenizer, e.g.,
                              {'input_ids': [[...], ...], 'attention_mask': [[...], ...], 'labels': [[...], ...]}.
                              Values should be lists of lists/token IDs.
        """
        if not isinstance(encodings, dict) or 'input_ids' not in encodings:
            raise ValueError("Encodings must be a dictionary containing at least 'input_ids'.")
        self.encodings = encodings
        try:
            self.length = len(encodings['input_ids'])
            # Validate other keys have the same length
            for key in encodings:
                if not isinstance(encodings[key], list) or len(encodings[key]) != self.length:
                    raise ValueError(f"Encoding key '{key}' is not a list or has inconsistent length.")
        except Exception as e:
             raise ValueError(f"Failed to validate encodings: {e}")

        if self.length == 0:
            raise ValueError("Input encodings are empty.")
        print(f"[INFO] Created Dataset with {self.length} examples.")

    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return self.length

    def __getitem__(self, idx):
        """Retrieves the tokenized data (as lists) for a single sample index."""
        if not 0 <= idx < self.length:
             raise IndexError(f"Index {idx} out of bounds for dataset of length {self.length}.")
        try:
            # Return a dictionary containing the data for the specified index
            # DataCollator will handle conversion to tensors.
            return {key: self.encodings[key][idx] for key in self.encodings}
        except Exception as e:
            print(f"[Error] Failed to retrieve item at index {idx}: {e}", file=sys.stderr)
            raise


# --- Main Execution Logic ---

def main():
    """Orchestrates the Reformer fine-tuning and evaluation pipeline."""
    print("[INFO] Starting Reformer Fine-tuning Script for Symbolic Regression...")
    try:
        current_time_str = datetime.datetime.now(datetime.timezone.utc).strftime('%Y-%m-%d %H:%M:%S UTC')
    except Exception:
        current_time_str = datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC (naive)')
    print(f"[INFO] Current date/time (UTC): {current_time_str}")
    print(f"[INFO] Using model: {MODEL_ID} with Tokenizer: {TOKENIZER_ID}")

    # --- 1. Load Data ---
    try:
        train_raw_data = load_jsonl(TRAIN_FILE)
        val_raw_data = load_jsonl(VAL_FILE)
        test_raw_data = load_jsonl(TEST_FILE)
    except Exception as e:
        print(f"[FATAL] Critical error during data loading: {e}", file=sys.stderr)
        sys.exit(1)

    if not train_raw_data or not val_raw_data or not test_raw_data:
        print("[FATAL] One or more required datasets are empty or failed to load. Exiting.", file=sys.stderr)
        sys.exit(1)

    # --- 2. Prepare Data (Tokens to Strings) ---
    train_sources, train_targets = convert_tokens_to_strings(train_raw_data)
    val_sources,   val_targets   = convert_tokens_to_strings(val_raw_data)
    test_sources,  test_targets  = convert_tokens_to_strings(test_raw_data)

    if not train_sources or not val_sources or not test_sources:
        print("[FATAL] Data conversion resulted in empty lists. Check input data format. Exiting.", file=sys.stderr)
        sys.exit(1)

    # --- 3. Initialize Tokenizer ---
    print(f"[INFO] Initializing Tokenizer: {TOKENIZER_ID}")
    try:
        # ReformerTokenizer requires sentencepiece if not installed
        # pip install sentencepiece
        tokenizer = ReformerTokenizer.from_pretrained(TOKENIZER_ID)
        # Check for necessary special tokens
        if tokenizer.pad_token is None: tokenizer.add_special_tokens({'pad_token': '<pad>'}) # Reformer uses <pad>
        if tokenizer.cls_token is None: tokenizer.add_special_tokens({'cls_token': '<s>'}) # Common start token
        if tokenizer.sep_token is None: tokenizer.add_special_tokens({'sep_token': '</s>'}) # Common end token

    except Exception as e:
        print(f"[FATAL] Failed to initialize tokenizer '{TOKENIZER_ID}': {e}", file=sys.stderr)
        sys.exit(1)

    # Store for metrics function
    global tokenizer_for_metrics
    tokenizer_for_metrics = tokenizer
    pad_token_id = tokenizer.pad_token_id
    print(f"[INFO] Tokenizer Pad Token ID: {pad_token_id}")
    print(f"[INFO] Tokenizer Vocab Size (initial): {tokenizer.vocab_size}")

    # --- 4. Tokenize Data ---
    try:
        train_encodings = encode_sequences(tokenizer, train_sources, train_targets, MAX_SEQ_LENGTH)
        val_encodings   = encode_sequences(tokenizer, val_sources,   val_targets,   MAX_SEQ_LENGTH)
        test_encodings  = encode_sequences(tokenizer, test_sources,  test_targets,  MAX_SEQ_LENGTH)
    except Exception as e:
        print(f"[FATAL] Failed during data tokenization: {e}", file=sys.stderr)
        sys.exit(1)

    # Validate encodings
    if not train_encodings.get('input_ids') or not val_encodings.get('input_ids') or not test_encodings.get('input_ids'):
         print("[FATAL] Tokenization resulted in empty encodings for one or more splits. Exiting.", file=sys.stderr)
         sys.exit(1)


    # --- 5. Create Datasets ---
    try:
        train_dataset = SequencePairDataset(train_encodings)
        val_dataset   = SequencePairDataset(val_encodings)
        test_dataset  = SequencePairDataset(test_encodings)
    except ValueError as e:
        print(f"[FATAL] Failed to create Dataset objects: {e}", file=sys.stderr)
        sys.exit(1)

    # --- 6. Initialize Model ---
    print(f"[INFO] Initializing Reformer Encoder-Decoder Model from: {MODEL_ID}")
    try:
        # Create encoder-decoder model from Reformer base model
        # This loads the Reformer weights into both encoder and decoder components.
        model = EncoderDecoderModel.from_encoder_decoder_pretrained(MODEL_ID, MODEL_ID)

        # Resize embeddings if tokenizer vocab size changed due to added special tokens
        if model.config.encoder.vocab_size != len(tokenizer):
             print(f"[INFO] Resizing model embeddings from {model.config.encoder.vocab_size} to {len(tokenizer)}")
             model.resize_token_embeddings(len(tokenizer))
             # Ensure config reflects this if needed later, although resize_token_embeddings usually updates it.
             model.config.encoder.vocab_size = len(tokenizer)
             model.config.decoder.vocab_size = len(tokenizer)

        # Configure model for sequence-to-sequence tasks
        model.config.decoder_start_token_id = tokenizer.cls_token_id # Use CLS as start token
        model.config.eos_token_id = tokenizer.sep_token_id           # Use SEP as end token
        model.config.pad_token_id = tokenizer.pad_token_id           # Use PAD for padding

        # Important for Reformer: Set decoder's is_decoder flag and add cross-attention
        model.config.decoder.is_decoder = True
        model.config.decoder.add_cross_attention = True

        # Tie weights between encoder and decoder (embeddings and potentially output projection)
        # Reduces parameter count and often improves performance for shared vocabularies.
        print("[INFO] Tying encoder and decoder weights.")
        model.tie_weights() # Ties input/output embeddings usually

        # Set generation parameters
        model.config.max_length = MAX_SEQ_LENGTH
        model.config.early_stopping = True # Optional: For generation
        model.config.num_beams = GENERATION_NUM_BEAMS # Default beam size for generation

        # Enable Gradient Checkpointing if configured
        if USE_GRADIENT_CHECKPOINTING:
            model.gradient_checkpointing_enable()
            print("[INFO] Gradient Checkpointing enabled on the model.")

    except Exception as e:
        print(f"[FATAL] Failed to initialize or configure the Reformer model: {e}", file=sys.stderr)
        sys.exit(1)

    # --- 7. Initialize Data Collator ---
    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=model, # Pass model for automatic decoder_input_ids creation
        label_pad_token_id=-100, # Ignore padding in loss calculation
        pad_to_multiple_of=8 if USE_FP16 else None # Pad efficiently for FP16
    )
    print("[INFO] Data collator initialized.")

    # --- 8. Define Metrics Computation ---
    def compute_metrics_fn(eval_pred):
        """Calculates exact sequence match accuracy after decoding."""
        predictions, labels = eval_pred
        if not isinstance(predictions, np.ndarray): predictions = np.array(predictions)
        if not isinstance(labels, np.ndarray): labels = np.array(labels)

        # Replace -100 (ignore index) with pad_token_id for decoding
        labels = np.where(labels != -100, labels, tokenizer_for_metrics.pad_token_id)

        try:
            decoded_preds = tokenizer_for_metrics.batch_decode(predictions, skip_special_tokens=True)
            decoded_labels = tokenizer_for_metrics.batch_decode(labels, skip_special_tokens=True)
        except Exception as e:
             print(f"[Error] Decoding failed in compute_metrics: {e}", file=sys.stderr)
             return {"sequence_accuracy": 0.0}

        # Post-process: strip whitespace
        decoded_preds = [pred.strip() for pred in decoded_preds]
        decoded_labels = [label.strip() for label in decoded_labels]

        if len(decoded_preds) != len(decoded_labels):
            print(f"[Warning] Mismatch prediction/label count: {len(decoded_preds)} vs {len(decoded_labels)}", file=sys.stderr)
            return {"sequence_accuracy": 0.0}

        matches = [pred == label for pred, label in zip(decoded_preds, decoded_labels)]
        accuracy = np.mean(matches) if matches else 0.0
        return {"sequence_accuracy": float(accuracy)}

    # --- 9. Define Training Arguments ---
    training_args = Seq2SeqTrainingArguments(
        output_dir=str(OUTPUT_DIR),
        # Training schedule
        num_train_epochs=NUM_TRAIN_EPOCHS,
        per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
        per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        warmup_steps=WARMUP_STEPS,
        # Logging, saving, evaluation
        logging_dir=str(OUTPUT_DIR / 'logs'),
        logging_steps=LOGGING_STEPS,
        evaluation_strategy=EVALUATION_STRATEGY,
        save_strategy=SAVE_STRATEGY,
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model="sequence_accuracy",
        greater_is_better=True,
        # Generation configuration
        predict_with_generate=True,
        generation_max_length=MAX_SEQ_LENGTH,
        generation_num_beams=GENERATION_NUM_BEAMS,
        # Advanced features
        fp16=USE_FP16,
        label_smoothing_factor=LABEL_SMOOTHING_FACTOR,
        # Pass gradient checkpointing flag based on earlier successful enablement
        # Note: Reformer might have internal GC controls too; check docs if issues arise.
        gradient_checkpointing=(USE_GRADIENT_CHECKPOINTING and hasattr(model, 'is_gradient_checkpointing') and model.is_gradient_checkpointing),
        # report_to="tensorboard", # Optional
    )
    print("[INFO] Training arguments defined.")
    print(f"[INFO] Effective Mixed Precision (FP16): {'Enabled' if USE_FP16 else 'Disabled'}")
    print(f"[INFO] Effective Label Smoothing Factor: {LABEL_SMOOTHING_FACTOR}")
    print(f"[INFO] Effective Gradient Checkpointing in Args: {training_args.gradient_checkpointing}")


    # --- 10. Initialize Trainer ---
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics_fn,
    )
    print("[INFO] Seq2SeqTrainer initialized.")

    # --- 11. Train the Model ---
    print(f"[INFO] Starting model training for {NUM_TRAIN_EPOCHS} epochs...")
    try:
        train_result = trainer.train()
        # Log/save final metrics and state
        metrics = train_result.metrics
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
        print("[INFO] Training finished successfully.")
        if trainer.state.best_model_checkpoint:
             print(f"[INFO] Best model checkpoint saved at: {trainer.state.best_model_checkpoint}")
        else:
             print("[Warning] No best model checkpoint recorded.")

    except Exception as e:
        print(f"[FATAL] Training loop encountered an error: {e}", file=sys.stderr)
        sys.exit(1)

    # --- 12. Evaluate on Test Set ---
    print("[INFO] Evaluating final model on the test set...")
    try:
        test_results = trainer.evaluate(
            eval_dataset=test_dataset,
            metric_key_prefix="test"
        )
        trainer.log_metrics("test", test_results)
        trainer.save_metrics("test", test_results)

        if 'test_sequence_accuracy' in test_results:
             print(f"\n--- Test Set Results ---")
             print(f"Test Sequence Accuracy: {test_results['test_sequence_accuracy']:.4f}")
             print(f"------------------------")
        else:
             print("[Warning] 'test_sequence_accuracy' not found in test results.")
             print("Full test results:", test_results)

    except Exception as e:
        print(f"[Error] Evaluation on test set failed: {e}", file=sys.stderr)
        sys.exit(1) # Exit with error status

    print("[INFO] Script finished successfully.")


if __name__ == "__main__":
    # Optional: Add command-line argument parsing (argparse)
    # Optional: Set random seeds for reproducibility
    # SEED = 42
    # torch.manual_seed(SEED); np.random.seed(SEED); import random; random.seed(SEED)
    # if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)

    main()

[INFO] Starting Reformer Fine-tuning Script for Symbolic Regression...
[INFO] Current date/time (UTC): 2025-04-08 17:37:32 UTC
[INFO] Using model: google/reformer-enwik8 with Tokenizer: google/reformer-enwik8
[INFO] Loading data from: qed_expressions_train.jsonl
[INFO] Successfully loaded 12441 records.
[INFO] Loading data from: qed_expressions_val.jsonl
[INFO] Successfully loaded 1555 records.
[INFO] Loading data from: qed_expressions_test.jsonl
[INFO] Successfully loaded 1556 records.
[INFO] Converted 12441 items to strings (skipped 0).
[INFO] Converted 1555 items to strings (skipped 0).
[INFO] Converted 1556 items to strings (skipped 0).
[INFO] Initializing Tokenizer: google/reformer-enwik8


[FATAL] Failed to initialize tokenizer 'google/reformer-enwik8': 
ReformerTokenizer requires the SentencePiece library but it was not found in your environment. Checkout the instructions on the
installation page of its repo: https://github.com/google/sentencepiece#installation and follow the ones
that match your environment. Please note that you may need to restart your runtime after installation.



SystemExit: 1

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


# 3.4: Titans for squared amplitude calculation
Model: One of the core architectures from Google’s paper introducing Titans concept


In [10]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
End-to-end Fine-tuning Script for a Hypothetical "Titan" Sequence-to-Sequence Model
on the Preprocessed QED Squared Amplitude Calculation Dataset.

This script demonstrates fine-tuning a sequence-to-sequence model, referred to
here as "Titan" (e.g., `google/titan-small`), presumably incorporating unique
architectural features like time-mixing and recursion modules as specified in
the configuration. It utilizes the Hugging Face Transformers library, including
the `Seq2SeqTrainer` API.

Key features demonstrated:
- Hypothetical Titan Model: Assumes existence of `TitanTokenizer` and
  `TitanForConditionalGeneration` with specific config flags (`use_time_mixing`,
  `use_recursion`).
- Advanced Training Techniques: Includes Gradient Checkpointing, Mixed Precision
  Training (FP16), Label Smoothing, and Beam Search during evaluation.
- Standard Workflow: Follows a common pattern of data loading, preprocessing,
  tokenization, model configuration, training, and evaluation.
- Exact Match Accuracy: Uses sequence-level exact match as the evaluation metric.

Workflow:
1. Load pre-split data (train, validation, test) from JSONL files.
2. Reconstruct source (input amplitude) and target (squared amplitude) strings.
3. Initialize the Titan tokenizer and model.
4. Configure Titan-specific features (time-mixing, recursion) and enable
   gradient checkpointing on the model instance.
5. Define a function to tokenize source/target string pairs.
6. Create PyTorch Dataset objects from the tokenized data.
7. Instantiate `DataCollatorForSeq2Seq` for dynamic batching.
8. Define the `compute_metrics` function for evaluation using decoded sequences.
9. Configure `Seq2SeqTrainingArguments` with hyperparameters and features.
10. Initialize and run the `Seq2SeqTrainer`.
11. Evaluate the final model on the test set.
"""

import json
import sys
import numpy as np
import torch
import datetime
from pathlib import Path
from torch.utils.data import Dataset
# --- Hypothetical Titan Model Imports ---
# Ensure these classes exist in your transformers installation or define placeholders
try:
    from transformers import (
        TitanTokenizer,                # Hypothetical Tokenizer
        TitanForConditionalGeneration, # Hypothetical Model
        DataCollatorForSeq2Seq,
        Seq2SeqTrainer,
        Seq2SeqTrainingArguments,
        TrainingArguments           # Explicit import
    )
    print("[INFO] Successfully imported hypothetical Titan classes.")
except ImportError:
    print("[ERROR] Failed to import TitanTokenizer or TitanForConditionalGeneration.", file=sys.stderr)
    print("Please ensure these classes are defined in your transformers library or environment.", file=sys.stderr)
    # Define dummy classes if needed for script structure analysis:
    # class TitanTokenizer: @staticmethod def from_pretrained(s): raise NotImplementedError
    # class TitanForConditionalGeneration: @staticmethod def from_pretrained(s): raise NotImplementedError
    # from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments, TrainingArguments
    sys.exit(1) # Exit if the core classes are missing

from tqdm.auto import tqdm # Optional progress bars

# --- Configuration ---

# File Paths
TRAIN_FILE = Path('qed_expressions_train.jsonl') # Use consistent naming if applicable
VAL_FILE   = Path('qed_expressions_val.jsonl')
TEST_FILE  = Path('qed_expressions_test.jsonl')
OUTPUT_DIR = Path('titan_qed_squared_amplitude_output')

# Model Configuration
# Replace with actual model ID if 'google/titan-small' is a placeholder
MODEL_ID = "google/titan-small" # Hypothetical Titan model identifier
TOKENIZER_ID = "google/titan-small" # Usually same as model ID

# Titan-Specific Configuration (Hypothetical)
USE_TIME_MIXING = True
USE_RECURSION   = True

# Tokenizer and Data Processing Configuration
MAX_SEQ_LENGTH = 512 # Max sequence length

# Training Hyperparameters
NUM_TRAIN_EPOCHS = 4 # Adjust based on convergence
# Adjust batch size based on Titan model size and GPU memory
PER_DEVICE_TRAIN_BATCH_SIZE = 4
PER_DEVICE_EVAL_BATCH_SIZE = 8
LEARNING_RATE = 2e-5
WEIGHT_DECAY = 0.01
WARMUP_STEPS = 200
LOGGING_STEPS = 50
EVALUATION_STRATEGY = "epoch"
SAVE_STRATEGY = "epoch"

# Advanced Training Feature Flags/Values
USE_FP16 = torch.cuda.is_available()
USE_GRADIENT_CHECKPOINTING = True
LABEL_SMOOTHING_FACTOR = 0.1

# Generation Configuration (for evaluation)
GENERATION_NUM_BEAMS = 4

# --- Helper Functions and Classes ---

def load_jsonl(file_path):
    """Loads data from a JSON Lines file, handling basic errors."""
    data = []
    file_path = Path(file_path)
    print(f"[INFO] Loading data from: {file_path}")
    if not file_path.is_file():
        print(f"[Error] Data file not found: {file_path}", file=sys.stderr)
        raise FileNotFoundError(f"File not found: {file_path}")

    try:
        with file_path.open('r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                line = line.strip()
                if line:
                    try:
                        data.append(json.loads(line))
                    except json.JSONDecodeError as e:
                        print(f"[Warning] Skipping invalid JSON on line {i+1} in {file_path}: {e}")
                        continue
        print(f"[INFO] Successfully loaded {len(data)} records.")
        if not data:
             print(f"[Warning] No valid records loaded from {file_path}.")
        return data
    except Exception as e:
        print(f"[Error] Failed to load data from {file_path}: {e}", file=sys.stderr)
        raise


def convert_tokens_to_strings(raw_data_list):
    """Converts lists of tokens from loaded data into source and target strings."""
    source_strings = []
    target_strings = []
    if not raw_data_list:
        return source_strings, target_strings

    skipped_count = 0
    for i, item in enumerate(raw_data_list):
        input_toks = item.get('input_tokens')
        target_toks = item.get('target_tokens')

        if isinstance(input_toks, list) and isinstance(target_toks, list):
            source_strings.append(" ".join(map(str, input_toks)))
            target_strings.append(" ".join(map(str, target_toks)))
        else:
            print(f"[Warning] Skipping item at index {i} due to missing/invalid keys or non-list values: {item}")
            skipped_count += 1

    print(f"[INFO] Converted {len(source_strings)} items to strings (skipped {skipped_count}).")
    if not source_strings:
        print("[Warning] No items successfully converted to strings.")
    return source_strings, target_strings


def encode_sequences(tokenizer, source_texts, target_texts, max_len):
    """
    Tokenizes source and target text sequences using the provided tokenizer.

    Args:
        tokenizer: Initialized Hugging Face tokenizer instance (e.g., TitanTokenizer).
        source_texts (list[str]): List of source sequences.
        target_texts (list[str]): List of target sequences for labels.
        max_len (int): Maximum sequence length for padding and truncation.

    Returns:
        dict: Dictionary containing lists of 'input_ids', 'attention_mask', and 'labels'.
    """
    print(f"[INFO] Encoding sequence pairs with max_length={max_len}...")

    # Tokenize source texts
    encoder_inputs = tokenizer(
        source_texts,
        max_length=max_len,
        padding='max_length',
        truncation=True,
        return_tensors=None # Return lists
    )

    # Tokenize target texts to create labels
    with tokenizer.as_target_tokenizer():
        decoder_labels = tokenizer(
            target_texts,
            max_length=max_len,
            padding='max_length',
            truncation=True,
            return_tensors=None
        )

    # Assign target input_ids as 'labels'
    encoder_inputs['labels'] = decoder_labels['input_ids']

    print(f"[INFO] Encoding complete.")
    # Basic validation
    if not encoder_inputs.get('input_ids') or not encoder_inputs.get('labels'):
         print("[Warning] Encoding resulted in empty input_ids or labels.")
    elif len(encoder_inputs['input_ids']) != len(encoder_inputs['labels']):
         print("[Warning] Mismatch in length between encoded inputs and labels.")

    return encoder_inputs


class SequencePairDataset(Dataset):
    """Simple PyTorch Dataset for holding tokenized sequence pair data (as lists)."""
    def __init__(self, encodings):
        """
        Args:
            encodings (dict): Dictionary from tokenizer {'input_ids': [...], ...}.
                              Values should be lists of token IDs or masks.
        """
        if not isinstance(encodings, dict) or 'input_ids' not in encodings:
            raise ValueError("Encodings must be a dictionary containing at least 'input_ids'.")
        self.encodings = encodings
        try:
            self.length = len(encodings['input_ids'])
            for key in encodings:
                if not isinstance(encodings[key], list) or len(encodings[key]) != self.length:
                    raise ValueError(f"Encoding key '{key}' is not a list or has inconsistent length.")
        except Exception as e:
             raise ValueError(f"Failed to validate encodings: {e}")

        if self.length == 0:
            raise ValueError("Input encodings are empty.")
        print(f"[INFO] Created Dataset with {self.length} examples.")

    def __len__(self):
        """Returns the total number of samples."""
        return self.length

    def __getitem__(self, idx):
        """Retrieves the tokenized data (as lists) for a single sample index."""
        if not 0 <= idx < self.length:
             raise IndexError(f"Index {idx} out of bounds for dataset of length {self.length}.")
        try:
            return {key: self.encodings[key][idx] for key in self.encodings}
        except Exception as e:
            print(f"[Error] Failed to retrieve item at index {idx}: {e}", file=sys.stderr)
            raise

# --- Main Execution Logic ---

def main():
    """Orchestrates the Titan model fine-tuning and evaluation pipeline."""
    print("[INFO] Starting Hypothetical Titan Model Fine-tuning Script...")
    try:
        current_time_str = datetime.datetime.now(datetime.timezone.utc).strftime('%Y-%m-%d %H:%M:%S UTC')
    except Exception:
        current_time_str = datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC (naive)')
    print(f"[INFO] Current time: {current_time_str} (local: San Diego, CA)") # Include location context
    print(f"[INFO] Using model: {MODEL_ID} with Tokenizer: {TOKENIZER_ID}")

    # --- 1. Load Data ---
    try:
        train_raw_data = load_jsonl(TRAIN_FILE)
        val_raw_data = load_jsonl(VAL_FILE)
        test_raw_data = load_jsonl(TEST_FILE)
    except Exception as e:
        print(f"[FATAL] Critical error during data loading: {e}", file=sys.stderr)
        sys.exit(1)

    if not train_raw_data or not val_raw_data or not test_raw_data:
        print("[FATAL] One or more required datasets are empty or failed to load. Exiting.", file=sys.stderr)
        sys.exit(1)

    # --- 2. Prepare Data (Tokens to Strings) ---
    train_sources, train_targets = convert_tokens_to_strings(train_raw_data)
    val_sources,   val_targets   = convert_tokens_to_strings(val_raw_data)
    test_sources,  test_targets  = convert_tokens_to_strings(test_raw_data)

    if not train_sources or not val_sources or not test_sources:
        print("[FATAL] Data conversion resulted in empty lists. Check input data format. Exiting.", file=sys.stderr)
        sys.exit(1)

    # --- 3. Initialize Tokenizer ---
    print(f"[INFO] Initializing Tokenizer: {TOKENIZER_ID}")
    try:
        # Assuming TitanTokenizer exists and works like standard tokenizers
        tokenizer = TitanTokenizer.from_pretrained(TOKENIZER_ID)
        # Add checks/additions for special tokens if necessary for Titan
        if tokenizer.pad_token is None: tokenizer.add_special_tokens({'pad_token': '<pad>'})
        # Determine start/end tokens based on tokenizer properties or model requirements
        if tokenizer.bos_token is None: tokenizer.add_special_tokens({'bos_token': '<s>'})
        if tokenizer.eos_token is None: tokenizer.add_special_tokens({'eos_token': '</s>'})

    except Exception as e:
        print(f"[FATAL] Failed to initialize tokenizer '{TOKENIZER_ID}': {e}", file=sys.stderr)
        sys.exit(1)

    # Store globally for metric computation
    global tokenizer_for_metrics
    tokenizer_for_metrics = tokenizer
    pad_token_id = tokenizer.pad_token_id
    print(f"[INFO] Tokenizer Pad Token ID: {pad_token_id}")
    print(f"[INFO] Tokenizer Vocab Size (initial): {tokenizer.vocab_size}")


    # --- 4. Initialize Model ---
    print(f"[INFO] Initializing Model: {MODEL_ID}")
    try:
        model = TitanForConditionalGeneration.from_pretrained(MODEL_ID)

        # Resize embeddings if vocab size changed
        if model.config.vocab_size != len(tokenizer):
             print(f"[INFO] Resizing model embeddings from {model.config.vocab_size} to {len(tokenizer)}")
             model.resize_token_embeddings(len(tokenizer))
             model.config.vocab_size = len(tokenizer) # Ensure config is updated

        # Configure hypothetical Titan features - use try/except for robustness
        config_updated = False
        try:
            if USE_TIME_MIXING and hasattr(model.config, 'use_time_mixing'):
                model.config.use_time_mixing = True
                print("[INFO] Enabled model.config.use_time_mixing")
                config_updated = True
            elif USE_TIME_MIXING:
                 print("[Warning] Configured USE_TIME_MIXING=True, but 'use_time_mixing' not found in model config.")

            if USE_RECURSION and hasattr(model.config, 'use_recursion'):
                model.config.use_recursion = True
                print("[INFO] Enabled model.config.use_recursion")
                config_updated = True
            elif USE_RECURSION:
                 print("[Warning] Configured USE_RECURSION=True, but 'use_recursion' not found in model config.")

        except Exception as config_e:
             print(f"[Warning] Error applying hypothetical Titan config options: {config_e}")

        # Configure standard seq2seq settings
        model.config.decoder_start_token_id = tokenizer.bos_token_id
        model.config.eos_token_id = tokenizer.eos_token_id
        model.config.pad_token_id = tokenizer.pad_token_id

        # Set generation defaults
        model.config.max_length = MAX_SEQ_LENGTH
        model.config.num_beams = GENERATION_NUM_BEAMS

        # Enable Gradient Checkpointing if configured
        if USE_GRADIENT_CHECKPOINTING:
            try:
                 model.gradient_checkpointing_enable()
                 print("[INFO] Gradient Checkpointing enabled on the model.")
            except Exception as gc_e:
                 print(f"[Warning] Failed to enable gradient checkpointing on model: {gc_e}")
                 # Ensure training arg reflects this failure
                 global USE_GRADIENT_CHECKPOINTING_EFFECTIVE
                 USE_GRADIENT_CHECKPOINTING_EFFECTIVE = False
        else:
            USE_GRADIENT_CHECKPOINTING_EFFECTIVE = False


    except Exception as e:
        print(f"[FATAL] Failed to initialize or configure the Titan model '{MODEL_ID}': {e}", file=sys.stderr)
        sys.exit(1)

    # --- 5. Tokenize Data ---
    try:
        train_encodings = encode_sequences(tokenizer, train_sources, train_targets, MAX_SEQ_LENGTH)
        val_encodings   = encode_sequences(tokenizer, val_sources,   val_targets,   MAX_SEQ_LENGTH)
        test_encodings  = encode_sequences(tokenizer, test_sources,  test_targets,  MAX_SEQ_LENGTH)
    except Exception as e:
        print(f"[FATAL] Failed during data tokenization: {e}", file=sys.stderr)
        sys.exit(1)

    if not train_encodings.get('input_ids') or not val_encodings.get('input_ids') or not test_encodings.get('input_ids'):
         print("[FATAL] Tokenization resulted in empty encodings for one or more splits. Exiting.", file=sys.stderr)
         sys.exit(1)

    # --- 6. Create Datasets ---
    try:
        train_dataset = SequencePairDataset(train_encodings)
        val_dataset   = SequencePairDataset(val_encodings)
        test_dataset  = SequencePairDataset(test_encodings)
    except ValueError as e:
        print(f"[FATAL] Failed to create Dataset objects: {e}", file=sys.stderr)
        sys.exit(1)

    # --- 7. Initialize Data Collator ---
    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=model,
        label_pad_token_id=-100, # Ignore pad tokens in loss
        pad_to_multiple_of=8 if USE_FP16 else None # Optimize padding for FP16
    )
    print("[INFO] Data collator initialized.")

    # --- 8. Define Metrics Computation ---
    def compute_metrics_fn(eval_pred):
        """Calculates exact sequence match accuracy after decoding."""
        predictions, labels = eval_pred
        if not isinstance(predictions, np.ndarray): predictions = np.array(predictions)
        if not isinstance(labels, np.ndarray): labels = np.array(labels)

        # Replace -100 with pad_token_id for decoding
        labels = np.where(labels != -100, labels, tokenizer_for_metrics.pad_token_id)

        try:
            decoded_preds = tokenizer_for_metrics.batch_decode(predictions, skip_special_tokens=True)
            decoded_labels = tokenizer_for_metrics.batch_decode(labels, skip_special_tokens=True)
        except Exception as e:
             print(f"[Error] Decoding failed in compute_metrics: {e}", file=sys.stderr)
             return {"sequence_accuracy": 0.0}

        # Post-process and compare
        decoded_preds = [pred.strip() for pred in decoded_preds]
        decoded_labels = [label.strip() for label in decoded_labels]

        if len(decoded_preds) != len(decoded_labels):
            print(f"[Warning] Mismatch prediction/label count: {len(decoded_preds)} vs {len(decoded_labels)}", file=sys.stderr)
            return {"sequence_accuracy": 0.0}

        matches = [pred == label for pred, label in zip(decoded_preds, decoded_labels)]
        accuracy = np.mean(matches) if matches else 0.0
        return {"sequence_accuracy": float(accuracy)}

    # --- 9. Define Training Arguments ---
    # Use effective GC flag based on earlier attempt
    effective_gc = USE_GRADIENT_CHECKPOINTING if 'USE_GRADIENT_CHECKPOINTING_EFFECTIVE' not in globals() else USE_GRADIENT_CHECKPOINTING_EFFECTIVE

    training_args = Seq2SeqTrainingArguments(
        output_dir=str(OUTPUT_DIR),
        # Schedule
        num_train_epochs=NUM_TRAIN_EPOCHS,
        per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
        per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        warmup_steps=WARMUP_STEPS,
        # Logging / Saving / Evaluation
        logging_dir=str(OUTPUT_DIR / 'logs'),
        logging_steps=LOGGING_STEPS,
        evaluation_strategy=EVALUATION_STRATEGY,
        save_strategy=SAVE_STRATEGY,
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model="sequence_accuracy",
        greater_is_better=True,
        # Generation
        predict_with_generate=True,
        generation_max_length=MAX_SEQ_LENGTH,
        generation_num_beams=GENERATION_NUM_BEAMS,
        # Advanced features
        fp16=USE_FP16,
        label_smoothing_factor=LABEL_SMOOTHING_FACTOR,
        gradient_checkpointing=effective_gc, # Use the effective flag
        # report_to="tensorboard", # Optional
    )
    print("[INFO] Training arguments defined.")
    print(f"[INFO] Effective Mixed Precision (FP16): {'Enabled' if USE_FP16 else 'Disabled'}")
    print(f"[INFO] Effective Label Smoothing Factor: {LABEL_SMOOTHING_FACTOR}")
    print(f"[INFO] Effective Gradient Checkpointing in Args: {training_args.gradient_checkpointing}")


    # --- 10. Initialize Trainer ---
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics_fn,
    )
    print("[INFO] Seq2SeqTrainer initialized.")

    # --- 11. Train the Model ---
    print(f"[INFO] Starting model training for {NUM_TRAIN_EPOCHS} epochs...")
    try:
        train_result = trainer.train()
        metrics = train_result.metrics
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
        print("[INFO] Training finished successfully.")
        if trainer.state.best_model_checkpoint:
             print(f"[INFO] Best model checkpoint saved at: {trainer.state.best_model_checkpoint}")
        else:
             print("[Warning] No best model checkpoint recorded.")

    except Exception as e:
        print(f"[FATAL] Training loop encountered an error: {e}", file=sys.stderr)
        sys.exit(1)

    # --- 12. Evaluate on Test Set ---
    print("[INFO] Evaluating final model on the test set...")
    try:
        test_results = trainer.evaluate(
            eval_dataset=test_dataset,
            metric_key_prefix="test"
        )
        trainer.log_metrics("test", test_results)
        trainer.save_metrics("test", test_results)

        if 'test_sequence_accuracy' in test_results:
             print(f"\n--- Test Set Results ---")
             print(f"Test Sequence Accuracy: {test_results['test_sequence_accuracy']:.4f}")
             print(f"------------------------")
        else:
             print("[Warning] 'test_sequence_accuracy' not found in test results.")
             print("Full test results:", test_results)

    except Exception as e:
        print(f"[Error] Evaluation on test set failed: {e}", file=sys.stderr)
        sys.exit(1) # Exit with error status

    print("[INFO] Script finished successfully.")


if __name__ == "__main__":
    # Optional: Add argument parsing (argparse)
    # Optional: Set random seeds
    # SEED = 42; torch.manual_seed(SEED); np.random.seed(SEED); import random; random.seed(SEED)
    # if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)

    main()

[ERROR] Failed to import TitanTokenizer or TitanForConditionalGeneration.
Please ensure these classes are defined in your transformers library or environment.


SystemExit: 1

# 3.5: Evolutionary and Transformer Models for Symbolic Regression
Model: Transformer model integrated with an evolutionary pipeline. It’s possible to start from previous year’s projects but should introduce a substantial innovation.


In [11]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Evolutionary Hyperparameter Optimization for Transformer-Based Symbolic Regression.

This script employs a Genetic Algorithm (GA), using the DEAP library, to optimize
key hyperparameters (learning rate and dropout probability) for a standard
Transformer Encoder-Decoder model (BERT-to-BERT) applied to a symbolic regression
task using the QED 2→2 dataset.

The optimization process works as follows:
1.  A small, fixed subset of the training and validation data is selected for
    rapid evaluation during the GA.
2.  DEAP initializes a population of candidate hyperparameter sets (individuals).
3.  Each individual is evaluated by:
    a. Initializing a fresh BERT-to-BERT `EncoderDecoderModel`.
    b. Configuring the model with the dropout specified by the individual.
    c. Setting up `Seq2SeqTrainingArguments` with the learning rate from the
       individual and settings suitable for a short training run.
    d. Initializing a `Seq2SeqTrainer`.
    e. Training the model for a fixed, small number of steps/epochs on the
       training subset.
    f. Evaluating the trained model on the validation subset using sequence
       accuracy.
    g. Returning the validation accuracy as the individual's fitness.
4.  DEAP applies genetic operators (selection, crossover, mutation) to evolve
    the population over several generations, aiming to maximize fitness
    (validation accuracy).
5.  After the GA completes, the best hyperparameter set found is identified.
6.  A final `EncoderDecoderModel` is initialized.
7.  A full training run is performed using the best hyperparameters found by the
    GA, training on the complete training dataset and validating on the complete
    validation dataset.
8.  The final performance is reported based on evaluation on the held-out test set.

Note: This script assumes the `deap` library is installed (`pip install deap`).
"""

import json
import random
import sys
import numpy as np
import torch
import datetime
import os
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer,
    EncoderDecoderModel,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    TrainingArguments # Explicit import
)
from deap import base, creator, tools, algorithms
from tqdm.auto import tqdm # Optional, useful if adding manual loops or progress bars later

# --- Configuration ---

# File Paths
# Ensure these point to your actual preprocessed data files
TRAIN_FILE = Path('qed_expressions_train.jsonl') # Adjust prefix if needed
VAL_FILE   = Path('qed_expressions_val.jsonl')
TEST_FILE  = Path('qed_expressions_test.jsonl')
OUTPUT_DIR = Path('ga_transformer_symbolic_regression_output') # Main output directory
GA_LOG_FILE = OUTPUT_DIR / "ga_logbook.json"                   # Log of GA progress
BEST_HP_FILE = OUTPUT_DIR / "best_hyperparameters.json"         # Best HPs found
GA_EVAL_OUTPUT_DIR = OUTPUT_DIR / "ga_eval_runs" # Temp dir for GA evaluation runs

# Model and Tokenizer Configuration
MODEL_ID = "bert-base-uncased" # Base model for Encoder-Decoder
TOKENIZER_ID = "bert-base-uncased"
MAX_SEQ_LENGTH = 128          # Max sequence length for tokenization

# Genetic Algorithm Configuration (DEAP)
POPULATION_SIZE = 20        # Number of HP sets evaluated per generation
N_GENERATIONS = 10          # Total number of GA generations
CX_PROB = 0.7               # Probability of crossover between individuals
MUT_PROB = 0.2              # Probability of mutating an individual
# Mutation parameters for Gaussian perturbation
MUT_SIGMA_LR = 1e-5         # Std deviation for learning rate mutation
MUT_SIGMA_DROPOUT = 0.05    # Std deviation for dropout mutation
# Boundaries for hyperparameter values
LR_BOUND_LOW = 1e-6
LR_BOUND_HIGH = 1e-3
DROPOUT_BOUND_LOW = 0.0
DROPOUT_BOUND_HIGH = 0.5
TOURNAMENT_SIZE = 3         # Selection pressure for tournament selection

# GA Evaluation Run Configuration (Short runs on subsets for speed)
GA_EVAL_TRAIN_SUBSET_SIZE = 512 # Number of training examples per GA eval run
GA_EVAL_VAL_SUBSET_SIZE = 128   # Number of validation examples per GA eval run
GA_EVAL_EPOCHS = 1              # Epochs per GA eval run (keep very small)
# GA_EVAL_MAX_STEPS = 200       # Alternative: Fixed steps per GA eval run
GA_EVAL_BATCH_SIZE = 16         # Batch size for GA eval runs

# Final Training Configuration (using best HPs on full data)
FINAL_TRAIN_EPOCHS = 5
FINAL_BATCH_SIZE = 16
FINAL_WEIGHT_DECAY = 0.01
FINAL_WARMUP_STEPS = 300
FINAL_LOGGING_STEPS = 50
FINAL_USE_FP16 = torch.cuda.is_available() # Use FP16 if available for final run
FINAL_LABEL_SMOOTHING = 0.1     # Optional label smoothing for final run
FINAL_SAVE_TOTAL_LIMIT = 2      # Keep only the best and latest checkpoints

# --- Helper Functions and Classes ---

def load_jsonl(file_path):
    """Loads data from a JSON Lines file, handling basic errors."""
    data = []
    file_path = Path(file_path)
    print(f"[INFO] Loading data from: {file_path}")
    if not file_path.is_file():
        print(f"[Error] Data file not found: {file_path}", file=sys.stderr)
        raise FileNotFoundError(f"File not found: {file_path}")
    try:
        with file_path.open('r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                line = line.strip()
                if line:
                    try: data.append(json.loads(line))
                    except json.JSONDecodeError as e: print(f"[Warning] Skipping invalid JSON on line {i+1} in {file_path}: {e}")
        print(f"[INFO] Successfully loaded {len(data)} records.")
        if not data: print(f"[Warning] No valid records loaded from {file_path}.")
        return data
    except Exception as e:
        print(f"[Error] Failed to load data from {file_path}: {e}", file=sys.stderr)
        raise

def convert_tokens_to_strings(raw_data_list):
    """Converts lists of tokens from loaded data into source and target strings."""
    source_strings, target_strings, skipped_count = [], [], 0
    if not raw_data_list: return source_strings, target_strings
    for i, item in enumerate(raw_data_list):
        input_toks, target_toks = item.get('input_tokens'), item.get('target_tokens')
        if isinstance(input_toks, list) and isinstance(target_toks, list):
            source_strings.append(" ".join(map(str, input_toks)))
            target_strings.append(" ".join(map(str, target_toks)))
        else:
            print(f"[Warning] Skipping item at index {i} due to missing/invalid data: {item}"); skipped_count += 1
    print(f"[INFO] Converted {len(source_strings)} items to strings (skipped {skipped_count}).")
    if not source_strings: print("[Warning] No items successfully converted.")
    return source_strings, target_strings

def encode_sequences(tokenizer, source_texts, target_texts, max_len):
    """Tokenizes source and target sequences, returning lists of IDs/masks."""
    print(f"[INFO] Encoding sequence pairs with max_length={max_len}...")
    encoder_inputs = tokenizer(source_texts, max_length=max_len, padding='max_length', truncation=True, return_tensors=None)
    with tokenizer.as_target_tokenizer():
        decoder_labels = tokenizer(target_texts, max_length=max_len, padding='max_length', truncation=True, return_tensors=None)
    encoder_inputs['labels'] = decoder_labels['input_ids']
    print(f"[INFO] Encoding complete.")
    if not encoder_inputs.get('input_ids') or not encoder_inputs.get('labels') or len(encoder_inputs['input_ids']) != len(encoder_inputs['labels']):
         print("[Warning] Encoding resulted in empty lists or length mismatch.")
    return encoder_inputs

class SequencePairDataset(Dataset):
    """Simple PyTorch Dataset for holding tokenized sequence pair data (as lists)."""
    def __init__(self, encodings):
        if not isinstance(encodings, dict) or 'input_ids' not in encodings: raise ValueError("Invalid encodings format.")
        self.encodings = encodings
        try:
            self.length = len(encodings['input_ids'])
            for key in encodings:
                 if not isinstance(encodings[key], list) or len(encodings[key]) != self.length: raise ValueError(f"Inconsistent length for key '{key}'.")
        except Exception as e: raise ValueError(f"Validation failed: {e}")
        if self.length == 0: raise ValueError("Input encodings are empty.")
        print(f"[INFO] Created Dataset with {self.length} examples.")

    def __len__(self): return self.length
    def __getitem__(self, idx):
        if not 0 <= idx < self.length: raise IndexError(f"Index {idx} out of bounds.")
        try: return {key: self.encodings[key][idx] for key in self.encodings}
        except Exception as e: print(f"[Error] Failed retrieval at index {idx}: {e}", file=sys.stderr); raise

# --- Global variables needed by evaluation function ---
# These are populated in main() before the GA starts
ga_train_dataset_subset = None
ga_val_dataset_subset = None
tokenizer_for_eval = None
data_collator_for_eval = None
compute_metrics_internal = None # Holds the metric calculation logic

# --- DEAP Evaluation Function ---

def evaluate_hyperparams(individual):
    """
    Evaluates a hyperparameter set [learning_rate, dropout_rate] by training
    a small model on data subsets and returning validation accuracy.

    Args:
        individual (deap.creator.Individual): List containing [LR, Dropout].

    Returns:
        tuple: Fitness value (validation_accuracy,). Must be a tuple.
    """
    learning_rate, dropout_rate = individual[0], individual[1]
    dropout_rate = max(DROPOUT_BOUND_LOW, min(DROPOUT_BOUND_HIGH, dropout_rate)) # Clamp dropout

    run_id = f"lr_{learning_rate:.1e}_drop_{dropout_rate:.3f}_{random.randint(1000,9999)}"
    run_output_dir = GA_EVAL_OUTPUT_DIR / run_id

    print(f"[GA Eval] Evaluating LR={learning_rate:.3e}, Dropout={dropout_rate:.4f}")

    try:
        # 1. Initialize NEW model instance
        model = EncoderDecoderModel.from_encoder_decoder_pretrained(MODEL_ID, MODEL_ID)
        if model.config.encoder.vocab_size != len(tokenizer_for_eval):
            model.resize_token_embeddings(len(tokenizer_for_eval))
            model.config.encoder.vocab_size = len(tokenizer_for_eval)
            model.config.decoder.vocab_size = len(tokenizer_for_eval)

        # 2. Apply dropout from individual
        model.config.dropout = dropout_rate
        model.config.attention_dropout = dropout_rate
        if hasattr(model.config, 'encoder'): model.config.encoder.dropout = dropout_rate; model.config.encoder.attention_dropout = dropout_rate
        if hasattr(model.config, 'decoder'): model.config.decoder.dropout = dropout_rate; model.config.decoder.attention_dropout = dropout_rate

        # 3. Configure standard seq2seq settings
        model.config.decoder_start_token_id = tokenizer_for_eval.cls_token_id
        model.config.eos_token_id = tokenizer_for_eval.sep_token_id
        model.config.pad_token_id = tokenizer_for_eval.pad_token_id
        model.config.max_length = MAX_SEQ_LENGTH

        # 4. Minimal Training Arguments for GA eval run
        eval_training_args = Seq2SeqTrainingArguments(
            output_dir=str(run_output_dir),
            num_train_epochs=GA_EVAL_EPOCHS,
            # max_steps=GA_EVAL_MAX_STEPS, # Alternative
            per_device_train_batch_size=GA_EVAL_BATCH_SIZE,
            per_device_eval_batch_size=GA_EVAL_BATCH_SIZE,
            learning_rate=learning_rate,
            weight_decay=0.0, # Keep simple for GA eval
            logging_steps=1000, # Reduce logging noise
            evaluation_strategy="no",
            save_strategy="no",
            predict_with_generate=True, # Use generation for accuracy
            generation_max_length=MAX_SEQ_LENGTH,
            generation_num_beams=1, # Greedy search for speed
            fp16=FINAL_USE_FP16, # Consistent FP16 usage
            report_to="none",
            disable_tqdm=True,
        )

        # 5. Initialize Trainer
        trainer = Seq2SeqTrainer(
            model=model,
            args=eval_training_args,
            train_dataset=ga_train_dataset_subset,
            eval_dataset=ga_val_dataset_subset,
            data_collator=data_collator_for_eval,
            tokenizer=tokenizer_for_eval,
            compute_metrics=compute_metrics_internal # Use the predefined metric logic
        )

        # 6. Train briefly
        trainer.train()

        # 7. Evaluate
        eval_results = trainer.evaluate(eval_dataset=ga_val_dataset_subset)
        accuracy = eval_results.get("eval_sequence_accuracy", 0.0)

        print(f"[GA Eval Result] LR={learning_rate:.3e}, Dropout={dropout_rate:.4f} -> Val Acc={accuracy:.4f}")
        # Optional: Clean up temp dir: import shutil; shutil.rmtree(run_output_dir)

        return (accuracy,) # Return fitness tuple

    except Exception as e:
        print(f"[GA Eval Error] Failed for LR={learning_rate:.3e}, Dropout={dropout_rate:.4f}: {e}", file=sys.stderr)
        return (0.0,) # Return poor fitness on failure


# --- Main Execution Logic ---

def main():
    """Orchestrates data loading, GA optimization, and final model training."""
    print("[INFO] Starting Evolutionary Hyperparameter Optimization Script...")
    # --- ***** Update Timestamp/Location ***** ---
    try:
        # Attempt to get timezone-aware UTC time
        current_time = datetime.datetime.now(datetime.timezone.utc)
        current_time_str = current_time.strftime('%Y-%m-%d %H:%M:%S %Z') # Includes timezone name
    except Exception:
        # Fallback to naive UTC time
        current_time = datetime.datetime.utcnow()
        current_time_str = current_time.strftime('%Y-%m-%d %H:%M:%S UTC (naive)')
    # Incorporate provided context
    print(f"[INFO] Current time: {current_time_str}")
    print(f"[INFO] Location context: San Diego, CA, USA")
    print(f"[INFO] Using Base Model: {MODEL_ID}")

    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    GA_EVAL_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    # --- 1. Load Data ---
    try:
        train_raw = load_jsonl(TRAIN_FILE)
        val_raw = load_jsonl(VAL_FILE)
        test_raw = load_jsonl(TEST_FILE)
    except Exception as e: print(f"[FATAL] Data loading failed: {e}", file=sys.stderr); sys.exit(1)
    if not train_raw or not val_raw or not test_raw: print("[FATAL] Datasets empty after loading.", file=sys.stderr); sys.exit(1)

    # --- 2. Prepare Data Strings ---
    train_src, train_tgt = convert_tokens_to_strings(train_raw)
    val_src,   val_tgt   = convert_tokens_to_strings(val_raw)
    test_src,  test_tgt  = convert_tokens_to_strings(test_raw)
    if not train_src or not val_src or not test_src: print("[FATAL] Data conversion failed.", file=sys.stderr); sys.exit(1)

    # --- 3. Initialize Tokenizer (Global for GA Eval) ---
    global tokenizer_for_eval
    try:
        tokenizer_for_eval = AutoTokenizer.from_pretrained(TOKENIZER_ID)
        if tokenizer_for_eval.pad_token is None: tokenizer_for_eval.add_special_tokens({'pad_token': '[PAD]'})
        if tokenizer_for_eval.cls_token is None: tokenizer_for_eval.add_special_tokens({'cls_token': '[CLS]'})
        if tokenizer_for_eval.sep_token is None: tokenizer_for_eval.add_special_tokens({'sep_token': '[SEP]'})
        pad_token_id = tokenizer_for_eval.pad_token_id
    except Exception as e: print(f"[FATAL] Tokenizer init failed: {e}", file=sys.stderr); sys.exit(1)
    print(f"[INFO] Tokenizer initialized (Vocab: {len(tokenizer_for_eval)}, Pad ID: {pad_token_id}).")

    # --- 4. Tokenize Data & Create Subsets ---
    try:
        print("[INFO] Tokenizing full datasets...")
        train_enc = encode_sequences(tokenizer_for_eval, train_src, train_tgt, MAX_SEQ_LENGTH)
        val_enc   = encode_sequences(tokenizer_for_eval, val_src,   val_tgt,   MAX_SEQ_LENGTH)
        test_enc  = encode_sequences(tokenizer_for_eval, test_src,  test_tgt,  MAX_SEQ_LENGTH)

        full_train_dataset = SequencePairDataset(train_enc)
        full_val_dataset   = SequencePairDataset(val_enc)
        full_test_dataset  = SequencePairDataset(test_enc)

        print(f"[INFO] Creating GA evaluation subsets (Train: {GA_EVAL_TRAIN_SUBSET_SIZE}, Val: {GA_EVAL_VAL_SUBSET_SIZE})")
        train_subset_indices = random.sample(range(len(full_train_dataset)), min(GA_EVAL_TRAIN_SUBSET_SIZE, len(full_train_dataset)))
        val_subset_indices = random.sample(range(len(full_val_dataset)), min(GA_EVAL_VAL_SUBSET_SIZE, len(full_val_dataset)))

        def subset_encodings(enc, indices): return {key: [enc[key][i] for i in indices] for key in enc}

        global ga_train_dataset_subset, ga_val_dataset_subset
        ga_train_dataset_subset = SequencePairDataset(subset_encodings(train_enc, train_subset_indices))
        ga_val_dataset_subset = SequencePairDataset(subset_encodings(val_enc, val_subset_indices))

    except Exception as e: print(f"[FATAL] Data tokenization or subsetting failed: {e}", file=sys.stderr); sys.exit(1)

    # --- 5. Setup Global Collator & Metrics Logic ---
    global data_collator_for_eval, compute_metrics_internal
    temp_model = EncoderDecoderModel.from_encoder_decoder_pretrained(MODEL_ID, MODEL_ID) # Need instance for collator setup
    data_collator_for_eval = DataCollatorForSeq2Seq(tokenizer_for_eval, model=temp_model, label_pad_token_id=-100, pad_to_multiple_of=8 if FINAL_USE_FP16 else None)
    del temp_model
    print("[INFO] Data collator prepared.")

    # Define the internal metric calculation logic once
    def compute_metrics_logic(eval_pred):
        """Internal logic for calculating sequence accuracy."""
        predictions, labels = eval_pred
        if not isinstance(predictions, np.ndarray): predictions = np.array(predictions)
        if not isinstance(labels, np.ndarray): labels = np.array(labels)
        labels = np.where(labels != -100, labels, tokenizer_for_eval.pad_token_id)
        try:
            decoded_preds = tokenizer_for_eval.batch_decode(predictions, skip_special_tokens=True)
            decoded_labels = tokenizer_for_eval.batch_decode(labels, skip_special_tokens=True)
            decoded_preds = [p.strip() for p in decoded_preds]
            decoded_labels = [l.strip() for l in decoded_labels]
            if len(decoded_preds) != len(decoded_labels): return {"sequence_accuracy": 0.0}
            matches = [p == l for p, l in zip(decoded_preds, decoded_labels)]
            accuracy = np.mean(matches) if matches else 0.0
            return {"sequence_accuracy": float(accuracy)}
        except Exception as dec_e:
             print(f"[Metrics Error] Decoding failed: {dec_e}", file=sys.stderr); return {"sequence_accuracy": 0.0}
    compute_metrics_internal = compute_metrics_logic # Assign to global var

    # --- 6. DEAP Genetic Algorithm Setup ---
    print("[INFO] Setting up Genetic Algorithm (DEAP)...")
    creator.create("FitnessMax", base.Fitness, weights=(1.0,))
    creator.create("Individual", list, fitness=creator.FitnessMax)
    toolbox = base.Toolbox()
    toolbox.register("attr_lr", random.uniform, LR_BOUND_LOW, LR_BOUND_HIGH)
    toolbox.register("attr_dropout", random.uniform, DROPOUT_BOUND_LOW, DROPOUT_BOUND_HIGH)
    toolbox.register("individual", tools.initCycle, creator.Individual, (toolbox.attr_lr, toolbox.attr_dropout), n=1)
    toolbox.register("population", tools.initRepeat, list, toolbox.individual)
    toolbox.register("mate", tools.cxTwoPoint)
    # Ensure mutation respects boundaries - DEAP's mutGaussian doesn't inherently
    # We handle clamping within the evaluate function, alternative is custom mutation
    toolbox.register("mutate", tools.mutGaussian, mu=0, sigma=[MUT_SIGMA_LR, MUT_SIGMA_DROPOUT], indpb=0.2)
    toolbox.register("select", tools.selTournament, tournsize=TOURNAMENT_SIZE)
    toolbox.register("evaluate", evaluate_hyperparams)

    # --- 7. Run Genetic Algorithm ---
    print(f"[INFO] Starting GA optimization ({N_GENERATIONS} gens, Pop: {POPULATION_SIZE})...")
    population = toolbox.population(n=POPULATION_SIZE)
    hall_of_fame = tools.HallOfFame(1)
    stats = tools.Statistics(lambda ind: ind.fitness.values)
    stats.register("avg", np.mean); stats.register("std", np.std); stats.register("min", np.min); stats.register("max", np.max)

    try:
        population, logbook = algorithms.eaSimple(population, toolbox, cxpb=CX_PROB, mutpb=MUT_PROB, ngen=N_GENERATIONS, stats=stats, halloffame=hall_of_fame, verbose=True)
        ga_success = True
    except Exception as ga_e: print(f"[ERROR] GA execution failed: {ga_e}", file=sys.stderr); ga_success = False; logbook = None; best_individual = None

    if logbook:
        try:
            with open(GA_LOG_FILE, 'w') as f: json.dump(logbook, f, indent=2)
            print(f"[INFO] GA logbook saved to {GA_LOG_FILE}")
        except Exception as log_e: print(f"[Error] Failed to save GA logbook: {log_e}", file=sys.stderr)

    if not ga_success or not hall_of_fame: print("[FATAL] GA optimization failed or found no best individual. Exiting.", file=sys.stderr); sys.exit(1)

    best_individual = hall_of_fame[0]
    best_lr, best_dropout = best_individual[0], max(DROPOUT_BOUND_LOW, min(DROPOUT_BOUND_HIGH, best_individual[1])) # Clamp final best dropout
    best_fitness = best_individual.fitness.values[0]

    print("\n--- GA Optimization Complete ---")
    print(f"Best Individual: LR={best_lr:.3e}, Dropout={best_dropout:.4f}")
    print(f"Best Validation Accuracy (on subset): {best_fitness:.4f}")
    print("----------------------------------")

    best_hp = {'learning_rate': best_lr, 'dropout_rate': best_dropout, 'validation_accuracy': best_fitness}
    try:
        with open(BEST_HP_FILE, 'w') as f: json.dump(best_hp, f, indent=2)
        print(f"[INFO] Best hyperparameters saved to {BEST_HP_FILE}")
    except Exception as hp_e: print(f"[Error] Failed to save best hyperparameters: {hp_e}", file=sys.stderr)

    # --- 8. Final Training with Best Hyperparameters ---
    print("\n[INFO] Starting final model training using best hyperparameters found...")
    try:
        final_model = EncoderDecoderModel.from_encoder_decoder_pretrained(MODEL_ID, MODEL_ID)
        if final_model.config.encoder.vocab_size != len(tokenizer_for_eval):
            final_model.resize_token_embeddings(len(tokenizer_for_eval))
            final_model.config.encoder.vocab_size = len(tokenizer_for_eval)
            final_model.config.decoder.vocab_size = len(tokenizer_for_eval)

        final_model.config.dropout = best_dropout
        final_model.config.attention_dropout = best_dropout
        if hasattr(final_model.config, 'encoder'): final_model.config.encoder.dropout = best_dropout; final_model.config.encoder.attention_dropout = best_dropout
        if hasattr(final_model.config, 'decoder'): final_model.config.decoder.dropout = best_dropout; final_model.config.decoder.attention_dropout = best_dropout

        final_model.config.decoder_start_token_id = tokenizer_for_eval.cls_token_id
        final_model.config.eos_token_id = tokenizer_for_eval.sep_token_id
        final_model.config.pad_token_id = tokenizer_for_eval.pad_token_id
        final_model.config.max_length = MAX_SEQ_LENGTH
        final_model.config.num_beams = 4 # Use beam search for final evaluation
        final_model.tie_weights()
        # Optional final gradient checkpointing
        # if FINAL_USE_GRADIENT_CHECKPOINTING: final_model.gradient_checkpointing_enable()

    except Exception as model_e: print(f"[FATAL] Failed to initialize final model: {model_e}", file=sys.stderr); sys.exit(1)

    final_training_args = Seq2SeqTrainingArguments(
        output_dir=str(OUTPUT_DIR / "final_model"),
        num_train_epochs=FINAL_TRAIN_EPOCHS,
        per_device_train_batch_size=FINAL_BATCH_SIZE,
        per_device_eval_batch_size=FINAL_BATCH_SIZE,
        learning_rate=best_lr, # Use best LR
        weight_decay=FINAL_WEIGHT_DECAY,
        warmup_steps=FINAL_WARMUP_STEPS,
        logging_dir=str(OUTPUT_DIR / "final_model" / 'logs'),
        logging_steps=FINAL_LOGGING_STEPS,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=FINAL_SAVE_TOTAL_LIMIT,
        load_best_model_at_end=True,
        metric_for_best_model="sequence_accuracy",
        greater_is_better=True,
        predict_with_generate=True,
        generation_max_length=MAX_SEQ_LENGTH,
        generation_num_beams=4,
        fp16=FINAL_USE_FP16,
        label_smoothing_factor=FINAL_LABEL_SMOOTHING,
        # gradient_checkpointing=FINAL_USE_GRADIENT_CHECKPOINTING, # Optional
        report_to="none",
    )

    final_trainer = Seq2SeqTrainer(
        model=final_model, args=final_training_args,
        train_dataset=full_train_dataset, eval_dataset=full_val_dataset,
        data_collator=data_collator_for_eval, tokenizer=tokenizer_for_eval,
        compute_metrics=compute_metrics_internal,
    )

    print(f"[INFO] Starting final training run ({FINAL_TRAIN_EPOCHS} epochs)...")
    try:
        train_result = final_trainer.train()
        final_trainer.log_metrics("final_train", train_result.metrics)
        final_trainer.save_metrics("final_train", train_result.metrics)
        final_trainer.save_state()
        print("[INFO] Final training complete.")
        if final_trainer.state.best_model_checkpoint: print(f"[INFO] Best final model checkpoint: {final_trainer.state.best_model_checkpoint}")
    except Exception as train_e: print(f"[FATAL] Final training failed: {train_e}", file=sys.stderr); sys.exit(1)

    # --- 9. Final Test Evaluation ---
    print("[INFO] Evaluating final best model on the full test set...")
    try:
        test_results = final_trainer.evaluate(eval_dataset=full_test_dataset, metric_key_prefix="test")
        final_trainer.log_metrics("test", test_results)
        final_trainer.save_metrics("test", test_results)
        if 'test_sequence_accuracy' in test_results:
             print(f"\n--- Final Test Set Results ---"); print(f"Test Sequence Accuracy: {test_results['test_sequence_accuracy']:.4f}"); print(f"-----------------------------")
        else: print("[Warning] 'test_sequence_accuracy' not found.", "Full test results:", test_results)
    except Exception as test_e: print(f"[Error] Final evaluation failed: {test_e}", file=sys.stderr); sys.exit(1)

    print("\n[INFO] Script finished successfully.")

if __name__ == "__main__":
    SEED = 42
    random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)
    main()

ModuleNotFoundError: No module named 'deap'

# 3.6: Symbolic empirical representation of squared amplitudes in high-energy physics
Model: Transformer with novel approach for tokenization, data representation and/or preprocessing that leads to better performance than basic tokenization with normalized indices.


In [12]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Fine-tuning Script for a Depth-Aware Transformer for Symbolic Regression.

This script implements and trains a custom Transformer Encoder-Decoder model
designed for symbolic regression tasks, specifically predicting squared amplitudes
in High Energy Physics (HEP) from input amplitudes.

Key features:
- Custom BPE Tokenizer: Trains a Byte-Pair Encoding tokenizer from scratch on the
  combined input and target sequences of the dataset, potentially capturing
  domain-specific physics identifiers and operators more effectively.
- Depth-Aware Embeddings: Calculates the parenthesis nesting depth for each token
  in the input sequence and incorporates this information via dedicated depth
  embeddings, adding structural awareness to the model's input representation.
- Custom Transformer Model: Implements a standard Transformer encoder-decoder
  using PyTorch's `nn.TransformerEncoderLayer` and `nn.TransformerDecoderLayer`,
  modified to accept the combined token, position, and depth embeddings.
- Manual PyTorch Training Loop: Includes standard training, validation, and
  testing phases with loss calculation, optimization, metric computation (sequence
  accuracy), and basic checkpointing.

Workflow:
1. Load raw data splits from JSONL files.
2. Train a new BPE tokenizer on the corpus (if not already trained and saved)
   or load a previously trained tokenizer.
3. Define a custom `DepthDataset` that performs tokenization and calculates
   parenthesis depth for each input sequence during initialization.
4. Define the `DepthTransformer` model architecture using PyTorch's nn.Module.
5. Set up DataLoaders, optimizer, loss criterion, and device (GPU/CPU).
6. Implement the training loop, iterating through epochs and batches, performing
   forward/backward passes, and updating model weights.
7. Implement validation loop to monitor performance (sequence accuracy) and save
   the best model checkpoint based on validation results.
8. Implement the final test loop to evaluate the best model on the held-out
   test set.
"""

import json
import sys
import numpy as np
import torch
import datetime
import os
from pathlib import Path
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaTokenizerFast # Using RobertaTokenizer for BPE training
from tqdm.auto import tqdm # Progress bars
# Optional: Learning rate scheduler
# from transformers import get_linear_schedule_with_warmup

# --- Configuration ---

# File Paths
TRAIN_FILE = Path('qed_expressions_train.jsonl') # Adjust if using different names
VAL_FILE   = Path('qed_expressions_val.jsonl')
TEST_FILE  = Path('qed_expressions_test.jsonl')
OUTPUT_DIR = Path('depth_aware_transformer_output')
TOKENIZER_SAVE_DIR = OUTPUT_DIR / "bpe_physics_tokenizer" # Directory to save/load trained tokenizer
CHECKPOINT_NAME = "depth_transformer_best.pt" # Name for the best model checkpoint

# Tokenizer Configuration
VOCAB_SIZE = 8000         # Target vocabulary size for BPE tokenizer
MIN_FREQUENCY = 2         # Minimum frequency for tokens in BPE training
# Special tokens common for sequence models
SPECIAL_TOKENS = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]

# Data and Model Configuration
MAX_SEQ_LENGTH = 128      # Max sequence length for tokenization and model
MAX_DEPTH = 50            # Maximum expected parenthesis nesting depth + buffer

# Model Hyperparameters (DepthTransformer)
D_MODEL = 512             # Embedding and hidden state dimension
N_HEAD = 8                # Number of attention heads
NUM_ENCODER_LAYERS = 6    # Number of layers in the Transformer encoder
NUM_DECODER_LAYERS = 6    # Number of layers in the Transformer decoder
DIM_FEEDFORWARD = 2048    # Dimension of the feed-forward networks
DROPOUT = 0.1             # Dropout rate (can be added to model layers)

# Training Hyperparameters
NUM_EPOCHS = 10
BATCH_SIZE = 16           # Adjust based on GPU memory
LEARNING_RATE = 5e-4
OPTIMIZER_EPS = 1e-8
WEIGHT_DECAY = 0.01
# LR_SCHEDULER_TYPE = "linear" # Optional scheduler type
# WARMUP_RATIO = 0.1         # Optional warmup ratio
LOG_INTERVAL = 50         # Log training stats every N batches

# --- Helper Functions and Classes ---

def load_jsonl(file_path):
    """Loads data from a JSON Lines file, handling basic errors."""
    # (Same implementation as previous script)
    data = []
    file_path = Path(file_path)
    print(f"[INFO] Loading data from: {file_path}")
    if not file_path.is_file():
        print(f"[Error] Data file not found: {file_path}", file=sys.stderr)
        raise FileNotFoundError(f"File not found: {file_path}")
    try:
        with file_path.open('r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                line = line.strip()
                if line:
                    try: data.append(json.loads(line))
                    except json.JSONDecodeError as e: print(f"[Warning] Skipping invalid JSON on line {i+1} in {file_path}: {e}")
        print(f"[INFO] Successfully loaded {len(data)} records.")
        if not data: print(f"[Warning] No valid records loaded from {file_path}.")
        return data
    except Exception as e:
        print(f"[Error] Failed to load data from {file_path}: {e}", file=sys.stderr)
        raise


def train_or_load_tokenizer(corpus_iterator, save_dir, vocab_size, min_freq, special_tokens):
    """Trains a new BPE tokenizer or loads it if it already exists."""
    save_dir = Path(save_dir)
    config_file = save_dir / "tokenizer_config.json"

    if save_dir.exists() and config_file.exists():
        print(f"[INFO] Loading existing tokenizer from {save_dir}")
        try:
            tokenizer = RobertaTokenizerFast.from_pretrained(str(save_dir))
            # Verify core special tokens are present
            if any(tok not in tokenizer.get_vocab() for tok in ['<s>', '<pad>', '</s>', '<unk>']):
                 print("[Warning] Loaded tokenizer missing some standard special tokens. Recheck config.")
            return tokenizer
        except Exception as e:
            print(f"[Warning] Failed to load existing tokenizer: {e}. Attempting to retrain.", file=sys.stderr)

    print(f"[INFO] Training new BPE tokenizer (Vocab: {vocab_size}, Min Freq: {min_freq})...")
    try:
        # Initialize a base tokenizer (like Roberta's structure) but don't load weights
        # We use RobertaTokenizerFast as it supports training from iterator
        tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") # Base structure
        print("[INFO] Training from iterator...")
        # The train_new_from_iterator method might not exist directly on RobertaTokenizerFast.
        # Often, you use the underlying `tokenizers` library for this.
        # Let's adapt using the `tokenizers` library approach if direct method fails.

        try:
             # Attempt direct method if available in specific transformers version
             tokenizer.train_new_from_iterator(corpus_iterator, vocab_size=vocab_size, min_frequency=min_freq, special_tokens=special_tokens)
        except AttributeError:
             print("[INFO] `train_new_from_iterator` not found, using `tokenizers` library directly.")
             from tokenizers import ByteLevelBPETokenizer as TokenizersBPETokenizer

             # Initialize BPE tokenizer from the `tokenizers` library
             tk_lib_tokenizer = TokenizersBPETokenizer()
             tk_lib_tokenizer.train_from_iterator(
                 corpus_iterator,
                 vocab_size=vocab_size,
                 min_frequency=min_freq,
                 special_tokens=special_tokens
             )
             # Need to save this intermediate tokenizer and then load it into RobertaTokenizerFast
             temp_save_path = save_dir / "temp_tokenizer_files"
             temp_save_path.mkdir(parents=True, exist_ok=True)
             tk_lib_tokenizer.save_model(str(temp_save_path)) # Saves vocab.json and merges.txt

             # Load the trained vocab/merges into the desired HF Tokenizer class
             tokenizer = RobertaTokenizerFast.from_pretrained(str(temp_save_path), max_len=MAX_SEQ_LENGTH) # Or appropriate max_len
             # Clean up temp files
             import shutil; shutil.rmtree(temp_save_path)


        print("[INFO] Tokenizer training complete.")
        save_dir.mkdir(parents=True, exist_ok=True)
        tokenizer.save_pretrained(str(save_dir))
        print(f"[INFO] Tokenizer saved to {save_dir}")
        return tokenizer
    except ImportError as e:
         print(f"[Error] Required library (`tokenizers`) not found for training: {e}. Please install it.", file=sys.stderr)
         raise
    except Exception as e:
        print(f"[Error] Tokenizer training failed: {e}", file=sys.stderr)
        raise


class DepthDataset(Dataset):
    """
    Dataset that tokenizes sequences and calculates parenthesis nesting depth.

    Tokenization and depth calculation are performed once during initialization.
    Stores token IDs, attention masks, depth IDs, and label IDs.

    Note: Stores all processed examples in memory. May be unsuitable for
          extremely large datasets without modification (e.g., lazy processing
          or memory mapping).
    """
    def __init__(self, raw_data, tokenizer, max_len=128, max_depth=50):
        """
        Args:
            raw_data (list[dict]): List of dictionaries, each with 'input_tokens'
                                   and 'target_tokens' keys containing lists of strings.
            tokenizer: Initialized Hugging Face tokenizer instance.
            max_len (int): Maximum sequence length for padding/truncation.
            max_depth (int): Maximum depth value to cap at; also determines depth embedding size.
        """
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.max_depth = max_depth
        self.examples = [] # List to store processed examples

        print(f"[INFO] Processing dataset for DepthDataset (Max Len: {max_len}, Max Depth: {max_depth})...")
        skipped_count = 0
        for i, ex in enumerate(tqdm(raw_data, desc="Processing Raw Data")):
            input_tokens = ex.get("input_tokens")
            target_tokens = ex.get("target_tokens")

            if not isinstance(input_tokens, list) or not isinstance(target_tokens, list):
                print(f"[Warning] Skipping example {i} due to missing or invalid token lists.")
                skipped_count += 1
                continue

            # Join tokens for BPE tokenizer
            src_str = " ".join(map(str, input_tokens))
            tgt_str = " ".join(map(str, target_tokens))

            # Tokenize source and target
            # Note: We don't use return_tensors='pt' here; store lists/ints
            src_enc = tokenizer(src_str, max_length=max_len, truncation=True, padding="max_length", return_tensors=None)
            # Use context manager for target tokenization (good practice)
            with tokenizer.as_target_tokenizer():
                tgt_enc = tokenizer(tgt_str, max_length=max_len, truncation=True, padding="max_length", return_tensors=None)

            # --- Calculate Parenthesis Depth ---
            # Operates on the *original* input_tokens before BPE tokenization
            current_depth = 0
            depth_sequence = []
            for tok in input_tokens:
                # Assign depth *before* potential change for closing parenthesis
                if tok == ")":
                    current_depth = max(0, current_depth - 1) # Decrement after assigning for ')'
                    depth_sequence.append(min(current_depth, max_depth - 1)) # Cap depth
                elif tok == "(":
                    depth_sequence.append(min(current_depth, max_depth - 1)) # Cap depth
                    current_depth += 1 # Increment after assigning for '('
                else:
                    depth_sequence.append(min(current_depth, max_depth - 1)) # Assign current depth

            # --- Align Depth with Tokenized Sequence (Simple Approach) ---
            # This is a simplification. A robust alignment would map original token
            # positions to BPE token positions. Here, we just pad/truncate the
            # depth sequence calculated on original tokens to match the BPE sequence length.
            # This might misalign depth for tokens split by BPE.
            # TODO: Implement a more robust alignment if needed.
            if len(depth_sequence) >= max_len:
                aligned_depth_ids = depth_sequence[:max_len]
            else:
                # Pad with depth 0 (assuming root level)
                aligned_depth_ids = depth_sequence + [0] * (max_len - len(depth_sequence))

            self.examples.append({
                "input_ids":      src_enc["input_ids"],
                "attention_mask": src_enc["attention_mask"],
                "depth_ids":      aligned_depth_ids,
                "labels":         tgt_enc["input_ids"], # Target token IDs
            })

        if skipped_count > 0:
            print(f"[Warning] Skipped {skipped_count} examples during dataset processing.")
        if not self.examples:
             raise ValueError("Dataset processing resulted in zero valid examples.")
        print(f"[INFO] DepthDataset processed. Total examples: {len(self.examples)}")

    def __len__(self):
        """Returns the total number of processed examples."""
        return len(self.examples)

    def __getitem__(self, idx):
        """
        Retrieves a processed example by index.
        Returns data as lists/ints; DataLoader handles tensor conversion.
        """
        if not 0 <= idx < len(self.examples):
            raise IndexError(f"Index {idx} out of bounds.")
        return self.examples[idx]


class DepthTransformer(nn.Module):
    """
    Transformer Encoder-Decoder model incorporating token depth embeddings.

    Input representation combines token, position, and parenthesis depth embeddings.
    Uses standard PyTorch Transformer layers for sequence processing.
    """
    def __init__(self, vocab_size, d_model=512, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, max_len=128,
                 max_depth=50, dropout=0.1, pad_token_id=None):
        """
        Initializes the Depth-Aware Transformer.

        Args:
            vocab_size (int): Size of the vocabulary.
            d_model (int): Dimension of embeddings and hidden states.
            nhead (int): Number of attention heads.
            num_encoder_layers (int): Number of layers in the encoder stack.
            num_decoder_layers (int): Number of layers in the decoder stack.
            dim_feedforward (int): Dimension of the feed-forward networks.
            max_len (int): Maximum sequence length for positional embeddings.
            max_depth (int): Maximum depth value + 1 (size of depth embedding table).
            dropout (float): Dropout rate.
            pad_token_id (int): ID of the padding token for embeddings.
        """
        super().__init__()
        self.d_model = d_model
        self.pad_token_id = pad_token_id if pad_token_id is not None else 0 # Default assumption

        # Embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=self.pad_token_id)
        self.positional_embedding = nn.Embedding(max_len, d_model) # Absolute positional
        self.depth_embedding = nn.Embedding(max_depth, d_model) # Depth information

        self.embedding_dropout = nn.Dropout(dropout)

        # Standard PyTorch Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        encoder_norm = nn.LayerNorm(d_model)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        # Standard PyTorch Transformer Decoder
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        decoder_norm = nn.LayerNorm(d_model)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)

        # Output Projection Layer
        self.output_projection = nn.Linear(d_model, vocab_size)

        print(f"[INFO] Initialized DepthTransformer:")
        print(f"  - Vocab Size: {vocab_size}, d_model: {d_model}, Heads: {nhead}")
        print(f"  - Max Len: {max_len}, Max Depth: {max_depth}")
        print(f"  - Enc Layers: {num_encoder_layers}, Dec Layers: {num_decoder_layers}")
        print(f"  - Feedforward Dim: {dim_feedforward}, Dropout: {dropout}")


    def forward(self, src_input_ids, src_attention_mask, src_depth_ids, tgt_input_ids):
        """
        Performs the forward pass through the depth-aware Transformer.

        Args:
            src_input_ids (torch.Tensor): Source sequence token IDs (batch_size, src_seq_len).
            src_attention_mask (torch.Tensor): Source sequence attention mask (batch_size, src_seq_len).
                                               (1 for real tokens, 0 for padding).
            src_depth_ids (torch.Tensor): Source sequence depth IDs (batch_size, src_seq_len).
            tgt_input_ids (torch.Tensor): Target sequence token IDs (batch_size, tgt_seq_len),
                                          typically shifted right for teacher forcing.

        Returns:
            torch.Tensor: Output logits (batch_size, tgt_seq_len, vocab_size).
        """
        batch_size, src_seq_len = src_input_ids.shape
        _, tgt_seq_len = tgt_input_ids.shape

        # 1. Source Embeddings (Token + Position + Depth)
        # Create position IDs (0 to seq_len-1)
        src_pos_ids = torch.arange(src_seq_len, device=src_input_ids.device).unsqueeze(0).expand(batch_size, -1)

        src_emb = self.token_embedding(src_input_ids)           # (B, S, D)
        src_pos = self.positional_embedding(src_pos_ids)       # (B, S, D)
        src_dep = self.depth_embedding(src_depth_ids)         # (B, S, D)

        src_combined_emb = src_emb + src_pos + src_dep # Combine embeddings
        src_combined_emb = self.embedding_dropout(src_combined_emb)

        # 2. Encoder
        # PyTorch Transformer layers expect masks where True indicates a position *not* to attend to.
        # src_attention_mask is (B, S) with 1 for non-pad, 0 for pad. Need (B, S) with True for pad.
        src_key_padding_mask = (src_attention_mask == 0) # True where padded

        # Pass through encoder stack
        memory = self.encoder(src_combined_emb, src_key_padding_mask=src_key_padding_mask) # (B, S, D)

        # 3. Target Embeddings (Token + Position) - No depth for target typically
        # Target IDs are usually shifted right for teacher forcing (e.g., start with BOS, end before EOS)
        tgt_pos_ids = torch.arange(tgt_seq_len, device=tgt_input_ids.device).unsqueeze(0).expand(batch_size, -1)

        tgt_emb = self.token_embedding(tgt_input_ids)           # (B, T, D)
        tgt_pos = self.positional_embedding(tgt_pos_ids)       # (B, T, D)

        tgt_combined_emb = tgt_emb + tgt_pos # Combine target embeddings
        tgt_combined_emb = self.embedding_dropout(tgt_combined_emb)

        # 4. Decoder
        # Create target padding mask (True for padded positions)
        # Assuming target uses same pad token id
        # Need a mask for the target sequence itself
        tgt_key_padding_mask = (tgt_input_ids == self.pad_token_id)

        # Create causal mask (autoregressive mask) to prevent attending to future tokens
        # Shape: (T, T)
        tgt_causal_mask = nn.Transformer.generate_square_subsequent_mask(tgt_seq_len, device=tgt_input_ids.device)

        # Pass through decoder stack
        decoder_output = self.decoder(
            tgt=tgt_combined_emb,           # Target sequence embeddings (B, T, D)
            memory=memory,                  # Encoder output (B, S, D)
            tgt_mask=tgt_causal_mask,       # Causal mask (T, T)
            tgt_key_padding_mask=tgt_key_padding_mask, # Target padding mask (B, T)
            memory_key_padding_mask=src_key_padding_mask # Source padding mask (B, S)
        ) # Output shape: (B, T, D)

        # 5. Output Projection
        logits = self.output_projection(decoder_output) # (B, T, VocabSize)

        return logits


def calculate_sequence_accuracy(logits, labels, pad_token_id):
    """Calculates exact sequence match accuracy using PyTorch tensors."""
    if logits.shape[0] == 0: return 0.0
    predictions = torch.argmax(logits, dim=-1)
    non_pad_mask = (labels != pad_token_id)
    correct_tokens = (predictions == labels) & non_pad_mask
    correct_sequences = (torch.sum(correct_tokens, dim=1) == torch.sum(non_pad_mask, dim=1))
    accuracy = torch.mean(correct_sequences.float())
    return accuracy.item()


# --- Main Execution ---

def main():
    """Orchestrates tokenizer training, data processing, model training, and evaluation."""
    print("[INFO] Starting Depth-Aware Transformer Script...")
    try: current_time_str = datetime.datetime.now(datetime.timezone.utc).strftime('%Y-%m-%d %H:%M:%S UTC')
    except Exception: current_time_str = datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC (naive)')
    # --- ***** Update Timestamp/Location ***** ---
    print(f"[INFO] Current time: {current_time_str}")
    print(f"[INFO] Location context: San Diego, CA, USA")

    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    # --- 1. Load Raw Data ---
    try:
        train_raw = load_jsonl(TRAIN_FILE)
        val_raw = load_jsonl(VAL_FILE)
        test_raw = load_jsonl(TEST_FILE)
    except Exception as e: print(f"[FATAL] Data loading failed: {e}", file=sys.stderr); sys.exit(1)
    if not train_raw or not val_raw or not test_raw: print("[FATAL] Datasets empty after loading.", file=sys.stderr); sys.exit(1)

    # --- 2. Prepare Corpus and Train/Load Tokenizer ---
    def corpus_gen(): # Generator to avoid loading all strings into memory at once
        print("[INFO] Preparing corpus for tokenizer training...")
        count = 0
        for split in (train_raw, val_raw, test_raw):
            for ex in split:
                if ex.get("input_tokens"): yield " ".join(map(str, ex["input_tokens"]))
                if ex.get("target_tokens"): yield " ".join(map(str, ex["target_tokens"]))
                count += 2
        print(f"[INFO] Corpus generator ready ({count} sequences).")

    try:
        tokenizer = train_or_load_tokenizer(
            corpus_iterator=corpus_gen(),
            save_dir=TOKENIZER_SAVE_DIR,
            vocab_size=VOCAB_SIZE,
            min_freq=MIN_FREQUENCY,
            special_tokens=SPECIAL_TOKENS
        )
        vocab_size = len(tokenizer) # Get actual vocab size after training/loading
        pad_token_id = tokenizer.pad_token_id
        if pad_token_id is None:
             print("[Error] Tokenizer loaded without a pad token ID!", file=sys.stderr); sys.exit(1)

    except Exception as e:
        print(f"[FATAL] Tokenizer training or loading failed: {e}", file=sys.stderr)
        sys.exit(1)
    print(f"[INFO] Tokenizer ready (Vocab size: {vocab_size}, Pad ID: {pad_token_id}).")


    # --- 3. Create Datasets ---
    try:
        print("[INFO] Creating DepthDatasets...")
        train_dataset = DepthDataset(train_raw, tokenizer, MAX_SEQ_LENGTH, MAX_DEPTH)
        val_dataset   = DepthDataset(val_raw,   tokenizer, MAX_SEQ_LENGTH, MAX_DEPTH)
        test_dataset  = DepthDataset(test_raw,  tokenizer, MAX_SEQ_LENGTH, MAX_DEPTH)
    except Exception as e:
        print(f"[FATAL] Failed to create datasets: {e}", file=sys.stderr)
        sys.exit(1)
    if not train_dataset or not val_dataset or not test_dataset:
         print("[FATAL] One or more datasets are empty after processing. Exiting.", file=sys.stderr)
         sys.exit(1)


    # --- 4. Create DataLoaders ---
    # Define a simple collate function to handle batching of list data into tensors
    def collate_batch(batch):
        elem = batch[0]
        collated = {}
        for key in elem.keys():
            collated[key] = torch.tensor([d[key] for d in batch])
        return collated

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch, num_workers=4, pin_memory=True)
    val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch, num_workers=4, pin_memory=True)
    test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch, num_workers=4, pin_memory=True)
    print("[INFO] DataLoaders created.")


    # --- 5. Initialize Model, Optimizer, Criterion ---
    print(f"[INFO] Initializing DepthTransformer model...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[INFO] Using device: {device}")

    try:
        model = DepthTransformer(
            vocab_size=vocab_size,
            d_model=D_MODEL,
            nhead=N_HEAD,
            num_encoder_layers=NUM_ENCODER_LAYERS,
            num_decoder_layers=NUM_DECODER_LAYERS,
            dim_feedforward=DIM_FEEDFORWARD,
            max_len=MAX_SEQ_LENGTH,
            max_depth=MAX_DEPTH,
            dropout=DROPOUT,
            pad_token_id=pad_token_id
        ).to(device)
    except Exception as e:
        print(f"[FATAL] Failed to initialize model: {e}", file=sys.stderr)
        sys.exit(1)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=LEARNING_RATE,
        eps=OPTIMIZER_EPS,
        weight_decay=WEIGHT_DECAY
    )
    criterion = nn.CrossEntropyLoss(ignore_index=pad_token_id) # Ignore padding in loss
    print("[INFO] Model, Optimizer, and Criterion initialized.")

    # Optional: Learning Rate Scheduler
    # total_steps = len(train_loader) * NUM_EPOCHS
    # warmup_steps = int(total_steps * WARMUP_RATIO)
    # scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)


    # --- 6. Training Loop ---
    print(f"[INFO] Starting training for {NUM_EPOCHS} epochs...")
    best_val_accuracy = -1.0
    best_epoch = -1
    checkpoint_path = OUTPUT_DIR / CHECKPOINT_NAME

    for epoch in range(NUM_EPOCHS):
        print(f"\n--- Epoch {epoch+1}/{NUM_EPOCHS} ---")
        model.train()
        total_train_loss = 0.0
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} Training", leave=False)

        for step, batch in enumerate(train_pbar):
            # Move batch data to the device
            src_ids = batch["input_ids"].to(device)
            src_mask = batch["attention_mask"].to(device) # Encoder padding mask source
            depth_ids = batch["depth_ids"].to(device)
            labels = batch["labels"].to(device)          # Target sequence (B, T)

            # Prepare decoder inputs (shifted right) and targets
            # Decoder input: BOS token + sequence (excluding last token)
            # Decoder target: sequence (excluding first token) + EOS token (implicitly handled by loss shifting)
            # For CrossEntropyLoss, logits (B, T, V) should align with targets (B, T)
            tgt_input = labels[:, :-1] # (B, T-1), excludes last token
            tgt_output = labels[:, 1:]  # (B, T-1), excludes first token (BOS/CLS)

            optimizer.zero_grad()

            # Forward pass
            logits = model(src_input_ids=src_ids,
                           src_attention_mask=src_mask,
                           src_depth_ids=depth_ids,
                           tgt_input_ids=tgt_input) # Pass shifted input (B, T-1, V)

            # Calculate loss - compare logits with shifted *output* targets
            loss = criterion(logits.reshape(-1, vocab_size), # (B*(T-1), V)
                             tgt_output.reshape(-1))         # (B*(T-1))

            # Backward pass and optimize
            loss.backward()
            # Optional: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            # if scheduler: scheduler.step() # Update LR if scheduler is used

            total_train_loss += loss.item()

            # Log training progress
            if (step + 1) % LOG_INTERVAL == 0:
                avg_loss = total_train_loss / LOG_INTERVAL
                # current_lr = scheduler.get_last_lr()[0] if scheduler else LEARNING_RATE
                current_lr = LEARNING_RATE # Simpler if no scheduler
                train_pbar.set_postfix({'Avg Loss': f'{avg_loss:.4f}', 'LR': f'{current_lr:.2e}'})
                total_train_loss = 0.0 # Reset accumulator

        train_pbar.close()

        # --- Validation Phase ---
        model.eval()
        total_val_accuracy = 0.0
        total_val_samples = 0
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1} Validation", leave=False)

        with torch.no_grad():
            for batch in val_pbar:
                src_ids = batch["input_ids"].to(device)
                src_mask = batch["attention_mask"].to(device)
                depth_ids = batch["depth_ids"].to(device)
                labels = batch["labels"].to(device)
                tgt_input = labels[:, :-1]
                tgt_output = labels[:, 1:]

                logits = model(src_ids, src_mask, depth_ids, tgt_input)

                batch_accuracy = calculate_sequence_accuracy(logits, tgt_output, pad_token_id)
                total_val_accuracy += batch_accuracy * src_ids.size(0) # Weighted by batch size
                total_val_samples += src_ids.size(0)

        val_pbar.close()
        epoch_val_accuracy = total_val_accuracy / total_val_samples if total_val_samples > 0 else 0.0
        print(f"Epoch {epoch+1} Validation Sequence Accuracy: {epoch_val_accuracy:.4f}")

        # --- Save Best Model Checkpoint ---
        if epoch_val_accuracy > best_val_accuracy:
            best_val_accuracy = epoch_val_accuracy
            best_epoch = epoch + 1
            try:
                torch.save(model.state_dict(), checkpoint_path)
                print(f"[INFO] New best model saved to {checkpoint_path} (Epoch {best_epoch}, Val Acc: {best_val_accuracy:.4f})")
            except Exception as e:
                print(f"[Error] Failed to save model checkpoint: {e}", file=sys.stderr)

    print(f"\n[INFO] Training complete. Best validation accuracy: {best_val_accuracy:.4f} at epoch {best_epoch}.")

    # --- 7. Final Test Evaluation ---
    print("\n--- Final Test Set Evaluation ---")
    if checkpoint_path.exists():
        print(f"[INFO] Loading best model from: {checkpoint_path}")
        try:
            model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        except Exception as e:
             print(f"[Error] Failed to load best model checkpoint: {e}. Evaluating with the final state.", file=sys.stderr)
    else:
        print("[Warning] Best model checkpoint not found. Evaluating with the final model state.", file=sys.stderr)

    model.eval()
    total_test_accuracy = 0.0
    total_test_samples = 0
    test_pbar = tqdm(test_loader, desc="Testing", leave=False)

    with torch.no_grad():
        for batch in test_pbar:
            src_ids = batch["input_ids"].to(device)
            src_mask = batch["attention_mask"].to(device)
            depth_ids = batch["depth_ids"].to(device)
            labels = batch["labels"].to(device)
            tgt_input = labels[:, :-1]
            tgt_output = labels[:, 1:]

            logits = model(src_ids, src_mask, depth_ids, tgt_input)

            batch_accuracy = calculate_sequence_accuracy(logits, tgt_output, pad_token_id)
            total_test_accuracy += batch_accuracy * src_ids.size(0)
            total_test_samples += src_ids.size(0)

    test_pbar.close()
    final_test_accuracy = total_test_accuracy / total_test_samples if total_test_samples > 0 else 0.0
    print(f"\nFinal Test Sequence Accuracy: {final_test_accuracy:.4f}")

    print("\n[INFO] Script finished successfully.")

if __name__ == "__main__":
    # Optional: Seed everything
    # SEED = 42
    # torch.manual_seed(SEED); np.random.seed(SEED); import random; random.seed(SEED)
    # if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)
    main()

[INFO] Starting Depth-Aware Transformer Script...
[INFO] Current time: 2025-04-08 17:47:44 UTC
[INFO] Location context: San Diego, CA, USA
[INFO] Loading data from: qed_expressions_train.jsonl
[INFO] Successfully loaded 12441 records.
[INFO] Loading data from: qed_expressions_val.jsonl
[INFO] Successfully loaded 1555 records.
[INFO] Loading data from: qed_expressions_test.jsonl
[INFO] Successfully loaded 1556 records.
[INFO] Training new BPE tokenizer (Vocab: 8000, Min Freq: 2)...
[INFO] Training from iterator...


[Error] Tokenizer training failed: tokenizers.trainers.BpeTrainer() got multiple values for keyword argument 'special_tokens'
[FATAL] Tokenizer training or loading failed: tokenizers.trainers.BpeTrainer() got multiple values for keyword argument 'special_tokens'


SystemExit: 1

# 3.7: Foundation models for symbolic regression tasks
Model: Novel foundation model for symbolic regression tasks. Should be sufficiently novel beyond the current literature.


In [13]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Fine-tuning Script for a T5 Model with Mixture-of-Experts (MoE) Layers
for Symbolic Regression (Squared Amplitude Calculation).

This script demonstrates modifying a standard T5 (`t5-small`) model by replacing
its Feed-Forward Networks (FFNs) with custom Mixture-of-Experts (MoE) layers.
The MoE layers potentially allow different "expert" networks within the model to
specialize in processing different types of symbolic patterns found in the input
sequences (e.g., polynomials, trigonometric functions).

The script utilizes the Hugging Face Transformers library, including the
`Seq2SeqTrainer` API for streamlined training and evaluation.

Key features demonstrated:
- T5 Base Model: Uses `t5-small` as the starting point.
- Custom MoE Layer: Defines and integrates a `MoEFeedForward` module.
- Model Patching: Dynamically replaces the standard FFNs in the loaded T5 model
  with the custom MoE layers.
- Advanced Training Techniques: Incorporates Gradient Checkpointing, Mixed
  Precision Training (FP16), and Label Smoothing.
- Standard Workflow: Follows data loading, preprocessing, tokenization, model
  configuration, training, and evaluation steps.
- Exact Match Accuracy (Logit-based): Evaluates performance using sequence
  accuracy calculated directly from model logits (`predict_with_generate=False`).

Workflow:
1. Load pre-split data from JSONL files.
2. Reconstruct source and target strings from token lists.
3. Initialize the T5 tokenizer.
4. Define the custom `MoEFeedForward` nn.Module.
5. Load the base T5 model and patch its encoder/decoder blocks by replacing
   standard FFN layers with instances of `MoEFeedForward`.
6. Configure the modified model (special tokens, gradient checkpointing).
7. Define a function to tokenize source/target string pairs.
8. Wrap tokenized data into PyTorch Dataset objects.
9. Instantiate `DataCollatorForSeq2Seq`.
10. Define the `compute_metrics` function for sequence accuracy from logits.
11. Configure `Seq2SeqTrainingArguments`.
12. Initialize and run the `Seq2SeqTrainer`.
13. Evaluate the final model on the test set.
"""

import json
import sys
import numpy as np
import torch
import datetime
import os
from pathlib import Path
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    T5Tokenizer,
    T5ForConditionalGeneration,
    T5Config, # To inspect T5 block structure if needed
    T5Block, # To check layer types
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    TrainingArguments # Explicit import
)
from tqdm.auto import tqdm # Optional progress bars

# --- Configuration ---

# File Paths
TRAIN_FILE = Path('qed_expressions_train.jsonl') # Adjust if needed
VAL_FILE   = Path('qed_expressions_val.jsonl')
TEST_FILE  = Path('qed_expressions_test.jsonl')
OUTPUT_DIR = Path('t5_moe_symbolic_regression_output') # Output directory

# Model Configuration
MODEL_ID = "t5-small" # Base T5 model identifier
TOKENIZER_ID = "t5-small"

# MoE Configuration
NUM_EXPERTS = 4             # Number of experts in each MoE layer
# d_ff (intermediate FFN dimension) will be taken from the base T5 config

# Tokenizer and Data Processing Configuration
MAX_SEQ_LENGTH = 128        # Max sequence length

# Training Hyperparameters
NUM_TRAIN_EPOCHS = 5
PER_DEVICE_TRAIN_BATCH_SIZE = 8
PER_DEVICE_EVAL_BATCH_SIZE = 8
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 0.01
WARMUP_STEPS = 200
LOGGING_STEPS = 50
EVALUATION_STRATEGY = "epoch"
SAVE_STRATEGY = "epoch"

# Advanced Training Feature Flags/Values
USE_FP16 = torch.cuda.is_available()
USE_GRADIENT_CHECKPOINTING = True
LABEL_SMOOTHING_FACTOR = 0.1

# --- Helper Functions and Classes ---

def load_jsonl(file_path):
    """Loads data from a JSON Lines file, handling basic errors."""
    data = []
    file_path = Path(file_path)
    print(f"[INFO] Loading data from: {file_path}")
    if not file_path.is_file():
        print(f"[Error] Data file not found: {file_path}", file=sys.stderr)
        raise FileNotFoundError(f"File not found: {file_path}")
    try:
        with file_path.open('r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                line = line.strip()
                if line:
                    try: data.append(json.loads(line))
                    except json.JSONDecodeError as e: print(f"[Warning] Skipping invalid JSON on line {i+1} in {file_path}: {e}")
        print(f"[INFO] Successfully loaded {len(data)} records.")
        if not data: print(f"[Warning] No valid records loaded from {file_path}.")
        return data
    except Exception as e:
        print(f"[Error] Failed to load data from {file_path}: {e}", file=sys.stderr)
        raise

def convert_tokens_to_strings(raw_data_list):
    """Converts lists of tokens from loaded data into source and target strings."""
    source_strings, target_strings, skipped_count = [], [], 0
    if not raw_data_list: return source_strings, target_strings
    for i, item in enumerate(raw_data_list):
        input_toks, target_toks = item.get('input_tokens'), item.get('target_tokens')
        if isinstance(input_toks, list) and isinstance(target_toks, list):
            source_strings.append(" ".join(map(str, input_toks)))
            target_strings.append(" ".join(map(str, target_toks)))
        else:
            print(f"[Warning] Skipping item at index {i} due to missing/invalid data: {item}"); skipped_count += 1
    print(f"[INFO] Converted {len(source_strings)} items to strings (skipped {skipped_count}).")
    if not source_strings: print("[Warning] No items successfully converted.")
    return source_strings, target_strings

def encode_sequences(tokenizer, source_texts, target_texts, max_len):
    """Tokenizes source and target sequences, returning lists of IDs/masks."""
    print(f"[INFO] Encoding sequence pairs with max_length={max_len}...")
    encoder_inputs = tokenizer(source_texts, max_length=max_len, padding='max_length', truncation=True, return_tensors=None)
    with tokenizer.as_target_tokenizer():
        decoder_labels = tokenizer(target_texts, max_length=max_len, padding='max_length', truncation=True, return_tensors=None)
    encoder_inputs['labels'] = decoder_labels['input_ids']
    print(f"[INFO] Encoding complete.")
    if not encoder_inputs.get('input_ids') or not encoder_inputs.get('labels') or len(encoder_inputs['input_ids']) != len(encoder_inputs['labels']):
         print("[Warning] Encoding resulted in empty lists or length mismatch.")
    return encoder_inputs

class SequencePairDataset(Dataset):
    """Simple PyTorch Dataset for holding tokenized sequence pair data (as lists)."""
    def __init__(self, encodings):
        if not isinstance(encodings, dict) or 'input_ids' not in encodings: raise ValueError("Invalid encodings format.")
        self.encodings = encodings
        try:
            self.length = len(encodings['input_ids'])
            for key in encodings:
                 if not isinstance(encodings[key], list) or len(encodings[key]) != self.length: raise ValueError(f"Inconsistent length for key '{key}'.")
        except Exception as e: raise ValueError(f"Validation failed: {e}")
        if self.length == 0: raise ValueError("Input encodings are empty.")
        print(f"[INFO] Created Dataset with {self.length} examples.")

    def __len__(self): return self.length
    def __getitem__(self, idx):
        if not 0 <= idx < self.length: raise IndexError(f"Index {idx} out of bounds.")
        try: return {key: self.encodings[key][idx] for key in self.encodings}
        except Exception as e: print(f"[Error] Failed retrieval at index {idx}: {e}", file=sys.stderr); raise

# --- Custom MoE Layer ---

class MoEFeedForward(nn.Module):
    """
    Mixture of Experts Feed-Forward Network Layer.

    Routes input tokens to different 'expert' FFNs based on a gating mechanism.
    The gating mechanism here uses the mean-pooled representation of the sequence
    to decide the weights for combining expert outputs.
    """
    def __init__(self, d_model, d_ff, n_experts=4, dropout_rate=0.1):
        """
        Initializes the MoE layer.

        Args:
            d_model (int): Input and output dimension of the layer.
            d_ff (int): Intermediate dimension of the expert FFNs.
            n_experts (int): Number of expert networks.
            dropout_rate (float): Dropout rate for expert FFNs.
        """
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        self.n_experts = n_experts

        # Define the expert networks (simple two-layer MLPs)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_ff),
                nn.ReLU(), # Or GeLU, SiLU etc. matching base model if known
                nn.Dropout(dropout_rate), # Add dropout within experts
                nn.Linear(d_ff, d_model)
            ) for _ in range(n_experts)
        ])

        # Gating network: maps pooled sequence representation to expert weights
        # Takes mean across sequence length dimension as input
        self.gate = nn.Linear(d_model, n_experts)

    def forward(self, hidden_states):
        """
        Forward pass through the MoE layer.

        Args:
            hidden_states (torch.Tensor): Input tensor (batch_size, seq_len, d_model).

        Returns:
            torch.Tensor: Output tensor (batch_size, seq_len, d_model).
        """
        # hidden_states: (B, S, D)
        batch_size, seq_len, d_model = hidden_states.shape

        # 1. Gating: Use mean pooling for simplicity to get sequence-level weights
        # Alternative: Could implement token-level gating for finer control.
        pooled_hidden_states = hidden_states.mean(dim=1) # (B, D)

        # Calculate gating scores (logits) and normalize to probabilities (weights)
        gating_logits = self.gate(pooled_hidden_states) # (B, N_experts)
        gating_weights = torch.softmax(gating_logits, dim=-1) # (B, N_experts)

        # 2. Expert Evaluation
        # Pass the full sequence through each expert
        expert_outputs = []
        for expert in self.experts:
            expert_outputs.append(expert(hidden_states))
        # Stack expert outputs: List[ (B, S, D) ] -> Tensor(N_experts, B, S, D)
        expert_outputs_stacked = torch.stack(expert_outputs, dim=0)

        # 3. Weighted Combination
        # Reshape weights for broadcasting: (B, N_experts) -> (B, N_experts, 1, 1)
        gating_weights_reshaped = gating_weights.unsqueeze(-1).unsqueeze(-1)
        # Reshape expert outputs for broadcasting: (N_experts, B, S, D) -> (B, N_experts, S, D)
        expert_outputs_permuted = expert_outputs_stacked.permute(1, 0, 2, 3)

        # Weighted sum: (B, N_experts, 1, 1) * (B, N_experts, S, D) -> sum over N_experts axis
        # Result: (B, S, D)
        output = torch.sum(gating_weights_reshaped * expert_outputs_permuted, dim=1)

        return output


def replace_ffn_with_moe(model, n_experts):
    """
    Replaces the standard FFN layers in a T5 model with MoEFeedForward layers.

    Iterates through encoder and decoder blocks and replaces the appropriate layer.
    Requires knowledge of the T5Block structure.

    Args:
        model: The T5ForConditionalGeneration model instance.
        n_experts (int): Number of experts for the MoE layers.

    Returns:
        The modified model instance.
    """
    print(f"[INFO] Patching T5 model with MoE layers (Experts: {n_experts})...")
    replaced_count = 0
    try:
        # Encoder blocks
        if hasattr(model, 'encoder') and hasattr(model.encoder, 'block'):
            for i, block in enumerate(model.encoder.block):
                # T5Block structure: [LayerNorm, SelfAttention, Dropout], [LayerNorm, DenseReluDense, Dropout]
                # The FFN (`DenseReluDense`) is typically the second layer in the second sub-block group.
                # Accessing via attribute name is safer if known, e.g., block.layer[1].DenseReluDense
                # Let's try accessing the assumed attribute location and verify type.
                ffn_layer_attr = None
                if hasattr(block, 'layer') and len(block.layer) > 1 and \
                   hasattr(block.layer[1], 'DenseReluDense') and \
                   isinstance(block.layer[1].DenseReluDense, nn.Module): # Check if it looks like the FFN component
                    ffn_layer_attr = block.layer[1].DenseReluDense
                    layer_target = block.layer[1]
                    attr_name = 'DenseReluDense'

                if ffn_layer_attr is not None:
                    print(f"  - Replacing Encoder Block {i} FFN...")
                    # Extract config needed for MoE layer
                    d_model = model.config.d_model
                    d_ff = model.config.d_ff
                    dropout_rate = model.config.dropout_rate # Get dropout from config

                    # Create and assign the new MoE layer
                    setattr(layer_target, attr_name, MoEFeedForward(d_model, d_ff, n_experts, dropout_rate))
                    replaced_count += 1
                else:
                     print(f"  - [Warning] Could not find/replace FFN in Encoder Block {i} structure.")

        # Decoder blocks
        if hasattr(model, 'decoder') and hasattr(model.decoder, 'block'):
            for i, block in enumerate(model.decoder.block):
                # T5DecoderBlock structure: [LN, SelfAttn, Drop], [LN, CrossAttn, Drop], [LN, FFN(DenseReluDense), Drop]
                # The FFN is typically the second layer in the *third* sub-block group.
                ffn_layer_attr = None
                if hasattr(block, 'layer') and len(block.layer) > 2 and \
                   hasattr(block.layer[2], 'DenseReluDense') and \
                   isinstance(block.layer[2].DenseReluDense, nn.Module):
                    ffn_layer_attr = block.layer[2].DenseReluDense
                    layer_target = block.layer[2]
                    attr_name = 'DenseReluDense'

                if ffn_layer_attr is not None:
                    print(f"  - Replacing Decoder Block {i} FFN...")
                    d_model = model.config.d_model
                    d_ff = model.config.d_ff
                    dropout_rate = model.config.dropout_rate

                    setattr(layer_target, attr_name, MoEFeedForward(d_model, d_ff, n_experts, dropout_rate))
                    replaced_count += 1
                else:
                    print(f"  - [Warning] Could not find/replace FFN in Decoder Block {i} structure.")

        if replaced_count == 0:
             print("[Warning] No FFN layers were replaced. Check model structure and replacement logic.")
        else:
             print(f"[INFO] Successfully replaced {replaced_count} FFN layers with MoE.")

    except Exception as e:
        print(f"[Error] Failed during model patching: {e}", file=sys.stderr)
        # Depending on severity, may want to raise or exit
    return model


# --- Main Execution Logic ---

def main():
    """Orchestrates data loading, MoE model patching, training, and evaluation."""
    print("[INFO] Starting T5 with Mixture-of-Experts Fine-tuning Script...")
    # --- ***** Update Timestamp/Location ***** ---
    try:
        current_time = datetime.datetime.now(datetime.timezone.utc)
        current_time_str = current_time.strftime('%Y-%m-%d %H:%M:%S %Z')
    except Exception:
        current_time = datetime.datetime.utcnow()
        current_time_str = current_time.strftime('%Y-%m-%d %H:%M:%S UTC (naive)')
    print(f"[INFO] Current time: {current_time_str}")
    print(f"[INFO] Location context: San Diego, CA, USA")
    print(f"[INFO] Using Base Model: {MODEL_ID}")

    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    # --- 1. Load Data ---
    try:
        train_raw = load_jsonl(TRAIN_FILE)
        val_raw = load_jsonl(VAL_FILE)
        test_raw = load_jsonl(TEST_FILE)
    except Exception as e: print(f"[FATAL] Data loading failed: {e}", file=sys.stderr); sys.exit(1)
    if not train_raw or not val_raw or not test_raw: print("[FATAL] Datasets empty after loading.", file=sys.stderr); sys.exit(1)

    # --- 2. Prepare Data Strings ---
    train_src, train_tgt = convert_tokens_to_strings(train_raw)
    val_src,   val_tgt   = convert_tokens_to_strings(val_raw)
    test_src,  test_tgt  = convert_tokens_to_strings(test_raw)
    if not train_src or not val_src or not test_src: print("[FATAL] Data conversion failed.", file=sys.stderr); sys.exit(1)

    # --- 3. Initialize Tokenizer ---
    global tokenizer_for_metrics # For metrics function access
    try:
        tokenizer = T5Tokenizer.from_pretrained(TOKENIZER_ID)
        # T5 generally uses <pad>, </s>, <unk> - check if needed
        if tokenizer.pad_token is None: tokenizer.add_special_tokens({'pad_token': '<pad>'})
        # T5 uses EOS as decoder start token ID by default, check if BOS needed
        # if tokenizer.bos_token is None: tokenizer.add_special_tokens({'bos_token': '<s>'})
        tokenizer_for_metrics = tokenizer
        pad_token_id = tokenizer.pad_token_id
    except Exception as e: print(f"[FATAL] Tokenizer init failed: {e}", file=sys.stderr); sys.exit(1)
    print(f"[INFO] Tokenizer initialized (Vocab: {tokenizer.vocab_size}, Pad ID: {pad_token_id}).")

    # --- 4. Load Base Model and Patch with MoE ---
    print(f"[INFO] Loading base model: {MODEL_ID}")
    try:
        model = T5ForConditionalGeneration.from_pretrained(MODEL_ID)

        # Patch the loaded model
        model = replace_ffn_with_moe(model, n_experts=NUM_EXPERTS)

        # Resize embeddings if vocab changed
        if model.config.vocab_size != len(tokenizer):
             print(f"[INFO] Resizing model embeddings to {len(tokenizer)}")
             model.resize_token_embeddings(len(tokenizer))
             model.config.vocab_size = len(tokenizer)

        # Configure standard seq2seq settings (T5 specific)
        model.config.decoder_start_token_id = tokenizer.pad_token_id # T5 starts decoder with pad
        model.config.eos_token_id = tokenizer.eos_token_id
        model.config.pad_token_id = tokenizer.pad_token_id

        # Configure generation defaults
        model.config.max_length = MAX_SEQ_LENGTH
        model.config.num_beams = 1 # Default to greedy for logits; beams set in TrainingArgs if needed

        # Enable Gradient Checkpointing if configured
        if USE_GRADIENT_CHECKPOINTING:
            try:
                 model.gradient_checkpointing_enable()
                 print("[INFO] Gradient Checkpointing enabled on the model.")
            except Exception as gc_e:
                 print(f"[Warning] Failed to enable gradient checkpointing on model: {gc_e}")
                 global USE_GRADIENT_CHECKPOINTING_EFFECTIVE
                 USE_GRADIENT_CHECKPOINTING_EFFECTIVE = False
        else:
             USE_GRADIENT_CHECKPOINTING_EFFECTIVE = False


    except Exception as e:
        print(f"[FATAL] Failed to initialize or patch the T5 model: {e}", file=sys.stderr)
        sys.exit(1)


    # --- 5. Tokenize Data ---
    try:
        train_encodings = encode_sequences(tokenizer, train_src, train_tgt, MAX_SEQ_LENGTH)
        val_encodings   = encode_sequences(tokenizer, val_src,   val_tgt,   MAX_SEQ_LENGTH)
        test_encodings  = encode_sequences(tokenizer, test_src,  test_tgt,  MAX_SEQ_LENGTH)
    except Exception as e: print(f"[FATAL] Data tokenization failed: {e}", file=sys.stderr); sys.exit(1)
    if not train_encodings.get('input_ids') or not val_encodings.get('input_ids') or not test_encodings.get('input_ids'):
        print("[FATAL] Tokenization resulted in empty encodings.", file=sys.stderr); sys.exit(1)

    # --- 6. Create Datasets ---
    try:
        train_dataset = SequencePairDataset(train_encodings)
        val_dataset   = SequencePairDataset(val_encodings)
        test_dataset  = SequencePairDataset(test_encodings)
    except ValueError as e: print(f"[FATAL] Dataset creation failed: {e}", file=sys.stderr); sys.exit(1)

    # --- 7. Initialize Data Collator ---
    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=model,
        label_pad_token_id=-100, # Ignore pad tokens in loss
        pad_to_multiple_of=8 if USE_FP16 else None
    )
    print("[INFO] Data collator initialized.")

    # --- 8. Define Metrics Computation (from Logits) ---
    def compute_metrics_fn(eval_pred):
        """Calculates exact sequence match accuracy from logits."""
        logits, labels = eval_pred # Trainer passes logits when predict_with_generate=False
        if not isinstance(logits, np.ndarray): logits = np.array(logits)
        if not isinstance(labels, np.ndarray): labels = np.array(labels)

        # Replace -100 in labels for accuracy calculation comparison if needed,
        # but accuracy function should handle pad_token_id masking.
        # labels_for_acc = np.where(labels != -100, labels, tokenizer_for_metrics.pad_token_id)

        try:
            # Use PyTorch tensor version for consistency and potential device placement
            logits_torch = torch.from_numpy(logits)
            labels_torch = torch.from_numpy(labels)
            # Pass the original labels (with -100 potentially) and pad_token_id
            accuracy = calculate_sequence_accuracy_torch(logits_torch, labels_torch, tokenizer_for_metrics.pad_token_id)
            return {"sequence_accuracy": accuracy}
        except Exception as met_e:
             print(f"[Error] compute_metrics failed: {met_e}", file=sys.stderr)
             return {"sequence_accuracy": 0.0}

    def calculate_sequence_accuracy_torch(logits, labels, pad_token_id):
        """Calculates exact sequence match accuracy using PyTorch tensors."""
        # Ensure inputs are tensors
        if not isinstance(logits, torch.Tensor): logits = torch.tensor(logits)
        if not isinstance(labels, torch.Tensor): labels = torch.tensor(labels)

        if logits.shape[0] == 0: return 0.0
        predictions = torch.argmax(logits, dim=-1) # (Batch, SeqLen)

        # Create mask for non-padding tokens in labels (use pad_token_id directly)
        # Labels might contain -100, treat them as padding for accuracy check
        non_pad_mask = (labels != pad_token_id) & (labels != -100)

        correct_tokens = (predictions == labels) & non_pad_mask
        correct_sequences = (torch.sum(correct_tokens, dim=1) == torch.sum(non_pad_mask, dim=1))
        accuracy = torch.mean(correct_sequences.float())
        return accuracy.item()


    # --- 9. Define Training Arguments ---
    effective_gc = USE_GRADIENT_CHECKPOINTING if 'USE_GRADIENT_CHECKPOINTING_EFFECTIVE' not in globals() else USE_GRADIENT_CHECKPOINTING_EFFECTIVE
    training_args = Seq2SeqTrainingArguments(
        output_dir=str(OUTPUT_DIR),
        # Schedule
        num_train_epochs=NUM_TRAIN_EPOCHS,
        per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
        per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        warmup_steps=WARMUP_STEPS,
        # Logging / Saving / Evaluation
        logging_dir=str(OUTPUT_DIR / 'logs'),
        logging_steps=LOGGING_STEPS,
        evaluation_strategy=EVALUATION_STRATEGY,
        save_strategy=SAVE_STRATEGY,
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model="sequence_accuracy",
        greater_is_better=True,
        # Use logits for evaluation (faster if only accuracy needed)
        predict_with_generate=False,
        # Advanced features
        fp16=USE_FP16,
        label_smoothing_factor=LABEL_SMOOTHING_FACTOR,
        gradient_checkpointing=effective_gc,
        # report_to="tensorboard", # Optional
    )
    print("[INFO] Training arguments defined.")
    print(f"[INFO] Effective Mixed Precision (FP16): {'Enabled' if USE_FP16 else 'Disabled'}")
    print(f"[INFO] Effective Label Smoothing Factor: {LABEL_SMOOTHING_FACTOR}")
    print(f"[INFO] Effective Gradient Checkpointing in Args: {training_args.gradient_checkpointing}")

    # --- 10. Initialize Trainer ---
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics_fn,
    )
    print("[INFO] Seq2SeqTrainer initialized.")

    # --- 11. Train the Model ---
    print(f"[INFO] Starting model training ({NUM_TRAIN_EPOCHS} epochs)...")
    try:
        train_result = trainer.train()
        metrics = train_result.metrics
        trainer.log_metrics("train", metrics); trainer.save_metrics("train", metrics)
        trainer.save_state()
        print("[INFO] Training finished successfully.")
        if trainer.state.best_model_checkpoint: print(f"[INFO] Best model checkpoint: {trainer.state.best_model_checkpoint}")
    except Exception as e: print(f"[FATAL] Training failed: {e}", file=sys.stderr); sys.exit(1)

    # --- 12. Evaluate on Test Set ---
    print("[INFO] Evaluating final model on the test set...")
    try:
        test_results = trainer.evaluate(eval_dataset=test_dataset, metric_key_prefix="test")
        trainer.log_metrics("test", test_results); trainer.save_metrics("test", test_results)
        if 'test_sequence_accuracy' in test_results:
             print(f"\n--- Test Set Results ---"); print(f"Test Sequence Accuracy: {test_results['test_sequence_accuracy']:.4f}"); print(f"-----------------------------")
        else: print("[Warning] Test accuracy metric not found.", "Results:", test_results)
    except Exception as e: print(f"[Error] Final evaluation failed: {e}", file=sys.stderr); sys.exit(1)

    print("\n[INFO] Script finished successfully.")

if __name__ == "__main__":
    # Optional: Seeds for reproducibility
    # SEED = 42; random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
    # if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)
    main()

ImportError: cannot import name 'T5Block' from 'transformers' (/home/nikitas/anaconda3/envs/torch_cpu_env/lib/python3.10/site-packages/transformers/__init__.py)