# üîÆ Fluor-RLAT Prediction

Lightweight notebook for predicting fluorescent molecule properties using pretrained models.

**Quick Start:**
1. Go to Runtime ‚Üí Change runtime type ‚Üí Select **T4 GPU** (optional, CPU works too)
2. Run all cells
3. Add your molecules to the `molecules` list and run predictions

**Properties predicted:**
- `abs` - Absorption wavelength (nm)
- `em` - Emission wavelength (nm)  
- `plqy` - Photoluminescence quantum yield (0-1)
- `k` - Log molar absorptivity

---

## 1. Install Dependencies

In [1]:
# ============================================================================
# Install Dependencies
# ============================================================================
import os
os.environ['TORCHDYNAMO_DISABLE'] = '1'

# Detect CUDA version
cuda_version = !nvcc --version 2>/dev/null | grep -oP 'release \K[\d.]+'
cuda_ver = cuda_version[0] if cuda_version else "12"
print(f"Detected CUDA: {cuda_ver}")

# Install packages
print("Installing dependencies...")
!pip install rdkit -q

if cuda_ver.startswith('12'):
    !pip install dgl -f https://data.dgl.ai/wheels/torch-2.4/cu124/repo.html -q
else:
    !pip install dgl -f https://data.dgl.ai/wheels/torch-2.1/cu118/repo.html -q

!pip install dgllife -q

print("‚úÖ Dependencies installed!")
print("‚ö†Ô∏è  If first run, restart runtime: Runtime ‚Üí Restart runtime")

