# üî¨ Fluor-RLAT Model Training

Train property prediction models for fluorescent molecules on Google Colab with GPU acceleration.

**Quick Start:**
1. Go to Runtime ‚Üí Change runtime type ‚Üí Select **T4 GPU**
2. Run all cells (Runtime ‚Üí Run all)
3. The notebook will automatically clone the repository and download training data

**Models:**
- `abs` - Absorption wavelength (nm)
- `em` - Emission wavelength (nm)
- `plqy` - Photoluminescence quantum yield (0-1)
- `k` - Log molar absorptivity

**Architecture:** AttentiveFP GNN + Fingerprint CNN fusion (~2.3M parameters per model)

**Expected training time:** ~1-2 hours total on T4 GPU

---

## 1. Install Dependencies

In [1]:
# ============================================================================
# CRITICAL: Disable torch dynamo BEFORE importing torch
# This prevents version conflicts between Colab's PyTorch and DGL
# ============================================================================
import os
os.environ['TORCHDYNAMO_DISABLE'] = '1'

import torch

# Also disable dynamo via config (belt and suspenders)
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.disable = True

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"‚úÖ Torch dynamo disabled")

PyTorch version: 2.4.0+cu121
CUDA available: True
CUDA version: 12.1
GPU: Tesla T4
‚úÖ Torch dynamo disabled


In [2]:
# Install DGL, dgllife, and RDKit (Colab has PyTorch pre-installed)
# DGL wheel selection based on CUDA version

cuda_version = torch.version.cuda
print(f"Detected CUDA version: {cuda_version}")

# Install RDKit first (required by dgllife)
print("Installing RDKit...")
!pip install rdkit -q

# Install DGL for the appropriate CUDA version
if cuda_version and cuda_version.startswith('12'):
    print("Installing DGL for CUDA 12.x...")
    !pip install dgl -f https://data.dgl.ai/wheels/torch-2.4/cu124/repo.html -q
elif cuda_version and cuda_version.startswith('11'):
    print("Installing DGL for CUDA 11.x...")
    !pip install dgl -f https://data.dgl.ai/wheels/torch-2.1/cu118/repo.html -q
else:
    # Fallback - try latest
    print("Installing DGL (default)...")
    !pip install dgl -f https://data.dgl.ai/wheels/torch-2.4/cu124/repo.html -q

# Install dgllife for molecular graph utilities
print("Installing dgllife...")
!pip install dgllife -q

# Install tqdm for progress bars (usually present but ensure it's available)
!pip install tqdm -q

print("\n‚úÖ Dependencies installed!")

Detected CUDA version: 12.1
Installing RDKit...
Installing DGL for CUDA 12.x...
Installing dgllife...

‚úÖ Dependencies installed!


In [3]:
# Verify DGL installation
import dgl
print(f"DGL version: {dgl.__version__}")

# Test CUDA with DGL
if torch.cuda.is_available():
    g = dgl.graph(([0, 1], [1, 2]))
    g = g.to('cuda')
    print(f"DGL graph on CUDA: {g.device}")
    print("‚úÖ DGL CUDA support working!")

DGL version: 2.4.0+cu124
DGL graph on CUDA: cuda:0
‚úÖ DGL CUDA support working!


## 2. Mount Google Drive and Setup Directories

In [4]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Create project directories
import os
os.makedirs('./data', exist_ok=True)
os.makedirs('./models', exist_ok=True)
os.makedirs('./checkpoints', exist_ok=True)

print("‚úÖ Directories created")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
‚úÖ Directories created


## 3. Clone Repository and Setup Data

Clone the fluor_tools repository directly from GitHub to access training data.

In [5]:
# Clone the repository
import os

REPO_URL = "https://github.com/markste-in/fluor_tools.git"
REPO_DIR = "fluor_tools"

if not os.path.exists(REPO_DIR):
    print(f"üì• Cloning repository from {REPO_URL}...")
    !git clone {REPO_URL}
    print("‚úÖ Repository cloned!")
else:
    print(f"‚úÖ Repository already exists at {REPO_DIR}")
    # Optionally pull latest changes
    !cd {REPO_DIR} && git pull

# Set data path to the cloned repo
DATA_DIR = f'./{REPO_DIR}/Fluor-RLAT/data'
print(f"üìÅ Data directory: {DATA_DIR}")

‚úÖ Repository already exists at fluor_tools
Already up to date.
üìÅ Data directory: ./fluor_tools/Fluor-RLAT/data


## 4. Verify GPU Availability

In [6]:
# Check GPU details
!nvidia-smi

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"\nüñ•Ô∏è Using device: {device}")

if device == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    props = torch.cuda.get_device_properties(0)
    print(f"   Memory: {props.total_memory / 1e9:.1f} GB")
    print(f"   Compute capability: {props.major}.{props.minor}")
