In [1]:
!pip install optuna



In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit import RDLogger
import torch
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GATConv, global_mean_pool, global_add_pool
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
import random
import warnings
from matplotlib.colors import LinearSegmentedColormap, Normalize
from io import BytesIO
from PIL import Image
import matplotlib.cm as cm

# Add necessary imports
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors


from torch_geometric.nn import GATv2Conv, GlobalAttention
from torch.optim.swa_utils import AveragedModel, SWALR
from sklearn.metrics import explained_variance_score, max_error, median_absolute_error


# Suppress RDKit warnings
RDLogger.DisableLog('rdApp.*')
warnings.filterwarnings('ignore')

plot_dir_name = 'plots_by_c/'

In [3]:
# Basic data analysis
def analyze_data(df):
    print("\nData Statistics:")
    print(df.describe())
    
    # Plot Tg distribution
    plt.figure(figsize=(10, 6))
    sns.histplot(df['tg'], bins=30, kde=True)
    plt.title('Distribution of Glass Transition Temperatures (Tg)')
    plt.xlabel('Tg (°C)')
    plt.ylabel('Count')
    plt.grid(alpha=0.3)
    plt.savefig(f'{plot_dir_name}tg_distribution.png')
    plt.close()
    
    # Analyze SMILES complexity
    df['smiles_length'] = df['SMILES'].apply(len)
    
    plt.figure(figsize=(10, 6))
    sns.histplot(df['smiles_length'], bins=30, kde=True)
    plt.title('Distribution of SMILES String Lengths')
    plt.xlabel('Length')
    plt.ylabel('Count')
    plt.grid(alpha=0.3)
    plt.savefig(f'{plot_dir_name}smiles_length_distribution.png')
    plt.close()
    
    # Correlation between SMILES length and Tg
    plt.figure(figsize=(10, 6))
    sns.scatterplot(x='smiles_length', y='tg', data=df, alpha=0.5)
    plt.title('Relationship Between Molecule Complexity and Tg')
    plt.xlabel('SMILES String Length')
    plt.ylabel('Tg (°C)')
    plt.grid(alpha=0.3)
    plt.savefig(f'{plot_dir_name}length_vs_tg.png')
    plt.close()
    
    return df


# Prepare the dataset
def prepare_dataset(dataframe, smiles_col='SMILES', target_col='tg'):
    data_list = []
    valid_indices = []
    invalid_smiles = []
    
    for idx, row in dataframe.iterrows():
        smiles = row[smiles_col]
        graph = smiles_to_graph(smiles)
        if graph is not None:
            # Add target value
            graph.y = torch.tensor([row[target_col]], dtype=torch.float)
            graph.smiles = smiles  # Store SMILES for reference
            data_list.append(graph)
            valid_indices.append(idx)
        else:
            invalid_smiles.append((idx, smiles))
    
    if invalid_smiles:
        print(f"\nWarning: {len(invalid_smiles)} invalid SMILES strings found and skipped.")
        
    return data_list, valid_indices


# Convert SMILES to molecular graphs
def smiles_to_graph(smiles):
    # Convert SMILES to RDKit molecule
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    
    # Get atom features
    atom_features = []
    for atom in mol.GetAtoms():
        # Atom features: One-hot encoding of atom type, formal charge, hybridization, aromaticity
        atom_type_one_hot = [0] * 100  # Limit to first 100 elements
        atom_num = atom.GetAtomicNum()
        if atom_num < 100:
            atom_type_one_hot[atom_num] = 1
            
        formal_charge = [atom.GetFormalCharge()]
        hybridization_type = [0, 0, 0, 0, 0]  # One-hot encoding of hybridization
        hyb_type = atom.GetHybridization()
        if hyb_type == Chem.rdchem.HybridizationType.SP:
            hybridization_type[0] = 1
        elif hyb_type == Chem.rdchem.HybridizationType.SP2:
            hybridization_type[1] = 1
        elif hyb_type == Chem.rdchem.HybridizationType.SP3:
            hybridization_type[2] = 1
        elif hyb_type == Chem.rdchem.HybridizationType.SP3D:
            hybridization_type[3] = 1
        elif hyb_type == Chem.rdchem.HybridizationType.SP3D2:
            hybridization_type[4] = 1
            
        is_aromatic = [1 if atom.GetIsAromatic() else 0]
        degree = [atom.GetDegree()]
        num_h = [atom.GetTotalNumHs()]
        
        # Combine all features
        features = atom_type_one_hot + formal_charge + hybridization_type + is_aromatic + degree + num_h
        atom_features.append(features)
    
    # Create node feature matrix (num_nodes x num_features)
    x = torch.tensor(atom_features, dtype=torch.float)
    
    # Get edge indices (bonds)
    edge_indices = []
    edge_attr = []
    
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        
        # Bond features
        bond_type = bond.GetBondType()
        bond_features = [0, 0, 0, 0]  # One-hot encoding of bond type
        if bond_type == Chem.rdchem.BondType.SINGLE:
            bond_features[0] = 1
        elif bond_type == Chem.rdchem.BondType.DOUBLE:
            bond_features[1] = 1
        elif bond_type == Chem.rdchem.BondType.TRIPLE:
            bond_features[2] = 1
        elif bond_type == Chem.rdchem.BondType.AROMATIC:
            bond_features[3] = 1
            
        is_conjugated = [1 if bond.GetIsConjugated() else 0]
        is_in_ring = [1 if bond.IsInRing() else 0]
        
        # Combine all features
        features = bond_features + is_conjugated + is_in_ring
        
        # Add bonds in both directions
        edge_indices.append([i, j])
        edge_indices.append([j, i])
        edge_attr.append(features)
        edge_attr.append(features)
    
    if len(edge_indices) == 0:  # For molecules with no bonds
        edge_index = torch.zeros((2, 0), dtype=torch.long)
        edge_attr = torch.zeros((0, 6), dtype=torch.float)
    else:
        edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=smiles)


