## Model Training Pipeline: ESM2 with LoRA Fine-tuning

This notebook handles the complete model training pipeline with flexible experimentation support:

1. **Environment Setup** - Check GPU and load base configuration
2. **Training Configuration** - Override hyperparameters for experimentation
3. **Data Loading** - Load tokenized datasets
4. **Model Initialization** - Load ESM2 model with LoRA adapters
5. **Training** - Train with automatic checkpointing and W&B logging

**Quick Start:**
- First time: Run all cells to train for configured epochs
- Resume training: Set `RESUME_TRAINING = True` to continue from last checkpoint
- New experiment: Set `RESUME_TRAINING = False` to start fresh training
- Change hyperparameters: Modify the "Training Configuration" cell

### Import Libraries

In [1]:
import yaml
import torch
import wandb
import pandas as pd
from pathlib import Path
from typing import Dict, Optional

from transformers import (
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    TrainerCallback
)
from peft import LoraConfig, get_peft_model
from datasets import load_from_disk
from evaluate import load
from sklearn.metrics import precision_recall_fscore_support

# Enable cuDNN benchmark for performance
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True

## 1. Environment Setup

In [2]:
# Check GPU availability
print("=" * 70)
print("ENVIRONMENT CHECK")
print("=" * 70)

if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    cuda_version = torch.version.cuda
    total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    
    print(f"GPU detected: {device_name}")
    print(f"  CUDA version: {cuda_version}")
    print(f"  Available memory: {total_memory:.2f} GB")
    print(f"  Number of GPUs: {torch.cuda.device_count()}")
else:
    print("WARNING: No GPU detected!")
    print("  Training will be extremely slow on CPU.")

print(f"\nPyTorch version: {torch.__version__}")
print("=" * 70)

ENVIRONMENT CHECK
GPU detected: NVIDIA GeForce RTX 5060
  CUDA version: 12.9
  Available memory: 8.55 GB
  Number of GPUs: 1

PyTorch version: 2.8.0+cu129


In [3]:
# Load base configuration from config.yaml
print("\n=== LOADING BASE CONFIGURATION ===")

with open("../config.yaml", "r") as f:
    cfg = yaml.safe_load(f)

print("\nBase configuration loaded from config.yaml")
print(f"  Model: {cfg['model']['name']}")
print("\nTip: Modify the next cell to override hyperparameters for experimentation")


=== LOADING BASE CONFIGURATION ===

Base configuration loaded from config.yaml
  Model: facebook/esm2_t12_35M_UR50D

Tip: Modify the next cell to override hyperparameters for experimentation


## 2. Training Configuration

**Modify this cell for experimentation**

### Resume vs New Training
- `RESUME_TRAINING = True` - Continue from last checkpoint with same hyperparameters
- `RESUME_TRAINING = False` - Start fresh training (creates new W&B run)

### Hyperparameter Overrides
Set any to `None` to use value from `config.yaml`

In [4]:
# ============================================================================
# EXPERIMENT CONFIGURATION - MODIFY THIS FOR YOUR EXPERIMENTS
# ============================================================================

# Resume from checkpoint or start fresh?
RESUME_TRAINING = False  # Set to True to continue training from last checkpoint

# Experiment name (used for W&B run name and checkpoint directory)
EXPERIMENT_NAME = None  # e.g., "high_lr_experiment" or None to use config

# Training hyperparameters (None = use config.yaml value)
EPOCHS = None           # e.g., 5 for 5 epochs total
LEARNING_RATE = None    # e.g., 5e-4 for higher learning rate
BATCH_SIZE = None       # e.g., 8 for smaller batch
GRAD_ACCUM_STEPS = None # e.g., 4 for gradient accumulation

# LoRA hyperparameters (None = use config.yaml value)
LORA_R = None          # e.g., 16 for rank 16
LORA_ALPHA = None      # e.g., 32 for alpha 32
LORA_DROPOUT = None    # e.g., 0.1 for 10% dropout

# Advanced options
WARMUP_RATIO = 0.1     # Learning rate warmup (10% of training)
FP16 = None            # Mixed precision (None = use config.yaml)

# ============================================================================

# Apply overrides
print("\n=== EXPERIMENT CONFIGURATION ===")
print(f"\nMode: {'RESUME TRAINING' if RESUME_TRAINING else 'NEW TRAINING'}")

