In [21]:
import wget
import pandas as pd
import numpy as np
import random
from tqdm import tqdm
from collections import defaultdict
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import mean_squared_error, r2_score, root_mean_squared_error
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import LabelEncoder
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import warnings
warnings.filterwarnings('ignore')

# Set random seeds
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# Load and prepare data
print("Loading Suzuki dataset...")
url = "https://github.com/open-reaction-database/ord-data/raw/main/data/68/ord_dataset-68cb8b4b2b384e3d85b5b1efae58b203.pb.gz"
filename = wget.download(url)

import ord_schema
from ord_schema import message_helpers, validations
from ord_schema.proto import dataset_pb2

data = message_helpers.load_message(filename, dataset_pb2.Dataset)
validations.validate_message(data)
df = message_helpers.messages_to_dataframe(data.reactions, drop_constant_columns=True)

model_cols = [
    'inputs["Aryl Halide"].components[0].identifiers[0].value',
    'inputs["Boronate in Solvent"].components[0].identifiers[0].value',
    'inputs["Ligand in Solvent"].components[0].identifiers[0].value',
    'inputs["Base in Solvent"].components[0].identifiers[0].value',
    'inputs["Solvent_1"].components[0].identifiers[0].value',
    "outcomes[0].products[0].measurements[0].percentage.value",
]
df = df[model_cols].dropna()
df.columns = ["aryl_halide", "boronate", "ligand", "base", "solvent", "yield"]
df["yield"] = df["yield"] / 100

print(f"Loaded {len(df)} real reactions")

# Critical data split: MAXIMUM analysis data while preserving test integrity
analysis_df, test_df = train_test_split(df, test_size=0.25, random_state=42)  # Use 75% for analysis
train_df, val_df = train_test_split(analysis_df, test_size=0.15, random_state=42)  # Small validation

print(f"Data allocation for maximum learning:")
print(f"  Analysis (for generator): {len(train_df)} reactions")
print(f"  Validation: {len(val_df)} reactions")
print(f"  Test (NEVER SEEN): {len(test_df)} reactions")

# Build vocabulary
aryl_halides = train_df["aryl_halide"].unique().tolist()
boronates = train_df["boronate"].unique().tolist()
ligands = train_df["ligand"].unique().tolist()
bases = train_df["base"].unique().tolist()
solvents = train_df["solvent"].unique().tolist()

target_stats = {
    'mean': train_df['yield'].mean(),
    'std': train_df['yield'].std(),
    'min': train_df['yield'].min(),
    'max': train_df['yield'].max()
}

print(f"Target statistics: mean={target_stats['mean']:.4f}, std={target_stats['std']:.4f}")

# =============================================================================
# GENERATE ULTIMATE SYNTHETIC DATA
# =============================================================================

print("\n" + "="*50)
print("GENERATING SYNTHETIC DATA")
print("="*50)

# Use the ultimate generator (assuming it's defined above)
# For brevity, I'll create a simplified but highly effective version here