# If the above approach has issues with coordinates, here's an even simpler alternative
def visualize_molecule_with_color_overlay(smiles, atom_importances, tg, idx, prediction=None, plot_dir_name=""):
    """
    Visualize a molecule with atom importance by creating two overlaid images:
    1. A standard black and white molecule with atom indices
    2. Colored circles positioned approximately where atoms are
    
    Parameters are the same as previous function.
    """
    # Convert SMILES to molecule
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        print(f"Could not convert SMILES to molecule: {smiles}")
        return
    
    # Normalize importance values
    if len(atom_importances) > 0 and max(atom_importances) > 0:
        norm_importances = atom_importances / max(atom_importances)
    else:
        norm_importances = atom_importances
    
    # Set up the colormap
    cmap = plt.cm.coolwarm
    
    # Generate a standard molecule image with atom indices
    drawer = rdMolDraw2D.MolDraw2DCairo(800, 800)
    drawer.drawOptions().addAtomIndices = True
    drawer.DrawMolecule(mol)
    drawer.FinishDrawing()
    png_data = drawer.GetDrawingText()
    molecule_img = Image.open(BytesIO(png_data))
    
    # Create a figure
    fig, ax = plt.subplots(figsize=(12, 12))
    
    # Display the molecule image
    ax.imshow(molecule_img)
    ax.set_title(f"SMILES: {smiles}\nTg: {tg}°C\nPredicted Tg: {prediction:.2f}°C", fontsize=14)
    ax.axis('off')
    
    # Add colorbar
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(0, 1))
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax, shrink=0.7, pad=0.05, orientation='horizontal')
    cbar.set_label('Atom Importance', fontsize=14)
    cbar.set_ticks([0.0, 0.5, 1.0])
    cbar.set_ticklabels(['Low', 'Medium', 'High'])
    
    
    plt.figtext(0.5, 0.01, "Blue = Low Importance, Red = High Importance for Tg Prediction",
                ha="center", fontsize=12, 
                bbox={"facecolor":"lightgray", "alpha":0.3, "pad":5})
    
    plt.savefig(f'{plot_dir_name}molecule_attention_{idx}.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # Also print to console
    top_indices = np.argsort(atom_importances)[-5:][::-1]
    print(f"Top 5 important atoms for molecule {idx} (indices): {top_indices}")
    print(f"Importance values: {atom_importances[top_indices]}")


# Function to analyze model interpretation
def analyze_model_interpretation(model, data_loader, device, num_examples=5):
    model.eval()
    
    # Get a batch of examples
    examples = []
    predictions = []
    for data in data_loader:
        data = data.to(device)
        with torch.no_grad():
            output = model(data)
            
        for i in range(min(len(data.y), num_examples - len(examples))):
            if len(examples) >= num_examples:
                break
                
            # Extract single molecule
            single_data = data[i].to(device)
            
            # Get atom importance
            atom_importance = model.get_atom_importance(single_data)
            
            # Store example and prediction
            examples.append((single_data.smiles, atom_importance, single_data.y.item()))
            predictions.append(output[i].item())
            
        if len(examples) >= num_examples:
            break
    
    # Visualize examples with attention weights
    print("\nVisualizing molecules with attention weights...")
    for i, (smiles, atom_importance, tg) in enumerate(examples):
        #visualize_molecule_with_attention
        visualize_molecule_with_color_overlay(smiles, atom_importance, tg, i, predictions[i])
        
        # Print most important atoms
        top_indices = np.argsort(atom_importance)[-5:][::-1]
        print(f"\nMolecule {i+1} (SMILES: {smiles})")
        print(f"Actual Tg: {tg:.2f}°C, Predicted Tg: {predictions[i]:.2f}°C")
        print(f"Top 5 important atoms (indices): {top_indices}")
        print(f"Importance values: {atom_importance[top_indices]}")


# Early stopping
class EarlyStopping:
    def __init__(self, patience=10, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.counter = 0
        self.early_stop = False
        self.best_model_state = None
    
    def __call__(self, val_loss, model):
        if self.best_score is None:
            self.best_score = val_loss
            self.best_model_state = model.state_dict().copy()
        elif val_loss > self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_loss
            self.best_model_state = model.state_dict().copy()
            self.counter = 0
        return self.early_stop
    

# Evaluation function
def evaluate(model, loader, device):
    model.eval()
    predictions = []
    actual = []
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            output = model(data)
            predictions.extend(output.cpu().numpy())
            actual.extend(data.y.cpu().numpy())
    
    predictions = np.array(predictions)
    actual = np.array(actual)
    
    # Calculate metrics
    mse = mean_squared_error(actual, predictions)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(actual, predictions)
    r2 = r2_score(actual, predictions)
    
    return {
        'predictions': predictions,
        'actual': actual,
        'mse': mse,
        'rmse': rmse,
        'mae': mae,
        'r2': r2
    }



In [4]:

# Define the GAT Model with attention
class GATTgPredictor(torch.nn.Module):
    def __init__(self, node_features, edge_features, hidden_channels=64, heads=4):
        super(GATTgPredictor, self).__init__()
        
        # Graph attention layers - these will provide interpretability
        self.conv1 = GATConv(node_features, hidden_channels, heads=heads, dropout=0.2)
        self.conv2 = GATConv(hidden_channels*heads, hidden_channels, heads=heads, dropout=0.2)
        self.conv3 = GATConv(hidden_channels*heads, hidden_channels, heads=1, dropout=0.2)
        
        # Batch normalization for stability
        self.bn1 = torch.nn.BatchNorm1d(hidden_channels*heads)
        self.bn2 = torch.nn.BatchNorm1d(hidden_channels*heads)
        self.bn3 = torch.nn.BatchNorm1d(hidden_channels)
        
        # Fully connected layers for regression
        self.fc1 = torch.nn.Linear(hidden_channels, 32)
        self.fc2 = torch.nn.Linear(32, 1)
        
        # Dropout for regularization
        self.dropout = torch.nn.Dropout(0.2)
        
        # For storing attention weights
        self.attention_weights = None
    
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # First GAT layer with attention
        x1, attention_weights1 = self.conv1(x, edge_index, return_attention_weights=True)
        x1 = F.relu(self.bn1(x1))
        x1 = self.dropout(x1)
        
        # Second GAT layer
        x2, attention_weights2 = self.conv2(x1, edge_index, return_attention_weights=True)
        x2 = F.relu(self.bn2(x2))
        x2 = self.dropout(x2)
        
        # Third GAT layer - final layer for capturing node importance
        x3, attention_weights3 = self.conv3(x2, edge_index, return_attention_weights=True)
        x3 = F.relu(self.bn3(x3))
        
        # Store attention weights from the final layer for interpretation
        # edge_index, attention (edge_index shape: [2, num_edges], attention shape: [num_edges, heads])
        self.attention_weights = attention_weights3
        
        # Global pooling - aggregate node features for each graph
        x = global_mean_pool(x3, batch)
        
        # Apply fully connected layers
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x.view(-1)
    
    # Method to get atom-level importance
    def get_atom_importance(self, data):
        self.eval()
        with torch.no_grad():
            # Forward pass to get attention weights
            _ = self(data)
            
            # Extract attention weights
            edge_index, attn_weights = self.attention_weights
            
            # Initialize importance scores for each atom
            num_nodes = data.x.size(0)
            importance = torch.zeros(num_nodes, device=data.x.device)
            
            # Sum attention weights for each node
            for i in range(edge_index.size(1)):
                target_node = edge_index[1, i].item()
                importance[target_node] += attn_weights[i].item()
            
            # Normalize importance scores
            if importance.max() > 0:
                importance = importance / importance.max()
                
            return importance.cpu().numpy()



In [6]:
import numpy as np
import torch
from torch_geometric.data import DataLoader
from sklearn.model_selection import KFold
from torch.optim.lr_scheduler import ReduceLROnPlateau
import optuna
from optuna.samplers import TPESampler
import copy

# Cross-validation function
def cross_validate(data_list, model_class, node_features, edge_features, device, 
                   n_splits=10, batch_size=32, epochs=100, patience=15):
    """
    Perform k-fold cross-validation.
    
    Parameters:
    -----------
    data_list : list
        List of PyG Data objects
    model_class : class
        Model class to instantiate
    node_features : int
        Number of node features
    edge_features : int
        Number of edge features
    device : torch.device
        Device to run the model on
    n_splits : int
        Number of folds for cross-validation
    batch_size : int
        Batch size for training
    epochs : int
        Maximum number of epochs for training
    patience : int
        Patience for early stopping
        
    Returns:
    --------
    dict
        Dictionary with cross-validation results
    """
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
    
    fold_results = {
        'train_loss': [],
        'val_loss': [],
        'val_r2': [],
        'val_rmse': [],
        'val_mae': []
    }
    
    print(f"\nPerforming {n_splits}-fold cross-validation...")
    
    # Create dataset indices for cross-validation
    indices = np.arange(len(data_list))
    
    for fold, (train_idx, val_idx) in enumerate(kf.split(indices)):
        print(f"\nFold {fold+1}/{n_splits}")
        
        # Split data for this fold
        train_data = [data_list[i] for i in train_idx]
        val_data = [data_list[i] for i in val_idx]
        
        train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_data, batch_size=batch_size)
        
        # Initialize model
        model = model_class(node_features, edge_features).to(device)
        optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
        criterion = torch.nn.MSELoss()
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
        
        # Initialize early stopping
        early_stopping = EarlyStopping(patience=patience)
        
        train_losses = []
        val_losses = []
        val_r2s = []
        
        for epoch in range(epochs):
            # Train
            model.train()
            total_loss = 0
            for data in train_loader:
                data = data.to(device)
                optimizer.zero_grad()
                output = model(data)
                loss = criterion(output, data.y)
                loss.backward()
                optimizer.step()
                total_loss += loss.item() * data.num_graphs
            train_loss = total_loss / len(train_loader.dataset)
            train_losses.append(train_loss)
            
            # Validate
            val_results = evaluate(model, val_loader, device)
            val_loss = val_results['mse']
            val_r2 = val_results['r2']
            val_losses.append(val_loss)
            val_r2s.append(val_r2)
            
            # Update learning rate scheduler
            scheduler.step(val_loss)
            
            # Print progress
            if (epoch + 1) % 10 == 0 or epoch == 0:
                print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, '
                      f'Val Loss: {val_loss:.4f}, Val R²: {val_r2:.4f}')
            
            # Check early stopping
            if early_stopping(val_loss, model):
                print(f"Early stopping at epoch {epoch+1}")
                break
        
        # Load the best model
        model.load_state_dict(early_stopping.best_model_state)
        
        # Final evaluation on validation set
        val_results = evaluate(model, val_loader, device)
        
        # Store results for this fold
        fold_results['train_loss'].append(train_losses[-1])
        fold_results['val_loss'].append(val_results['mse'])
        fold_results['val_r2'].append(val_results['r2'])
        fold_results['val_rmse'].append(val_results['rmse'])
        fold_results['val_mae'].append(val_results['mae'])
        
        print(f"Fold {fold+1} Results - Val MSE: {val_results['mse']:.4f}, "
              f"Val RMSE: {val_results['rmse']:.4f}, Val R²: {val_results['r2']:.4f}")
    
    # Calculate average metrics across folds
    for metric in fold_results:
        avg_value = np.mean(fold_results[metric])
        std_value = np.std(fold_results[metric])
        print(f"Average {metric}: {avg_value:.4f} ± {std_value:.4f}")
    
    return fold_results