# Create experiment-specific checkpoint directory
if EXPERIMENT_NAME:
    base_checkpoint_dir = Path('..') / 'models' / 'experiments' / EXPERIMENT_NAME
    run_name = EXPERIMENT_NAME
else:
    base_checkpoint_dir = Path('..') / cfg['paths']['checkpoints']
    run_name = cfg['wandb']['run_name']

base_checkpoint_dir.mkdir(parents=True, exist_ok=True)

# Override training config
if EPOCHS is not None:
    cfg['training']['epochs'] = EPOCHS
    print(f"  Epochs: {EPOCHS} (overridden)")
else:
    print(f"  Epochs: {cfg['training']['epochs']} (from config)")

if LEARNING_RATE is not None:
    cfg['training']['learning_rate'] = LEARNING_RATE
    print(f"  Learning rate: {LEARNING_RATE} (overridden)")
else:
    print(f"  Learning rate: {cfg['training']['learning_rate']} (from config)")

if BATCH_SIZE is not None:
    cfg['training']['batch_size'] = BATCH_SIZE
    print(f"  Batch size: {BATCH_SIZE} (overridden)")
else:
    print(f"  Batch size: {cfg['training']['batch_size']} (from config)")

if GRAD_ACCUM_STEPS is not None:
    cfg['training']['grad_accum_steps'] = GRAD_ACCUM_STEPS
    print(f"  Gradient accumulation: {GRAD_ACCUM_STEPS} (overridden)")
else:
    print(f"  Gradient accumulation: {cfg['training']['grad_accum_steps']} (from config)")

if FP16 is not None:
    cfg['training']['fp16'] = FP16
    print(f"  FP16: {FP16} (overridden)")
else:
    print(f"  FP16: {cfg['training']['fp16']} (from config)")

# Override LoRA config
if LORA_R is not None:
    cfg['model']['lora']['r'] = LORA_R
    print(f"\n  LoRA rank: {LORA_R} (overridden)")
else:
    print(f"\n  LoRA rank: {cfg['model']['lora']['r']} (from config)")

if LORA_ALPHA is not None:
    cfg['model']['lora']['alpha'] = LORA_ALPHA
    print(f"  LoRA alpha: {LORA_ALPHA} (overridden)")
else:
    print(f"  LoRA alpha: {cfg['model']['lora']['alpha']} (from config)")

if LORA_DROPOUT is not None:
    cfg['model']['lora']['dropout'] = LORA_DROPOUT
    print(f"  LoRA dropout: {LORA_DROPOUT} (overridden)")
else:
    print(f"  LoRA dropout: {cfg['model']['lora']['dropout']} (from config)")

print(f"\nCheckpoint directory: {base_checkpoint_dir}")
print(f"W&B run name: {run_name}")
print("\nConfiguration ready")


=== EXPERIMENT CONFIGURATION ===

Mode: NEW TRAINING
  Epochs: 3 (from config)
  Learning rate: 2e-4 (from config)
  Batch size: 8 (from config)
  Gradient accumulation: 1 (from config)
  FP16: True (from config)

  LoRA rank: 8 (from config)
  LoRA alpha: 16 (from config)
  LoRA dropout: 0.1 (from config)

Checkpoint directory: ..\models\checkpoints
W&B run name: esm2-lora-baseline

Configuration ready


## 3. Load Tokenized Datasets

In [5]:
print("\n=== LOADING TOKENIZED DATASETS ===")

train_path = "../data/tokenized/train_dataset"
val_path = "../data/tokenized/val_dataset"

train_ds = load_from_disk(train_path).with_format("torch")
val_ds = load_from_disk(val_path).with_format("torch")

print(f"Train samples: {len(train_ds):,}")
print(f"Validation samples: {len(val_ds):,}")

# Quick class distribution check
train_labels = pd.Series([int(label) for label in train_ds['labels']])
print(f"\nClass distribution:")
print(train_labels.value_counts().sort_index())

print("\nDatasets loaded")


=== LOADING TOKENIZED DATASETS ===
Train samples: 111,434
Validation samples: 23,879

Class distribution:
0    21000
1    21000
2    21000
3    14141
4     9020
5    16797
6     8476
Name: count, dtype: int64

Datasets loaded


## 4. Model Initialization

**LoRA (Low-Rank Adaptation)** adds small trainable adapter layers to the frozen base model:
- Trains <1% of parameters
- Reduces memory and training time
- Preserves pre-trained knowledge

In [6]:
print("\n=== LOADING BASE MODEL ===")

