# Training Demo for DeepReaction

This notebook demonstrates how to train a molecular reaction prediction model using the DeepReaction framework.

## 1. Import Required Libraries

In [1]:
import os
import sys
import torch
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from pathlib import Path

# Assuming deepreaction is installed or in the Python path
# If not, you might need to add its location to sys.path
# Example: sys.path.append('/path/to/deepreaction/parent/directory')
from deepreaction import ReactionDataset, ReactionTrainer
from deepreaction.config import ReactionConfig, ModelConfig, TrainingConfig, Config, save_config

## 2. Define Training Parameters

The parameters below can be modified to fit your specific use case.

In [2]:
# Define configuration parameters - these can be modified directly in the notebook
params = {
    # Dataset parameters
    'dataset': 'XTB',
    'readout': 'mean',
    'dataset_root': './dataset/DATASET_DA_F',  # Adjust path if needed
    'dataset_csv': './dataset/DATASET_DA_F/dataset_xtb_final.csv', # Adjust path if needed
    'train_ratio': 0.8,
    'val_ratio': 0.1,
    'test_ratio': 0.1,
    'target_fields': ['G(TS)', 'DrG'],
    'target_weights': [1.0, 1.0],
    'input_features': ['G(TS)_xtb', 'DrG_xtb'],
    'file_patterns': ['*_reactant.xyz', '*_ts.xyz', '*_product.xyz'],
    'file_dir_pattern': 'reaction_*',
    'id_field': 'ID',
    'dir_field': 'R_dir',
    'reaction_field': 'reaction',
    'cv_folds': 0, # Set > 0 for cross-validation
    
    # Model parameters (DimeNet++ specific)
    'model_type': 'dimenet++',
    'node_dim': 128,
    'dropout': 0.1,
    'prediction_hidden_layers': 3,
    'prediction_hidden_dim': 512,
    'use_layer_norm': False,
    
    'hidden_channels': 128,
    'num_blocks': 5,
    'int_emb_size': 64,
    'basis_emb_size': 8,
    'out_emb_channels': 256,
    'num_spherical': 7,
    'num_radial': 6,
    'cutoff': 5.0,
    'envelope_exponent': 5,
    'num_before_skip': 1,
    'num_after_skip': 2,
    'num_output_layers': 3,
    'max_num_neighbors': 32,
    
    # Training parameters
    'batch_size': 16,
    'eval_batch_size': None, # Uses batch_size if None
    'lr': 0.0005,
    'finetune_lr': None,
    'epochs': 1, # Set higher for actual training (e.g., 100, 500)
    'min_epochs': 0,
    'early_stopping': 40,
    'optimizer': 'adamw',
    'scheduler': 'warmup_cosine',
    'warmup_epochs': 10,
    'min_lr': 1e-7,
    'weight_decay': 0.0001,
    'random_seed': 42234,
    
    'out_dir': './results/reaction_model', # Adjust path if needed
    'save_best_model': True,
    'save_last_model': False,
    'checkpoint_path': None, # Path to a .ckpt file to resume/continue
    'mode': 'continue', # 'train' or 'continue'
    'freeze_base_model': False,
    
    'cuda': True, # Set to False to force CPU
    'gpu_id': 0,
    'num_workers': 4 # Number of workers for data loading
}

## 3. Set Up GPU and Output Directory

In [3]:
# Setup GPU or CPU
if params['cuda'] and torch.cuda.is_available():
    os.environ["CUDA_VISIBLE_DEVICES"] = str(params['gpu_id'])
    device = torch.device(f"cuda:{params['gpu_id']}")
    print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
    os.environ["CUDA_VISIBLE_DEVICES"] = ""
    device = torch.device("cpu")
    print("Using CPU")
    params['cuda'] = False # Ensure cuda param reflects reality

# Create output directory
os.makedirs(params['out_dir'], exist_ok=True)
print(f"Output directory created/exists: {params['out_dir']}")