# Custom R² Loss function that optimizes for R² directly
class R2Loss(torch.nn.Module):
    def __init__(self, epsilon=1e-10):
        super(R2Loss, self).__init__()
        self.epsilon = epsilon
        
    def forward(self, y_pred, y_true):
        """
        Calculate negative R² (to minimize) as a loss function
        
        Parameters:
        -----------
        y_pred : tensor
            Predicted values
        y_true : tensor
            True values
            
        Returns:
        --------
        tensor
            Negative R² loss
        """
        y_pred = y_pred.view(-1)
        y_true = y_true.view(-1)
        
        # Calculate mean of true values
        y_mean = torch.mean(y_true)
        
        # Calculate total sum of squares
        ss_tot = torch.sum((y_true - y_mean) ** 2)
        
        # Calculate residual sum of squares
        ss_res = torch.sum((y_true - y_pred) ** 2)
        
        # Calculate R²
        r2 = 1 - (ss_res / (ss_tot + self.epsilon))
        
        # Return negative R² for minimization
        return -r2

# Combined loss function to balance MSE and R² optimization
class CombinedLoss(torch.nn.Module):
    def __init__(self, alpha=0.7, epsilon=1e-10):
        """
        Combined loss function with MSE and R²
        
        Parameters:
        -----------
        alpha : float
            Weight for R² loss (1-alpha is weight for MSE)
        epsilon : float
            Small value to avoid division by zero
        """
        super(CombinedLoss, self).__init__()
        self.alpha = alpha
        self.mse_loss = torch.nn.MSELoss()
        self.r2_loss = R2Loss(epsilon=epsilon)
        
    def forward(self, y_pred, y_true):
        mse = self.mse_loss(y_pred, y_true)
        r2 = self.r2_loss(y_pred, y_true)
        
        # Combine losses (note: lower is better for both)
        return (1 - self.alpha) * mse + self.alpha * r2