model_name = cfg['model']['name']
num_labels = cfg['model']['num_labels']
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Model: {model_name}")
print(f"Device: {device}")

# Load base model
print("\nLoading base model...")
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=num_labels,
    trust_remote_code=True
)

total_params = sum(p.numel() for p in model.parameters())
print(f"Base model loaded ({total_params:,} parameters)")


=== LOADING BASE MODEL ===
Model: facebook/esm2_t12_35M_UR50D
Device: cuda

Loading base model...


Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Base model loaded (33,503,768 parameters)


In [7]:
print("\n=== APPLYING LoRA ADAPTERS ===")

lora_cfg = cfg['model']['lora']

lora_config = LoraConfig(
    r=lora_cfg['r'],
    lora_alpha=lora_cfg['alpha'],
    target_modules=lora_cfg['target_modules'],
    lora_dropout=lora_cfg['dropout'],
    bias=lora_cfg['bias'],
    task_type="SEQ_CLS"
)

print(f"LoRA config: r={lora_config.r}, alpha={lora_config.lora_alpha}, dropout={lora_config.lora_dropout}")

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

model = model.to(device)
print(f"\nModel ready on {device}")


=== APPLYING LoRA ADAPTERS ===
LoRA config: r=8, alpha=16, dropout=0.1
trainable params: 418,567 || all params: 33,922,335 || trainable%: 1.2339

Model ready on cuda


## 5. W&B Setup and Training Utilities

In [8]:
def get_or_create_run_id(checkpoint_dir: Path, resume: bool) -> tuple[str, str]:
    """
    Get existing run ID or create new one for W&B.
    
    Returns:
        (run_id, resume_mode) tuple
    """
    run_id_file = checkpoint_dir / 'wandb_run_id.txt'
    
    if resume and run_id_file.exists():
        with open(run_id_file, 'r') as f:
            run_id = f.read().strip()
        print(f"Resuming W&B run: {run_id}")
        return run_id, "must"
    else:
        run_id = wandb.util.generate_id()
        checkpoint_dir.mkdir(parents=True, exist_ok=True)
        with open(run_id_file, 'w') as f:
            f.write(run_id)
        print(f"Starting new W&B run: {run_id}")
        return run_id, "never"

def flatten_config(config: Dict, parent_key: str = '') -> Dict:
    """Flatten nested config for W&B."""
    items = []
    for k, v in config.items():
        new_key = f"{parent_key}/{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_config(v, new_key).items())
        else:
            items.append((new_key, v))
    return dict(items)

def get_latest_checkpoint(checkpoint_dir: Path) -> Optional[Path]:
    """Get the latest checkpoint directory."""
    if not checkpoint_dir.exists():
        return None
    
    checkpoints = [p for p in checkpoint_dir.iterdir() 
                   if p.is_dir() and p.name.startswith('checkpoint-')]
    
    if not checkpoints:
        return None
    
    # Sort by checkpoint number
    def get_checkpoint_num(path):
        try:
            return int(path.name.split('-')[-1])
        except:
            return -1
    
    return max(checkpoints, key=get_checkpoint_num)

class WandBCallbackEnhanced(TrainerCallback):
    """Enhanced W&B callback for richer logging."""
    
    def __init__(self, class_names: list):
        self.class_names = class_names
        self.best_f1_macro = 0
        self.best_accuracy = 0
    
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs and wandb.run:
            if 'learning_rate' in logs:
                wandb.log({"train/learning_rate": logs['learning_rate']}, step=state.global_step)
    
    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if metrics and wandb.run:
            # Track best metrics
            if 'eval_f1_macro' in metrics:
                if metrics['eval_f1_macro'] > self.best_f1_macro:
                    self.best_f1_macro = metrics['eval_f1_macro']
                    wandb.run.summary['best_f1_macro'] = self.best_f1_macro
            
            if 'eval_accuracy' in metrics:
                if metrics['eval_accuracy'] > self.best_accuracy:
                    self.best_accuracy = metrics['eval_accuracy']
                    wandb.run.summary['best_accuracy'] = self.best_accuracy

print("W&B utilities loaded")

W&B utilities loaded


## 6. Detailed Metrics Setup

- **Accuracy**: Overall correctness (Number of correct predictions) / (Total number of predictions)
- **Macro F1**: Average F1 across all classes (treats each class equally)
- **Micro F1**: Global F1 (weighted by class frequency, similar to accuracy)
- **Per-class F1**: Shows performance on individual enzyme classes