class UltimateSimplifiedGenerator:
    """Simplified but highly effective ultimate generator"""
    
    def __init__(self, aryl_halides, boronates, ligands, bases, solvents, real_train_data, target_stats):
        self.aryl_halides = aryl_halides
        self.boronates = boronates
        self.ligands = ligands
        self.bases = bases
        self.solvents = solvents
        self.real_train_data = real_train_data
        self.target_stats = target_stats
        
        print("Building ultra-high fidelity generator...")
        self.extract_all_patterns()
        
    def extract_all_patterns(self):
        """Extract every possible pattern from real data"""
        
        overall_mean = self.real_train_data['yield'].mean()
        
        # Individual component effects
        self.component_effects = {}
        for col in ['aryl_halide', 'boronate', 'ligand', 'base', 'solvent']:
            effects = {}
            for comp in self.real_train_data[col].unique():
                subset_yields = self.real_train_data[self.real_train_data[col] == comp]['yield']
                effects[comp] = {
                    'mean_effect': subset_yields.mean() - overall_mean,
                    'std': subset_yields.std(),
                    'count': len(subset_yields),
                    'raw_mean': subset_yields.mean()
                }
            self.component_effects[col] = effects
        
        # Pairwise interactions (exhaustive)
        self.pairwise_interactions = {}
        component_cols = ['aryl_halide', 'boronate', 'ligand', 'base', 'solvent']
        
        for i, col1 in enumerate(component_cols):
            for col2 in component_cols[i+1:]:
                interactions = {}
                grouped = self.real_train_data.groupby([col1, col2])['yield']
                
                for (comp1, comp2), yields in grouped:
                    if len(yields) >= 1:
                        # Expected from individual effects
                        expected = (self.component_effects[col1][comp1]['mean_effect'] + 
                                  self.component_effects[col2][comp2]['mean_effect'])
                        # Actual combined effect
                        actual = yields.mean() - overall_mean
                        # Interaction = actual - expected
                        interaction = actual - expected
                        
                        interactions[(comp1, comp2)] = {
                            'effect': interaction,
                            'count': len(yields),
                            'std': yields.std(),
                            'raw_mean': yields.mean()
                        }
                
                self.pairwise_interactions[(col1, col2)] = interactions
        
        # Higher-order combinations (exact matches from real data)
        self.exact_combinations = {}
        for _, row in self.real_train_data.iterrows():
            key = (row['aryl_halide'], row['boronate'], row['ligand'], row['base'], row['solvent'])
            if key not in self.exact_combinations:
                self.exact_combinations[key] = []
            self.exact_combinations[key].append(row['yield'])
        
        # Convert to statistics
        for key, yields in self.exact_combinations.items():
            self.exact_combinations[key] = {
                'mean': np.mean(yields),
                'std': np.std(yields) if len(yields) > 1 else 0.05,
                'count': len(yields)
            }
        
        # Advanced noise modeling
        self.model_residual_structure()
        
        print(f"Pattern extraction complete:")
        print(f"  Component effects: {sum(len(x) for x in self.component_effects.values())}")
        print(f"  Pairwise interactions: {sum(len(x) for x in self.pairwise_interactions.values())}")
        print(f"  Exact combinations: {len(self.exact_combinations)}")
    
    def model_residual_structure(self):
        """Model the residual structure for ultra-realistic noise"""
        
        # Predict yields using additive model
        predicted_yields = []
        actual_yields = []
        
        for _, row in self.real_train_data.iterrows():
            pred = self.target_stats['mean']
            
            # Add main effects
            for col in ['aryl_halide', 'boronate', 'ligand', 'base', 'solvent']:
                pred += self.component_effects[col][row[col]]['mean_effect']
            
            predicted_yields.append(pred)
            actual_yields.append(row['yield'])
        
        residuals = np.array(actual_yields) - np.array(predicted_yields)
        
        # Yield-dependent noise model (high resolution)
        self.noise_model = {}
        n_bins = 50  # High resolution
        yield_bins = np.linspace(0, 1, n_bins + 1)
        
        for i in range(n_bins):
            mask = ((np.array(predicted_yields) >= yield_bins[i]) & 
                   (np.array(predicted_yields) < yield_bins[i+1]))
            
            if np.sum(mask) > 0:
                bin_residuals = residuals[mask]
                self.noise_model[i] = {
                    'mean': np.mean(bin_residuals),
                    'std': max(np.std(bin_residuals), 0.01),  # Minimum noise
                    'count': np.sum(mask)
                }
            else:
                # Interpolate from nearby bins
                self.noise_model[i] = {
                    'mean': 0.0,
                    'std': 0.05,
                    'count': 0
                }
        
        self.global_noise_std = np.std(residuals)
        print(f"Noise model: {n_bins} bins, global std = {self.global_noise_std:.4f}")
    
    def predict_yield(self, aryl, boronate, ligand, base, solvent):
        """Ultra-high fidelity yield prediction"""
        
        # Check for exact match first
        exact_key = (aryl, boronate, ligand, base, solvent)
        if exact_key in self.exact_combinations:
            # Use exact data with small noise
            exact_data = self.exact_combinations[exact_key]
            base_yield = exact_data['mean']
            noise_std = max(exact_data['std'], 0.02)
            noise = np.random.normal(0, noise_std)
            return np.clip(base_yield + noise, 0.001, 0.999)
        
        # Multi-model ensemble for non-exact matches
        predictions = []
        weights = []
        
        # Model 1: Additive with interactions (40% weight)
        pred1 = self.predict_additive_with_interactions(aryl, boronate, ligand, base, solvent)
        predictions.append(pred1)
        weights.append(0.4)
        
        # Model 2: Similarity-based prediction (30% weight)
        pred2 = self.predict_similarity_based(aryl, boronate, ligand, base, solvent)
        predictions.append(pred2)
        weights.append(0.3)
        
        # Model 3: Pattern matching (30% weight)
        pred3 = self.predict_pattern_matching(aryl, boronate, ligand, base, solvent)
        predictions.append(pred3)
        weights.append(0.3)
        
        # Ensemble prediction
        ensemble_pred = np.average(predictions, weights=weights)
        
        # Add realistic noise
        yield_bin = int(np.clip(ensemble_pred * 50, 0, 49))
        if yield_bin in self.noise_model:
            noise_params = self.noise_model[yield_bin]
            noise = np.random.normal(noise_params['mean'], noise_params['std'])
        else:
            noise = np.random.normal(0, self.global_noise_std)
        
        final_yield = ensemble_pred + noise
        return np.clip(final_yield, 0.001, 0.999)
    
    def predict_additive_with_interactions(self, aryl, boronate, ligand, base, solvent):
        """Additive model with all learned interactions"""
        
        yield_pred = self.target_stats['mean']
        
        # Main effects
        yield_pred += self.component_effects['aryl_halide'][aryl]['mean_effect']
        yield_pred += self.component_effects['boronate'][boronate]['mean_effect']
        yield_pred += self.component_effects['ligand'][ligand]['mean_effect']
        yield_pred += self.component_effects['base'][base]['mean_effect']
        yield_pred += self.component_effects['solvent'][solvent]['mean_effect']
        
        # All pairwise interactions
        interaction_pairs = [
            ('aryl_halide', 'boronate', aryl, boronate),
            ('aryl_halide', 'ligand', aryl, ligand),
            ('aryl_halide', 'base', aryl, base),
            ('aryl_halide', 'solvent', aryl, solvent),
            ('boronate', 'ligand', boronate, ligand),
            ('boronate', 'base', boronate, base),
            ('boronate', 'solvent', boronate, solvent),
            ('ligand', 'base', ligand, base),
            ('ligand', 'solvent', ligand, solvent),
            ('base', 'solvent', base, solvent)
        ]
        
        for col1, col2, comp1, comp2 in interaction_pairs:
            if (col1, col2) in self.pairwise_interactions:
                interactions = self.pairwise_interactions[(col1, col2)]
                if (comp1, comp2) in interactions:
                    interaction_effect = interactions[(comp1, comp2)]['effect']
                    confidence = min(interactions[(comp1, comp2)]['count'] / 3.0, 1.0)
                    yield_pred += interaction_effect * confidence
        
        return yield_pred
    
    def predict_similarity_based(self, aryl, boronate, ligand, base, solvent):
        """Similarity-based prediction using closest matches"""
        
        similarities = []
        
        # Find similar reactions in training data
        for _, row in self.real_train_data.iterrows():
            similarity = 0
            
            # Exact component matches
            if row['aryl_halide'] == aryl:
                similarity += 10
            if row['boronate'] == boronate:
                similarity += 6
            if row['ligand'] == ligand:
                similarity += 10  # Ligand is critical
            if row['base'] == base:
                similarity += 4
            if row['solvent'] == solvent:
                similarity += 6
            
            # Partial matches for chemically similar components
            if similarity >= 5:  # Only consider reasonable matches
                similarities.append((similarity, row['yield']))
        
        if similarities:
            # Weight by similarity and take top matches
            similarities.sort(reverse=True)
            top_matches = similarities[:min(10, len(similarities))]
            
            if top_matches:
                weights = [sim**2 for sim, _ in top_matches]  # Square for emphasis
                yields = [yield_val for _, yield_val in top_matches]
                
                if sum(weights) > 0:
                    return np.average(yields, weights=weights)
        
        # Fallback to component averages
        return self.predict_component_average(aryl, boronate, ligand, base, solvent)
    
    def predict_pattern_matching(self, aryl, boronate, ligand, base, solvent):
        """Advanced pattern matching"""
        
        # Look for partial combinations in exact matches
        partial_scores = []
        
        for exact_key, exact_data in self.exact_combinations.items():
            e_aryl, e_boronate, e_ligand, e_base, e_solvent = exact_key
            
            score = 0
            # Critical matches
            if e_aryl == aryl and e_ligand == ligand:  # Key combination
                score += 20
            if e_boronate == boronate and e_base == base:  # Another key combination
                score += 15
            
            # Individual matches
            if e_aryl == aryl:
                score += 5
            if e_boronate == boronate:
                score += 3
            if e_ligand == ligand:
                score += 8
            if e_base == base:
                score += 2
            if e_solvent == solvent:
                score += 4
            
            if score >= 10:  # Meaningful similarity
                partial_scores.append((score, exact_data['mean']))
        
        if partial_scores:
            partial_scores.sort(reverse=True)
            top_partials = partial_scores[:5]
            
            weights = [score for score, _ in top_partials]
            yields = [yield_val for _, yield_val in top_partials]
            
            if sum(weights) > 0:
                return np.average(yields, weights=weights)
        
        # Final fallback
        return self.predict_component_average(aryl, boronate, ligand, base, solvent)
    
    def predict_component_average(self, aryl, boronate, ligand, base, solvent):
        """Weighted average of component performances"""
        
        # Weight by importance and data availability
        aryl_mean = self.component_effects['aryl_halide'][aryl]['raw_mean']
        aryl_weight = self.component_effects['aryl_halide'][aryl]['count']
        
        boronate_mean = self.component_effects['boronate'][boronate]['raw_mean']
        boronate_weight = self.component_effects['boronate'][boronate]['count']
        
        ligand_mean = self.component_effects['ligand'][ligand]['raw_mean']
        ligand_weight = self.component_effects['ligand'][ligand]['count'] * 2  # Ligand is more important
        
        base_mean = self.component_effects['base'][base]['raw_mean']
        base_weight = self.component_effects['base'][base]['count']
        
        solvent_mean = self.component_effects['solvent'][solvent]['raw_mean']
        solvent_weight = self.component_effects['solvent'][solvent]['count']
        
        # Weighted average
        total_weight = aryl_weight + boronate_weight + ligand_weight + base_weight + solvent_weight
        
        if total_weight > 0:
            weighted_yield = (
                aryl_mean * aryl_weight +
                boronate_mean * boronate_weight +
                ligand_mean * ligand_weight +
                base_mean * base_weight +
                solvent_mean * solvent_weight
            ) / total_weight
            
            return weighted_yield
        else:
            return self.target_stats['mean']
    
    def generate_dataset(self, n_samples):
        """Generate ultimate quality dataset"""
        
        data = []
        
        print(f"Generating {n_samples} ultra-high fidelity reactions...")
        
        # Strategic sampling for maximum realism
        real_combinations = list(self.exact_combinations.keys())
        
        for _ in tqdm(range(n_samples)):
            
            # 60% exact real combinations (with noise)
            if np.random.random() < 0.6 and len(real_combinations) > 0:
                aryl, boronate, ligand, base, solvent = random.choice(real_combinations)
            
            # 30% realistic partial combinations
            elif np.random.random() < 0.9:
                # Sample aryl-ligand from real data, others random
                real_aryl_ligand = [(k[0], k[2]) for k in real_combinations]
                aryl, ligand = random.choice(real_aryl_ligand)
                boronate = random.choice(self.boronates)
                base = random.choice(self.bases)
                solvent = random.choice(self.solvents)
            
            # 10% full exploration
            else:
                aryl = random.choice(self.aryl_halides)
                boronate = random.choice(self.boronates)
                ligand = random.choice(self.ligands)
                base = random.choice(self.bases)
                solvent = random.choice(self.solvents)
            
            yield_val = self.predict_yield(aryl, boronate, ligand, base, solvent)
            
            data.append({
                "aryl_halide": aryl,
                "boronate": boronate,
                "ligand": ligand,
                "base": base,
                "solvent": solvent,
                "yield": yield_val
            })
        
        df = pd.DataFrame(data)
        
        # Perfect statistical calibration
        self.calibrate_distribution(df)
        
        return df
    
    def calibrate_distribution(self, df):
        """Perfect distribution calibration"""
        
        # Match moments exactly
        current_mean = df['yield'].mean()
        current_std = df['yield'].std()
        
        # Linear transformation for mean and std
        df['yield'] = (df['yield'] - current_mean) / current_std * self.target_stats['std'] + self.target_stats['mean']
        
        # Match higher moments (skewness, kurtosis)
        real_skew = self.compute_skewness(self.real_train_data['yield'])
        syn_skew = self.compute_skewness(df['yield'])
        
        # Apply bounds
        df['yield'] = np.clip(df['yield'], 0.001, 0.999)
        
        print(f"Perfect calibration complete:")
        print(f"  Mean: {df['yield'].mean():.4f} (target: {self.target_stats['mean']:.4f})")
        print(f"  Std:  {df['yield'].std():.4f} (target: {self.target_stats['std']:.4f})")
        print(f"  Min:  {df['yield'].min():.4f}")
        print(f"  Max:  {df['yield'].max():.4f}")
        print(f"  Skewness: {self.compute_skewness(df['yield']):.3f} (real: {real_skew:.3f})")
    
    def compute_skewness(self, data):
        """Compute skewness"""
        if len(data) < 3:
            return 0
        mean = np.mean(data)
        std = np.std(data)
        if std == 0:
            return 0
        return np.mean(((data - mean) / std) ** 3)

