# DeepReaction Model Prediction

This notebook uses a pre-trained DeepReaction model checkpoint to make predictions on a dataset specified by a CSV file and corresponding XYZ files.

## 1. Import Required Libraries

In [17]:
import os
import sys
import torch
import pandas as pd
import numpy as np
from pathlib import Path
from torch_geometric.loader import DataLoader


from deepreaction.data.PygReaction import ReactionXYZDataset 
from deepreaction.module.pl_wrap import Estimator


## 2. Configuration Parameters

Modify the parameters below to match your dataset, model checkpoint, and desired output locations.

In [18]:

dataset_root = './dataset/DATASET_DA_F' 
dataset_csv = './dataset/DATASET_DA_F/dataset_xtb_final.csv'

input_features = ['G(TS)_xtb', 'DrG_xtb'] # Example: ['G(TS)_xtb', 'DrG_xtb'] or [] 
file_patterns = ['*_reactant.xyz', '*_ts.xyz', '*_product.xyz'] # Patterns to find geometry files
id_field = 'ID'            # Column name in CSV for unique identifier
dir_field = 'R_dir'       # Column name in CSV for subdirectory containing XYZ files
reaction_field = 'reaction' # Column name in CSV for reaction SMILES/identifier (optional, for output)

checkpoint_path = './results/reaction_model/checkpoints/best-epoch=0000-val_total_loss=0.4343.ckpt' 

# Output parameters (Adjust paths as needed)
output_csv = './predictions.csv' # Path to save the predictions in CSV format
output_dir = './predictions'   # Directory to save prediction outputs (e.g., numpy arrays)

# Inference parameters
batch_size = 32          # Number of samples per batch for prediction
use_cuda = True          # Set to True to use GPU if available, False to force CPU
gpu_id = 0               # GPU ID to use if use_cuda is True and GPU is available
num_workers = 4          # Number of workers for data loading (set to 0 on Windows if issues arise)

# --- End Configuration ---

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
print(f"Output directory set to: {output_dir}")

# Handle case where input_features might be None
if input_features is None:
    input_features = []

Output directory set to: ./predictions


## 3. Setup Device (GPU/CPU)

In [19]:
if use_cuda and torch.cuda.is_available():
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    device = torch.device(f"cuda:{gpu_id}")
    print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
    os.environ["CUDA_VISIBLE_DEVICES"] = ""
    device = torch.device("cpu")
    print("Using CPU")
    use_cuda = False # Ensure flag reflects actual device used

Using GPU: NVIDIA TITAN Xp


## 4. Load Model and Extract Target Fields

Load the trained model from the checkpoint. The target field names and scaler information (if saved) will be extracted from the model.

In [20]:
# Load the model from the checkpoint
print(f"Loading model from checkpoint: {checkpoint_path}")
model = None
target_fields = None

if not os.path.exists(checkpoint_path):
    print(f"Error: Checkpoint file not found at '{checkpoint_path}'")
    print("Please verify the 'checkpoint_path' in the configuration cell.")
else:
    try:
        # Ensure the Estimator class is available from imports
        # Use map_location to ensure the model loads correctly whether on CPU or GPU
        model = Estimator.load_from_checkpoint(checkpoint_path, map_location=device)
        model = model.to(device) # Move model to the selected device
        model.eval() # Set model to evaluation mode (important!)

        # Extract target field names from the loaded model
        if hasattr(model, 'target_field_names'):
            target_fields = model.target_field_names
            print(f"Successfully loaded model. Using target fields from model: {target_fields}")
        else:
             print("Warning: Could not automatically determine target fields from the model.")
             # Manually define target_fields here if necessary, e.g.:
             # target_fields = ['G(TS)', 'DrG'] 
             # Make sure this matches the order of outputs from your model!
             raise ValueError("target_field_names attribute not found in the loaded model.")
            
        # Check for scaler information (used for inverse transforming predictions)
        if hasattr(model, 'scaler') and model.scaler is not None:
             print(f"Model contains scaler information for {len(model.scaler)} targets.")
        else:
             print("Warning: Model does not contain scaler information. Predictions will be in scaled units.")

    except Exception as e:
        print(f"Error loading model from checkpoint: {e}")
        print("Ensure the checkpoint file is valid and compatible with the current deepreaction code.")
        model = None # Ensure model is None if loading failed

Loading model from checkpoint: ./results/reaction_model/checkpoints/best-epoch=0000-val_total_loss=0.4343.ckpt
Successfully loaded model. Using target fields from model: ['G(TS)', 'DrG']
Model contains scaler information for 2 targets.


## 5. Load Dataset for Inference

Load the dataset using `ReactionXYZDataset`. This class reads the CSV and finds the corresponding XYZ files based on the provided patterns and directory structure.