**Macro F1 is used as the key metric** - Treats each enzyme class equally, best for balanced evaluation across all 7 classes

In [9]:
print("\n=== SETTING UP DETAILED METRICS ===")

# EC class names for reference
EC_NAMES = [
    "Oxidoreductases",  # EC 1
    "Transferases",     # EC 2
    "Hydrolases",       # EC 3
    "Lyases",           # EC 4
    "Isomerases",       # EC 5
    "Ligases",          # EC 6
    "Translocases"      # EC 7
]

def compute_metrics(eval_pred):
    """
    Compute comprehensive metrics for multi-class classification.
    
    Returns:
        - accuracy: Overall accuracy
        - f1_macro: Macro-averaged F1 (treats each class equally)
        - f1_micro: Micro-averaged F1 (weighted by frequency)
        - precision_macro, recall_macro: Macro averages
        - precision_micro, recall_micro: Micro averages
        - f1_class_N: Per-class F1 scores for each enzyme class
    """
    logits, labels = eval_pred
    predictions = logits.argmax(-1)
    
    # Basic accuracy
    accuracy = (predictions == labels).mean()
    
    # Macro metrics (average across classes - treats each class equally)
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        labels, predictions, average='macro', zero_division=0
    )
    
    # Micro metrics (global average - weighted by frequency)
    precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
        labels, predictions, average='micro', zero_division=0
    )
    
    # Per-class metrics
    precision_per_class, recall_per_class, f1_per_class, support = precision_recall_fscore_support(
        labels, predictions, average=None, zero_division=0
    )
    
    # Build comprehensive metrics dictionary
    metrics = {
        'accuracy': float(accuracy),
        
        # Macro metrics
        'f1_macro': float(f1_macro),
        'precision_macro': float(precision_macro),
        'recall_macro': float(recall_macro),
        
        # Micro metrics
        'f1_micro': float(f1_micro),
        'precision_micro': float(precision_micro),
        'recall_micro': float(recall_micro),
    }
    
    # Add per-class F1 scores
    for i, (f1, name) in enumerate(zip(f1_per_class, EC_NAMES)):
        metrics[f'f1_class_{i}_{name[:8]}'] = float(f1)
    
    return metrics

print("\nMetrics configured:")
print("  ✓ Accuracy (overall)")
print("  ✓ Macro F1 (treats each class equally)")
print("  ✓ Micro F1 (weighted by class frequency)")
print("  ✓ Precision & Recall (both macro and micro)")
print("  ✓ Per-class F1 scores for all 7 enzyme classes")


=== SETTING UP DETAILED METRICS ===

Metrics configured:
  ✓ Accuracy (overall)
  ✓ Macro F1 (treats each class equally)
  ✓ Micro F1 (weighted by class frequency)
  ✓ Precision & Recall (both macro and micro)
  ✓ Per-class F1 scores for all 7 enzyme classes


## 7. Training Configuration

**If `RESUME_TRAINING = False`:**
- Starts fresh training
- Creates new W&B run
- Trains for configured epochs
- Saves checkpoints along the way

**If `RESUME_TRAINING = True`:**
- Continues from last checkpoint
- Resumes existing W&B run
- Trains for remaining epochs

**Checkpoints:**
- Saved at end of each epoch (if `save_strategy="epoch"`)
- Only keeps last N checkpoints (set by `save_total_limit`)
- Best model automatically loaded at end (based on f1_macro)

In [10]:
print("\n=== CONFIGURING TRAINING ===")

# Training arguments
training_args = TrainingArguments(
    output_dir=str(base_checkpoint_dir),
    
    # Training
    num_train_epochs=cfg['training']['epochs'],
    per_device_train_batch_size=cfg['training']['batch_size'],
    per_device_eval_batch_size=cfg['training']['batch_size'],
    gradient_accumulation_steps=cfg['training']['grad_accum_steps'],
    learning_rate=float(cfg['training']['learning_rate']),
    warmup_ratio=WARMUP_RATIO,
    fp16=cfg['training']['fp16'] and torch.cuda.is_available(),
    
    # Checkpointing
    save_strategy=cfg['training']['save_strategy'],
    save_total_limit=cfg['training']['save_total_limit'],
    load_best_model_at_end=cfg['training']['load_best_model_at_end'],
    
    # Evaluation - USE F1_MACRO as primary metric
    eval_strategy=cfg['training']['eval_strategy'],
    metric_for_best_model="f1_macro",
    greater_is_better=True,
    
    # Logging
    logging_dir=str(base_checkpoint_dir / "logs"),
    logging_steps=10,
    logging_first_step=True,
    report_to=["wandb"] if cfg['wandb']['enabled'] else [],
    
    # Data loading
    dataloader_num_workers=cfg['training']['num_workers'],
    dataloader_pin_memory=cfg['training']['pin_memory'],
    dataloader_persistent_workers=cfg['training']['persistent_workers'],
)