# Generate ultimate synthetic data
generator = UltimateSimplifiedGenerator(
    aryl_halides, boronates, ligands, bases, solvents, train_df, target_stats
)

# Generate massive, ultra-high quality dataset
synthetic_df = generator.generate_dataset(500_000)

print(f"\nUltimate dataset generated: {len(synthetic_df)} reactions")

# =============================================================================
# ULTIMATE MODEL ARCHITECTURE (Embedding-Based SGNN Version)
# =============================================================================
class HybridWideDeepSGNN(nn.Module):
    def __init__(self, aryls, boronates, ligands, bases, solvents, emb_dim=128):
        super().__init__()
        self.aryl_emb     = nn.Embedding(len(aryls),     emb_dim)
        self.boronate_emb = nn.Embedding(len(boronates), emb_dim)
        self.ligand_emb   = nn.Embedding(len(ligands),   emb_dim)
        self.base_emb     = nn.Embedding(len(bases),     emb_dim)
        self.solv_emb     = nn.Embedding(len(solvents),  emb_dim)

        self.total_emb = emb_dim * 5
        self.interact  = nn.Linear(self.total_emb, 128)
        self.wide      = nn.Linear(self.total_emb, 1)

        self.deep = nn.Sequential(
            nn.Linear(self.total_emb + 128, 768), nn.ReLU(), nn.BatchNorm1d(768), nn.Dropout(0.15),
            nn.Linear(768, 768), nn.ReLU(), nn.BatchNorm1d(768), nn.Dropout(0.15),
            nn.Linear(768, 384), nn.ReLU(), nn.BatchNorm1d(384), nn.Dropout(0.10),
            nn.Linear(384, 128), nn.ReLU(), nn.Linear(128, 64),  nn.ReLU()
        )

        self.main_head = nn.Linear(64, 1)
        self.aux_head  = nn.Sequential(nn.Linear(64, 32), nn.ReLU(), nn.Linear(32, 1))
        self.conf_head = nn.Sequential(nn.Linear(64, 32), nn.ReLU(), nn.Linear(32, 1), nn.Sigmoid())

        self.ens_w = nn.Parameter(torch.tensor([0.7, 0.3]))

    def forward(self, a, b, l, c, s):
        x = torch.cat([
            self.aryl_emb(a), self.boronate_emb(b), self.ligand_emb(l),
            self.base_emb(c), self.solv_emb(s)
        ], dim=1)

        wide_out = self.wide(x)
        inter    = torch.relu(self.interact(x))
        deep_in  = torch.cat([x, inter], dim=1)
        deep_out = self.deep(deep_in)

        main = self.main_head(deep_out) + wide_out
        aux  = self.aux_head(deep_out)
        conf = self.conf_head(deep_out)

        w = torch.softmax(self.ens_w, dim=0)
        y_hat = w[0]*main + w[1]*aux
        return {"yield": y_hat, "aux_pred": aux, "confidence": conf}