Detected CUDA: 12.8
Installing dependencies...
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m36.7/36.7 MB[0m [31m55.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m347.8/347.8 MB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m797.2/797.2 MB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m410.6/410.6 MB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚î

In [12]:
# ============================================================================
# Import Libraries
# ============================================================================
import os
os.environ['TORCHDYNAMO_DISABLE'] = '1'

import torch
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.disable = True

import numpy as np
import pandas as pd
import dgl
from dgllife.model import AttentiveFPGNN, AttentiveFPReadout
from dgllife.utils import smiles_to_bigraph, AttentiveFPAtomFeaturizer, AttentiveFPBondFeaturizer
import torch.nn as nn
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from sklearn.preprocessing import StandardScaler, MinMaxScaler

print(f"PyTorch: {torch.__version__}")
print(f"DGL: {dgl.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

PyTorch: 2.4.0+cu121
DGL: 2.4.0+cu124
CUDA available: True
Using device: cuda


In [13]:
# ============================================================================
# Mount Google Drive (for checkpoints)
# ============================================================================
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [14]:
# ============================================================================
# Clone Repository (for models and data)
# ============================================================================
REPO_URL = "https://github.com/markste-in/fluor_tools.git"
REPO_DIR = "fluor_tools"

if not os.path.exists(REPO_DIR):
    print(f"üì• Cloning repository...")
    !git clone {REPO_URL} -q
    print("‚úÖ Repository cloned!")
else:
    print(f"‚úÖ Repository already exists")

# Paths
MODEL_DIR = f'./{REPO_DIR}/Fluor-RLAT'
DATA_DIR = f'./{REPO_DIR}/Fluor-RLAT/data'
CHECKPOINT_DIR = '/content/drive/MyDrive/fluor_checkpoints'

print(f"üìÅ Models: {MODEL_DIR}")
print(f"üìÅ Data: {DATA_DIR}")
print(f"üìÅ Checkpoints: {CHECKPOINT_DIR}")

‚úÖ Repository already exists
üìÅ Models: ./fluor_tools/Fluor-RLAT
üìÅ Data: ./fluor_tools/Fluor-RLAT/data
üìÅ Checkpoints: /content/drive/MyDrive/fluor_checkpoints


## 2. Model Definitions

In [15]:
# ============================================================================
# Model Architectures
# ============================================================================

# Use AttentiveFP featurizers - they store features under 'hv' and 'he' keys
ATOM_FEATURIZER = AttentiveFPAtomFeaturizer(atom_data_field='hv')
BOND_FEATURIZER = AttentiveFPBondFeaturizer(bond_data_field='he')
GRAPH_FEAT_SIZE = 256

# Model configs (must match pretrained models)
MODEL_CONFIGS = {
    'abs':  {'num_layers': 2, 'num_timesteps': 2, 'dropout': 0.3, 'model_class': 'GraphFingerprintsModel'},
    'em':   {'num_layers': 3, 'num_timesteps': 1, 'dropout': 0.3, 'model_class': 'GraphFingerprintsModel'},
    'plqy': {'num_layers': 2, 'num_timesteps': 3, 'dropout': 0.4, 'model_class': 'GraphFingerprintsModelFC'},
    'k':    {'num_layers': 3, 'num_timesteps': 1, 'dropout': 0.3, 'model_class': 'GraphFingerprintsModelFC'},
}


class FingerprintAttentionCNN(nn.Module):
    """CNN with attention for fingerprint processing (used by abs/em models)."""
    def __init__(self, input_dim, conv_channels=256):
        super().__init__()
        self.conv_feat = nn.Conv1d(1, conv_channels, kernel_size=3, padding=1)
        self.conv_attn = nn.Conv1d(1, conv_channels, kernel_size=3, padding=1)
        self.softmax = nn.Softmax(dim=-1)
        self.pool = nn.AdaptiveMaxPool1d(1)

    def forward(self, x):
        x = x.unsqueeze(1)
        feat_map = self.conv_feat(x)
        attn_map = self.conv_attn(x)
        attn_weights = self.softmax(attn_map)
        attn_out = torch.sum(feat_map * attn_weights, dim=-1)
        pooled = self.pool(feat_map).squeeze(-1)
        return torch.cat([attn_out, pooled], dim=1)


class GraphFingerprintsModel(nn.Module):
    """Model for abs/em: AttentiveFP GNN + CNN attention for fingerprints."""
    def __init__(self, node_feat_size, edge_feat_size, solvent_dim, smiles_extra_dim,
                 graph_feat_size=256, num_layers=2, num_timesteps=2, n_tasks=1, dropout=0.3):
        super().__init__()
        self.solvent_dim = solvent_dim
        
        self.gnn = AttentiveFPGNN(node_feat_size=node_feat_size, edge_feat_size=edge_feat_size,
                                   num_layers=num_layers, graph_feat_size=graph_feat_size, dropout=dropout)
        self.readout = AttentiveFPReadout(feat_size=graph_feat_size, num_timesteps=num_timesteps, dropout=dropout)
        self.fp_extractor = FingerprintAttentionCNN(smiles_extra_dim, conv_channels=graph_feat_size)
        self.solvent_extractor = nn.Sequential(
            nn.Linear(solvent_dim, 256), nn.ReLU(), nn.Dropout(dropout), nn.Linear(256, graph_feat_size))
        self.predict = nn.Sequential(
            nn.Dropout(dropout), nn.Linear(graph_feat_size * 4, 128), nn.ReLU(), nn.Linear(128, n_tasks))
        
    def forward(self, graph, node_feats, edge_feats, fingerprints):
        node_out = self.gnn(graph, node_feats, edge_feats)
        graph_out = self.readout(graph, node_out, False)
        solvent_out = self.solvent_extractor(fingerprints[:, :self.solvent_dim])
        smiles_extra_out = self.fp_extractor(fingerprints[:, self.solvent_dim:])
        combined = torch.cat([graph_out, solvent_out, smiles_extra_out], dim=1)
        return self.predict(combined)


class GraphFingerprintsModelFC(nn.Module):
    """Model for plqy/k: AttentiveFP GNN + Simple FC for fingerprints."""
    def __init__(self, node_feat_size, edge_feat_size, fp_size,
                 graph_feat_size=256, num_layers=2, num_timesteps=2, n_tasks=1, dropout=0.3):
        super().__init__()
        
        self.gnn = AttentiveFPGNN(node_feat_size=node_feat_size, edge_feat_size=edge_feat_size,
                                   num_layers=num_layers, graph_feat_size=graph_feat_size, dropout=dropout)
        self.readout = AttentiveFPReadout(feat_size=graph_feat_size, num_timesteps=num_timesteps, dropout=dropout)
        self.fp_fc = nn.Sequential(
            nn.Linear(fp_size, 256), nn.ReLU(), nn.Dropout(dropout), nn.Linear(256, graph_feat_size))
        self.predict = nn.Sequential(
            nn.Dropout(dropout), nn.Linear(graph_feat_size * 2, 128), nn.ReLU(), nn.Linear(128, n_tasks))
        
    def forward(self, graph, node_feats, edge_feats, fingerprints):
        node_out = self.gnn(graph, node_feats, edge_feats)
        graph_out = self.readout(graph, node_out, False)
        fp_out = self.fp_fc(fingerprints)
        combined = torch.cat([graph_out, fp_out], dim=1)
        return self.predict(combined)


def smiles_to_graph(smiles):
    """Convert SMILES to DGL graph."""
    try:
        return smiles_to_bigraph(smiles, node_featurizer=ATOM_FEATURIZER, 
                                  edge_featurizer=BOND_FEATURIZER, add_self_loop=False)
    except:
        return None

print("‚úÖ Model classes defined")

‚úÖ Model classes defined


## 3. Prediction Function

In [16]:
# ============================================================================
# Prediction Function
# ============================================================================

def predict_properties(molecule_smiles, solvent_smiles, model_dir=MODEL_DIR, 
                       data_dir=DATA_DIR, checkpoint_dir=CHECKPOINT_DIR, device='cuda'):
    """Predict all properties for a single molecule."""
    
    # Generate molecular graph
    graph = smiles_to_graph(molecule_smiles)
    if graph is None:
        raise ValueError(f"Could not parse molecule SMILES: {molecule_smiles}")
    
    mol = Chem.MolFromSmiles(molecule_smiles)
    sol = Chem.MolFromSmiles(solvent_smiles)
    if mol is None or sol is None:
        raise ValueError("Invalid SMILES")
    
    # Generate Morgan fingerprints (1024-bit, radius 2)
    mol_fp = np.array(AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=1024), dtype=np.float32)
    sol_fp = np.array(AllChem.GetMorganFingerprintAsBitVect(sol, radius=2, nBits=1024), dtype=np.float32)
    
    # Compute molecular descriptors
    mw = Descriptors.MolWt(mol)
    logp = Descriptors.MolLogP(mol)
    tpsa = Descriptors.TPSA(mol)
    double_bonds = sum(1 for bond in mol.GetBonds() 
                       if bond.GetBondType() == Chem.BondType.DOUBLE or bond.GetIsAromatic())
    ring_count = mol.GetRingInfo().NumRings()
    
    # Solvent mapping (simplified)
    solvent_mapping = {'CC1=CC=CC=C1': 6, 'Cc1ccccc1': 6, 'CCO': 2, 'CO': 1, 'c1ccccc1': 5}
    solvent_num = solvent_mapping.get(solvent_smiles, 0)
    
    # Detect scaffold (BODIPY check)
    bodipy_pattern = Chem.MolFromSmarts('[#5](-F)(-F)(-[#7])(-[#7])')
    tag = 5 if bodipy_pattern and mol.HasSubstructMatch(bodipy_pattern) else 0
    
    # Create scaffold flags
    scaffold_flags = np.zeros(136, dtype=np.float32)
    if tag == 5:
        scaffold_flags[3] = 1
    
    unimol_plus = 3.49
    numeric_feats = np.array([solvent_num, tag, mw, logp, tpsa, double_bonds, ring_count, unimol_plus], dtype=np.float32)
    
    predictions = {}
    n_feats = graph.ndata['hv'].shape[1]
    e_feats = graph.edata['he'].shape[1]
    
    for target in ['abs', 'em', 'plqy', 'k']:
        model_path = os.path.join(model_dir, f'Model_{target}.pth')
        if not os.path.exists(model_path):
            continue
        
        # Load training data to fit scalers
        train_path = os.path.join(data_dir, f'train_{target}.csv')
        if not os.path.exists(train_path):
            continue
            
        train_df = pd.read_csv(train_path)
        
        # Fit scalers
        label_scaler = StandardScaler()
        label_scaler.fit(train_df[[target]].values)
        
        num_scaler = MinMaxScaler()
        num_scaler.fit(train_df.iloc[:, 8:16].values)
        numeric_scaled = num_scaler.transform(numeric_feats.reshape(1, -1)).flatten()
        
        # Combine features
        extra_feats = np.concatenate([numeric_scaled, scaffold_flags]).astype(np.float32)
        fp = np.concatenate([sol_fp, mol_fp, extra_feats])
        
        config = MODEL_CONFIGS[target]
        
        if config['model_class'] == 'GraphFingerprintsModel':
            model = GraphFingerprintsModel(
                node_feat_size=n_feats, edge_feat_size=e_feats,
                solvent_dim=1024, smiles_extra_dim=len(fp)-1024,
                graph_feat_size=GRAPH_FEAT_SIZE, num_layers=config['num_layers'],
                num_timesteps=config['num_timesteps'], dropout=config['dropout'])
        else:
            model = GraphFingerprintsModelFC(
                node_feat_size=n_feats, edge_feat_size=e_feats, fp_size=len(fp),
                graph_feat_size=GRAPH_FEAT_SIZE, num_layers=config['num_layers'],
                num_timesteps=config['num_timesteps'], dropout=config['dropout'])
        
        model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
        model = model.to(device)
        model.eval()
        
        graph_batch = dgl.batch([graph]).to(device)
        fp_tensor = torch.tensor(fp, dtype=torch.float32).unsqueeze(0).to(device)
        
        with torch.no_grad():
            pred = model(graph_batch, graph_batch.ndata['hv'], graph_batch.edata['he'], fp_tensor)
            pred_scaled = pred.item()
        
        predictions[target] = label_scaler.inverse_transform([[pred_scaled]])[0, 0]
    
    return predictions

print("‚úÖ Prediction function defined")

‚úÖ Prediction function defined


## 4. Define Molecules to Predict

In [28]:
# ============================================================================
# üß™ DEFINE YOUR MOLECULES HERE
# ============================================================================
# Format: (name, SMILES)
# Add as many molecules as you want!

molecules = [
    ("BODIPY-phenyl", "C2=C1C7=C(C(=[N+]1[B-]([N]3C2=C5C(=C3C4=CC=CC=C4)C=CC=C5)(F)F)C6=CC=CC=C6)C=CC=C7"),
    ("BODIPY-thiophene", "C2(=C1C(=C(C(=[N+]1[B-]([N]3C2=C(C(=C3C)C4=CC=CS4)C)(F)F)C)C5=CC=CS5)C)C6=C(C=C(C=C6C)C)C"),
    ("BBOT", "CC(C)(C)c1ccc2oc(nc2c1)c1sc(cc1)c1oc2ccc(cc2n1)C(C)(C)C"),
    ("BJ18023", "Cc1cc(C)cc(C)c1C=1c2cc(cn2[B-](F)(F)[N+]2=CC(=CC2=1)c1cccs1)c1cccs1"),
    # Add more molecules here:
    # ("Name", "SMILES"),
]

# Solvent (same for all molecules)
solvent = "CC1=CC=CC=C1"  # toluene

print(f"üß´ Solvent: {solvent}")
print(f"\nüß™ Molecules to predict ({len(molecules)}):")
for name, smiles in molecules:
    print(f"   ‚Ä¢ {name}")

üß´ Solvent: CC1=CC=CC=C1

üß™ Molecules to predict (4):
   ‚Ä¢ BODIPY-phenyl
   ‚Ä¢ BODIPY-thiophene
   ‚Ä¢ BBOT
   ‚Ä¢ BJ18023


## 5. Run Predictions

In [29]:
# ============================================================================
# Run Predictions
# ============================================================================

print("üîÆ Running predictions...\n")
all_results = []

for name, smiles in molecules:
    try:
        preds = predict_properties(smiles, solvent, model_dir=MODEL_DIR, 
                                   data_dir=DATA_DIR, device=device)
        preds['name'] = name
        preds['smiles'] = smiles
        all_results.append(preds)
        print(f"   ‚úÖ {name}")
    except Exception as e:
        print(f"   ‚ùå {name}: {e}")

# Display results
print("\n" + "="*80)
print("üìä PREDICTION RESULTS")
print("="*80)

for preds in all_results:
    print(f"\nüß™ {preds['name']}")
    print(f"   SMILES: {preds['smiles'][:60]}{'...' if len(preds['smiles']) > 60 else ''}")
    print(f"   ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ")
    print(f"   Absorption (abs):     {preds.get('abs', float('nan')):>7.1f} nm")
    print(f"   Emission (em):        {preds.get('em', float('nan')):>7.1f} nm")
    print(f"   Quantum Yield (plqy): {preds.get('plqy', float('nan')):>7.3f}")
    print(f"   Log Œµ (k):            {preds.get('k', float('nan')):>7.2f}")

print("\n" + "="*80)

# Summary table
if all_results:
    results_df = pd.DataFrame(all_results)[['name', 'abs', 'em', 'plqy', 'k']]
    results_df.columns = ['Molecule', 'Abs (nm)', 'Em (nm)', 'PLQY', 'Log Œµ']
    print("\nüìã Summary Table:")
    display(results_df)

üîÆ Running predictions...

   ‚úÖ BODIPY-phenyl
   ‚úÖ BODIPY-thiophene
   ‚úÖ BBOT
   ‚úÖ BJ18023

üìä PREDICTION RESULTS

üß™ BODIPY-phenyl
   SMILES: C2=C1C7=C(C(=[N+]1[B-]([N]3C2=C5C(=C3C4=CC=CC=C4)C=CC=C5)(F)...
   ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
   Absorption (abs):       637.6 nm
   Emission (em):          656.5 nm
   Quantum Yield (plqy):   0.742
   Log Œµ (k):               5.00

üß™ BODIPY-thiophene
   SMILES: C2(=C1C(=C(C(=[N+]1[B-]([N]3C2=C(C(=C3C)C4=CC=CS4)C)(F)F)C)C...
   ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
   Absorption (abs):       524.0 nm
   Emission (em):          607.0 nm
   Quantum Yield (plqy):   0.315
   Log Œµ (k):               4.70

üß™ BBOT
   SMILES: CC(C)(C)c1ccc2oc(nc2c1)c1sc(cc1)c1oc2ccc(cc2n1)C(C)(C)C
   ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î

Unnamed: 0,Molecule,Abs (nm),Em (nm),PLQY,Log Œµ
0,BODIPY-phenyl,637.627266,656.546115,0.742101,5.000829
1,BODIPY-thiophene,524.000141,606.971069,0.315354,4.700419
2,BBOT,376.987715,428.494148,0.944382,4.683788
3,BJ18023,563.249698,664.983123,0.18302,4.342248


## 6. Compare with Custom Trained Models (Optional)

If you've trained your own models, compare predictions here.

In [30]:
# ============================================================================
# Compare Pretrained vs Custom Trained Models
# ============================================================================

# Custom models are saved to Google Drive by the training notebook
CUSTOM_MODEL_DIR = '/content/drive/MyDrive/fluor_models'
USE_CUSTOM_MODELS = os.path.exists(CUSTOM_MODEL_DIR) and any(
    os.path.exists(os.path.join(CUSTOM_MODEL_DIR, f'Model_{t}.pth')) for t in ['abs', 'em', 'plqy', 'k']
)

if USE_CUSTOM_MODELS:
    print("üîÑ Comparing pretrained vs custom trained models...\n")
    
    comparison = []
    for name, smiles in molecules:
        try:
            # Pretrained
            preds_pre = predict_properties(smiles, solvent, model_dir=MODEL_DIR, data_dir=DATA_DIR, device=device)
            # Custom
            preds_custom = predict_properties(smiles, solvent, model_dir=CUSTOM_MODEL_DIR, data_dir=DATA_DIR, device=device)
            
            comparison.append({
                'Molecule': name,
                'Abs (Pre)': preds_pre.get('abs', float('nan')),
                'Abs (Custom)': preds_custom.get('abs', float('nan')),
                'Em (Pre)': preds_pre.get('em', float('nan')),
                'Em (Custom)': preds_custom.get('em', float('nan')),
                'PLQY (Pre)': preds_pre.get('plqy', float('nan')),
                'PLQY (Custom)': preds_custom.get('plqy', float('nan')),
            })
        except Exception as e:
            print(f"   ‚ùå {name}: {e}")
    
    if comparison:
        comp_df = pd.DataFrame(comparison)
        print("\nüìä Comparison Table:")
        display(comp_df)
else:
    print(f"‚ÑπÔ∏è  No custom trained models found in {CUSTOM_MODEL_DIR}")
    print("   Run the training notebook first to create custom models.")
    print("   Check that Google Drive is mounted and contains Model_*.pth files.")

üîÑ Comparing pretrained vs custom trained models...


üìä Comparison Table:


Unnamed: 0,Molecule,Abs (Pre),Abs (Custom),Em (Pre),Em (Custom),PLQY (Pre),PLQY (Custom)
0,BODIPY-phenyl,637.627266,637.649876,656.546115,662.416656,0.742101,0.861475
1,BODIPY-thiophene,524.000141,523.555714,606.971069,605.820536,0.315354,0.345119
2,BBOT,376.987715,374.083477,428.494148,431.178622,0.944382,0.919884
3,BJ18023,563.249698,571.337136,664.983123,674.192256,0.18302,0.408203



C2(=C1C(=C(C(=[N+]1[B-]([N]3C2=C(C(=C3C)C4=CC=CS4)C)(F)F)C)C5=CC=CS5)C)C6=C(C=C(C=C6C)C)C
CC1=CC=CC=C1
526-611

C2=C1C7=C(C(=[N+]1[B-]([N]3C2=C5C(=C3C4=CC=CC=C4)C=CC=C5)(F)F)C6=CC=CC=C6)C=CC=C7
CC1=CC=CC=C1
640-660