In [1]:
!pip install rdkit
!pip install torch
!pip install torch_geometric



In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


DATA_PATH = "./data/D1/tg_raw.csv"

data = pd.read_csv(DATA_PATH)
data

Unnamed: 0,SMILES,tg
0,*C*,-54.0
1,*CC(*)C,-3.0
2,*CC(*)CC,-24.1
3,*CC(*)CCC,-37.0
4,*CC(*)C(C)C,60.0
...,...,...
7169,*CC(*)(F)C(=O)OCCC,62.0
7170,*CC(F)(F)C1(F)C(*)CC(O)(C(F)(F)F)C1(F)F,152.0
7171,*CC(F)(F)C1(F)CC(CC(O)(C(F)(F)F)C(F)(F)F)CC1*,98.0
7172,*CC(F)(F)C1(F)CC(C(O)(C(F)(F)F)C(F)(F)F)CC1*,118.0


In [3]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit import RDLogger
import torch
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GATConv, global_mean_pool, global_add_pool
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
import random
import warnings
from matplotlib.colors import LinearSegmentedColormap, Normalize
from io import BytesIO
from PIL import Image
import matplotlib.cm as cm

# Add necessary imports
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors


from torch_geometric.nn import GATv2Conv, GlobalAttention
from torch.optim.swa_utils import AveragedModel, SWALR
from sklearn.metrics import explained_variance_score, max_error, median_absolute_error


# Suppress RDKit warnings
RDLogger.DisableLog('rdApp.*')
warnings.filterwarnings('ignore')

In [4]:
plot_dir_name = 'plots_by_c/'

In [5]:
''' function definitions '''
# Basic data analysis
def analyze_data(df):
    print("\nData Statistics:")
    print(df.describe())
    
    # Plot Tg distribution
    plt.figure(figsize=(10, 6))
    sns.histplot(df['tg'], bins=30, kde=True)
    plt.title('Distribution of Glass Transition Temperatures (Tg)')
    plt.xlabel('Tg (°C)')
    plt.ylabel('Count')
    plt.grid(alpha=0.3)
    plt.savefig(f'{plot_dir_name}tg_distribution.png')
    plt.close()
    
    # Analyze SMILES complexity
    df['smiles_length'] = df['SMILES'].apply(len)
    
    plt.figure(figsize=(10, 6))
    sns.histplot(df['smiles_length'], bins=30, kde=True)
    plt.title('Distribution of SMILES String Lengths')
    plt.xlabel('Length')
    plt.ylabel('Count')
    plt.grid(alpha=0.3)
    plt.savefig(f'{plot_dir_name}smiles_length_distribution.png')
    plt.close()
    
    # Correlation between SMILES length and Tg
    plt.figure(figsize=(10, 6))
    sns.scatterplot(x='smiles_length', y='tg', data=df, alpha=0.5)
    plt.title('Relationship Between Molecule Complexity and Tg')
    plt.xlabel('SMILES String Length')
    plt.ylabel('Tg (°C)')
    plt.grid(alpha=0.3)
    plt.savefig(f'{plot_dir_name}length_vs_tg.png')
    plt.close()
    
    return df


# Training function
def train(model, train_loader, optimizer, device):
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(train_loader.dataset)

# Evaluation function
def evaluate(model, loader, device):
    model.eval()
    predictions = []
    actual = []
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            output = model(data)
            predictions.extend(output.cpu().numpy())
            actual.extend(data.y.cpu().numpy())
    
    predictions = np.array(predictions)
    actual = np.array(actual)
    
    # Calculate metrics
    mse = mean_squared_error(actual, predictions)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(actual, predictions)
    r2 = r2_score(actual, predictions)
    
    return {
        'predictions': predictions,
        'actual': actual,
        'mse': mse,
        'rmse': rmse,
        'mae': mae,
        'r2': r2
    }