else:
    print("‚ö†Ô∏è No GPU detected! Training will be slow.")
    print("   Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

Mon Feb 16 11:03:33 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.82.07              Driver Version: 580.82.07      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   75C    P0             30W /   70W |     325MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

## 5. Configure Environment Variables

In [7]:
# Configure CUDA memory management to prevent OOM errors
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    # Enable TF32 for faster training on Ampere GPUs (A100)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

print("‚úÖ Environment configured")

‚úÖ Environment configured


## 6. Load and Verify Training Data

In [8]:
# Verify data files are present
DATA_DIR = './fluor_tools/Fluor-RLAT/data'

required_base = ['train', 'valid']
targets = ['abs', 'em', 'plqy', 'k']
file_types = ['', '_smiles', '_sol']

print(f"Checking data files in: {DATA_DIR}")
missing = []
found = []

for base in required_base:
    for target in targets:
        for ftype in file_types:
            filename = f'{base}{ftype}_{target}.csv'
            path = os.path.join(DATA_DIR, filename)
            if os.path.exists(path):
                size = os.path.getsize(path) / 1024 / 1024
                found.append((filename, size))
            else:
                missing.append(filename)

print(f"\n‚úÖ Found {len(found)} files:")
for fname, size in found[:6]:
    print(f"   {fname} ({size:.1f} MB)")
if len(found) > 6:
    print(f"   ... and {len(found) - 6} more")

if missing:
    print(f"\n‚ùå Missing {len(missing)} files:")
    for fname in missing[:5]:
        print(f"   {fname}")
    print("\n‚ö†Ô∏è Please upload data before training!")
else:
    print("\n‚úÖ All required files present!")

Checking data files in: ./fluor_tools/Fluor-RLAT/data

‚úÖ Found 24 files:
   train_abs.csv (9.3 MB)
   train_smiles_abs.csv (42.9 MB)
   train_sol_abs.csv (42.9 MB)
   train_em.csv (7.2 MB)
   train_smiles_em.csv (32.9 MB)
   train_sol_em.csv (32.9 MB)
   ... and 18 more

‚úÖ All required files present!


In [9]:
# Load sample counts
import pandas as pd

print("Dataset sizes:")
for target in ['abs', 'em', 'plqy', 'k']:
    train_path = os.path.join(DATA_DIR, f'train_{target}.csv')
    valid_path = os.path.join(DATA_DIR, f'valid_{target}.csv')
    
    if os.path.exists(train_path) and os.path.exists(valid_path):
        train_df = pd.read_csv(train_path)
        valid_df = pd.read_csv(valid_path)
        print(f"  {target.upper()}: {len(train_df):,} train / {len(valid_df):,} valid")

Dataset sizes:
  ABS: 21,948 train / 3,132 valid
  EM: 16,833 train / 2,370 valid
  PLQY: 12,998 train / 1,855 valid
  K: 6,976 train / 952 valid


## 7. Initialize Model Architecture

Define the complete model architecture including:
- AttentiveFP Graph Neural Network
- Fingerprint CNN with attention (for abs/em)
- Simple FC network (for plqy/k)

In [10]:
# Import all required libraries
import copy
import time
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

# Use tqdm.auto which automatically selects the right backend
# Falls back to text mode if notebook widgets aren't available
from tqdm.auto import tqdm

import dgl
from dgllife.model.gnn import AttentiveFPGNN
from dgllife.model.readout import AttentiveFPReadout
from dgllife.utils import (
    smiles_to_bigraph,
    AttentiveFPAtomFeaturizer,
    AttentiveFPBondFeaturizer,
)

print("‚úÖ All imports successful")

‚úÖ All imports successful


In [11]:
# ============================================================================
# Configuration
# ============================================================================

BATCH_SIZE = 32
LEARNING_RATE = 5e-4  # Reduced from 1e-3 for more stable training
GRAPH_FEAT_SIZE = 256

# Learning rate scheduler settings
USE_LR_SCHEDULER = True   # Set to False to disable scheduler
LR_SCHEDULER_FACTOR = 0.5  # Reduce LR by this factor when plateau
LR_SCHEDULER_PATIENCE = 10  # Epochs to wait before reducing LR
LR_SCHEDULER_MIN = 1e-6   # Minimum learning rate

# Model configurations per target (MATCHES pretrained model parameters)
# Verified from 02_property_prediction.py source code
MODEL_CONFIGS = {
    'abs': {
        'num_layers': 2,
        'num_timesteps': 2,
        'dropout': 0.3,
        'alpha': 0.1,  # LDS alpha
        'model_class': 'GraphFingerprintsModel',  # CNN attention + solvent extractor
    },
    'em': {
        'num_layers': 3,
        'num_timesteps': 1,
        'dropout': 0.3,
        'alpha': 0.0,
        'model_class': 'GraphFingerprintsModel',
    },
    'plqy': {
        'num_layers': 2,
        'num_timesteps': 3,
        'dropout': 0.4,
        'alpha': 0.2,
        'model_class': 'GraphFingerprintsModelFC',  # Simple FC for all fingerprints
    },
    'k': {
        'num_layers': 3,
        'num_timesteps': 1,
        'dropout': 0.3,
        'alpha': 0.6,
        'model_class': 'GraphFingerprintsModelFC',
    },
}

# Featurizers for molecular graphs
ATOM_FEATURIZER = AttentiveFPAtomFeaturizer(atom_data_field='hv')
BOND_FEATURIZER = AttentiveFPBondFeaturizer(bond_data_field='he')

print("‚úÖ Configuration loaded")
print(f"   Learning rate: {LEARNING_RATE}")
print(f"   LR scheduler: {'Enabled' if USE_LR_SCHEDULER else 'Disabled'}")
if USE_LR_SCHEDULER:
    print(f"      Factor: {LR_SCHEDULER_FACTOR}, Patience: {LR_SCHEDULER_PATIENCE}, Min LR: {LR_SCHEDULER_MIN}")
print("   abs/em: GraphFingerprintsModel (num_layers=2/3, num_timesteps=2/1)")
print("   plqy/k: GraphFingerprintsModelFC (num_layers=2/3, num_timesteps=3/1)")

‚úÖ Configuration loaded
   Learning rate: 0.0005
   LR scheduler: Enabled
      Factor: 0.5, Patience: 10, Min LR: 1e-06
   abs/em: GraphFingerprintsModel (num_layers=2/3, num_timesteps=2/1)
   plqy/k: GraphFingerprintsModelFC (num_layers=2/3, num_timesteps=3/1)


In [12]:
# ============================================================================
# Model Architecture Classes
# ============================================================================
# These MUST match the original architecture in 02_property_prediction.py
# to ensure compatibility with pretrained models and consistent results.
#
# VERIFIED against pretrained model state_dicts:
# - Model_abs.pth, Model_em.pth: use fp_extractor + solvent_extractor
# - Model_plqy.pth, Model_k.pth: use fp_fc only (no solvent split)

class FingerprintAttentionCNN(nn.Module):
    """CNN with attention for fingerprint processing (used by abs/em models).
    
    Layer names must match: fp_extractor.conv_feat, fp_extractor.conv_attn
    Output: 2 * conv_channels (512 when conv_channels=256)
    """
    
    def __init__(self, input_dim, conv_channels=256):
        super().__init__()
        self.conv_feat = nn.Conv1d(1, conv_channels, kernel_size=3, padding=1)
        self.conv_attn = nn.Conv1d(1, conv_channels, kernel_size=3, padding=1)
        self.softmax = nn.Softmax(dim=-1)
        self.pool = nn.AdaptiveMaxPool1d(1)

    def forward(self, x):
        x = x.unsqueeze(1)  # [B, 1, D]
        feat_map = self.conv_feat(x)         # [B, C, D]
        attn_map = self.conv_attn(x)         # [B, C, D]
        attn_weights = self.softmax(attn_map)
        attn_out = torch.sum(feat_map * attn_weights, dim=-1)  # [B, C]
        pooled = self.pool(feat_map).squeeze(-1)               # [B, C]
        return torch.cat([attn_out, pooled], dim=1)            # [B, 2C]


class GraphFingerprintsModel(nn.Module):
    """Model for abs/em: AttentiveFP GNN + CNN attention for fingerprints.
    
    This is the ORIGINAL class name from 02_property_prediction.py.
    Used for absorption and emission prediction.
    
    Layer names verified against Model_abs.pth / Model_em.pth:
    - gnn.*, readout.*: AttentiveFP layers
    - fp_extractor.conv_feat, fp_extractor.conv_attn: CNN attention
    - solvent_extractor.0, solvent_extractor.3: FC for solvent (1024->256->256)
    - predict.1, predict.3: final prediction (1024->128->1)
    """
    
    def __init__(self, node_feat_size, edge_feat_size, solvent_dim, smiles_extra_dim,
                 graph_feat_size=256, num_layers=2, num_timesteps=2, n_tasks=1, dropout=0.3):
        super().__init__()
        
        self.solvent_dim = solvent_dim
        
        # GNN
        self.gnn = AttentiveFPGNN(
            node_feat_size=node_feat_size,
            edge_feat_size=edge_feat_size,
            num_layers=num_layers,
            graph_feat_size=graph_feat_size,
            dropout=dropout
        )
        
        self.readout = AttentiveFPReadout(
            feat_size=graph_feat_size,
            num_timesteps=num_timesteps,
            dropout=dropout
        )
        
        # Fingerprint part: smiles + extra -> CNN attention -> 2*graph_feat_size
        self.fp_extractor = FingerprintAttentionCNN(smiles_extra_dim, conv_channels=graph_feat_size)
        
        # Fingerprint part: solvent -> FC -> graph_feat_size
        # Original: nn.Linear(solvent_dim, 256), ReLU, Dropout, nn.Linear(256, graph_feat_size)
        self.solvent_extractor = nn.Sequential(
            nn.Linear(solvent_dim, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, graph_feat_size)
        )
        
        # Prediction: graph(256) + solvent(256) + fp(512) = 1024 -> 128 -> 1
        self.predict = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(graph_feat_size * 4, 128),  # 1024
            nn.ReLU(),
            nn.Linear(128, n_tasks)
        )
        
    def forward(self, graph, node_feats, edge_feats, fingerprints):
        node_out = self.gnn(graph, node_feats, edge_feats)
        graph_out = self.readout(graph, node_out, False)
        
        solvent_feat = fingerprints[:, :self.solvent_dim]
        smiles_extra_feat = fingerprints[:, self.solvent_dim:]
        
        solvent_out = self.solvent_extractor(solvent_feat)
        smiles_extra_out = self.fp_extractor(smiles_extra_feat)
        
        combined = torch.cat([graph_out, solvent_out, smiles_extra_out], dim=1)
        return self.predict(combined)


class GraphFingerprintsModelFC(nn.Module):
    """Model for plqy/k: AttentiveFP GNN + Simple FC for fingerprints.
    
    This uses a different architecture from abs/em - processes ALL fingerprints
    together through a simple FC, no separate solvent extractor.
    
    Layer names verified against Model_plqy.pth / Model_k.pth:
    - gnn.*, readout.*: AttentiveFP layers
    - fp_fc.0, fp_fc.3: FC for all fingerprints (2192->256->256)
    - predict.1, predict.3: final prediction (512->128->1)
    """
    
    def __init__(self, node_feat_size, edge_feat_size, fp_size,
                 graph_feat_size=256, num_layers=2, num_timesteps=2, n_tasks=1, dropout=0.3):
        super().__init__()
        
        # GNN
        self.gnn = AttentiveFPGNN(
            node_feat_size=node_feat_size,
            edge_feat_size=edge_feat_size,
            num_layers=num_layers,
            graph_feat_size=graph_feat_size,
            dropout=dropout
        )
        
        self.readout = AttentiveFPReadout(
            feat_size=graph_feat_size,
            num_timesteps=num_timesteps,
            dropout=dropout
        )
        
        # FC for ALL fingerprints -> graph_feat_size
        # Original: Linear(fp_size, 256), ReLU, Dropout, Linear(256, graph_feat_size)
        # Indices: 0=Linear, 1=ReLU, 2=Dropout, 3=Linear
        self.fp_fc = nn.Sequential(
            nn.Linear(fp_size, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, graph_feat_size)
        )
        
        # Prediction: graph(256) + fp(256) = 512 -> 128 -> 1
        self.predict = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(graph_feat_size * 2, 128),  # 512
            nn.ReLU(),
            nn.Linear(128, n_tasks)
        )
        
    def forward(self, graph, node_feats, edge_feats, fingerprints):
        node_out = self.gnn(graph, node_feats, edge_feats)
        graph_out = self.readout(graph, node_out, False)
        fp_out = self.fp_fc(fingerprints)
        combined = torch.cat([graph_out, fp_out], dim=1)
        return self.predict(combined)