print("\n" + "="*50)
print("TRAINING MODEL")
print("="*50)

# Build categorical vocabularies and mappings for embeddings
aryl2idx = {name: idx for idx, name in enumerate(aryl_halides)}
boronate2idx = {name: idx for idx, name in enumerate(boronates)}
ligand2idx = {name: idx for idx, name in enumerate(ligands)}
base2idx = {name: idx for idx, name in enumerate(bases)}
solvent2idx = {name: idx for idx, name in enumerate(solvents)}

# Map synthetic data into indices for embeddings
def map_to_indices(df):
    return (
        df["aryl_halide"].map(aryl2idx).values,
        df["boronate"].map(boronate2idx).values,
        df["ligand"].map(ligand2idx).values,
        df["base"].map(base2idx).values,
        df["solvent"].map(solvent2idx).values,
        df["yield"].values
    )

# Process synthetic, val, and test sets
aryls_syn, boronates_syn, ligands_syn, bases_syn, solvents_syn, yields_syn = map_to_indices(synthetic_df)
aryls_val, boronates_val, ligands_val, bases_val, solvents_val, yields_val = map_to_indices(val_df)
aryls_test, boronates_test, ligands_test, bases_test, solvents_test, yields_test = map_to_indices(test_df)

# Build TensorDatasets
train_dataset = TensorDataset(
    torch.tensor(aryls_syn, dtype=torch.long),
    torch.tensor(boronates_syn, dtype=torch.long),
    torch.tensor(ligands_syn, dtype=torch.long),
    torch.tensor(bases_syn, dtype=torch.long),
    torch.tensor(solvents_syn, dtype=torch.long),
    torch.tensor(yields_syn, dtype=torch.float32).unsqueeze(1)
)

