# Symbolic AI Tests 2025

In [None]:
# YAML for conda env for Task 1

# # environment.yml
# # Conda environment definition for the Feynman Dataset Preprocessing script.
# #
# # This environment contains the packages needed to run the Python script that
# # downloads the Feynman data, preprocesses equations, trains a tokenizer,
# # and maps feature files.
# #
# # To create this environment:
# # conda env create -f environment.yml
# #
# # To activate this environment:
# # conda activate feynman_preprocess_env
# #
# # To update the environment using this file (after saving changes):
# # conda activate feynman_preprocess_env
# # conda env update --file environment.yml --prune
# #
# # To remove the environment:
# # conda deactivate
# # conda env remove -n feynman_preprocess_env

# name: feynman_preprocess_env # Changed name to reflect the specific task

# channels:
#   - conda-forge   # Primary channel for broad package availability
#   - defaults      # Default conda channel

# dependencies:
#   # --- Python Version ---
#   # Using Python 3.10 - a well-supported recent version. Adjust if needed (e.g., 3.9, 3.11).
#   - python=3.10

#   # --- Core Data Handling & I/O ---
#   - pandas      # For reading CSV files
#   - numpy       # For loading numeric data (optional, but used in load_numeric_data)
#   - requests    # For downloading the dataset files from URLs

#   # --- Pip Tool ---
#   - pip         # Required for installing pip packages listed below

#   # --- Optional Development Tools ---
#   # Uncomment these lines if you plan to use Jupyter notebooks within this environment
#   # - jupyter
#   # - ipykernel

#   # --- Pip dependencies ---
#   # Install specific libraries, often those frequently updated or primarily distributed via pip
#   - pip:
#     - transformers  # Hugging Face library (needed for PreTrainedTokenizerFast)
#     - tokenizers    # Hugging Face library (needed for ByteLevelBPETokenizer and training)

# # Notes:
# # - This environment focuses solely on the requirements of the preprocessing script.
# # - It does *not* include PyTorch or other deep learning frameworks. If you need those
# #   for model training later, you can either:
# #     a) Add them to this file (e.g., add 'pytorch::pytorch', 'pytorch::torchvision', etc.
# #        under dependencies and potentially add the 'pytorch'/'nvidia' channels back).
# #     b) Create a separate environment for training that includes these packages.

# Common Task 1.1. Dataset preprocessing 
Dataset:

https://space.mit.edu/home/tegmark/aifeynman.html 
Note: The authors of this dataset are not affiliated with ML4SCI

Description:
Download the Feynman_with_units.tar.gz features and corresponding FeynmanEquations.csv targets. Preprocess and tokenize the target data and document your rationale for choice of tokenization.



In [2]:
import sys
import tarfile # Keep for potential type hints if needed, but logic removed
import requests # Keep for potential type hints if needed, but logic removed
import shutil # Keep for potential type hints if needed, but logic removed
from pathlib import Path
from typing import Iterator, Optional, List, Dict, Any
import logging
import os # Added to check environment variable

# Set up basic logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Make sure to install required libraries:
# pip install pandas requests tokenizers transformers numpy markupsafe
try:
    import pandas as pd
    from tokenizers import ByteLevelBPETokenizer # type: ignore
    from transformers import PreTrainedTokenizerFast
    import numpy as np
    import markupsafe # Explicitly import to check if installed
except ImportError as e:
    logging.error(f"Error importing libraries: {e}")
    logging.error("Please ensure required libraries are installed in the 'feynman_tokenizer_env' environment:")
    logging.error("  conda install pandas numpy requests markupsafe -c conda-forge")
    logging.error("  pip install transformers tokenizers")
    sys.exit(1)

# --- Configuration --- MODIFIED TO USE LOCAL PATHS ---

# Assume the script is run from the project root directory (e.g., ~/Desktop/Miche/GOOGLE)
PROJECT_ROOT = Path(".")

# Direct paths to existing data within the project root directory
EQUATIONS_CSV_PATH = PROJECT_ROOT / "FeynmanEquations.csv"
FEATURES_EXTRACTED_PATH = PROJECT_ROOT / "Feynman_with_units" # The existing unpacked dir

# Output directory for the tokenizer (will be created in project root)
TOKENIZER_OUTPUT_DIR = PROJECT_ROOT / "feynman_tokenizer"

# --- Original URLs and Paths (Commented out - No longer used) ---
# EQUATIONS_CSV_URL = "https://space.mit.edu/home/tegmark/aifeynman/FeynmanEquations.csv"
# FEATURES_TAR_URL = "https://space.mit.edu/home/tegmark/aifeynman/Feynman_with_units.tar.gz"
# BASE_DATA_DIR = Path("./feynman_data") # No longer using this subdir for inputs
# EQUATIONS_CSV_PATH_ORIG = BASE_DATA_DIR / "FeynmanEquations.csv"
# FEATURES_TAR_PATH = BASE_DATA_DIR / "Feynman_with_units.tar.gz"
# FEATURES_DIR_NAME = "Feynman_with_units"
# FEATURES_EXTRACTED_PATH_ORIG = BASE_DATA_DIR / FEATURES_DIR_NAME

# CSV Column Names
FILENAME_COLUMN = 'Filename'
EQUATION_COLUMN = 'Formula'

# Tokenizer Settings
VOCAB_SIZE = 10_000
MIN_FREQUENCY = 2
SPECIAL_TOKENS = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]

# --- Helper Functions --- (Download and Unpack are no longer needed by main logic)

def download_file(url: str, destination: Path) -> bool:
    """Downloads a file from a URL to a destination path if it doesn't exist. (NO LONGER CALLED BY main)"""
    if destination.exists():
        logging.info(f"File already exists: {destination}")
        return True
    # ... (rest of download logic remains here in case needed elsewhere, but won't be executed)
    logging.info(f"Attempting download (called unexpectedly): {url} to {destination}")
    # Added security context modification attempt for DH_KEY_TOO_SMALL - USE WITH CAUTION
    # Requires specific OpenSSL version potentially. May not work.
    session = requests.Session()
    try:
        # Try forcing lower security level ciphers - WARNING: REDUCES SECURITY
        logging.warning("Attempting download with reduced SSL security settings (SECLEVEL=1) due to potential DH_KEY_TOO_SMALL error.")
        CIPHERS = ('DEFAULT@SECLEVEL=1')
        from requests.adapters import HTTPAdapter
        from urllib3.util.ssl_ import create_urllib3_context

        class CustomHttpAdapter(HTTPAdapter):
            def init_poolmanager(self, connections, maxsize, block=False):
                context = create_urllib3_context(ciphers=CIPHERS)
                self.poolmanager = requests.packages.urllib3.PoolManager(
                    num_pools=connections, maxsize=maxsize, block=block, ssl_context=context
                )
        session.mount('https://', CustomHttpAdapter())
    except Exception as e:
        logging.error(f"Could not apply custom SSL Adapter: {e}. Proceeding with default settings.")
        # Fallback to default session if adapter setup fails

    try:
        response = session.get(url, stream=True, timeout=120)
        response.raise_for_status()
        destination.parent.mkdir(parents=True, exist_ok=True)
        with open(destination, 'wb') as f:
            shutil.copyfileobj(response.raw, f)
        logging.info(f"Successfully downloaded {destination.name}")
        return True
    except requests.exceptions.SSLError as e:
        logging.error(f"SSL Error during download {url}: {e}")
        logging.error("This often happens if the server uses outdated security (e.g., DH_KEY_TOO_SMALL).")
        logging.error("Consider downloading the file manually via browser and placing it at the destination path.")
        if destination.exists(): destination.unlink(missing_ok=True)
        return False
    except requests.exceptions.RequestException as e:
        logging.error(f"Failed to download {url}: {e}")
        if destination.exists(): destination.unlink(missing_ok=True)
        return False
    except Exception as e:
        logging.error(f"An unexpected error occurred during download: {e}")
        if destination.exists(): destination.unlink(missing_ok=True)
        return False


def unpack_tar_gz(tar_path: Path, extract_to: Path) -> bool:
    """Unpacks a .tar.gz file to a specified directory. (NO LONGER CALLED BY main)"""
    if not tar_path.is_file():
        logging.error(f"Archive file not found or is not a file: {tar_path}")
        return False
    if extract_to.is_dir() and any(extract_to.iterdir()):
         logging.info(f"Target directory {extract_to} already exists and is not empty. Assuming unpacked.")
         return True
    # ... (rest of unpack logic remains here but won't be executed)
    logging.info(f"Attempting unpack (called unexpectedly): {tar_path.name} to {extract_to.parent}...")
    try:
        extract_to.parent.mkdir(parents=True, exist_ok=True)
        with tarfile.open(tar_path, "r:gz") as tar:
            tar.extractall(path=extract_to.parent)
            if extract_to.is_dir():
                 logging.info(f"Successfully unpacked. Target directory: {extract_to}")
                 return True
            else:
                 logging.warning(f"Tar file unpacked, but expected directory '{extract_to.name}' not found directly in '{extract_to.parent}'.")
                 # Simplified check - assuming failure if exact name doesn't match
                 return False
    except Exception as e:
        logging.error(f"An unexpected error occurred during unpacking: {e}")
        if extract_to.exists(): shutil.rmtree(extract_to, ignore_errors=True)
        return False


def equation_iterator(csv_file_path: Path, target_column: str) -> Iterator[str]:
    """Creates a generator to efficiently yield non-empty equations from a CSV column."""
    # (This function remains unchanged)
    chunk_size = 1000
    logging.info(f"Initializing equation iterator for column '{target_column}' from {csv_file_path}...")
    if not csv_file_path.is_file():
        logging.error(f"Equation CSV file not found: {csv_file_path}")
        return
    processed_count = 0
    try:
        chunk_iterator = pd.read_csv(
            csv_file_path, usecols=[target_column], chunksize=chunk_size,
            skipinitialspace=True, low_memory=False, dtype={target_column: str},
            on_bad_lines='warn'
        )
        for i, chunk in enumerate(chunk_iterator):
            if target_column not in chunk.columns:
                 logging.error(f"Column '{target_column}' not found in CSV chunk {i+1}. Stopping.")
                 return
            for equation in chunk[target_column]:
                if pd.isna(equation): continue
                eq_str = str(equation).strip()
                if not eq_str or eq_str.lower() == 'nan': continue
                yield eq_str
                processed_count += 1
        logging.info(f"Equation iterator finished yielding {processed_count} equations.")
    except Exception as e:
        logging.error(f"Failed reading CSV chunks from {csv_file_path}: {e}")
        return

def load_numeric_data(filename: str, paths_dict: Dict[str, Path]) -> Optional[np.ndarray]:
    """Loads numeric data from a file specified by its base filename using a path map."""
    # (This function remains unchanged)
    file_path = paths_dict.get(filename)
    if file_path is None: return None
    if not file_path.is_file(): return None
    try:
        data_array: np.ndarray = np.loadtxt(file_path)
        return data_array
    except Exception as e:
        logging.error(f"Failed to load numeric file {file_path}: {e}")
        return None

# --- Main Execution Logic --- MODIFIED ---

def main():
    """Runs the main script logic using locally available data."""
    logging.info("--- Feynman Dataset Preprocessing Script (Using Local Data) ---")

    # 1. Verify Local Data Paths
    logging.info("[Step 1] Verifying local data paths...")
    csv_ok = False
    features_dir_ok = False

    if EQUATIONS_CSV_PATH.is_file():
        logging.info(f"Found Equations CSV: {EQUATIONS_CSV_PATH.resolve()}")
        csv_ok = True
    else:
        logging.error(f"Equations CSV file not found at expected location: {EQUATIONS_CSV_PATH.resolve()}")
        logging.error("Please ensure 'FeynmanEquations.csv' is in the same directory as the script.")

    if FEATURES_EXTRACTED_PATH.is_dir():
        logging.info(f"Found unpacked Features directory: {FEATURES_EXTRACTED_PATH.resolve()}")
        features_dir_ok = True
    else:
        logging.error(f"Unpacked features directory not found at expected location: {FEATURES_EXTRACTED_PATH.resolve()}")
        logging.error("Please ensure the 'Feynman_with_units' directory (containing data files like I.10.7 etc.) is in the same directory as the script.")

    if not csv_ok:
        logging.error("Cannot proceed without the Equations CSV file.")
        sys.exit(1)
    # We can proceed with tokenization even if features dir is missing, but mapping will fail.

    # 2. Download Data (SKIPPED)
    logging.info("[Step 2] Download Data (Skipped - Using local files)")

    # 3. Unpack Features Archive (SKIPPED)
    logging.info("[Step 3] Unpack Features Archive (Skipped - Using local directory)")

    # 4. Train BPE Tokenizer (Uses verified EQUATIONS_CSV_PATH)
    logging.info(f"[Step 4] Training Byte-Level BPE Tokenizer on '{EQUATION_COLUMN}' column...")
    tokenizer_json_path: Optional[Path] = None
    tokenizer_trained = False

    # Check equation column exists
    try:
        df_head = pd.read_csv(EQUATIONS_CSV_PATH, nrows=0, skipinitialspace=True)
        if EQUATION_COLUMN not in df_head.columns:
            logging.error(f"Equation column '{EQUATION_COLUMN}' not found in {EQUATIONS_CSV_PATH}. Cannot train tokenizer.")
            sys.exit(1)
        logging.info(f"Confirmed equation column '{EQUATION_COLUMN}' exists.")
    except Exception as e:
        logging.error(f"Could not read or verify columns in {EQUATIONS_CSV_PATH}: {e}")
        sys.exit(1)

    # Create tokenizer output dir
    try:
        TOKENIZER_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
        logging.info(f"Tokenizer output directory: {TOKENIZER_OUTPUT_DIR.resolve()}")
    except OSError as e:
        logging.error(f"Could not create tokenizer output directory {TOKENIZER_OUTPUT_DIR}: {e}. Tokenizer training will fail.")
        sys.exit(1)

    logging.info(f"Tokenizer Settings: vocab_size={VOCAB_SIZE}, min_frequency={MIN_FREQUENCY}")
    try:
        bpe_tokenizer = ByteLevelBPETokenizer()
        eq_iter = equation_iterator(EQUATIONS_CSV_PATH, EQUATION_COLUMN)
        # Check iterator
        try:
            first_item = next(eq_iter)
            logging.info("Successfully retrieved first equation from iterator.")
        except StopIteration:
            logging.error("Equation iterator did not yield any data. Check CSV content.")
            first_item = None

        if first_item is not None:
            from itertools import chain
            full_iterator = chain([first_item], eq_iter)
            logging.info("Starting tokenizer training...")
            bpe_tokenizer.train_from_iterator(
                full_iterator, vocab_size=VOCAB_SIZE, min_frequency=MIN_FREQUENCY, special_tokens=SPECIAL_TOKENS
            )
            logging.info(f"Tokenizer training complete. Final vocab size: {bpe_tokenizer.get_vocab_size()}")
            bpe_tokenizer.save_model(str(TOKENIZER_OUTPUT_DIR))
            logging.info(f"Tokenizer vocabulary/merges saved to {TOKENIZER_OUTPUT_DIR}.")
            tokenizer_json_path = TOKENIZER_OUTPUT_DIR / "tokenizer.json"
            bpe_tokenizer.save(str(tokenizer_json_path))
            logging.info(f"Full tokenizer config saved to: {tokenizer_json_path}")
            tokenizer_trained = True
        else:
            logging.error("Cannot train tokenizer because no valid equations were found.")
    except Exception as e:
        logging.exception(f"Tokenizer training failed unexpectedly: {e}")

    # 5. Wrap with PreTrainedTokenizerFast
    logging.info("[Step 5] Wrapping tokenizer with Hugging Face PreTrainedTokenizerFast...")
    hf_tokenizer: Optional[PreTrainedTokenizerFast] = None
    if tokenizer_trained and tokenizer_json_path and tokenizer_json_path.is_file():
        try:
            # (Wrap logic remains the same)
            hf_tokenizer = PreTrainedTokenizerFast(
                tokenizer_file=str(tokenizer_json_path),
                bos_token=SPECIAL_TOKENS[0], eos_token=SPECIAL_TOKENS[2],
                unk_token=SPECIAL_TOKENS[3], pad_token=SPECIAL_TOKENS[1],
                mask_token=SPECIAL_TOKENS[4]
            )
            logging.info("Hugging Face tokenizer wrapper created successfully.")
            # (Testing logic remains the same)
            logging.info("Testing tokenizer encoding/decoding...")
            try:
                test_eq_iter = equation_iterator(EQUATIONS_CSV_PATH, EQUATION_COLUMN)
                first_equation = next(test_eq_iter, None)
                del test_eq_iter
                if first_equation:
                     logging.info(f"Test Equation Sample: {first_equation}")
                     tokens = hf_tokenizer.tokenize(first_equation)
                     logging.info(f" -> Tokens ({len(tokens)}): {tokens}")
                     encoded_ids = hf_tokenizer.encode(first_equation)
                     logging.info(f" -> Encoded IDs ({len(encoded_ids)}): {encoded_ids}")
                     decoded_clean = hf_tokenizer.decode(encoded_ids, skip_special_tokens=True)
                     logging.info(f" -> Decoded (clean): {decoded_clean}")
                else:
                     logging.warning("Could not retrieve an equation for tokenizer testing.")
            except Exception as e:
                logging.exception(f"Error during tokenizer test encode/decode: {e}")
        except Exception as e:
            logging.exception(f"Failed to load tokenizer with PreTrainedTokenizerFast: {e}")
    elif not tokenizer_trained:
        logging.warning("Skipping tokenizer wrapping: Tokenizer training failed or was skipped.")
    else:
        logging.warning(f"Skipping tokenizer wrapping: Tokenizer file '{tokenizer_json_path}' not found.")

    # 6. Map Filenames to Feature Files (Uses verified FEATURES_EXTRACTED_PATH)
    logging.info(f"[Step 6] Mapping CSV filenames to feature files in '{FEATURES_EXTRACTED_PATH}'...")
    numeric_file_paths: Dict[str, Path] = {}
    filenames: List[str] = []
    missing_files_count = 0

    # Load filenames from CSV (uses verified EQUATIONS_CSV_PATH)
    try:
        logging.info(f"Loading filenames from '{FILENAME_COLUMN}' column in {EQUATIONS_CSV_PATH}...")
        filenames_series = pd.read_csv(
            EQUATIONS_CSV_PATH, usecols=[FILENAME_COLUMN], skipinitialspace=True, dtype={FILENAME_COLUMN: str}
        )[FILENAME_COLUMN]
        filenames = [fn.strip() for fn in filenames_series if pd.notna(fn) and str(fn).strip()]
        logging.info(f"Loaded {len(filenames)} non-empty filenames.")
        if not filenames: logging.warning(f"No valid filenames found in CSV column '{FILENAME_COLUMN}'.")
    except Exception as e:
        logging.exception(f"Failed to read filenames column '{FILENAME_COLUMN}' from CSV: {e}")
        logging.warning("Skipping feature file mapping due to error reading filenames.")

    # Perform mapping if filenames loaded and features directory exists
    if filenames and features_dir_ok: # Use the features_dir_ok flag from Step 1
        logging.info(f"Checking for {len(filenames)} potential feature files in {FEATURES_EXTRACTED_PATH}...")
        for fn in filenames:
            potential_path = FEATURES_EXTRACTED_PATH / fn
            if potential_path.is_file():
                numeric_file_paths[fn] = potential_path.resolve() # Store absolute path
            else:
                missing_files_count += 1

        found_files_count = len(numeric_file_paths)
        logging.info(f"Finished checking {len(filenames)} filenames.")
        logging.info(f"Found {found_files_count} corresponding feature files.")
        if missing_files_count > 0:
            logging.warning(f"Did not find {missing_files_count} expected feature files in {FEATURES_EXTRACTED_PATH}.")

        if numeric_file_paths:
             example_fn = next(iter(numeric_file_paths))
             logging.info(f"Example mapping: '{example_fn}' -> '{numeric_file_paths[example_fn]}'")
        elif found_files_count == 0 and filenames:
             logging.warning(f"No feature files mapped despite having filenames. Check directory '{FEATURES_EXTRACTED_PATH}'.")

    elif not filenames:
         logging.info("Skipping feature mapping: No filenames loaded.")
    elif not features_dir_ok:
         logging.info(f"Skipping feature mapping: Feature directory '{FEATURES_EXTRACTED_PATH}' not found or verified.")

    # 7. Rationale and Usage Summary (Printed to console - No changes needed here)
    print("\n" + "="*70)
    print("          Rationale for Byte-Level BPE Tokenization")
    print("="*70)
    # (Rationale text remains the same)
    rationale = """
    Byte-Level Byte Pair Encoding (BBPE) was chosen for tokenizing the Feynman equations (targets) for several reasons:

    1.  **Handles Diverse Characters:** Equations contain a wide mix of mathematical symbols (e.g., +, -, *, /), Greek letters (e.g., \\theta, \\omega), numbers, standard letters (variable names), and LaTeX-like commands (e.g., \\sqrt, \\frac, ^, _). BBPE operates at the byte level initially, meaning it can handle *any* character without needing a predefined vocabulary of all possible symbols.

    2.  **Subword Information:** BBPE learns to merge frequent byte sequences into tokens. This allows it to represent common mathematical operators, function names (like 'sin', 'cos'), common variable fragments, and even parts of LaTeX commands as single tokens, while still being able to break down rare or unseen sequences into smaller, known units (subwords or individual bytes). This is beneficial for capturing structure within the equations.

    3.  **Robustness to Variations:** Unlike word-level tokenization (which would struggle with defining "words" in equations), BBPE is less sensitive to minor variations in notation or typos. It can tokenize novel combinations of symbols or slightly different variable names by breaking them down.

    4.  **Controlled Vocabulary Size:** While equations can be infinitely complex, the underlying set of characters and common mathematical constructs is limited. BBPE allows controlling the final vocabulary size (`VOCAB_SIZE`) by limiting the number of merge operations, preventing an excessively large vocabulary while capturing the most frequent and meaningful patterns.

    5.  **No Unknown Tokens (Almost):** Since it works at the byte level, BBPE inherently avoids the "unknown token" (<unk>) problem for individual characters. Unknown *sequences* are simply represented by the tokens corresponding to their constituent bytes or learned subwords. We still include <unk> as a special token for robustness or potential future use, but it's less critical than in word-level tokenizers.
    """
    print(rationale)
    print("="*70)
    print("                    Usage Summary")
    print("="*70)
    # (Usage summary text remains largely the same, paths updated)
    print("\n[Tokenizer Usage]")
    if hf_tokenizer:
        print(" - The Hugging Face tokenizer object is loaded and ready for use.")
        print(f" - Tokenizer files are saved in: {TOKENIZER_OUTPUT_DIR.resolve()}")
        print(" - Example:")
        print("   ```python")
        print("   # Assuming 'hf_tokenizer' is the loaded PreTrainedTokenizerFast object")
        print("   equation = 'F = G * m1 * m2 / r**2'")
        print("   encoded = hf_tokenizer(equation)")
        print("   print('Encoded IDs:', encoded['input_ids'])")
        print("   print('Decoded:', hf_tokenizer.decode(encoded['input_ids']))")
        print("   ```")
    else:
        print(" - Tokenizer training or loading FAILED.")
        print(f" - Check logs for errors. If training seemed to succeed, verify tokenizer files exist in: {TOKENIZER_OUTPUT_DIR.resolve()}")

    print("\n[Feature Data Usage]")
    if numeric_file_paths:
        print(" - The 'numeric_file_paths' dictionary maps filenames to feature file Paths.")
        print(f" - Feature files are located in: {FEATURES_EXTRACTED_PATH.resolve()}")
        print(f" - Total mapped files: {len(numeric_file_paths)}")
        print(" - Example:")
        print("   ```python")
        print("   # Assuming 'numeric_file_paths' is the dictionary and 'load_numeric_data' is defined")
        example_fn_usage = next(iter(numeric_file_paths))
        print(f"   filename = '{example_fn_usage}'")
        print(f"   if filename in numeric_file_paths:")
        print(f"       # data_array = load_numeric_data(filename, numeric_file_paths)")
        print(f"       file_path = numeric_file_paths[filename]")
        print(f"       print(f'Path to feature file: {{file_path}}')")
        print(f"       # data = np.loadtxt(file_path)")
        print("   ```")
        print("\n   Loading one example feature file for demonstration:")
        example_data = load_numeric_data(example_fn_usage, numeric_file_paths)
        if example_data is not None:
             print(f"   > Successfully loaded '{example_fn_usage}' with shape {example_data.shape}")
        else:
             print(f"   > Failed to load example file '{example_fn_usage}'. Check logs for errors.")
    else:
        print(" - Mapping filenames to feature files FAILED or resulted in zero mapped files.")
        print(" - Potential reasons:")
        print(f"   - CSV '{EQUATIONS_CSV_PATH.name}' missing filenames or '{FILENAME_COLUMN}' column.")
        print(f"   - Feature directory '{FEATURES_EXTRACTED_PATH.name}' not found or verified.")
        print(f"   - Feature files not found in the expected directory structure within: {FEATURES_EXTRACTED_PATH.resolve()}")
        print("   - Check script logs above for specific warnings or errors.")
        print(" - The 'numeric_file_paths' dictionary is empty.")

    print("\n" + "="*70)
    logging.info("--- Feynman Dataset Preprocessing Script (Using Local Data) FINISHED ---")


if __name__ == "__main__":
    main()

2025-04-22 11:46:59,566 - INFO - --- Feynman Dataset Preprocessing Script (Using Local Data) ---
2025-04-22 11:46:59,567 - INFO - [Step 1] Verifying local data paths...
2025-04-22 11:46:59,567 - INFO - Found Equations CSV: /home/nikitas/Desktop/Miche/GOOGLE/FeynmanEquations.csv
2025-04-22 11:46:59,567 - INFO - Found unpacked Features directory: /home/nikitas/Desktop/Miche/GOOGLE/Feynman_with_units
2025-04-22 11:46:59,567 - INFO - [Step 2] Download Data (Skipped - Using local files)
2025-04-22 11:46:59,567 - INFO - [Step 3] Unpack Features Archive (Skipped - Using local directory)
2025-04-22 11:46:59,567 - INFO - [Step 4] Training Byte-Level BPE Tokenizer on 'Formula' column...
2025-04-22 11:46:59,576 - INFO - Confirmed equation column 'Formula' exists.
2025-04-22 11:46:59,576 - INFO - Tokenizer output directory: /home/nikitas/Desktop/Miche/GOOGLE/feynman_tokenizer
2025-04-22 11:46:59,576 - INFO - Tokenizer Settings: vocab_size=10000, min_frequency=2
2025-04-22 11:46:59,576 - INFO - Ini





          Rationale for Byte-Level BPE Tokenization

    Byte-Level Byte Pair Encoding (BBPE) was chosen for tokenizing the Feynman equations (targets) for several reasons:

    1.  **Handles Diverse Characters:** Equations contain a wide mix of mathematical symbols (e.g., +, -, *, /), Greek letters (e.g., \theta, \omega), numbers, standard letters (variable names), and LaTeX-like commands (e.g., \sqrt, \frac, ^, _). BBPE operates at the byte level initially, meaning it can handle *any* character without needing a predefined vocabulary of all possible symbols.

    2.  **Subword Information:** BBPE learns to merge frequent byte sequences into tokens. This allows it to represent common mathematical operators, function names (like 'sin', 'cos'), common variable fragments, and even parts of LaTeX commands as single tokens, while still being able to break down rare or unseen sequences into smaller, known units (subwords or individual bytes). This is beneficial for capturing structure w

2025-04-22 11:46:59,881 - INFO - --- Feynman Dataset Preprocessing Script (Using Local Data) FINISHED ---


   > Successfully loaded 'I.6.2a' with shape (1000000, 2)



# Common Task 1.2 Dataset preprocessing
Dataset:

https://alabama.box.com/s/xhgr2onrn503jyse2fs5vxtapg0oifcs 

Dataset:
Download the dataset (split across 10 files) and preprocess and tokenize the target data and document your rationale for choice of tokenization. Data file is formatted with rows like 
“event type : Feynman diagram : amplitude : squared amplitude”
Here the amplitudes are the input sequences and squared amplitudes are the target sequences. Note that indices like _123456 grow over the course of the dataset and should be normalized for each amplitude and squared amplitude. Use an 80-10-10 split of train-val-test across all files.


In [4]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
QED Amplitude Expression Preprocessing Pipeline (Task 1.2)

Handles the preprocessing of QED amplitude data:
1.  Assumes data files (expected format: event:diagram:amplitude:sq_amplitude)
    have been downloaded manually into a specified directory.
2.  Loads data from all '.txt' files within that directory.
3.  Parses lines, extracting amplitude (input) and squared amplitude (target).
4.  Applies index normalization (e.g., _123 -> _0) locally within each expression.
5.  Tokenizes normalized expressions using a custom regex tokenizer.
6.  Structures data into input/target token pairs.
7.  Shuffles the combined dataset.
8.  Splits into 80/10/10 train/validation/test sets.
9.  Saves splits to JSON Lines (.jsonl) files.
10. Documents the rationale for tokenization choices.
"""

import re
import random
import json
import sys
import logging
from pathlib import Path
from typing import List, Dict, Tuple # Added Tuple for type hints

# --- Configuration ---

# **IMPORTANT**: Manual Download Required!
# Download the dataset (10 files) from:
# https://alabama.box.com/s/xhgr2onrn503jyse2fs5vxtapg0oifcs
# And place ALL the '.txt' files into the directory specified below.
# Example: Create a directory 'qed_data_raw' next to this script
#          and put the 10 downloaded txt files inside it.
INPUT_DATA_DIR = Path("./qed_data_raw") # <--- UPDATE if you place data elsewhere

# Output directory for processed JSONL files
OUTPUT_DIR = Path("./qed_data_processed")
OUTPUT_PREFIX = 'qed_amplitudes' # Files will be qed_amplitudes_train.jsonl etc.

# Dataset split ratios
TRAIN_FRAC = 0.8
VAL_FRAC = 0.1
TEST_FRAC = 1.0 - TRAIN_FRAC - VAL_FRAC # Calculated automatically

# Random seed for reproducibility
RANDOM_SEED = 42

# Logging Setup
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


# --- Tokenization and Normalization ---

# Regex pattern to capture tokens (same as before, seems suitable):
# Group 1: Identifiers ([A-Za-z_]\w*)
# Group 2: Numbers (\d+)
# Group 3: Operators/Special Chars (\*\*|\^|[+\-*/=()_,:])
TOKEN_PATTERN = re.compile(r"([A-Za-z_]\w*|\d+|\*\*|\^|[+\-*/=()_,:])")

# Regex pattern to find numeric subscripts (e.g., "_123456")
INDEX_PATTERN = re.compile(r'_\d+')

def tokenize(expression_string: str) -> List[str]:
    """Tokenizes a mathematical expression string using TOKEN_PATTERN."""
    if not isinstance(expression_string, str): # Added type check
        return []
    return TOKEN_PATTERN.findall(expression_string)

def normalize_indices(expression_string: str) -> str:
    """Normalizes numeric subscripts (_123 -> _0) within a single expression string."""
    if not isinstance(expression_string, str): # Added type check
        return ""

    index_mapping: Dict[str, str] = {}
    counter = [0] # Mutable counter for closure

    def _replace_match(match):
        original_index = match.group(0)
        if original_index not in index_mapping:
            normalized_index = f"_{counter[0]}"
            index_mapping[original_index] = normalized_index
            counter[0] += 1
        return index_mapping[original_index]

    return INDEX_PATTERN.sub(_replace_match, expression_string)

# --- Data Loading and Processing ---

def load_and_preprocess_data(data_directory: Path) -> List[Dict[str, List[str]]]:
    """
    Loads data from all .txt files in the specified directory, preprocesses,
    and tokenizes amplitude and squared amplitude pairs.

    Args:
        data_directory (Path): The directory containing the downloaded .txt files.

    Returns:
        List[Dict[str, List[str]]]: A list of dictionaries, each with keys
                                     'input_tokens' and 'target_tokens'.
    """
    dataset: List[Dict[str, List[str]]] = []
    skipped_lines = 0
    processed_files = 0

    if not data_directory.is_dir():
        logging.error(f"Input data directory not found: {data_directory}")
        logging.error("Please ensure you have downloaded the data and placed the .txt files there.")
        sys.exit(1)

    # Use pathlib's glob for cleaner path handling
    file_paths = sorted(list(data_directory.glob("*.txt")))

    if not file_paths:
        logging.error(f"No '.txt' files found in directory: {data_directory}")
        logging.error("Please check the directory content and INPUT_DATA_DIR setting.")
        sys.exit(1)

    logging.info(f"Found {len(file_paths)} '.txt' files in {data_directory}. Processing...")

    for file_path in file_paths:
        processed_files += 1
        logging.info(f"  Processing file: {file_path.name}...")
        line_count_in_file = 0
        skipped_in_file = 0
        try:
            with open(file_path, 'r', encoding='utf-8') as infile:
                for line_num, line in enumerate(infile, 1):
                    line = line.strip()
                    if not line: continue # Skip empty lines

                    parts = line.split(' : ', 3) # Split only 3 times max
                    # Expecting 4 parts: event:diagram:amplitude:sq_amplitude
                    if len(parts) != 4:
                        # Reduce logging noise, maybe log first few warnings per file?
                        # if skipped_in_file < 5:
                        #     logging.warning(f"Skipping malformed line {line_num} in {file_path.name}: Expected 4 parts separated by ' : ', found {len(parts)}. Content: '{line[:100]}...'")
                        skipped_lines += 1
                        skipped_in_file += 1
                        continue

                    # Input is amplitude, Target is squared amplitude
                    amplitude_expr = parts[2]
                    sq_amplitude_expr = parts[3]

                    # 1. Normalize indices (Applied independently to each expression)
                    normalized_amplitude = normalize_indices(amplitude_expr)
                    normalized_sq_amplitude = normalize_indices(sq_amplitude_expr)

                    # 2. Tokenize
                    amplitude_tokens = tokenize(normalized_amplitude)
                    sq_amplitude_tokens = tokenize(normalized_sq_amplitude)

                    # Append structured data if tokenization successful
                    if amplitude_tokens and sq_amplitude_tokens:
                        dataset.append({
                            'input_tokens': amplitude_tokens,
                            'target_tokens': sq_amplitude_tokens
                        })
                        line_count_in_file += 1
                    else:
                        # if skipped_in_file < 5: # Example condition to reduce noise
                        #    logging.warning(f"Skipping line {line_num} in {file_path.name} due to empty tokens after processing.")
                        skipped_lines += 1
                        skipped_in_file += 1

            logging.info(f"    Finished {file_path.name}. Added {line_count_in_file} examples. Skipped {skipped_in_file} lines.")

        except FileNotFoundError:
            logging.error(f"File not found during processing: {file_path}. This shouldn't happen if glob worked.")
            continue
        except Exception as e:
            logging.exception(f"Error processing file {file_path}: {e}") # Log full traceback
            continue # Skip faulty file

    logging.info(f"Finished processing {processed_files} files.")
    if skipped_lines > 0:
        logging.warning(f"Total skipped lines across all files: {skipped_lines}")

    return dataset

def split_and_save_dataset(
    data: List[Dict[str, List[str]]],
    train_frac: float,
    val_frac: float,
    output_dir: Path,
    output_prefix: str
) -> None:
    """
    Shuffles, splits (train/val/test), and saves the dataset to JSON Lines files.
    """
    if not data:
        logging.warning("No data provided to split and save.")
        return

    n_total = len(data)
    logging.info(f"Total examples loaded: {n_total}")

    if n_total == 0:
        logging.error("No valid examples were loaded from the data files. Cannot split.")
        sys.exit(1)

    # Ensure fractions are valid
    if not (0 < train_frac < 1 and 0 < val_frac < 1 and (train_frac + val_frac) < 1):
        logging.error(f"Invalid split fractions: train={train_frac}, val={val_frac}. Must be > 0 and sum < 1.")
        sys.exit(1)

    # Shuffle the data
    logging.info(f"Shuffling dataset with random seed {RANDOM_SEED}...")
    random.shuffle(data)

    # Calculate split indices
    n_train = int(train_frac * n_total)
    n_val = int(val_frac * n_total)
    # Test gets the remainder
    n_test = n_total - n_train - n_val

    # Check if splits are sensible
    if n_train == 0 or n_val == 0 or n_test == 0:
         logging.warning(f"Dataset size ({n_total}) is very small, resulting in potentially zero examples in splits: Train={n_train}, Val={n_val}, Test={n_test}. Check input data.")

    # Perform the splits
    splits: Dict[str, List[Dict]] = {
        'train': data[:n_train],
        'val': data[n_train : n_train + n_val],
        'test': data[n_train + n_val :]
    }

    logging.info(f"Calculated split sizes: Train={len(splits['train'])}, Validation={len(splits['val'])}, Test={len(splits['test'])}")

    # Save each split
    try:
        output_dir.mkdir(parents=True, exist_ok=True) # Ensure output directory exists
        logging.info(f"Saving splits to directory: {output_dir.resolve()}")
    except OSError as e:
        logging.error(f"Could not create output directory {output_dir}: {e}")
        sys.exit(1)

    for split_name, subset in splits.items():
        output_filename = output_dir / f"{output_prefix}_{split_name}.jsonl"
        logging.info(f"  Saving {split_name} set ({len(subset)} examples) to {output_filename}...")
        try:
            with open(output_filename, 'w', encoding='utf-8') as outfile:
                count = 0
                for entry in subset:
                    # Write each dictionary as a JSON string on its own line
                    outfile.write(json.dumps(entry, ensure_ascii=False) + '\n')
                    count +=1
                logging.info(f"    Successfully saved {count} lines.")
        except IOError as e:
            logging.error(f"Failed to write {split_name} set to {output_filename}: {e}")
        except Exception as e:
             logging.exception(f"An unexpected error occurred while writing {split_name} set: {e}")


def print_rationale() -> None:
    """Prints the documented rationale for the chosen preprocessing steps."""
    rationale = """