Using GPU: NVIDIA TITAN Xp
Output directory created/exists: ./results/reaction_model


## 4. Create Configuration Objects

In [4]:
# Convert parameters to configuration objects
reaction_config = ReactionConfig(
    dataset_root=params['dataset_root'],
    dataset_csv=params['dataset_csv'],
    target_fields=params['target_fields'],
    file_patterns=params['file_patterns'],
    input_features=params['input_features'],
    use_scaler=True,
    train_ratio=params['train_ratio'],
    val_ratio=params['val_ratio'],
    test_ratio=params['test_ratio'],
    cv_folds=params['cv_folds'],
    cv_test_fold=-1, # Which fold is test set in CV, -1 if standard split
    cv_stratify=False,
    cv_grouped=True,
    id_field=params['id_field'],
    dir_field=params['dir_field'],
    reaction_field=params['reaction_field'],
    random_seed=params['random_seed']
)

model_config = ModelConfig(
    model_type=params['model_type'],
    readout=params['readout'],
    # DimeNet++ specific
    hidden_channels=params['hidden_channels'],
    num_blocks=params['num_blocks'],
    cutoff=params['cutoff'],
    int_emb_size=params['int_emb_size'],
    basis_emb_size=params['basis_emb_size'],
    out_emb_channels=params['out_emb_channels'],
    num_spherical=params['num_spherical'],
    num_radial=params['num_radial'],
    envelope_exponent=params['envelope_exponent'],
    num_before_skip=params['num_before_skip'],
    num_after_skip=params['num_after_skip'],
    num_output_layers=params['num_output_layers'],
    max_num_neighbors=params['max_num_neighbors'],
    # General model params
    node_dim=params['node_dim'], 
    dropout=params['dropout'],
    use_layer_norm=params['use_layer_norm'],
    use_xtb_features=len(params['input_features']) > 0,
    num_xtb_features=len(params['input_features']),
    prediction_hidden_layers=params['prediction_hidden_layers'],
    prediction_hidden_dim=params['prediction_hidden_dim']
)

training_config = TrainingConfig(
    output_dir=params['out_dir'],
    batch_size=params['batch_size'],
    learning_rate=params['lr'],
    max_epochs=params['epochs'],
    min_epochs=params['min_epochs'],
    early_stopping_patience=params['early_stopping'],
    save_best_model=params['save_best_model'],
    save_last_model=params['save_last_model'],
    optimizer=params['optimizer'],
    weight_decay=params['weight_decay'],
    scheduler=params['scheduler'],
    warmup_epochs=params['warmup_epochs'],
    min_lr=params['min_lr'],
    target_weights=params['target_weights'],
    gpu=params['cuda'],
    num_workers=params['num_workers'],
    resume_from_checkpoint=params['checkpoint_path']
)

config = Config(
    reaction=reaction_config,
    model=model_config,
    training=training_config
)

# Save configuration
config_path = os.path.join(params['out_dir'], 'config')
save_config(config, config_path) # Saves both .yaml and .json
print(f"Configuration saved to {config_path}.yaml and {config_path}.json")

Configuration saved to ./results/reaction_model/config.yaml and ./results/reaction_model/config.json


## 5. Load and Prepare Dataset

*Note: This might take a while depending on the dataset size and preprocessing steps.*

In [5]:
# Load dataset
print(f"Loading dataset from {params['dataset_root']}")
# Ensure file paths exist before proceeding
if not os.path.exists(params['dataset_root']) or not os.path.exists(params['dataset_csv']):
    print(f"Error: Dataset root ({params['dataset_root']}) or CSV ({params['dataset_csv']}) not found.")
    print("Please ensure the dataset files are correctly placed and paths are updated in Section 2.")
    # Stop execution or raise an error if critical files are missing
    # raise FileNotFoundError("Dataset files not found.") 