# Modified training function with combined loss
def train_with_combined_loss(model, train_loader, optimizer, device, alpha=0.5):
    model.train()
    total_mse_loss = 0
    total_r2_loss = 0
    total_combined_loss = 0
    
    # Create loss functions
    mse_criterion = torch.nn.MSELoss()
    r2_criterion = R2Loss()
    combined_criterion = CombinedLoss(alpha=alpha)
    
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        
        # Calculate MSE loss for tracking
        mse_loss = mse_criterion(output, data.y)
        total_mse_loss += mse_loss.item() * data.num_graphs
        
        # Calculate R² loss for tracking
        r2_loss = r2_criterion(output, data.y)
        total_r2_loss += r2_loss.item() * data.num_graphs
        
        # Use combined loss for optimization
        combined_loss = combined_criterion(output, data.y)
        combined_loss.backward()
        optimizer.step()
        
        total_combined_loss += combined_loss.item() * data.num_graphs
    
    num_samples = len(train_loader.dataset)
    return {
        'mse': total_mse_loss / num_samples,
        'r2': total_r2_loss / num_samples,
        'combined': total_combined_loss / num_samples
    }

# Modified evaluation function to return R² directly
def evaluate_with_r2(model, loader, device):
    model.eval()
    predictions = []
    actual = []
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            output = model(data)
            predictions.extend(output.cpu().numpy())
            actual.extend(data.y.cpu().numpy())
    
    predictions = np.array(predictions)
    actual = np.array(actual)
    
    # Calculate metrics
    mse = mean_squared_error(actual, predictions)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(actual, predictions)
    r2 = r2_score(actual, predictions)
    
    return {
        'predictions': predictions,
        'actual': actual,
        'mse': mse,
        'rmse': rmse,
        'mae': mae,
        'r2': r2
    }