======================================================================
          Preprocessing and Tokenization Rationale (QED Amplitudes)
======================================================================

1.  **Index Normalization (`_123456` -> `_0`, `_1`, ...):**
    * **Problem:** The raw data contains indices (e.g., `p_123`, `gamma_45`) where the numeric part can grow very large across the dataset. Treating each unique numbered index (like `p_123` vs `p_124`) as a distinct token would lead to an enormous, sparse vocabulary.
    * **Solution:** We normalize these indices *within each individual expression* (amplitude and squared amplitude separately). The first unique numeric subscript encountered (like `_123`) becomes `_0`, the second unique one (`_45`) becomes `_1`, and so on. If the *same* original index appears multiple times in one expression (e.g., `p_123 * p_123`), it gets the *same* normalized index (`p_0 * p_0`).
    * **Benefit:** This dramatically reduces vocabulary size and helps the model focus on the *structure* and *relationships* between indexed terms rather than memorizing absolute index values. It treats `p_0`, `p_1` etc., as generic indexed placeholders.

2.  **Custom Regex Tokenization (`TOKEN_PATTERN`):**
    * **Goal:** To break down the normalized mathematical expressions into meaningful units suitable for sequence modeling.
    * **Method:** A regular expression `([A-Za-z_]\w*|\d+|\*\*|\^|[+\-*/=()_,:])` is used to explicitly define tokens as:
        * Identifiers: `p_0`, `m_e`, `gamma_1`, `exp`, etc. (alphanumeric starting with letter/_).
        * Numbers: `2`, `4`, `1`, etc. (integers).
        * Operators/Special Characters: `+`, `-`, `*`, `/`, `=`, `**` (power), `^` (power), `(`, `)`, `,`, `:`.
    * **Why Regex Here?**
        * **Preserves Structure:** Unlike statistical methods like BPE (Byte Pair Encoding) used for natural language or the previous Feynman equations, this regex approach prevents merging parts of distinct mathematical significance. For example, it ensures `p`, `_`, `0` remain separate if desired (though here `p_0` is one identifier token), or that operators like `**` are treated as single units. It guarantees that fundamental symbols like `+`, `*`, `(` are always distinct tokens.
        * **Interpretability:** The resulting tokens directly correspond to the symbolic components of the expressions, making the input/output sequences easier to understand.
        * **Domain Specificity:** The structure of these QED expressions is relatively regular compared to free-form text. A targeted regex can capture the known relevant components effectively without needing to learn merges from data (like BPE).
    * **Alternative Considered (BPE):** While BPE could be trained, it might learn merges that are not mathematically ideal (e.g., merging an operator with part of a variable name) and could result in a less interpretable token set for this specific symbolic domain. Given the structured nature of the input, regex provides more control.

3.  **Input/Target:**
    * The 'amplitude' expression (normalized and tokenized) serves as the input sequence.
    * The 'squared amplitude' expression (normalized and tokenized) serves as the target sequence.

4.  **Splitting (80/10/10 Train/Val/Test):**
    * Standard practice for robust model development: training on the largest portion, tuning hyperparameters on the validation set, and final unbiased evaluation on the held-out test set. Shuffling *before* splitting ensures randomness across data from all source files.

======================================================================
"""
    print(rationale)

# --- Main Execution ---

def main() -> None:
    """Main function to orchestrate the preprocessing and splitting."""
    logging.info("Starting QED amplitude expression preprocessing script...")

    # Set the random seed for reproducibility
    random.seed(RANDOM_SEED)
    logging.info(f"Random seed set to: {RANDOM_SEED}")

    # Print rationale upfront
    print_rationale()

    # 1. Load and preprocess data
    logging.info(f"Loading data from directory: {INPUT_DATA_DIR.resolve()}")
    preprocessed_data = load_and_preprocess_data(INPUT_DATA_DIR)

    if not preprocessed_data:
        logging.error("No data loaded after processing files. Exiting.")
        # load_and_preprocess_data already exits if dir/files not found
        sys.exit(1)

    logging.info(f"Successfully loaded and preprocessed {len(preprocessed_data)} examples.")

    # 2. Split and save the dataset
    logging.info(f"Splitting and saving dataset with prefix '{OUTPUT_PREFIX}' to {OUTPUT_DIR.resolve()}...")
    split_and_save_dataset(
        data=preprocessed_data,
        train_frac=TRAIN_FRAC,
        val_frac=VAL_FRAC,
        output_dir=OUTPUT_DIR,
        output_prefix=OUTPUT_PREFIX
    )

    logging.info("Script finished successfully.")

if __name__ == "__main__":
    main()

2025-04-22 12:02:31,969 - INFO - Starting QED amplitude expression preprocessing script...
2025-04-22 12:02:31,970 - INFO - Random seed set to: 42
2025-04-22 12:02:31,970 - INFO - Loading data from directory: /home/nikitas/Desktop/Miche/GOOGLE/qed_data_raw
2025-04-22 12:02:31,971 - INFO - Found 10 '.txt' files in qed_data_raw. Processing...
2025-04-22 12:02:31,971 - INFO -   Processing file: QED-2-to-2-diag-TreeLevel-0.txt...


2025-04-22 12:02:32,048 - INFO -     Finished QED-2-to-2-diag-TreeLevel-0.txt. Added 1728 examples. Skipped 0 lines.
2025-04-22 12:02:32,049 - INFO -   Processing file: QED-2-to-2-diag-TreeLevel-1.txt...
2025-04-22 12:02:32,104 - INFO -     Finished QED-2-to-2-diag-TreeLevel-1.txt. Added 1664 examples. Skipped 0 lines.
2025-04-22 12:02:32,105 - INFO -   Processing file: QED-2-to-2-diag-TreeLevel-2.txt...
2025-04-22 12:02:32,161 - INFO -     Finished QED-2-to-2-diag-TreeLevel-2.txt. Added 1600 examples. Skipped 0 lines.
2025-04-22 12:02:32,161 - INFO -   Processing file: QED-2-to-2-diag-TreeLevel-3.txt...



          Preprocessing and Tokenization Rationale (QED Amplitudes)

1.  **Index Normalization (`_123456` -> `_0`, `_1`, ...):**
    * **Problem:** The raw data contains indices (e.g., `p_123`, `gamma_45`) where the numeric part can grow very large across the dataset. Treating each unique numbered index (like `p_123` vs `p_124`) as a distinct token would lead to an enormous, sparse vocabulary.
    * **Solution:** We normalize these indices *within each individual expression* (amplitude and squared amplitude separately). The first unique numeric subscript encountered (like `_123`) becomes `_0`, the second unique one (`_45`) becomes `_1`, and so on. If the *same* original index appears multiple times in one expression (e.g., `p_123 * p_123`), it gets the *same* normalized index (`p_0 * p_0`).
    * **Benefit:** This dramatically reduces vocabulary size and helps the model focus on the *structure* and *relationships* between indexed terms rather than memorizing absolute index values. It 

2025-04-22 12:02:32,214 - INFO -     Finished QED-2-to-2-diag-TreeLevel-3.txt. Added 1536 examples. Skipped 0 lines.
2025-04-22 12:02:32,215 - INFO -   Processing file: QED-2-to-2-diag-TreeLevel-4.txt...
2025-04-22 12:02:32,269 - INFO -     Finished QED-2-to-2-diag-TreeLevel-4.txt. Added 1472 examples. Skipped 0 lines.
2025-04-22 12:02:32,269 - INFO -   Processing file: QED-2-to-2-diag-TreeLevel-5.txt...
2025-04-22 12:02:32,319 - INFO -     Finished QED-2-to-2-diag-TreeLevel-5.txt. Added 1408 examples. Skipped 0 lines.
2025-04-22 12:02:32,319 - INFO -   Processing file: QED-2-to-2-diag-TreeLevel-6.txt...
2025-04-22 12:02:32,370 - INFO -     Finished QED-2-to-2-diag-TreeLevel-6.txt. Added 1344 examples. Skipped 0 lines.
2025-04-22 12:02:32,371 - INFO -   Processing file: QED-2-to-2-diag-TreeLevel-7.txt...
2025-04-22 12:02:32,419 - INFO -     Finished QED-2-to-2-diag-TreeLevel-7.txt. Added 1280 examples. Skipped 0 lines.
2025-04-22 12:02:32,420 - INFO -   Processing file: QED-2-to-2-diag

# Common Task 2: Train/Evaluate Transformer model
Train a generic next-token-prediction Transformer model to map the input data to the tokenized output sequences. Evaluate performance on the test set using sequence accuracy as a metric.


In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
WORKS ON CPU

Task 2: Transformer Training for QED Amplitude Mapping (Jupyter Compatible)

Trains an Encoder-Decoder Transformer model on the preprocessed QED amplitude dataset.
Uses a custom tokenizer trained on the specific tokens generated during preprocessing.
Modified to run within a Jupyter Notebook by detecting the environment and using
default arguments instead of parsing command-line args when needed.
Uses MINIMAL TrainingArguments parameters for compatibility testing.

Pipeline:
1.  Define paths and configurations (using defaults when in Jupyter).
2.  Train a custom WordLevel tokenizer on the training data tokens if not already present.
3.  Load the custom tokenizer.
4.  Load the train/val/test datasets (JSONL format from Task 1.2).
5.  Define a tokenization function using the custom tokenizer.
6.  Tokenize the datasets using `datasets.map`.
7.  Initialize an Encoder-Decoder model (e.g., bert-base-uncased) and resize
    its token embeddings INDIVIDUALLY to match the custom tokenizer's vocabulary.
8.  Configure the model for sequence-to-sequence tasks.
9.  Set up Data Collator, MINIMAL Training Arguments, and Seq2SeqTrainer.
10. Define sequence accuracy metric function.
11. Train the model.
12. Evaluate the model on the test set.
"""

# --- Core Imports ---
import json
import sys
import numpy as np
import torch
import logging
import argparse
import random
from pathlib import Path
from typing import Dict, List, Optional

# --- Hugging Face Imports ---
from datasets import load_dataset
from tokenizers import Tokenizer, models, pre_tokenizers, trainers, decoders, processors
from transformers import (
    PreTrainedTokenizerFast,
    AutoConfig,
    EncoderDecoderModel,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)
# import accelerate

# Logging Setup
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# --- Default Configuration (Used when running in Jupyter) ---
DEFAULT_ARGS = {
    "data_dir": Path("./qed_data_processed"),
    "output_dir": Path("./qed_model_output_v2"),
    "tokenizer_file_name": "qed_custom_tokenizer.json",
    "encoder_model_id": "bert-base-uncased",
    "decoder_model_id": "bert-base-uncased",
    "max_seq_length": 128,
    "vocab_size": 10000,
    "num_train_epochs": 5,
    "per_device_train_batch_size": 16,
    "per_device_eval_batch_size": 16,
    "learning_rate": 5e-5,
    "weight_decay": 0.01,
    "warmup_steps": 500,
    "logging_steps": 100,
    # Removed eval/save strategy args - use minimal set
    "seed": 42,
}


# --- Custom Tokenizer Training ---
# (Function remains the same)
def train_custom_tokenizer(
    data_files: Dict[str, str],
    output_path: Path,
    vocab_size: int,
    min_frequency: int = 2
    ) -> None:
    """ Trains a WordLevel tokenizer. """
    if output_path.exists():
        logging.info(f"Tokenizer file already exists at {output_path}, skipping training.")
        return
    logging.info("Training custom WordLevel tokenizer...")
    try:
        train_file_path = data_files.get('train')
        if not train_file_path or not Path(train_file_path).is_file():
            logging.error(f"Training data file path missing or invalid: {train_file_path}")
            sys.exit(1)
        raw_train_dataset = load_dataset("json", data_files={'train': train_file_path}, split="train")
        logging.info(f"Loaded {len(raw_train_dataset)} examples for tokenizer training.")
    except Exception as e: logging.error(f"Failed to load training data: {e}"); sys.exit(1)
    special_tokens = ["<s>", "<pad>", "</s>", "<unk>"]
    def get_training_corpus():
        sequences_yielded = 0
        for example in raw_train_dataset:
            in_tokens = example.get("input_tokens"); tgt_tokens = example.get("target_tokens")
            if in_tokens and isinstance(in_tokens, list): yield in_tokens; sequences_yielded += 1
            if tgt_tokens and isinstance(tgt_tokens, list): yield tgt_tokens; sequences_yielded += 1
        if sequences_yielded == 0: logging.warning("Tokenizer training corpus yielded zero sequences.")
        else: logging.info(f"Tokenizer training corpus yielded {sequences_yielded} sequences.")
    try:
        custom_tokenizer = Tokenizer(models.WordLevel(unk_token="<unk>"))
        custom_tokenizer.pre_tokenizer = None
        trainer = trainers.WordLevelTrainer(vocab_size=vocab_size, min_frequency=min_frequency, special_tokens=special_tokens)
        logging.info(f"Starting tokenizer training (vocab_size={vocab_size}, min_freq={min_frequency})...")
        custom_tokenizer.train_from_iterator(get_training_corpus(), trainer=trainer)
        logging.info("Tokenizer training finished.")
        current_vocab = custom_tokenizer.get_vocab()
        tokens_to_add = [token for token in special_tokens if token not in current_vocab]
        if tokens_to_add:
            num_added = custom_tokenizer.add_special_tokens(tokens_to_add)
            logging.warning(f"Added {num_added} special tokens ({tokens_to_add}) missing after training.")
            if not all(token in custom_tokenizer.get_vocab() for token in special_tokens): raise ValueError("Failed to add all required special tokens!")
        bos_token_id = custom_tokenizer.token_to_id("<s>"); eos_token_id = custom_tokenizer.token_to_id("</s>")
        if bos_token_id is None or eos_token_id is None: raise ValueError("BOS/EOS tokens missing from trained tokenizer vocab.")
        custom_tokenizer.post_processor = processors.TemplateProcessing(single="<s> $A </s>", special_tokens=[("<s>", bos_token_id), ("</s>", eos_token_id)])
        logging.info("Set tokenizer post-processor for BOS/EOS.")
        output_path.parent.mkdir(parents=True, exist_ok=True)
        custom_tokenizer.save(str(output_path))
        logging.info(f"Custom tokenizer trained ({custom_tokenizer.get_vocab_size()} vocab size) and saved to {output_path}")
    except Exception as e: logging.exception(f"Error during tokenizer training or saving: {e}"); sys.exit(1)


# --- Data Loading and Tokenization ---
# (Function remains the same)
def tokenize_function(examples, tokenizer, max_len):
    """ Tokenizes input and target token lists using the custom tokenizer. """
    input_strings = [" ".join(map(str, tokens)) if isinstance(tokens, list) else "" for tokens in examples['input_tokens']]
    target_strings = [" ".join(map(str, tokens)) if isinstance(tokens, list) else "" for tokens in examples['target_tokens']]
    model_inputs = tokenizer(input_strings, max_length=max_len, padding=False, truncation=True)
    labels = tokenizer(text_target=target_strings, max_length=max_len, padding=False, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# --- Metrics ---
# (Functions remain the same)
def compute_sequence_accuracy(predictions, labels, pad_token_id):
    """ Calculates exact sequence match accuracy, ignoring padding. """
    pred_ids = np.argmax(predictions, axis=-1); labels = np.asarray(labels)
    if pred_ids.shape != labels.shape:
        logging.warning(f"Shape mismatch in metrics: preds {pred_ids.shape}, labels {labels.shape}. Truncating.")
        min_len = min(pred_ids.shape[1], labels.shape[1]); pred_ids = pred_ids[:, :min_len]; labels = labels[:, :min_len]
    non_padding_mask = (labels != pad_token_id); correct_tokens = (pred_ids == labels) & non_padding_mask
    sum_non_padding = np.sum(non_padding_mask, axis=1); sum_correct_tokens = np.sum(correct_tokens, axis=1)
    correct_sequences = np.where(sum_non_padding > 0, sum_correct_tokens == sum_non_padding, True)
    accuracy = np.mean(correct_sequences); return float(accuracy)

def compute_metrics(eval_pred, tokenizer):
    """ Computes metrics for evaluation, specifically sequence accuracy. """
    predictions, labels = eval_pred; pad_token_id = tokenizer.pad_token_id
    if pad_token_id is None:
        pad_token_id = tokenizer.token_to_id("<pad>")
        if pad_token_id is None: raise ValueError("Cannot determine pad_token_id for metrics.")
    seq_acc = compute_sequence_accuracy(predictions, labels, pad_token_id); return {"sequence_accuracy": seq_acc}

# --- Main Execution Logic ---

def main():
    """ Orchestrates the training/evaluation, compatible with Jupyter and terminal. """
    parser = argparse.ArgumentParser(description="Train/evaluate Seq2Seq model.")
    is_jupyter = any(['ipykernel' in arg for arg in sys.argv])

    if is_jupyter:
        logging.warning("Running in Jupyter/IPython! Using default arguments.")
        # Create a temporary dict excluding keys not needed by Namespace/older args
        temp_defaults = {k: v for k, v in DEFAULT_ARGS.items() if k != 'eval_save_strategy'} # Example if strategy was in defaults
        args = argparse.Namespace(**temp_defaults)
        args.tokenizer_file = args.output_dir / args.tokenizer_file_name
    else:
        # Define arguments for command-line execution
        parser.add_argument("--data_dir", type=Path, default=DEFAULT_ARGS["data_dir"])
        parser.add_argument("--output_dir", type=Path, default=DEFAULT_ARGS["output_dir"])
        parser.add_argument("--tokenizer_file_name", type=str, default=DEFAULT_ARGS["tokenizer_file_name"])
        parser.add_argument("--encoder_model_id", type=str, default=DEFAULT_ARGS["encoder_model_id"])
        parser.add_argument("--decoder_model_id", type=str, default=DEFAULT_ARGS["decoder_model_id"])
        parser.add_argument("--max_seq_length", type=int, default=DEFAULT_ARGS["max_seq_length"])
        parser.add_argument("--vocab_size", type=int, default=DEFAULT_ARGS["vocab_size"])
        parser.add_argument("--num_train_epochs", type=int, default=DEFAULT_ARGS["num_train_epochs"])
        parser.add_argument("--per_device_train_batch_size", type=int, default=DEFAULT_ARGS["per_device_train_batch_size"])
        parser.add_argument("--per_device_eval_batch_size", type=int, default=DEFAULT_ARGS["per_device_eval_batch_size"])
        parser.add_argument("--learning_rate", type=float, default=DEFAULT_ARGS["learning_rate"])
        parser.add_argument("--weight_decay", type=float, default=DEFAULT_ARGS["weight_decay"])
        parser.add_argument("--warmup_steps", type=int, default=DEFAULT_ARGS["warmup_steps"])
        parser.add_argument("--logging_steps", type=int, default=DEFAULT_ARGS["logging_steps"])
        # Remove eval/save strategy args from parser definition
        # parser.add_argument("--eval_save_strategy", type=str, default=DEFAULT_ARGS["eval_save_strategy"], choices=["epoch", "steps"])
        # parser.add_argument("--eval_save_steps", type=int, default=DEFAULT_ARGS["eval_save_steps"])
        parser.add_argument("--seed", type=int, default=DEFAULT_ARGS["seed"])
        args = parser.parse_args()
        args.tokenizer_file = args.output_dir / args.tokenizer_file_name

    args.output_dir.mkdir(parents=True, exist_ok=True)
    logging.info(f"Starting script with effective arguments: {args}")

    torch.manual_seed(args.seed); np.random.seed(args.seed); random.seed(args.seed)
    logging.info(f"Random seed set to {args.seed}")
    OUTPUT_PREFIX = 'qed_amplitudes'

    data_files = {
        "train": str(args.data_dir / f"{OUTPUT_PREFIX}_train.jsonl"),
        "validation": str(args.data_dir / f"{OUTPUT_PREFIX}_val.jsonl"),
        "test": str(args.data_dir / f"{OUTPUT_PREFIX}_test.jsonl"),
    }
    for split, path in data_files.items():
        if not Path(path).is_file(): logging.error(f"{split.capitalize()} data file not found: {path}"); sys.exit(1)
        logging.info(f"Found {split} data file: {path}")

    train_custom_tokenizer(data_files, args.tokenizer_file, args.vocab_size)
    try:
        tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(args.tokenizer_file),
                                            bos_token="<s>", eos_token="</s>",
                                            unk_token="<unk>", pad_token="<pad>")
        if tokenizer.pad_token is None:
             logging.warning("Manually adding PAD token '<pad>' to tokenizer.")
             tokenizer.add_special_tokens({'pad_token': '<pad>'})
        logging.info(f"Successfully loaded custom tokenizer from {args.tokenizer_file}")
        logging.info(f"Tokenizer vocabulary size: {tokenizer.vocab_size}")
        if tokenizer.pad_token_id is None: raise ValueError("Tokenizer pad_token_id is None.")
        logging.info(f"Using PAD token: '{tokenizer.pad_token}' (ID: {tokenizer.pad_token_id})")
    except Exception as e: logging.exception(f"Failed to load custom tokenizer: {e}"); sys.exit(1)

    try:
        logging.info("Loading datasets..."); raw_datasets = load_dataset("json", data_files=data_files)
        logging.info(f"Datasets loaded: {raw_datasets}")
    except Exception as e: logging.exception(f"Failed to load datasets: {e}"); sys.exit(1)

    logging.info("Tokenizing datasets...")
    try:
        tokenize_partial = lambda examples: tokenize_function(examples, tokenizer, args.max_seq_length)
        tokenized_datasets = raw_datasets.map(
            tokenize_partial, batched=True, remove_columns=raw_datasets["train"].column_names, desc="Running tokenizer on dataset"
        )
        logging.info(f"Tokenized datasets structure: {tokenized_datasets}")
    except Exception as e: logging.exception(f"Failed during dataset tokenization: {e}"); sys.exit(1)

    logging.info(f"Initializing Encoder-Decoder model ({args.encoder_model_id} -> {args.decoder_model_id})")
    try:
        model = EncoderDecoderModel.from_encoder_decoder_pretrained(args.encoder_model_id, args.decoder_model_id)
        logging.info(f"Resizing ENCODER token embeddings to {tokenizer.vocab_size}")
        model.encoder.resize_token_embeddings(len(tokenizer))
        logging.info(f"Resizing DECODER token embeddings to {tokenizer.vocab_size}")
        model.decoder.resize_token_embeddings(len(tokenizer))
        model.config.decoder_start_token_id = tokenizer.bos_token_id; model.config.eos_token_id = tokenizer.eos_token_id
        model.config.pad_token_id = tokenizer.pad_token_id; model.config.encoder.vocab_size = tokenizer.vocab_size
        model.config.decoder.vocab_size = tokenizer.vocab_size; model.config.vocab_size = tokenizer.vocab_size
        logging.info("Model configuration complete.")
    except Exception as e: logging.exception(f"Failed to initialize or configure model: {e}"); sys.exit(1)

    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, label_pad_token_id=tokenizer.pad_token_id)
    logging.info("Data collator initialized.")

    # --- MINIMAL Training Arguments Block (for compatibility testing) ---
    logging.info("Defining Training Arguments (Minimal for compatibility check)...")
    training_args = Seq2SeqTrainingArguments(
        output_dir=str(args.output_dir),
        # Required arguments
        do_train=True,
        do_eval=True, # Evaluation will happen at the end by default
        num_train_epochs=args.num_train_epochs,
        per_device_train_batch_size=args.per_device_train_batch_size,
        per_device_eval_batch_size=args.per_device_eval_batch_size,
        # Basic logging
        logging_dir=str(args.output_dir / 'logs'),
        logging_steps=args.logging_steps,
        # Other standard args likely fine
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay,
        warmup_steps=args.warmup_steps,
        seed=args.seed,
        report_to="none", # Disable external reporting like wandb
        # --- Temporarily REMOVED all potentially problematic evaluation/saving args ---
        # save_strategy="no", # Removed even this
        # evaluation_strategy="no", # Removed even this
        # save_total_limit=2, # Removed
        load_best_model_at_end=False, # Cannot load best if not evaluating/saving during train
        # metric_for_best_model="sequence_accuracy", # Removed
        # greater_is_better=True, # Removed
        predict_with_generate=False, # Keep this
    )
    # --- End Minimal Block ---
    logging.info(f"Training arguments defined (Minimal).")


    compute_metrics_with_tokenizer = lambda eval_pred: compute_metrics(eval_pred, tokenizer)
    trainer = Seq2SeqTrainer(
        model=model, args=training_args,
        train_dataset=tokenized_datasets["train"], eval_dataset=tokenized_datasets["validation"],
        data_collator=data_collator, tokenizer=tokenizer,
        compute_metrics=compute_metrics_with_tokenizer,
    )
    logging.info("Seq2SeqTrainer initialized.")

    logging.info("Starting model training...")
    try:
        train_result = trainer.train()
        logging.info("Training finished successfully.")
        metrics = train_result.metrics
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
        # Save the final model manually since automatic saving might be disabled
        trainer.save_model(str(args.output_dir / "final_model"))
        logging.info(f"Final model saved to {args.output_dir / 'final_model'}")
    except Exception as e: logging.exception(f"Training failed: {e}"); sys.exit(1)

    logging.info("Evaluating model on the test set...")
    try:
        # Note: This evaluates the *final* model state, not necessarily the best validation one
        test_results = trainer.evaluate(eval_dataset=tokenized_datasets["test"], metric_key_prefix="test")
        trainer.log_metrics("test", test_results)
        trainer.save_metrics("test", test_results)
        if 'test_sequence_accuracy' in test_results:
             print(f"\n--- Test Set Results ---")
             print(f"  Test Sequence Accuracy: {test_results['test_sequence_accuracy']:.4f}")
             print(f"------------------------")
        else:
             logging.warning("'test_sequence_accuracy' not found in test evaluation results.")
             print("Test results:", test_results)
    except Exception as e: logging.exception(f"Evaluation on test set failed: {e}"); sys.exit(1)

    logging.info(f"Script finished. Model and results saved in {args.output_dir.resolve()}")

# --- Entry Point ---
if __name__ == "__main__":
     main()

# --- Code to Run in Jupyter Cell ---
# Copy all code above into a single cell, then run main() in the *next* cell:
# main()

2025-04-22 13:09:28,774 - INFO - Starting script with effective arguments: Namespace(data_dir=PosixPath('qed_data_processed'), output_dir=PosixPath('qed_model_output_v2'), tokenizer_file_name='qed_custom_tokenizer.json', encoder_model_id='bert-base-uncased', decoder_model_id='bert-base-uncased', max_seq_length=128, vocab_size=10000, num_train_epochs=5, per_device_train_batch_size=16, per_device_eval_batch_size=16, learning_rate=5e-05, weight_decay=0.01, warmup_steps=500, logging_steps=100, seed=42, tokenizer_file=PosixPath('qed_model_output_v2/qed_custom_tokenizer.json'))
2025-04-22 13:09:28,775 - INFO - Random seed set to 42
2025-04-22 13:09:28,776 - INFO - Found train data file: qed_data_processed/qed_amplitudes_train.jsonl
2025-04-22 13:09:28,776 - INFO - Found validation data file: qed_data_processed/qed_amplitudes_val.jsonl
2025-04-22 13:09:28,776 - INFO - Found test data file: qed_data_processed/qed_amplitudes_test.jsonl
2025-04-22 13:09:28,777 - INFO - Tokenizer file already exi

