# Prediction Demo for DeepReaction

This notebook demonstrates how to use a trained molecular reaction prediction model to make predictions using the DeepReaction framework.

## 1. Import Required Libraries

In [1]:
import os
import sys
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

# Import from deepreaction package
from deepreaction.config.config import Config
from deepreaction.data.load_Reaction import load_reaction_for_inference
from torch_geometric.loader import DataLoader
from deepreaction.core.predictor import ReactionPredictor

## 2. Define Prediction Parameters

All parameters are defined in a single dictionary for simplicity.

In [2]:
# Define all parameters in a single dictionary
params = {
    # Dataset parameters
    'dataset_root': './dataset/DATASET_DA_F',  # Adjust path if needed
    'dataset_csv': './dataset/DATASET_DA_F/dataset_xtb_final.csv', # Adjust path if needed
    'target_fields': ['G(TS)', 'DrG'],
    'input_features': ['G(TS)_xtb', 'DrG_xtb'],
    'file_patterns': ['*_reactant.xyz', '*_ts.xyz', '*_product.xyz'],
    'id_field': 'ID',
    'dir_field': 'R_dir',
    'reaction_field': 'reaction',
    'random_seed': 42234,
    'inference_mode': True,  # Important: set to True for prediction
    
    # Prediction parameters
    'checkpoint_path': './results/reaction_model/checkpoints/best-epoch=0002-val_total_loss=0.1359-v1.ckpt',  # Path to trained model
    'output_dir': './predictions',  # Output directory for prediction results
    'batch_size': 16,
    'use_scaler': True,  # Should match what was used during training
    'num_workers': 4
}

## 3. Set Up GPU and Output Directory

In [3]:
# Setup GPU or CPU
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(device)}")
    params['gpu'] = True
else:
    device = torch.device("cpu")
    print("Using CPU")
    params['gpu'] = False

# Create output directory
os.makedirs(params['output_dir'], exist_ok=True)
print(f"Output directory created/exists: {params['output_dir']}")

Using GPU: NVIDIA GeForce RTX 4090 D
Output directory created/exists: ./predictions


## 4. Load Data Directly for Inference

In [4]:
# Creating configuration is still useful for record-keeping
config = Config.from_params(params)
print("Configuration created successfully")

# Load data directly using load_reaction_for_inference instead of ReactionDataset
print("Loading data directly for inference...")
inference_data = load_reaction_for_inference(
    random_seed=params['random_seed'],
    root=params['dataset_root'],
    dataset_csv=params['dataset_csv'],
    file_patterns=params['file_patterns'],
    input_features=params['input_features'],
    id_field=params['id_field'],
    dir_field=params['dir_field'],
    reaction_field=params['reaction_field']
)

print(f"Loaded {len(inference_data)} samples for inference")

# Create a data loader for prediction
follow_batch = ['z0', 'z1', 'z2', 'pos0', 'pos1', 'pos2']
inference_loader = DataLoader(
    inference_data,
    batch_size=params['batch_size'],
    shuffle=False,
    num_workers=params['num_workers'],
    follow_batch=follow_batch
)

print(f"Created data loader with {len(inference_loader)} batches")

Configuration created successfully
Loading data directly for inference...
Error checking saved data: 'NoneType' object is not subscriptable
Inference mode: Using dummy target field
Using target fields: ['target']
Using input features: ['G(TS)_xtb', 'DrG_xtb']
Using file patterns: ['*_reactant.xyz', '*_ts.xyz', '*_product.xyz']


Processing reactions:  71%|███████   | 1126/1582 [00:00<00:00, 1873.81it/s]



Processing reactions: 100%|██████████| 1582/1582 [00:00<00:00, 1851.15it/s]


Processed 1580 reactions, skipped 2 reactions
Saved metadata to dataset/DATASET_DA_F/processed/metadata.json
Processed 1580 reactions, saved to dataset/DATASET_DA_F/processed/data_d9139bd83f9f.pt
Loaded 1580 samples for inference
Created data loader with 99 batches


## 5. Verify Checkpoint Exists

In [5]:
# Check that the model checkpoint file exists
if not os.path.exists(params['checkpoint_path']):
    raise FileNotFoundError(f"Checkpoint not found: {params['checkpoint_path']}")
else:
    print(f"Found checkpoint: {params['checkpoint_path']}")

FileNotFoundError: Checkpoint not found: ./results/reaction_model/checkpoints/best-epoch=0002-val_total_loss=0.1359-v1.ckpt

## 6. Initialize Predictor and Make Predictions

In [None]:
# Initialize the predictor with the trained model
predictor = ReactionPredictor(
    checkpoint_path=params['checkpoint_path'],
    output_dir=params['output_dir'],
    batch_size=params['batch_size'],
    gpu=params['gpu'],
    num_workers=params['num_workers'],
    use_scaler=params['use_scaler']
)
print("Predictor initialized successfully")

In [None]:
# Run prediction directly on the data loader
print("Starting prediction...")
predictions_df = predictor.predict(
    data_loader=inference_loader,
    csv_output_path=os.path.join(params['output_dir'], 'predictions.csv')
)
print("Prediction completed successfully")
print(f"Predictions saved to {os.path.join(params['output_dir'], 'predictions.csv')}")

## 7. Analyze Prediction Results

In [None]:
# Display prediction results
print("\nPreview of predictions:")
print(predictions_df.head())

