In [1]:
import random
import pandas as pd
import os, sys, argparse, json
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor
from datasets import load_dataset, Dataset
import torch
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
import numpy as np
import wandb
from omegaconf import DictConfig, OmegaConf

In [None]:
def load_config(config_path: Optional[str] = None) -> DictConfig:
    """Load config from YAML file or use defaults."""
    
    # Default configuration
    default_config = OmegaConf.create({
        "TUNE_HYPERPARAMS": True,
        "INITIAL_WEIGHTS": None,
        "DATASET": "rvi_mdb", 
        "BASE_PATH": ".", # Set to directory containing datasets etc
        "EXPERIMENT_NAME": "tuned",
        "MAX_TRACE_LENGTH": 1000,
        "MAX_TRACE_EXPANSION": 5,
        "FIXED_HYPERPARAMS": {
            "num_train_steps": 1000,
            "group_by_length": False,
            "eval_steps": 10,
            "logging_steps": 10,
        },
        "DEFAULT_TUNABLE_HYPERPARAMS": {
            "learning_rate": 1e-3,
            "per_device_train_batch_size": 32,
            "weight_decay": 0.005,
            "warmup_steps": 100,
            "gradient_accumulation_steps": 1,
            "freeze_feature_encoder": True,
        },
        "SAMPLING_RATE": 200e6,
    })
    
    # Check command line if no path provided
    if config_path is None:
        config_path = OmegaConf.from_cli().get("config")
    
    # Load and merge YAML config
    if config_path and os.path.exists(config_path):
        yaml_config = OmegaConf.load(config_path)
        config = OmegaConf.merge(default_config, yaml_config)
        print(f"Loaded config: {config_path}")
    else:
        config = default_config
        print("Using default config")
    return config

# Simple one-liner setup
def setup_config(config_path: Optional[str] = None):
    """Load config and inject into calling module's globals."""
    config = load_config(config_path)
    sys._getframe(1).f_globals.update(config)

setup_config()

DATASET_PATH = os.path.join(BASE_PATH, f"datasets/{DATASET}")
MODEL_PATH = os.path.join(BASE_PATH, "runs")
EXPERIMENT_IDENTIFIER = f"{EXPERIMENT_NAME}-{DATASET}"
TARGET_PATH = os.path.join(MODEL_PATH,EXPERIMENT_IDENTIFIER)

Using default config


In [3]:
print(f"Hyperparameter tuning: {'ENABLED' if TUNE_HYPERPARAMS else 'DISABLED'}")
if INITIAL_WEIGHTS:
    print(f"Using initial weights from: {INITIAL_WEIGHTS}")
else:
    print("Using default pretrained weights")

# [Your existing vocabulary and data setup code remains the same]
# ... (vocabulary creation, tokenizer setup, dataset loading, etc.)

# Import hyperparameter tuning libraries only if needed
if TUNE_HYPERPARAMS:
    try:
        import optuna
        print("Optuna imported successfully for hyperparameter tuning")
    except ImportError:
        print("Warning: Optuna not installed. Install with: pip install optuna")
        TUNE_HYPERPARAMS = False

Hyperparameter tuning: ENABLED
Using default pretrained weights
Optuna imported successfully for hyperparameter tuning


## Set up key objects

In [None]:
# Set up vocabulary
path_vocab = os.path.join(DATASET_PATH, "tokenizer.vocab")
with open(path_vocab, "r") as f:
    data = f.readlines()
vocab_dict = {key.strip(): val for val, key in enumerate(data)}
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
vocab_dict[" "] = len(vocab_dict) # Delimiter, is it needed?

path_vocab = os.path.join(DATASET_PATH,'vocab.json')
with open(path_vocab, 'w') as f:
    json.dump(vocab_dict, f)

In [None]:
# Set up key model parts

tokenizer = Wav2Vec2CTCTokenizer(path_vocab, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token=" ")

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=SAMPLING_RATE, padding_value=0.0, do_normalize=True, return_attention_mask=False)

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

print(f"Length of tokenizer: {len(tokenizer)}")