# Early stopping
class EarlyStopping:
    def __init__(self, patience=10, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.counter = 0
        self.early_stop = False
        self.best_model_state = None
    
    def __call__(self, val_loss, model):
        if self.best_score is None:
            self.best_score = val_loss
            self.best_model_state = model.state_dict().copy()
        elif val_loss > self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_loss
            self.best_model_state = model.state_dict().copy()
            self.counter = 0
        return self.early_stop

# Training loop with visualization
def train_and_evaluate(model, train_loader, val_loader, test_loader, optimizer, device, num_epochs=100):
    train_losses = []
    val_losses = []
    val_r2s = []
    
    # Initialize early stopping
    early_stopping = EarlyStopping(patience=15)
    
    for epoch in range(num_epochs):
        # Train
        train_loss = train(model, train_loader, optimizer, device)
        train_losses.append(train_loss)
        
        # Validate
        val_results = evaluate(model, val_loader, device)
        val_loss = val_results['mse']
        val_r2 = val_results['r2']
        val_losses.append(val_loss)
        val_r2s.append(val_r2)
        
        # Update learning rate scheduler
        scheduler.step(val_loss)
        
        # Print progress
        if (epoch + 1) % 10 == 0 or epoch == 0:
            print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, '
                  f'Val Loss: {val_loss:.4f}, Val R²: {val_r2:.4f}')
        
        # Check early stopping
        if early_stopping(val_loss, model):
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    # Load the best model
    model.load_state_dict(early_stopping.best_model_state)
    
    # Evaluate on test set
    test_results = evaluate(model, test_loader, device)
    print("\nTest Results:")
    print(f"MSE: {test_results['mse']:.4f}")
    print(f"RMSE: {test_results['rmse']:.4f}")
    print(f"MAE: {test_results['mae']:.4f}")
    print(f"R²: {test_results['r2']:.4f}")
    
    # Plot training curves
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Loss Curves')
    plt.xlabel('Epoch')
    plt.ylabel('MSE Loss')
    plt.legend()
    plt.grid(alpha=0.3)
    
    plt.subplot(1, 2, 2)
    plt.plot(val_r2s, label='Validation R²')
    plt.axhline(y=test_results['r2'], color='r', linestyle='--', label='Test R²')
    plt.title('R² Score')
    plt.xlabel('Epoch')
    plt.ylabel('R² Score')
    plt.legend()
    plt.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f'{plot_dir_name}training_curves.png')
    plt.close()
    
    # Plot predictions vs actual
    plt.figure(figsize=(10, 8))
    plt.scatter(test_results['actual'], test_results['predictions'], alpha=0.5)
    
    # Add identity line
    min_val = min(min(test_results['actual']), min(test_results['predictions']))
    max_val = max(max(test_results['actual']), max(test_results['predictions']))
    plt.plot([min_val, max_val], [min_val, max_val], 'r--')
    
    plt.xlabel('Actual Tg (°C)')
    plt.ylabel('Predicted Tg (°C)')
    plt.title(f'GAT Model: Actual vs Predicted Tg (R² = {test_results["r2"]:.4f})')
    plt.grid(alpha=0.3)
    plt.savefig(f'{plot_dir_name}prediction_results.png')
    plt.close()
    
    # Create residual plot
    residuals = test_results['predictions'] - test_results['actual']
    
    plt.figure(figsize=(10, 8))
    plt.scatter(test_results['actual'], residuals, alpha=0.5)
    plt.axhline(y=0, color='r', linestyle='--')
    plt.xlabel('Actual Tg (°C)')
    plt.ylabel('Residuals (Predicted - Actual)')
    plt.title('Residual Plot')
    plt.grid(alpha=0.3)
    plt.savefig(f'{plot_dir_name}residual_plot.png')
    plt.close()
    
    return model, test_results