# Calculate and display basic statistics for each target
target_stats = {}
for target in params['target_fields']:
    pred_col = f"{target}_predicted"
    if pred_col in predictions_df.columns:
        pred_values = predictions_df[pred_col].values
        stats = {
            'min': pred_values.min(),
            'max': pred_values.max(),
            'mean': pred_values.mean(),
            'std': pred_values.std()
        }
        target_stats[target] = stats
        
        print(f"\nStatistics for {pred_col}:")
        print(f"Min: {stats['min']:.4f}")
        print(f"Max: {stats['max']:.4f}")
        print(f"Mean: {stats['mean']:.4f}")
        print(f"Std: {stats['std']:.4f}")

## 8. Visualize Predictions

In [None]:
# Create histograms of prediction distributions
plt.figure(figsize=(12, 5))

for i, target in enumerate(params['target_fields']):
    pred_col = f"{target}_predicted"
    if pred_col in predictions_df.columns:
        plt.subplot(1, len(params['target_fields']), i+1)
        plt.hist(predictions_df[pred_col], bins=20, alpha=0.7)
        plt.title(f"{target} Predictions")
        plt.xlabel("Predicted Value")
        plt.ylabel("Frequency")
        plt.axvline(x=target_stats[target]['mean'], color='r', linestyle='--', 
                   label=f"Mean: {target_stats[target]['mean']:.2f}")
        plt.legend()

plt.tight_layout()
plt.savefig(os.path.join(params['output_dir'], 'prediction_histogram.png'))
print(f"Saved prediction histogram to {os.path.join(params['output_dir'], 'prediction_histogram.png')}")
plt.show()

In [None]:
# Create correlation plots between input features and predictions
if params['input_features']:
    try:
        # Load original CSV to get input features
        original_df = pd.read_csv(params['dataset_csv'])
        
        # Merge with predictions
        merged_df = pd.merge(
            predictions_df, 
            original_df[[params['id_field']] + params['input_features']], 
            on=params['id_field'], 
            how='left'
        )
        
        if not merged_df.empty:
            plt.figure(figsize=(12, 10))
            plot_idx = 1
            
            for target in params['target_fields']:
                pred_col = f"{target}_predicted"
                
                for feature in params['input_features']:
                    if pred_col in merged_df.columns and feature in merged_df.columns:
                        plt.subplot(len(params['target_fields']), len(params['input_features']), plot_idx)
                        plt.scatter(merged_df[feature], merged_df[pred_col], alpha=0.5)
                        plt.title(f"{target} vs {feature}")
                        plt.xlabel(feature)
                        plt.ylabel(f"{target} (Predicted)")
                        
                        # Add correlation line
                        z = np.polyfit(merged_df[feature], merged_df[pred_col], 1)
                        p = np.poly1d(z)
                        plt.plot(merged_df[feature], p(merged_df[feature]), "r--")
                        
                        corr = np.corrcoef(merged_df[feature], merged_df[pred_col])[0, 1]
                        plt.annotate(f"Corr: {corr:.2f}", xy=(0.05, 0.95), xycoords='axes fraction')
                        
                        plot_idx += 1
            
            plt.tight_layout()
            plt.savefig(os.path.join(params['output_dir'], 'feature_correlation.png'))
            print(f"Saved feature correlation plot to {os.path.join(params['output_dir'], 'feature_correlation.png')}")
            plt.show()
    except Exception as e:
        print(f"Error creating correlation plots: {e}")

## 9. Compare with Original Values (If Available)

In [None]:
# Optional: Compare predictions with original values if available
try:
    # Load original CSV
    original_df = pd.read_csv(params['dataset_csv'])
    
    # Merge with predictions
    comparison_df = pd.merge(
        predictions_df,
        original_df[[params['id_field']] + params['target_fields'] + params['input_features']],
        on=params['id_field'],
        how='left'
    )
    
    if not comparison_df.empty:
        # Create scatter plots for predicted vs actual values
        plt.figure(figsize=(15, 6))
        
        for i, target in enumerate(params['target_fields']):
            pred_col = f"{target}_predicted"
            
            if pred_col in comparison_df.columns and target in comparison_df.columns:
                plt.subplot(1, len(params['target_fields']), i+1)
                
                # Get min and max for setting plot limits
                min_val = min(comparison_df[target].min(), comparison_df[pred_col].min())
                max_val = max(comparison_df[target].max(), comparison_df[pred_col].max())
                
                # Plot points
                plt.scatter(comparison_df[target], comparison_df[pred_col], alpha=0.5)
                
                # Plot diagonal line (perfect prediction)
                plt.plot([min_val, max_val], [min_val, max_val], 'r--')
                
                # Calculate metrics
                from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
                mae = mean_absolute_error(comparison_df[target], comparison_df[pred_col])
                rmse = mean_squared_error(comparison_df[target], comparison_df[pred_col], squared=False)
                r2 = r2_score(comparison_df[target], comparison_df[pred_col])
                
                plt.title(f"{target} (Actual vs Predicted)\nMAE: {mae:.2f}, RMSE: {rmse:.2f}, R²: {r2:.2f}")
                plt.xlabel(f"Actual {target}")
                plt.ylabel(f"Predicted {target}")
                
                # Add grid for better readability
                plt.grid(True, linestyle='--', alpha=0.7)
                
        plt.tight_layout()
        plt.savefig(os.path.join(params['output_dir'], 'actual_vs_predicted.png'))
        print(f"Saved actual vs predicted plot to {os.path.join(params['output_dir'], 'actual_vs_predicted.png')}")
        plt.show()
        
        # Save comparison to CSV
        comparison_csv = os.path.join(params['output_dir'], 'comparison.csv')
        comparison_df.to_csv(comparison_csv, index=False)
        print(f"Saved comparison data to {comparison_csv}")
except Exception as e:
    print(f"Error creating comparison plots: {e}")