In [None]:
def expand_dataset_with_repetitions(dataset, shuffle=True):
    """
    Expand each matrix entry into N separate dataset entries, each with the same labels
    """
    expanded_data = {
        "input_values": [],
        "input_length": [],
        "labels": [],
        "program_label": []  # Keep original text labels for reference
    }
    
    for example in dataset:
        # Load numpy trace data
        traces = np.load(example["traces"])

        if traces.shape[1] > MAX_TRACE_LENGTH:
            continue
        
        # Handle both single traces and matrices
        if traces.ndim == 1:
            # Single trace - keep as one entry
            traces = [traces]
        else:
            # Matrix - each row becomes a separate entry
            traces = [traces[j] for j in range(traces.shape[0]) if j < MAX_TRACE_EXPANSION]
        
        # Process labels once for this example
        processed_labels = processor(text=example["program_label"]).input_ids
        
        # Create separate dataset entries for each trace repetition
        for trace in traces:
            # Process the trace
            processed_trace = processor(trace, sampling_rate=SAMPLING_RATE).input_values[0]
            
            # Add to expanded dataset
            expanded_data["input_values"].append(processed_trace)
            expanded_data["input_length"].append(len(processed_trace))
            expanded_data["labels"].append(processed_labels)
            expanded_data["program_label"].append(example["program_label"])
        
    if shuffle:
        # Convert to list of indices and shuffle
        indices = list(range(len(expanded_data["input_values"])))
        random.shuffle(indices)
        
        # Reorder all lists according to shuffled indices
        for key in expanded_data:
            expanded_data[key] = [expanded_data[key][i] for i in indices]
    
    return expanded_data

In [None]:
# Set up data

splits = ["train", "val", "test"]
data_paths = [os.path.join(DATASET_PATH, f"{split}.json") for split in splits]
data = load_dataset("json", data_files={key:val for key, val in zip(["train","validation","test"],data_paths)}, field="batches")

print("Expanding datasets with repetitions...")

# Expand each split
expanded_datasets = {}
for split in ["train", "validation", "test"]:
    if split in data:
        print(f"Expanding {split} split...")
        expanded_data = expand_dataset_with_repetitions(data[split])
        expanded_datasets[split] = Dataset.from_dict(expanded_data)
        print(f"  {split}: {len(data[split])} -> {len(expanded_datasets[split])} examples")

# Create the final dataset
from datasets import DatasetDict
dataset = DatasetDict(expanded_datasets)

In [None]:
@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

In [None]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [None]:
def print_debug(pred_str, pred_logits, pred_ids, pred, label_str):
    # === DEBUGGING SECTION ===
    print("\n=== COMPUTE_METRICS DEBUG ===")
    print(f"Batch size: {len(pred_str)}")
    print(f"Pred logits shape: {pred_logits.shape}")
    print(f"Pred ids shape: {pred_ids.shape}")
    print(f"Label ids shape: {pred.label_ids.shape}")
    
    import random
    sample_indices = random.sample(range(len(pred_str)), 3)

    # Check raw predictions and labels
    print(f"\nRaw pred_ids (first 3 samples, first 10 tokens):")
    for i in sample_indices:
        print(f"  Sample {i}: {pred_ids[i][:10]}")
    
    print(f"\nRaw label_ids (first 3 samples, first 10 tokens):")
    for i in sample_indices:
        print(f"  Sample {i}: {pred.label_ids[i][:10]}")
    
    # Check decoded strings
    print(f"\nDecoded predictions (first 3):")
    for i in sample_indices:
        print(f"  Pred {i}: '{pred_str[i]}'")
    
    print(f"\nDecoded labels (first 3):")
    for i in sample_indices:
        print(f"  Label {i}: '{label_str[i]}'")
    
    # Check vocabulary usage
    flat_pred_ids = pred_ids.flatten()
    flat_label_ids = pred.label_ids.flatten()
    flat_label_ids = flat_label_ids[flat_label_ids != -100]  # Remove padding
    
    print(f"\nVocabulary analysis:")
    print(f"  Processor vocab size: {processor.tokenizer.vocab_size}")
    print(f"  Unique pred tokens used: {len(np.unique(flat_pred_ids))}")
    print(f"  Pred token range: {flat_pred_ids.min()}-{flat_pred_ids.max()}")
    print(f"  Unique label tokens: {len(np.unique(flat_label_ids))}")
    print(f"  Label token range: {flat_label_ids.min()}-{flat_label_ids.max()}")
    
    # Check for common issues
    if len(set(pred_str)) == 1:
        print(f"  WARNING: All predictions are identical: '{pred_str[0]}'")
    
    if len(set(label_str)) <= 5:
        print(f"  INFO: Only {len(set(label_str))} unique labels in batch: {set(label_str)}")
    
    # Check for empty predictions/labels
    empty_preds = sum(1 for p in pred_str if not p.strip())
    empty_labels = sum(1 for l in label_str if not l.strip())
    print(f"  Empty predictions: {empty_preds}/{len(pred_str)}")
    print(f"  Empty labels: {empty_labels}/{len(label_str)}")
    
    # Sample token-by-token comparison
    print(f"\nSample comparisons (first 3):")
    for i in range(min(3, len(pred_str))):
        print(f"  Sample {i}:")
        print(f"    Pred: '{pred_str[i]}'")
        print(f"    Label: '{label_str[i]}'")
        print(f"    Match: {pred_str[i] == label_str[i]}")
    
    print("=" * 40)