else:
    dataset = ReactionDataset(
        root=params['dataset_root'],
        csv_file=params['dataset_csv'],
        target_fields=params['target_fields'],
        file_patterns=params['file_patterns'],
        input_features=params['input_features'],
        use_scaler=True, # Important for consistent scaling
        random_seed=params['random_seed'],
        train_ratio=params['train_ratio'],
        val_ratio=params['val_ratio'],
        test_ratio=params['test_ratio'],
        cv_folds=params['cv_folds'],
        id_field=params['id_field'],
        dir_field=params['dir_field'],
        reaction_field=params['reaction_field']
        # Add other ReactionDataset specific args if needed
    )

    print("Dataset loaded successfully")
    data_stats = dataset.get_data_stats()
    print(f"Dataset stats: Train: {data_stats['train_size']}, Validation: {data_stats['val_size']}, Test: {data_stats['test_size']}")
    if params['cv_folds'] > 0:
        print(f"Cross-validation enabled with {dataset.get_num_folds()} folds.")

    # Show sample data (optional)
    try:
        if dataset.train_data and len(dataset.train_data) > 0:
            sample = dataset.train_data[0]
            print(f"\nSample data attributes:")
            # Print common attributes expected in a PyG Data object for reactions
            for attr in ['reaction_id', 'id', 'reaction', 'y', 'pos0', 'z0', 'pos1', 'z1', 'pos2', 'z2', 'xtb_features']:
                if hasattr(sample, attr):
                     val = getattr(sample, attr)
                     val_repr = f"shape: {val.shape}" if isinstance(val, torch.Tensor) else val
                     print(f"  {attr}: {val_repr}")
        elif dataset.data and len(dataset.data) > 0 and params['cv_folds'] > 0:
             # If using CV, train_data might be None initially, check raw data
             sample = dataset.data[0]
             print(f"\nSample data attributes (from raw dataset for CV):")
             for attr in ['reaction_id', 'id', 'reaction', 'y', 'pos0', 'z0', 'pos1', 'z1', 'pos2', 'z2', 'xtb_features']:
                 if hasattr(sample, attr):
                     val = getattr(sample, attr)
                     val_repr = f"shape: {val.shape}" if isinstance(val, torch.Tensor) else val
                     print(f"  {attr}: {val_repr}")
        else:
            print("\nNo training data available to display sample.")
            
    except Exception as e:
        print(f"Could not display sample data: {e}")

Loading dataset from ./dataset/DATASET_DA_F
Error checking saved data: 'NoneType' object is not subscriptable
Using target fields: ['G(TS)', 'DrG']
Using input features: ['G(TS)_xtb', 'DrG_xtb']
Using file patterns: ['*_reactant.xyz', '*_ts.xyz', '*_product.xyz']


Processing reactions:  63%|██████▎   | 1002/1582 [00:00<00:00, 1155.57it/s]



Processing reactions: 100%|██████████| 1582/1582 [00:01<00:00, 1112.84it/s]


Processed 1580 reactions, skipped 2 reactions
Saved metadata to dataset/DATASET_DA_F/processed/metadata.json
Processed 1580 reactions, saved to dataset/DATASET_DA_F/processed/data_038b0f2fed6b.pt
Dataset split: train 1269, validation 162, test 149 samples
Dataset loaded successfully
Dataset stats: Train: 1269, Validation: 162, Test: 149

Sample data attributes:
  reaction_id: ID87464
  id: reaction_R13963
  reaction: [C:1]([C:2](=[C:3]([C:4]([H:23])([H:24])[H:25])[C:5](=[C:6]([O:7][H:28])[H:27])[H:26])[H:22])([H:19])([H:20])[H:21].[O:8]=[C:9]1[N:10]([C:11]([O:12][H:31])([H:29])[H:30])[C:17](=[O:18])[C:15]([C:16]([H:32])([H:33])[H:34])=[C:13]1[C:14]([H:35])([H:36])[H:37]>>[C:1]([C@:2]1([H:22])[C:3]([C:4]([H:23])([H:24])[H:25])=[C:5]([H:26])[C@@:6]([O:7][H:28])([H:27])[C@@:13]2([C:14]([H:35])([H:36])[H:37])[C:9](=[O:8])[N:10]([C:11]([O:12][H:31])([H:29])[H:30])[C:17](=[O:18])[C@@:15]12[C:16]([H:32])([H:33])[H:34])([H:19])([H:20])[H:21]
  y: shape: torch.Size([1, 2])
  pos0: shape: torch.