In [21]:
dataset = None
if model is not None and target_fields is not None:
    # Check if dataset files exist
    if not os.path.exists(dataset_root):
        print(f"Error: Dataset root directory not found at '{dataset_root}'")
    elif not os.path.exists(dataset_csv):
        print(f"Error: Dataset CSV file not found at '{dataset_csv}'")
    else:
        print(f"Loading dataset from {dataset_root} using CSV {dataset_csv}")
        try:
            # Ensure the ReactionXYZDataset class is available from imports
            dataset = ReactionXYZDataset(
                root=dataset_root,
                csv_file=dataset_csv,
                target_fields=target_fields, # Use fields extracted from the model
                file_patterns=file_patterns,
                input_features=input_features, # Use features defined in config
                id_field=id_field,
                dir_field=dir_field,
                reaction_field=reaction_field,
                inference_mode=True # Crucial for prediction datasets (skips train/val/test split)
                # Add other relevant parameters if your dataset class requires them
            )
            print(f"Dataset loaded successfully with {len(dataset)} samples")
            
            # Optional: Display a sample from the dataset
            if len(dataset) > 0:
                print("\nSample data object from dataset:")
                print(dataset[0])
        except Exception as e:
            print(f"Error loading dataset: {e}")
            print("Check dataset paths, CSV format, and file structure.")
            dataset = None # Ensure dataset is None if loading failed
else:
    print("Skipping dataset loading because the model could not be loaded or target fields are missing.")

Loading dataset from ./dataset/DATASET_DA_F using CSV ./dataset/DATASET_DA_F/dataset_xtb_final.csv
Error checking saved data: 'NoneType' object is not subscriptable
Using target fields: ['G(TS)', 'DrG']
Using input features: ['G(TS)_xtb', 'DrG_xtb']
Using file patterns: ['*_reactant.xyz', '*_ts.xyz', '*_product.xyz']


Processing reactions:  71%|███████   | 1117/1582 [00:01<00:00, 1190.76it/s]



Processing reactions: 100%|██████████| 1582/1582 [00:01<00:00, 1127.65it/s]


Processed 1580 reactions, skipped 2 reactions
Saved metadata to dataset/DATASET_DA_F/processed/metadata.json
Processed 1580 reactions, saved to dataset/DATASET_DA_F/processed/data_c08fa62613d2.pt
Dataset loaded successfully with 1580 samples

Sample data object from dataset:
Data(y=[1, 2], z0=[16], z1=[16], z2=[16], pos0=[16, 3], pos1=[16, 3], pos2=[16, 3], xtb_features=[1, 2], feature_names=[2], reaction_id='ID63623', id='reaction_R0', reaction='[C:1](=[C:2]([C:3](=[C:4]([H:11])[H:12])[H:10])[H:9])([H:7])[H:8].[C:5](=[C:6]([H:15])[H:16])([H:13])[H:14]>>[C:1]1([H:7])([H:8])[C:2]([H:9])=[C:3]([H:10])[C:4]([H:11])([H:12])[C:5]([H:13])([H:14])[C:6]1([H:15])[H:16]', num_nodes=16)


## 6. Create DataLoader

Prepare the data loader for batching during inference.

In [22]:
data_loader = None
if dataset is not None and len(dataset) > 0:
    # Define which attributes require special batching in PyG (usually atom counts and positions)
    # Check a sample to see which attributes exist (e.g., pos0, z0, pos1, z1, pos2, z2)
    follow_batch = []
    sample_data = dataset[0]
    for key in ['z0', 'z1', 'z2', 'pos0', 'pos1', 'pos2']:
        if hasattr(sample_data, key):
            follow_batch.append(key)
    print(f"Using follow_batch attributes: {follow_batch}")

    # Create the DataLoader
    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False, # Important: Do not shuffle data during prediction
        num_workers=num_workers,
        follow_batch=follow_batch # Handles batching for varying numbers of atoms
    )
    print(f"DataLoader created with batch size {batch_size}.")
elif dataset is not None and len(dataset) == 0:
    print("Dataset is empty. Cannot create DataLoader.")
else:
    print("Skipping DataLoader creation because the dataset was not loaded successfully.")

Using follow_batch attributes: ['z0', 'z1', 'z2', 'pos0', 'pos1', 'pos2']
DataLoader created with batch size 32.


## 7. Run Inference

Iterate through the data loader, pass batches to the model, and collect the predictions.

In [23]:
all_predictions_scaled = [] # Store predictions (likely scaled)
all_batch_metadata = []  # Store corresponding metadata (IDs, etc.)