def compute_metrics(pred, verbose=True):
    import evaluate

    # Load the WER metric
    wer_metric = evaluate.load("wer")

    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    # Create a copy to avoid modifying original
    label_ids = pred.label_ids.copy()
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids, spaces_between_special_tokens=True)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(label_ids, group_tokens=False, spaces_between_special_tokens=True)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    if verbose:
        print_debug(pred_str, pred_logits, pred_ids, pred, label_str)

    return {"wer": wer}

## Set up training and tuning functions

In [None]:
def get_model_path():
    """Get the model path to use for initialization"""
    if INITIAL_WEIGHTS and os.path.exists(INITIAL_WEIGHTS):
        return INITIAL_WEIGHTS
    else:
        default_path = os.path.join(MODEL_PATH, f"wav2vec2-pretrain-{DATASET}")
        if not os.path.exists(default_path) and INITIAL_WEIGHTS:
            print(f"Warning: Specified initial weights path {INITIAL_WEIGHTS} not found, using default")
        return default_path

def create_model(hyperparams):
    """Create and configure model based on hyperparameters"""
    from transformers import Wav2Vec2ForCTC
    
    model_path = get_model_path()
    print(f"Loading model from: {model_path}")
    
    model = Wav2Vec2ForCTC.from_pretrained(
        model_path,
        ctc_loss_reduction="mean", 
        pad_token_id=processor.tokenizer.pad_token_id,
    )
    
    if hyperparams.get("freeze_feature_encoder", True):
        model.freeze_feature_encoder()
        print("Feature encoder frozen")
    else:
        print("Feature encoder unfrozen")
    
    return model

def create_training_args(hyperparams):
    """Create training arguments from hyperparameters"""
    from transformers import TrainingArguments
    
    # Combine tunable and fixed hyperparameters
    all_hyperparams = {**FIXED_HYPERPARAMS, **hyperparams}
    
    return TrainingArguments(
        output_dir=TARGET_PATH,
        learning_rate=all_hyperparams["learning_rate"],
        per_device_train_batch_size=all_hyperparams["per_device_train_batch_size"],
        weight_decay=all_hyperparams["weight_decay"],
        warmup_steps=all_hyperparams["warmup_steps"],
        gradient_accumulation_steps=all_hyperparams["gradient_accumulation_steps"],
        max_steps=all_hyperparams["num_train_steps"],
        group_by_length=all_hyperparams["group_by_length"],
        eval_strategy="steps",
        eval_steps=all_hyperparams["eval_steps"],
        logging_steps=all_hyperparams["logging_steps"],
        fp16=True,
        gradient_checkpointing=True,
        save_steps=all_hyperparams["eval_steps"] if not TUNE_HYPERPARAMS else 10e6, # De-Facto - don't save
        save_total_limit=1 if not TUNE_HYPERPARAMS else 0,
        load_best_model_at_end=not TUNE_HYPERPARAMS,
        metric_for_best_model="eval_wer",
        greater_is_better=False,
        report_to=["wandb"],
    )