val_dataset = TensorDataset(
    torch.tensor(aryls_val, dtype=torch.long),
    torch.tensor(boronates_val, dtype=torch.long),
    torch.tensor(ligands_val, dtype=torch.long),
    torch.tensor(bases_val, dtype=torch.long),
    torch.tensor(solvents_val, dtype=torch.long),
    torch.tensor(yields_val, dtype=torch.float32).unsqueeze(1)
)

test_dataset = TensorDataset(
    torch.tensor(aryls_test, dtype=torch.long),
    torch.tensor(boronates_test, dtype=torch.long),
    torch.tensor(ligands_test, dtype=torch.long),
    torch.tensor(bases_test, dtype=torch.long),
    torch.tensor(solvents_test, dtype=torch.long),
    torch.tensor(yields_test, dtype=torch.float32).unsqueeze(1)
)

# Build dataloaders
train_loader = DataLoader(train_dataset, batch_size=2048, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2048)
test_loader = DataLoader(test_dataset, batch_size=2048)

# Define model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HybridWideDeepSGNN(aryl_halides, boronates, ligands, bases, solvents, emb_dim=128).to(device)

# Define optimizer and loss
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(len(syn_loader)*3), eta_min=1e-6)
criterion = nn.MSELoss()

# Training loop
best_val_r2 = -float('inf')
patience = 30
patience_counter = 0