def visualize_molecule_with_attention(smiles, atom_importances, tg, idx, prediction=None, plot_dir_name=""):
    """
    Visualize a molecule with atom importance using a simpler approach with fewer RDKit options.
    
    Parameters:
    -----------
    smiles : str
        SMILES string of the molecule
    atom_importances : numpy array
        Array of importance values for each atom
    tg : float
        Actual glass transition temperature
    idx : str/int
        Identifier for the saved file
    prediction : float, optional
        Predicted glass transition temperature
    plot_dir_name : str, optional
        Directory to save the plot
    """
    # Convert SMILES to molecule
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        print(f"Could not convert SMILES to molecule: {smiles}")
        return
    
    # Normalize importance values if needed
    if len(atom_importances) > 0 and max(atom_importances) > 0:
        norm_importances = atom_importances / max(atom_importances)
    else:
        norm_importances = atom_importances

    # Create atom-specific highlight colors
    highlight_atoms = []
    highlight_colors = {}
    
    # Set up the colormap
    cmap = plt.cm.coolwarm
    
    # Add atom properties for visualization
    for atom_idx, importance in enumerate(norm_importances):
        if atom_idx < mol.GetNumAtoms():
            highlight_atoms.append(atom_idx)
            
            # Generate color from colormap (coolwarm: blue to red)
            color_rgba = cmap(float(importance))
            color_tuple = (int(color_rgba[0]*255), int(color_rgba[1]*255), int(color_rgba[2]*255))
            highlight_colors[atom_idx] = color_tuple
    
    # Prepare the molecule drawing with minimal options
    d2d = rdMolDraw2D.MolDraw2DCairo(800, 800)
    
    # Configure minimal drawing options
    d2d.drawOptions().addAtomIndices = True
    
    # Draw the molecule with atom highlights using a simpler approach
    rdMolDraw2D.PrepareAndDrawMolecule(
        d2d, mol,
        highlightAtoms=highlight_atoms,
        highlightAtomColors=highlight_colors
    )
    d2d.FinishDrawing()
    
    # Get the PNG data and convert to PIL Image
    png_data = d2d.GetDrawingText()
    molecule_img = Image.open(BytesIO(png_data))
    
    # Create the figure with just the molecule and a colorbar
    fig, ax = plt.subplots(figsize=(10, 10))
    
    # Display the molecule image
    ax.imshow(molecule_img)
    ax.set_title(f"SMILES: {smiles}\nTg: {tg}°C\nPredicted Tg: {prediction:.2f}°C", fontsize=14)
    ax.axis('off')
    
    # Add colorbar to show importance scale
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(0, 1))
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax, shrink=0.7, pad=0.05, orientation='horizontal')
    cbar.set_label('Atom Importance', fontsize=14)
    
    # Add ticks at 0.0, 0.5 and 1.0 with labels
    cbar.set_ticks([0.0, 0.5, 1.0])
    cbar.set_ticklabels(['Low', 'Medium', 'High'])
    
    # Add an annotation explaining the colors
    plt.figtext(0.5, 0.01,
                "Blue = Low Importance, Red = High Importance for Tg Prediction",
                ha="center", fontsize=12,
                bbox={"facecolor":"lightgray", "alpha":0.3, "pad":5})
    
    # Print the most important atoms for reference
    top_indices = np.argsort(atom_importances)[-5:][::-1]
    importance_text = "Top 5 important atoms (indices): "
    importance_text += ", ".join([f"{idx}({atom_importances[idx]:.3f})" for idx in top_indices])
    
    # Add the importance information as an annotation on the plot
    plt.figtext(0.5, 0.05, importance_text, ha="center", fontsize=10,
               bbox={"facecolor":"white", "edgecolor":"gray", "alpha":0.8, "pad":5})
    
    plt.tight_layout(rect=[0, 0.07, 1, 0.97])
    plt.savefig(f'{plot_dir_name}molecule_attention_{idx}.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # Also print to console for reference
    print(f"Top 5 important atoms for molecule {idx} (indices): {top_indices}")
    print(f"Importance values: {atom_importances[top_indices]}")

# If the above approach has issues with coordinates, here's an even simpler alternative
def visualize_molecule_with_color_overlay(smiles, atom_importances, tg, idx, prediction=None, plot_dir_name=""):
    """
    Visualize a molecule with atom importance by creating two overlaid images:
    1. A standard black and white molecule with atom indices
    2. Colored circles positioned approximately where atoms are
    
    Parameters are the same as previous function.
    """
    # Convert SMILES to molecule
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        print(f"Could not convert SMILES to molecule: {smiles}")
        return
    
    # Normalize importance values
    if len(atom_importances) > 0 and max(atom_importances) > 0:
        norm_importances = atom_importances / max(atom_importances)
    else:
        norm_importances = atom_importances
    
    # Set up the colormap
    cmap = plt.cm.coolwarm
    
    # Generate a standard molecule image with atom indices
    drawer = rdMolDraw2D.MolDraw2DCairo(800, 800)
    drawer.drawOptions().addAtomIndices = True
    drawer.DrawMolecule(mol)
    drawer.FinishDrawing()
    png_data = drawer.GetDrawingText()
    molecule_img = Image.open(BytesIO(png_data))
    
    # Create a figure
    fig, ax = plt.subplots(figsize=(12, 12))
    
    # Display the molecule image
    ax.imshow(molecule_img)
    ax.set_title(f"SMILES: {smiles}\nTg: {tg}°C\nPredicted Tg: {prediction:.2f}°C", fontsize=14)
    ax.axis('off')
    
    # Add colorbar
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(0, 1))
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax, shrink=0.7, pad=0.05, orientation='horizontal')
    cbar.set_label('Atom Importance', fontsize=14)
    cbar.set_ticks([0.0, 0.5, 1.0])
    cbar.set_ticklabels(['Low', 'Medium', 'High'])
    
    
    plt.figtext(0.5, 0.01, "Blue = Low Importance, Red = High Importance for Tg Prediction",
                ha="center", fontsize=12, 
                bbox={"facecolor":"lightgray", "alpha":0.3, "pad":5})
    
    plt.savefig(f'{plot_dir_name}molecule_attention_{idx}.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # Also print to console
    top_indices = np.argsort(atom_importances)[-5:][::-1]
    print(f"Top 5 important atoms for molecule {idx} (indices): {top_indices}")
    print(f"Importance values: {atom_importances[top_indices]}")

# Function to analyze model interpretation
def analyze_model_interpretation(model, data_loader, device, num_examples=5):
    model.eval()
    
    # Get a batch of examples
    examples = []
    predictions = []
    for data in data_loader:
        data = data.to(device)
        with torch.no_grad():
            output = model(data)
            
        for i in range(min(len(data.y), num_examples - len(examples))):
            if len(examples) >= num_examples:
                break
                
            # Extract single molecule
            single_data = data[i].to(device)
            
            # Get atom importance
            atom_importance = model.get_atom_importance(single_data)
            
            # Store example and prediction
            examples.append((single_data.smiles, atom_importance, single_data.y.item()))
            predictions.append(output[i].item())
            
        if len(examples) >= num_examples:
            break
    
    # Visualize examples with attention weights
    print("\nVisualizing molecules with attention weights...")
    for i, (smiles, atom_importance, tg) in enumerate(examples):
        #visualize_molecule_with_attention
        visualize_molecule_with_color_overlay(smiles, atom_importance, tg, i, predictions[i])
        
        # Print most important atoms
        top_indices = np.argsort(atom_importance)[-5:][::-1]
        print(f"\nMolecule {i+1} (SMILES: {smiles})")
        print(f"Actual Tg: {tg:.2f}°C, Predicted Tg: {predictions[i]:.2f}°C")
        print(f"Top 5 important atoms (indices): {top_indices}")
        print(f"Importance values: {atom_importance[top_indices]}")

# Define the GAT Model with attention
class GATTgPredictor(torch.nn.Module):
    def __init__(self, node_features, edge_features, hidden_channels=64, heads=4):
        super(GATTgPredictor, self).__init__()
        
        # Graph attention layers - these will provide interpretability
        self.conv1 = GATConv(node_features, hidden_channels, heads=heads, dropout=0.2)
        self.conv2 = GATConv(hidden_channels*heads, hidden_channels, heads=heads, dropout=0.2)
        self.conv3 = GATConv(hidden_channels*heads, hidden_channels, heads=1, dropout=0.2)
        
        # Batch normalization for stability
        self.bn1 = torch.nn.BatchNorm1d(hidden_channels*heads)
        self.bn2 = torch.nn.BatchNorm1d(hidden_channels*heads)
        self.bn3 = torch.nn.BatchNorm1d(hidden_channels)
        
        # Fully connected layers for regression
        self.fc1 = torch.nn.Linear(hidden_channels, 32)
        self.fc2 = torch.nn.Linear(32, 1)
        
        # Dropout for regularization
        self.dropout = torch.nn.Dropout(0.2)
        
        # For storing attention weights
        self.attention_weights = None
    
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # First GAT layer with attention
        x1, attention_weights1 = self.conv1(x, edge_index, return_attention_weights=True)
        x1 = F.relu(self.bn1(x1))
        x1 = self.dropout(x1)
        
        # Second GAT layer
        x2, attention_weights2 = self.conv2(x1, edge_index, return_attention_weights=True)
        x2 = F.relu(self.bn2(x2))
        x2 = self.dropout(x2)
        
        # Third GAT layer - final layer for capturing node importance
        x3, attention_weights3 = self.conv3(x2, edge_index, return_attention_weights=True)
        x3 = F.relu(self.bn3(x3))
        
        # Store attention weights from the final layer for interpretation
        # edge_index, attention (edge_index shape: [2, num_edges], attention shape: [num_edges, heads])
        self.attention_weights = attention_weights3
        
        # Global pooling - aggregate node features for each graph
        x = global_mean_pool(x3, batch)
        
        # Apply fully connected layers
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x.view(-1)
    
    # Method to get atom-level importance
    def get_atom_importance(self, data):
        self.eval()
        with torch.no_grad():
            # Forward pass to get attention weights
            _ = self(data)
            
            # Extract attention weights
            edge_index, attn_weights = self.attention_weights
            
            # Initialize importance scores for each atom
            num_nodes = data.x.size(0)
            importance = torch.zeros(num_nodes, device=data.x.device)
            
            # Sum attention weights for each node
            for i in range(edge_index.size(1)):
                target_node = edge_index[1, i].item()
                importance[target_node] += attn_weights[i].item()
            
            # Normalize importance scores
            if importance.max() > 0:
                importance = importance / importance.max()
                
            return importance.cpu().numpy()


# Convert SMILES to molecular graphs
def smiles_to_graph(smiles):
    # Convert SMILES to RDKit molecule
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    
    # Get atom features
    atom_features = []
    for atom in mol.GetAtoms():
        # Atom features: One-hot encoding of atom type, formal charge, hybridization, aromaticity
        atom_type_one_hot = [0] * 100  # Limit to first 100 elements
        atom_num = atom.GetAtomicNum()
        if atom_num < 100:
            atom_type_one_hot[atom_num] = 1
            
        formal_charge = [atom.GetFormalCharge()]
        hybridization_type = [0, 0, 0, 0, 0]  # One-hot encoding of hybridization
        hyb_type = atom.GetHybridization()
        if hyb_type == Chem.rdchem.HybridizationType.SP:
            hybridization_type[0] = 1
        elif hyb_type == Chem.rdchem.HybridizationType.SP2:
            hybridization_type[1] = 1
        elif hyb_type == Chem.rdchem.HybridizationType.SP3:
            hybridization_type[2] = 1
        elif hyb_type == Chem.rdchem.HybridizationType.SP3D:
            hybridization_type[3] = 1
        elif hyb_type == Chem.rdchem.HybridizationType.SP3D2:
            hybridization_type[4] = 1
            
        is_aromatic = [1 if atom.GetIsAromatic() else 0]
        degree = [atom.GetDegree()]
        num_h = [atom.GetTotalNumHs()]
        
        # Combine all features
        features = atom_type_one_hot + formal_charge + hybridization_type + is_aromatic + degree + num_h
        atom_features.append(features)
    
    # Create node feature matrix (num_nodes x num_features)
    x = torch.tensor(atom_features, dtype=torch.float)
    
    # Get edge indices (bonds)
    edge_indices = []
    edge_attr = []
    
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        
        # Bond features
        bond_type = bond.GetBondType()
        bond_features = [0, 0, 0, 0]  # One-hot encoding of bond type
        if bond_type == Chem.rdchem.BondType.SINGLE:
            bond_features[0] = 1
        elif bond_type == Chem.rdchem.BondType.DOUBLE:
            bond_features[1] = 1
        elif bond_type == Chem.rdchem.BondType.TRIPLE:
            bond_features[2] = 1
        elif bond_type == Chem.rdchem.BondType.AROMATIC:
            bond_features[3] = 1
            
        is_conjugated = [1 if bond.GetIsConjugated() else 0]
        is_in_ring = [1 if bond.IsInRing() else 0]
        
        # Combine all features
        features = bond_features + is_conjugated + is_in_ring
        
        # Add bonds in both directions
        edge_indices.append([i, j])
        edge_indices.append([j, i])
        edge_attr.append(features)
        edge_attr.append(features)
    
    if len(edge_indices) == 0:  # For molecules with no bonds
        edge_index = torch.zeros((2, 0), dtype=torch.long)
        edge_attr = torch.zeros((0, 6), dtype=torch.float)
    else:
        edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=smiles)