# Hyperparameter tuning with Optuna
# Fix for the hyperparameter optimization function
def optimize_hyperparameters(data_list, node_features, edge_features, device, n_trials=30):
    """
    Optimize hyperparameters using Optuna.
    
    Parameters:
    -----------
    data_list : list
        List of PyG Data objects
    node_features : int
        Number of node features
    edge_features : int
        Number of edge features
    device : torch.device
        Device to run the model on
    n_trials : int
        Number of optimization trials
        
    Returns:
    --------
    dict
        Best hyperparameters
    """
    print("\nOptimizing hyperparameters with Optuna...")
    
    # Split data once for hyperparameter tuning
    train_idx, val_idx = train_test_split(
        np.arange(len(data_list)), test_size=0.2, random_state=42
    )
    
    train_data = [data_list[i] for i in train_idx]
    val_data = [data_list[i] for i in val_idx]
    
    def objective(trial):
        # Sample hyperparameters
        hidden_channels = trial.suggest_int('hidden_channels', 32, 128, step=16)
        heads = trial.suggest_int('heads', 1, 8)
        dropout = trial.suggest_float('dropout', 0.1, 0.5)
        learning_rate = trial.suggest_float('learning_rate', 1e-4, 1e-2, log=True)
        weight_decay = trial.suggest_float('weight_decay', 1e-6, 1e-3, log=True)
        batch_size = trial.suggest_categorical('batch_size', [16, 32, 64, 128])
        r2_weight = trial.suggest_float('r2_weight', 0.0, 1.0)
        
        # Create loaders with current batch size
        train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_data, batch_size=batch_size)
        
        # Create a modified GATTgPredictor with tunable parameters
        class TunableGATTgPredictor(torch.nn.Module):
            def __init__(self):
                super(TunableGATTgPredictor, self).__init__()
                
                # Graph attention layers - these will provide interpretability
                self.conv1 = GATConv(node_features, hidden_channels, heads=heads, dropout=dropout)
                self.conv2 = GATConv(hidden_channels*heads, hidden_channels, heads=heads, dropout=dropout)
                self.conv3 = GATConv(hidden_channels*heads, hidden_channels, heads=1, dropout=dropout)
                
                # Batch normalization for stability
                self.bn1 = torch.nn.BatchNorm1d(hidden_channels*heads)
                self.bn2 = torch.nn.BatchNorm1d(hidden_channels*heads)
                self.bn3 = torch.nn.BatchNorm1d(hidden_channels)
                
                # Fully connected layers for regression
                self.fc1 = torch.nn.Linear(hidden_channels, 32)
                self.fc2 = torch.nn.Linear(32, 1)
                
                # Dropout for regularization
                self.dropout = torch.nn.Dropout(dropout)
                
                # For storing attention weights
                self.attention_weights = None
            
            def forward(self, data):
                x, edge_index, batch = data.x, data.edge_index, data.batch
                
                # First GAT layer with attention
                x1, attention_weights1 = self.conv1(x, edge_index, return_attention_weights=True)
                x1 = F.relu(self.bn1(x1))
                x1 = self.dropout(x1)
                
                # Second GAT layer
                x2, attention_weights2 = self.conv2(x1, edge_index, return_attention_weights=True)
                x2 = F.relu(self.bn2(x2))
                x2 = self.dropout(x2)
                
                # Third GAT layer - final layer for capturing node importance
                x3, attention_weights3 = self.conv3(x2, edge_index, return_attention_weights=True)
                x3 = F.relu(self.bn3(x3))
                
                # Store attention weights from the final layer for interpretation
                self.attention_weights = attention_weights3
                
                # Global pooling - aggregate node features for each graph
                x = global_mean_pool(x3, batch)
                
                # Apply fully connected layers
                x = F.relu(self.fc1(x))
                x = self.dropout(x)
                x = self.fc2(x)
                
                return x.view(-1)
            
            # Method to get atom-level importance
            def get_atom_importance(self, data):
                self.eval()
                with torch.no_grad():
                    # Forward pass to get attention weights
                    _ = self(data)
                    
                    # Extract attention weights
                    edge_index, attn_weights = self.attention_weights
                    
                    # Initialize importance scores for each atom
                    num_nodes = data.x.size(0)
                    importance = torch.zeros(num_nodes, device=data.x.device)
                    
                    # Sum attention weights for each node
                    for i in range(edge_index.size(1)):
                        target_node = edge_index[1, i].item()
                        importance[target_node] += attn_weights[i].item()
                    
                    # Normalize importance scores
                    if importance.max() > 0:
                        importance = importance / importance.max()
                        
                    return importance.cpu().numpy()
        
        # Initialize model with trial parameters
        model = TunableGATTgPredictor().to(device)
        
        optimizer = torch.optim.AdamW(
            model.parameters(), 
            lr=learning_rate, 
            weight_decay=weight_decay
        )
        
        # Custom combined loss function
        combined_loss = CombinedLoss(alpha=r2_weight)
        
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=5, verbose=False
        )
        
        # Initialize early stopping
        early_stopping = EarlyStopping(patience=10)
        
        # Training loop
        max_epochs = 50
        for epoch in range(max_epochs):
            # Train with combined loss
            model.train()
            train_loss = 0
            for data in train_loader:
                data = data.to(device)
                optimizer.zero_grad()
                output = model(data)
                loss = combined_loss(output, data.y)
                loss.backward()
                optimizer.step()
                train_loss += loss.item() * data.num_graphs
            
            # Evaluate
            model.eval()
            val_predictions = []
            val_targets = []
            with torch.no_grad():
                for data in val_loader:
                    data = data.to(device)
                    output = model(data)
                    val_predictions.extend(output.cpu().numpy())
                    val_targets.extend(data.y.cpu().numpy())
            
            val_predictions = np.array(val_predictions)
            val_targets = np.array(val_targets)
            val_mse = mean_squared_error(val_targets, val_predictions)
            
            # Update scheduler
            scheduler.step(val_mse)
            
            # Check early stopping
            if early_stopping(val_mse, model):
                break
        
        # Load best model
        model.load_state_dict(early_stopping.best_model_state)
        
        # Final evaluation
        model.eval()
        val_predictions = []
        val_targets = []
        with torch.no_grad():
            for data in val_loader:
                data = data.to(device)
                output = model(data)
                val_predictions.extend(output.cpu().numpy())
                val_targets.extend(data.y.cpu().numpy())
        
        val_predictions = np.array(val_predictions)
        val_targets = np.array(val_targets)
        val_r2 = r2_score(val_targets, val_predictions)
        
        # Report R² (higher is better)
        return val_r2
    
    # Create the study and optimize
    sampler = TPESampler(seed=42)
    study = optuna.create_study(direction='maximize', sampler=sampler)
    study.optimize(objective, n_trials=n_trials)
    
    # Print results
    print("\nBest trial:")
    trial = study.best_trial
    print(f"  Value (R²): {trial.value:.4f}")
    print("  Hyperparameters:")
    for key, value in trial.params.items():
        print(f"    {key}: {value}")
    
    return trial.params