for epoch in range(3):
    model.train()
    train_losses = []
    
    for a, b, l, c, s, y in train_loader:
        a, b, l, c, s, y = a.to(device), b.to(device), l.to(device), c.to(device), s.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(a, b, l, c, s)
        loss = criterion(out['yield'], y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        train_losses.append(loss.item())

    scheduler.step()

    # Validation eval
    model.eval()
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for a, b, l, c, s, y in val_loader:
            a, b, l, c, s = a.to(device), b.to(device), l.to(device), c.to(device), s.to(device)
            y = y.to(device)
            out = model(a, b, l, c, s)
            preds = out['yield'].cpu().numpy().flatten()
            all_preds.append(preds)
            all_targets.append(y.cpu().numpy().flatten())
    all_preds = np.concatenate(all_preds)
    all_targets = np.concatenate(all_targets)
    val_r2 = r2_score(all_targets, all_preds)
    val_rmse = root_mean_squared_error(all_targets, all_preds)
    
    print(f"Epoch {epoch+1:3d}: Train Loss = {np.mean(train_losses):.5f} | Val R² = {val_r2:.4f} | Val RMSE = {val_rmse:.4f}")
    
    if val_r2 > best_val_r2:
        best_val_r2 = val_r2
        patience_counter = 0
        torch.save(model.state_dict(), "ultimate_embedding_model.pth")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

# ========================================
# FINAL TEST EVALUATION
# ========================================
model.load_state_dict(torch.load("ultimate_embedding_model.pth"))
model.eval()

all_preds = []
all_targets = []

with torch.no_grad():
    for a, b, l, c, s, y in test_loader:
        a, b, l, c, s = a.to(device), b.to(device), l.to(device), c.to(device), s.to(device)
        y = y.to(device)
        out = model(a, b, l, c, s)
        preds = out['yield'].cpu().numpy().flatten()
        all_preds.append(preds)
        all_targets.append(y.cpu().numpy().flatten())

all_preds = np.concatenate(all_preds)
all_targets = np.concatenate(all_targets)
test_r2 = r2_score(all_targets, all_preds)
test_rmse = root_mean_squared_error(all_targets, all_preds)
test_mae = np.mean(np.abs(all_preds - all_targets))

print("\n" + "="*70)
print(f"FINAL TEST EVALUATION:")
print(f"Test R²   = {test_r2:.4f}")
print(f"Test RMSE = {test_rmse:.4f}")
print(f"Test MAE  = {test_mae:.4f}")
print("="*70)


ULTIMATE CHALLENGE: BEAT R² = 0.85 WITH MECHANISTIC SYNTHETIC DATA
Loading Suzuki dataset...
100% [........................................................] 269731 / 269731Loaded 5760 real reactions
Data allocation for maximum learning:
  Analysis (for generator): 3672 reactions
  Validation: 648 reactions
  Test (NEVER SEEN): 1440 reactions
Target statistics: mean=0.3993, std=0.2797

GENERATING ULTIMATE SYNTHETIC DATA
Building ultra-high fidelity generator...
Noise model: 50 bins, global std = 0.2012
Pattern extraction complete:
  Component effects: 35
  Pairwise interactions: 455
  Exact combinations: 3672
Generating 500000 ultra-high fidelity reactions...


100%|█████████████████████████████████| 500000/500000 [1:58:37<00:00, 70.25it/s]


Perfect calibration complete:
  Mean: 0.3996 (target: 0.3993)
  Std:  0.2792 (target: 0.2797)
  Min:  0.0010
  Max:  0.9780
  Skewness: 0.408 (real: 0.459)

Ultimate dataset generated: 500000 reactions

TRAINING ULTIMATE EMBEDDING MODEL

Ultimate embedding-based training protocol...
Epoch   1: Train Loss = 0.02913 | Val R² = 0.8438 | Val RMSE = 0.1109
Epoch   2: Train Loss = 0.01328 | Val R² = 0.8370 | Val RMSE = 0.1133
Epoch   3: Train Loss = 0.01237 | Val R² = 0.8372 | Val RMSE = 0.1132

FINAL TEST EVALUATION:
Test R²   = 0.8483
Test RMSE = 0.1105
Test MAE  = 0.0787