# Prepare the dataset
def prepare_dataset(dataframe, smiles_col='SMILES', target_col='tg'):
    data_list = []
    valid_indices = []
    invalid_smiles = []
    
    for idx, row in dataframe.iterrows():
        smiles = row[smiles_col]
        graph = smiles_to_graph(smiles)
        if graph is not None:
            # Add target value
            graph.y = torch.tensor([row[target_col]], dtype=torch.float)
            graph.smiles = smiles  # Store SMILES for reference
            data_list.append(graph)
            valid_indices.append(idx)
        else:
            invalid_smiles.append((idx, smiles))
    
    if invalid_smiles:
        print(f"\nWarning: {len(invalid_smiles)} invalid SMILES strings found and skipped.")
        
    return data_list, valid_indices


In [7]:
# Set seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed()

# Load data
DATA_PATH = "./data/D1/tg_raw.csv"
data = pd.read_csv(DATA_PATH)
print(f"Dataset shape: {data.shape}")
print(data.head())

# Run the initial analysis
data = analyze_data(data)

# Split into train, validation, and test sets
data_list, valid_indices = prepare_dataset(data)
#data_list, valid_indices = prepare_dataset(data, smiles_to_graph_function=smiles_to_graph_enhanced)

print(f"\nValid molecules processed: {len(data_list)} out of {len(data)}")