## 6. Initialize and Configure Trainer

In [6]:
# Ensure dataset was loaded before proceeding
if 'dataset' not in locals():
     print("Error: Dataset not loaded. Please run the previous cell successfully.")
     # Stop execution
     # raise RuntimeError("Dataset not available for trainer initialization.")
else:
    # Additional keywords for trainer
    additional_kwargs = {}
    if params['finetune_lr'] is not None:
        additional_kwargs['finetune_lr'] = params['finetune_lr']
    if params['freeze_base_model']:
        additional_kwargs['freeze_base_model'] = True
    if params['eval_batch_size'] is not None:
         additional_kwargs['eval_batch_size'] = params['eval_batch_size']


    # Create trainer
    trainer = ReactionTrainer(
        # Core training params from config
        model_type=config.model.model_type,
        readout=config.model.readout,
        batch_size=config.training.batch_size,
        max_epochs=config.training.max_epochs,
        learning_rate=config.training.learning_rate,
        output_dir=config.training.output_dir,
        early_stopping_patience=config.training.early_stopping_patience,
        save_best_model=config.training.save_best_model,
        save_last_model=config.training.save_last_model,
        random_seed=config.reaction.random_seed, # Use seed from reaction config for consistency
        num_targets=len(config.reaction.target_fields),
        use_scaler=config.reaction.use_scaler,
        scalers=dataset.get_scalers(), # Get scalers from the loaded dataset
        optimizer=config.training.optimizer,
        weight_decay=config.training.weight_decay,
        scheduler=config.training.scheduler,
        warmup_epochs=config.training.warmup_epochs,
        min_lr=config.training.min_lr,
        gpu=config.training.gpu,
        target_field_names=config.reaction.target_fields,
        min_epochs=config.training.min_epochs,
        num_workers=config.training.num_workers,
        
        # Model architecture params from config
        node_dim=config.model.node_dim,
        dropout=config.model.dropout,
        use_layer_norm=config.model.use_layer_norm,
        use_xtb_features=config.model.use_xtb_features,
        num_xtb_features=config.model.num_xtb_features,
        prediction_hidden_layers=config.model.prediction_hidden_layers,
        prediction_hidden_dim=config.model.prediction_hidden_dim,
        
        # DimeNet++ specific params from config
        hidden_channels=config.model.hidden_channels,
        num_blocks=config.model.num_blocks,
        cutoff=config.model.cutoff,
        int_emb_size=config.model.int_emb_size,
        basis_emb_size=config.model.basis_emb_size,
        out_emb_channels=config.model.out_emb_channels,
        num_spherical=config.model.num_spherical,
        num_radial=config.model.num_radial,
        envelope_exponent=config.model.envelope_exponent,
        num_before_skip=config.model.num_before_skip,
        num_after_skip=config.model.num_after_skip,
        num_output_layers=config.model.num_output_layers,
        max_num_neighbors=config.model.max_num_neighbors,
        
        # Pass additional kwargs like finetune_lr, freeze_base_model, etc.
        **additional_kwargs 
    )
    print("ReactionTrainer initialized successfully.")

Seed set to 42234


ReactionTrainer initialized successfully.


## 7. Train the Model

*Note: This is where the actual training happens. If `epochs` is set low (e.g., 1), this will be very fast but won't result in a trained model. Increase `epochs` in Section 2 for real training.*