# Modified GAT model with additional options for interpretability
class EnhancedGATTgPredictor(torch.nn.Module):
    def __init__(self, node_features, edge_features, hidden_channels=64, heads=4, 
                 dropout=0.2, use_gat_v2=False, use_global_attention=False):
        super(EnhancedGATTgPredictor, self).__init__()
        
        self.use_gat_v2 = use_gat_v2
        self.use_global_attention = use_global_attention
        
        # Choose between GAT and GATv2
        if use_gat_v2:
            # GATv2 has improved attention mechanism
            self.conv1 = GATv2Conv(node_features, hidden_channels, heads=heads, dropout=dropout)
            self.conv2 = GATv2Conv(hidden_channels*heads, hidden_channels, heads=heads, dropout=dropout)
            self.conv3 = GATv2Conv(hidden_channels*heads, hidden_channels, heads=1, dropout=dropout)
        else:
            # Standard GAT
            self.conv1 = GATConv(node_features, hidden_channels, heads=heads, dropout=dropout)
            self.conv2 = GATConv(hidden_channels*heads, hidden_channels, heads=heads, dropout=dropout)
            self.conv3 = GATConv(hidden_channels*heads, hidden_channels, heads=1, dropout=dropout)
        
        # Batch normalization for stability
        self.bn1 = torch.nn.BatchNorm1d(hidden_channels*heads)
        self.bn2 = torch.nn.BatchNorm1d(hidden_channels*heads)
        self.bn3 = torch.nn.BatchNorm1d(hidden_channels)
        
        # Optional global attention pooling
        if use_global_attention:
            self.global_attention = GlobalAttention(
                torch.nn.Sequential(
                    torch.nn.Linear(hidden_channels, hidden_channels // 2),
                    torch.nn.ReLU(),
                    torch.nn.Linear(hidden_channels // 2, 1)
                )
            )
        
        # Fully connected layers for regression
        self.fc1 = torch.nn.Linear(hidden_channels, 32)
        self.fc2 = torch.nn.Linear(32, 1)
        
        # Dropout for regularization
        self.dropout = torch.nn.Dropout(dropout)
        
        # For storing attention weights
        self.attention_weights = None
    
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # First GAT layer with attention
        if self.use_gat_v2:
            x1 = self.conv1(x, edge_index)
            self.attention_weights = None  # GATv2Conv doesn't return attention weights by default
        else:
            x1, attention_weights1 = self.conv1(x, edge_index, return_attention_weights=True)
            self.attention_weights = attention_weights1
            
        x1 = F.relu(self.bn1(x1))
        x1 = self.dropout(x1)
        
        # Second GAT layer
        if self.use_gat_v2:
            x2 = self.conv2(x1, edge_index)
        else:
            x2, attention_weights2 = self.conv2(x1, edge_index, return_attention_weights=True)
            
        x2 = F.relu(self.bn2(x2))
        x2 = self.dropout(x2)
        
        # Third GAT layer - final layer for capturing node importance
        if self.use_gat_v2:
            x3 = self.conv3(x2, edge_index)
        else:
            x3, attention_weights3 = self.conv3(x2, edge_index, return_attention_weights=True)
            self.attention_weights = attention_weights3
            
        x3 = F.relu(self.bn3(x3))
        
        # Global pooling - aggregate node features for each graph
        if self.use_global_attention:
            x = self.global_attention(x3, batch)
        else:
            x = global_mean_pool(x3, batch)
        
        # Apply fully connected layers
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x.view(-1)
    
    # Method to get atom-level importance
    def get_atom_importance(self, data):
        self.eval()
        with torch.no_grad():
            # Forward pass to get attention weights
            _ = self(data)
            
            if self.use_gat_v2:
                # For GATv2, we need a different approach to get importance
                # Run a forward pass and use gradients with respect to node features
                self.train()  # temporarily set to train mode to compute gradients
                
                # Create a copy of node features that requires gradients
                x = data.x.clone().detach().to(data.x.device).requires_grad_(True)
                
                # Forward pass with the copied features
                data_copy = copy.copy(data)
                data_copy.x = x
                out = self(data_copy)
                
                # Compute gradients
                out.backward()
                
                # Use gradient magnitudes as importance scores
                importance = torch.sum(torch.abs(x.grad), dim=1)
                
                # Normalize importance scores
                if importance.max() > 0:
                    importance = importance / importance.max()
                
                self.eval()  # set back to eval mode
                return importance.cpu().numpy()
            
            # For standard GAT, use attention weights
            if self.attention_weights is None:
                # Fallback if no attention weights
                num_nodes = data.x.size(0)
                return np.ones(num_nodes) / num_nodes
                
            # Extract attention weights
            edge_index, attn_weights = self.attention_weights
            
            # Initialize importance scores for each atom
            num_nodes = data.x.size(0)
            importance = torch.zeros(num_nodes, device=data.x.device)
            
            # Sum attention weights for each node
            for i in range(edge_index.size(1)):
                target_node = edge_index[1, i].item()
                importance[target_node] += attn_weights[i].item()
            
            # Normalize importance scores
            if importance.max() > 0:
                importance = importance / importance.max()
                
            return importance.cpu().numpy()

# Function to run an experiment with optimized model and cross-validation
def run_optimized_experiment(data_list, best_params, node_features, edge_features, device, 
                            test_size=0.2, seed=42):
    """
    Run a complete experiment with the optimized model.
    
    Parameters:
    -----------
    data_list : list
        List of PyG Data objects
    best_params : dict
        Best hyperparameters from optimization
    node_features : int
        Number of node features
    edge_features : int
        Number of edge features
    device : torch.device
        Device to run the model on
    test_size : float
        Proportion of data to use for testing
    seed : int
        Random seed for reproducibility
        
    Returns:
    --------
    dict
        Experiment results
    """
    # Set reproducibility
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    # Create train/test split
    train_data, test_data = train_test_split(data_list, test_size=test_size, random_state=seed)
    
    print(f"\nRunning final experiment with optimized parameters...")
    print(f"Training set: {len(train_data)}, Test set: {len(test_data)}")
    
    # Extract hyperparameters
    hidden_channels = best_params.get('hidden_channels', 64)
    heads = best_params.get('heads', 4)
    dropout = best_params.get('dropout', 0.2)
    learning_rate = best_params.get('learning_rate', 0.001)
    weight_decay = best_params.get('weight_decay', 1e-4)
    batch_size = best_params.get('batch_size', 64)
    r2_weight = best_params.get('r2_weight', 0.5)
    
    # Initialize model with optimized parameters
    model = EnhancedGATTgPredictor(
        node_features, edge_features,
        hidden_channels=hidden_channels,
        heads=heads,
        dropout=dropout,
        use_gat_v2=False,  # Change to True to try GATv2
        use_global_attention=False  # Change to True to try global attention
    ).to(device)
    
    # Create optimized dataloader
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=batch_size)
    
    # Create optimizer with optimized parameters
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        weight_decay=weight_decay
    )
    
    # Learning rate scheduler
    scheduler = ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )
    
    # Initialize early stopping
    early_stopping = EarlyStopping(patience=15)
    
    # Training loop
    num_epochs = 200
    train_losses = []
    train_r2s = []
    
    print("\nTraining optimized model...")
    for epoch in range(num_epochs):
        # Train with combined loss
        train_results = train_with_combined_loss(
            model, train_loader, optimizer, device, alpha=r2_weight
        )
        train_losses.append(train_results['mse'])
        
        # Validate on training set for tracking
        train_eval = evaluate_with_r2(model, train_loader, device)
        train_r2s.append(train_eval['r2'])
        
        # Print progress
        if (epoch + 1) % 10 == 0 or epoch == 0:
            print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_results["mse"]:.4f}, '
                  f'Train R²: {train_eval["r2"]:.4f}')
        
        # Update scheduler based on MSE
        scheduler.step(train_results['mse'])
        
        # Check early stopping
        if early_stopping(train_results['mse'], model):
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    # Load the best model
    model.load_state_dict(early_stopping.best_model_state)
    
    # Evaluate on test set
    test_results = evaluate_with_r2(model, test_loader, device)
    print("\nTest Results with Optimized Model:")
    print(f"MSE: {test_results['mse']:.4f}")
    print(f"RMSE: {test_results['rmse']:.4f}")
    print(f"MAE: {test_results['mae']:.4f}")
    print(f"R²: {test_results['r2']:.4f}")
    
    # Plot training curves
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss (MSE)')
    plt.title('Loss Curve')
    plt.xlabel('Epoch')
    plt.ylabel('MSE Loss')
    plt.legend()
    plt.grid(alpha=0.3)
    
    plt.subplot(1, 2, 2)
    plt.plot(train_r2s, label='Train R²')
    plt.axhline(y=test_results['r2'], color='r', linestyle='--', label='Test R²')
    plt.title('R² Score')
    plt.xlabel('Epoch')
    plt.ylabel('R² Score')
    plt.legend()
    plt.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f'{plot_dir_name}optimized_training_curves.png')
    plt.close()
    
    # Plot predictions vs actual
    plt.figure(figsize=(10, 8))
    plt.scatter(test_results['actual'], test_results['predictions'], alpha=0.5)
    
    # Add identity line
    min_val = min(min(test_results['actual']), min(test_results['predictions']))
    max_val = max(max(test_results['actual']), max(test_results['predictions']))
    plt.plot([min_val, max_val], [min_val, max_val], 'r--')
    
    plt.xlabel('Actual Tg (°C)')
    plt.ylabel('Predicted Tg (°C)')
    plt.title(f'Optimized GAT Model: Actual vs Predicted Tg (R² = {test_results["r2"]:.4f})')
    plt.grid(alpha=0.3)
    plt.savefig(f'{plot_dir_name}optimized_prediction_results.png')
    plt.close()
    
    # Analyze model interpretation
    print("\nAnalyzing model interpretation with attention weights...")
    analyze_model_interpretation(model, test_loader, device, num_examples=20)
    
    return {
        'model': model,
        'test_results': test_results,
        'best_params': best_params
    }