# Create train/val/test split
train_data, test_data = train_test_split(data_list, test_size=0.2, random_state=42)
train_data, val_data = train_test_split(train_data, test_size=0.25, random_state=42)  # 0.25 of 0.8 = 0.2 of total

print(f"Training set: {len(train_data)}")
print(f"Validation set: {len(val_data)}")
print(f"Test set: {len(test_data)}")

# Create data loaders
batch_size = 64 #32
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size)
test_loader = DataLoader(test_data, batch_size=batch_size)

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize model
if len(data_list) > 0:
    node_features = data_list[0].x.shape[1]
    edge_features = data_list[0].edge_attr.shape[1] if data_list[0].edge_attr.shape[0] > 0 else 0
    print(f"Node features: {node_features}, Edge features: {edge_features}")
    
    model = GATTgPredictor(node_features, edge_features).to(device)
    #optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)

    criterion = torch.nn.MSELoss()
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )
else:
    print("Error: No valid molecules processed. Can't initialize model.")
    exit()

# Train the model
print("\nTraining the GAT model...")
model, test_results = train_and_evaluate(model, train_loader, val_loader, test_loader, optimizer, device)

# Analyze model interpretation
print("\nAnalyzing model interpretation with attention weights...")
analyze_model_interpretation(model, test_loader, device, num_examples=5)