effective_batch = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
steps_per_epoch = len(train_ds) // effective_batch
total_steps = steps_per_epoch * training_args.num_train_epochs

print(f"Effective batch size: {effective_batch}")
print(f"Steps per epoch: {steps_per_epoch:,}")
print(f"Total training steps: {total_steps:,}")
print("\nBest model selection: Based on MACRO F1 (not accuracy)")
print("This ensures the model performs well on ALL enzyme classes.")
print("\nTraining configured")


=== CONFIGURING TRAINING ===
Effective batch size: 8
Steps per epoch: 13,929
Total training steps: 41,787

Best model selection: Based on MACRO F1 (not accuracy)
This ensures the model performs well on ALL enzyme classes.

Training configured


## 8. Training

In [11]:
print("\n" + "=" * 70)
print("STARTING TRAINING")
print("=" * 70)

# EC class names
ec_names = [
    "Oxidoreductases", "Transferases", "Hydrolases",
    "Lyases", "Isomerases", "Ligases", "Translocases"
]

# Check for existing checkpoint
checkpoint_to_resume = None
if RESUME_TRAINING:
    checkpoint_to_resume = get_latest_checkpoint(base_checkpoint_dir)
    if checkpoint_to_resume:
        print(f"\n✓ Found checkpoint: {checkpoint_to_resume.name}")
        print("  Will resume training from this checkpoint\n")
    else:
        print("\n⚠ No checkpoint found. Starting fresh training\n")
else:
    print("\n✓ Starting fresh training (RESUME_TRAINING=False)\n")

# Initialize W&B
if cfg['wandb']['enabled']:
    run_id, resume_mode = get_or_create_run_id(base_checkpoint_dir, RESUME_TRAINING)
    
    wandb_context = wandb.init(
        project=cfg['wandb']['project'],
        name=run_name,
        id=run_id,
        resume=resume_mode,
        config=flatten_config(cfg),
        tags=['esm2', 'lora', 'protein-classification'],
        notes=f"Training {cfg['model']['name']} with LoRA (r={cfg['model']['lora']['r']})"
    )
else:
    from contextlib import nullcontext
    wandb_context = nullcontext()

with wandb_context:
    # Log model info to W&B
    if cfg['wandb']['enabled']:
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        wandb.run.summary.update({
            'model/total_parameters': total_params,
            'model/trainable_parameters': trainable_params,
            'model/trainable_percentage': 100 * trainable_params / total_params,
            'data/train_samples': len(train_ds),
            'data/val_samples': len(val_ds),
        })
    
    # Initialize trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        compute_metrics=compute_metrics,
        callbacks=[WandBCallbackEnhanced(ec_names)] if cfg['wandb']['enabled'] else []
    )
    
    # Train
    if checkpoint_to_resume:
        print(f"Resuming from: {checkpoint_to_resume}\n")
        train_result = trainer.train(resume_from_checkpoint=str(checkpoint_to_resume))
    else:
        train_result = trainer.train()
    
    # Training completed
    print("\n" + "=" * 70)
    print("TRAINING COMPLETED")
    print("=" * 70)
    
    print("\nFinal Metrics:")
    for key, value in train_result.metrics.items():
        metric_name = key.replace('train_', '').replace('_', ' ').title()
        if isinstance(value, float):
            print(f"  {metric_name}: {value:.4f}")
        else:
            print(f"  {metric_name}: {value}")
    
    # Save final model
    final_model_path = base_checkpoint_dir / "final_model"
    print(f"\nSaving final model to: {final_model_path}")
    trainer.save_model(str(final_model_path))
    
    # Log model as W&B artifact
    if cfg['wandb']['enabled']:
        print("\nLogging model to W&B artifacts...")
        artifact = wandb.Artifact(
            name=f"esm2-lora-ec-classifier",
            type="model",
            description=f"ESM2 with LoRA (r={cfg['model']['lora']['r']}) for EC classification",
            metadata={
                'epochs': cfg['training']['epochs'],
                'learning_rate': cfg['training']['learning_rate'],
                'lora_r': cfg['model']['lora']['r'],
            }
        )
        artifact.add_dir(str(final_model_path))
        wandb.log_artifact(artifact)
        print("✓ Model artifact logged")
    
    print("\n✓ Training complete!")
    print(f"\n   Model: {final_model_path}")
    if cfg['wandb']['enabled']:
        print(f"   W&B: {wandb.run.url}")
    print("\n" + "=" * 70)
    print("NEXT STEPS")
    print("=" * 70)
    print("\n1. To continue training for more epochs:")
    print("   - Set RESUME_TRAINING = True")
    print("   - Adjust EPOCHS to total desired epochs")
    print("   - Re-run this cell")
    print("\n2. To start a new experiment:")
    print("   - Set RESUME_TRAINING = False")
    print("   - Set EXPERIMENT_NAME to a new name")
    print("   - Adjust hyperparameters as needed")
    print("   - Re-run from 'Training Configuration' cell")
    print("\n3. For evaluation:")
    print("   - Run 03_evaluation.ipynb")