if model is not None and data_loader is not None:
    print("Running inference...")
    with torch.no_grad(): # Disable gradient calculations for efficiency
        batch_count = 0
        for batch in data_loader:
            batch = batch.to(device) # Move batch data to the target device
            
            # Extract necessary inputs for the model's forward method
            # These names must match the attributes in your PyG Data objects
            pos0, pos1, pos2 = batch.pos0, batch.pos1, batch.pos2
            z0, z1, z2 = batch.z0, batch.z1, batch.z2
            batch_mapping = batch.batch # PyG batch mapping tensor
            
            # Handle optional xtb_features (or other input features)
            xtb_features = getattr(batch, 'xtb_features', None)
            if xtb_features is not None:
                xtb_features = xtb_features.to(device)
            
            # Perform the forward pass
            # The output format depends on your Estimator's forward method
            # Typically: embeddings, atomwise_results, predictions = model(...)
            try:
                 _, _, predictions = model(pos0, pos1, pos2, z0, z1, z2, batch_mapping, xtb_features)
            except TypeError as e:
                 print(f"\nError during model forward pass: {e}")
                 print("Check if the arguments passed match the model's forward signature.")
                 print("Expected arguments based on common patterns: pos0, pos1, pos2, z0, z1, z2, batch_mapping, xtb_features (optional)")
                 print("Stopping inference.")
                 all_predictions_scaled = [] # Clear partial results on error
                 all_batch_metadata = []
                 break # Exit the loop
            except Exception as e:
                 print(f"\nAn unexpected error occurred during model forward pass: {e}")
                 print("Stopping inference.")
                 all_predictions_scaled = []
                 all_batch_metadata = []
                 break
                 
            # Store predictions (move to CPU and convert to numpy)
            all_predictions_scaled.append(predictions.cpu().numpy())
            
            # Store metadata (IDs, reaction strings, etc.) from the batch
            batch_meta = {}
            for attr in ['reaction_id', 'id', 'reaction']: # Add other relevant fields
                if hasattr(batch, attr):
                    value = getattr(batch, attr)
                    # Ensure metadata is easily serializable (e.g., list of strings/numbers)
                    if isinstance(value, torch.Tensor):
                        batch_meta[attr] = value.cpu().tolist() 
                    elif isinstance(value, list):
                        batch_meta[attr] = value
                    else: # Handle potential single values if batch size is 1
                         batch_meta[attr] = [value] * len(predictions)
            all_batch_metadata.append(batch_meta)
            
            batch_count += 1
            if batch_count % 50 == 0: # Print progress periodically
                 print(f"  Processed {batch_count * batch_size} / {len(dataset)} samples...")

    if all_predictions_scaled: # Check if inference ran without breaking early
        print(f"Inference completed. Processed {len(all_predictions_scaled)} batches.")
        # Concatenate predictions from all batches
        predictions_scaled_np = np.vstack(all_predictions_scaled)
        print(f"Shape of concatenated scaled predictions: {predictions_scaled_np.shape}")
    else:
        print("Inference did not produce results (possibly due to errors or empty dataset).")
        predictions_scaled_np = np.array([]) # Ensure it's an empty array

else:
    print("Skipping inference because model or data loader is not available.")
    predictions_scaled_np = np.array([])

Running inference...
  Processed 1600 / 1580 samples...
Inference completed. Processed 50 batches.
Shape of concatenated scaled predictions: (1580, 2)


## 8. Process and Save Predictions

Combine the predictions with identifiers, apply inverse scaling if the model has scaler information, and save the results to a CSV file and numpy arrays.

In [24]:
results_df = pd.DataFrame()