print("\nModel training, evaluation, and interpretation complete.")

Dataset shape: (7174, 2)
        SMILES    tg
0          *C* -54.0
1      *CC(*)C  -3.0
2     *CC(*)CC -24.1
3    *CC(*)CCC -37.0
4  *CC(*)C(C)C  60.0

Data Statistics:
                tg
count  7174.000000
mean    141.948090
std     112.178143
min    -139.000000
25%      55.000000
50%     134.000000
75%     231.000000
max     495.000000

Valid molecules processed: 7174 out of 7174
Training set: 4304
Validation set: 1435
Test set: 1435
Using device: cpu
Node features: 109, Edge features: 6

Training the GAT model...
Epoch 1/100, Train Loss: 32430.0928, Val Loss: 30845.1680, Val R²: -1.3288
Epoch 10/100, Train Loss: 3775.8275, Val Loss: 4331.0649, Val R²: 0.6730
Epoch 20/100, Train Loss: 3416.6245, Val Loss: 2999.9929, Val R²: 0.7735
Epoch 30/100, Train Loss: 3319.8744, Val Loss: 2711.3923, Val R²: 0.7953
Epoch 40/100, Train Loss: 3204.9772, Val Loss: 2730.4980, Val R²: 0.7938
Epoch 50/100, Train Loss: 3076.7126, Val Loss: 2910.1516, Val R²: 0.7803
Epoch 60/100, Train Loss: 2998.3780, V