# Aliases to match different naming conventions
GraphFingerprintsModelAttentionCNN = GraphFingerprintsModel  # for abs/em
GraphFingerprintsModelSimpleFC = GraphFingerprintsModelFC     # for plqy/k

print("‚úÖ Model classes defined (matching original 02_property_prediction.py)")
print("   - GraphFingerprintsModel (alias: GraphFingerprintsModelAttentionCNN): for abs/em")
print("   - GraphFingerprintsModelFC (alias: GraphFingerprintsModelSimpleFC): for plqy/k")
print("")
print("   Architecture summary:")
print("   abs/em:  fp_extractor (CNN) + solvent_extractor (FC) -> predict(1024->128->1)")
print("   plqy/k:  fp_fc (FC, all fingerprints combined) -> predict(512->128->1)")

‚úÖ Model classes defined (matching original 02_property_prediction.py)
   - GraphFingerprintsModel (alias: GraphFingerprintsModelAttentionCNN): for abs/em
   - GraphFingerprintsModelFC (alias: GraphFingerprintsModelSimpleFC): for plqy/k

   Architecture summary:
   abs/em:  fp_extractor (CNN) + solvent_extractor (FC) -> predict(1024->128->1)
   plqy/k:  fp_fc (FC, all fingerprints combined) -> predict(512->128->1)


In [13]:
# ============================================================================
# Model Configs for Original Pretrained Models
# ============================================================================
# Verified hyperparameters from 02_property_prediction.py:
#
# abs:  num_layers=2, num_timesteps=2, dropout=0.3, alpha=0.1
# em:   num_layers=3, num_timesteps=1, dropout=0.3, alpha=0.0  
# plqy: num_layers=2, num_timesteps=3, dropout=0.4, alpha=0.2
# k:    num_layers=3, num_timesteps=1, dropout=0.3, alpha=0.6

ORIGINAL_MODEL_CONFIGS = {
    'abs':  {'num_layers': 2, 'num_timesteps': 2, 'dropout': 0.3, 'alpha': 0.1, 'model_class': 'GraphFingerprintsModel'},
    'em':   {'num_layers': 3, 'num_timesteps': 1, 'dropout': 0.3, 'alpha': 0.0, 'model_class': 'GraphFingerprintsModel'},
    'plqy': {'num_layers': 2, 'num_timesteps': 3, 'dropout': 0.4, 'alpha': 0.2, 'model_class': 'GraphFingerprintsModelFC'},
    'k':    {'num_layers': 3, 'num_timesteps': 1, 'dropout': 0.3, 'alpha': 0.6, 'model_class': 'GraphFingerprintsModelFC'},
}

print("‚úÖ Original model configs defined")
print("   abs/em use GraphFingerprintsModel (CNN attention + solvent extractor)")
print("   plqy/k use GraphFingerprintsModelFC (simple FC for all fingerprints)")

‚úÖ Original model configs defined
   abs/em use GraphFingerprintsModel (CNN attention + solvent extractor)
   plqy/k use GraphFingerprintsModelFC (simple FC for all fingerprints)


In [14]:
# ============================================================================
# Dataset and Data Loading
# ============================================================================

class MolecularDataset(Dataset):
    """Dataset for molecular property prediction."""
    
    def __init__(self, graphs, fingerprints, labels, masks=None, weights=None):
        self.graphs = graphs
        self.fingerprints = fingerprints
        self.labels = labels
        self.masks = masks
        self.weights = weights
        
    def __len__(self):
        return len(self.graphs)
    
    def __getitem__(self, idx):
        return (
            self.graphs[idx],
            self.fingerprints[idx],
            self.labels[idx],
            self.masks[idx] if self.masks is not None else 1.0,
            self.weights[idx] if self.weights is not None else 1.0
        )


def collate_fn(batch):
    """Custom collate function for batching molecular graphs."""
    graphs, fps, labels, masks, weights = zip(*batch)
    batched_graph = dgl.batch(graphs)
    fps = torch.stack(fps)
    labels = torch.stack(labels)
    masks = torch.tensor(masks, dtype=torch.float32) if masks[0] is not None else None
    weights = torch.tensor(weights, dtype=torch.float32) if weights[0] is not None else None
    return batched_graph, fps, labels, masks, weights


def compute_lds_weights(labels, alpha=0.5, kernel_size=5):
    """Compute Label Distribution Smoothing weights."""
    if alpha == 0:
        return np.ones(len(labels))
    
    n_bins = 100
    label_min, label_max = labels.min(), labels.max()
    bins = np.linspace(label_min, label_max, n_bins + 1)
    bin_indices = np.digitize(labels, bins) - 1
    bin_indices = np.clip(bin_indices, 0, n_bins - 1)
    
    bin_counts = np.bincount(bin_indices, minlength=n_bins).astype(float)
    kernel = np.ones(kernel_size) / kernel_size
    smoothed_counts = np.convolve(bin_counts, kernel, mode='same')
    smoothed_counts = np.maximum(smoothed_counts, 1)
    
    effective_counts = bin_counts ** alpha * smoothed_counts ** (1 - alpha)
    effective_counts = np.maximum(effective_counts, 1)
    
    weights = 1.0 / effective_counts[bin_indices]
    weights = weights / weights.mean()
    
    return weights.astype(np.float32)


def smiles_to_graph(smiles):
    """Convert SMILES to DGL graph with AttentiveFP featurization."""
    try:
        graph = smiles_to_bigraph(
            smiles,
            node_featurizer=ATOM_FEATURIZER,
            edge_featurizer=BOND_FEATURIZER,
            add_self_loop=False
        )
        return graph
    except Exception as e:
        return None


print("‚úÖ Dataset classes defined")

‚úÖ Dataset classes defined


In [15]:
# ============================================================================
# Data Loading Functions
# ============================================================================

def load_and_process_data(target, data_dir, config):
    """Load and process training/validation data for a target property."""
    
    data_dir = Path(data_dir)
    
    # Load main data files
    train_df = pd.read_csv(data_dir / f'train_{target}.csv')
    valid_df = pd.read_csv(data_dir / f'valid_{target}.csv')
    
    # Load fingerprint files
    train_smiles_fp = pd.read_csv(data_dir / f'train_smiles_{target}.csv')
    train_sol_fp = pd.read_csv(data_dir / f'train_sol_{target}.csv')
    valid_smiles_fp = pd.read_csv(data_dir / f'valid_smiles_{target}.csv')
    valid_sol_fp = pd.read_csv(data_dir / f'valid_sol_{target}.csv')
    
    print(f"   Train samples: {len(train_df)}")
    print(f"   Valid samples: {len(valid_df)}")
    
    # Extract labels
    train_labels = train_df[target].values.reshape(-1, 1).astype(np.float32)
    valid_labels = valid_df[target].values.reshape(-1, 1).astype(np.float32)
    
    # Normalize labels
    scaler = StandardScaler()
    train_labels_scaled = scaler.fit_transform(train_labels)
    valid_labels_scaled = scaler.transform(valid_labels)
    
    # Extract and normalize numeric features
    numeric_cols = train_df.columns[8:16].tolist() if len(train_df.columns) > 16 else []
    
    scaler_num = MinMaxScaler()
    if numeric_cols:
        train_numeric = scaler_num.fit_transform(train_df[numeric_cols].values)
        valid_numeric = scaler_num.transform(valid_df[numeric_cols].values)
    else:
        train_numeric = np.zeros((len(train_df), 0))
        valid_numeric = np.zeros((len(valid_df), 0))
    
    # Combine fingerprints
    train_sol = train_sol_fp.values.astype(np.float32)
    train_smiles = train_smiles_fp.values.astype(np.float32)
    valid_sol = valid_sol_fp.values.astype(np.float32)
    valid_smiles = valid_smiles_fp.values.astype(np.float32)
    
    # Get scaffold flags if present
    scaffold_cols = [c for c in train_df.columns if c.startswith('fragment_')]
    if scaffold_cols:
        train_scaffold = train_df[scaffold_cols].values.astype(np.float32)
        valid_scaffold = valid_df[scaffold_cols].values.astype(np.float32)
    else:
        train_scaffold = np.zeros((len(train_df), 0), dtype=np.float32)
        valid_scaffold = np.zeros((len(valid_df), 0), dtype=np.float32)
    
    # Combine all features
    train_fp = np.hstack([train_sol, train_smiles, train_numeric, train_scaffold])
    valid_fp = np.hstack([valid_sol, valid_smiles, valid_numeric, valid_scaffold])
    
    solvent_dim = train_sol.shape[1]
    smiles_extra_dim = train_fp.shape[1] - solvent_dim
    fp_size = train_fp.shape[1]
    
    print(f"   Total fingerprint dimensions: {fp_size}")
    
    # Compute LDS weights
    alpha = config['alpha']
    train_weights = compute_lds_weights(train_labels.flatten(), alpha=alpha)
    print(f"   LDS alpha: {alpha}, weight range: [{train_weights.min():.2f}, {train_weights.max():.2f}]")
    
    # Convert SMILES to graphs
    print("   Converting SMILES to graphs...")
    
    train_smiles_list = train_df['smiles'].tolist()
    valid_smiles_list = valid_df['smiles'].tolist()
    
    train_graphs = []
    for smi in tqdm(train_smiles_list, desc="   Train graphs", leave=False):
        g = smiles_to_graph(smi)
        if g is None:
            g = dgl.graph(([0], [0]))
            g.ndata['hv'] = torch.zeros(1, 39)
            g.edata['he'] = torch.zeros(1, 10)
        train_graphs.append(g)
    
    valid_graphs = []
    for smi in tqdm(valid_smiles_list, desc="   Valid graphs", leave=False):
        g = smiles_to_graph(smi)
        if g is None:
            g = dgl.graph(([0], [0]))
            g.ndata['hv'] = torch.zeros(1, 39)
            g.edata['he'] = torch.zeros(1, 10)
        valid_graphs.append(g)
    
    print(f"   ‚úÖ Converted {len(train_graphs)} train + {len(valid_graphs)} valid graphs")
    
    # Create datasets
    train_data = MolecularDataset(
        graphs=train_graphs,
        fingerprints=torch.tensor(train_fp, dtype=torch.float32),
        labels=torch.tensor(train_labels_scaled, dtype=torch.float32),
        masks=None,
        weights=train_weights
    )
    
    valid_data = MolecularDataset(
        graphs=valid_graphs,
        fingerprints=torch.tensor(valid_fp, dtype=torch.float32),
        labels=torch.tensor(valid_labels_scaled, dtype=torch.float32),
        masks=None,
        weights=None
    )
    
    return {
        'train_data': train_data,
        'valid_data': valid_data,
        'scaler': scaler,
        'scaler_num': scaler_num,
        'config': config,
        'solvent_dim': solvent_dim,
        'smiles_extra_dim': smiles_extra_dim,
        'fp_size': fp_size,
    }