In [7]:
# Ensure trainer and dataset are available
if 'trainer' not in locals() or 'dataset' not in locals():
    print("Error: Trainer or Dataset not initialized. Please run previous cells.")
    # Stop execution
    # raise RuntimeError("Trainer or Dataset not available for fitting.")
elif params['cv_folds'] > 0:
    print("Cross-validation is enabled. Training will be handled in the CV section (Section 10).")
    print("Skipping single training run.")
else:
    # Start training (only if not doing cross-validation)
    print(f"Starting {params['mode']} training with {params['epochs']} epochs")
    # Make sure datasets are available
    if dataset.train_data is None or dataset.val_data is None:
         print("Error: Train or Validation data split not found. Check dataset loading and splitting.")
    else:
        train_metrics = trainer.fit(
            train_dataset=dataset.train_data,
            val_dataset=dataset.val_data,
            test_dataset=dataset.test_data, # Optional, used for final evaluation if provided
            checkpoint_path=params['checkpoint_path'],
            mode=params['mode']
        )
    
        print(f"Training completed.")
        print("Metrics:", train_metrics)
        if 'best_model_path' in train_metrics and train_metrics['best_model_path']:
            print(f"Best model saved to: {train_metrics['best_model_path']}")
        elif params['save_last_model'] and 'last_model_path' in train_metrics and train_metrics['last_model_path']:
             print(f"Last model saved to: {train_metrics['last_model_path']}")

Starting continue training with 1 epochs


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/root/miniconda3/lib/python3.12/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /root/autodl-tmp/deepooooo19/results/reaction_model/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type                      | Params | Mode 
---------------------------------------------------------------------
0 | model          | MoleculePredictionModel   | 3.5 M  | train
1 | net            | DimeNetPlusPlus           | 2.3 M  | train
2 | readout_module | MeanReadout               | 0      | train
3 | regr_or_cls_nn | MultiTargetPredictionHead | 1.2 M  | train
---------------------------------------------------------------------
3.5 M     Trainable params
0         Non-trainable params
3.5 M     Total params
13.866    Total estimated model params size (MB)
193       Modules in train mode
0         

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_total_loss improved. New best score: 0.433
Epoch 0, global step 80: 'val_total_loss' reached 0.43263 (best 0.43263), saving model to '/root/autodl-tmp/deepooooo19/results/reaction_model/checkpoints/best-epoch=0000-val_total_loss=0.4326.ckpt' as top 1
`Trainer.fit` stopped: `max_epochs=1` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      Test Avg MAE           4.679178714752197
     Test Avg MAX_AE         17.82347869873047
   Test Avg MEDIAN_AE        3.879878282546997
      Test Avg MPAE          52.42013168334961
       Test Avg R2           0.509222149848938
      Test Avg RMSE          6.083659648895264
      Test MAE DrG          4.0490851402282715
     Test MAE G(TS)          5.309272289276123
     Test MAX_AE DrG        19.114639282226562
    Test MAX_AE G(TS)       16.532320022583008
   Test MEDIAN_AE DrG        2.354483127593994
  Test MEDIAN_AE G(TS)         5.4052734375
      Test MPAE DrG          88.11637115478516
     Test MPAE G(TS)        16.723892211914062
       Test R2 DrG          0.7958351373672485
      Test R

## 8. Visualize Training Results

*This section attempts to load the training logs saved by PyTorch Lightning and plot the loss and learning rate.*