STARTING TRAINING

✓ Starting fresh training (RESUME_TRAINING=False)

Starting new W&B run: gd3mymsv


[34m[1mwandb[0m: Currently logged in as: [33mcristinalee0723[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,Precision Macro,Recall Macro,F1 Micro,Precision Micro,Recall Micro,F1 Class 0 Oxidored,F1 Class 1 Transfer,F1 Class 2 Hydrolas,F1 Class 3 Lyases,F1 Class 4 Isomeras,F1 Class 5 Ligases,F1 Class 6 Transloc
1,0.4096,0.234101,0.94179,0.941408,0.946499,0.938144,0.94179,0.94179,0.94179,0.94989,0.940486,0.924637,0.913275,0.922407,0.978664,0.960494
2,0.1608,0.150899,0.964655,0.964143,0.963366,0.965047,0.964655,0.964655,0.964655,0.970198,0.95992,0.95803,0.952,0.940488,0.987854,0.980511
3,0.0056,0.139935,0.970476,0.970897,0.970919,0.970891,0.970476,0.970476,0.970476,0.972985,0.965379,0.963743,0.958607,0.96,0.990391,0.985173



TRAINING COMPLETED

Final Metrics:
  Runtime: 9452.8848
  Samples Per Second: 35.3650
  Steps Per Second: 4.4210
  Total Flos: 68146074658037880.0000
  Loss: 0.2739
  Epoch: 3.0000

Saving final model to: ..\models\checkpoints\final_model


[34m[1mwandb[0m: Adding directory to artifact (..\models\checkpoints\final_model)... Done. 0.1s



Logging model to W&B artifacts...
✓ Model artifact logged

✓ Training complete!

   Model: ..\models\checkpoints\final_model
   W&B: https://wandb.ai/cristinalee0723/enzyme-classification-esm2/runs/gd3mymsv

NEXT STEPS

1. To continue training for more epochs:
   - Set RESUME_TRAINING = True
   - Adjust EPOCHS to total desired epochs
   - Re-run this cell

2. To start a new experiment:
   - Set RESUME_TRAINING = False
   - Set EXPERIMENT_NAME to a new name
   - Adjust hyperparameters as needed
   - Re-run from 'Training Configuration' cell

3. For evaluation:
   - Run 03_evaluation.ipynb


0,1
eval/accuracy,▁▇█
eval/f1_class_0_Oxidored,▁▇█
eval/f1_class_1_Transfer,▁▆█
eval/f1_class_2_Hydrolas,▁▇█
eval/f1_class_3_Lyases,▁▇█
eval/f1_class_4_Isomeras,▁▄█
eval/f1_class_5_Ligases,▁▆█
eval/f1_class_6_Transloc,▁▇█
eval/f1_macro,▁▆█
eval/f1_micro,▁▇█

0,1
best_accuracy,0.97048
best_f1_macro,0.9709
data/train_samples,111434
data/val_samples,23879
eval/accuracy,0.97048
eval/f1_class_0_Oxidored,0.97298
eval/f1_class_1_Transfer,0.96538
eval/f1_class_2_Hydrolas,0.96374
eval/f1_class_3_Lyases,0.95861
eval/f1_class_4_Isomeras,0.96