print("‚úÖ Data loading functions defined")

‚úÖ Data loading functions defined


In [16]:
# ============================================================================
# Training Functions
# ============================================================================

def train_epoch(model, dataloader, optimizer, criterion, device, use_lds=True):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    num_batches = 0
    
    for batch in dataloader:
        graphs, fps, labels, masks, weights = batch
        
        graphs = graphs.to(device)
        fps = fps.to(device)
        labels = labels.to(device)
        
        node_feats = graphs.ndata['hv']
        edge_feats = graphs.edata.get('he', None)
        
        optimizer.zero_grad()
        predictions = model(graphs, node_feats, edge_feats, fps)
        
        if use_lds and weights is not None:
            weights = weights.to(device)
            loss = (criterion(predictions, labels) * weights.unsqueeze(1)).mean()
        else:
            loss = criterion(predictions, labels).mean()
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
    
    return total_loss / num_batches


def validate(model, dataloader, criterion, device):
    """Validate model on validation set."""
    model.eval()
    total_loss = 0
    num_batches = 0
    
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in dataloader:
            graphs, fps, labels, masks, _ = batch
            
            graphs = graphs.to(device)
            fps = fps.to(device)
            labels = labels.to(device)
            
            node_feats = graphs.ndata['hv']
            edge_feats = graphs.edata.get('he', None)
            
            predictions = model(graphs, node_feats, edge_feats, fps)
            loss = criterion(predictions, labels).mean()
            
            total_loss += loss.item()
            num_batches += 1
            
            all_preds.append(predictions.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
    
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    
    return total_loss / num_batches, all_preds, all_labels


print("‚úÖ Training functions defined")

‚úÖ Training functions defined


## 8. Run Training Loop

In [17]:
# ============================================================================
# Main Training Function with Checkpoint Resumption
# ============================================================================

def find_checkpoint(checkpoint_dir, target):
    """Check if a checkpoint exists for the given target."""
    checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_{target}.pth')
    if os.path.exists(checkpoint_path):
        return checkpoint_path
    return None


def check_checkpoint_status(checkpoint_path, epochs, patience):
    """Check if training is already complete without loading full checkpoint into model."""
    if not checkpoint_path:
        return None
    
    try:
        ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
        epoch = ckpt.get('epoch', 0)
        epochs_without_improvement = ckpt.get('epochs_without_improvement', 0)
        best_val_loss = ckpt.get('best_val_loss', float('inf'))
        history = ckpt.get('history', {'train_loss': [], 'val_loss': []})
        
        is_complete = epochs_without_improvement >= patience or epoch >= epochs
        
        return {
            'epoch': epoch,
            'best_val_loss': best_val_loss,
            'epochs_without_improvement': epochs_without_improvement,
            'history': history,
            'is_complete': is_complete,
            'scaler_mean': ckpt.get('scaler_mean'),
            'scaler_scale': ckpt.get('scaler_scale'),
            'best_model_state': ckpt.get('model_state_dict'),
        }
    except Exception as e:
        print(f"   ‚ö†Ô∏è Failed to read checkpoint: {e}")
        return None


def load_checkpoint(checkpoint_path, model, optimizer, device):
    """Load checkpoint and return training state."""
    print(f"   üìÇ Loading checkpoint from: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    # Move optimizer states to device
    for state in optimizer.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.to(device)
    
    return {
        'epoch': checkpoint['epoch'],
        'best_val_loss': checkpoint['best_val_loss'],
        'history': checkpoint.get('history', {'train_loss': [], 'val_loss': []}),
        'epochs_without_improvement': checkpoint.get('epochs_without_improvement', 0),
        'best_model_state': checkpoint['model_state_dict'],
    }


def save_checkpoint(checkpoint_dir, target, epoch, model, optimizer, best_val_loss, 
                    best_model_state, history, epochs_without_improvement, scaler, config):
    """Save comprehensive checkpoint with all training state."""
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_{target}.pth')
    
    torch.save({
        'epoch': epoch,
        'model_state_dict': best_model_state,
        'optimizer_state_dict': optimizer.state_dict(),
        'best_val_loss': best_val_loss,
        'history': history,
        'epochs_without_improvement': epochs_without_improvement,
        'scaler_mean': scaler.mean_.tolist(),
        'scaler_scale': scaler.scale_.tolist(),
        'config': config,
        'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
    }, checkpoint_path)
    
    return checkpoint_path


def train_model(target, data_dir='./data', output_dir='./models', 
                checkpoint_dir='./checkpoints', epochs=200, patience=20, device='cuda',
                resume=True):
    """Train a single property prediction model with checkpoint saving and resumption.
    
    Uses MODEL_CONFIGS to determine model architecture:
    - abs/em: GraphFingerprintsModel (CNN attention + solvent extractor)
    - plqy/k: GraphFingerprintsModelFC (simple FC for all fingerprints)
    """
    
    print(f"\n{'='*60}")
    print(f"üöÄ Training model for: {target.upper()}")
    print(f"{'='*60}")
    
    start_time = time.time()
    
    # Get model config
    config = MODEL_CONFIGS[target]
    model_class_name = config['model_class']
    print(f"   Config: {config}")
    
    # =========================================================================
    # EARLY CHECK: See if training is already complete BEFORE loading data
    # =========================================================================
    checkpoint_path = find_checkpoint(checkpoint_dir, target) if resume else None
    
    if checkpoint_path:
        ckpt_status = check_checkpoint_status(checkpoint_path, epochs, patience)
        
        if ckpt_status and ckpt_status['is_complete']:
            # Training already done - return cached results without loading data
            print(f"   üìÇ Found completed checkpoint")
            print(f"   ‚úÖ Already trained: epoch {ckpt_status['epoch']}, " + 
                  f"val_loss={ckpt_status['best_val_loss']:.4f}")
            
            if ckpt_status['epochs_without_improvement'] >= patience:
                print(f"   ‚èπÔ∏è Early stopped at epoch {ckpt_status['epoch']} " +
                      f"(no improvement for {patience} epochs)")
            else:
                print(f"   ‚èπÔ∏è Completed {epochs} epochs")
            
            # Save model to output_dir from checkpoint (so predictions work)
            os.makedirs(output_dir, exist_ok=True)
            model_path = os.path.join(output_dir, f'Model_{target}.pth')
            if ckpt_status['best_model_state'] is not None:
                torch.save(ckpt_status['best_model_state'], model_path)
                print(f"   üíæ Model saved to: {model_path}")
            
            # Return results using cached scaler params (no data loading needed!)
            return {
                'target': target,
                'best_val_loss': ckpt_status['best_val_loss'],
                'mae': float('nan'),  # Would need validation data to compute
                'rmse': float('nan'),
                'r2': float('nan'),
                'epochs_trained': ckpt_status['epoch'],
                'training_time': 0,
                'history': ckpt_status['history'],
                'resumed': True,
                'already_complete': True,
            }
        elif ckpt_status:
            print(f"   üìÇ Found in-progress checkpoint at epoch {ckpt_status['epoch']}")
            print(f"   üîÑ Will resume training...")
    else:
        print(f"   üìù No checkpoint found, will train from scratch")
    
    # =========================================================================
    # Only load data if we actually need to train
    # =========================================================================
    
    # Get feature dimensions
    sample_graph = smiles_to_graph('CCO')
    n_feats = sample_graph.ndata['hv'].shape[1]
    e_feats = sample_graph.edata['he'].shape[1]
    print(f"   Node features: {n_feats}, Edge features: {e_feats}")
    
    # Load data
    print(f"\nüìÇ Loading data for {target}...")
    data = load_and_process_data(target, data_dir, config)
    
    # Create data loaders with GPU optimization
    train_loader = DataLoader(
        data['train_data'],
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=2,
        pin_memory=True
    )
    
    valid_loader = DataLoader(
        data['valid_data'],
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=2,
        pin_memory=True
    )
    
    # Initialize model based on model_class
    if model_class_name == 'GraphFingerprintsModel':
        # abs/em: CNN attention + solvent extractor
        model = GraphFingerprintsModel(
            node_feat_size=n_feats,
            edge_feat_size=e_feats,
            solvent_dim=data['solvent_dim'],
            smiles_extra_dim=data['smiles_extra_dim'],
            graph_feat_size=GRAPH_FEAT_SIZE,
            num_layers=config['num_layers'],
            num_timesteps=config['num_timesteps'],
            n_tasks=1,
            dropout=config['dropout']
        ).to(device)
    else:
        # plqy/k: Simple FC for all fingerprints
        model = GraphFingerprintsModelFC(
            node_feat_size=n_feats,
            edge_feat_size=e_feats,
            fp_size=data['fp_size'],
            graph_feat_size=GRAPH_FEAT_SIZE,
            num_layers=config['num_layers'],
            num_timesteps=config['num_timesteps'],
            n_tasks=1,
            dropout=config['dropout']
        ).to(device)
    
    total_params = sum(p.numel() for p in model.parameters())
    print(f"   Model parameters: {total_params:,}")
    
    # Loss and optimizer
    criterion = nn.MSELoss(reduction='none')
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    # Learning rate scheduler (reduce on plateau)
    scheduler = None
    if USE_LR_SCHEDULER:
        from torch.optim.lr_scheduler import ReduceLROnPlateau
        scheduler = ReduceLROnPlateau(
            optimizer, 
            mode='min', 
            factor=LR_SCHEDULER_FACTOR, 
            patience=LR_SCHEDULER_PATIENCE,
            min_lr=LR_SCHEDULER_MIN
        )
    
    # Initialize training state
    start_epoch = 1
    best_val_loss = float('inf')
    best_model_state = None
    epochs_without_improvement = 0
    history = {'train_loss': [], 'val_loss': []}
    
    # Load checkpoint if resuming
    if checkpoint_path:
        try:
            ckpt = load_checkpoint(checkpoint_path, model, optimizer, device)
            start_epoch = ckpt['epoch'] + 1
            best_val_loss = ckpt['best_val_loss']
            best_model_state = ckpt['best_model_state']
            history = ckpt['history']
            epochs_without_improvement = ckpt['epochs_without_improvement']
            
            print(f"   ‚úÖ Resuming from epoch {start_epoch}")
            print(f"   Previous best val loss: {best_val_loss:.4f}")
            print(f"   Epochs without improvement: {epochs_without_improvement}")
                
        except Exception as e:
            print(f"   ‚ö†Ô∏è Failed to load checkpoint: {e}")
            print(f"   Starting fresh training...")
            start_epoch = 1
            best_val_loss = float('inf')
            best_model_state = None
            epochs_without_improvement = 0
            history = {'train_loss': [], 'val_loss': []}
    
    use_lds = config['alpha'] > 0
    
    remaining_epochs = epochs - start_epoch + 1
    print(f"\nüìà Training epochs {start_epoch} to {epochs} ({remaining_epochs} remaining)")
    print(f"   Patience: {patience}, Using LDS: {use_lds}")
    
    # Monitor GPU memory
    if device == 'cuda':
        torch.cuda.reset_peak_memory_stats()
    
    pbar = tqdm(range(start_epoch, epochs + 1), desc=f"{target.upper()}", unit="epoch", 
                dynamic_ncols=True, leave=True)
    final_epoch = start_epoch - 1
    last_lr = LEARNING_RATE
    
    for epoch in pbar:
        train_loss = train_epoch(model, train_loader, optimizer, criterion, device, use_lds)
        val_loss, val_preds, val_labels = validate(model, valid_loader, criterion, device)
        
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        
        # Step the learning rate scheduler
        if scheduler is not None:
            old_lr = optimizer.param_groups[0]['lr']
            scheduler.step(val_loss)
            new_lr = optimizer.param_groups[0]['lr']
            if new_lr < old_lr:
                print(f"   üìâ LR reduced: {old_lr:.2e} ‚Üí {new_lr:.2e}")
                last_lr = new_lr
        
        # Check for improvement
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = copy.deepcopy(model.state_dict())
            epochs_without_improvement = 0
            improved = True
        else:
            epochs_without_improvement += 1
            improved = False
        
        # Save checkpoint (every epoch for robustness against disconnection)
        save_checkpoint(
            checkpoint_dir, target, epoch, model, optimizer,
            best_val_loss, best_model_state, history, 
            epochs_without_improvement, data['scaler'], config
        )
        
        # Update progress bar
        current_lr = optimizer.param_groups[0]['lr']
        pbar.set_postfix({
            'train': f'{train_loss:.4f}',
            'val': f'{val_loss:.4f}',
            'best': f'{best_val_loss:.4f}',
            'no_improv': epochs_without_improvement,
            'lr': f'{current_lr:.1e}',
        })
        
        # Print periodic updates (every 10 epochs or on improvement) as fallback
        if epoch % 10 == 0 or improved or epoch == start_epoch:
            star = " ‚òÖ new best!" if improved else ""
            print(f"   Epoch {epoch:3d}/{epochs}: train={train_loss:.4f}, val={val_loss:.4f}, best={best_val_loss:.4f}{star}")
        
        final_epoch = epoch
        
        # Early stopping
        if epochs_without_improvement >= patience:
            print(f"\n‚èπÔ∏è Early stopping at epoch {epoch}")
            break
    
    # Load best model and evaluate
    model.load_state_dict(best_model_state)
    
    _, val_preds, val_labels = validate(model, valid_loader, criterion, device)
    
    # Inverse transform predictions
    val_preds_orig = data['scaler'].inverse_transform(val_preds)
    val_labels_orig = data['scaler'].inverse_transform(val_labels)
    
    mae = mean_absolute_error(val_labels_orig, val_preds_orig)
    rmse = np.sqrt(mean_squared_error(val_labels_orig, val_preds_orig))
    r2 = r2_score(val_labels_orig, val_preds_orig)
    
    print(f"\nüìä Final Metrics (original scale):")
    print(f"   MAE:  {mae:.4f}")
    print(f"   RMSE: {rmse:.4f}")
    print(f"   R¬≤:   {r2:.4f}")
    
    # Report GPU memory usage
    if device == 'cuda':
        peak_memory = torch.cuda.max_memory_allocated() / 1e9
        print(f"   Peak GPU memory: {peak_memory:.2f} GB")
    
    # Save final model
    os.makedirs(output_dir, exist_ok=True)
    model_path = os.path.join(output_dir, f'Model_{target}.pth')
    torch.save(best_model_state, model_path)
    print(f"\nüíæ Model saved to: {model_path}")
    
    # Calculate training time
    elapsed_time = time.time() - start_time
    hours, remainder = divmod(elapsed_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    time_str = f"{int(hours)}h {int(minutes)}m {int(seconds)}s"
    print(f"‚è±Ô∏è Training time: {time_str}")
    
    return {
        'target': target,
        'best_val_loss': best_val_loss,
        'mae': mae,
        'rmse': rmse,
        'r2': r2,
        'epochs_trained': final_epoch,
        'training_time': elapsed_time,
        'history': history,
        'resumed': checkpoint_path is not None,
    }


print("‚úÖ Main training function defined (with early checkpoint check)")

‚úÖ Main training function defined (with early checkpoint check)


In [None]:
# ============================================================================
# Training Configuration
# ============================================================================

# Paths - using cloned repository data
DATA_DIR = './fluor_tools/Fluor-RLAT/data'
OUTPUT_DIR = './models'
CHECKPOINT_DIR = '/content/drive/MyDrive/fluor_checkpoints'  # Save checkpoints to Drive

# Training parameters
EPOCHS = 200
PATIENCE = 30

# Resume from checkpoints if they exist
RESUME_FROM_CHECKPOINT = True  # Set to False to start fresh

# Select which models to train
# Options: 'abs', 'em', 'plqy', 'k'
TARGETS = ['abs', 'em', 'plqy', 'k']  # Train all models
# TARGETS = ['abs']  # Train only absorption model

print(f"üéØ Models to train: {TARGETS}")
print(f"üìÅ Data directory: {DATA_DIR}")
print(f"üìÅ Output directory: {OUTPUT_DIR}")
print(f"üíæ Checkpoints: {CHECKPOINT_DIR}")
print(f"‚öôÔ∏è  Epochs: {EPOCHS}, Patience: {PATIENCE}")
print(f"üîÑ Resume from checkpoint: {RESUME_FROM_CHECKPOINT}")

üéØ Models to train: ['abs', 'em', 'plqy', 'k']
üìÅ Data directory: ./fluor_tools/Fluor-RLAT/data
üìÅ Output directory: ./models
üíæ Checkpoints: /content/drive/MyDrive/fluor_checkpoints
‚öôÔ∏è  Epochs: 200, Patience: 30
üîÑ Resume from checkpoint: True


In [None]:
# ============================================================================
# Check Existing Checkpoints
# ============================================================================

print(f"üîç Checking for existing checkpoints in: {CHECKPOINT_DIR}\n")

if os.path.exists(CHECKPOINT_DIR):
    found_checkpoints = []
    for target in TARGETS:
        ckpt_path = os.path.join(CHECKPOINT_DIR, f'checkpoint_{target}.pth')
        if os.path.exists(ckpt_path):
            try:
                ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
                epoch = ckpt.get('epoch', '?')
                val_loss = ckpt.get('best_val_loss', '?')
                timestamp = ckpt.get('timestamp', 'unknown')
                epochs_no_improv = ckpt.get('epochs_without_improvement', 0)
                
                status = "‚úÖ Complete" if epochs_no_improv >= PATIENCE or epoch >= EPOCHS else "üîÑ In progress"
                
                found_checkpoints.append({
                    'target': target,
                    'epoch': epoch,
                    'val_loss': val_loss,
                    'timestamp': timestamp,
                    'status': status,
                    'epochs_no_improv': epochs_no_improv,
                })
                print(f"  üì¶ {target.upper()}: Epoch {epoch}, Val Loss: {val_loss:.4f}, {status}")
                print(f"      Last saved: {timestamp}, No improvement: {epochs_no_improv}/{PATIENCE}")
            except Exception as e:
                print(f"  ‚ö†Ô∏è  {target.upper()}: Checkpoint exists but failed to read: {e}")
        else:
            print(f"  ‚¨ú {target.upper()}: No checkpoint found")
    
    if found_checkpoints and RESUME_FROM_CHECKPOINT:
        print(f"\n‚úÖ Will resume training from existing checkpoints")
    elif found_checkpoints and not RESUME_FROM_CHECKPOINT:
        print(f"\n‚ö†Ô∏è  Checkpoints exist but RESUME_FROM_CHECKPOINT=False")
        print(f"   Training will start fresh (existing checkpoints ignored)")
else:
    print(f"  üìÅ Checkpoint directory does not exist yet")
    print(f"  üìù Will start fresh training for all models")

In [None]:
# ============================================================================
# üöÄ START TRAINING
# ============================================================================

total_start = time.time()
results = []

for target in TARGETS:
    result = train_model(
        target=target,
        data_dir=DATA_DIR,
        output_dir=OUTPUT_DIR,
        checkpoint_dir=CHECKPOINT_DIR,
        epochs=EPOCHS,
        patience=PATIENCE,
        device=device,
        resume=RESUME_FROM_CHECKPOINT  # Use checkpoint if available
    )
    results.append(result)
    
    # Clear GPU cache between models
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

total_elapsed = time.time() - total_start

# ============================================================================
# Summary
# ============================================================================
print(f"\n{'='*60}")
print("üìã TRAINING SUMMARY")
print(f"{'='*60}")

for r in results:
    t = r['training_time']
    h, rem = divmod(t, 3600)
    m, s = divmod(rem, 60)
    time_str = f"{int(h)}h {int(m)}m {int(s)}s" if t > 0 else "0s (loaded from checkpoint)"
    
    resumed = " (resumed)" if r.get('resumed', False) else ""
    complete = " [already complete]" if r.get('already_complete', False) else ""
    
    print(f"\n{r['target'].upper()}{resumed}{complete}:")
    print(f"   Epochs: {r['epochs_trained']}, Time: {time_str}")
    print(f"   MAE: {r['mae']:.4f}, RMSE: {r['rmse']:.4f}, R¬≤: {r['r2']:.4f}")

total_h, total_rem = divmod(total_elapsed, 3600)
total_m, total_s = divmod(total_rem, 60)
print(f"\n{'='*60}")
print(f"‚è±Ô∏è Total training time: {int(total_h)}h {int(total_m)}m {int(total_s)}s")
print(f"‚úÖ Training complete!")

## 9. Download Trained Models

In [None]:
# List trained models
print("Trained models:")
for f in os.listdir(OUTPUT_DIR):
    if f.endswith('.pth'):
        size = os.path.getsize(os.path.join(OUTPUT_DIR, f)) / 1e6
        print(f"  üì¶ {f} ({size:.1f} MB)")

In [None]:
# Option 2: Copy to Google Drive
DRIVE_OUTPUT = "/content/drive/MyDrive/fluor_models"
os.makedirs(DRIVE_OUTPUT, exist_ok=True)

for f in os.listdir(OUTPUT_DIR):
    if f.endswith('.pth'):
        src = os.path.join(OUTPUT_DIR, f)
        dst = os.path.join(DRIVE_OUTPUT, f)
        !cp "{src}" "{dst}"

print(f"‚úÖ Models copied to: {DRIVE_OUTPUT}")

## 10. Training History Visualization

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()

for idx, r in enumerate(results):
    if idx >= 4:
        break
    ax = axes[idx]
    ax.plot(r['history']['train_loss'], label='Train Loss', alpha=0.8)
    ax.plot(r['history']['val_loss'], label='Val Loss', alpha=0.8)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title(f"{r['target'].upper()} - MAE: {r['mae']:.2f}, R¬≤: {r['r2']:.3f}")
    ax.legend()
    ax.grid(True, alpha=0.3)

# Hide unused subplots
for idx in range(len(results), 4):
    axes[idx].axis('off')

plt.tight_layout()

# Save to Google Drive
plot_path = os.path.join(DRIVE_OUTPUT, 'training_history.png')
plt.savefig(plot_path, dpi=150)
plt.show()
print(f"üìä Plot saved to: {plot_path}")

## 11. Make Predictions

Use the trained models to predict properties for a new molecule.

In [None]:
# ============================================================================
# Prediction Input
# ============================================================================

molecule = "C2=C1C7=C(C(=[N+]1[B-]([N]3C2=C5C(=C3C4=CC=CC=C4)C=CC=C5)(F)F)C6=CC=CC=C6)C=CC=C7"
solvent = "CC1=CC=CC=C1"

print(f"üß™ Molecule: {molecule}")
print(f"üß´ Solvent:  {solvent}")

In [None]:
# ============================================================================
# Run Prediction
# ============================================================================
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from sklearn.preprocessing import StandardScaler, MinMaxScaler

def predict_properties(molecule_smiles, solvent_smiles, model_dir='./models', 
                       checkpoint_dir='/content/drive/MyDrive/fluor_checkpoints', device='cuda'):
    """Predict all properties for a single molecule with proper preprocessing and inverse scaling.
    
    This function mimics the preprocessing from 02_property_prediction.py.
    Uses MODEL_CONFIGS which matches ORIGINAL_MODEL_CONFIGS for consistency.
    """
    
    # Generate molecular graph
    graph = smiles_to_graph(molecule_smiles)
    if graph is None:
        raise ValueError(f"Could not parse molecule SMILES: {molecule_smiles}")
    
    mol = Chem.MolFromSmiles(molecule_smiles)
    sol = Chem.MolFromSmiles(solvent_smiles)
    if mol is None or sol is None:
        raise ValueError("Invalid SMILES")
    
    # Generate Morgan fingerprints (1024-bit, radius 2)
    mol_fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=1024)
    sol_fp = AllChem.GetMorganFingerprintAsBitVect(sol, radius=2, nBits=1024)
    mol_fp_arr = np.array(mol_fp, dtype=np.float32)
    sol_fp_arr = np.array(sol_fp, dtype=np.float32)
    
    # Compute molecular descriptors
    mw = Descriptors.MolWt(mol)
    logp = Descriptors.MolLogP(mol)
    tpsa = Descriptors.TPSA(mol)
    double_bonds = sum(1 for bond in mol.GetBonds() 
                       if bond.GetBondType() == Chem.BondType.DOUBLE or bond.GetIsAromatic())
    ring_count = mol.GetRingInfo().NumRings()
    
    # Get solvent_num from mapping (simplified)
    solvent_mapping = {
        'CC1=CC=CC=C1': 6, 'Cc1ccccc1': 6,  # toluene
        'CCO': 2, 'CO': 1, 'c1ccccc1': 5,
    }
    solvent_num = solvent_mapping.get(solvent_smiles, 0)
    
    # Detect scaffold (simplified - check for BODIPY)
    bodipy_smarts = '[#5](-F)(-F)(-[#7])(-[#7])'
    tag = 0
    bodipy_pattern = Chem.MolFromSmarts(bodipy_smarts)
    if bodipy_pattern and mol.HasSubstructMatch(bodipy_pattern):
        tag = 5  # BODIPY
    
    # Create 136 scaffold flags
    scaffold_flags = np.zeros(136, dtype=np.float32)
    if tag == 5:  # BODIPY
        scaffold_flags[3] = 1  # fragment_4
    
    unimol_plus = 3.49  # placeholder
    
    # Numeric features: [solvent_num, tag, MW, LogP, TPSA, double_bonds, ring_count, unimol_plus]
    numeric_feats = np.array([solvent_num, tag, mw, logp, tpsa, double_bonds, ring_count, unimol_plus], dtype=np.float32)
    
    predictions = {}
    n_feats = graph.ndata['hv'].shape[1]
    e_feats = graph.edata['he'].shape[1]
    
    for target in ['abs', 'em', 'plqy', 'k']:
        model_path = os.path.join(model_dir, f'Model_{target}.pth')
        if not os.path.exists(model_path):
            print(f"   ‚ö†Ô∏è Model not found: {model_path}")
            continue
        
        # Load scaler parameters from checkpoint
        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_{target}.pth')
        scaler_mean = None
        scaler_scale = None
        num_scaler_min = None
        num_scaler_scale = None
        
        if os.path.exists(checkpoint_path):
            ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
            scaler_mean = ckpt.get('scaler_mean', [0])[0]
            scaler_scale = ckpt.get('scaler_scale', [1])[0]
            # Load numeric scaler if saved
            num_scaler_min = ckpt.get('num_scaler_min', None)
            num_scaler_scale = ckpt.get('num_scaler_scale', None)
        
        # Apply numeric scaling if available, otherwise use raw values
        if num_scaler_min is not None and num_scaler_scale is not None:
            numeric_scaled = (numeric_feats - np.array(num_scaler_min)) / np.array(num_scaler_scale)
        else:
            # Fallback: use MinMaxScaler from training data
            train_path = os.path.join(DATA_DIR, f'train_{target}.csv')
            if os.path.exists(train_path):
                train_df = pd.read_csv(train_path)
                train_numeric = train_df.iloc[:, 8:16].values
                num_scaler = MinMaxScaler()
                num_scaler.fit(train_numeric)
                numeric_scaled = num_scaler.transform(numeric_feats.reshape(1, -1)).flatten()
            else:
                numeric_scaled = numeric_feats  # no scaling if no data
        
        # Combine extra features
        extra_feats = np.concatenate([numeric_scaled, scaffold_flags]).astype(np.float32)
        
        config = MODEL_CONFIGS[target]
        model_class = config['model_class']
        
        # Build fingerprint tensor: [solvent_fp(1024), smiles_fp(1024), extra(144)] = 2192
        fp = np.concatenate([sol_fp_arr, mol_fp_arr, extra_feats])
        
        if model_class == 'GraphFingerprintsModel':
            # abs/em use CNN attention architecture
            solvent_dim = 1024
            smiles_extra_dim = len(fp) - solvent_dim  # 1024 + 144 = 1168
            
            model = GraphFingerprintsModel(
                node_feat_size=n_feats,
                edge_feat_size=e_feats,
                solvent_dim=solvent_dim,
                smiles_extra_dim=smiles_extra_dim,
                graph_feat_size=GRAPH_FEAT_SIZE,
                num_layers=config['num_layers'],
                num_timesteps=config['num_timesteps'],
                n_tasks=1,
                dropout=config['dropout']
            )
        else:
            # plqy/k use simple FC architecture
            model = GraphFingerprintsModelFC(
                node_feat_size=n_feats,
                edge_feat_size=e_feats,
                fp_size=len(fp),  # 2192
                graph_feat_size=GRAPH_FEAT_SIZE,
                num_layers=config['num_layers'],
                num_timesteps=config['num_timesteps'],
                n_tasks=1,
                dropout=config['dropout']
            )
        
        # Load weights
        model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
        model = model.to(device)
        model.eval()
        
        # Prepare inputs
        graph_batch = dgl.batch([graph]).to(device)
        node_feats = graph_batch.ndata['hv']
        edge_feats = graph_batch.edata['he']
        fp_tensor = torch.tensor(fp, dtype=torch.float32).unsqueeze(0).to(device)
        
        # Predict
        with torch.no_grad():
            pred = model(graph_batch, node_feats, edge_feats, fp_tensor)
            pred_scaled = pred.item()
        
        # Inverse transform: original = scaled * scale + mean
        if scaler_mean is not None and scaler_scale is not None:
            pred_value = pred_scaled * scaler_scale + scaler_mean
        else:
            pred_value = pred_scaled
        
        predictions[target] = pred_value
    
    return predictions

# Run prediction
print("\nüîÆ Running predictions...")
preds = predict_properties(molecule, solvent, model_dir=OUTPUT_DIR, 
                           checkpoint_dir=CHECKPOINT_DIR, device=device)

print("\n" + "="*50)
print("üìä PREDICTION RESULTS")
print("="*50)
print(f"   Absorption (abs):    {preds.get('abs', 'N/A'):.1f} nm")
print(f"   Emission (em):       {preds.get('em', 'N/A'):.1f} nm")
print(f"   Quantum Yield (plqy): {preds.get('plqy', 'N/A'):.3f}")
print(f"   Log Œµ (k):           {preds.get('k', 'N/A'):.2f}")
print("="*50)

In [None]:
# ============================================================================
# Compare Old (Pretrained) vs New (Just Trained) Models
# ============================================================================
# 
# IMPORTANT: The original models require the FULL preprocessing pipeline:
# 1. Solvent mapping (solvent SMILES -> solvent_num via 00_solvent_mapping.csv)
# 2. Scaffold detection (match against 136 substructures)
# 3. MinMaxScaler normalization on 8 numeric features (fit on training data)
# 4. Scaler inverse transform on predictions
#
# For a quick comparison, we'll load the preprocessed input.csv that was 
# generated by the original pipeline.

OLD_MODEL_DIR = './fluor_tools/Fluor-RLAT'  # Original pretrained models from repo
NEW_MODEL_DIR = './models'  # Newly trained models
ORIGINAL_DATA_DIR = './fluor_tools/Fluor-RLAT/data'  # Training data for scalers

def predict_with_original_model_proper(molecule_smiles, solvent_smiles, model_dir, data_dir, device='cuda'):
    """Predict using original pretrained models with proper preprocessing.
    
    This function mimics the original 02_property_prediction.py pipeline.
    Uses ORIGINAL_MODEL_CONFIGS which matches the pretrained model architectures.
    """
    from rdkit import Chem
    from rdkit.Chem import AllChem, Descriptors
    from sklearn.preprocessing import StandardScaler, MinMaxScaler
    
    # Step 1: Create molecular graph
    graph = smiles_to_graph(molecule_smiles)
    if graph is None:
        raise ValueError(f"Could not parse molecule SMILES: {molecule_smiles}")
    
    mol = Chem.MolFromSmiles(molecule_smiles)
    sol = Chem.MolFromSmiles(solvent_smiles)
    if mol is None or sol is None:
        raise ValueError("Invalid SMILES")
    
    # Step 2: Generate Morgan fingerprints (1024-bit, radius 2)
    mol_fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=1024)
    sol_fp = AllChem.GetMorganFingerprintAsBitVect(sol, radius=2, nBits=1024)
    mol_fp_arr = np.array(mol_fp, dtype=np.float32)
    sol_fp_arr = np.array(sol_fp, dtype=np.float32)
    
    # Step 3: Compute molecular descriptors
    mw = Descriptors.MolWt(mol)
    logp = Descriptors.MolLogP(mol)
    tpsa = Descriptors.TPSA(mol)
    double_bonds = sum(1 for bond in mol.GetBonds() 
                       if bond.GetBondType() == Chem.BondType.DOUBLE or bond.GetIsAromatic())
    ring_count = mol.GetRingInfo().NumRings()
    
    # Step 4: Get solvent_num from mapping (simplified - use 6 for toluene)
    solvent_mapping = {
        'CC1=CC=CC=C1': 6,  # toluene
        'Cc1ccccc1': 6,    # toluene (alternative)
        'CCO': 2,          # ethanol
        'CO': 1,           # methanol
        'c1ccccc1': 5,     # benzene
    }
    solvent_num = solvent_mapping.get(solvent_smiles, 0)
    
    # Step 5: Detect scaffold (simplified - check for BODIPY)
    bodipy_smarts = '[#5](-F)(-F)(-[#7])(-[#7])'
    tag = 0
    tag_name = 'Unknown'
    bodipy_pattern = Chem.MolFromSmarts(bodipy_smarts)
    if bodipy_pattern and mol.HasSubstructMatch(bodipy_pattern):
        tag = 5  # BODIPY tag
        tag_name = 'BODIPY'
    
    # Step 6: Create 136 scaffold flags (simplified - just set fragment_4 for BODIPY)
    scaffold_flags = np.zeros(136, dtype=np.float32)
    if tag == 5:  # BODIPY
        scaffold_flags[3] = 1  # fragment_4 (0-indexed as 3)
    
    # Step 7: Get unimol_plus placeholder
    unimol_plus = 3.49  # placeholder
    
    # Numeric features in correct order: [solvent_num, tag, MW, LogP, TPSA, double_bonds, ring_count, unimol_plus]
    numeric_feats = np.array([solvent_num, tag, mw, logp, tpsa, double_bonds, ring_count, unimol_plus], dtype=np.float32)
    
    predictions = {}
    raw_predictions = {}
    n_feats = graph.ndata['hv'].shape[1]
    e_feats = graph.edata['he'].shape[1]
    
    for target in ['abs', 'em', 'plqy', 'k']:
        model_path = os.path.join(model_dir, f'Model_{target}.pth')
        if not os.path.exists(model_path):
            continue
        
        # Load training data to fit scalers (same as original pipeline)
        train_path = os.path.join(data_dir, f'train_{target}.csv')
        if not os.path.exists(train_path):
            print(f"   ‚ö†Ô∏è Training data not found: {train_path}")
            continue
            
        train_df = pd.read_csv(train_path)
        
        # Fit label scaler
        label_scaler = StandardScaler()
        label_scaler.fit(train_df[[target]].values)
        
        # Fit numeric scaler on columns 8:16 of training data
        train_numeric = train_df.iloc[:, 8:16].values
        num_scaler = MinMaxScaler()
        num_scaler.fit(train_numeric)
        
        # Apply numeric scaler to our features
        numeric_scaled = num_scaler.transform(numeric_feats.reshape(1, -1)).flatten()
        
        # Combine extra features: scaled_numeric (8) + scaffold_flags (136) = 144
        extra_feats = np.concatenate([numeric_scaled, scaffold_flags]).astype(np.float32)
        
        config = ORIGINAL_MODEL_CONFIGS[target]
        model_class = config['model_class']
        
        # Build fingerprint tensor: [solvent_fp(1024), smiles_fp(1024), extra(144)] = 2192
        fp = np.concatenate([sol_fp_arr, mol_fp_arr, extra_feats])
        
        if model_class == 'GraphFingerprintsModel':
            # abs/em use CNN attention architecture
            solvent_dim = 1024
            smiles_extra_dim = len(fp) - solvent_dim  # 1024 + 144 = 1168
            
            model = GraphFingerprintsModel(
                node_feat_size=n_feats,
                edge_feat_size=e_feats,
                solvent_dim=solvent_dim,
                smiles_extra_dim=smiles_extra_dim,
                graph_feat_size=256,
                num_layers=config['num_layers'],
                num_timesteps=config['num_timesteps'],
                n_tasks=1,
                dropout=config['dropout']
            )
        else:
            # plqy/k use simple FC architecture
            model = GraphFingerprintsModelFC(
                node_feat_size=n_feats,
                edge_feat_size=e_feats,
                fp_size=len(fp),  # 2192
                graph_feat_size=256,
                num_layers=config['num_layers'],
                num_timesteps=config['num_timesteps'],
                n_tasks=1,
                dropout=config['dropout']
            )
        
        model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
        model = model.to(device)
        model.eval()
        
        graph_batch = dgl.batch([graph]).to(device)
        node_feats = graph_batch.ndata['hv']
        edge_feats = graph_batch.edata['he']
        fp_tensor = torch.tensor(fp, dtype=torch.float32).unsqueeze(0).to(device)
        
        with torch.no_grad():
            pred = model(graph_batch, node_feats, edge_feats, fp_tensor)
            raw_pred = pred.item()
            raw_predictions[target] = raw_pred
            
            # Inverse transform using label scaler
            pred_value = label_scaler.inverse_transform([[raw_pred]])[0, 0]
            predictions[target] = pred_value
    
    return predictions, raw_predictions


print("üî¨ Comparing pretrained vs newly trained models\n")
print(f"Molecule: {molecule}")
print(f"Solvent:  {solvent}\n")

# Check which models exist
old_models_exist = all(os.path.exists(os.path.join(OLD_MODEL_DIR, f'Model_{t}.pth')) for t in ['abs', 'em', 'plqy', 'k'])
new_models_exist = all(os.path.exists(os.path.join(NEW_MODEL_DIR, f'Model_{t}.pth')) for t in ['abs', 'em', 'plqy', 'k'])

print(f"Old models found: {'‚úÖ' if old_models_exist else '‚ùå'} ({OLD_MODEL_DIR})")
print(f"New models found: {'‚úÖ' if new_models_exist else '‚ùå'} ({NEW_MODEL_DIR})")

old_preds = {}
old_raw = {}
new_preds = {}

if old_models_exist:
    print("\nüîÆ Running predictions with original pretrained models...")
    try:
        old_preds, old_raw = predict_with_original_model_proper(
            molecule, solvent, 
            model_dir=OLD_MODEL_DIR, 
            data_dir=ORIGINAL_DATA_DIR,
            device=device
        )
        print("   ‚úÖ Original model predictions complete")
        print(f"   Raw outputs (normalized): abs={old_raw.get('abs', 'N/A'):.3f}, em={old_raw.get('em', 'N/A'):.3f}, plqy={old_raw.get('plqy', 'N/A'):.3f}, k={old_raw.get('k', 'N/A'):.3f}")
    except Exception as e:
        print(f"   ‚ùå Error: {e}")
        import traceback
        traceback.print_exc()

if new_models_exist:
    print("\nüîÆ Running predictions with newly trained models...")
    try:
        new_preds = predict_properties(molecule, solvent, model_dir=NEW_MODEL_DIR, 
                                       checkpoint_dir=CHECKPOINT_DIR, device=device)
        print("   ‚úÖ New model predictions complete")
    except Exception as e:
        print(f"   ‚ùå Error: {e}")
        import traceback
        traceback.print_exc()

# Display comparison
if old_preds or new_preds:
    print("\n" + "="*70)
    print("üìä MODEL COMPARISON")
    print("="*70)
    print(f"{'Property':<20} {'Original':<18} {'Retrained':<18} {'Diff':<12}")
    print("-"*70)
    
    # Expected values from running the original pipeline
    expected = {'abs': 639.89, 'em': 660.38, 'plqy': 0.76, 'k': 5.0}
    
    for prop, unit, fmt in [('abs', 'nm', '.1f'), ('em', 'nm', '.1f'), ('plqy', '', '.3f'), ('k', '', '.2f')]:
        old_val = old_preds.get(prop, float('nan'))
        new_val = new_preds.get(prop, float('nan'))
        exp_val = expected.get(prop, float('nan'))
        
        old_str = f"{old_val:{fmt}} {unit}".strip() if not np.isnan(old_val) else "N/A"
        new_str = f"{new_val:{fmt}} {unit}".strip() if not np.isnan(new_val) else "N/A"
        
        if not np.isnan(old_val) and not np.isnan(new_val):
            diff = new_val - old_val
            diff_str = f"{diff:+{fmt}}"
        else:
            diff_str = "-"
        
        prop_name = {'abs': 'Absorption', 'em': 'Emission', 'plqy': 'Quantum Yield', 'k': 'Log Œµ'}[prop]
        print(f"   {prop_name:<17} {old_str:<18} {new_str:<18} {diff_str:<12}")
    
    print("="*70)
    
    # Show expected from original pipeline
    print("\nüìã Reference values from original pipeline (run.py):")
    print(f"   Absorption:    {expected['abs']:.2f} nm")
    print(f"   Emission:      {expected['em']:.2f} nm")
    print(f"   Quantum Yield: {expected['plqy']:.2f}")
    print(f"   Log Œµ:         {expected['k']:.1f}")
    
    print("\nüìè Expected value ranges:")
    print("   Absorption:    ~300-900 nm")
    print("   Emission:      ~350-1000 nm")
    print("   Quantum Yield: 0.0-1.0")
    print("   Log Œµ:         ~3.0-5.5")