In [8]:
# Display training metrics if training was run and logs exist
# Check if trainer object exists and has completed fitting (indicated by presence of trainer.trainer)
if 'trainer' in locals() and hasattr(trainer, 'trainer') and hasattr(trainer.trainer, 'logger'):
    try:
        # Construct the expected path to the CSV logs
        # Default logger is CSVLogger, logs go into lightning_logs/version_X or a specified dir
        log_dir = trainer.trainer.logger.log_dir
        metrics_path = os.path.join(log_dir, 'metrics.csv')
        
        if os.path.exists(metrics_path):
            print(f"Loading metrics from: {metrics_path}")
            metrics_df = pd.read_csv(metrics_path)
            
            # Fill NaN values that might occur if a metric wasn't logged every step/epoch
            metrics_df = metrics_df.ffill().bfill()
            
            # Plot training and validation loss
            plt.figure(figsize=(10, 6))
            
            train_loss_col = None
            val_loss_col = None
            # Find the correct column names (might vary slightly based on logging)
            for col in metrics_df.columns:
                if 'train_total_loss' in col and 'epoch' in col:
                    train_loss_col = col
                if 'val_total_loss' in col:
                    val_loss_col = col

            # Use 'epoch' column for x-axis if available, otherwise index
            epoch_col = 'epoch' if 'epoch' in metrics_df.columns else metrics_df.index
            
            if train_loss_col:
                 # Need to handle potential NaNs if logged per step vs per epoch
                 train_loss_epoch = metrics_df.dropna(subset=[train_loss_col])
                 epoch_axis_train = train_loss_epoch['epoch'] if 'epoch' in train_loss_epoch.columns else train_loss_epoch.index
                 plt.plot(epoch_axis_train, train_loss_epoch[train_loss_col], label='Train Loss', marker='.')
                 
            if val_loss_col:
                 # Val loss is typically logged per epoch, so NaNs are less likely unless training stopped early
                 val_loss_epoch = metrics_df.dropna(subset=[val_loss_col])
                 epoch_axis_val = val_loss_epoch['epoch'] if 'epoch' in val_loss_epoch.columns else val_loss_epoch.index
                 plt.plot(epoch_axis_val, val_loss_epoch[val_loss_col], label='Val Loss', marker='.')

            if train_loss_col or val_loss_col:
                plt.xlabel('Epoch')
                plt.ylabel('Loss')
                plt.title('Training and Validation Loss')
                plt.legend()
                plt.grid(True, alpha=0.3)
                plt.show()
            else:
                print("Could not find training/validation loss columns in metrics.csv")
                
            # Plot learning rate
            lr_col = None
            # Find LR column (might depend on optimizer)
            possible_lr_cols = [col for col in metrics_df.columns if col.startswith('lr-') or col == 'learning_rate']
            if possible_lr_cols:
                 lr_col = possible_lr_cols[0] # Take the first match
            
            if lr_col and lr_col in metrics_df.columns:
                lr_epoch = metrics_df.dropna(subset=[lr_col])
                epoch_axis_lr = lr_epoch['epoch'] if 'epoch' in lr_epoch.columns else lr_epoch.index
                plt.figure(figsize=(10, 4))
                plt.plot(epoch_axis_lr, lr_epoch[lr_col])
                plt.xlabel('Epoch')
                plt.ylabel('Learning Rate')
                plt.title('Learning Rate Schedule')
                plt.grid(True, alpha=0.3)
                plt.show()
            else:
                 print("Could not find learning rate column in metrics.csv")

        else:
            print(f"Metrics file not found at {metrics_path}. Cannot plot results.")
            print("Ensure training ran and generated logs.")
            
    except Exception as e:
        import traceback
        print(f"Could not plot metrics: {e}")
        # traceback.print_exc()
elif params['cv_folds'] > 0:
     print("Plotting is skipped for cross-validation runs in this section.")
     print("Consider plotting results for each fold or aggregated results after CV finishes.")
else:
    print("Trainer object not found or training hasn't run. Cannot plot results.")

Metrics file not found at ./results/reaction_model/tensorboard/version_4/metrics.csv. Cannot plot results.
Ensure training ran and generated logs.


## 9. Test Model Performance (Optional)

*This evaluates the trained model (ideally the best one saved) on the held-out test set.*