2025-04-22 13:09:29,031 - INFO - Datasets loaded: DatasetDict({
    train: Dataset({
        features: ['input_tokens', 'target_tokens'],
        num_rows: 12441
    })
    validation: Dataset({
        features: ['input_tokens', 'target_tokens'],
        num_rows: 1555
    })
    test: Dataset({
        features: ['input_tokens', 'target_tokens'],
        num_rows: 1556
    })
})
2025-04-22 13:09:29,032 - INFO - Tokenizing datasets...
Running tokenizer on dataset: 100%|██████████| 12441/12441 [00:01<00:00, 9203.72 examples/s]
Running tokenizer on dataset: 100%|██████████| 1555/1555 [00:00<00:00, 10288.61 examples/s]
Running tokenizer on dataset: 100%|██████████| 1556/1556 [00:00<00:00, 6132.66 examples/s]
2025-04-22 13:09:30,799 - INFO - Tokenized datasets structure: DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 12441
    })
    validation: Dataset({
        features: ['input_ids', 'token_type_ids', 

Step,Training Loss
100,4.0905
200,0.5897
300,0.0177
400,0.0069
500,0.0007
600,0.0004
700,0.0003


Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x76322503cd00>>
Traceback (most recent call last):
  File "/home/nikitas/anaconda3/envs/feynman_seq2seq_env/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 


# Specific Test 3: Train/Evaluate advanced model
Repeat task two including checking sequence accuracy but with a model that leverages some slightly more advanced techniques. The model you use should relate to the project you’re applying for.

In [5]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
End-to-end Training and Evaluation Script for a T5 Sequence-to-Sequence Model
on the Preprocessed QED 2-to-2 Tree-Level Dataset.

*** NOTE: This version is significantly modified for robust execution ***
*** within a Jupyter Notebook (.ipynb) by REMOVING ARGPARSE.   ***
*** ***
*** Configuration is now handled by manually setting attributes on    ***
*** the `args` object defined near the top of the script below.       ***
*** Edit the `args` object directly to change parameters.             ***
*** ***
*** CORRECTED parameter name from evaluation_strategy to eval_strategy ***
*** Removed diagnostic blocks for cleaner notebook execution.        ***
*** ***
*** This script is designed to automatically utilize GPUs via     ***
*** Hugging Face Trainer/Accelerate when run in a correctly     ***
*** configured environment (like the one from t5_gpu_env.yaml or t5_pip_test). ***

Dependencies:
- Working Python 3.10 environment (e.g., t5_pip_test) with necessary packages.
"""

import json
import sys
import numpy as np
import torch
import datetime # Use the standard datetime module
import logging
import argparse # Still used for the Namespace object, but not for parsing
from pathlib import Path
from torch.utils.data import Dataset
import inspect # Keep inspect if needed by other parts, otherwise optional now
# Import the specific classes needed
from transformers import (
    T5Tokenizer,
    T5ForConditionalGeneration,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    TrainingArguments # Import base class for comparison if needed
)
from transformers.trainer_utils import set_seed # Use transformers' set_seed


# Configure basic logging (Set level to DEBUG for detailed output, INFO for progress)
# Check if logger already exists (useful in notebooks where cells might be re-run)
logger = logging.getLogger(__name__)
if not logger.hasHandlers():
    logging.basicConfig(
        level=logging.INFO, # INFO shows progress, DEBUG shows fine details
        format="%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        # Use stream=sys.stdout to ensure output appears in notebook cell
        stream=sys.stdout
    )
else:
    # Ensure level is still INFO if logger was configured previously
    logger.setLevel(logging.INFO)


# =============================================================================
# <<< CONFIGURATION >>>
# =============================================================================
# --- EDIT THESE VALUES TO CONFIGURE THE RUN ---
args = argparse.Namespace(
    # --- File Paths ---
    train_file='qed_data_processed/qed_amplitudes_train.jsonl',
    val_file='qed_data_processed/qed_amplitudes_val.jsonl',
    test_file='qed_data_processed/qed_amplitudes_test.jsonl',
    output_dir='qed_t5_model_output_gpu_final', # Consider a new output dir

    # --- Model Configuration ---
    model_id="t5-base",
    task_prefix="",
    tokenizer_legacy=False,

    # --- Tokenizer and Data Processing ---
    max_seq_length=128,

    # --- Training Hyperparameters ---
    num_epochs=5,
    # *** Adjust batch sizes based on GPU memory (RTX 4090 16GB can likely handle larger) ***
    train_batch_size=16,    # Increased from 8, monitor GPU memory usage
    eval_batch_size=16,     # Increased from 8
    learning_rate=3e-4,
    weight_decay=0.01,
    warmup_steps=200,
    logging_steps=50,
    eval_strategy="epoch",         # <<< CORRECTED NAME
    eval_steps=500, # Note: eval_steps is ignored if eval_strategy is 'epoch'
    save_strategy="epoch",
    save_steps=500, # Note: save_steps is ignored if save_strategy is 'epoch'
    save_total_limit=2,
    seed=42,
    dataloader_num_workers=0, # 0 is safest for notebooks, increase if needed

    # --- Advanced Training Features ---
    fp16=torch.cuda.is_available(), # Automatically use FP16 if GPU available
    gradient_checkpointing=True, # Good for saving memory on large models
    label_smoothing=0.1,

    # --- Generation Configuration (for evaluation) ---
    num_beams=4,

    # --- Reporting ---
    report_to="tensorboard", # Or "none", "wandb", etc.
)
# --- END OF CONFIGURATION ---

# (Helper Functions: load_jsonl, convert_tokens_to_strings, encode_sequences, SequenceDataset, compute_metrics_fn)
def load_jsonl(file_path):
    """Loads a JSON Lines (.jsonl) file, skipping blank/whitespace-only lines."""
    data = []
    file_path = Path(file_path)
    logger.info(f"Loading data from: {file_path}")
    if not file_path.is_file():
        logger.error(f"Data file not found: {file_path}")
        raise FileNotFoundError(f"Data file not found: {file_path}")
    try:
        with file_path.open('r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                line = line.strip()
                if line:
                    try:
                        data.append(json.loads(line))
                    except json.JSONDecodeError as e:
                        logger.warning(f"Skipping invalid JSON on line {i+1} in {file_path}: {e}")
                        continue
        logger.info(f"Successfully loaded {len(data)} records from {file_path}.")
        return data
    except Exception as e:
        logger.error(f"Failed to load data from {file_path}: {e}", exc_info=True)
        raise


def convert_tokens_to_strings(raw_data_list):
    """Converts lists of tokens into single whitespace-joined strings."""
    input_strings = []
    target_strings = []
    if not raw_data_list:
        logger.warning("Input data list is empty for token-to-string conversion.")
        return input_strings, target_strings

    skipped_count = 0
    logger.info(f"Attempting to convert {len(raw_data_list)} items to strings...")
    for i, item in enumerate(raw_data_list):
        input_toks = item.get('input_tokens')
        target_toks = item.get('target_tokens')

        if isinstance(input_toks, list) and isinstance(target_toks, list):
            try:
                # Ensure all tokens are strings before joining
                str_input_toks = [str(tok) for tok in input_toks]
                str_target_toks = [str(tok) for tok in target_toks]
                joined_input = " ".join(str_input_toks)
                joined_target = " ".join(str_target_toks)
                input_strings.append(joined_input)
                target_strings.append(joined_target)
            except Exception as e:
                 logger.warning(f"Error joining tokens for item at index {i}: {e}. Item: {item}", exc_info=True)
                 skipped_count += 1
        else:
            missing_keys = []
            if 'input_tokens' not in item: missing_keys.append('input_tokens')
            if 'target_tokens' not in item: missing_keys.append('target_tokens')
            invalid_types = []
            if 'input_tokens' in item and not isinstance(input_toks, list): invalid_types.append(f"input_tokens (type: {type(input_toks).__name__})")
            if 'target_tokens' in item and not isinstance(target_toks, list): invalid_types.append(f"target_tokens (type: {type(target_toks).__name__})")
            log_msg = f"Skipping item at index {i}."
            if missing_keys: log_msg += f" Missing keys: {missing_keys}."
            if invalid_types: log_msg += f" Keys with non-list values: {invalid_types}."
            item_str = str(item)
            if len(item_str) > 200: item_str = item_str[:200] + "..."
            log_msg += f" Item (truncated): {item_str}"
            logger.warning(log_msg)
            skipped_count += 1

    logger.info(f"Successfully converted {len(input_strings)} items to strings (skipped {skipped_count}).")
    if not input_strings and skipped_count > 0 and skipped_count == len(raw_data_list):
         logger.error("All items were skipped during token-to-string conversion.")
         raise ValueError("Failed to convert any items to strings.")
    elif not input_strings and len(raw_data_list) > 0 and skipped_count != len(raw_data_list):
        logger.error(f"Conversion resulted in an empty list of strings, but {len(raw_data_list)} items were processed ({skipped_count} skipped).")
        raise ValueError("String conversion unexpectedly produced no output despite valid input items existing.")
    elif not input_strings and not raw_data_list:
        logger.warning("Input data list was empty, resulting in empty string lists.")
    return input_strings, target_strings


def encode_sequences(tokenizer, input_strings, target_strings, max_len, task_prefix="", batch_size=500):
    """Tokenizes input and target string pairs using the T5 tokenizer."""
    logger.info(f"Starting encoding for {len(input_strings)} pairs. MaxLen={max_len}. Prefix='{task_prefix}'. (Batch size: {batch_size})")

    if not isinstance(input_strings, list) or not isinstance(target_strings, list):
        raise TypeError("input_strings and target_strings must be lists.")
    if not input_strings or not target_strings:
        logger.warning("Received empty list(s) for encoding. Returning empty.")
        return {'input_ids': [], 'attention_mask': [], 'labels': []}
    if len(input_strings) != len(target_strings):
        raise ValueError(f"Input ({len(input_strings)}) and Target ({len(target_strings)}) string lists must have the same length.")

    all_input_ids = []
    all_attention_mask = []
    all_labels = []
    num_sequences = len(input_strings)

    for i in range(0, num_sequences, batch_size):
        batch_start = i
        batch_end = min(i + batch_size, num_sequences)
        logger.info(f"  Encoding batch {batch_start+1}-{batch_end}/{num_sequences}...")

        input_batch = input_strings[batch_start:batch_end]
        target_batch = target_strings[batch_start:batch_end]

        # Encode Inputs
        try:
            processed_input_batch = [f"{task_prefix}{s}" if task_prefix else s for s in input_batch]
            # Use batch encoding directly
            encoder_outputs = tokenizer(
                processed_input_batch, max_length=max_len, padding='max_length', truncation=True, return_tensors=None
            )
            batch_input_ids = encoder_outputs['input_ids']
            batch_attention_mask = encoder_outputs['attention_mask']
        except Exception as e:
            logger.error(f"Tokenizer CRASHED on ENCODER batch: {e}", exc_info=True)
            raise

        # Encode Targets (Labels)
        try:
            # Use text_target argument for T5/Seq2Seq models
            decoder_outputs = tokenizer(
                text_target=target_batch, max_length=max_len, padding='max_length', truncation=True, return_tensors=None
            )
            batch_labels = decoder_outputs['input_ids']
        except Exception as e:
            logger.error(f"Tokenizer CRASHED on DECODER batch: {e}", exc_info=True)
            raise

        # Validation before extending
        if len(batch_input_ids) != len(input_batch) or len(batch_attention_mask) != len(input_batch) or len(batch_labels) != len(target_batch):
            raise ValueError(f"Batch encoding length mismatch! Input={len(batch_input_ids)}, Attn={len(batch_attention_mask)}, Labels={len(batch_labels)}, Expected={len(input_batch)}")

        all_input_ids.extend(batch_input_ids)
        all_attention_mask.extend(batch_attention_mask)
        all_labels.extend(batch_labels)

    logger.info(f"Finished encoding all {len(all_input_ids)} sequence pairs.")
    if len(all_input_ids) != num_sequences:
         raise ValueError(f"Encoded sequence count ({len(all_input_ids)}) mismatch original ({num_sequences}).")

    return {'input_ids': all_input_ids, 'attention_mask': all_attention_mask, 'labels': all_labels}


class SequenceDataset(Dataset):
    """Custom PyTorch Dataset wrapper for tokenized sequence data."""
    def __init__(self, encodings):
        required_keys = ['input_ids', 'attention_mask', 'labels']
        if not isinstance(encodings, dict) or not all(key in encodings for key in required_keys):
            logger.error(f"Invalid encodings passed to SequenceDataset. Got Keys: {list(encodings.keys()) if isinstance(encodings, dict) else type(encodings)}")
            raise ValueError(f"Encodings must be a dict with keys: {required_keys}.")
        try:
            lengths = {key: len(val) for key, val in encodings.items() if key in required_keys and isinstance(val, list)}
            if len(lengths) != len(required_keys):
                 missing_or_wrong_type = [k for k in required_keys if k not in lengths]
                 raise ValueError(f"Missing or non-list required keys in encodings: {missing_or_wrong_type}")
            if len(set(lengths.values())) > 1:
                 raise ValueError(f"Inconsistent lengths in encodings: {lengths}")
            self.length = lengths.get('input_ids', 0)
        except Exception as e:
            raise ValueError(f"Failed validating encoding structure/lengths: {e}") from e

        self.encodings = {k: v for k, v in encodings.items() if k in required_keys}
        if self.length == 0: logger.warning("Initializing SequenceDataset with length 0.")
        logger.info(f"Created SequenceDataset with {self.length} examples.")

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        if not 0 <= idx < self.length: raise IndexError(f"Index {idx} out of bounds ({self.length}).")
        try:
            item = {key: val[idx] for key, val in self.encodings.items()}
            if 'labels' not in item: raise KeyError(f"'labels' key missing for index {idx}")
            return item
        except Exception as e:
            logger.error(f"Failed creating item at index {idx}: {e}", exc_info=True)
            raise IndexError(f"Error retrieving item at index {idx}: {e}") from e


tokenizer_for_metrics = None # Global tokenizer for compute_metrics

def compute_metrics_fn(eval_pred):
    """Calculates exact match accuracy between decoded predictions and labels."""
    global tokenizer_for_metrics
    if tokenizer_for_metrics is None:
        logger.error("Tokenizer missing in compute_metrics!")
        return {"sequence_accuracy": 0.0}

    predictions, labels = eval_pred
    if isinstance(predictions, tuple): predictions = predictions[0]
    if predictions.ndim == 3 and predictions.shape[-1] > 1:
         predictions = np.argmax(predictions, axis=-1)

    if not isinstance(predictions, np.ndarray): predictions = np.array(predictions)
    if not isinstance(labels, np.ndarray): labels = np.array(labels)

    labels = labels.astype(np.int64)
    pad_token_id = tokenizer_for_metrics.pad_token_id
    labels = np.where(labels != -100, labels, pad_token_id)

    try:
        decoded_preds = tokenizer_for_metrics.batch_decode(predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        decoded_labels = tokenizer_for_metrics.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    except Exception as e:
        logger.error(f"Decode failed in compute_metrics: {e}", exc_info=True)
        return {"sequence_accuracy": 0.0}

    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [label.strip() for label in decoded_labels]

    if not decoded_preds or not decoded_labels or len(decoded_preds) != len(decoded_labels):
        logger.warning(f"Metrics decode issue: Preds empty={not decoded_preds}, Labels empty={not decoded_labels}, LenPred={len(decoded_preds)}, LenLabel={len(decoded_labels)}")
        return {"sequence_accuracy": 0.0}

    matches = [pred == label for pred, label in zip(decoded_preds, decoded_labels)]
    accuracy = np.mean(matches) if matches else 0.0

    logger.info(f"Computed sequence accuracy: {accuracy:.4f}")
    return {"sequence_accuracy": float(accuracy)}


# =============================================================================
# Main Execution Logic
# =============================================================================

def main(config_args):
    """Orchestrates the T5 training and evaluation pipeline using config object."""
    args = config_args # Use the config object passed as argument

    # --- Basic Setup ---
    logger.info("="*60)
    logger.info(" Starting T5 Sequence-to-Sequence Model Training and Evaluation ")
    logger.info("="*60)
    start_time = datetime.datetime.now()
    logger.info(f"Script execution started at: {start_time.strftime('%Y-%m-%d %H:%M:%S')}")

    if torch.cuda.is_available():
        logger.info(f"CUDA available. Device: {torch.cuda.get_device_name(0)}")
        logger.info(f"PyTorch CUDA version: {torch.version.cuda}")
    else:
        logger.warning("CUDA not available. Running on CPU.")
        if args.fp16:
             logger.warning("fp16=True ignored because CUDA is not available.")
             args.fp16 = False # Override config if no GPU

    logger.info(f"Running with configuration:")
    for k, v in vars(args).items(): logger.info(f"  {k}: {v}")

    set_seed(args.seed)
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    logger.info(f"Output directory set to: {output_dir}")

    # --- Data Loading and Preprocessing ---
    logger.info("--- Stage 1: Loading Raw Data ---")
    train_raw = load_jsonl(args.train_file)
    val_raw = load_jsonl(args.val_file)
    test_raw = load_jsonl(args.test_file)
    if not train_raw or not val_raw: # Test raw is optional for training
        raise ValueError("Training and/or validation raw datasets empty.")
    logger.info("--- Raw Data Loading Complete ---")

    logger.info("--- Stage 2: Converting Tokens to Strings ---")
    train_in, train_tgt = convert_tokens_to_strings(train_raw)
    val_in, val_tgt = convert_tokens_to_strings(val_raw)
    test_in, test_tgt = convert_tokens_to_strings(test_raw) # Convert test even if empty list results
    if not train_in or not val_in:
        raise ValueError("Training and/or validation datasets empty after string conversion.")
    logger.info("--- Token-to-String Conversion Complete ---")

    logger.info("--- Stage 3: Initializing Tokenizer ---")
    try:
        tokenizer = T5Tokenizer.from_pretrained(args.model_id, legacy=args.tokenizer_legacy)
        global tokenizer_for_metrics; tokenizer_for_metrics = tokenizer
        logger.info(f"Tokenizer {args.model_id} initialized (Fast: {tokenizer.is_fast}, Vocab: {tokenizer.vocab_size}).")
    except Exception as e: logger.error(f"Failed initializing tokenizer: {e}", exc_info=True); raise
    logger.info("--- Tokenizer Initialization Complete ---")

    logger.info("--- Stage 4: Encoding Data (Tokenization) ---")
    task_prefix = args.task_prefix or ""
    logger.info(f"Using Task Prefix: '{task_prefix}'")
    try:
        logger.info("--> Encoding Training Data...")
        train_enc = encode_sequences(tokenizer, train_in, train_tgt, args.max_seq_length, task_prefix)
        if not train_enc.get('input_ids'): raise ValueError("Training encoding failed (returned empty).")

        logger.info("--> Encoding Validation Data...")
        val_enc = encode_sequences(tokenizer, val_in, val_tgt, args.max_seq_length, task_prefix)
        if not val_enc.get('input_ids'): raise ValueError("Validation encoding failed (returned empty).")

        # Encode test data if available
        test_enc = None
        if test_in and test_tgt:
            logger.info("--> Encoding Test Data...")
            test_enc = encode_sequences(tokenizer, test_in, test_tgt, args.max_seq_length, task_prefix)
            if not test_enc.get('input_ids'): logger.warning("Test encoding returned empty.")
        else:
             logger.info("Test input/target strings empty, skipping test encoding.")

    except Exception as e: logger.error(f"Data encoding failed: {e}", exc_info=True); raise
    logger.info("--- Data Encoding Phase Complete ---")

    logger.info("--- Stage 5: Creating PyTorch Datasets ---")
    try:
        train_ds = SequenceDataset(train_enc)
        val_ds = SequenceDataset(val_enc)
        test_ds = SequenceDataset(test_enc) if test_enc and test_enc.get('input_ids') else None # Only create if encoding successful
    except Exception as e: logger.error(f"Dataset creation failed: {e}", exc_info=True); raise
    logger.info("--- PyTorch Datasets Created Successfully ---")

    logger.info("--- Stage 6: Initializing Model ---")
    try:
        model = T5ForConditionalGeneration.from_pretrained(args.model_id)
        logger.info(f"Model {args.model_id} loaded.")
        if torch.cuda.is_available(): logger.info(f"Est. Model Memory: {model.get_memory_footprint() / 1e9:.2f} GB")
    except Exception as e: logger.error(f"Failed initializing model: {e}", exc_info=True); raise
    logger.info("--- Model Initialization Complete ---")

    logger.info("--- Stage 7: Configuring Training Environment ---")
    use_gc = args.gradient_checkpointing
    if use_gc: logger.info("Attempting to enable Gradient Checkpointing...")
    try:
        if use_gc and hasattr(model, 'gradient_checkpointing_enable'):
             model.gradient_checkpointing_enable(); logger.info("Gradient Checkpointing enabled via model method.")
    except Exception as e:
        logger.warning(f"Failed enabling Gradient Checkpointing on model: {e}. Relying on TrainingArguments setting.", exc_info=True)
        use_gc = args.gradient_checkpointing # Keep desired value

    use_fp16 = args.fp16 # Already checked against cuda availability
    logger.info("Initializing Data Collator...")
    try:
        collator = DataCollatorForSeq2Seq(tokenizer, model=model, label_pad_token_id=-100, pad_to_multiple_of=8 if use_fp16 else None)
        logger.info("Data Collator initialized.")
    except Exception as e: logger.error(f"Collator init failed: {e}", exc_info=True); raise


    # <<<--- DIRECT INITIALIZATION (Corrected Parameter Name) --- >>>
    logger.info("Defining Training Arguments...")
    report_to = args.report_to.lower() if isinstance(args.report_to, str) else args.report_to
    if report_to == "none": logger.info("Reporting disabled.")
    elif report_to == "tensorboard": logger.info("Reporting to tensorboard.")
    else: logger.info(f"Reporting to: {report_to}")

    train_args = None # Initialize
    try:
        # Note: No signature check needed now, using corrected param name
        train_args = Seq2SeqTrainingArguments(
            output_dir=str(output_dir),
            # Core Training Params
            num_train_epochs=args.num_epochs,
            per_device_train_batch_size=args.train_batch_size,
            per_device_eval_batch_size=args.eval_batch_size,
            learning_rate=args.learning_rate,
            weight_decay=args.weight_decay,
            warmup_steps=args.warmup_steps,
            seed=args.seed,
            # Evaluation and Saving Strategy (Corrected Name)
            eval_strategy=args.eval_strategy,               # <<< CORRECTED
            eval_steps=args.eval_steps if args.eval_strategy == "steps" else None, # <<< CORRECTED
            save_strategy=args.save_strategy,
            save_steps=args.save_steps if args.save_strategy == "steps" else None,
            save_total_limit=args.save_total_limit,
            load_best_model_at_end=True,
            metric_for_best_model="sequence_accuracy",
            greater_is_better=True,
            # Logging
            logging_dir=str(output_dir / 'logs'),
            logging_strategy="steps",
            logging_steps=args.logging_steps,
            report_to=report_to,
            # Performance / Hardware
            fp16=use_fp16,
            gradient_checkpointing=use_gc,
            dataloader_num_workers=args.dataloader_num_workers,
            # Seq2Seq Specific
            predict_with_generate=True,
            generation_max_length=args.max_seq_length,
            generation_num_beams=args.num_beams,
            label_smoothing_factor=args.label_smoothing,
        )
        logger.info("Successfully initialized Seq2SeqTrainingArguments.")
        logger.info(f"Effective FP16: {train_args.fp16}, Grad Checkpointing: {train_args.gradient_checkpointing}")

    except TypeError as te:
        logger.error(f"FAILED to initialize Seq2SeqTrainingArguments: {te}", exc_info=True)
        raise te # Re-raise to stop execution
    except Exception as e:
        logger.error(f"FAILED to initialize Seq2SeqTrainingArguments with unexpected error: {e}", exc_info=True)
        raise e

    if train_args is None:
        raise RuntimeError("train_args was not successfully defined.")

    logger.info("--- Training Environment Configuration Complete ---")
    # <<<--- END DIRECT INITIALIZATION --- >>>


    logger.info("--- Stage 8: Initializing Seq2SeqTrainer ---")
    try:
        trainer = Seq2SeqTrainer(
            model=model,
            args=train_args,
            train_dataset=train_ds,
            eval_dataset=val_ds,
            data_collator=collator,
            tokenizer=tokenizer,
            compute_metrics=compute_metrics_fn,
        )
        logger.info("Seq2SeqTrainer Initialized.")
    except Exception as e:
        logger.error(f"Trainer init failed: {e}", exc_info=True)
        raise e


    logger.info("--- Stage 9: Starting Model Training ---")
    try:
        logger.info(f"Starting training for {args.num_epochs} epochs...")
        train_res = trainer.train()
        metrics = train_res.metrics
        logger.info("--- Training Finished ---")
        logger.info("--- Training Metrics ---")
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
        logger.info("Trainer state saved.")
        best_path = str(output_dir / "best_model")
        trainer.save_model(best_path)
        logger.info(f"Final best model saved to: {best_path}")
    except Exception as e: logger.error(f"Training failed: {e}", exc_info=True); raise

    logger.info("--- Stage 10: Evaluating Model on Test Set ---")
    if test_ds:
        try:
            logger.info("Running final evaluation on the test set...")
            test_res = trainer.evaluate(eval_dataset=test_ds, metric_key_prefix="test")
            logger.info("--- Evaluation on Test Set Finished ---")
            logger.info("--- Test Set Results ---")
            trainer.log_metrics("test", test_res)
            trainer.save_metrics("test", test_res)
            metric_key = "test_sequence_accuracy"
            if metric_key in test_res:
                logger.info(f"\n*** Test Sequence Accuracy: {test_res[metric_key]:.4f} ***\n")
            else:
                logger.warning(f"Metric '{metric_key}' not found in test results.")
                logger.info(f"Available test metrics: {list(test_res.keys())}")
        except Exception as e:
            logger.error(f"Test evaluation failed: {e}", exc_info=True)
    else:
        logger.warning("Test dataset unavailable or not loaded. Skipping final evaluation.")
    logger.info("--- Final Evaluation Stage Complete ---")

    end_time = datetime.datetime.now()
    total_time = end_time - start_time
    logger.info("="*60)
    try:
        end_time_str = end_time.strftime('%Y-%m-%d %H:%M:%S')
    except Exception:
        end_time_str = "Time Unavailable"
    logger.info(f" Script execution finished at: {end_time_str}")
    logger.info(f" Total execution time: {total_time}")
    logger.info("="*60)


# =============================================================================
# Execute Main Function (for Notebook context)
# =============================================================================
# In a notebook, we call main directly instead of using if __name__ == "__main__":
try:
    # Ensure CUDA check happens right before main call if needed
    if not torch.cuda.is_available() and args.fp16:
        logger.warning("MAIN (Notebook): CUDA not detected before starting main function!")
        logger.warning("MAIN (Notebook): Disabling fp16 as CUDA is not available.")
        args.fp16 = False

    main(args) # Call the main function with the configured args

except Exception as e:
     logger.critical(f"Unhandled exception terminated script execution: {e}", exc_info=True)
     # In a notebook, just logging might be sufficient, or re-raise if preferred
     # raise e

2025-04-22 20:46:14 - INFO - [main] -  Starting T5 Sequence-to-Sequence Model Training and Evaluation 
2025-04-22 20:46:14 - INFO - [main] - Script execution started at: 2025-04-22 20:46:14
2025-04-22 20:46:14 - INFO - [main] - CUDA available. Device: NVIDIA GeForce RTX 4090 Laptop GPU
2025-04-22 20:46:14 - INFO - [main] - PyTorch CUDA version: 12.1
2025-04-22 20:46:14 - INFO - [main] - Running with configuration:
2025-04-22 20:46:14 - INFO - [main] -   train_file: qed_data_processed/qed_amplitudes_train.jsonl
2025-04-22 20:46:14 - INFO - [main] -   val_file: qed_data_processed/qed_amplitudes_val.jsonl
2025-04-22 20:46:14 - INFO - [main] -   test_file: qed_data_processed/qed_amplitudes_test.jsonl
2025-04-22 20:46:14 - INFO - [main] -   output_dir: qed_t5_model_output_gpu_final
2025-04-22 20:46:14 - INFO - [main] -   model_id: t5-base
2025-04-22 20:46:14 - INFO - [main] -   task_prefix: 
2025-04-22 20:46:14 - INFO - [main] -   tokenizer_legacy: False
2025-04-22 20:46:14 - INFO - [main] 

Epoch,Training Loss,Validation Loss,Sequence Accuracy
1,1.4484,1.422035,0.215434
2,1.4204,1.405642,0.326688
3,1.4128,1.401123,0.342122
4,1.4081,1.397476,0.404502
5,1.4055,1.395355,0.406431


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
2025-04-22 20:50:58 - INFO - [compute_metrics_fn] - Computed sequence accuracy: 0.2154
2025-04-22 20:55:34 - INFO - [compute_metrics_fn] - Computed sequence accuracy: 0.3267
2025-04-22 21:00:11 - INFO - [compute_metrics_fn] - Computed sequence accuracy: 0.3421
2025-04-22 21:04:44 - INFO - [compute_metrics_fn] - Computed sequence accuracy: 0.4045
2025-04-22 21:09:16 - INFO - [compute_metrics_fn] - Computed sequence accuracy: 0.4064
There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].
2025-04-22 21:09:18 - INFO - [main] - --- Training Finished ---
2025-04-22 21:09:18 - INFO - [main] - --- Training Metrics ---
2025-04-22 21:09:18 - INFO - [main] - Trainer state saved.


***** train metrics *****
  epoch                    =        5.0
  total_flos               =  8819677GF
  train_loss               =     1.4702
  train_runtime            = 0:22:54.71
  train_samples_per_second =      45.25
  train_steps_per_second   =       2.83


2025-04-22 21:09:19 - INFO - [main] - Final best model saved to: qed_t5_model_output_gpu_final/best_model
2025-04-22 21:09:19 - INFO - [main] - --- Stage 10: Evaluating Model on Test Set ---
2025-04-22 21:09:19 - INFO - [main] - Running final evaluation on the test set...


2025-04-22 21:11:33 - INFO - [compute_metrics_fn] - Computed sequence accuracy: 0.4254
2025-04-22 21:11:33 - INFO - [main] - --- Evaluation on Test Set Finished ---
2025-04-22 21:11:33 - INFO - [main] - --- Test Set Results ---
2025-04-22 21:11:33 - INFO - [main] - 
*** Test Sequence Accuracy: 0.4254 ***

2025-04-22 21:11:33 - INFO - [main] - --- Final Evaluation Stage Complete ---
2025-04-22 21:11:33 - INFO - [main] -  Script execution finished at: 2025-04-22 21:11:33
2025-04-22 21:11:33 - INFO - [main] -  Total execution time: 0:25:19.068186


***** test metrics *****
  epoch                   =        5.0
  test_loss               =     1.3951
  test_runtime            = 0:02:14.57
  test_samples_per_second =     11.562
  test_sequence_accuracy  =     0.4254
  test_steps_per_second   =      0.728


# 3.1: Next-Generation Transformer Models for Symbolic Calculations of Squared Amplitudes in HEP
Model: Transformer model with a contemporary innovation added such as KAN layers, reinforcement learning, genetic algorithms, specialized long-sequence attention, etc. which improves the performance compared to a basic transformer.


In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Task 3.1: End-to-end Training and Evaluation Script for a LongT5 Model
on Symbolic Calculations of Squared Amplitudes in HEP.

This script adapts the previous T5 framework to use LongT5, a Transformer
model with improved long-sequence handling capabilities (Transient Global Attention),
as the "contemporary innovation".

*** NOTE: This version is adapted for robust execution ***
*** within a Jupyter Notebook (.ipynb). Configuration is handled    ***
*** by manually setting attributes on the `args` object.            ***

Key Changes from Basic T5 Script:
- Model: Uses LongT5ForConditionalGeneration.
- Tokenizer: Still uses T5Tokenizer (compatible with LongT5 checkpoints).
- Checkpoint: Uses a pre-trained LongT5 checkpoint ('google/long-t5-tglobal-base').
- Max Sequence Length: Increased significantly to leverage long-sequence capability.
- Batch Size: Potentially reduced due to higher memory usage from longer sequences.
- Output Directory: Changed to reflect the new model.
- Comments/Logging: Updated for the new task and model.
- Diagnostics: Removed path/signature checks for clarity.

Dependencies:
- Working Python 3.10 environment (e.g., t5_pip_test) with necessary packages
  (transformers, accelerate, torch, datasets, sentencepiece, etc.).
"""

import json
import sys
import numpy as np
import torch
import datetime # Use the standard datetime module
import logging
import argparse # Still used for the Namespace object, but not for parsing
from pathlib import Path
from torch.utils.data import Dataset
import inspect # Keep inspect if needed by other parts, otherwise optional now

# Import the specific classes needed
# NOTE: LongT5 often uses the standard T5Tokenizer
from transformers import (
    T5Tokenizer,
    LongT5ForConditionalGeneration, # <<< CHANGED Model Class
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    TrainingArguments # Keep for reference if needed
)
from transformers.trainer_utils import set_seed # Use transformers' set_seed

# Configure basic logging (Set level to DEBUG for detailed output, INFO for progress)
# Check if logger already exists (useful in notebooks where cells might be re-run)
logger = logging.getLogger(__name__)
if not logger.hasHandlers():
    logging.basicConfig(
        level=logging.INFO, # INFO shows progress, DEBUG shows fine details
        format="%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        # Use stream=sys.stdout to ensure output appears in notebook cell
        stream=sys.stdout
    )
else:
    # Ensure level is still INFO if logger was configured previously
    logger.setLevel(logging.INFO)


# =============================================================================
# <<< CONFIGURATION >>>
# =============================================================================
# --- EDIT THESE VALUES TO CONFIGURE THE RUN ---
args = argparse.Namespace(
    # --- File Paths ---
    # !!! Update these if your HEP dataset has different names/locations !!!
    train_file='qed_data_processed/qed_amplitudes_train.jsonl',
    val_file='qed_data_processed/qed_amplitudes_val.jsonl',
    test_file='qed_data_processed/qed_amplitudes_test.jsonl',
    output_dir='longt5_symbolic_hep_output', # <<< CHANGED Output Dir

    # --- Model Configuration ---
    # <<< CHANGED Model Checkpoint to a LongT5 variant >>>
    # (Other options: 'google/long-t5-local-base', 'google/long-t5-tglobal-large', etc.)
    model_id="google/long-t5-tglobal-base",
    task_prefix="", # Add task prefix if needed e.g. "calculate squared amplitude: "
    tokenizer_legacy=False, # Usually False for newer models

    # --- Tokenizer and Data Processing ---
    # <<< INCREASED Max Sequence Length >>> (Adjust based on your data & GPU memory)
    max_seq_length=2048, # e.g., 1024, 2048, 4096, up to 16k for some variants

    # --- Training Hyperparameters ---
    # !!! These likely need significant tuning for LongT5 and longer sequences !!!
    num_epochs=5,           # May need more or fewer epochs
    # <<< REDUCED Batch Size due to longer sequences / potentially larger model >>>
    train_batch_size=4,     # START SMALL! Monitor GPU memory. Adjust as needed.
    eval_batch_size=4,      # START SMALL!
    learning_rate=5e-5,     # Common starting point, adjust based on results
    weight_decay=0.01,
    warmup_steps=500,       # May need adjustment based on dataset size/LR schedule
    logging_steps=100,      # Log more often initially?
    eval_strategy="epoch",  # Correct parameter name
    eval_steps=None,        # Not used when eval_strategy is 'epoch'
    save_strategy="epoch",
    save_steps=None,        # Not used when save_strategy is 'epoch'
    save_total_limit=2,     # Keep latest 2 checkpoints
    seed=42,
    dataloader_num_workers=0, # Safest for notebooks

    # --- Advanced Training Features ---
    fp16=torch.cuda.is_available(), # Use mixed precision if possible
    gradient_checkpointing=True,    # HIGHLY Recommended for long sequences/large models
    label_smoothing=0.1,

    # --- Generation Configuration (for evaluation) ---
    num_beams=4,

    # --- Reporting ---
    report_to="tensorboard", # Or "none", "wandb", etc.
)
# --- END OF CONFIGURATION ---

# (Helper Functions: load_jsonl, convert_tokens_to_strings, encode_sequences, SequenceDataset, compute_metrics_fn)
def load_jsonl(file_path):
    """Loads a JSON Lines (.jsonl) file, skipping blank/whitespace-only lines."""
    data = []
    file_path = Path(file_path)
    logger.info(f"Loading data from: {file_path}")
    if not file_path.is_file():
        logger.error(f"Data file not found: {file_path}")
        raise FileNotFoundError(f"Data file not found: {file_path}")
    try:
        with file_path.open('r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                line = line.strip()
                if line:
                    try:
                        data.append(json.loads(line))
                    except json.JSONDecodeError as e:
                        logger.warning(f"Skipping invalid JSON on line {i+1} in {file_path}: {e}")
                        continue
        logger.info(f"Successfully loaded {len(data)} records from {file_path}.")
        return data
    except Exception as e:
        logger.error(f"Failed to load data from {file_path}: {e}", exc_info=True)
        raise


def convert_tokens_to_strings(raw_data_list):
    """Converts lists of tokens into single whitespace-joined strings."""
    input_strings = []
    target_strings = []
    if not raw_data_list:
        logger.warning("Input data list is empty for token-to-string conversion.")
        return input_strings, target_strings

    skipped_count = 0
    logger.info(f"Attempting to convert {len(raw_data_list)} items to strings...")
    for i, item in enumerate(raw_data_list):
        # !!! Adapt these keys if your HEP dataset uses different names !!!
        input_toks = item.get('input_tokens')  # Or e.g., 'process_tokens'
        target_toks = item.get('target_tokens') # Or e.g., 'amplitude_tokens'

        if isinstance(input_toks, list) and isinstance(target_toks, list):
            try:
                str_input_toks = [str(tok) for tok in input_toks]
                str_target_toks = [str(tok) for tok in target_toks]
                joined_input = " ".join(str_input_toks)
                joined_target = " ".join(str_target_toks)
                input_strings.append(joined_input)
                target_strings.append(joined_target)
            except Exception as e:
                 logger.warning(f"Error joining tokens for item at index {i}: {e}. Item: {item}", exc_info=True)
                 skipped_count += 1
        else:
            missing_keys = []
            # !!! Update keys checked if needed !!!
            if 'input_tokens' not in item: missing_keys.append('input_tokens')
            if 'target_tokens' not in item: missing_keys.append('target_tokens')
            invalid_types = []
            if 'input_tokens' in item and not isinstance(input_toks, list): invalid_types.append(f"input_tokens (type: {type(input_toks).__name__})")
            if 'target_tokens' in item and not isinstance(target_toks, list): invalid_types.append(f"target_tokens (type: {type(target_toks).__name__})")
            log_msg = f"Skipping item at index {i}."
            if missing_keys: log_msg += f" Missing keys: {missing_keys}."
            if invalid_types: log_msg += f" Keys with non-list values: {invalid_types}."
            item_str = str(item)
            if len(item_str) > 200: item_str = item_str[:200] + "..."
            log_msg += f" Item (truncated): {item_str}"
            logger.warning(log_msg)
            skipped_count += 1

    logger.info(f"Successfully converted {len(input_strings)} items to strings (skipped {skipped_count}).")
    if not input_strings and skipped_count > 0 and skipped_count == len(raw_data_list):
         logger.error("All items were skipped during token-to-string conversion.")
         raise ValueError("Failed to convert any items to strings.")
    elif not input_strings and len(raw_data_list) > 0 and skipped_count != len(raw_data_list):
        logger.error(f"Conversion resulted in an empty list of strings, but {len(raw_data_list)} items were processed ({skipped_count} skipped).")
        raise ValueError("String conversion unexpectedly produced no output despite valid input items existing.")
    elif not input_strings and not raw_data_list:
        logger.warning("Input data list was empty, resulting in empty string lists.")
    return input_strings, target_strings


def encode_sequences(tokenizer, input_strings, target_strings, max_len, task_prefix="", batch_size=500):
    """Tokenizes input and target string pairs."""
    logger.info(f"Starting encoding for {len(input_strings)} pairs. MaxLen={max_len}. Prefix='{task_prefix}'. (Batch size: {batch_size})")

    if not isinstance(input_strings, list) or not isinstance(target_strings, list):
        raise TypeError("input_strings and target_strings must be lists.")
    if not input_strings or not target_strings:
        logger.warning("Received empty list(s) for encoding. Returning empty.")
        return {'input_ids': [], 'attention_mask': [], 'labels': []}
    if len(input_strings) != len(target_strings):
        raise ValueError(f"Input ({len(input_strings)}) and Target ({len(target_strings)}) string lists must have the same length.")

    all_input_ids = []
    all_attention_mask = []
    all_labels = []
    num_sequences = len(input_strings)

    for i in range(0, num_sequences, batch_size):
        batch_start = i
        batch_end = min(i + batch_size, num_sequences)
        logger.info(f"  Encoding batch {batch_start+1}-{batch_end}/{num_sequences}...")

        input_batch = input_strings[batch_start:batch_end]
        target_batch = target_strings[batch_start:batch_end]

        # Encode Inputs
        try:
            processed_input_batch = [f"{task_prefix}{s}" if task_prefix else s for s in input_batch]
            encoder_outputs = tokenizer(
                processed_input_batch, max_length=max_len, padding='max_length', truncation=True, return_tensors=None
            )
            batch_input_ids = encoder_outputs['input_ids']
            batch_attention_mask = encoder_outputs['attention_mask']
        except Exception as e:
            logger.error(f"Tokenizer CRASHED on ENCODER batch: {e}", exc_info=True)
            raise

        # Encode Targets (Labels)
        try:
            decoder_outputs = tokenizer(
                text_target=target_batch, max_length=max_len, padding='max_length', truncation=True, return_tensors=None
            )
            batch_labels = decoder_outputs['input_ids']
        except Exception as e:
            logger.error(f"Tokenizer CRASHED on DECODER batch: {e}", exc_info=True)
            raise

        if len(batch_input_ids) != len(input_batch) or len(batch_attention_mask) != len(input_batch) or len(batch_labels) != len(target_batch):
            raise ValueError(f"Batch encoding length mismatch! Input={len(batch_input_ids)}, Attn={len(batch_attention_mask)}, Labels={len(batch_labels)}, Expected={len(input_batch)}")

        all_input_ids.extend(batch_input_ids)
        all_attention_mask.extend(batch_attention_mask)
        all_labels.extend(batch_labels)

    logger.info(f"Finished encoding all {len(all_input_ids)} sequence pairs.")
    if len(all_input_ids) != num_sequences:
         raise ValueError(f"Encoded sequence count ({len(all_input_ids)}) mismatch original ({num_sequences}).")

    return {'input_ids': all_input_ids, 'attention_mask': all_attention_mask, 'labels': all_labels}


class SequenceDataset(Dataset):
    """Custom PyTorch Dataset wrapper for tokenized sequence data."""
    def __init__(self, encodings):
        required_keys = ['input_ids', 'attention_mask', 'labels']
        if not isinstance(encodings, dict) or not all(key in encodings for key in required_keys):
            logger.error(f"Invalid encodings passed to SequenceDataset. Got Keys: {list(encodings.keys()) if isinstance(encodings, dict) else type(encodings)}")
            raise ValueError(f"Encodings must be a dict with keys: {required_keys}.")
        try:
            lengths = {key: len(val) for key, val in encodings.items() if key in required_keys and isinstance(val, list)}
            if len(lengths) != len(required_keys):
                 missing_or_wrong_type = [k for k in required_keys if k not in lengths]
                 raise ValueError(f"Missing or non-list required keys in encodings: {missing_or_wrong_type}")
            if len(set(lengths.values())) > 1:
                 raise ValueError(f"Inconsistent lengths in encodings: {lengths}")
            self.length = lengths.get('input_ids', 0)
        except Exception as e:
            raise ValueError(f"Failed validating encoding structure/lengths: {e}") from e

        self.encodings = {k: v for k, v in encodings.items() if k in required_keys}
        if self.length == 0: logger.warning("Initializing SequenceDataset with length 0.")
        logger.info(f"Created SequenceDataset with {self.length} examples.")

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        if not 0 <= idx < self.length: raise IndexError(f"Index {idx} out of bounds ({self.length}).")
        try:
            item = {key: val[idx] for key, val in self.encodings.items()}
            if 'labels' not in item: raise KeyError(f"'labels' key missing for index {idx}")
            return item
        except Exception as e:
            logger.error(f"Failed creating item at index {idx}: {e}", exc_info=True)
            raise IndexError(f"Error retrieving item at index {idx}: {e}") from e


tokenizer_for_metrics = None # Global tokenizer for compute_metrics

def compute_metrics_fn(eval_pred):
    """Calculates exact match accuracy between decoded predictions and labels."""
    # !!! Consider adding more sophisticated metrics for symbolic math if needed !!!
    # E.g., SymPy based comparison, tree edit distance, etc. Accuracy might be too strict.
    global tokenizer_for_metrics
    if tokenizer_for_metrics is None:
        logger.error("Tokenizer missing in compute_metrics!")
        return {"sequence_accuracy": 0.0}

    predictions, labels = eval_pred
    if isinstance(predictions, tuple): predictions = predictions[0]
    if predictions.ndim == 3 and predictions.shape[-1] > 1:
         predictions = np.argmax(predictions, axis=-1)

    if not isinstance(predictions, np.ndarray): predictions = np.array(predictions)
    if not isinstance(labels, np.ndarray): labels = np.array(labels)

    labels = labels.astype(np.int64)
    pad_token_id = tokenizer_for_metrics.pad_token_id
    labels = np.where(labels != -100, labels, pad_token_id)

    try:
        decoded_preds = tokenizer_for_metrics.batch_decode(predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        decoded_labels = tokenizer_for_metrics.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    except Exception as e:
        logger.error(f"Decode failed in compute_metrics: {e}", exc_info=True)
        return {"sequence_accuracy": 0.0}

    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [label.strip() for label in decoded_labels]

    if not decoded_preds or not decoded_labels or len(decoded_preds) != len(decoded_labels):
        logger.warning(f"Metrics decode issue: Preds empty={not decoded_preds}, Labels empty={not decoded_labels}, LenPred={len(decoded_preds)}, LenLabel={len(decoded_labels)}")
        return {"sequence_accuracy": 0.0}

    matches = [pred == label for pred, label in zip(decoded_preds, decoded_labels)]
    accuracy = np.mean(matches) if matches else 0.0

    logger.info(f"Computed sequence accuracy: {accuracy:.4f}")
    # Consider adding other metrics here later
    return {"sequence_accuracy": float(accuracy)}


# =============================================================================
# Main Execution Logic
# =============================================================================

def main(config_args):
    """Orchestrates the LongT5 training and evaluation for symbolic HEP."""
    args = config_args # Use the config object passed as argument

    # --- Basic Setup ---
    logger.info("="*60)
    logger.info(" Starting LongT5 Training/Evaluation for Symbolic HEP Calculation ")
    logger.info("="*60)
    start_time = datetime.datetime.now()
    logger.info(f"Script execution started at: {start_time.strftime('%Y-%m-%d %H:%M:%S')}")

    if torch.cuda.is_available():
        logger.info(f"CUDA available. Device: {torch.cuda.get_device_name(0)}")
        logger.info(f"PyTorch CUDA version: {torch.version.cuda}")
    else:
        logger.warning("CUDA not available. Running on CPU.")
        if args.fp16:
             logger.warning("fp16=True ignored because CUDA is not available.")
             args.fp16 = False # Override config if no GPU

    logger.info(f"Running with configuration:")
    for k, v in vars(args).items(): logger.info(f"  {k}: {v}")

    set_seed(args.seed)
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    logger.info(f"Output directory set to: {output_dir}")

    # --- Data Loading and Preprocessing ---
    logger.info("--- Stage 1: Loading Raw Data ---")
    train_raw = load_jsonl(args.train_file)
    val_raw = load_jsonl(args.val_file)
    test_raw = load_jsonl(args.test_file)
    if not train_raw or not val_raw:
        raise ValueError("Training and/or validation raw datasets empty.")
    logger.info("--- Raw Data Loading Complete ---")

    logger.info("--- Stage 2: Converting Tokens to Strings ---")
    train_in, train_tgt = convert_tokens_to_strings(train_raw)
    val_in, val_tgt = convert_tokens_to_strings(val_raw)
    test_in, test_tgt = convert_tokens_to_strings(test_raw)
    if not train_in or not val_in:
        raise ValueError("Training and/or validation datasets empty after string conversion.")
    logger.info("--- Token-to-String Conversion Complete ---")

    logger.info("--- Stage 3: Initializing Tokenizer ---")
    try:
        # Use T5Tokenizer, often compatible with LongT5 checkpoints
        tokenizer = T5Tokenizer.from_pretrained(args.model_id, legacy=args.tokenizer_legacy)
        global tokenizer_for_metrics; tokenizer_for_metrics = tokenizer
        logger.info(f"Tokenizer for {args.model_id} initialized (Fast: {tokenizer.is_fast}, Vocab: {tokenizer.vocab_size}).")
    except Exception as e: logger.error(f"Failed initializing tokenizer: {e}", exc_info=True); raise
    logger.info("--- Tokenizer Initialization Complete ---")

    logger.info("--- Stage 4: Encoding Data (Tokenization) ---")
    task_prefix = args.task_prefix or ""
    logger.info(f"Using Task Prefix: '{task_prefix}'")
    try:
        logger.info("--> Encoding Training Data...")
        train_enc = encode_sequences(tokenizer, train_in, train_tgt, args.max_seq_length, task_prefix)
        if not train_enc.get('input_ids'): raise ValueError("Training encoding failed (returned empty).")

        logger.info("--> Encoding Validation Data...")
        val_enc = encode_sequences(tokenizer, val_in, val_tgt, args.max_seq_length, task_prefix)
        if not val_enc.get('input_ids'): raise ValueError("Validation encoding failed (returned empty).")

        test_enc = None
        if test_in and test_tgt:
            logger.info("--> Encoding Test Data...")
            test_enc = encode_sequences(tokenizer, test_in, test_tgt, args.max_seq_length, task_prefix)
            if not test_enc.get('input_ids'): logger.warning("Test encoding returned empty.")
        else:
             logger.info("Test input/target strings empty, skipping test encoding.")

    except Exception as e: logger.error(f"Data encoding failed: {e}", exc_info=True); raise
    logger.info("--- Data Encoding Phase Complete ---")

    logger.info("--- Stage 5: Creating PyTorch Datasets ---")
    try:
        train_ds = SequenceDataset(train_enc)
        val_ds = SequenceDataset(val_enc)
        test_ds = SequenceDataset(test_enc) if test_enc and test_enc.get('input_ids') else None
    except Exception as e: logger.error(f"Dataset creation failed: {e}", exc_info=True); raise
    logger.info("--- PyTorch Datasets Created Successfully ---")

    logger.info("--- Stage 6: Initializing Model ---")
    try:
        # <<< CHANGED to LongT5 Model >>>
        model = LongT5ForConditionalGeneration.from_pretrained(args.model_id)
        logger.info(f"Model {args.model_id} loaded.")
        if torch.cuda.is_available(): logger.info(f"Est. Model Memory: {model.get_memory_footprint() / 1e9:.2f} GB")
    except Exception as e: logger.error(f"Failed initializing model: {e}", exc_info=True); raise
    logger.info("--- Model Initialization Complete ---")

    logger.info("--- Stage 7: Configuring Training Environment ---")
    use_gc = args.gradient_checkpointing
    if use_gc: logger.info("Attempting to enable Gradient Checkpointing...")
    try:
        if use_gc and hasattr(model, 'gradient_checkpointing_enable'):
             model.gradient_checkpointing_enable(); logger.info("Gradient Checkpointing enabled via model method.")
             if hasattr(model, 'is_gradient_checkpointing'): logger.info(f"Model GC state: {model.is_gradient_checkpointing}")
    except Exception as e:
        logger.warning(f"Failed enabling Gradient Checkpointing on model: {e}. Relying on TrainingArguments setting.", exc_info=True)
        use_gc = args.gradient_checkpointing # Keep desired value

    use_fp16 = args.fp16 # Already checked against cuda availability
    logger.info("Initializing Data Collator...")
    try:
        collator = DataCollatorForSeq2Seq(tokenizer, model=model, label_pad_token_id=-100, pad_to_multiple_of=8 if use_fp16 else None)
        logger.info("Data Collator initialized.")
    except Exception as e: logger.error(f"Collator init failed: {e}", exc_info=True); raise


    # --- Initialize Training Arguments (Corrected Parameter Name) ---
    logger.info("Defining Training Arguments...")
    report_to = args.report_to.lower() if isinstance(args.report_to, str) else args.report_to
    if report_to == "none": logger.info("Reporting disabled.")
    else: logger.info(f"Reporting to: {report_to}")

    train_args = None # Initialize
    try:
        train_args = Seq2SeqTrainingArguments(
            output_dir=str(output_dir),
            # Core Training Params
            num_train_epochs=args.num_epochs,
            per_device_train_batch_size=args.train_batch_size,
            per_device_eval_batch_size=args.eval_batch_size,
            learning_rate=args.learning_rate,
            weight_decay=args.weight_decay,
            warmup_steps=args.warmup_steps,
            seed=args.seed,
            # Evaluation and Saving Strategy (Corrected Name)
            eval_strategy=args.eval_strategy,
            eval_steps=args.eval_steps if args.eval_strategy == "steps" else None,
            save_strategy=args.save_strategy,
            save_steps=args.save_steps if args.save_strategy == "steps" else None,
            save_total_limit=args.save_total_limit,
            load_best_model_at_end=True,
            metric_for_best_model="sequence_accuracy", # Make sure this matches compute_metrics output
            greater_is_better=True,
            # Logging
            logging_dir=str(output_dir / 'logs'),
            logging_strategy="steps",
            logging_steps=args.logging_steps,
            report_to=report_to,
            # Performance / Hardware
            fp16=use_fp16,
            gradient_checkpointing=use_gc, # Use value potentially modified if model enable failed
            dataloader_num_workers=args.dataloader_num_workers,
            # Seq2Seq Specific
            predict_with_generate=True,
            # Generation params used during evaluation AND testing with predict_with_generate=True
            generation_max_length=args.max_seq_length, # Control output length during eval/predict
            generation_num_beams=args.num_beams,
            label_smoothing_factor=args.label_smoothing,
        )
        logger.info("Successfully initialized Seq2SeqTrainingArguments.")
        logger.info(f"Effective FP16: {train_args.fp16}, Grad Checkpointing: {train_args.gradient_checkpointing}")

    except Exception as e: # Catch any exception during init
        logger.error(f"FAILED to initialize Seq2SeqTrainingArguments: {e}", exc_info=True)
        raise e # Re-raise to stop execution

    if train_args is None:
        raise RuntimeError("train_args was not successfully defined.")

    logger.info("--- Training Environment Configuration Complete ---")

    # --- Initialize Trainer ---
    logger.info("--- Stage 8: Initializing Seq2SeqTrainer ---")
    try:
        trainer = Seq2SeqTrainer(
            model=model,
            args=train_args,
            train_dataset=train_ds,
            eval_dataset=val_ds,
            data_collator=collator,
            tokenizer=tokenizer,
            compute_metrics=compute_metrics_fn,
        )
        logger.info("Seq2SeqTrainer Initialized.")
    except Exception as e:
        logger.error(f"Trainer init failed: {e}", exc_info=True)
        raise e

    # --- Training ---
    logger.info("--- Stage 9: Starting Model Training ---")
    try:
        logger.info(f"Starting training for {args.num_epochs} epochs...")
        train_res = trainer.train()
        metrics = train_res.metrics
        logger.info("--- Training Finished ---")
        logger.info("--- Training Metrics ---")
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
        logger.info("Trainer state saved.")
        best_path = str(output_dir / "best_model")
        trainer.save_model(best_path)
        logger.info(f"Final best model saved to: {best_path}")
    except Exception as e: logger.error(f"Training failed: {e}", exc_info=True); raise

    # --- Evaluation ---
    logger.info("--- Stage 10: Evaluating Model on Test Set ---")
    if test_ds:
        try:
            logger.info("Running final evaluation on the test set...")
            test_res = trainer.evaluate(eval_dataset=test_ds, metric_key_prefix="test")
            logger.info("--- Evaluation on Test Set Finished ---")
            logger.info("--- Test Set Results ---")
            trainer.log_metrics("test", test_res)
            trainer.save_metrics("test", test_res)
            metric_key = "test_sequence_accuracy"
            if metric_key in test_res:
                logger.info(f"\n*** Test Sequence Accuracy: {test_res[metric_key]:.4f} ***\n")
            else:
                logger.warning(f"Metric '{metric_key}' not found in test results.")
                logger.info(f"Available test metrics: {list(test_res.keys())}")
        except Exception as e:
            logger.error(f"Test evaluation failed: {e}", exc_info=True)
    else:
        logger.warning("Test dataset unavailable or not loaded. Skipping final evaluation.")
    logger.info("--- Final Evaluation Stage Complete ---")

    # --- End Script ---
    end_time = datetime.datetime.now()
    total_time = end_time - start_time
    logger.info("="*60)
    logger.info(f" Script execution finished successfully at: {end_time.strftime('%Y-%m-%d %H:%M:%S')}")
    logger.info(f" Total execution time: {total_time}")
    logger.info("="*60)


# =============================================================================
# Execute Main Function (for Notebook context)
# =============================================================================
# In a notebook, we call main directly instead of using if __name__ == "__main__":
try:
    # Ensure CUDA check happens right before main call if needed
    if not torch.cuda.is_available() and args.fp16:
        logger.warning("MAIN (Notebook): CUDA not detected before starting main function!")
        logger.warning("MAIN (Notebook): Disabling fp16 as CUDA is not available.")
        args.fp16 = False

    # Call the main function with the configured args object
    main(args)

except Exception as e:
     logger.critical(f"Unhandled exception terminated script execution: {e}", exc_info=True)
     # In a notebook, just logging might be sufficient, or re-raise if preferred
     # raise e

2025-04-22 21:31:15 - INFO - [main] -  Starting LongT5 Training/Evaluation for Symbolic HEP Calculation 
2025-04-22 21:31:15 - INFO - [main] - Script execution started at: 2025-04-22 21:31:15
2025-04-22 21:31:15 - INFO - [main] - CUDA available. Device: NVIDIA GeForce RTX 4090 Laptop GPU
2025-04-22 21:31:15 - INFO - [main] - PyTorch CUDA version: 12.1
2025-04-22 21:31:15 - INFO - [main] - Running with configuration:
2025-04-22 21:31:15 - INFO - [main] -   train_file: qed_data_processed/qed_amplitudes_train.jsonl
2025-04-22 21:31:15 - INFO - [main] -   val_file: qed_data_processed/qed_amplitudes_val.jsonl
2025-04-22 21:31:15 - INFO - [main] -   test_file: qed_data_processed/qed_amplitudes_test.jsonl
2025-04-22 21:31:15 - INFO - [main] -   output_dir: longt5_symbolic_hep_output
2025-04-22 21:31:15 - INFO - [main] -   model_id: google/long-t5-tglobal-base
2025-04-22 21:31:15 - INFO - [main] -   task_prefix: 
2025-04-22 21:31:15 - INFO - [main] -   tokenizer_legacy: False
2025-04-22 21:31:

  trainer = Seq2SeqTrainer(


2025-04-22 21:31:25 - INFO - [main] - Seq2SeqTrainer Initialized.
2025-04-22 21:31:25 - INFO - [main] - --- Stage 9: Starting Model Training ---
2025-04-22 21:31:25 - INFO - [main] - Starting training for 5 epochs...
2025-04-22 21:31:27 - ERROR - [main] - Training failed: CUDA out of memory. Tried to allocate 192.00 MiB. GPU 0 has a total capacity of 15.70 GiB of which 67.25 MiB is free. Process 99334 has 12.93 GiB memory in use. Including non-PyTorch memory, this process has 2.69 GiB memory in use. Of the allocated memory 2.14 GiB is allocated by PyTorch, and 268.35 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Traceback (most recent call last):
  File "/tmp/ipykernel_140970/1604048880.py", line 539, in main
    train_res = trainer.train()
  File "

# 3.2: State-space Models for Squared Amplitude Calculation in High-Energy Physics
Model: State-space model such as mamba or other model.

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Task 3.2: End-to-end Training and Evaluation Script for a Mamba SSM
for Symbolic Calculations of Squared Amplitudes in HEP.

This script adapts the previous framework to use a Mamba model,
a type of State-Space Model (SSM), integrated within the Hugging Face ecosystem.

*** NOTE: This version is adapted for robust execution ***
*** within a Jupyter Notebook (.ipynb). Configuration is handled    ***
*** by manually setting attributes on the `args` object.            ***

Key Changes from T5/LongT5 Script:
- Model: Uses MambaForCausalLM.
- Tokenizer: Uses a tokenizer compatible with the Mamba checkpoint (e.g., GPTNeoXTokenizer).
- Checkpoint: Uses a pre-trained Mamba checkpoint (e.g., 'state-spaces/mamba-130m').
- Data Formatting: Input and Target are concatenated into a single sequence for Causal LM training.
- Data Collator: Uses DataCollatorForLanguageModeling.
- Training Arguments: Uses base TrainingArguments, predict_with_generate=True.
- Trainer: Uses base Trainer.
- Evaluation: compute_metrics decodes generated sequences and compares to references.
- Output Directory: Changed to reflect the Mamba model.
- Hyperparameters: Adjusted defaults for Mamba (likely need tuning).

Dependencies:
- Working Python 3.10 environment (e.g., t5_pip_test) with necessary packages
  (transformers>=4.38, accelerate, torch, datasets, sentencepiece, causal-conv1d, mamba-ssm).
"""

import json
import sys
import numpy as np
import torch
import datetime # Use the standard datetime module
import logging
import argparse # Still used for the Namespace object, but not for parsing
from pathlib import Path
from torch.utils.data import Dataset
import inspect # Keep inspect if needed by other parts, otherwise optional now

# Import the specific classes needed
from transformers import (
    AutoTokenizer, # Use AutoTokenizer for flexibility
    MambaForCausalLM, # <<< CHANGED Model Class
    DataCollatorForLanguageModeling, # <<< CHANGED Data Collator
    Trainer, # <<< CHANGED to base Trainer
    TrainingArguments, # <<< CHANGED to base TrainingArguments
    TrainerCallback # For potential custom logging/debugging
)
from transformers.trainer_utils import set_seed, EvalPrediction

# Configure basic logging
logger = logging.getLogger(__name__)
if not logger.hasHandlers():
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        stream=sys.stdout # Ensure output to notebook
    )
else:
    logger.setLevel(logging.INFO)


# =============================================================================
# <<< CONFIGURATION >>>
# =============================================================================
# --- EDIT THESE VALUES TO CONFIGURE THE RUN ---
args = argparse.Namespace(
    # --- File Paths ---
    # !!! Update these if your HEP dataset has different names/locations !!!
    train_file='qed_data_processed/qed_amplitudes_train.jsonl',
    val_file='qed_data_processed/qed_amplitudes_val.jsonl',
    test_file='qed_data_processed/qed_amplitudes_test.jsonl',
    output_dir='mamba_symbolic_hep_output', # <<< CHANGED Output Dir

    # --- Model Configuration ---
    # <<< CHANGED Model Checkpoint to a Mamba variant >>>
    # Examples: 'state-spaces/mamba-130m', 'state-spaces/mamba-370m', ...
    model_id="state-spaces/mamba-130m",
    # <<< Define a separator for causal LM formatting >>>
    # Using EOS token is common
    # tokenizer_id will be set based on model_id later, but we anticipate needing eos
    separator_token="<|endoftext|>", # Will be replaced by actual tokenizer.eos_token

    # --- Tokenizer and Data Processing ---
    # Mamba can handle long sequences, adjust based on data & GPU memory
    max_seq_length=1024, # Start moderate, increase if possible

    # --- Training Hyperparameters ---
    # !!! Mamba often requires different HPs than Transformers !!!
    num_epochs=3,           # May need more/fewer
    # <<< Batch Size likely needs to be small for Mamba/long sequences >>>
    train_batch_size=4,     # START SMALL!
    eval_batch_size=4,      # START SMALL!
    learning_rate=3e-4,     # Mamba might need different LR, check model card
    weight_decay=0.1,       # Check model card recommendations
    warmup_steps=100,       # Adjust based on dataset size/LR schedule
    logging_steps=50,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    seed=42,
    dataloader_num_workers=0,

    # --- Advanced Training Features ---
    fp16=torch.cuda.is_available(),
    gradient_checkpointing=True, # Recommended for Mamba
    # label_smoothing=0.0, # Typically not used for Causal LM

    # --- Generation Configuration (for evaluation) ---
    # These are used by compute_metrics via model.generate
    generation_max_length=512, # Max length of the *generated target* part
    num_beams=1, # Use greedy decoding (beam search less common for Mamba eval)
    do_sample=False, # Use greedy decoding

    # --- Reporting ---
    report_to="tensorboard",
)
# --- END OF CONFIGURATION ---

# --- Global Tokenizer ---
# We initialize the tokenizer globally after loading the config
# because encode_sequences needs it, including the separator token.
tokenizer = None

# (Helper Functions: load_jsonl, convert_tokens_to_strings - Adapt if needed)
def load_jsonl(file_path):
    """Loads a JSON Lines (.jsonl) file."""
    # (Implementation is the same as before)
    data = []
    file_path = Path(file_path)
    logger.info(f"Loading data from: {file_path}")
    if not file_path.is_file():
        logger.error(f"Data file not found: {file_path}")
        raise FileNotFoundError(f"Data file not found: {file_path}")
    try:
        with file_path.open('r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                line = line.strip()
                if line:
                    try:
                        data.append(json.loads(line))
                    except json.JSONDecodeError as e:
                        logger.warning(f"Skipping invalid JSON on line {i+1} in {file_path}: {e}")
                        continue
        logger.info(f"Successfully loaded {len(data)} records from {file_path}.")
        return data
    except Exception as e:
        logger.error(f"Failed to load data from {file_path}: {e}", exc_info=True)
        raise

def convert_tokens_to_strings(raw_data_list):
    """Converts lists of tokens into single whitespace-joined strings."""
    # (Implementation is the same as before, check keys)
    input_strings = []
    target_strings = []
    if not raw_data_list: return input_strings, target_strings
    logger.info(f"Attempting to convert {len(raw_data_list)} items to strings...")
    skipped_count = 0
    for i, item in enumerate(raw_data_list):
        input_toks = item.get('input_tokens') # Adapt key if needed
        target_toks = item.get('target_tokens') # Adapt key if needed
        if isinstance(input_toks, list) and isinstance(target_toks, list):
            try:
                joined_input = " ".join([str(tok) for tok in input_toks])
                joined_target = " ".join([str(tok) for tok in target_toks])
                input_strings.append(joined_input)
                target_strings.append(joined_target)
            except Exception as e: logger.warning(f"Error joining tokens idx {i}: {e}", exc_info=True); skipped_count+=1
        else: logger.warning(f"Skipping item idx {i}: Invalid/missing keys/types."); skipped_count+=1
    logger.info(f"Successfully converted {len(input_strings)} items to strings (skipped {skipped_count}).")
    return input_strings, target_strings


# <<< CHANGED Data Encoding for Causal LM >>>
def encode_sequences_for_causal_lm(input_strings, target_strings, max_len):
    """
    Encodes input and target strings into a single sequence for Causal LM training.
    Formats as: input_string + eos_token + target_string + eos_token
    Creates labels where input tokens and pad tokens are masked (-100).
    """
    global tokenizer # Use the globally initialized tokenizer
    if tokenizer is None:
        raise RuntimeError("Tokenizer is not initialized. Call init_tokenizer first.")
    if tokenizer.eos_token is None:
        raise ValueError("Tokenizer must have an EOS token defined for this encoding scheme.")

    logger.info(f"Starting Causal LM encoding for {len(input_strings)} pairs. MaxLen={max_len}.")

    results = {'input_ids': [], 'attention_mask': [], 'labels': []}
    skipped_count = 0

    for i, (input_str, target_str) in enumerate(zip(input_strings, target_strings)):
        # 1. Tokenize Input
        input_encoding = tokenizer(input_str, add_special_tokens=False) # Don't add EOS here yet
        input_ids = input_encoding['input_ids']

        # 2. Tokenize Target
        target_encoding = tokenizer(target_str, add_special_tokens=False)
        target_ids = target_encoding['input_ids']

        # 3. Combine with Separator (EOS) and add final EOS
        eos_token_id = tokenizer.eos_token_id
        combined_ids = input_ids + [eos_token_id] + target_ids + [eos_token_id]

        # 4. Truncate if necessary
        if len(combined_ids) > max_len:
            combined_ids = combined_ids[:max_len]
            # logger.warning(f"Sequence {i} truncated (len {len(combined_ids)+len(input_ids)+2} > {max_len})") # Be careful about logging frequency

        # 5. Create Attention Mask
        attention_mask = [1] * len(combined_ids)

        # 6. Create Labels (Mask input tokens and separator)
        # Input length including the separator EOS token
        input_len_plus_sep = len(input_ids) + 1
        labels = [-100] * input_len_plus_sep + combined_ids[input_len_plus_sep:]
        # Ensure labels length matches combined_ids length after potential truncation
        labels = labels[:len(combined_ids)]
        # Mask any remaining positions if truncation happened within the input part
        labels[:min(input_len_plus_sep, len(labels))] = [-100] * min(input_len_plus_sep, len(labels))


        # 7. Padding (Will be handled by DataCollator, but store results)
        # Note: DataCollatorForLanguageModeling pads input_ids, attention_mask, and labels
        results['input_ids'].append(combined_ids)
        results['attention_mask'].append(attention_mask)
        results['labels'].append(labels)

        if i % 5000 == 0 and i > 0: # Log progress occasionally
             logger.info(f"  Encoded {i} sequences...")

    logger.info(f"Finished Causal LM encoding {len(results['input_ids'])} sequences (skipped {skipped_count}).")
    return results


class SequenceDataset(Dataset):
    """Custom PyTorch Dataset wrapper for Causal LM formatted sequence data."""
    # (Implementation is the same, just ensure keys match encode_sequences output)
    def __init__(self, encodings):
        required_keys = ['input_ids', 'attention_mask', 'labels']
        if not isinstance(encodings, dict) or not all(key in encodings for key in required_keys):
            logger.error(f"Invalid encodings passed to SequenceDataset. Got Keys: {list(encodings.keys()) if isinstance(encodings, dict) else type(encodings)}")
            raise ValueError(f"Encodings must be a dict with keys: {required_keys}.")
        try:
            lengths = {key: len(val) for key, val in encodings.items() if key in required_keys and isinstance(val, list)}
            if len(lengths) != len(required_keys):
                 missing_or_wrong_type = [k for k in required_keys if k not in lengths]
                 raise ValueError(f"Missing or non-list required keys in encodings: {missing_or_wrong_type}")
            if len(set(lengths.values())) > 1:
                 raise ValueError(f"Inconsistent lengths in encodings: {lengths}")
            self.length = lengths.get('input_ids', 0)
        except Exception as e:
            raise ValueError(f"Failed validating encoding structure/lengths: {e}") from e

        self.encodings = {k: v for k, v in encodings.items() if k in required_keys}
        if self.length == 0: logger.warning("Initializing SequenceDataset with length 0.")
        logger.info(f"Created SequenceDataset with {self.length} examples.")

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        if not 0 <= idx < self.length: raise IndexError(f"Index {idx} out of bounds ({self.length}).")
        try:
            item = {key: self.encodings[key][idx] for key in self.encodings}
            return item
        except Exception as e:
            logger.error(f"Failed creating item at index {idx}: {e}", exc_info=True)
            raise IndexError(f"Error retrieving item at index {idx}: {e}") from e


# <<< CHANGED Evaluation Metrics for Causal LM Generation >>>
def compute_metrics_causal_lm(eval_pred: EvalPrediction):
    """
    Calculates exact match accuracy for Causal LM by comparing generated
    sequences to reference sequences.
    Assumes eval_pred.predictions contains generated token IDs and
    eval_pred.label_ids contains the reference IDs (with -100 masking).
    """
    global tokenizer # Use the globally initialized tokenizer
    if tokenizer is None:
        logger.error("Tokenizer missing in compute_metrics!")
        return {"sequence_accuracy": 0.0}

    # Predictions are generated token IDs (potentially padded)
    # Label IDs are reference token IDs (with -100 masking for input/padding)
    predictions = eval_pred.predictions
    label_ids = eval_pred.label_ids

    if predictions is None or label_ids is None:
         logger.error("compute_metrics received None for predictions or label_ids.")
         return {"sequence_accuracy": 0.0}

    # Ensure numpy arrays
    if not isinstance(predictions, np.ndarray): predictions = np.array(predictions)
    if not isinstance(label_ids, np.ndarray): label_ids = np.array(label_ids)

    # Replace -100 in labels with pad_token_id for decoding comparison
    label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)

    try:
        # Decode generated predictions
        # skip_special_tokens=True removes EOS, PAD etc.
        decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True)

        # Decode reference labels
        decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)

    except Exception as e:
        logger.error(f"Decode failed in compute_metrics: {e}", exc_info=True)
        # Log shapes for debugging
        logger.error(f"Pred shape: {predictions.shape}, Labels shape: {label_ids.shape}")
        return {"sequence_accuracy": 0.0}

    # Post-processing
    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [label.strip() for label in decoded_labels]

    # Basic validation
    if not decoded_preds or not decoded_labels or len(decoded_preds) != len(decoded_labels):
        logger.warning(f"Metrics decode issue: Preds empty={not decoded_preds}, Labels empty={not decoded_labels}, LenPred={len(decoded_preds)}, LenLabel={len(decoded_labels)}")
        return {"sequence_accuracy": 0.0}

    # Calculate exact matches
    matches = [pred == label for pred, label in zip(decoded_preds, decoded_labels)]
    accuracy = np.mean(matches) if matches else 0.0

    logger.info(f"Computed sequence accuracy (generated vs reference): {accuracy:.4f}")
    # Return metric (can add others like BLEU, ROUGE if needed, though maybe less relevant for symbolic)
    return {"sequence_accuracy": float(accuracy)}


# =============================================================================
# Main Execution Logic
# =============================================================================

def main(config_args):
    """Orchestrates the Mamba training and evaluation for symbolic HEP."""
    global tokenizer # Allow main to modify the global tokenizer
    args = config_args

    # --- Basic Setup ---
    logger.info("="*60)
    logger.info(" Starting Mamba Training/Evaluation for Symbolic HEP Calculation ")
    logger.info("="*60)
    start_time = datetime.datetime.now()
    logger.info(f"Script execution started at: {start_time.strftime('%Y-%m-%d %H:%M:%S')}")

    if torch.cuda.is_available():
        logger.info(f"CUDA available. Device: {torch.cuda.get_device_name(0)}")
        logger.info(f"PyTorch CUDA version: {torch.version.cuda}")
    else:
        logger.warning("CUDA not available. Running on CPU.")
        if args.fp16:
             logger.warning("fp16=True ignored because CUDA is not available.")
             args.fp16 = False

    logger.info(f"Running with configuration:")
    for k, v in vars(args).items(): logger.info(f"  {k}: {v}")

    set_seed(args.seed)
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    logger.info(f"Output directory set to: {output_dir}")

    # --- Initialize Tokenizer ---
    # Needs to happen before data encoding
    logger.info("--- Stage 1: Initializing Tokenizer ---")
    try:
        # Use AutoTokenizer based on the Mamba model ID
        # Mamba models often use tokenizers like GPTNeoX
        tokenizer_id = args.model_id # Or specify explicitly e.g., "EleutherAI/gpt-neox-20b"
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)

        # <<< IMPORTANT: Set PAD token if missing (GPTNeoX doesn't always have one) >>>
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            logger.warning(f"Tokenizer missing PAD token. Setting pad_token to eos_token ({tokenizer.eos_token}).")
            # Update model config potentially? Usually handled by Trainer/Collator if PAD=EOS
            # model.config.pad_token_id = tokenizer.eos_token_id # Do this after model load

        # Set the separator token based on the actual loaded tokenizer
        args.separator_token = tokenizer.eos_token
        logger.info(f"Using separator token: {args.separator_token}")

        logger.info(f"Tokenizer '{tokenizer_id}' initialized (Pad: {tokenizer.pad_token}, EOS: {tokenizer.eos_token}).")
    except Exception as e: logger.error(f"Failed initializing tokenizer: {e}", exc_info=True); raise
    logger.info("--- Tokenizer Initialization Complete ---")


    # --- Data Loading and Preprocessing ---
    logger.info("--- Stage 2: Loading Raw Data ---")
    train_raw = load_jsonl(args.train_file)
    val_raw = load_jsonl(args.val_file)
    test_raw = load_jsonl(args.test_file)
    if not train_raw or not val_raw:
        raise ValueError("Training and/or validation raw datasets empty.")
    logger.info("--- Raw Data Loading Complete ---")

    logger.info("--- Stage 3: Converting Tokens to Strings ---")
    train_in, train_tgt = convert_tokens_to_strings(train_raw)
    val_in, val_tgt = convert_tokens_to_strings(val_raw)
    test_in, test_tgt = convert_tokens_to_strings(test_raw)
    if not train_in or not val_in:
        raise ValueError("Training and/or validation datasets empty after string conversion.")
    logger.info("--- Token-to-String Conversion Complete ---")

    logger.info("--- Stage 4: Encoding Data for Causal LM ---")
    try:
        logger.info("--> Encoding Training Data...")
        train_enc = encode_sequences_for_causal_lm(train_in, train_tgt, args.max_seq_length)
        if not train_enc.get('input_ids'): raise ValueError("Training encoding failed.")

        logger.info("--> Encoding Validation Data...")
        val_enc = encode_sequences_for_causal_lm(val_in, val_tgt, args.max_seq_length)
        if not val_enc.get('input_ids'): raise ValueError("Validation encoding failed.")

        test_enc = None
        if test_in and test_tgt:
            logger.info("--> Encoding Test Data...")
            test_enc = encode_sequences_for_causal_lm(test_in, test_tgt, args.max_seq_length)
            if not test_enc.get('input_ids'): logger.warning("Test encoding returned empty.")
        else:
             logger.info("Test input/target strings empty, skipping test encoding.")

    except Exception as e: logger.error(f"Data encoding failed: {e}", exc_info=True); raise
    logger.info("--- Data Encoding Phase Complete ---")

    logger.info("--- Stage 5: Creating PyTorch Datasets ---")
    try:
        train_ds = SequenceDataset(train_enc)
        val_ds = SequenceDataset(val_enc)
        test_ds = SequenceDataset(test_enc) if test_enc and test_enc.get('input_ids') else None
    except Exception as e: logger.error(f"Dataset creation failed: {e}", exc_info=True); raise
    logger.info("--- PyTorch Datasets Created Successfully ---")

    logger.info("--- Stage 6: Initializing Model ---")
    try:
        # <<< CHANGED to Mamba Model >>>
        model = MambaForCausalLM.from_pretrained(
            args.model_id,
            # torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # Use BF16 if available
            # Use FP32 for stability initially, Trainer handles FP16/BF16 later
            torch_dtype=torch.float32
        )
        # <<< Set pad_token_id in model config if tokenizer needed it >>>
        if tokenizer.pad_token_id == tokenizer.eos_token_id:
             logger.info("Setting model.config.pad_token_id = tokenizer.eos_token_id")
             model.config.pad_token_id = tokenizer.eos_token_id

        logger.info(f"Model {args.model_id} loaded.")
        if torch.cuda.is_available(): logger.info(f"Est. Model Memory (FP32): {model.get_memory_footprint() / 1e9:.2f} GB")
    except Exception as e: logger.error(f"Failed initializing model: {e}", exc_info=True); raise
    logger.info("--- Model Initialization Complete ---")

    logger.info("--- Stage 7: Configuring Training Environment ---")
    use_gc = args.gradient_checkpointing
    if use_gc: logger.info("Attempting to enable Gradient Checkpointing...")
    # Mamba models might have different GC enabling methods or rely solely on Trainer arg
    try:
        if use_gc and hasattr(model, 'gradient_checkpointing_enable'):
             model.gradient_checkpointing_enable(); logger.info("Gradient Checkpointing enabled via model method.")
        elif use_gc:
             logger.info("Model has no 'gradient_checkpointing_enable' method. Relying on TrainingArguments.")
    except Exception as e:
        logger.warning(f"Failed enabling Gradient Checkpointing on model: {e}. Relying on TrainingArguments setting.", exc_info=True)
        use_gc = args.gradient_checkpointing

    use_fp16 = args.fp16
    logger.info("Initializing Data Collator...")
    try:
        # <<< Use DataCollatorForLanguageModeling >>>
        # mlm=False for Causal LM (not Masked LM)
        collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
        logger.info("Data Collator for Language Modeling initialized.")
    except Exception as e: logger.error(f"Collator init failed: {e}", exc_info=True); raise


    # --- Initialize Training Arguments ---
    logger.info("Defining Training Arguments...")
    report_to = args.report_to.lower() if isinstance(args.report_to, str) else args.report_to
    if report_to == "none": logger.info("Reporting disabled.")
    else: logger.info(f"Reporting to: {report_to}")

    train_args = None
    try:
        # <<< Use base TrainingArguments >>>
        train_args = TrainingArguments(
            output_dir=str(output_dir),
            # Core Training Params
            num_train_epochs=args.num_epochs,
            per_device_train_batch_size=args.train_batch_size,
            per_device_eval_batch_size=args.eval_batch_size,
            learning_rate=args.learning_rate,
            weight_decay=args.weight_decay,
            warmup_steps=args.warmup_steps,
            seed=args.seed,
            # Evaluation and Saving Strategy
            eval_strategy=args.eval_strategy,
            eval_steps=args.eval_steps if args.eval_strategy == "steps" else None,
            save_strategy=args.save_strategy,
            save_steps=args.save_steps if args.save_strategy == "steps" else None,
            save_total_limit=args.save_total_limit,
            load_best_model_at_end=True, # Load best model based on eval metric
            metric_for_best_model="sequence_accuracy", # Needs to match compute_metrics output
            greater_is_better=True, # Accuracy should be maximized
            # Logging
            logging_dir=str(output_dir / 'logs'),
            logging_strategy="steps",
            logging_steps=args.logging_steps,
            report_to=report_to,
            # Performance / Hardware
            fp16=use_fp16,
            # bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(), # Optionally enable BF16
            gradient_checkpointing=use_gc,
            dataloader_num_workers=args.dataloader_num_workers,
            # Other Args
            remove_unused_columns=False, # Important for custom datasets/collators sometimes
            # <<< Add args relevant for generation during evaluation >>>
            predict_with_generate=True, # Tell Trainer to use generate() for predictions
            # label_names=["labels"], # Usually inferred correctly
        )
        logger.info("Successfully initialized TrainingArguments.")
        logger.info(f"Effective FP16: {train_args.fp16}, Grad Checkpointing: {train_args.gradient_checkpointing}")

    except Exception as e:
        logger.error(f"FAILED to initialize TrainingArguments: {e}", exc_info=True)
        raise e

    if train_args is None:
        raise RuntimeError("train_args was not successfully defined.")

    logger.info("--- Training Environment Configuration Complete ---")

    # --- Initialize Trainer ---
    logger.info("--- Stage 8: Initializing Trainer ---")
    try:
        # <<< Use base Trainer >>>
        trainer = Trainer(
            model=model,
            args=train_args,
            train_dataset=train_ds,
            eval_dataset=val_ds,
            data_collator=collator,
            tokenizer=tokenizer, # Pass tokenizer for generation
            compute_metrics=compute_metrics_causal_lm, # Use the Causal LM metrics function
        )
        logger.info("Trainer Initialized.")
    except Exception as e:
        logger.error(f"Trainer init failed: {e}", exc_info=True)
        raise e

    # --- Training ---
    logger.info("--- Stage 9: Starting Model Training ---")
    try:
        logger.info(f"Starting training for {args.num_epochs} epochs...")
        train_res = trainer.train()
        metrics = train_res.metrics
        logger.info("--- Training Finished ---")
        logger.info("--- Training Metrics ---")
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
        logger.info("Trainer state saved.")
        best_path = str(output_dir / "best_model")
        trainer.save_model(best_path)
        logger.info(f"Final best model saved to: {best_path}")
    except Exception as e: logger.error(f"Training failed: {e}", exc_info=True); raise

    # --- Evaluation ---
    logger.info("--- Stage 10: Evaluating Model on Test Set ---")
    if test_ds:
        try:
            logger.info("Running final evaluation on the test set...")
            # Use predict() which respects predict_with_generate=True
            # Or use evaluate() which should also trigger generation if configured
            test_res = trainer.evaluate(eval_dataset=test_ds, metric_key_prefix="test")
            logger.info("--- Evaluation on Test Set Finished ---")
            logger.info("--- Test Set Results ---")
            trainer.log_metrics("test", test_res)
            trainer.save_metrics("test", test_res)
            metric_key = "test_sequence_accuracy"
            if metric_key in test_res:
                logger.info(f"\n*** Test Sequence Accuracy: {test_res[metric_key]:.4f} ***\n")
            else:
                logger.warning(f"Metric '{metric_key}' not found in test results.")
                logger.info(f"Available test metrics: {list(test_res.keys())}")
        except Exception as e:
            logger.error(f"Test evaluation failed: {e}", exc_info=True)
    else:
        logger.warning("Test dataset unavailable or not loaded. Skipping final evaluation.")
    logger.info("--- Final Evaluation Stage Complete ---")

    # --- End Script ---
    end_time = datetime.datetime.now()
    total_time = end_time - start_time
    logger.info("="*60)
    logger.info(f" Script execution finished successfully at: {end_time.strftime('%Y-%m-%d %H:%M:%S')}")
    logger.info(f" Total execution time: {total_time}")
    logger.info("="*60)


# =============================================================================
# Execute Main Function (for Notebook context)
# =============================================================================
try:
    if not torch.cuda.is_available() and args.fp16:
        logger.warning("MAIN (Notebook): CUDA not detected before starting main function!")
        logger.warning("MAIN (Notebook): Disabling fp16 as CUDA is not available.")
        args.fp16 = False

    main(args) # Call the main function

except Exception as e:
     logger.critical(f"Unhandled exception terminated script execution: {e}", exc_info=True)
     # raise e # Optionally re-raise

[ERROR] 'mamba_ssm' library not found. Please install it (`pip install mamba_ssm causal-conv1d>=1.1.0`) or replace StateSpaceLayer with your implementation.


[INFO] Starting Custom SSM Sequence-to-Sequence Script...
[INFO] Current date/time (UTC): 2025-04-08 17:34:48 UTC
[INFO] Using CPU
[INFO] Initializing Tokenizer: bert-base-uncased
[INFO] Tokenizer Vocab Size: 30522, Pad Token ID: 0
[INFO] Loading data from: qed_expressions_train.jsonl
[INFO] Successfully loaded 12441 records.
[INFO] Loading data from: qed_expressions_val.jsonl
[INFO] Successfully loaded 1555 records.
[INFO] Loading data from: qed_expressions_test.jsonl
[INFO] Successfully loaded 1556 records.
[INFO] Converted 12441 items to strings (skipped 0).
[INFO] Converted 1555 items to strings (skipped 0).
[INFO] Converted 1556 items to strings (skipped 0).
[INFO] Pre-tokenizing dataset with max_length=256...
[INFO] Pre-tokenization complete. Dataset size: 12441 examples.
[INFO] Pre-tokenizing dataset with max_length=256...
[INFO] Pre-tokenization complete. Dataset size: 1555 examples.
[INFO] Pre-tokenizing dataset with max_length=256...
[INFO] Pre-tokenization complete. Dataset 

Epoch 1 Training:   0%|          | 0/778 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already bee

AttributeError: 'list' object has no attribute 'to'

# 3.3: Transformer Models for Symbolic Regression
Model: Transformer model with a contemporary innovation added such as KAN layers, reinforcement learning, genetic algorithms, specialized long-sequence attention, etc. which improves the performance compared to a basic transformer.


In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Task 3.3: End-to-end Training and Evaluation Script for a Reformer Model
for Symbolic Regression.

This script fine-tunes a Reformer model configured as an Encoder-Decoder
for symbolic regression tasks (sequence-to-sequence). Reformer is chosen
for its efficient attention mechanism (LSH), suitable for potentially long sequences.

*** NOTE: This version is adapted for robust execution ***
*** within a Jupyter Notebook (.ipynb). Configuration is handled    ***
*** by manually setting attributes on the `args` object.            ***

Key Changes:
- Model: Uses EncoderDecoderModel with Reformer weights.
- Tokenizer: Uses ReformerTokenizer (character-level for 'google/reformer-enwik8').
- Checkpoint: Uses 'google/reformer-enwik8' (with caveats about tokenizer suitability).
- Data Loading: Assumes input/target strings are directly in the JSONL data.
- Output Directory: Changed to reflect the model and task.
- Dependencies: Requires 'transformers', 'torch', 'datasets', 'sentencepiece'.

!!! WARNING !!!
The 'google/reformer-enwik8' checkpoint uses a character-level tokenizer.
This may be suboptimal for symbolic regression compared to a tokenizer trained
on mathematical symbols or subwords. Consider using a different base model or
tokenizer fine-tuning if needed.
"""

import json
import sys
import numpy as np
import torch
import datetime # Use the standard datetime module
import logging
import argparse # Still used for the Namespace object, but not for parsing
from pathlib import Path
from torch.utils.data import Dataset
import inspect # Keep inspect if needed by other parts, otherwise optional now

# Import the specific classes needed
from transformers import (
    ReformerTokenizer,         # <<< Reformer Tokenizer
    ReformerModel,             # <<< Base Reformer model
    EncoderDecoderModel,       # <<< Wrapper for Enc-Dec setup
    DataCollatorForSeq2Seq,    # <<< Standard Seq2Seq Collator
    Seq2SeqTrainer,            # <<< Standard Seq2Seq Trainer
    Seq2SeqTrainingArguments,  # <<< Standard Seq2Seq Arguments
)
from transformers.trainer_utils import set_seed, EvalPrediction

# Configure basic logging
logger = logging.getLogger(__name__)
if not logger.hasHandlers():
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        stream=sys.stdout # Ensure output to notebook
    )
else:
    logger.setLevel(logging.INFO)

# =============================================================================
# <<< CONFIGURATION >>>
# =============================================================================
# --- EDIT THESE VALUES TO CONFIGURE THE RUN ---
args = argparse.Namespace(
    # --- File Paths ---
    # !!! UPDATE these to your symbolic regression data files !!!
    # Expects keys like 'input_str' and 'target_str' in the JSONL
    train_file='symbolic_regression_train.jsonl', # Placeholder name
    val_file='symbolic_regression_val.jsonl',     # Placeholder name
    test_file='symbolic_regression_test.jsonl',      # Placeholder name
    output_dir='reformer_symbolic_regression_output', # <<< CHANGED Output Dir

    # --- Model Configuration ---
    # Using Reformer - WARNING about character-level tokenizer for enwik8!
    model_id="google/reformer-enwik8",
    tokenizer_id="google/reformer-enwik8", # Usually same as model_id
    task_prefix="regress expression: ",     # Optional task prefix

    # --- Tokenizer and Data Processing ---
    max_seq_length=512, # Reformer can handle longer sequences, adjust based on data/memory

    # --- Training Hyperparameters ---
    # !!! Tune these for Reformer and your specific dataset !!!
    num_epochs=5,
    train_batch_size=8,     # Adjust based on sequence length and GPU memory
    eval_batch_size=8,      # Adjust based on sequence length and GPU memory
    learning_rate=5e-5,     # Starting point, may need adjustment
    weight_decay=0.01,
    warmup_steps=300,
    logging_steps=100,
    eval_strategy="epoch",  # Correct parameter name
    save_strategy="epoch",
    save_total_limit=2,
    seed=42,
    dataloader_num_workers=0,

    # --- Advanced Training Features ---
    fp16=torch.cuda.is_available(),
    gradient_checkpointing=True,    # Recommended for Reformer
    label_smoothing_factor=0.1,

    # --- Generation Configuration (for evaluation) ---
    generation_num_beams=4,

    # --- Reporting ---
    report_to="tensorboard",
)
# --- END OF CONFIGURATION ---

# --- Global Tokenizer ---
tokenizer = None
tokenizer_for_metrics = None

# (Helper Functions)
def load_jsonl_direct_strings(file_path, input_key='input_str', target_key='target_str'):
    """
    Loads data from a JSON Lines file, extracting specified input/target strings.
    """
    input_strings = []
    target_strings = []
    file_path = Path(file_path)
    logger.info(f"Loading data from: {file_path} (expecting keys '{input_key}', '{target_key}')")
    if not file_path.is_file():
        logger.error(f"Data file not found: {file_path}")
        raise FileNotFoundError(f"File not found: {file_path}")

    skipped_count = 0
    try:
        with file_path.open('r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                line = line.strip()
                if line:
                    try:
                        item = json.loads(line)
                        input_str = item.get(input_key)
                        target_str = item.get(target_key)
                        if isinstance(input_str, str) and isinstance(target_str, str):
                            input_strings.append(input_str)
                            target_strings.append(target_str)
                        else:
                            logger.warning(f"Skipping line {i+1}: Missing/invalid types for keys '{input_key}' or '{target_key}'.")
                            skipped_count += 1
                    except json.JSONDecodeError as e:
                        logger.warning(f"Skipping invalid JSON on line {i+1}: {e}")
                        skipped_count += 1
                    except Exception as item_e:
                        logger.warning(f"Error processing line {i+1}: {item_e}")
                        skipped_count += 1

        logger.info(f"Successfully loaded {len(input_strings)} string pairs (skipped {skipped_count} lines).")
        if not input_strings:
             logger.warning(f"No valid records loaded from {file_path}.")
        return input_strings, target_strings
    except Exception as e:
        logger.error(f"Failed to load data from {file_path}: {e}", exc_info=True)
        raise


def encode_sequences(source_texts, target_texts, max_len):
    """
    Tokenizes source and target text sequences for Encoder-Decoder models.
    """
    global tokenizer # Use global tokenizer
    if tokenizer is None: raise RuntimeError("Tokenizer not initialized.")

    logger.info(f"Encoding sequence pairs with max_length={max_len}...")

    # Tokenize source texts (encoder input)
    encoder_inputs = tokenizer(
        source_texts,
        max_length=max_len,
        padding='max_length',   # Pad to max_len
        truncation=True,        # Truncate sequences longer than max_len
        return_tensors=None     # Return lists
    )

    # Tokenize target texts to create labels
    # Decoder input_ids are usually created internally by shifting labels
    with tokenizer.as_target_tokenizer():
        decoder_labels = tokenizer(
            target_texts,
            max_length=max_len,
            padding='max_length',
            truncation=True,
            return_tensors=None
        )

    # Assign the tokenized target IDs as 'labels'
    encoder_inputs['labels'] = decoder_labels['input_ids']

    logger.info(f"Encoding complete for {len(encoder_inputs['input_ids'])} sequences.")
    return encoder_inputs


class SequencePairDataset(Dataset):
    """Simple PyTorch Dataset for holding tokenized sequence pair data."""
    def __init__(self, encodings):
        if not isinstance(encodings, dict) or 'input_ids' not in encodings or 'labels' not in encodings:
            raise ValueError("Encodings must be a dictionary containing 'input_ids' and 'labels'.")
        self.encodings = encodings
        try:
            self.length = len(encodings['input_ids'])
            # Basic validation
            if not all(isinstance(encodings[key], list) and len(encodings[key]) == self.length for key in encodings):
                 raise ValueError(f"All encoding keys must be lists of the same length ({self.length}).")
        except Exception as e:
             raise ValueError(f"Failed to validate encodings: {e}")
        if self.length == 0: logger.warning("Initializing SequencePairDataset with length 0.")
        logger.info(f"Created Dataset with {self.length} examples.")

    def __len__(self): return self.length
    def __getitem__(self, idx):
        if not 0 <= idx < self.length: raise IndexError(f"Index {idx} out of bounds.")
        try:
            return {key: self.encodings[key][idx] for key in self.encodings}
        except Exception as e: logger.error(f"Failed retrieving item idx {idx}: {e}"); raise


def compute_metrics_fn(eval_pred):
    """Calculates exact sequence match accuracy after decoding."""
    # !!! Consider SymPy or numerical evaluation for more robust symbolic regression metrics !!!
    global tokenizer_for_metrics
    if tokenizer_for_metrics is None: logger.error("Tokenizer missing!"); return {"sequence_accuracy": 0.0}

    predictions, labels = eval_pred
    if not isinstance(predictions, np.ndarray): predictions = np.array(predictions)
    if not isinstance(labels, np.ndarray): labels = np.array(labels)

    labels = np.where(labels != -100, labels, tokenizer_for_metrics.pad_token_id)

    try:
        # Predictions might already be IDs if predict_with_generate=True
        # Check shape, if it looks like logits, argmax it. Reformer might output logits.
        if predictions.ndim == 3: # Shape (batch_size, seq_len, vocab_size) -> Logits
            predictions = np.argmax(predictions, axis=-1)

        decoded_preds = tokenizer_for_metrics.batch_decode(predictions, skip_special_tokens=True)
        decoded_labels = tokenizer_for_metrics.batch_decode(labels, skip_special_tokens=True)
    except Exception as e:
         logger.error(f"Decoding failed in compute_metrics: {e}", exc_info=True)
         return {"sequence_accuracy": 0.0}

    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [label.strip() for label in decoded_labels]

    if len(decoded_preds) != len(decoded_labels):
        logger.warning(f"Metrics: Mismatch pred/label count: {len(decoded_preds)} vs {len(decoded_labels)}")
        return {"sequence_accuracy": 0.0}

    matches = [pred == label for pred, label in zip(decoded_preds, decoded_labels)]
    accuracy = np.mean(matches) if matches else 0.0
    logger.info(f"Computed sequence accuracy: {accuracy:.4f}")
    return {"sequence_accuracy": float(accuracy)}

# =============================================================================
# Main Execution Logic
# =============================================================================

def main(config_args):
    """Orchestrates the Reformer fine-tuning for symbolic regression."""
    global tokenizer, tokenizer_for_metrics # Allow main to modify globals
    args = config_args

    # --- Basic Setup ---
    logger.info("="*60)
    logger.info(" Starting Reformer Training/Evaluation for Symbolic Regression ")
    logger.info("="*60)
    start_time = datetime.datetime.now()
    logger.info(f"Script execution started at: {start_time.strftime('%Y-%m-%d %H:%M:%S')}")

    if torch.cuda.is_available(): logger.info(f"CUDA available: {torch.cuda.get_device_name(0)}")
    else: logger.warning("CUDA not available. Running on CPU."); args.fp16 = False

    logger.info(f"Running with configuration:")
    for k, v in vars(args).items(): logger.info(f"  {k}: {v}")

    set_seed(args.seed)
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    logger.info(f"Output directory set to: {output_dir}")

    # --- Initialize Tokenizer ---
    logger.info("--- Stage 1: Initializing Tokenizer ---")
    try:
        tokenizer = ReformerTokenizer.from_pretrained(args.tokenizer_id)
        # Add special tokens if they don't exist - important for Reformer structure
        special_tokens_to_add = {}
        if tokenizer.pad_token is None: special_tokens_to_add['pad_token'] = '<pad>'
        if tokenizer.bos_token is None: special_tokens_to_add['bos_token'] = '<s>' # Use BOS for start
        if tokenizer.eos_token is None: special_tokens_to_add['eos_token'] = '</s>' # Use EOS for end
        if special_tokens_to_add:
             logger.warning(f"Adding special tokens: {special_tokens_to_add}")
             tokenizer.add_special_tokens(special_tokens_to_add)
        tokenizer_for_metrics = tokenizer # Set global for metrics
        logger.warning("Using ReformerTokenizer (google/reformer-enwik8) which is CHARACTER-LEVEL. "
                       "Ensure this matches your data or consider a different tokenizer/model.")
        logger.info(f"Tokenizer '{args.tokenizer_id}' initialized (Vocab size: {len(tokenizer)}).")
        logger.info(f"PAD={tokenizer.pad_token_id}, BOS={tokenizer.bos_token_id}, EOS={tokenizer.eos_token_id}")
    except Exception as e: logger.error(f"Failed initializing tokenizer: {e}", exc_info=True); raise
    logger.info("--- Tokenizer Initialization Complete ---")


    # --- Data Loading (Direct Strings) ---
    logger.info("--- Stage 2: Loading Raw Data (Direct Strings) ---")
    try:
        # Assumes keys 'input_str', 'target_str' in the jsonl files
        train_sources, train_targets = load_jsonl_direct_strings(args.train_file)
        val_sources, val_targets = load_jsonl_direct_strings(args.val_file)
        test_sources, test_targets = load_jsonl_direct_strings(args.test_file)
    except Exception as e: logger.error(f"Critical error during data loading: {e}", exc_info=True); raise
    if not train_sources or not val_sources: raise ValueError("Training and/or validation source/target strings empty.")
    logger.info("--- Raw String Data Loading Complete ---")


    # --- Tokenize Data ---
    logger.info("--- Stage 3: Encoding Data (Tokenization) ---")
    try:
        train_encodings = encode_sequences(train_sources, train_targets, args.max_seq_length)
        val_encodings = encode_sequences(val_sources, val_targets, args.max_seq_length)
        test_encodings = encode_sequences(test_sources, test_targets, args.max_seq_length) if test_sources else None
    except Exception as e: logger.error(f"Data encoding failed: {e}", exc_info=True); raise
    logger.info("--- Data Encoding Phase Complete ---")


    # --- Create Datasets ---
    logger.info("--- Stage 4: Creating PyTorch Datasets ---")
    try:
        train_dataset = SequencePairDataset(train_encodings)
        val_dataset = SequencePairDataset(val_encodings)
        test_dataset = SequencePairDataset(test_encodings) if test_encodings else None
    except Exception as e: logger.error(f"Dataset creation failed: {e}", exc_info=True); raise
    logger.info("--- PyTorch Datasets Created Successfully ---")

    # --- Initialize Model ---
    logger.info("--- Stage 5: Initializing Model ---")
    try:
        logger.info(f"Loading Reformer weights from {args.model_id} for EncoderDecoderModel")
        model = EncoderDecoderModel.from_encoder_decoder_pretrained(args.model_id, args.model_id)

        # Resize embeddings if tokenizer vocab changed
        if model.config.encoder.vocab_size != len(tokenizer):
             logger.info(f"Resizing model token embeddings from {model.config.encoder.vocab_size} to {len(tokenizer)}")
             model.resize_token_embeddings(len(tokenizer))
             # Update config values after resizing
             model.config.encoder.vocab_size = len(tokenizer)
             model.config.decoder.vocab_size = len(tokenizer)

        # Configure for Seq2Seq Generation
        model.config.decoder_start_token_id = tokenizer.bos_token_id # Use BOS as decoder start
        model.config.eos_token_id = tokenizer.eos_token_id
        model.config.pad_token_id = tokenizer.pad_token_id

        # Critical: Ensure decoder is configured as a decoder with cross-attention
        model.config.decoder.is_decoder = True
        model.config.decoder.add_cross_attention = True

        # Set generation parameters on the model config
        model.config.max_length = args.max_seq_length # Max output length during generation
        model.config.num_beams = args.generation_num_beams

        # Tie weights
        logger.info("Tying encoder and decoder weights.")
        model.tie_weights()

        logger.info(f"Model {args.model_id} (EncoderDecoder) loaded and configured.")
        if torch.cuda.is_available(): logger.info(f"Est. Model Memory: {model.get_memory_footprint() / 1e9:.2f} GB")

        # Enable gradient checkpointing *on the model* if desired
        if args.gradient_checkpointing:
            model.gradient_checkpointing_enable()
            logger.info("Gradient Checkpointing enabled via model method.")

    except Exception as e: logger.error(f"Failed initializing model: {e}", exc_info=True); raise
    logger.info("--- Model Initialization Complete ---")

    # --- Initialize Collator ---
    logger.info("--- Stage 6: Initializing Data Collator ---")
    try:
        data_collator = DataCollatorForSeq2Seq(
            tokenizer=tokenizer,
            model=model,
            label_pad_token_id=-100, # Important: Ignore padding in loss
            pad_to_multiple_of=8 if args.fp16 else None
        )
        logger.info("Data Collator Initialized.")
    except Exception as e: logger.error(f"Collator init failed: {e}", exc_info=True); raise

    # --- Initialize Training Arguments ---
    logger.info("--- Stage 7: Defining Training Arguments ---")
    report_to = args.report_to.lower() if isinstance(args.report_to, str) else args.report_to
    logger.info(f"Reporting to: {report_to}")
    train_args = None
    try:
        # Use Seq2SeqTrainingArguments
        train_args = Seq2SeqTrainingArguments(
            output_dir=str(output_dir),
            # Core Training Params
            num_train_epochs=args.num_epochs,
            per_device_train_batch_size=args.train_batch_size,
            per_device_eval_batch_size=args.eval_batch_size,
            learning_rate=args.learning_rate,
            weight_decay=args.weight_decay,
            warmup_steps=args.warmup_steps,
            seed=args.seed,
            # Evaluation and Saving Strategy
            eval_strategy=args.eval_strategy, # Use correct name
            save_strategy=args.save_strategy,
            save_total_limit=args.save_total_limit,
            load_best_model_at_end=True,
            metric_for_best_model="sequence_accuracy",
            greater_is_better=True,
            # Logging
            logging_dir=str(output_dir / 'logs'),
            logging_strategy="steps",
            logging_steps=args.logging_steps,
            report_to=report_to,
            # Performance / Hardware
            fp16=args.fp16,
            # Gradient checkpointing - enable in args if enabled successfully on model
            gradient_checkpointing=(args.gradient_checkpointing and hasattr(model, 'is_gradient_checkpointing') and model.is_gradient_checkpointing),
            dataloader_num_workers=args.dataloader_num_workers,
            # Seq2Seq Specific
            predict_with_generate=True,
            generation_max_length=args.max_seq_length, # Use config seq len for consistency
            generation_num_beams=args.generation_num_beams,
            label_smoothing_factor=args.label_smoothing_factor,
        )
        logger.info("Successfully initialized Seq2SeqTrainingArguments.")
        logger.info(f"Effective FP16: {train_args.fp16}, Grad Checkpointing: {train_args.gradient_checkpointing}")

    except Exception as e: logger.error(f"FAILED to initialize Training Arguments: {e}", exc_info=True); raise
    if train_args is None: raise RuntimeError("train_args failed initialization.")
    logger.info("--- Training Environment Configuration Complete ---")

    # --- Initialize Trainer ---
    logger.info("--- Stage 8: Initializing Seq2SeqTrainer ---")
    try:
        trainer = Seq2SeqTrainer(
            model=model,
            args=train_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            data_collator=data_collator,
            tokenizer=tokenizer,
            compute_metrics=compute_metrics_fn,
        )
        logger.info("Seq2SeqTrainer Initialized.")
    except Exception as e: logger.error(f"Trainer init failed: {e}", exc_info=True); raise

    # --- Training ---
    logger.info("--- Stage 9: Starting Model Training ---")
    try:
        logger.info(f"Starting training for {args.num_epochs} epochs...")
        train_result = trainer.train()
        metrics = train_result.metrics
        logger.info("--- Training Finished ---")
        logger.info("--- Training Metrics ---")
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
        logger.info("Trainer state saved.")
        best_path = str(output_dir / "best_model")
        trainer.save_model(best_path) # Saves the best model if load_best_model_at_end=True
        logger.info(f"Final best model saved to: {best_path}")
    except Exception as e: logger.error(f"Training failed: {e}", exc_info=True); raise

    # --- Evaluation ---
    logger.info("--- Stage 10: Evaluating Model on Test Set ---")
    if test_dataset:
        try:
            logger.info("Running final evaluation on the test set...")
            test_results = trainer.evaluate(eval_dataset=test_dataset, metric_key_prefix="test")
            logger.info("--- Evaluation on Test Set Finished ---")
            logger.info("--- Test Set Results ---")
            trainer.log_metrics("test", test_results)
            trainer.save_metrics("test", test_results)
            metric_key = "test_sequence_accuracy"
            if metric_key in test_results:
                logger.info(f"\n*** Test Sequence Accuracy: {test_results[metric_key]:.4f} ***\n")
            else:
                logger.warning(f"Metric '{metric_key}' not found in test results.")
                logger.info(f"Available test metrics: {list(test_results.keys())}")
        except Exception as e: logger.error(f"Test evaluation failed: {e}", exc_info=True)
    else: logger.warning("Test dataset unavailable or not loaded. Skipping final evaluation.")
    logger.info("--- Final Evaluation Stage Complete ---")

    # --- End Script ---
    end_time = datetime.datetime.now()
    total_time = end_time - start_time
    logger.info("="*60)
    logger.info(f" Script execution finished successfully at: {end_time.strftime('%Y-%m-%d %H:%M:%S')}")
    logger.info(f" Total execution time: {total_time}")
    logger.info("="*60)

# =============================================================================
# Execute Main Function (for Notebook context)
# =============================================================================
try:
    if not torch.cuda.is_available() and args.fp16:
        logger.warning("MAIN (Notebook): CUDA not detected!")
        args.fp16 = False
    # Call the main function with the configured args object
    main(args)
except Exception as e:
     logger.critical(f"Unhandled exception terminated script execution: {e}", exc_info=True)
     # raise e # Optionally re-raise

[INFO] Starting Reformer Fine-tuning Script for Symbolic Regression...
[INFO] Current date/time (UTC): 2025-04-23 03:45:11 UTC
[INFO] Using model: google/reformer-enwik8 with Tokenizer: google/reformer-enwik8
[INFO] Loading data from: qed_expressions_train.jsonl


[Error] Data file not found: qed_expressions_train.jsonl
[FATAL] Critical error during data loading: File not found: qed_expressions_train.jsonl


SystemExit: 1

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


# 3.4: Titans for squared amplitude calculation
Model: One of the core architectures from Google’s paper introducing Titans concept


In [10]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
End-to-end Fine-tuning Script for a Hypothetical "Titan" Sequence-to-Sequence Model
on the Preprocessed QED Squared Amplitude Calculation Dataset.

This script demonstrates fine-tuning a sequence-to-sequence model, referred to
here as "Titan" (e.g., `google/titan-small`), presumably incorporating unique
architectural features like time-mixing and recursion modules as specified in
the configuration. It utilizes the Hugging Face Transformers library, including
the `Seq2SeqTrainer` API.

Key features demonstrated:
- Hypothetical Titan Model: Assumes existence of `TitanTokenizer` and
  `TitanForConditionalGeneration` with specific config flags (`use_time_mixing`,
  `use_recursion`).
- Advanced Training Techniques: Includes Gradient Checkpointing, Mixed Precision
  Training (FP16), Label Smoothing, and Beam Search during evaluation.
- Standard Workflow: Follows a common pattern of data loading, preprocessing,
  tokenization, model configuration, training, and evaluation.
- Exact Match Accuracy: Uses sequence-level exact match as the evaluation metric.

Workflow:
1. Load pre-split data (train, validation, test) from JSONL files.
2. Reconstruct source (input amplitude) and target (squared amplitude) strings.
3. Initialize the Titan tokenizer and model.
4. Configure Titan-specific features (time-mixing, recursion) and enable
   gradient checkpointing on the model instance.
5. Define a function to tokenize source/target string pairs.
6. Create PyTorch Dataset objects from the tokenized data.
7. Instantiate `DataCollatorForSeq2Seq` for dynamic batching.
8. Define the `compute_metrics` function for evaluation using decoded sequences.
9. Configure `Seq2SeqTrainingArguments` with hyperparameters and features.
10. Initialize and run the `Seq2SeqTrainer`.
11. Evaluate the final model on the test set.
"""

import json
import sys
import numpy as np
import torch
import datetime
from pathlib import Path
from torch.utils.data import Dataset
# --- Hypothetical Titan Model Imports ---
# Ensure these classes exist in your transformers installation or define placeholders
try:
    from transformers import (
        TitanTokenizer,                # Hypothetical Tokenizer
        TitanForConditionalGeneration, # Hypothetical Model
        DataCollatorForSeq2Seq,
        Seq2SeqTrainer,
        Seq2SeqTrainingArguments,
        TrainingArguments           # Explicit import
    )
    print("[INFO] Successfully imported hypothetical Titan classes.")
except ImportError:
    print("[ERROR] Failed to import TitanTokenizer or TitanForConditionalGeneration.", file=sys.stderr)
    print("Please ensure these classes are defined in your transformers library or environment.", file=sys.stderr)
    # Define dummy classes if needed for script structure analysis:
    # class TitanTokenizer: @staticmethod def from_pretrained(s): raise NotImplementedError
    # class TitanForConditionalGeneration: @staticmethod def from_pretrained(s): raise NotImplementedError
    # from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments, TrainingArguments
    sys.exit(1) # Exit if the core classes are missing

from tqdm.auto import tqdm # Optional progress bars

# --- Configuration ---

# File Paths
TRAIN_FILE = Path('qed_expressions_train.jsonl') # Use consistent naming if applicable
VAL_FILE   = Path('qed_expressions_val.jsonl')
TEST_FILE  = Path('qed_expressions_test.jsonl')
OUTPUT_DIR = Path('titan_qed_squared_amplitude_output')

# Model Configuration
# Replace with actual model ID if 'google/titan-small' is a placeholder
MODEL_ID = "google/titan-small" # Hypothetical Titan model identifier
TOKENIZER_ID = "google/titan-small" # Usually same as model ID

# Titan-Specific Configuration (Hypothetical)
USE_TIME_MIXING = True
USE_RECURSION   = True

# Tokenizer and Data Processing Configuration
MAX_SEQ_LENGTH = 512 # Max sequence length

# Training Hyperparameters
NUM_TRAIN_EPOCHS = 4 # Adjust based on convergence
# Adjust batch size based on Titan model size and GPU memory
PER_DEVICE_TRAIN_BATCH_SIZE = 4
PER_DEVICE_EVAL_BATCH_SIZE = 8
LEARNING_RATE = 2e-5
WEIGHT_DECAY = 0.01
WARMUP_STEPS = 200
LOGGING_STEPS = 50
EVALUATION_STRATEGY = "epoch"
SAVE_STRATEGY = "epoch"

# Advanced Training Feature Flags/Values
USE_FP16 = torch.cuda.is_available()
USE_GRADIENT_CHECKPOINTING = True
LABEL_SMOOTHING_FACTOR = 0.1

# Generation Configuration (for evaluation)
GENERATION_NUM_BEAMS = 4

# --- Helper Functions and Classes ---

def load_jsonl(file_path):
    """Loads data from a JSON Lines file, handling basic errors."""
    data = []
    file_path = Path(file_path)
    print(f"[INFO] Loading data from: {file_path}")
    if not file_path.is_file():
        print(f"[Error] Data file not found: {file_path}", file=sys.stderr)
        raise FileNotFoundError(f"File not found: {file_path}")

    try:
        with file_path.open('r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                line = line.strip()
                if line:
                    try:
                        data.append(json.loads(line))
                    except json.JSONDecodeError as e:
                        print(f"[Warning] Skipping invalid JSON on line {i+1} in {file_path}: {e}")
                        continue
        print(f"[INFO] Successfully loaded {len(data)} records.")
        if not data:
             print(f"[Warning] No valid records loaded from {file_path}.")
        return data
    except Exception as e:
        print(f"[Error] Failed to load data from {file_path}: {e}", file=sys.stderr)
        raise


def convert_tokens_to_strings(raw_data_list):
    """Converts lists of tokens from loaded data into source and target strings."""
    source_strings = []
    target_strings = []
    if not raw_data_list:
        return source_strings, target_strings

    skipped_count = 0
    for i, item in enumerate(raw_data_list):
        input_toks = item.get('input_tokens')
        target_toks = item.get('target_tokens')

        if isinstance(input_toks, list) and isinstance(target_toks, list):
            source_strings.append(" ".join(map(str, input_toks)))
            target_strings.append(" ".join(map(str, target_toks)))
        else:
            print(f"[Warning] Skipping item at index {i} due to missing/invalid keys or non-list values: {item}")
            skipped_count += 1

    print(f"[INFO] Converted {len(source_strings)} items to strings (skipped {skipped_count}).")
    if not source_strings:
        print("[Warning] No items successfully converted to strings.")
    return source_strings, target_strings


def encode_sequences(tokenizer, source_texts, target_texts, max_len):
    """
    Tokenizes source and target text sequences using the provided tokenizer.

    Args:
        tokenizer: Initialized Hugging Face tokenizer instance (e.g., TitanTokenizer).
        source_texts (list[str]): List of source sequences.
        target_texts (list[str]): List of target sequences for labels.
        max_len (int): Maximum sequence length for padding and truncation.

    Returns:
        dict: Dictionary containing lists of 'input_ids', 'attention_mask', and 'labels'.
    """
    print(f"[INFO] Encoding sequence pairs with max_length={max_len}...")

    # Tokenize source texts
    encoder_inputs = tokenizer(
        source_texts,
        max_length=max_len,
        padding='max_length',
        truncation=True,
        return_tensors=None # Return lists
    )

    # Tokenize target texts to create labels
    with tokenizer.as_target_tokenizer():
        decoder_labels = tokenizer(
            target_texts,
            max_length=max_len,
            padding='max_length',
            truncation=True,
            return_tensors=None
        )

    # Assign target input_ids as 'labels'
    encoder_inputs['labels'] = decoder_labels['input_ids']

    print(f"[INFO] Encoding complete.")
    # Basic validation
    if not encoder_inputs.get('input_ids') or not encoder_inputs.get('labels'):
         print("[Warning] Encoding resulted in empty input_ids or labels.")
    elif len(encoder_inputs['input_ids']) != len(encoder_inputs['labels']):
         print("[Warning] Mismatch in length between encoded inputs and labels.")

    return encoder_inputs


class SequencePairDataset(Dataset):
    """Simple PyTorch Dataset for holding tokenized sequence pair data (as lists)."""
    def __init__(self, encodings):
        """
        Args:
            encodings (dict): Dictionary from tokenizer {'input_ids': [...], ...}.
                              Values should be lists of token IDs or masks.
        """
        if not isinstance(encodings, dict) or 'input_ids' not in encodings:
            raise ValueError("Encodings must be a dictionary containing at least 'input_ids'.")
        self.encodings = encodings
        try:
            self.length = len(encodings['input_ids'])
            for key in encodings:
                if not isinstance(encodings[key], list) or len(encodings[key]) != self.length:
                    raise ValueError(f"Encoding key '{key}' is not a list or has inconsistent length.")
        except Exception as e:
             raise ValueError(f"Failed to validate encodings: {e}")

        if self.length == 0:
            raise ValueError("Input encodings are empty.")
        print(f"[INFO] Created Dataset with {self.length} examples.")

    def __len__(self):
        """Returns the total number of samples."""
        return self.length

    def __getitem__(self, idx):
        """Retrieves the tokenized data (as lists) for a single sample index."""
        if not 0 <= idx < self.length:
             raise IndexError(f"Index {idx} out of bounds for dataset of length {self.length}.")
        try:
            return {key: self.encodings[key][idx] for key in self.encodings}
        except Exception as e:
            print(f"[Error] Failed to retrieve item at index {idx}: {e}", file=sys.stderr)
            raise

# --- Main Execution Logic ---

def main():
    """Orchestrates the Titan model fine-tuning and evaluation pipeline."""
    print("[INFO] Starting Hypothetical Titan Model Fine-tuning Script...")
    try:
        current_time_str = datetime.datetime.now(datetime.timezone.utc).strftime('%Y-%m-%d %H:%M:%S UTC')
    except Exception:
        current_time_str = datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC (naive)')
    print(f"[INFO] Current time: {current_time_str} (local: San Diego, CA)") # Include location context
    print(f"[INFO] Using model: {MODEL_ID} with Tokenizer: {TOKENIZER_ID}")

    # --- 1. Load Data ---
    try:
        train_raw_data = load_jsonl(TRAIN_FILE)
        val_raw_data = load_jsonl(VAL_FILE)
        test_raw_data = load_jsonl(TEST_FILE)
    except Exception as e:
        print(f"[FATAL] Critical error during data loading: {e}", file=sys.stderr)
        sys.exit(1)

    if not train_raw_data or not val_raw_data or not test_raw_data:
        print("[FATAL] One or more required datasets are empty or failed to load. Exiting.", file=sys.stderr)
        sys.exit(1)

    # --- 2. Prepare Data (Tokens to Strings) ---
    train_sources, train_targets = convert_tokens_to_strings(train_raw_data)
    val_sources,   val_targets   = convert_tokens_to_strings(val_raw_data)
    test_sources,  test_targets  = convert_tokens_to_strings(test_raw_data)

    if not train_sources or not val_sources or not test_sources:
        print("[FATAL] Data conversion resulted in empty lists. Check input data format. Exiting.", file=sys.stderr)
        sys.exit(1)

    # --- 3. Initialize Tokenizer ---
    print(f"[INFO] Initializing Tokenizer: {TOKENIZER_ID}")
    try:
        # Assuming TitanTokenizer exists and works like standard tokenizers
        tokenizer = TitanTokenizer.from_pretrained(TOKENIZER_ID)
        # Add checks/additions for special tokens if necessary for Titan
        if tokenizer.pad_token is None: tokenizer.add_special_tokens({'pad_token': '<pad>'})
        # Determine start/end tokens based on tokenizer properties or model requirements
        if tokenizer.bos_token is None: tokenizer.add_special_tokens({'bos_token': '<s>'})
        if tokenizer.eos_token is None: tokenizer.add_special_tokens({'eos_token': '</s>'})

    except Exception as e:
        print(f"[FATAL] Failed to initialize tokenizer '{TOKENIZER_ID}': {e}", file=sys.stderr)
        sys.exit(1)

    # Store globally for metric computation
    global tokenizer_for_metrics
    tokenizer_for_metrics = tokenizer
    pad_token_id = tokenizer.pad_token_id
    print(f"[INFO] Tokenizer Pad Token ID: {pad_token_id}")
    print(f"[INFO] Tokenizer Vocab Size (initial): {tokenizer.vocab_size}")


    # --- 4. Initialize Model ---
    print(f"[INFO] Initializing Model: {MODEL_ID}")
    try:
        model = TitanForConditionalGeneration.from_pretrained(MODEL_ID)

        # Resize embeddings if vocab size changed
        if model.config.vocab_size != len(tokenizer):
             print(f"[INFO] Resizing model embeddings from {model.config.vocab_size} to {len(tokenizer)}")
             model.resize_token_embeddings(len(tokenizer))
             model.config.vocab_size = len(tokenizer) # Ensure config is updated

        # Configure hypothetical Titan features - use try/except for robustness
        config_updated = False
        try:
            if USE_TIME_MIXING and hasattr(model.config, 'use_time_mixing'):
                model.config.use_time_mixing = True
                print("[INFO] Enabled model.config.use_time_mixing")
                config_updated = True
            elif USE_TIME_MIXING:
                 print("[Warning] Configured USE_TIME_MIXING=True, but 'use_time_mixing' not found in model config.")

            if USE_RECURSION and hasattr(model.config, 'use_recursion'):
                model.config.use_recursion = True
                print("[INFO] Enabled model.config.use_recursion")
                config_updated = True
            elif USE_RECURSION:
                 print("[Warning] Configured USE_RECURSION=True, but 'use_recursion' not found in model config.")

        except Exception as config_e:
             print(f"[Warning] Error applying hypothetical Titan config options: {config_e}")

        # Configure standard seq2seq settings
        model.config.decoder_start_token_id = tokenizer.bos_token_id
        model.config.eos_token_id = tokenizer.eos_token_id
        model.config.pad_token_id = tokenizer.pad_token_id

        # Set generation defaults
        model.config.max_length = MAX_SEQ_LENGTH
        model.config.num_beams = GENERATION_NUM_BEAMS

        # Enable Gradient Checkpointing if configured
        if USE_GRADIENT_CHECKPOINTING:
            try:
                 model.gradient_checkpointing_enable()
                 print("[INFO] Gradient Checkpointing enabled on the model.")
            except Exception as gc_e:
                 print(f"[Warning] Failed to enable gradient checkpointing on model: {gc_e}")
                 # Ensure training arg reflects this failure
                 global USE_GRADIENT_CHECKPOINTING_EFFECTIVE
                 USE_GRADIENT_CHECKPOINTING_EFFECTIVE = False
        else:
            USE_GRADIENT_CHECKPOINTING_EFFECTIVE = False


    except Exception as e:
        print(f"[FATAL] Failed to initialize or configure the Titan model '{MODEL_ID}': {e}", file=sys.stderr)
        sys.exit(1)

    # --- 5. Tokenize Data ---
    try:
        train_encodings = encode_sequences(tokenizer, train_sources, train_targets, MAX_SEQ_LENGTH)
        val_encodings   = encode_sequences(tokenizer, val_sources,   val_targets,   MAX_SEQ_LENGTH)
        test_encodings  = encode_sequences(tokenizer, test_sources,  test_targets,  MAX_SEQ_LENGTH)
    except Exception as e:
        print(f"[FATAL] Failed during data tokenization: {e}", file=sys.stderr)
        sys.exit(1)

    if not train_encodings.get('input_ids') or not val_encodings.get('input_ids') or not test_encodings.get('input_ids'):
         print("[FATAL] Tokenization resulted in empty encodings for one or more splits. Exiting.", file=sys.stderr)
         sys.exit(1)

    # --- 6. Create Datasets ---
    try:
        train_dataset = SequencePairDataset(train_encodings)
        val_dataset   = SequencePairDataset(val_encodings)
        test_dataset  = SequencePairDataset(test_encodings)
    except ValueError as e:
        print(f"[FATAL] Failed to create Dataset objects: {e}", file=sys.stderr)
        sys.exit(1)

    # --- 7. Initialize Data Collator ---
    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=model,
        label_pad_token_id=-100, # Ignore pad tokens in loss
        pad_to_multiple_of=8 if USE_FP16 else None # Optimize padding for FP16
    )
    print("[INFO] Data collator initialized.")

    # --- 8. Define Metrics Computation ---
    def compute_metrics_fn(eval_pred):
        """Calculates exact sequence match accuracy after decoding."""
        predictions, labels = eval_pred
        if not isinstance(predictions, np.ndarray): predictions = np.array(predictions)
        if not isinstance(labels, np.ndarray): labels = np.array(labels)

        # Replace -100 with pad_token_id for decoding
        labels = np.where(labels != -100, labels, tokenizer_for_metrics.pad_token_id)

        try:
            decoded_preds = tokenizer_for_metrics.batch_decode(predictions, skip_special_tokens=True)
            decoded_labels = tokenizer_for_metrics.batch_decode(labels, skip_special_tokens=True)
        except Exception as e:
             print(f"[Error] Decoding failed in compute_metrics: {e}", file=sys.stderr)
             return {"sequence_accuracy": 0.0}

        # Post-process and compare
        decoded_preds = [pred.strip() for pred in decoded_preds]
        decoded_labels = [label.strip() for label in decoded_labels]

        if len(decoded_preds) != len(decoded_labels):
            print(f"[Warning] Mismatch prediction/label count: {len(decoded_preds)} vs {len(decoded_labels)}", file=sys.stderr)
            return {"sequence_accuracy": 0.0}

        matches = [pred == label for pred, label in zip(decoded_preds, decoded_labels)]
        accuracy = np.mean(matches) if matches else 0.0
        return {"sequence_accuracy": float(accuracy)}

    # --- 9. Define Training Arguments ---
    # Use effective GC flag based on earlier attempt
    effective_gc = USE_GRADIENT_CHECKPOINTING if 'USE_GRADIENT_CHECKPOINTING_EFFECTIVE' not in globals() else USE_GRADIENT_CHECKPOINTING_EFFECTIVE

    training_args = Seq2SeqTrainingArguments(
        output_dir=str(OUTPUT_DIR),
        # Schedule
        num_train_epochs=NUM_TRAIN_EPOCHS,
        per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
        per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        warmup_steps=WARMUP_STEPS,
        # Logging / Saving / Evaluation
        logging_dir=str(OUTPUT_DIR / 'logs'),
        logging_steps=LOGGING_STEPS,
        evaluation_strategy=EVALUATION_STRATEGY,
        save_strategy=SAVE_STRATEGY,
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model="sequence_accuracy",
        greater_is_better=True,
        # Generation
        predict_with_generate=True,
        generation_max_length=MAX_SEQ_LENGTH,
        generation_num_beams=GENERATION_NUM_BEAMS,
        # Advanced features
        fp16=USE_FP16,
        label_smoothing_factor=LABEL_SMOOTHING_FACTOR,
        gradient_checkpointing=effective_gc, # Use the effective flag
        # report_to="tensorboard", # Optional
    )
    print("[INFO] Training arguments defined.")
    print(f"[INFO] Effective Mixed Precision (FP16): {'Enabled' if USE_FP16 else 'Disabled'}")
    print(f"[INFO] Effective Label Smoothing Factor: {LABEL_SMOOTHING_FACTOR}")
    print(f"[INFO] Effective Gradient Checkpointing in Args: {training_args.gradient_checkpointing}")


    # --- 10. Initialize Trainer ---
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics_fn,
    )
    print("[INFO] Seq2SeqTrainer initialized.")

    # --- 11. Train the Model ---
    print(f"[INFO] Starting model training for {NUM_TRAIN_EPOCHS} epochs...")
    try:
        train_result = trainer.train()
        metrics = train_result.metrics
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
        print("[INFO] Training finished successfully.")
        if trainer.state.best_model_checkpoint:
             print(f"[INFO] Best model checkpoint saved at: {trainer.state.best_model_checkpoint}")
        else:
             print("[Warning] No best model checkpoint recorded.")

    except Exception as e:
        print(f"[FATAL] Training loop encountered an error: {e}", file=sys.stderr)
        sys.exit(1)

    # --- 12. Evaluate on Test Set ---
    print("[INFO] Evaluating final model on the test set...")
    try:
        test_results = trainer.evaluate(
            eval_dataset=test_dataset,
            metric_key_prefix="test"
        )
        trainer.log_metrics("test", test_results)
        trainer.save_metrics("test", test_results)

        if 'test_sequence_accuracy' in test_results:
             print(f"\n--- Test Set Results ---")
             print(f"Test Sequence Accuracy: {test_results['test_sequence_accuracy']:.4f}")
             print(f"------------------------")
        else:
             print("[Warning] 'test_sequence_accuracy' not found in test results.")
             print("Full test results:", test_results)

    except Exception as e:
        print(f"[Error] Evaluation on test set failed: {e}", file=sys.stderr)
        sys.exit(1) # Exit with error status

    print("[INFO] Script finished successfully.")


if __name__ == "__main__":
    # Optional: Add argument parsing (argparse)
    # Optional: Set random seeds
    # SEED = 42; torch.manual_seed(SEED); np.random.seed(SEED); import random; random.seed(SEED)
    # if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)

    main()

[ERROR] Failed to import TitanTokenizer or TitanForConditionalGeneration.
Please ensure these classes are defined in your transformers library or environment.


SystemExit: 1

# 3.5: Evolutionary and Transformer Models for Symbolic Regression
Model: Transformer model integrated with an evolutionary pipeline. It’s possible to start from previous year’s projects but should introduce a substantial innovation.


In [11]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Evolutionary Hyperparameter Optimization for Transformer-Based Symbolic Regression.

This script employs a Genetic Algorithm (GA), using the DEAP library, to optimize
key hyperparameters (learning rate and dropout probability) for a standard
Transformer Encoder-Decoder model (BERT-to-BERT) applied to a symbolic regression
task using the QED 2→2 dataset.

The optimization process works as follows:
1.  A small, fixed subset of the training and validation data is selected for
    rapid evaluation during the GA.
2.  DEAP initializes a population of candidate hyperparameter sets (individuals).
3.  Each individual is evaluated by:
    a. Initializing a fresh BERT-to-BERT `EncoderDecoderModel`.
    b. Configuring the model with the dropout specified by the individual.
    c. Setting up `Seq2SeqTrainingArguments` with the learning rate from the
       individual and settings suitable for a short training run.
    d. Initializing a `Seq2SeqTrainer`.
    e. Training the model for a fixed, small number of steps/epochs on the
       training subset.
    f. Evaluating the trained model on the validation subset using sequence
       accuracy.
    g. Returning the validation accuracy as the individual's fitness.
4.  DEAP applies genetic operators (selection, crossover, mutation) to evolve
    the population over several generations, aiming to maximize fitness
    (validation accuracy).
5.  After the GA completes, the best hyperparameter set found is identified.
6.  A final `EncoderDecoderModel` is initialized.
7.  A full training run is performed using the best hyperparameters found by the
    GA, training on the complete training dataset and validating on the complete
    validation dataset.
8.  The final performance is reported based on evaluation on the held-out test set.

Note: This script assumes the `deap` library is installed (`pip install deap`).
"""

import json
import random
import sys
import numpy as np
import torch
import datetime
import os
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer,
    EncoderDecoderModel,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    TrainingArguments # Explicit import
)
from deap import base, creator, tools, algorithms
from tqdm.auto import tqdm # Optional, useful if adding manual loops or progress bars later

# --- Configuration ---

# File Paths
# Ensure these point to your actual preprocessed data files
TRAIN_FILE = Path('qed_expressions_train.jsonl') # Adjust prefix if needed
VAL_FILE   = Path('qed_expressions_val.jsonl')
TEST_FILE  = Path('qed_expressions_test.jsonl')
OUTPUT_DIR = Path('ga_transformer_symbolic_regression_output') # Main output directory
GA_LOG_FILE = OUTPUT_DIR / "ga_logbook.json"                   # Log of GA progress
BEST_HP_FILE = OUTPUT_DIR / "best_hyperparameters.json"         # Best HPs found
GA_EVAL_OUTPUT_DIR = OUTPUT_DIR / "ga_eval_runs" # Temp dir for GA evaluation runs

# Model and Tokenizer Configuration
MODEL_ID = "bert-base-uncased" # Base model for Encoder-Decoder
TOKENIZER_ID = "bert-base-uncased"
MAX_SEQ_LENGTH = 128          # Max sequence length for tokenization

# Genetic Algorithm Configuration (DEAP)
POPULATION_SIZE = 20        # Number of HP sets evaluated per generation
N_GENERATIONS = 10          # Total number of GA generations
CX_PROB = 0.7               # Probability of crossover between individuals
MUT_PROB = 0.2              # Probability of mutating an individual
# Mutation parameters for Gaussian perturbation
MUT_SIGMA_LR = 1e-5         # Std deviation for learning rate mutation
MUT_SIGMA_DROPOUT = 0.05    # Std deviation for dropout mutation
# Boundaries for hyperparameter values
LR_BOUND_LOW = 1e-6
LR_BOUND_HIGH = 1e-3
DROPOUT_BOUND_LOW = 0.0
DROPOUT_BOUND_HIGH = 0.5
TOURNAMENT_SIZE = 3         # Selection pressure for tournament selection

# GA Evaluation Run Configuration (Short runs on subsets for speed)
GA_EVAL_TRAIN_SUBSET_SIZE = 512 # Number of training examples per GA eval run
GA_EVAL_VAL_SUBSET_SIZE = 128   # Number of validation examples per GA eval run
GA_EVAL_EPOCHS = 1              # Epochs per GA eval run (keep very small)
# GA_EVAL_MAX_STEPS = 200       # Alternative: Fixed steps per GA eval run
GA_EVAL_BATCH_SIZE = 16         # Batch size for GA eval runs

# Final Training Configuration (using best HPs on full data)
FINAL_TRAIN_EPOCHS = 5
FINAL_BATCH_SIZE = 16
FINAL_WEIGHT_DECAY = 0.01
FINAL_WARMUP_STEPS = 300
FINAL_LOGGING_STEPS = 50
FINAL_USE_FP16 = torch.cuda.is_available() # Use FP16 if available for final run
FINAL_LABEL_SMOOTHING = 0.1     # Optional label smoothing for final run
FINAL_SAVE_TOTAL_LIMIT = 2      # Keep only the best and latest checkpoints

# --- Helper Functions and Classes ---

def load_jsonl(file_path):
    """Loads data from a JSON Lines file, handling basic errors."""
    data = []
    file_path = Path(file_path)
    print(f"[INFO] Loading data from: {file_path}")
    if not file_path.is_file():
        print(f"[Error] Data file not found: {file_path}", file=sys.stderr)
        raise FileNotFoundError(f"File not found: {file_path}")
    try:
        with file_path.open('r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                line = line.strip()
                if line:
                    try: data.append(json.loads(line))
                    except json.JSONDecodeError as e: print(f"[Warning] Skipping invalid JSON on line {i+1} in {file_path}: {e}")
        print(f"[INFO] Successfully loaded {len(data)} records.")
        if not data: print(f"[Warning] No valid records loaded from {file_path}.")
        return data
    except Exception as e:
        print(f"[Error] Failed to load data from {file_path}: {e}", file=sys.stderr)
        raise

def convert_tokens_to_strings(raw_data_list):
    """Converts lists of tokens from loaded data into source and target strings."""
    source_strings, target_strings, skipped_count = [], [], 0
    if not raw_data_list: return source_strings, target_strings
    for i, item in enumerate(raw_data_list):
        input_toks, target_toks = item.get('input_tokens'), item.get('target_tokens')
        if isinstance(input_toks, list) and isinstance(target_toks, list):
            source_strings.append(" ".join(map(str, input_toks)))
            target_strings.append(" ".join(map(str, target_toks)))
        else:
            print(f"[Warning] Skipping item at index {i} due to missing/invalid data: {item}"); skipped_count += 1
    print(f"[INFO] Converted {len(source_strings)} items to strings (skipped {skipped_count}).")
    if not source_strings: print("[Warning] No items successfully converted.")
    return source_strings, target_strings

def encode_sequences(tokenizer, source_texts, target_texts, max_len):
    """Tokenizes source and target sequences, returning lists of IDs/masks."""
    print(f"[INFO] Encoding sequence pairs with max_length={max_len}...")
    encoder_inputs = tokenizer(source_texts, max_length=max_len, padding='max_length', truncation=True, return_tensors=None)
    with tokenizer.as_target_tokenizer():
        decoder_labels = tokenizer(target_texts, max_length=max_len, padding='max_length', truncation=True, return_tensors=None)
    encoder_inputs['labels'] = decoder_labels['input_ids']
    print(f"[INFO] Encoding complete.")
    if not encoder_inputs.get('input_ids') or not encoder_inputs.get('labels') or len(encoder_inputs['input_ids']) != len(encoder_inputs['labels']):
         print("[Warning] Encoding resulted in empty lists or length mismatch.")
    return encoder_inputs

class SequencePairDataset(Dataset):
    """Simple PyTorch Dataset for holding tokenized sequence pair data (as lists)."""
    def __init__(self, encodings):
        if not isinstance(encodings, dict) or 'input_ids' not in encodings: raise ValueError("Invalid encodings format.")
        self.encodings = encodings
        try:
            self.length = len(encodings['input_ids'])
            for key in encodings:
                 if not isinstance(encodings[key], list) or len(encodings[key]) != self.length: raise ValueError(f"Inconsistent length for key '{key}'.")
        except Exception as e: raise ValueError(f"Validation failed: {e}")
        if self.length == 0: raise ValueError("Input encodings are empty.")
        print(f"[INFO] Created Dataset with {self.length} examples.")

    def __len__(self): return self.length
    def __getitem__(self, idx):
        if not 0 <= idx < self.length: raise IndexError(f"Index {idx} out of bounds.")
        try: return {key: self.encodings[key][idx] for key in self.encodings}
        except Exception as e: print(f"[Error] Failed retrieval at index {idx}: {e}", file=sys.stderr); raise

# --- Global variables needed by evaluation function ---
# These are populated in main() before the GA starts
ga_train_dataset_subset = None
ga_val_dataset_subset = None
tokenizer_for_eval = None
data_collator_for_eval = None
compute_metrics_internal = None # Holds the metric calculation logic

# --- DEAP Evaluation Function ---

def evaluate_hyperparams(individual):
    """
    Evaluates a hyperparameter set [learning_rate, dropout_rate] by training
    a small model on data subsets and returning validation accuracy.

    Args:
        individual (deap.creator.Individual): List containing [LR, Dropout].

    Returns:
        tuple: Fitness value (validation_accuracy,). Must be a tuple.
    """
    learning_rate, dropout_rate = individual[0], individual[1]
    dropout_rate = max(DROPOUT_BOUND_LOW, min(DROPOUT_BOUND_HIGH, dropout_rate)) # Clamp dropout

    run_id = f"lr_{learning_rate:.1e}_drop_{dropout_rate:.3f}_{random.randint(1000,9999)}"
    run_output_dir = GA_EVAL_OUTPUT_DIR / run_id

    print(f"[GA Eval] Evaluating LR={learning_rate:.3e}, Dropout={dropout_rate:.4f}")

    try:
        # 1. Initialize NEW model instance
        model = EncoderDecoderModel.from_encoder_decoder_pretrained(MODEL_ID, MODEL_ID)
        if model.config.encoder.vocab_size != len(tokenizer_for_eval):
            model.resize_token_embeddings(len(tokenizer_for_eval))
            model.config.encoder.vocab_size = len(tokenizer_for_eval)
            model.config.decoder.vocab_size = len(tokenizer_for_eval)

        # 2. Apply dropout from individual
        model.config.dropout = dropout_rate
        model.config.attention_dropout = dropout_rate
        if hasattr(model.config, 'encoder'): model.config.encoder.dropout = dropout_rate; model.config.encoder.attention_dropout = dropout_rate
        if hasattr(model.config, 'decoder'): model.config.decoder.dropout = dropout_rate; model.config.decoder.attention_dropout = dropout_rate

        # 3. Configure standard seq2seq settings
        model.config.decoder_start_token_id = tokenizer_for_eval.cls_token_id
        model.config.eos_token_id = tokenizer_for_eval.sep_token_id
        model.config.pad_token_id = tokenizer_for_eval.pad_token_id
        model.config.max_length = MAX_SEQ_LENGTH

        # 4. Minimal Training Arguments for GA eval run
        eval_training_args = Seq2SeqTrainingArguments(
            output_dir=str(run_output_dir),
            num_train_epochs=GA_EVAL_EPOCHS,
            # max_steps=GA_EVAL_MAX_STEPS, # Alternative
            per_device_train_batch_size=GA_EVAL_BATCH_SIZE,
            per_device_eval_batch_size=GA_EVAL_BATCH_SIZE,
            learning_rate=learning_rate,
            weight_decay=0.0, # Keep simple for GA eval
            logging_steps=1000, # Reduce logging noise
            evaluation_strategy="no",
            save_strategy="no",
            predict_with_generate=True, # Use generation for accuracy
            generation_max_length=MAX_SEQ_LENGTH,
            generation_num_beams=1, # Greedy search for speed
            fp16=FINAL_USE_FP16, # Consistent FP16 usage
            report_to="none",
            disable_tqdm=True,
        )

        # 5. Initialize Trainer
        trainer = Seq2SeqTrainer(
            model=model,
            args=eval_training_args,
            train_dataset=ga_train_dataset_subset,
            eval_dataset=ga_val_dataset_subset,
            data_collator=data_collator_for_eval,
            tokenizer=tokenizer_for_eval,
            compute_metrics=compute_metrics_internal # Use the predefined metric logic
        )

        # 6. Train briefly
        trainer.train()

        # 7. Evaluate
        eval_results = trainer.evaluate(eval_dataset=ga_val_dataset_subset)
        accuracy = eval_results.get("eval_sequence_accuracy", 0.0)

        print(f"[GA Eval Result] LR={learning_rate:.3e}, Dropout={dropout_rate:.4f} -> Val Acc={accuracy:.4f}")
        # Optional: Clean up temp dir: import shutil; shutil.rmtree(run_output_dir)

        return (accuracy,) # Return fitness tuple

    except Exception as e:
        print(f"[GA Eval Error] Failed for LR={learning_rate:.3e}, Dropout={dropout_rate:.4f}: {e}", file=sys.stderr)
        return (0.0,) # Return poor fitness on failure


# --- Main Execution Logic ---

def main():
    """Orchestrates data loading, GA optimization, and final model training."""
    print("[INFO] Starting Evolutionary Hyperparameter Optimization Script...")
    # --- ***** Update Timestamp/Location ***** ---
    try:
        # Attempt to get timezone-aware UTC time
        current_time = datetime.datetime.now(datetime.timezone.utc)
        current_time_str = current_time.strftime('%Y-%m-%d %H:%M:%S %Z') # Includes timezone name
    except Exception:
        # Fallback to naive UTC time
        current_time = datetime.datetime.utcnow()
        current_time_str = current_time.strftime('%Y-%m-%d %H:%M:%S UTC (naive)')
    # Incorporate provided context
    print(f"[INFO] Current time: {current_time_str}")
    print(f"[INFO] Location context: San Diego, CA, USA")
    print(f"[INFO] Using Base Model: {MODEL_ID}")

    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    GA_EVAL_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    # --- 1. Load Data ---
    try:
        train_raw = load_jsonl(TRAIN_FILE)
        val_raw = load_jsonl(VAL_FILE)
        test_raw = load_jsonl(TEST_FILE)
    except Exception as e: print(f"[FATAL] Data loading failed: {e}", file=sys.stderr); sys.exit(1)
    if not train_raw or not val_raw or not test_raw: print("[FATAL] Datasets empty after loading.", file=sys.stderr); sys.exit(1)

    # --- 2. Prepare Data Strings ---
    train_src, train_tgt = convert_tokens_to_strings(train_raw)
    val_src,   val_tgt   = convert_tokens_to_strings(val_raw)
    test_src,  test_tgt  = convert_tokens_to_strings(test_raw)
    if not train_src or not val_src or not test_src: print("[FATAL] Data conversion failed.", file=sys.stderr); sys.exit(1)

    # --- 3. Initialize Tokenizer (Global for GA Eval) ---
    global tokenizer_for_eval
    try:
        tokenizer_for_eval = AutoTokenizer.from_pretrained(TOKENIZER_ID)
        if tokenizer_for_eval.pad_token is None: tokenizer_for_eval.add_special_tokens({'pad_token': '[PAD]'})
        if tokenizer_for_eval.cls_token is None: tokenizer_for_eval.add_special_tokens({'cls_token': '[CLS]'})
        if tokenizer_for_eval.sep_token is None: tokenizer_for_eval.add_special_tokens({'sep_token': '[SEP]'})
        pad_token_id = tokenizer_for_eval.pad_token_id
    except Exception as e: print(f"[FATAL] Tokenizer init failed: {e}", file=sys.stderr); sys.exit(1)
    print(f"[INFO] Tokenizer initialized (Vocab: {len(tokenizer_for_eval)}, Pad ID: {pad_token_id}).")

    # --- 4. Tokenize Data & Create Subsets ---
    try:
        print("[INFO] Tokenizing full datasets...")
        train_enc = encode_sequences(tokenizer_for_eval, train_src, train_tgt, MAX_SEQ_LENGTH)
        val_enc   = encode_sequences(tokenizer_for_eval, val_src,   val_tgt,   MAX_SEQ_LENGTH)
        test_enc  = encode_sequences(tokenizer_for_eval, test_src,  test_tgt,  MAX_SEQ_LENGTH)

        full_train_dataset = SequencePairDataset(train_enc)
        full_val_dataset   = SequencePairDataset(val_enc)
        full_test_dataset  = SequencePairDataset(test_enc)

        print(f"[INFO] Creating GA evaluation subsets (Train: {GA_EVAL_TRAIN_SUBSET_SIZE}, Val: {GA_EVAL_VAL_SUBSET_SIZE})")
        train_subset_indices = random.sample(range(len(full_train_dataset)), min(GA_EVAL_TRAIN_SUBSET_SIZE, len(full_train_dataset)))
        val_subset_indices = random.sample(range(len(full_val_dataset)), min(GA_EVAL_VAL_SUBSET_SIZE, len(full_val_dataset)))

        def subset_encodings(enc, indices): return {key: [enc[key][i] for i in indices] for key in enc}

        global ga_train_dataset_subset, ga_val_dataset_subset
        ga_train_dataset_subset = SequencePairDataset(subset_encodings(train_enc, train_subset_indices))
        ga_val_dataset_subset = SequencePairDataset(subset_encodings(val_enc, val_subset_indices))

    except Exception as e: print(f"[FATAL] Data tokenization or subsetting failed: {e}", file=sys.stderr); sys.exit(1)

    # --- 5. Setup Global Collator & Metrics Logic ---
    global data_collator_for_eval, compute_metrics_internal
    temp_model = EncoderDecoderModel.from_encoder_decoder_pretrained(MODEL_ID, MODEL_ID) # Need instance for collator setup
    data_collator_for_eval = DataCollatorForSeq2Seq(tokenizer_for_eval, model=temp_model, label_pad_token_id=-100, pad_to_multiple_of=8 if FINAL_USE_FP16 else None)
    del temp_model
    print("[INFO] Data collator prepared.")

    # Define the internal metric calculation logic once
    def compute_metrics_logic(eval_pred):
        """Internal logic for calculating sequence accuracy."""
        predictions, labels = eval_pred
        if not isinstance(predictions, np.ndarray): predictions = np.array(predictions)
        if not isinstance(labels, np.ndarray): labels = np.array(labels)
        labels = np.where(labels != -100, labels, tokenizer_for_eval.pad_token_id)
        try:
            decoded_preds = tokenizer_for_eval.batch_decode(predictions, skip_special_tokens=True)
            decoded_labels = tokenizer_for_eval.batch_decode(labels, skip_special_tokens=True)
            decoded_preds = [p.strip() for p in decoded_preds]
            decoded_labels = [l.strip() for l in decoded_labels]
            if len(decoded_preds) != len(decoded_labels): return {"sequence_accuracy": 0.0}
            matches = [p == l for p, l in zip(decoded_preds, decoded_labels)]
            accuracy = np.mean(matches) if matches else 0.0
            return {"sequence_accuracy": float(accuracy)}
        except Exception as dec_e:
             print(f"[Metrics Error] Decoding failed: {dec_e}", file=sys.stderr); return {"sequence_accuracy": 0.0}
    compute_metrics_internal = compute_metrics_logic # Assign to global var

    # --- 6. DEAP Genetic Algorithm Setup ---
    print("[INFO] Setting up Genetic Algorithm (DEAP)...")
    creator.create("FitnessMax", base.Fitness, weights=(1.0,))
    creator.create("Individual", list, fitness=creator.FitnessMax)
    toolbox = base.Toolbox()
    toolbox.register("attr_lr", random.uniform, LR_BOUND_LOW, LR_BOUND_HIGH)
    toolbox.register("attr_dropout", random.uniform, DROPOUT_BOUND_LOW, DROPOUT_BOUND_HIGH)
    toolbox.register("individual", tools.initCycle, creator.Individual, (toolbox.attr_lr, toolbox.attr_dropout), n=1)
    toolbox.register("population", tools.initRepeat, list, toolbox.individual)
    toolbox.register("mate", tools.cxTwoPoint)
    # Ensure mutation respects boundaries - DEAP's mutGaussian doesn't inherently
    # We handle clamping within the evaluate function, alternative is custom mutation
    toolbox.register("mutate", tools.mutGaussian, mu=0, sigma=[MUT_SIGMA_LR, MUT_SIGMA_DROPOUT], indpb=0.2)
    toolbox.register("select", tools.selTournament, tournsize=TOURNAMENT_SIZE)
    toolbox.register("evaluate", evaluate_hyperparams)

    # --- 7. Run Genetic Algorithm ---
    print(f"[INFO] Starting GA optimization ({N_GENERATIONS} gens, Pop: {POPULATION_SIZE})...")
    population = toolbox.population(n=POPULATION_SIZE)
    hall_of_fame = tools.HallOfFame(1)
    stats = tools.Statistics(lambda ind: ind.fitness.values)
    stats.register("avg", np.mean); stats.register("std", np.std); stats.register("min", np.min); stats.register("max", np.max)

    try:
        population, logbook = algorithms.eaSimple(population, toolbox, cxpb=CX_PROB, mutpb=MUT_PROB, ngen=N_GENERATIONS, stats=stats, halloffame=hall_of_fame, verbose=True)
        ga_success = True
    except Exception as ga_e: print(f"[ERROR] GA execution failed: {ga_e}", file=sys.stderr); ga_success = False; logbook = None; best_individual = None

    if logbook:
        try:
            with open(GA_LOG_FILE, 'w') as f: json.dump(logbook, f, indent=2)
            print(f"[INFO] GA logbook saved to {GA_LOG_FILE}")
        except Exception as log_e: print(f"[Error] Failed to save GA logbook: {log_e}", file=sys.stderr)

    if not ga_success or not hall_of_fame: print("[FATAL] GA optimization failed or found no best individual. Exiting.", file=sys.stderr); sys.exit(1)

    best_individual = hall_of_fame[0]
    best_lr, best_dropout = best_individual[0], max(DROPOUT_BOUND_LOW, min(DROPOUT_BOUND_HIGH, best_individual[1])) # Clamp final best dropout
    best_fitness = best_individual.fitness.values[0]

    print("\n--- GA Optimization Complete ---")
    print(f"Best Individual: LR={best_lr:.3e}, Dropout={best_dropout:.4f}")
    print(f"Best Validation Accuracy (on subset): {best_fitness:.4f}")
    print("----------------------------------")

    best_hp = {'learning_rate': best_lr, 'dropout_rate': best_dropout, 'validation_accuracy': best_fitness}
    try:
        with open(BEST_HP_FILE, 'w') as f: json.dump(best_hp, f, indent=2)
        print(f"[INFO] Best hyperparameters saved to {BEST_HP_FILE}")
    except Exception as hp_e: print(f"[Error] Failed to save best hyperparameters: {hp_e}", file=sys.stderr)

    # --- 8. Final Training with Best Hyperparameters ---
    print("\n[INFO] Starting final model training using best hyperparameters found...")
    try:
        final_model = EncoderDecoderModel.from_encoder_decoder_pretrained(MODEL_ID, MODEL_ID)
        if final_model.config.encoder.vocab_size != len(tokenizer_for_eval):
            final_model.resize_token_embeddings(len(tokenizer_for_eval))
            final_model.config.encoder.vocab_size = len(tokenizer_for_eval)
            final_model.config.decoder.vocab_size = len(tokenizer_for_eval)

        final_model.config.dropout = best_dropout
        final_model.config.attention_dropout = best_dropout
        if hasattr(final_model.config, 'encoder'): final_model.config.encoder.dropout = best_dropout; final_model.config.encoder.attention_dropout = best_dropout
        if hasattr(final_model.config, 'decoder'): final_model.config.decoder.dropout = best_dropout; final_model.config.decoder.attention_dropout = best_dropout

        final_model.config.decoder_start_token_id = tokenizer_for_eval.cls_token_id
        final_model.config.eos_token_id = tokenizer_for_eval.sep_token_id
        final_model.config.pad_token_id = tokenizer_for_eval.pad_token_id
        final_model.config.max_length = MAX_SEQ_LENGTH
        final_model.config.num_beams = 4 # Use beam search for final evaluation
        final_model.tie_weights()
        # Optional final gradient checkpointing
        # if FINAL_USE_GRADIENT_CHECKPOINTING: final_model.gradient_checkpointing_enable()

    except Exception as model_e: print(f"[FATAL] Failed to initialize final model: {model_e}", file=sys.stderr); sys.exit(1)

    final_training_args = Seq2SeqTrainingArguments(
        output_dir=str(OUTPUT_DIR / "final_model"),
        num_train_epochs=FINAL_TRAIN_EPOCHS,
        per_device_train_batch_size=FINAL_BATCH_SIZE,
        per_device_eval_batch_size=FINAL_BATCH_SIZE,
        learning_rate=best_lr, # Use best LR
        weight_decay=FINAL_WEIGHT_DECAY,
        warmup_steps=FINAL_WARMUP_STEPS,
        logging_dir=str(OUTPUT_DIR / "final_model" / 'logs'),
        logging_steps=FINAL_LOGGING_STEPS,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=FINAL_SAVE_TOTAL_LIMIT,
        load_best_model_at_end=True,
        metric_for_best_model="sequence_accuracy",
        greater_is_better=True,
        predict_with_generate=True,
        generation_max_length=MAX_SEQ_LENGTH,
        generation_num_beams=4,
        fp16=FINAL_USE_FP16,
        label_smoothing_factor=FINAL_LABEL_SMOOTHING,
        # gradient_checkpointing=FINAL_USE_GRADIENT_CHECKPOINTING, # Optional
        report_to="none",
    )

    final_trainer = Seq2SeqTrainer(
        model=final_model, args=final_training_args,
        train_dataset=full_train_dataset, eval_dataset=full_val_dataset,
        data_collator=data_collator_for_eval, tokenizer=tokenizer_for_eval,
        compute_metrics=compute_metrics_internal,
    )

    print(f"[INFO] Starting final training run ({FINAL_TRAIN_EPOCHS} epochs)...")
    try:
        train_result = final_trainer.train()
        final_trainer.log_metrics("final_train", train_result.metrics)
        final_trainer.save_metrics("final_train", train_result.metrics)
        final_trainer.save_state()
        print("[INFO] Final training complete.")
        if final_trainer.state.best_model_checkpoint: print(f"[INFO] Best final model checkpoint: {final_trainer.state.best_model_checkpoint}")
    except Exception as train_e: print(f"[FATAL] Final training failed: {train_e}", file=sys.stderr); sys.exit(1)

    # --- 9. Final Test Evaluation ---
    print("[INFO] Evaluating final best model on the full test set...")
    try:
        test_results = final_trainer.evaluate(eval_dataset=full_test_dataset, metric_key_prefix="test")
        final_trainer.log_metrics("test", test_results)
        final_trainer.save_metrics("test", test_results)
        if 'test_sequence_accuracy' in test_results:
             print(f"\n--- Final Test Set Results ---"); print(f"Test Sequence Accuracy: {test_results['test_sequence_accuracy']:.4f}"); print(f"-----------------------------")
        else: print("[Warning] 'test_sequence_accuracy' not found.", "Full test results:", test_results)
    except Exception as test_e: print(f"[Error] Final evaluation failed: {test_e}", file=sys.stderr); sys.exit(1)

    print("\n[INFO] Script finished successfully.")

if __name__ == "__main__":
    SEED = 42
    random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)
    main()

ModuleNotFoundError: No module named 'deap'

# 3.6: Symbolic empirical representation of squared amplitudes in high-energy physics
Model: Transformer with novel approach for tokenization, data representation and/or preprocessing that leads to better performance than basic tokenization with normalized indices.


In [6]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Fine-tuning Script for a Depth-Aware Transformer for Symbolic Regression.

This script implements and trains a custom Transformer Encoder-Decoder model
designed for symbolic regression tasks, specifically predicting squared amplitudes
in High Energy Physics (HEP) from input amplitudes.

Key features:
- Custom BPE Tokenizer: Trains a Byte-Pair Encoding tokenizer from scratch on the
  combined input and target sequences of the dataset, potentially capturing
  domain-specific physics identifiers and operators more effectively.
- Depth-Aware Embeddings: Calculates the parenthesis nesting depth for each token
  in the input sequence and incorporates this information via dedicated depth
  embeddings, adding structural awareness to the model's input representation.
- Custom Transformer Model: Implements a standard Transformer encoder-decoder
  using PyTorch's `nn.TransformerEncoderLayer` and `nn.TransformerDecoderLayer`,
  modified to accept the combined token, position, and depth embeddings.
- Manual PyTorch Training Loop: Includes standard training, validation, and
  testing phases with loss calculation, optimization, metric computation (sequence
  accuracy), and basic checkpointing.

Workflow:
1. Load raw data splits from JSONL files.
2. Train a new BPE tokenizer on the corpus (if not already trained and saved)
   or load a previously trained tokenizer.
3. Define a custom `DepthDataset` that performs tokenization and calculates
   parenthesis depth for each input sequence during initialization.
4. Define the `DepthTransformer` model architecture using PyTorch's nn.Module.
5. Set up DataLoaders, optimizer, loss criterion, and device (GPU/CPU).
6. Implement the training loop, iterating through epochs and batches, performing
   forward/backward passes, and updating model weights.
7. Implement validation loop to monitor performance (sequence accuracy) and save
   the best model checkpoint based on validation results.
8. Implement the final test loop to evaluate the best model on the held-out
   test set.
"""

import json
import sys
import numpy as np
import torch
import datetime
import os
from pathlib import Path
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaTokenizerFast # Using RobertaTokenizer for BPE training
from tqdm.auto import tqdm # Progress bars
# Optional: Learning rate scheduler
# from transformers import get_linear_schedule_with_warmup

# --- Configuration ---

# File Paths
TRAIN_FILE = Path('qed_expressions_train.jsonl') # Adjust if using different names
VAL_FILE   = Path('qed_expressions_val.jsonl')
TEST_FILE  = Path('qed_expressions_test.jsonl')
OUTPUT_DIR = Path('depth_aware_transformer_output')
TOKENIZER_SAVE_DIR = OUTPUT_DIR / "bpe_physics_tokenizer" # Directory to save/load trained tokenizer
CHECKPOINT_NAME = "depth_transformer_best.pt" # Name for the best model checkpoint

# Tokenizer Configuration
VOCAB_SIZE = 8000         # Target vocabulary size for BPE tokenizer
MIN_FREQUENCY = 2         # Minimum frequency for tokens in BPE training
# Special tokens common for sequence models
SPECIAL_TOKENS = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]

# Data and Model Configuration
MAX_SEQ_LENGTH = 128      # Max sequence length for tokenization and model
MAX_DEPTH = 50            # Maximum expected parenthesis nesting depth + buffer

# Model Hyperparameters (DepthTransformer)
D_MODEL = 512             # Embedding and hidden state dimension
N_HEAD = 8                # Number of attention heads
NUM_ENCODER_LAYERS = 6    # Number of layers in the Transformer encoder
NUM_DECODER_LAYERS = 6    # Number of layers in the Transformer decoder
DIM_FEEDFORWARD = 2048    # Dimension of the feed-forward networks
DROPOUT = 0.1             # Dropout rate (can be added to model layers)

# Training Hyperparameters
NUM_EPOCHS = 10
BATCH_SIZE = 16           # Adjust based on GPU memory
LEARNING_RATE = 5e-4
OPTIMIZER_EPS = 1e-8
WEIGHT_DECAY = 0.01
# LR_SCHEDULER_TYPE = "linear" # Optional scheduler type
# WARMUP_RATIO = 0.1         # Optional warmup ratio
LOG_INTERVAL = 50         # Log training stats every N batches

# --- Helper Functions and Classes ---

def load_jsonl(file_path):
    """Loads data from a JSON Lines file, handling basic errors."""
    # (Same implementation as previous script)
    data = []
    file_path = Path(file_path)
    print(f"[INFO] Loading data from: {file_path}")
    if not file_path.is_file():
        print(f"[Error] Data file not found: {file_path}", file=sys.stderr)
        raise FileNotFoundError(f"File not found: {file_path}")
    try:
        with file_path.open('r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                line = line.strip()
                if line:
                    try: data.append(json.loads(line))
                    except json.JSONDecodeError as e: print(f"[Warning] Skipping invalid JSON on line {i+1} in {file_path}: {e}")
        print(f"[INFO] Successfully loaded {len(data)} records.")
        if not data: print(f"[Warning] No valid records loaded from {file_path}.")
        return data
    except Exception as e:
        print(f"[Error] Failed to load data from {file_path}: {e}", file=sys.stderr)
        raise


def train_or_load_tokenizer(corpus_iterator, save_dir, vocab_size, min_freq, special_tokens):
    """Trains a new BPE tokenizer or loads it if it already exists."""
    save_dir = Path(save_dir)
    config_file = save_dir / "tokenizer_config.json"

    if save_dir.exists() and config_file.exists():
        print(f"[INFO] Loading existing tokenizer from {save_dir}")
        try:
            tokenizer = RobertaTokenizerFast.from_pretrained(str(save_dir))
            # Verify core special tokens are present
            if any(tok not in tokenizer.get_vocab() for tok in ['<s>', '<pad>', '</s>', '<unk>']):
                 print("[Warning] Loaded tokenizer missing some standard special tokens. Recheck config.")
            return tokenizer
        except Exception as e:
            print(f"[Warning] Failed to load existing tokenizer: {e}. Attempting to retrain.", file=sys.stderr)

    print(f"[INFO] Training new BPE tokenizer (Vocab: {vocab_size}, Min Freq: {min_freq})...")
    try:
        # Initialize a base tokenizer (like Roberta's structure) but don't load weights
        # We use RobertaTokenizerFast as it supports training from iterator
        tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") # Base structure
        print("[INFO] Training from iterator...")
        # The train_new_from_iterator method might not exist directly on RobertaTokenizerFast.
        # Often, you use the underlying `tokenizers` library for this.
        # Let's adapt using the `tokenizers` library approach if direct method fails.

        try:
             # Attempt direct method if available in specific transformers version
             tokenizer.train_new_from_iterator(corpus_iterator, vocab_size=vocab_size, min_frequency=min_freq, special_tokens=special_tokens)
        except AttributeError:
             print("[INFO] `train_new_from_iterator` not found, using `tokenizers` library directly.")
             from tokenizers import ByteLevelBPETokenizer as TokenizersBPETokenizer

             # Initialize BPE tokenizer from the `tokenizers` library
             tk_lib_tokenizer = TokenizersBPETokenizer()
             tk_lib_tokenizer.train_from_iterator(
                 corpus_iterator,
                 vocab_size=vocab_size,
                 min_frequency=min_freq,
                 special_tokens=special_tokens
             )
             # Need to save this intermediate tokenizer and then load it into RobertaTokenizerFast
             temp_save_path = save_dir / "temp_tokenizer_files"
             temp_save_path.mkdir(parents=True, exist_ok=True)
             tk_lib_tokenizer.save_model(str(temp_save_path)) # Saves vocab.json and merges.txt

             # Load the trained vocab/merges into the desired HF Tokenizer class
             tokenizer = RobertaTokenizerFast.from_pretrained(str(temp_save_path), max_len=MAX_SEQ_LENGTH) # Or appropriate max_len
             # Clean up temp files
             import shutil; shutil.rmtree(temp_save_path)


        print("[INFO] Tokenizer training complete.")
        save_dir.mkdir(parents=True, exist_ok=True)
        tokenizer.save_pretrained(str(save_dir))
        print(f"[INFO] Tokenizer saved to {save_dir}")
        return tokenizer
    except ImportError as e:
         print(f"[Error] Required library (`tokenizers`) not found for training: {e}. Please install it.", file=sys.stderr)
         raise
    except Exception as e:
        print(f"[Error] Tokenizer training failed: {e}", file=sys.stderr)
        raise


class DepthDataset(Dataset):
    """
    Dataset that tokenizes sequences and calculates parenthesis nesting depth.

    Tokenization and depth calculation are performed once during initialization.
    Stores token IDs, attention masks, depth IDs, and label IDs.

    Note: Stores all processed examples in memory. May be unsuitable for
          extremely large datasets without modification (e.g., lazy processing
          or memory mapping).
    """
    def __init__(self, raw_data, tokenizer, max_len=128, max_depth=50):
        """
        Args:
            raw_data (list[dict]): List of dictionaries, each with 'input_tokens'
                                   and 'target_tokens' keys containing lists of strings.
            tokenizer: Initialized Hugging Face tokenizer instance.
            max_len (int): Maximum sequence length for padding/truncation.
            max_depth (int): Maximum depth value to cap at; also determines depth embedding size.
        """
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.max_depth = max_depth
        self.examples = [] # List to store processed examples

        print(f"[INFO] Processing dataset for DepthDataset (Max Len: {max_len}, Max Depth: {max_depth})...")
        skipped_count = 0
        for i, ex in enumerate(tqdm(raw_data, desc="Processing Raw Data")):
            input_tokens = ex.get("input_tokens")
            target_tokens = ex.get("target_tokens")

            if not isinstance(input_tokens, list) or not isinstance(target_tokens, list):
                print(f"[Warning] Skipping example {i} due to missing or invalid token lists.")
                skipped_count += 1
                continue

            # Join tokens for BPE tokenizer
            src_str = " ".join(map(str, input_tokens))
            tgt_str = " ".join(map(str, target_tokens))

            # Tokenize source and target
            # Note: We don't use return_tensors='pt' here; store lists/ints
            src_enc = tokenizer(src_str, max_length=max_len, truncation=True, padding="max_length", return_tensors=None)
            # Use context manager for target tokenization (good practice)
            with tokenizer.as_target_tokenizer():
                tgt_enc = tokenizer(tgt_str, max_length=max_len, truncation=True, padding="max_length", return_tensors=None)

            # --- Calculate Parenthesis Depth ---
            # Operates on the *original* input_tokens before BPE tokenization
            current_depth = 0
            depth_sequence = []
            for tok in input_tokens:
                # Assign depth *before* potential change for closing parenthesis
                if tok == ")":
                    current_depth = max(0, current_depth - 1) # Decrement after assigning for ')'
                    depth_sequence.append(min(current_depth, max_depth - 1)) # Cap depth
                elif tok == "(":
                    depth_sequence.append(min(current_depth, max_depth - 1)) # Cap depth
                    current_depth += 1 # Increment after assigning for '('
                else:
                    depth_sequence.append(min(current_depth, max_depth - 1)) # Assign current depth

            # --- Align Depth with Tokenized Sequence (Simple Approach) ---
            # This is a simplification. A robust alignment would map original token
            # positions to BPE token positions. Here, we just pad/truncate the
            # depth sequence calculated on original tokens to match the BPE sequence length.
            # This might misalign depth for tokens split by BPE.
            # TODO: Implement a more robust alignment if needed.
            if len(depth_sequence) >= max_len:
                aligned_depth_ids = depth_sequence[:max_len]
            else:
                # Pad with depth 0 (assuming root level)
                aligned_depth_ids = depth_sequence + [0] * (max_len - len(depth_sequence))

            self.examples.append({
                "input_ids":      src_enc["input_ids"],
                "attention_mask": src_enc["attention_mask"],
                "depth_ids":      aligned_depth_ids,
                "labels":         tgt_enc["input_ids"], # Target token IDs
            })

        if skipped_count > 0:
            print(f"[Warning] Skipped {skipped_count} examples during dataset processing.")
        if not self.examples:
             raise ValueError("Dataset processing resulted in zero valid examples.")
        print(f"[INFO] DepthDataset processed. Total examples: {len(self.examples)}")

    def __len__(self):
        """Returns the total number of processed examples."""
        return len(self.examples)

    def __getitem__(self, idx):
        """
        Retrieves a processed example by index.
        Returns data as lists/ints; DataLoader handles tensor conversion.
        """
        if not 0 <= idx < len(self.examples):
            raise IndexError(f"Index {idx} out of bounds.")
        return self.examples[idx]


class DepthTransformer(nn.Module):
    """
    Transformer Encoder-Decoder model incorporating token depth embeddings.

    Input representation combines token, position, and parenthesis depth embeddings.
    Uses standard PyTorch Transformer layers for sequence processing.
    """
    def __init__(self, vocab_size, d_model=512, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, max_len=128,
                 max_depth=50, dropout=0.1, pad_token_id=None):
        """
        Initializes the Depth-Aware Transformer.

        Args:
            vocab_size (int): Size of the vocabulary.
            d_model (int): Dimension of embeddings and hidden states.
            nhead (int): Number of attention heads.
            num_encoder_layers (int): Number of layers in the encoder stack.
            num_decoder_layers (int): Number of layers in the decoder stack.
            dim_feedforward (int): Dimension of the feed-forward networks.
            max_len (int): Maximum sequence length for positional embeddings.
            max_depth (int): Maximum depth value + 1 (size of depth embedding table).
            dropout (float): Dropout rate.
            pad_token_id (int): ID of the padding token for embeddings.
        """
        super().__init__()
        self.d_model = d_model
        self.pad_token_id = pad_token_id if pad_token_id is not None else 0 # Default assumption

        # Embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=self.pad_token_id)
        self.positional_embedding = nn.Embedding(max_len, d_model) # Absolute positional
        self.depth_embedding = nn.Embedding(max_depth, d_model) # Depth information

        self.embedding_dropout = nn.Dropout(dropout)

        # Standard PyTorch Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        encoder_norm = nn.LayerNorm(d_model)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        # Standard PyTorch Transformer Decoder
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        decoder_norm = nn.LayerNorm(d_model)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)

        # Output Projection Layer
        self.output_projection = nn.Linear(d_model, vocab_size)

        print(f"[INFO] Initialized DepthTransformer:")
        print(f"  - Vocab Size: {vocab_size}, d_model: {d_model}, Heads: {nhead}")
        print(f"  - Max Len: {max_len}, Max Depth: {max_depth}")
        print(f"  - Enc Layers: {num_encoder_layers}, Dec Layers: {num_decoder_layers}")
        print(f"  - Feedforward Dim: {dim_feedforward}, Dropout: {dropout}")


    def forward(self, src_input_ids, src_attention_mask, src_depth_ids, tgt_input_ids):
        """
        Performs the forward pass through the depth-aware Transformer.

        Args:
            src_input_ids (torch.Tensor): Source sequence token IDs (batch_size, src_seq_len).
            src_attention_mask (torch.Tensor): Source sequence attention mask (batch_size, src_seq_len).
                                               (1 for real tokens, 0 for padding).
            src_depth_ids (torch.Tensor): Source sequence depth IDs (batch_size, src_seq_len).
            tgt_input_ids (torch.Tensor): Target sequence token IDs (batch_size, tgt_seq_len),
                                          typically shifted right for teacher forcing.

        Returns:
            torch.Tensor: Output logits (batch_size, tgt_seq_len, vocab_size).
        """
        batch_size, src_seq_len = src_input_ids.shape
        _, tgt_seq_len = tgt_input_ids.shape

        # 1. Source Embeddings (Token + Position + Depth)
        # Create position IDs (0 to seq_len-1)
        src_pos_ids = torch.arange(src_seq_len, device=src_input_ids.device).unsqueeze(0).expand(batch_size, -1)

        src_emb = self.token_embedding(src_input_ids)           # (B, S, D)
        src_pos = self.positional_embedding(src_pos_ids)       # (B, S, D)
        src_dep = self.depth_embedding(src_depth_ids)         # (B, S, D)

        src_combined_emb = src_emb + src_pos + src_dep # Combine embeddings
        src_combined_emb = self.embedding_dropout(src_combined_emb)

        # 2. Encoder
        # PyTorch Transformer layers expect masks where True indicates a position *not* to attend to.
        # src_attention_mask is (B, S) with 1 for non-pad, 0 for pad. Need (B, S) with True for pad.
        src_key_padding_mask = (src_attention_mask == 0) # True where padded

        # Pass through encoder stack
        memory = self.encoder(src_combined_emb, src_key_padding_mask=src_key_padding_mask) # (B, S, D)

        # 3. Target Embeddings (Token + Position) - No depth for target typically
        # Target IDs are usually shifted right for teacher forcing (e.g., start with BOS, end before EOS)
        tgt_pos_ids = torch.arange(tgt_seq_len, device=tgt_input_ids.device).unsqueeze(0).expand(batch_size, -1)

        tgt_emb = self.token_embedding(tgt_input_ids)           # (B, T, D)
        tgt_pos = self.positional_embedding(tgt_pos_ids)       # (B, T, D)

        tgt_combined_emb = tgt_emb + tgt_pos # Combine target embeddings
        tgt_combined_emb = self.embedding_dropout(tgt_combined_emb)

        # 4. Decoder
        # Create target padding mask (True for padded positions)
        # Assuming target uses same pad token id
        # Need a mask for the target sequence itself
        tgt_key_padding_mask = (tgt_input_ids == self.pad_token_id)

        # Create causal mask (autoregressive mask) to prevent attending to future tokens
        # Shape: (T, T)
        tgt_causal_mask = nn.Transformer.generate_square_subsequent_mask(tgt_seq_len, device=tgt_input_ids.device)

        # Pass through decoder stack
        decoder_output = self.decoder(
            tgt=tgt_combined_emb,           # Target sequence embeddings (B, T, D)
            memory=memory,                  # Encoder output (B, S, D)
            tgt_mask=tgt_causal_mask,       # Causal mask (T, T)
            tgt_key_padding_mask=tgt_key_padding_mask, # Target padding mask (B, T)
            memory_key_padding_mask=src_key_padding_mask # Source padding mask (B, S)
        ) # Output shape: (B, T, D)

        # 5. Output Projection
        logits = self.output_projection(decoder_output) # (B, T, VocabSize)

        return logits


def calculate_sequence_accuracy(logits, labels, pad_token_id):
    """Calculates exact sequence match accuracy using PyTorch tensors."""
    if logits.shape[0] == 0: return 0.0
    predictions = torch.argmax(logits, dim=-1)
    non_pad_mask = (labels != pad_token_id)
    correct_tokens = (predictions == labels) & non_pad_mask
    correct_sequences = (torch.sum(correct_tokens, dim=1) == torch.sum(non_pad_mask, dim=1))
    accuracy = torch.mean(correct_sequences.float())
    return accuracy.item()


# --- Main Execution ---

def main():
    """Orchestrates tokenizer training, data processing, model training, and evaluation."""
    print("[INFO] Starting Depth-Aware Transformer Script...")
    try: current_time_str = datetime.datetime.now(datetime.timezone.utc).strftime('%Y-%m-%d %H:%M:%S UTC')
    except Exception: current_time_str = datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC (naive)')
    # --- ***** Update Timestamp/Location ***** ---
    print(f"[INFO] Current time: {current_time_str}")
    print(f"[INFO] Location context: San Diego, CA, USA")

    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    # --- 1. Load Raw Data ---
    try:
        train_raw = load_jsonl(TRAIN_FILE)
        val_raw = load_jsonl(VAL_FILE)
        test_raw = load_jsonl(TEST_FILE)
    except Exception as e: print(f"[FATAL] Data loading failed: {e}", file=sys.stderr); sys.exit(1)
    if not train_raw or not val_raw or not test_raw: print("[FATAL] Datasets empty after loading.", file=sys.stderr); sys.exit(1)

    # --- 2. Prepare Corpus and Train/Load Tokenizer ---
    def corpus_gen(): # Generator to avoid loading all strings into memory at once
        print("[INFO] Preparing corpus for tokenizer training...")
        count = 0
        for split in (train_raw, val_raw, test_raw):
            for ex in split:
                if ex.get("input_tokens"): yield " ".join(map(str, ex["input_tokens"]))
                if ex.get("target_tokens"): yield " ".join(map(str, ex["target_tokens"]))
                count += 2
        print(f"[INFO] Corpus generator ready ({count} sequences).")

    try:
        tokenizer = train_or_load_tokenizer(
            corpus_iterator=corpus_gen(),
            save_dir=TOKENIZER_SAVE_DIR,
            vocab_size=VOCAB_SIZE,
            min_freq=MIN_FREQUENCY,
            special_tokens=SPECIAL_TOKENS
        )
        vocab_size = len(tokenizer) # Get actual vocab size after training/loading
        pad_token_id = tokenizer.pad_token_id
        if pad_token_id is None:
             print("[Error] Tokenizer loaded without a pad token ID!", file=sys.stderr); sys.exit(1)

    except Exception as e:
        print(f"[FATAL] Tokenizer training or loading failed: {e}", file=sys.stderr)
        sys.exit(1)
    print(f"[INFO] Tokenizer ready (Vocab size: {vocab_size}, Pad ID: {pad_token_id}).")


    # --- 3. Create Datasets ---
    try:
        print("[INFO] Creating DepthDatasets...")
        train_dataset = DepthDataset(train_raw, tokenizer, MAX_SEQ_LENGTH, MAX_DEPTH)
        val_dataset   = DepthDataset(val_raw,   tokenizer, MAX_SEQ_LENGTH, MAX_DEPTH)
        test_dataset  = DepthDataset(test_raw,  tokenizer, MAX_SEQ_LENGTH, MAX_DEPTH)
    except Exception as e:
        print(f"[FATAL] Failed to create datasets: {e}", file=sys.stderr)
        sys.exit(1)
    if not train_dataset or not val_dataset or not test_dataset:
         print("[FATAL] One or more datasets are empty after processing. Exiting.", file=sys.stderr)
         sys.exit(1)


    # --- 4. Create DataLoaders ---
    # Define a simple collate function to handle batching of list data into tensors
    def collate_batch(batch):
        elem = batch[0]
        collated = {}
        for key in elem.keys():
            collated[key] = torch.tensor([d[key] for d in batch])
        return collated

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch, num_workers=4, pin_memory=True)
    val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch, num_workers=4, pin_memory=True)
    test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch, num_workers=4, pin_memory=True)
    print("[INFO] DataLoaders created.")


    # --- 5. Initialize Model, Optimizer, Criterion ---
    print(f"[INFO] Initializing DepthTransformer model...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[INFO] Using device: {device}")

    try:
        model = DepthTransformer(
            vocab_size=vocab_size,
            d_model=D_MODEL,
            nhead=N_HEAD,
            num_encoder_layers=NUM_ENCODER_LAYERS,
            num_decoder_layers=NUM_DECODER_LAYERS,
            dim_feedforward=DIM_FEEDFORWARD,
            max_len=MAX_SEQ_LENGTH,
            max_depth=MAX_DEPTH,
            dropout=DROPOUT,
            pad_token_id=pad_token_id
        ).to(device)
    except Exception as e:
        print(f"[FATAL] Failed to initialize model: {e}", file=sys.stderr)
        sys.exit(1)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=LEARNING_RATE,
        eps=OPTIMIZER_EPS,
        weight_decay=WEIGHT_DECAY
    )
    criterion = nn.CrossEntropyLoss(ignore_index=pad_token_id) # Ignore padding in loss
    print("[INFO] Model, Optimizer, and Criterion initialized.")

    # Optional: Learning Rate Scheduler
    # total_steps = len(train_loader) * NUM_EPOCHS
    # warmup_steps = int(total_steps * WARMUP_RATIO)
    # scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)


    # --- 6. Training Loop ---
    print(f"[INFO] Starting training for {NUM_EPOCHS} epochs...")
    best_val_accuracy = -1.0
    best_epoch = -1
    checkpoint_path = OUTPUT_DIR / CHECKPOINT_NAME

    for epoch in range(NUM_EPOCHS):
        print(f"\n--- Epoch {epoch+1}/{NUM_EPOCHS} ---")
        model.train()
        total_train_loss = 0.0
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} Training", leave=False)

        for step, batch in enumerate(train_pbar):
            # Move batch data to the device
            src_ids = batch["input_ids"].to(device)
            src_mask = batch["attention_mask"].to(device) # Encoder padding mask source
            depth_ids = batch["depth_ids"].to(device)
            labels = batch["labels"].to(device)          # Target sequence (B, T)

            # Prepare decoder inputs (shifted right) and targets
            # Decoder input: BOS token + sequence (excluding last token)
            # Decoder target: sequence (excluding first token) + EOS token (implicitly handled by loss shifting)
            # For CrossEntropyLoss, logits (B, T, V) should align with targets (B, T)
            tgt_input = labels[:, :-1] # (B, T-1), excludes last token
            tgt_output = labels[:, 1:]  # (B, T-1), excludes first token (BOS/CLS)

            optimizer.zero_grad()

            # Forward pass
            logits = model(src_input_ids=src_ids,
                           src_attention_mask=src_mask,
                           src_depth_ids=depth_ids,
                           tgt_input_ids=tgt_input) # Pass shifted input (B, T-1, V)

            # Calculate loss - compare logits with shifted *output* targets
            loss = criterion(logits.reshape(-1, vocab_size), # (B*(T-1), V)
                             tgt_output.reshape(-1))         # (B*(T-1))

            # Backward pass and optimize
            loss.backward()
            # Optional: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            # if scheduler: scheduler.step() # Update LR if scheduler is used

            total_train_loss += loss.item()

            # Log training progress
            if (step + 1) % LOG_INTERVAL == 0:
                avg_loss = total_train_loss / LOG_INTERVAL
                # current_lr = scheduler.get_last_lr()[0] if scheduler else LEARNING_RATE
                current_lr = LEARNING_RATE # Simpler if no scheduler
                train_pbar.set_postfix({'Avg Loss': f'{avg_loss:.4f}', 'LR': f'{current_lr:.2e}'})
                total_train_loss = 0.0 # Reset accumulator

        train_pbar.close()

        # --- Validation Phase ---
        model.eval()
        total_val_accuracy = 0.0
        total_val_samples = 0
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1} Validation", leave=False)

        with torch.no_grad():
            for batch in val_pbar:
                src_ids = batch["input_ids"].to(device)
                src_mask = batch["attention_mask"].to(device)
                depth_ids = batch["depth_ids"].to(device)
                labels = batch["labels"].to(device)
                tgt_input = labels[:, :-1]
                tgt_output = labels[:, 1:]

                logits = model(src_ids, src_mask, depth_ids, tgt_input)

                batch_accuracy = calculate_sequence_accuracy(logits, tgt_output, pad_token_id)
                total_val_accuracy += batch_accuracy * src_ids.size(0) # Weighted by batch size
                total_val_samples += src_ids.size(0)

        val_pbar.close()
        epoch_val_accuracy = total_val_accuracy / total_val_samples if total_val_samples > 0 else 0.0
        print(f"Epoch {epoch+1} Validation Sequence Accuracy: {epoch_val_accuracy:.4f}")

        # --- Save Best Model Checkpoint ---
        if epoch_val_accuracy > best_val_accuracy:
            best_val_accuracy = epoch_val_accuracy
            best_epoch = epoch + 1
            try:
                torch.save(model.state_dict(), checkpoint_path)
                print(f"[INFO] New best model saved to {checkpoint_path} (Epoch {best_epoch}, Val Acc: {best_val_accuracy:.4f})")
            except Exception as e:
                print(f"[Error] Failed to save model checkpoint: {e}", file=sys.stderr)

    print(f"\n[INFO] Training complete. Best validation accuracy: {best_val_accuracy:.4f} at epoch {best_epoch}.")

    # --- 7. Final Test Evaluation ---
    print("\n--- Final Test Set Evaluation ---")
    if checkpoint_path.exists():
        print(f"[INFO] Loading best model from: {checkpoint_path}")
        try:
            model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        except Exception as e:
             print(f"[Error] Failed to load best model checkpoint: {e}. Evaluating with the final state.", file=sys.stderr)
    else:
        print("[Warning] Best model checkpoint not found. Evaluating with the final model state.", file=sys.stderr)

    model.eval()
    total_test_accuracy = 0.0
    total_test_samples = 0
    test_pbar = tqdm(test_loader, desc="Testing", leave=False)

    with torch.no_grad():
        for batch in test_pbar:
            src_ids = batch["input_ids"].to(device)
            src_mask = batch["attention_mask"].to(device)
            depth_ids = batch["depth_ids"].to(device)
            labels = batch["labels"].to(device)
            tgt_input = labels[:, :-1]
            tgt_output = labels[:, 1:]

            logits = model(src_ids, src_mask, depth_ids, tgt_input)

            batch_accuracy = calculate_sequence_accuracy(logits, tgt_output, pad_token_id)
            total_test_accuracy += batch_accuracy * src_ids.size(0)
            total_test_samples += src_ids.size(0)

    test_pbar.close()
    final_test_accuracy = total_test_accuracy / total_test_samples if total_test_samples > 0 else 0.0
    print(f"\nFinal Test Sequence Accuracy: {final_test_accuracy:.4f}")

    print("\n[INFO] Script finished successfully.")

if __name__ == "__main__":
    # Optional: Seed everything
    # SEED = 42
    # torch.manual_seed(SEED); np.random.seed(SEED); import random; random.seed(SEED)
    # if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)
    main()

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

ImportError: tokenizers>=0.11.1,!=0.11.3,<0.14 is required for a normal functioning of this module, but found tokenizers==0.21.1.
Try: pip install transformers -U or pip install -e '.[dev]' if you're working with git main

# 3.7: Foundation models for symbolic regression tasks
Model: Novel foundation model for symbolic regression tasks. Should be sufficiently novel beyond the current literature.


In [13]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Fine-tuning Script for a T5 Model with Mixture-of-Experts (MoE) Layers
for Symbolic Regression (Squared Amplitude Calculation).

This script demonstrates modifying a standard T5 (`t5-small`) model by replacing
its Feed-Forward Networks (FFNs) with custom Mixture-of-Experts (MoE) layers.
The MoE layers potentially allow different "expert" networks within the model to
specialize in processing different types of symbolic patterns found in the input
sequences (e.g., polynomials, trigonometric functions).

The script utilizes the Hugging Face Transformers library, including the
`Seq2SeqTrainer` API for streamlined training and evaluation.

Key features demonstrated:
- T5 Base Model: Uses `t5-small` as the starting point.
- Custom MoE Layer: Defines and integrates a `MoEFeedForward` module.
- Model Patching: Dynamically replaces the standard FFNs in the loaded T5 model
  with the custom MoE layers.
- Advanced Training Techniques: Incorporates Gradient Checkpointing, Mixed
  Precision Training (FP16), and Label Smoothing.
- Standard Workflow: Follows data loading, preprocessing, tokenization, model
  configuration, training, and evaluation steps.
- Exact Match Accuracy (Logit-based): Evaluates performance using sequence
  accuracy calculated directly from model logits (`predict_with_generate=False`).

Workflow:
1. Load pre-split data from JSONL files.
2. Reconstruct source and target strings from token lists.
3. Initialize the T5 tokenizer.
4. Define the custom `MoEFeedForward` nn.Module.
5. Load the base T5 model and patch its encoder/decoder blocks by replacing
   standard FFN layers with instances of `MoEFeedForward`.
6. Configure the modified model (special tokens, gradient checkpointing).
7. Define a function to tokenize source/target string pairs.
8. Wrap tokenized data into PyTorch Dataset objects.
9. Instantiate `DataCollatorForSeq2Seq`.
10. Define the `compute_metrics` function for sequence accuracy from logits.
11. Configure `Seq2SeqTrainingArguments`.
12. Initialize and run the `Seq2SeqTrainer`.
13. Evaluate the final model on the test set.
"""

import json
import sys
import numpy as np
import torch
import datetime
import os
from pathlib import Path
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    T5Tokenizer,
    T5ForConditionalGeneration,
    T5Config, # To inspect T5 block structure if needed
    T5Block, # To check layer types
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    TrainingArguments # Explicit import
)
from tqdm.auto import tqdm # Optional progress bars

# --- Configuration ---

# File Paths
TRAIN_FILE = Path('qed_expressions_train.jsonl') # Adjust if needed
VAL_FILE   = Path('qed_expressions_val.jsonl')
TEST_FILE  = Path('qed_expressions_test.jsonl')
OUTPUT_DIR = Path('t5_moe_symbolic_regression_output') # Output directory

# Model Configuration
MODEL_ID = "t5-small" # Base T5 model identifier
TOKENIZER_ID = "t5-small"

# MoE Configuration
NUM_EXPERTS = 4             # Number of experts in each MoE layer
# d_ff (intermediate FFN dimension) will be taken from the base T5 config

# Tokenizer and Data Processing Configuration
MAX_SEQ_LENGTH = 128        # Max sequence length

# Training Hyperparameters
NUM_TRAIN_EPOCHS = 5
PER_DEVICE_TRAIN_BATCH_SIZE = 8
PER_DEVICE_EVAL_BATCH_SIZE = 8
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 0.01
WARMUP_STEPS = 200
LOGGING_STEPS = 50
EVALUATION_STRATEGY = "epoch"
SAVE_STRATEGY = "epoch"

# Advanced Training Feature Flags/Values
USE_FP16 = torch.cuda.is_available()
USE_GRADIENT_CHECKPOINTING = True
LABEL_SMOOTHING_FACTOR = 0.1

# --- Helper Functions and Classes ---

def load_jsonl(file_path):
    """Loads data from a JSON Lines file, handling basic errors."""
    data = []
    file_path = Path(file_path)
    print(f"[INFO] Loading data from: {file_path}")
    if not file_path.is_file():
        print(f"[Error] Data file not found: {file_path}", file=sys.stderr)
        raise FileNotFoundError(f"File not found: {file_path}")
    try:
        with file_path.open('r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                line = line.strip()
                if line:
                    try: data.append(json.loads(line))
                    except json.JSONDecodeError as e: print(f"[Warning] Skipping invalid JSON on line {i+1} in {file_path}: {e}")
        print(f"[INFO] Successfully loaded {len(data)} records.")
        if not data: print(f"[Warning] No valid records loaded from {file_path}.")
        return data
    except Exception as e:
        print(f"[Error] Failed to load data from {file_path}: {e}", file=sys.stderr)
        raise

def convert_tokens_to_strings(raw_data_list):
    """Converts lists of tokens from loaded data into source and target strings."""
    source_strings, target_strings, skipped_count = [], [], 0
    if not raw_data_list: return source_strings, target_strings
    for i, item in enumerate(raw_data_list):
        input_toks, target_toks = item.get('input_tokens'), item.get('target_tokens')
        if isinstance(input_toks, list) and isinstance(target_toks, list):
            source_strings.append(" ".join(map(str, input_toks)))
            target_strings.append(" ".join(map(str, target_toks)))
        else:
            print(f"[Warning] Skipping item at index {i} due to missing/invalid data: {item}"); skipped_count += 1
    print(f"[INFO] Converted {len(source_strings)} items to strings (skipped {skipped_count}).")
    if not source_strings: print("[Warning] No items successfully converted.")
    return source_strings, target_strings

def encode_sequences(tokenizer, source_texts, target_texts, max_len):
    """Tokenizes source and target sequences, returning lists of IDs/masks."""
    print(f"[INFO] Encoding sequence pairs with max_length={max_len}...")
    encoder_inputs = tokenizer(source_texts, max_length=max_len, padding='max_length', truncation=True, return_tensors=None)
    with tokenizer.as_target_tokenizer():
        decoder_labels = tokenizer(target_texts, max_length=max_len, padding='max_length', truncation=True, return_tensors=None)
    encoder_inputs['labels'] = decoder_labels['input_ids']
    print(f"[INFO] Encoding complete.")
    if not encoder_inputs.get('input_ids') or not encoder_inputs.get('labels') or len(encoder_inputs['input_ids']) != len(encoder_inputs['labels']):
         print("[Warning] Encoding resulted in empty lists or length mismatch.")
    return encoder_inputs

class SequencePairDataset(Dataset):
    """Simple PyTorch Dataset for holding tokenized sequence pair data (as lists)."""
    def __init__(self, encodings):
        if not isinstance(encodings, dict) or 'input_ids' not in encodings: raise ValueError("Invalid encodings format.")
        self.encodings = encodings
        try:
            self.length = len(encodings['input_ids'])
            for key in encodings:
                 if not isinstance(encodings[key], list) or len(encodings[key]) != self.length: raise ValueError(f"Inconsistent length for key '{key}'.")
        except Exception as e: raise ValueError(f"Validation failed: {e}")
        if self.length == 0: raise ValueError("Input encodings are empty.")
        print(f"[INFO] Created Dataset with {self.length} examples.")

    def __len__(self): return self.length
    def __getitem__(self, idx):
        if not 0 <= idx < self.length: raise IndexError(f"Index {idx} out of bounds.")
        try: return {key: self.encodings[key][idx] for key in self.encodings}
        except Exception as e: print(f"[Error] Failed retrieval at index {idx}: {e}", file=sys.stderr); raise

# --- Custom MoE Layer ---

class MoEFeedForward(nn.Module):
    """
    Mixture of Experts Feed-Forward Network Layer.

    Routes input tokens to different 'expert' FFNs based on a gating mechanism.
    The gating mechanism here uses the mean-pooled representation of the sequence
    to decide the weights for combining expert outputs.
    """
    def __init__(self, d_model, d_ff, n_experts=4, dropout_rate=0.1):
        """
        Initializes the MoE layer.

        Args:
            d_model (int): Input and output dimension of the layer.
            d_ff (int): Intermediate dimension of the expert FFNs.
            n_experts (int): Number of expert networks.
            dropout_rate (float): Dropout rate for expert FFNs.
        """
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        self.n_experts = n_experts

        # Define the expert networks (simple two-layer MLPs)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_ff),
                nn.ReLU(), # Or GeLU, SiLU etc. matching base model if known
                nn.Dropout(dropout_rate), # Add dropout within experts
                nn.Linear(d_ff, d_model)
            ) for _ in range(n_experts)
        ])

        # Gating network: maps pooled sequence representation to expert weights
        # Takes mean across sequence length dimension as input
        self.gate = nn.Linear(d_model, n_experts)

    def forward(self, hidden_states):
        """
        Forward pass through the MoE layer.

        Args:
            hidden_states (torch.Tensor): Input tensor (batch_size, seq_len, d_model).

        Returns:
            torch.Tensor: Output tensor (batch_size, seq_len, d_model).
        """
        # hidden_states: (B, S, D)
        batch_size, seq_len, d_model = hidden_states.shape

        # 1. Gating: Use mean pooling for simplicity to get sequence-level weights
        # Alternative: Could implement token-level gating for finer control.
        pooled_hidden_states = hidden_states.mean(dim=1) # (B, D)

        # Calculate gating scores (logits) and normalize to probabilities (weights)
        gating_logits = self.gate(pooled_hidden_states) # (B, N_experts)
        gating_weights = torch.softmax(gating_logits, dim=-1) # (B, N_experts)

        # 2. Expert Evaluation
        # Pass the full sequence through each expert
        expert_outputs = []
        for expert in self.experts:
            expert_outputs.append(expert(hidden_states))
        # Stack expert outputs: List[ (B, S, D) ] -> Tensor(N_experts, B, S, D)
        expert_outputs_stacked = torch.stack(expert_outputs, dim=0)

        # 3. Weighted Combination
        # Reshape weights for broadcasting: (B, N_experts) -> (B, N_experts, 1, 1)
        gating_weights_reshaped = gating_weights.unsqueeze(-1).unsqueeze(-1)
        # Reshape expert outputs for broadcasting: (N_experts, B, S, D) -> (B, N_experts, S, D)
        expert_outputs_permuted = expert_outputs_stacked.permute(1, 0, 2, 3)

        # Weighted sum: (B, N_experts, 1, 1) * (B, N_experts, S, D) -> sum over N_experts axis
        # Result: (B, S, D)
        output = torch.sum(gating_weights_reshaped * expert_outputs_permuted, dim=1)

        return output


def replace_ffn_with_moe(model, n_experts):
    """
    Replaces the standard FFN layers in a T5 model with MoEFeedForward layers.

    Iterates through encoder and decoder blocks and replaces the appropriate layer.
    Requires knowledge of the T5Block structure.

    Args:
        model: The T5ForConditionalGeneration model instance.
        n_experts (int): Number of experts for the MoE layers.

    Returns:
        The modified model instance.
    """
    print(f"[INFO] Patching T5 model with MoE layers (Experts: {n_experts})...")
    replaced_count = 0
    try:
        # Encoder blocks
        if hasattr(model, 'encoder') and hasattr(model.encoder, 'block'):
            for i, block in enumerate(model.encoder.block):
                # T5Block structure: [LayerNorm, SelfAttention, Dropout], [LayerNorm, DenseReluDense, Dropout]
                # The FFN (`DenseReluDense`) is typically the second layer in the second sub-block group.
                # Accessing via attribute name is safer if known, e.g., block.layer[1].DenseReluDense
                # Let's try accessing the assumed attribute location and verify type.
                ffn_layer_attr = None
                if hasattr(block, 'layer') and len(block.layer) > 1 and \
                   hasattr(block.layer[1], 'DenseReluDense') and \
                   isinstance(block.layer[1].DenseReluDense, nn.Module): # Check if it looks like the FFN component
                    ffn_layer_attr = block.layer[1].DenseReluDense
                    layer_target = block.layer[1]
                    attr_name = 'DenseReluDense'

                if ffn_layer_attr is not None:
                    print(f"  - Replacing Encoder Block {i} FFN...")
                    # Extract config needed for MoE layer
                    d_model = model.config.d_model
                    d_ff = model.config.d_ff
                    dropout_rate = model.config.dropout_rate # Get dropout from config

                    # Create and assign the new MoE layer
                    setattr(layer_target, attr_name, MoEFeedForward(d_model, d_ff, n_experts, dropout_rate))
                    replaced_count += 1
                else:
                     print(f"  - [Warning] Could not find/replace FFN in Encoder Block {i} structure.")

        # Decoder blocks
        if hasattr(model, 'decoder') and hasattr(model.decoder, 'block'):
            for i, block in enumerate(model.decoder.block):
                # T5DecoderBlock structure: [LN, SelfAttn, Drop], [LN, CrossAttn, Drop], [LN, FFN(DenseReluDense), Drop]
                # The FFN is typically the second layer in the *third* sub-block group.
                ffn_layer_attr = None
                if hasattr(block, 'layer') and len(block.layer) > 2 and \
                   hasattr(block.layer[2], 'DenseReluDense') and \
                   isinstance(block.layer[2].DenseReluDense, nn.Module):
                    ffn_layer_attr = block.layer[2].DenseReluDense
                    layer_target = block.layer[2]
                    attr_name = 'DenseReluDense'

                if ffn_layer_attr is not None:
                    print(f"  - Replacing Decoder Block {i} FFN...")
                    d_model = model.config.d_model
                    d_ff = model.config.d_ff
                    dropout_rate = model.config.dropout_rate

                    setattr(layer_target, attr_name, MoEFeedForward(d_model, d_ff, n_experts, dropout_rate))
                    replaced_count += 1
                else:
                    print(f"  - [Warning] Could not find/replace FFN in Decoder Block {i} structure.")

        if replaced_count == 0:
             print("[Warning] No FFN layers were replaced. Check model structure and replacement logic.")
        else:
             print(f"[INFO] Successfully replaced {replaced_count} FFN layers with MoE.")

    except Exception as e:
        print(f"[Error] Failed during model patching: {e}", file=sys.stderr)
        # Depending on severity, may want to raise or exit
    return model


# --- Main Execution Logic ---

def main():
    """Orchestrates data loading, MoE model patching, training, and evaluation."""
    print("[INFO] Starting T5 with Mixture-of-Experts Fine-tuning Script...")
    # --- ***** Update Timestamp/Location ***** ---
    try:
        current_time = datetime.datetime.now(datetime.timezone.utc)
        current_time_str = current_time.strftime('%Y-%m-%d %H:%M:%S %Z')
    except Exception:
        current_time = datetime.datetime.utcnow()
        current_time_str = current_time.strftime('%Y-%m-%d %H:%M:%S UTC (naive)')
    print(f"[INFO] Current time: {current_time_str}")
    print(f"[INFO] Location context: San Diego, CA, USA")
    print(f"[INFO] Using Base Model: {MODEL_ID}")

    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    # --- 1. Load Data ---
    try:
        train_raw = load_jsonl(TRAIN_FILE)
        val_raw = load_jsonl(VAL_FILE)
        test_raw = load_jsonl(TEST_FILE)
    except Exception as e: print(f"[FATAL] Data loading failed: {e}", file=sys.stderr); sys.exit(1)
    if not train_raw or not val_raw or not test_raw: print("[FATAL] Datasets empty after loading.", file=sys.stderr); sys.exit(1)

    # --- 2. Prepare Data Strings ---
    train_src, train_tgt = convert_tokens_to_strings(train_raw)
    val_src,   val_tgt   = convert_tokens_to_strings(val_raw)
    test_src,  test_tgt  = convert_tokens_to_strings(test_raw)
    if not train_src or not val_src or not test_src: print("[FATAL] Data conversion failed.", file=sys.stderr); sys.exit(1)

    # --- 3. Initialize Tokenizer ---
    global tokenizer_for_metrics # For metrics function access
    try:
        tokenizer = T5Tokenizer.from_pretrained(TOKENIZER_ID)
        # T5 generally uses <pad>, </s>, <unk> - check if needed
        if tokenizer.pad_token is None: tokenizer.add_special_tokens({'pad_token': '<pad>'})
        # T5 uses EOS as decoder start token ID by default, check if BOS needed
        # if tokenizer.bos_token is None: tokenizer.add_special_tokens({'bos_token': '<s>'})
        tokenizer_for_metrics = tokenizer
        pad_token_id = tokenizer.pad_token_id
    except Exception as e: print(f"[FATAL] Tokenizer init failed: {e}", file=sys.stderr); sys.exit(1)
    print(f"[INFO] Tokenizer initialized (Vocab: {tokenizer.vocab_size}, Pad ID: {pad_token_id}).")

    # --- 4. Load Base Model and Patch with MoE ---
    print(f"[INFO] Loading base model: {MODEL_ID}")
    try:
        model = T5ForConditionalGeneration.from_pretrained(MODEL_ID)

        # Patch the loaded model
        model = replace_ffn_with_moe(model, n_experts=NUM_EXPERTS)

        # Resize embeddings if vocab changed
        if model.config.vocab_size != len(tokenizer):
             print(f"[INFO] Resizing model embeddings to {len(tokenizer)}")
             model.resize_token_embeddings(len(tokenizer))
             model.config.vocab_size = len(tokenizer)

        # Configure standard seq2seq settings (T5 specific)
        model.config.decoder_start_token_id = tokenizer.pad_token_id # T5 starts decoder with pad
        model.config.eos_token_id = tokenizer.eos_token_id
        model.config.pad_token_id = tokenizer.pad_token_id

        # Configure generation defaults
        model.config.max_length = MAX_SEQ_LENGTH
        model.config.num_beams = 1 # Default to greedy for logits; beams set in TrainingArgs if needed

        # Enable Gradient Checkpointing if configured
        if USE_GRADIENT_CHECKPOINTING:
            try:
                 model.gradient_checkpointing_enable()
                 print("[INFO] Gradient Checkpointing enabled on the model.")
            except Exception as gc_e:
                 print(f"[Warning] Failed to enable gradient checkpointing on model: {gc_e}")
                 global USE_GRADIENT_CHECKPOINTING_EFFECTIVE
                 USE_GRADIENT_CHECKPOINTING_EFFECTIVE = False
        else:
             USE_GRADIENT_CHECKPOINTING_EFFECTIVE = False


    except Exception as e:
        print(f"[FATAL] Failed to initialize or patch the T5 model: {e}", file=sys.stderr)
        sys.exit(1)


    # --- 5. Tokenize Data ---
    try:
        train_encodings = encode_sequences(tokenizer, train_src, train_tgt, MAX_SEQ_LENGTH)
        val_encodings   = encode_sequences(tokenizer, val_src,   val_tgt,   MAX_SEQ_LENGTH)
        test_encodings  = encode_sequences(tokenizer, test_src,  test_tgt,  MAX_SEQ_LENGTH)
    except Exception as e: print(f"[FATAL] Data tokenization failed: {e}", file=sys.stderr); sys.exit(1)
    if not train_encodings.get('input_ids') or not val_encodings.get('input_ids') or not test_encodings.get('input_ids'):
        print("[FATAL] Tokenization resulted in empty encodings.", file=sys.stderr); sys.exit(1)

    # --- 6. Create Datasets ---
    try:
        train_dataset = SequencePairDataset(train_encodings)
        val_dataset   = SequencePairDataset(val_encodings)
        test_dataset  = SequencePairDataset(test_encodings)
    except ValueError as e: print(f"[FATAL] Dataset creation failed: {e}", file=sys.stderr); sys.exit(1)

    # --- 7. Initialize Data Collator ---
    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=model,
        label_pad_token_id=-100, # Ignore pad tokens in loss
        pad_to_multiple_of=8 if USE_FP16 else None
    )
    print("[INFO] Data collator initialized.")

    # --- 8. Define Metrics Computation (from Logits) ---
    def compute_metrics_fn(eval_pred):
        """Calculates exact sequence match accuracy from logits."""
        logits, labels = eval_pred # Trainer passes logits when predict_with_generate=False
        if not isinstance(logits, np.ndarray): logits = np.array(logits)
        if not isinstance(labels, np.ndarray): labels = np.array(labels)

        # Replace -100 in labels for accuracy calculation comparison if needed,
        # but accuracy function should handle pad_token_id masking.
        # labels_for_acc = np.where(labels != -100, labels, tokenizer_for_metrics.pad_token_id)

        try:
            # Use PyTorch tensor version for consistency and potential device placement
            logits_torch = torch.from_numpy(logits)
            labels_torch = torch.from_numpy(labels)
            # Pass the original labels (with -100 potentially) and pad_token_id
            accuracy = calculate_sequence_accuracy_torch(logits_torch, labels_torch, tokenizer_for_metrics.pad_token_id)
            return {"sequence_accuracy": accuracy}
        except Exception as met_e:
             print(f"[Error] compute_metrics failed: {met_e}", file=sys.stderr)
             return {"sequence_accuracy": 0.0}

    def calculate_sequence_accuracy_torch(logits, labels, pad_token_id):
        """Calculates exact sequence match accuracy using PyTorch tensors."""
        # Ensure inputs are tensors
        if not isinstance(logits, torch.Tensor): logits = torch.tensor(logits)
        if not isinstance(labels, torch.Tensor): labels = torch.tensor(labels)

        if logits.shape[0] == 0: return 0.0
        predictions = torch.argmax(logits, dim=-1) # (Batch, SeqLen)

        # Create mask for non-padding tokens in labels (use pad_token_id directly)
        # Labels might contain -100, treat them as padding for accuracy check
        non_pad_mask = (labels != pad_token_id) & (labels != -100)

        correct_tokens = (predictions == labels) & non_pad_mask
        correct_sequences = (torch.sum(correct_tokens, dim=1) == torch.sum(non_pad_mask, dim=1))
        accuracy = torch.mean(correct_sequences.float())
        return accuracy.item()


    # --- 9. Define Training Arguments ---
    effective_gc = USE_GRADIENT_CHECKPOINTING if 'USE_GRADIENT_CHECKPOINTING_EFFECTIVE' not in globals() else USE_GRADIENT_CHECKPOINTING_EFFECTIVE
    training_args = Seq2SeqTrainingArguments(
        output_dir=str(OUTPUT_DIR),
        # Schedule
        num_train_epochs=NUM_TRAIN_EPOCHS,
        per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
        per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        warmup_steps=WARMUP_STEPS,
        # Logging / Saving / Evaluation
        logging_dir=str(OUTPUT_DIR / 'logs'),
        logging_steps=LOGGING_STEPS,
        evaluation_strategy=EVALUATION_STRATEGY,
        save_strategy=SAVE_STRATEGY,
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model="sequence_accuracy",
        greater_is_better=True,
        # Use logits for evaluation (faster if only accuracy needed)
        predict_with_generate=False,
        # Advanced features
        fp16=USE_FP16,
        label_smoothing_factor=LABEL_SMOOTHING_FACTOR,
        gradient_checkpointing=effective_gc,
        # report_to="tensorboard", # Optional
    )
    print("[INFO] Training arguments defined.")
    print(f"[INFO] Effective Mixed Precision (FP16): {'Enabled' if USE_FP16 else 'Disabled'}")
    print(f"[INFO] Effective Label Smoothing Factor: {LABEL_SMOOTHING_FACTOR}")
    print(f"[INFO] Effective Gradient Checkpointing in Args: {training_args.gradient_checkpointing}")

    # --- 10. Initialize Trainer ---
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics_fn,
    )
    print("[INFO] Seq2SeqTrainer initialized.")

    # --- 11. Train the Model ---
    print(f"[INFO] Starting model training ({NUM_TRAIN_EPOCHS} epochs)...")
    try:
        train_result = trainer.train()
        metrics = train_result.metrics
        trainer.log_metrics("train", metrics); trainer.save_metrics("train", metrics)
        trainer.save_state()
        print("[INFO] Training finished successfully.")
        if trainer.state.best_model_checkpoint: print(f"[INFO] Best model checkpoint: {trainer.state.best_model_checkpoint}")
    except Exception as e: print(f"[FATAL] Training failed: {e}", file=sys.stderr); sys.exit(1)

    # --- 12. Evaluate on Test Set ---
    print("[INFO] Evaluating final model on the test set...")
    try:
        test_results = trainer.evaluate(eval_dataset=test_dataset, metric_key_prefix="test")
        trainer.log_metrics("test", test_results); trainer.save_metrics("test", test_results)
        if 'test_sequence_accuracy' in test_results:
             print(f"\n--- Test Set Results ---"); print(f"Test Sequence Accuracy: {test_results['test_sequence_accuracy']:.4f}"); print(f"-----------------------------")
        else: print("[Warning] Test accuracy metric not found.", "Results:", test_results)
    except Exception as e: print(f"[Error] Final evaluation failed: {e}", file=sys.stderr); sys.exit(1)

    print("\n[INFO] Script finished successfully.")

if __name__ == "__main__":
    # Optional: Seeds for reproducibility
    # SEED = 42; random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
    # if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)
    main()

ImportError: cannot import name 'T5Block' from 'transformers' (/home/nikitas/anaconda3/envs/torch_cpu_env/lib/python3.10/site-packages/transformers/__init__.py)