# Main function to run the complete workflow
def run_complete_workflow(data, smiles_col='SMILES', target_col='tg'):
    """
    Run the complete workflow from data preparation to optimized model training.
    
    Parameters:
    -----------
    data : pandas.DataFrame
        DataFrame with molecule data
    smiles_col : str
        Column name for SMILES strings
    target_col : str
        Column name for target values
        
    Returns:
    --------
    dict
        Results from the experiment
    """
    # Analyze data
    data = analyze_data(data)
    
    # Prepare dataset
    print("\nPreparing dataset...")
    data_list, valid_indices = prepare_dataset(data, smiles_col=smiles_col, target_col=target_col)
    print(f"Valid molecules processed: {len(data_list)} out of {len(data)}")
    
    # Extract node/edge features
    if len(data_list) > 0:
        node_features = data_list[0].x.shape[1]
        edge_features = data_list[0].edge_attr.shape[1] if data_list[0].edge_attr.shape[0] > 0 else 0
        print(f"Node features: {node_features}, Edge features: {edge_features}")
    else:
        print("Error: No valid molecules processed.")
        return None
    
    # Check device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # 1. Perform cross-validation with baseline model
    print("\nPerforming cross-validation with baseline model...")
    cv_results = cross_validate(
        data_list, GATTgPredictor, node_features, edge_features, device, 
        n_splits=10, batch_size=64, epochs=100
    )
    
    # 2. Hyperparameter optimization
    print("\nOptimizing hyperparameters...")
    best_params = optimize_hyperparameters(
        data_list, node_features, edge_features, device, n_trials=5
    )
    
    # 3. Train optimized model
    print("\nTraining final model with optimized parameters...")
    results = run_optimized_experiment(
        data_list, best_params, node_features, edge_features, device
    )
    
    return results