if predictions_scaled_np.size > 0 and model is not None and target_fields is not None:
    print("Processing predictions...")
    
    # Apply inverse scaling if scaler is available
    predictions_final = {} # Dictionary to hold final (potentially unscaled) predictions
    has_scaler = hasattr(model, 'scaler') and model.scaler is not None and len(model.scaler) == len(target_fields)

    for i, target_name in enumerate(target_fields):
        if i < predictions_scaled_np.shape[1]:
            target_preds_scaled = predictions_scaled_np[:, i].reshape(-1, 1)
            if has_scaler:
                try:
                    # Assume model.scaler is a list/tuple of scaler objects (e.g., StandardScaler)
                    target_preds_unscaled = model.scaler[i].inverse_transform(target_preds_scaled)
                    predictions_final[target_name] = target_preds_unscaled.flatten()
                    print(f"  Applied inverse scaling for target: '{target_name}'")
                except Exception as e:
                    print(f"  Warning: Could not inverse scale target '{target_name}'. Error: {e}. Using scaled values.")
                    predictions_final[target_name] = target_preds_scaled.flatten() # Use scaled if error
            else:
                predictions_final[target_name] = target_preds_scaled.flatten() # Use scaled if no scaler
        else:
            print(f"  Warning: Prediction array has fewer columns ({predictions_scaled_np.shape[1]}) than target fields ({len(target_fields)}). Skipping target '{target_name}'.")

    if not has_scaler:
         print("  No scaler found or scaler mismatch. Final predictions are the scaled model outputs.")
            
    # Create output DataFrame
    # Start with metadata (IDs, reaction strings, etc.)
    df_metadata = {}
    if all_batch_metadata:
        # Concatenate metadata from all batches
        first_batch_keys = all_batch_metadata[0].keys()
        for key in first_batch_keys:
             df_metadata[key] = [item for batch_meta in all_batch_metadata for item in batch_meta.get(key, [])]
    
    try:
        results_df = pd.DataFrame(df_metadata)
        
        # Add predicted values (potentially unscaled)
        for target_name, preds in predictions_final.items():
             # Ensure length matches DataFrame length
             if len(preds) == len(results_df):
                  results_df[f'{target_name}_predicted'] = preds
             else:
                  print(f"  Warning: Length mismatch for predicted target '{target_name}' ({len(preds)}) vs metadata ({len(results_df)}). Skipping column.")
                  # Optionally pad or handle mismatch differently
                  # results_df[f'{target_name}_predicted'] = [np.nan] * len(results_df) 

        # Reorder columns for better readability (optional)
        id_cols = [col for col in ['reaction_id', 'id', 'reaction'] if col in results_df.columns]
        pred_cols = sorted([col for col in results_df.columns if col.endswith('_predicted')])
        other_cols = sorted([col for col in results_df.columns if col not in id_cols and col not in pred_cols])
        results_df = results_df[id_cols + pred_cols + other_cols]
        
        # Save predictions
        results_df.to_csv(output_csv, index=False)
        print(f"\nPredictions successfully saved to: {output_csv}")
        
        # Save raw scaled predictions as numpy array
        scaled_npy_path = os.path.join(output_dir, 'predictions_scaled.npy')
        np.save(scaled_npy_path, predictions_scaled_np)
        print(f"Raw scaled predictions saved to: {scaled_npy_path}")

        # Save final (potentially unscaled) predictions as numpy array
        if predictions_final:
            try:
                 # Ensure all arrays have the same length before stacking
                 valid_preds = {k: v for k, v in predictions_final.items() if len(v) == predictions_scaled_np.shape[0]}
                 if len(valid_preds) == len(target_fields):
                    final_preds_array = np.stack([valid_preds[key] for key in target_fields], axis=1)
                    final_npy_path = os.path.join(output_dir, 'predictions_final.npy')
                    np.save(final_npy_path, final_preds_array)
                    print(f"Final (unscaled) predictions saved to: {final_npy_path}")
                 else:
                    print("Could not save final predictions numpy array due to inconsistent lengths or missing targets.")
            except Exception as e:
                 print(f"Error saving final predictions numpy array: {e}")

        print(f"\nTotal number of predictions generated: {len(results_df)}")
        if len(results_df) > 0:
            print("\nSample predictions (first 5 rows):")
            print(results_df.head())

    except Exception as e:
        print(f"Error creating or saving DataFrame: {e}")
        print("Please check the collected metadata and predictions.")

elif predictions_scaled_np.size == 0:
    print("No predictions were generated, skipping saving.")
else:
    print("Model or target fields missing, skipping processing and saving.")


Processing predictions...
  Applied inverse scaling for target: 'G(TS)'
  Applied inverse scaling for target: 'DrG'

Predictions successfully saved to: ./predictions.csv
Raw scaled predictions saved to: ./predictions/predictions_scaled.npy
Final (unscaled) predictions saved to: ./predictions/predictions_final.npy

Total number of predictions generated: 1580

Sample predictions (first 5 rows):
  reaction_id               id  \
0     ID63623      reaction_R0   
1     ID86062      reaction_R1   
2     ID52093     reaction_R10   
3     ID31786    reaction_R100   
4     ID30289  reaction_R10166   

                                            reaction  DrG_predicted  \
0  [C:1](=[C:2]([C:3](=[C:4]([H:11])[H:12])[H:10]...    -140.557632   
1  [C:6](=[C:7]([H:14])[H:15])([H:12])[H:13].[c:1...     -17.935015   
2  [C:1]([c:2]1[c:3]([H:12])[c:4]([H:13])[c:5]([H...       7.588223   
3  [N:6](/[C:7](=[C:8](\[N:9]([H:20])[H:21])[H:19...     -67.345047   
4  [C:1]([C:2](=[C:3]([C:4](=[C:5]([H:24])[H