def train_model(hyperparams, verbose=True):
    """Train model with given hyperparameters"""
    from transformers import Trainer
    from datetime import datetime
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")

    run = wandb.init(
        project=EXPERIMENT_IDENTIFIER,
        name=timestamp,
    )

    try:
        # Create model and training args
        model = create_model(hyperparams)
        training_args = create_training_args(hyperparams)

        # Select dataset (subset for tuning, full for normal training)
        train_data = dataset["train"]
        eval_data = dataset["validation"]

        # Create trainer
        trainer = Trainer(
            model=model,
            data_collator=data_collator,
            args=training_args,
            compute_metrics=lambda pred: compute_metrics(pred, verbose=verbose),
            train_dataset=train_data,
            eval_dataset=eval_data,
            processing_class=processor.feature_extractor,
        )

        # Train
        trainer.train()

        # Evaluate BEFORE finishing wandb session
        eval_result = trainer.evaluate()

        result = {
            "wer": eval_result["eval_wer"],
            "eval_loss": eval_result["eval_loss"],
            "trainer": trainer,
            "hyperparams": hyperparams
        }

        if not TUNE_HYPERPARAMS: # avoid multiprocessing issue for optuna
            result["model"] = model

        return result

    finally:
        # Always finish wandb session, even if there's an exception
        run.finish()

def hyperparameter_tuning():
    """Perform hyperparameter tuning using Optuna"""
    print("Starting hyperparameter tuning...")
    print(f"Fixed parameters: {FIXED_HYPERPARAMS}")

    def objective(trial):
        # Clear any CUDA cache at start of each trial
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            
        # Define search space (only tunable parameters)
        hyperparams = {
            "learning_rate": trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True),
            "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [32, 64, 128, 256]),
            "weight_decay": trial.suggest_float("weight_decay", 0.001, 0.1),
            "warmup_steps": trial.suggest_int("warmup_steps", 50, 500),
            "gradient_accumulation_steps": trial.suggest_categorical("gradient_accumulation_steps", [1, 2]),
            "freeze_feature_encoder": trial.suggest_categorical("freeze_feature_encoder", [True, False]),
        }

        print(f"Trial {trial.number}: {hyperparams}")

        # Train with subset
        result = train_model(
            hyperparams, 
            verbose=False,
        )

        # Return eval_loss instead of WER
        print(f"Trial {trial.number} eval_loss: {result['eval_loss']:.4f}, WER: {result['wer']:.4f}")
        return result["eval_loss"]

    # Create study - direction should be "minimize" for loss (lower is better)
    
    db_filename = "optuna.db"
    db_path = os.path.join(TARGET_PATH, db_filename)
    os.makedirs(TARGET_PATH, exist_ok=True)

    study = optuna.create_study(
        direction="minimize",
        study_name="wav2vec2_tuning_loss",
        storage=f"sqlite:///{db_path}",
        load_if_exists=True
    )

    # Optimize
    study.optimize(objective, n_trials=200, timeout=3600*12)  # 4 hours max

    print("Hyperparameter tuning completed!")
    print(f"Best eval_loss: {study.best_value:.4f}")  # Changed message
    print("Best parameters:")
    for key, value in study.best_params.items():
        print(f"  {key}: {value}")

    print("Fixed parameters (not tuned):")
    for key, value in FIXED_HYPERPARAMS.items():
        print(f"  {key}: {value}")

    # Save best parameters (only tunable ones)
    best_hyperparams = study.best_params.copy()

    return best_hyperparams

IndentationError: expected an indented block after 'else' statement on line 75 (1286412964.py, line 78)

In [None]:
if TUNE_HYPERPARAMS:
    # Run hyperparameter tuning
    best_hyperparams = hyperparameter_tuning()
    
    print(f"Finished Optuna run.\n")
else:
    # Normal training mode
    print("Running normal training...")
    
    print("Using default hyperparameters")
    hyperparams = DEFAULT_TUNABLE_HYPERPARAMS
    
    # Train model
    result = train_model(hyperparams, verbose=True)
    
    # Save model
    result["model"].save_pretrained(TARGET_PATH)
    processor.save_pretrained(TARGET_PATH)
    print(f"Model saved to {TARGET_PATH}")
    print(f"Final WER: {result['wer']:.4f}")
    print("Training completed!")