# Set seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

# Code to integrate the new functionality into the main workflow
if __name__ == "__main__":
    # Set seeds for reproducibility
    set_seed(42)
    
    # Load data
    DATA_PATH = "./data/D1/tg_raw.csv"
    data = pd.read_csv(DATA_PATH)
    print(f"Dataset shape: {data.shape}")
    print(data.head())
    
    # Run the complete workflow
    results = run_complete_workflow(data)
    
    # Print final results
    if results:
        print("\nExperiment completed successfully!")
        print(f"Final R² score: {results['test_results']['r2']:.4f}")
        print(f"Final RMSE: {results['test_results']['rmse']:.4f}")
        print("\nBest parameters:")
        for param, value in results['best_params'].items():
            print(f"  {param}: {value}")
    else:
        print("Experiment failed to complete.")

Dataset shape: (7174, 2)
        SMILES    tg
0          *C* -54.0
1      *CC(*)C  -3.0
2     *CC(*)CC -24.1
3    *CC(*)CCC -37.0
4  *CC(*)C(C)C  60.0

Data Statistics:
                tg
count  7174.000000
mean    141.948090
std     112.178143
min    -139.000000
25%      55.000000
50%     134.000000
75%     231.000000
max     495.000000

Preparing dataset...
Valid molecules processed: 7174 out of 7174
Node features: 109, Edge features: 6
Using device: cpu

Performing cross-validation with baseline model...

Performing 10-fold cross-validation...

Fold 1/10
Epoch 1/100, Train Loss: 31367.9359, Val Loss: 28152.6543, Val R²: -1.2918
Epoch 10/100, Train Loss: 3567.9802, Val Loss: 3875.6616, Val R²: 0.6845
Epoch 20/100, Train Loss: 3293.0040, Val Loss: 3137.3389, Val R²: 0.7446
Epoch 30/100, Train Loss: 3082.8440, Val Loss: 3093.8530, Val R²: 0.7481
Early stopping at epoch 34
Fold 1 Results - Val MSE: 2946.2444, Val RMSE: 54.2793, Val R²: 0.7602

Fold 2/10
Epoch 1/100, Train Loss: 31292.45

[I 2025-03-16 16:27:52,739] A new study created in memory with name: no-name-1096ad54-5878-4c52-8776-4b14717ff0ca


Fold 10 Results - Val MSE: 2187.3552, Val RMSE: 46.7692, Val R²: 0.8332
Average train_loss: 2791.7264 ± 210.4418
Average val_loss: 2351.5829 ± 279.2697
Average val_r2: 0.8125 ± 0.0255
Average val_rmse: 48.4106 ± 2.8284
Average val_mae: 35.4905 ± 2.7705

Optimizing hyperparameters...

Optimizing hyperparameters with Optuna...


[I 2025-03-16 16:33:53,584] Trial 0 finished with value: 0.7942336797714233 and parameters: {'hidden_channels': 64, 'heads': 8, 'dropout': 0.39279757672456206, 'learning_rate': 0.0015751320499779737, 'weight_decay': 2.9380279387035354e-06, 'batch_size': 64, 'r2_weight': 0.7080725777960455}. Best is trial 0 with value: 0.7942336797714233.
[I 2025-03-16 16:38:40,858] Trial 1 finished with value: 0.7383342981338501 and parameters: {'hidden_channels': 32, 'heads': 8, 'dropout': 0.4329770563201687, 'learning_rate': 0.00026587543983272726, 'weight_decay': 3.5113563139704077e-06, 'batch_size': 64, 'r2_weight': 0.2912291401980419}. Best is trial 0 with value: 0.7942336797714233.
[I 2025-03-16 16:43:31,287] Trial 2 finished with value: 0.8116323947906494 and parameters: {'hidden_channels': 96, 'heads': 2, 'dropout': 0.21685785941408728, 'learning_rate': 0.0005404103854647331, 'weight_decay': 2.334586407601622e-05, 'batch_size': 16, 'r2_weight': 0.046450412719997725}. Best is trial 2 with value:


Best trial:
  Value (R²): 0.8263
  Hyperparameters:
    hidden_channels: 32
    heads: 4
    dropout: 0.11375540844608736
    learning_rate: 0.006586289317583112
    weight_decay: 5.975027999960295e-06
    batch_size: 16
    r2_weight: 0.18485445552552704

Training final model with optimized parameters...

Running final experiment with optimized parameters...
Training set: 5739, Test set: 1435

Training optimized model...
Epoch 1/200, Train Loss: 8368.1025, Train R²: 0.7229
Epoch 10/200, Train Loss: 3302.9197, Train R²: 0.7981
Epoch 20/200, Train Loss: 3227.9120, Train R²: 0.7772
Epoch 30/200, Train Loss: 3069.7076, Train R²: 0.7412
Epoch 40/200, Train Loss: 2958.4408, Train R²: 0.8130
Epoch 50/200, Train Loss: 2609.1094, Train R²: 0.8498
Epoch 60/200, Train Loss: 2500.7507, Train R²: 0.8197
Epoch 70/200, Train Loss: 2416.1790, Train R²: 0.8551
Epoch 80/200, Train Loss: 2390.9546, Train R²: 0.8563
Epoch 90/200, Train Loss: 2305.0630, Train R²: 0.8725
Epoch 100/200, Train Loss: 2182.75