In [9]:
# Evaluate on test set if available and not doing CV
if 'trainer' in locals() and hasattr(trainer, 'trainer') and 'dataset' in locals() and dataset.test_data and len(dataset.test_data) > 0 and params['cv_folds'] == 0:
    try:
        from torch_geometric.loader import DataLoader
        
        # Define which attributes need batching (specific to your data structure)
        # These usually include atomic numbers (z) and positions (pos) for each state
        follow_batch = []
        sample_data = dataset.test_data[0]
        if hasattr(sample_data, 'z0'): follow_batch.append('z0')
        if hasattr(sample_data, 'pos0'): follow_batch.append('pos0')
        if hasattr(sample_data, 'z1'): follow_batch.append('z1')
        if hasattr(sample_data, 'pos1'): follow_batch.append('pos1')
        if hasattr(sample_data, 'z2'): follow_batch.append('z2')
        if hasattr(sample_data, 'pos2'): follow_batch.append('pos2')
        # Add other batch-sensitive attributes if needed
        
        eval_batch_size = params['eval_batch_size'] if params['eval_batch_size'] is not None else params['batch_size']
        test_loader = DataLoader(
            dataset.test_data,
            batch_size=eval_batch_size,
            shuffle=False,
            num_workers=params['num_workers'],
            follow_batch=follow_batch
        )
        
        print("\nEvaluating model on the test set...")
        # The trainer.test() method uses the best checkpoint by default if available
        # Alternatively, you can specify a checkpoint_path='path/to/model.ckpt'
        test_results = trainer.trainer.test(dataloaders=test_loader, ckpt_path='best') # Use 'best' or path to specific ckpt
        
        print("\nTest Set Evaluation Results:")
        if test_results:
             for key, value in test_results[0].items(): # Results are often in a list
                 print(f"  {key}: {value:.4f}")
        else:
            print("No test results returned.")
            
    except FileNotFoundError:
        print("Could not evaluate on test set: Best checkpoint not found. Ensure training ran and saved a model.")
    except Exception as e:
        import traceback
        print(f"Could not evaluate on test set: {e}")
        # traceback.print_exc()
elif params['cv_folds'] > 0:
     print("\nTest set evaluation is typically done within or after the cross-validation loop (Section 10).")
elif 'dataset' in locals() and (dataset.test_data is None or len(dataset.test_data) == 0):
    print("\nNo test data available for evaluation.")
else:
    print("\nTrainer not available or not configured for single run. Skipping test set evaluation.")

Restoring states from the checkpoint path at /root/autodl-tmp/deepooooo19/results/reaction_model/checkpoints/best-epoch=0000-val_total_loss=0.4326.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /root/autodl-tmp/deepooooo19/results/reaction_model/checkpoints/best-epoch=0000-val_total_loss=0.4326.ckpt



Evaluating model on the test set...


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      Test Avg MAE           4.679178714752197
     Test Avg MAX_AE         17.82347869873047
   Test Avg MEDIAN_AE        3.879878282546997
      Test Avg MPAE          52.42013168334961
       Test Avg R2           0.509222149848938
      Test Avg RMSE          6.083659648895264
      Test MAE DrG          4.0490851402282715
     Test MAE G(TS)          5.309272289276123
     Test MAX_AE DrG        19.114639282226562
    Test MAX_AE G(TS)       16.532320022583008
   Test MEDIAN_AE DrG        2.354483127593994
  Test MEDIAN_AE G(TS)         5.4052734375
      Test MPAE DrG          88.11637115478516
     Test MPAE G(TS)        16.723892211914062
       Test R2 DrG          0.7958351373672485
      Test R

## 10. Cross-Validation Training Example (Optional)

If you want to perform cross-validation, use the following code:

**Important:**
* Make sure `params['cv_folds']` was set to a value greater than 0 in Section 2.
* The `ReactionDataset` must be initialized with the same `cv_folds` value.
* This cell will train the model multiple times (once for each fold), which can take a significant amount of time.