In [None]:
import os
import json
import time
import multiprocessing
import numpy as np
import pandas as pd
import warnings
from glob import glob
from tqdm import tqdm
from joblib import Parallel, delayed
import matplotlib.pyplot as plt
from IPython.display import clear_output
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score
from torch.optim.lr_scheduler import ReduceLROnPlateau

try:
    from rdkit import Chem
    from rdkit.Chem import AllChem
    print("RDKit Imported successfully")
except ImportError:
    raise ImportError("rdkit: pip install rdkit")

print("="*40)
print(f"PyTorch versions: {torch.__version__}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print("Graphics card presence")
else:
    print("Graphics card not present")
print("="*40)

try:
    from torch_geometric.data import Data, Batch, Dataset
    from torch_geometric.loader import DataLoader
    from torch_geometric.nn import GCNConv, global_mean_pool
except ImportError:
    raise ImportError("PyTorch Geometric: pip install torch_geometric")

try:
    from pymatgen.core import Structure
except ImportError:
    raise ImportError("Pymatgen: pip install pymatgen")

warnings.filterwarnings('ignore')

class Config:
    BASE_PATH = r"./data"
    DATASET_PATH = os.path.join(BASE_PATH, "Energy_data.xlsx")
    CIF_PATH = os.path.join(BASE_PATH, "molecular_sieve")
    CONFORMER_3D_PATH = os.path.join(BASE_PATH, "conformer 3D")
    STRUCTURE_2D_PATH = os.path.join(BASE_PATH, "structure 2D")
    
    PROCESSED_CACHE_PATH = os.path.join(BASE_PATH, "cached_graphs_box64_cleaned.pt")
    
    TARGET_COLS = [
        'Binding Energy (kJ/mol Si)',
        'Directivity Energy (kJ/mol Si)',
        'Competition Energy (kJ/mol Si)',
        'Binding Energy (kJ/mol OSDA)',
        'Competition Energy (kJ/mol OSDA)'
    ]

    BATCH_SIZE = 64
    NUM_WORKERS = 0  
    PIN_MEMORY = True

    ATOM_EMBEDDING_DIM = 64
    HIDDEN_DIM = 128

    EMB_DIM_DEGREE = 8
    EMB_DIM_CHARGE = 8
    EMB_DIM_HYB = 8
    EMB_DIM_AROMATIC = 4
    EMB_DIM_CHIRAL = 4

    LR = 0.0008
    WEIGHT_DECAY = 1e-5
    EPOCHS = 200
    CRYSTAL_RADIUS = 6.0

    VOXEL_SIZE = 64  
    VOXEL_RES = 0.5
    SIGMA = 0.5
    
    # Data clarity
    MIN_SAMPLES_PER_TOPO = 0

    # Early Stop Configuration
    EARLY_STOPPING_PATIENCE = 20
    EARLY_STOPPING_MIN_DELTA = 0.001

    # Learning Rate Scheduling Configuration
    SCHEDULER_PATIENCE = 5
    SCHEDULER_FACTOR = 0.5
    MIN_LR = 1e-6
    
    # RDKit Number of conformations generated
    NUM_CONFORMERS = 3

# 2. 3D Transformations and Voxelization
def get_random_rotation_matrix():
    """Generate random 3D rotation matrices"""
    theta = np.random.uniform(0, 2*np.pi)
    phi = np.random.uniform(0, 2*np.pi)
    z = np.random.uniform(0, 2*np.pi)

    Rx = np.array([[1, 0, 0],
                   [0, np.cos(theta), -np.sin(theta)],
                   [0, np.sin(theta),  np.cos(theta)]])

    Ry = np.array([[np.cos(phi), 0, np.sin(phi)],
                   [0, 1, 0],
                   [-np.sin(phi), 0, np.cos(phi)]])

    Rz = np.array([[np.cos(z), -np.sin(z), 0],
                   [np.sin(z),  np.cos(z), 0],
                   [0, 0, 1]])

    return Rz @ Ry @ Rx

def coords_to_voxel(coords, grid_size=32, res=0.5, sigma=0.5):
    grid = np.zeros((grid_size, grid_size, grid_size), dtype=np.float32)
    limit = (grid_size * res) / 2.0

    mask = (coords[:, 0] > -limit) & (coords[:, 0] < limit) & \
           (coords[:, 1] > -limit) & (coords[:, 1] < limit) & \
           (coords[:, 2] > -limit) & (coords[:, 2] < limit)

    valid_coords = coords[mask]
    if len(valid_coords) == 0:
        return grid

    indices = ((valid_coords + limit) / res).astype(int)
    indices = np.clip(indices, 0, grid_size - 1)

    for idx in indices:
        x, y, z = idx
        x_min, x_max = max(0, x-1), min(grid_size, x+2)
        y_min, y_max = max(0, y-1), min(grid_size, y+2)
        z_min, z_max = max(0, z-1), min(grid_size, z+2)
        grid[x_min:x_max, y_min:y_max, z_min:z_max] += 1.0

    return np.clip(grid, 0, 1.0)

# 3. graphical construction
class GraphBuilder:
    def _get_atom_encoding_legacy(self, atomic_num):
        if atomic_num > 100: return 100
        return atomic_num - 1

    def _get_rich_atom_features(self, atom=None, element_symbol=None, is_crystal=False):

        # 1. Atomic Number (0-118)
        if atom:
            atomic_num = atom.GetAtomicNum()
        elif element_symbol:
            pt = Chem.GetPeriodicTable()
            atomic_num = pt.GetAtomicNumber(element_symbol)
        else:
            atomic_num = 0

        feat_atomic = min(atomic_num, 118)

        if is_crystal or atom is None:
            return [feat_atomic, 0, 5, 0, 0, 0] 

        # 2. Degree (0-10)
        degree = min(atom.GetDegree(), 10)

        # 3. Formal Charge
        charge = atom.GetFormalCharge()
        charge_idx = charge + 5 
        charge_idx = max(0, min(charge_idx, 14))

        # 4. Hybridization (0-6)
        hyb = atom.GetHybridization()
        hyb_map = {
            Chem.rdchem.HybridizationType.S: 0,
            Chem.rdchem.HybridizationType.SP: 1,
            Chem.rdchem.HybridizationType.SP2: 2,
            Chem.rdchem.HybridizationType.SP3: 3,
            Chem.rdchem.HybridizationType.SP3D: 4,
            Chem.rdchem.HybridizationType.SP3D2: 5,
            Chem.rdchem.HybridizationType.UNSPECIFIED: 6
        }
        hyb_idx = hyb_map.get(hyb, 6)

        # 5. Aromaticity (0 or 1)
        is_aromatic = 1 if atom.GetIsAromatic() else 0

        # 6. Chirality (0-3)
        chi_map = {
            Chem.rdchem.ChiralType.CHI_UNSPECIFIED: 0,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW: 1,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW: 2,
            Chem.rdchem.ChiralType.CHI_OTHER: 3
        }
        chi_idx = chi_map.get(atom.GetChiralTag(), 0)

        return [feat_atomic, degree, charge_idx, hyb_idx, is_aromatic, chi_idx]

    def _calculate_shape_descriptors(self, coords):
        if coords is None or len(coords) < 2:
            return [0.0] * 10
        coords = coords - np.mean(coords, axis=0)
        cov_matrix = np.cov(coords.T)
        evals, evecs = np.linalg.eigh(cov_matrix)
        idx = evals.argsort()[::-1]
        evals = evals[idx]
        L, W, H = 0.0, 0.0, 0.0
        try:
            aligned_coords = np.dot(coords, evecs[:, idx])
            min_b = np.min(aligned_coords, axis=0)
            max_b = np.max(aligned_coords, axis=0)
            dims = max_b - min_b
            L, W, H = sorted(dims, reverse=True)
        except:
            pass
        rg = np.sqrt(np.mean(np.sum(coords**2, axis=1)))
        pm1 = max(evals[0], 1e-6)
        pm2 = evals[1] if len(evals) > 1 else 0.0
        pm3 = evals[2] if len(evals) > 2 else 0.0
        return [
            rg,
            pm1, pm2, pm3,
            L/(H+1e-6),
            np.sqrt(max(0, 1 - (pm3/pm1))),
            pm3/pm1,
            L, W, H
        ]

    def _extract_global_props(self, props_2d, props_3d, charge_val, coords_3d):
        props_map = {'logp': 0.0, 'tpsa': 0.0, 'rotatable': 0.0, 'h_acceptor': 0.0, 'h_donor': 0.0, 'volume': 0.0}
        for prop in props_2d:
            urn = prop.get('urn', {})
            label, name, val = urn.get('label', ''), urn.get('name', ''), prop.get('value', {})
            if label == 'Log P' and name == 'XLogP3-AA': props_map['logp'] = val.get('fval', 0.0)
            elif label == 'Topological' and name == 'Polar Surface Area': props_map['tpsa'] = val.get('fval', 0.0)
            elif label == 'Count' and name == 'Rotatable Bond': props_map['rotatable'] = float(val.get('ival', 0))
            elif label == 'Count' and name == 'Hydrogen Bond Acceptor': props_map['h_acceptor'] = float(val.get('ival', 0))
            elif label == 'Count' and name == 'Hydrogen Bond Donor': props_map['h_donor'] = float(val.get('ival', 0))
        for prop in props_3d:
            if prop.get('urn', {}).get('name') == 'Volume': props_map['volume'] = prop.get('value', {}).get('fval', 0.0)
        shape_feats = self._calculate_shape_descriptors(coords_3d)
        features_list = [props_map[k] for k in props_map] + [float(charge_val)] + shape_feats
        return torch.tensor(features_list, dtype=torch.float).unsqueeze(0)

    def _extract_partial_charges(self, num_atoms, props_list):
        charges = torch.zeros(num_atoms, 1, dtype=torch.float)
        for prop in props_list:
            if prop.get('urn', {}).get('name') == 'MMFF94 Partial':
                for item in prop.get('value', {}).get('slist', []):
                    try:
                        parts = item.split()
                        charges[int(parts[0])-1] = float(parts[1])
                    except: pass
                break
        return charges

    def build_molecule_graph(self, cid):
        file_3d = os.path.join(Config.CONFORMER_3D_PATH, f"Conformer3D_COMPOUND_CID_{cid}.json")
        file_2d = os.path.join(Config.STRUCTURE_2D_PATH, f"Structure2D_COMPOUND_CID_{cid}.json")
        main_data_2d, props_2d = None, []
        total_charge = 0.0
        smiles_str = None
        
        if os.path.exists(file_2d):
            try:
                with open(file_2d, 'r', encoding='utf-8') as f: d = json.load(f)
                if 'PC_Compounds' in d:
                    main_data_2d = d['PC_Compounds'][0]
                    props_2d = main_data_2d.get('props', [])
                    total_charge = float(main_data_2d.get('charge', 0.0))

                    for prop in props_2d:
                        if prop.get('urn', {}).get('label') == 'SMILES':
                            smiles_str = prop.get('value', {}).get('sval')
                            break
            except: pass
        
        # 2.RDKit Generating diagrams and conformations
        if smiles_str:
            try:
                mol = Chem.MolFromSmiles(smiles_str)
                if mol:
                    mol = Chem.AddHs(mol)
                    params = AllChem.ETKDGv2()
                    params.randomSeed = 42
                    cids_rdkit = AllChem.EmbedMultipleConfs(mol, numConfs=Config.NUM_CONFORMERS, params=params)
                    
                    if len(cids_rdkit) > 0:
                        AllChem.MMFFOptimizeMoleculeConfs(mol, numThreads=0)
                    
                    pos_variants = []
                    if len(cids_rdkit) > 0:
                        for i in range(len(cids_rdkit)):
                            conf = mol.GetConformer(cids_rdkit[i])
                            pos = conf.GetPositions()
                            pos = pos - np.mean(pos, axis=0)
                            pos_variants.append(pos)
                    
                    if len(pos_variants) > 0:
                        while len(pos_variants) < Config.NUM_CONFORMERS:
                            pos_variants.append(pos_variants[0])
                        
                        atom_features = [self._get_rich_atom_features(atom=atom) for atom in mol.GetAtoms()]
                        x = torch.tensor(atom_features, dtype=torch.long) # Shape: [N, 6]
                        pos_main = torch.tensor(pos_variants[0], dtype=torch.float)
                        pos_variants_tensor = torch.tensor(np.array(pos_variants), dtype=torch.float)
                        
                        edge_indices, edge_weights = [], []
                        for bond in mol.GetBonds():
                            u, v = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                            dist = np.linalg.norm(pos_main[u].numpy() - pos_main[v].numpy())
                            w = 1.0 / (dist + 0.1)
                            edge_indices.extend([[u, v], [v, u]])
                            edge_weights.extend([w, w])
                        
                        edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous() if edge_indices else torch.empty((2, 0), dtype=torch.long)
                        edge_weight = torch.tensor(edge_weights, dtype=torch.float) if edge_weights else torch.empty(0)
                        
                        AllChem.ComputeGasteigerCharges(mol)
                        charges = [float(atom.GetProp('_GasteigerCharge')) if atom.HasProp('_GasteigerCharge') else 0.0 for atom in mol.GetAtoms()]
                        x_charge = torch.tensor(charges, dtype=torch.float).unsqueeze(1)
                        
                        return Data(
                            x=x,
                            edge_index=edge_index,
                            edge_weight=edge_weight,
                            x_charge=x_charge,
                            global_attr=self._extract_global_props(props_2d, [], total_charge, pos_main.numpy()),
                            pos=pos_main,
                            pos_variants=pos_variants_tensor
                        )
            except Exception as e:
                # print(f"RDKit Build Failed for {cid}: {e}")
                pass
        
        return None

    def build_zeolite_graph(self, topology):
        patterns = [os.path.join(Config.CIF_PATH, f"*{topology}*.cif*"), os.path.join(Config.CIF_PATH, topology, "*.cif*")]
        cif_files = []
        for p in patterns: cif_files.extend(glob(p))
        if not cif_files: return None

        try:
            struct = Structure.from_file(cif_files[0])
            
            # PBC
            atom_features = []
            for site in struct:
                feats = self._get_rich_atom_features(element_symbol=site.specie.symbol, is_crystal=True)
                atom_features.append(feats)
            
            x = torch.tensor(atom_features, dtype=torch.long)
            pos = torch.tensor(struct.cart_coords, dtype=torch.float)
            
            # create a border 
            nbrs = struct.get_all_neighbors(r=Config.CRYSTAL_RADIUS, include_index=True)
            
            edge_indices, edge_attrs = [], []
            for i, nbr_list in enumerate(nbrs):
                for nbr in sorted(nbr_list, key=lambda x: x[1])[:12]:
                    target_index = nbr[2]
                    distance = nbr[1]
                    edge_indices.append([i, target_index])
                    edge_attrs.append(distance)
            
            edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
            edge_attr = torch.tensor(edge_attrs, dtype=torch.float).unsqueeze(1)
            
            struct_super = struct.copy()
            struct_super.make_supercell([3, 3, 3]) 
            pos_super = torch.tensor(struct_super.cart_coords, dtype=torch.float)
            pos_super = pos_super - torch.mean(pos_super, dim=0)

            return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos, pos_super=pos_super)
        except Exception as e:
            return None

# 4. Data Preprocessing and Caching
def build_mol_helper(cid):
    return (cid, GraphBuilder().build_molecule_graph(cid))

def build_zeo_helper(topo):
    return (topo, GraphBuilder().build_zeolite_graph(topo))

def prepare_and_cache_data(df):
    if os.path.exists(Config.PROCESSED_CACHE_PATH):
        print(f"Discover cache files: {Config.PROCESSED_CACHE_PATH}")
        try:
            return torch.load(Config.PROCESSED_CACHE_PATH, weights_only=False)
        except Exception as e:
            print(f"failed to load")
    
    print("No cache files found")
    unique_cids = df['CID'].unique()
    unique_topos = df['Topology Code'].unique()

    print(f"  - Number of unique molecules: {len(unique_cids)}")
    print(f"  - Number of unique zeolites: {len(unique_topos)}")
    
    print("Constructing molecular maps")
    mol_results = Parallel(n_jobs=multiprocessing.cpu_count())(
        delayed(build_mol_helper)(cid) for cid in tqdm(unique_cids, desc="Molecules")
    )
    mol_cache = {res[0]: res[1] for res in mol_results if res[1] is not None}
    
    print("Constructing a zeolite diagram")
    zeo_results = Parallel(n_jobs=min(len(unique_topos), multiprocessing.cpu_count()))(
        delayed(build_zeo_helper)(topo) for topo in tqdm(unique_topos, desc="Zeolites")
    )
    zeo_cache = {res[0]: res[1] for res in zeo_results if res[1] is not None}
    
    cache_data = {'mol_cache': mol_cache, 'zeo_cache': zeo_cache}
    print(f"Save the cache to: {Config.PROCESSED_CACHE_PATH}")
    torch.save(cache_data, Config.PROCESSED_CACHE_PATH)
    return cache_data

# 5. Dataset 
class ZeoliteDataset(Dataset):
    def __init__(self, df, cache_data, target_scaler=None, props_scaler=None, is_train=False):
        super().__init__()
        self.target_scaler = target_scaler if target_scaler else StandardScaler()
        self.props_scaler = props_scaler if props_scaler else StandardScaler()
        self.is_train = is_train
        
        mol_cache = cache_data['mol_cache']
        zeo_cache = cache_data['zeo_cache']
        
        self.mol_list = []
        self.zeo_list = []
        raw_y_list = []
        
        for idx, row in df.iterrows():
            cid = row['CID']
            topo = row['Topology Code']
            
            if cid in mol_cache and topo in zeo_cache:
                targets = row[Config.TARGET_COLS].values.astype(float)
                if not np.isnan(targets).any():
                    self.mol_list.append(mol_cache[cid])
                    self.zeo_list.append(zeo_cache[topo])
                    raw_y_list.append(targets)
        
        y_all = np.array(raw_y_list)
        if is_train:
            y_norm = self.target_scaler.fit_transform(y_all)
        else:
            y_norm = self.target_scaler.transform(y_all) if hasattr(self.target_scaler, 'mean_') else y_all
            
        self.y_list = [torch.tensor(y, dtype=torch.float) for y in y_norm]
        
        if len(self.mol_list) > 0:
            all_props = torch.cat([m.global_attr for m in self.mol_list], dim=0).numpy()
            if is_train:
                self.props_scaler.fit(all_props)
                
        self.length = len(self.mol_list)
        print(f"Dataset Finish building: {self.length} sample (Train={is_train})")

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        mol_data = self.mol_list[idx].clone()
        zeo_data = self.zeo_list[idx].clone()
        y = self.y_list[idx]

        if hasattr(self.props_scaler, 'mean_'):
            props_raw = mol_data.global_attr.numpy()
            props_norm = self.props_scaler.transform(props_raw)
            mol_data.global_attr = torch.tensor(props_norm, dtype=torch.float)

        mol_coords = mol_data.pos.numpy()
        if hasattr(mol_data, 'pos_variants'):
            variants = mol_data.pos_variants
            if self.is_train:
                conf_idx = np.random.randint(len(variants))
                mol_coords = variants[conf_idx].numpy()
            else:
                mol_coords = variants[0].numpy()
            del mol_data.pos_variants 

        zeo_voxel_coords = zeo_data.pos_super.numpy() if hasattr(zeo_data, 'pos_super') else zeo_data.pos.numpy()
        if hasattr(zeo_data, 'pos_super'): del zeo_data.pos_super 

        if self.is_train:
            rot_matrix = get_random_rotation_matrix()
            mol_coords = np.dot(mol_coords, rot_matrix)
            mol_noise = np.random.normal(0, 0.02, mol_coords.shape)
            zeo_noise = np.random.normal(0, 0.02, zeo_voxel_coords.shape)
            mol_coords += mol_noise
            zeo_voxel_coords += zeo_noise
            
        mol_data.pos = torch.tensor(mol_coords, dtype=torch.float)
        
        grid_mol = coords_to_voxel(mol_coords, Config.VOXEL_SIZE, Config.VOXEL_RES, Config.SIGMA)
        grid_zeo = coords_to_voxel(zeo_voxel_coords, Config.VOXEL_SIZE, Config.VOXEL_RES, Config.SIGMA)
        
        voxel_tensor = torch.tensor(np.stack([grid_mol, grid_zeo], axis=0), dtype=torch.float)
        
        return mol_data, zeo_data, voxel_tensor, y

    @staticmethod
    def gpu_collate(batch):
        mol_list = [item[0] for item in batch]
        zeo_list = [item[1] for item in batch]
        voxel_list = [item[2] for item in batch]
        y_list = [item[3] for item in batch]
        
        return (Batch.from_data_list(mol_list),
                Batch.from_data_list(zeo_list),
                torch.stack(voxel_list),
                torch.stack(y_list))

# 6. Model Definition
class Voxel3DCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv3d(2, 16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(16)
        self.pool1 = nn.MaxPool3d(2)
        
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm3d(32)
        self.pool2 = nn.MaxPool3d(2)
        
        self.conv3 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm3d(64)
        self.pool3 = nn.MaxPool3d(2)
        
        self.fc = nn.Linear(64 * 8 * 8 * 8, 128)

    def forward(self, x):
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc(x))
        return x

class DualBranchGNN(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.emb_atom = nn.Embedding(120, Config.ATOM_EMBEDDING_DIM)
        self.emb_degree = nn.Embedding(12, Config.EMB_DIM_DEGREE)
        self.emb_charge = nn.Embedding(15, Config.EMB_DIM_CHARGE)
        self.emb_hyb = nn.Embedding(8, Config.EMB_DIM_HYB)
        self.emb_aromatic = nn.Embedding(2, Config.EMB_DIM_AROMATIC)
        self.emb_chiral = nn.Embedding(4, Config.EMB_DIM_CHIRAL)
        
        total_emb_dim = (Config.ATOM_EMBEDDING_DIM + Config.EMB_DIM_DEGREE + 
                         Config.EMB_DIM_CHARGE + Config.EMB_DIM_HYB + 
                         Config.EMB_DIM_AROMATIC + Config.EMB_DIM_CHIRAL)
        
        self.mol_conv1 = GCNConv(total_emb_dim + 1, Config.HIDDEN_DIM)
        self.mol_conv2 = GCNConv(Config.HIDDEN_DIM, Config.HIDDEN_DIM)
        
        self.zeo_conv1 = GCNConv(total_emb_dim, Config.HIDDEN_DIM)
        self.zeo_conv2 = GCNConv(Config.HIDDEN_DIM, Config.HIDDEN_DIM)
        
        self.voxel_cnn = Voxel3DCNN()
        
        self.global_feat_dim = 17
        self.global_encoder = nn.Sequential(
            nn.Linear(self.global_feat_dim, 64),
            nn.ReLU(),
            nn.Linear(64, Config.HIDDEN_DIM),
            nn.BatchNorm1d(Config.HIDDEN_DIM),
            nn.ReLU()
        )
        
        fusion_dim = Config.HIDDEN_DIM * 4
        self.head = nn.Sequential(
            nn.Linear(fusion_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, len(Config.TARGET_COLS))
        )

    def _embed_features(self, x_idx):
        e1 = self.emb_atom(x_idx[:, 0])
        e2 = self.emb_degree(x_idx[:, 1])
        e3 = self.emb_charge(x_idx[:, 2])
        e4 = self.emb_hyb(x_idx[:, 3])
        e5 = self.emb_aromatic(x_idx[:, 4])
        e6 = self.emb_chiral(x_idx[:, 5])
        return torch.cat([e1, e2, e3, e4, e5, e6], dim=1)

    def forward(self, mol_batch, zeo_batch, voxel_batch):
        # --- molecules GNN ---
        x_m, edge_index_m, batch_m = mol_batch.x, mol_batch.edge_index, mol_batch.batch
        x_m_emb = self._embed_features(x_m)
        x_m_in = torch.cat([x_m_emb, mol_batch.x_charge], dim=1)
        
        x_m_out = F.relu(self.mol_conv1(x_m_in, edge_index_m, edge_weight=mol_batch.edge_weight))
        x_m_out = F.relu(self.mol_conv2(x_m_out, edge_index_m, edge_weight=mol_batch.edge_weight))
        feat_m = global_mean_pool(x_m_out, batch_m)
        
        # --- zeolites GNN ---
        x_z, edge_index_z, batch_z = zeo_batch.x, zeo_batch.edge_index, zeo_batch.batch
        x_z_emb = self._embed_features(x_z)
        
        x_z_out = F.relu(self.zeo_conv1(x_z_emb, edge_index_z))
        x_z_out = F.relu(self.zeo_conv2(x_z_out, edge_index_z))
        feat_z = global_mean_pool(x_z_out, batch_z)
        
        # --- 3D CNN ---
        feat_v = self.voxel_cnn(voxel_batch)
        
        # --- global feature ---
        global_attr = mol_batch.global_attr
        if global_attr.dim() == 3: global_attr = global_attr.squeeze(1)
        feat_global = self.global_encoder(global_attr)
        
        combined = torch.cat([feat_m, feat_z, feat_global, feat_v], dim=1)
        return self.head(combined)

# 7. main program
class EarlyStopping:
    def __init__(self, patience=20, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.best_state_dict = None

    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_state_dict = model.state_dict()
            return False

        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.best_state_dict = model.state_dict()
            return False
        else:
            self.counter += 1
            print(f"precession count: {self.counter}/{self.patience} (Best: {self.best_loss:.6f})")
            return self.counter >= self.patience

if __name__ == "__main__":
    multiprocessing.freeze_support()
    
    if not os.path.exists(Config.DATASET_PATH):
        print(f"Error: File does not exist {Config.DATASET_PATH}")
        exit()

    df = pd.read_excel(Config.DATASET_PATH, engine='openpyxl')

    print(f"Total raw data: {len(df)}")
    topo_counts = df['Topology Code'].value_counts()
    
    valid_topos = topo_counts[topo_counts >= Config.MIN_SAMPLES_PER_TOPO].index
    
    df_filtered = df[df['Topology Code'].isin(valid_topos)].reset_index(drop=True)
    
    print(f"Total amount of data after cleaning: {len(df_filtered)} (decline {len(df) - len(df_filtered)} samples, thresholds={Config.MIN_SAMPLES_PER_TOPO})")
    print(f"Types of zeolites retained: {len(valid_topos)}")

    cache_data = prepare_and_cache_data(df_filtered)
    
    indices = list(range(len(df_filtered)))
    train_idx, temp_idx = train_test_split(indices, train_size=0.8, random_state=42)
    val_idx, test_idx = train_test_split(temp_idx, train_size=0.5, random_state=42)
    
    print(f"\nData set segmentation: train {len(train_idx)} | validation {len(val_idx)} | test {len(test_idx)}")
    
    train_dataset = ZeoliteDataset(df_filtered.iloc[train_idx].reset_index(drop=True), cache_data=cache_data, is_train=True)
    val_dataset = ZeoliteDataset(df_filtered.iloc[val_idx].reset_index(drop=True), cache_data=cache_data, 
                                 target_scaler=train_dataset.target_scaler, props_scaler=train_dataset.props_scaler, is_train=False)
    test_dataset = ZeoliteDataset(df_filtered.iloc[test_idx].reset_index(drop=True), cache_data=cache_data, 
                                  target_scaler=train_dataset.target_scaler, props_scaler=train_dataset.props_scaler, is_train=False)

    loader_kwargs = {
        'batch_size': Config.BATCH_SIZE,
        'num_workers': Config.NUM_WORKERS,
        'pin_memory': Config.PIN_MEMORY,
        'collate_fn': ZeoliteDataset.gpu_collate,
        'persistent_workers': Config.NUM_WORKERS > 0
    }
    
    train_loader = DataLoader(train_dataset, shuffle=True, **loader_kwargs)
    val_loader = DataLoader(val_dataset, shuffle=False, **loader_kwargs)
    test_loader = DataLoader(test_dataset, shuffle=False, **loader_kwargs)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = DualBranchGNN().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=Config.LR, weight_decay=Config.WEIGHT_DECAY)
    criterion = nn.MSELoss()
    
    scheduler = ReduceLROnPlateau(
        optimizer, 
        mode='min', 
        factor=Config.SCHEDULER_FACTOR, 
        patience=Config.SCHEDULER_PATIENCE, 
        min_lr=Config.MIN_LR
    )
    
    early_stopper = EarlyStopping(patience=Config.EARLY_STOPPING_PATIENCE, min_delta=Config.EARLY_STOPPING_MIN_DELTA)
    

    scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())
    
    train_losses, val_losses = [], []
    
    print(f"\nstart.")
    
    for epoch in range(Config.EPOCHS):
        model.train()
        train_loss = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{Config.EPOCHS}", unit="batch")
        
        for mol, zeo, voxel, y in pbar:
            mol, zeo, voxel, y = mol.to(device), zeo.to(device), voxel.to(device), y.to(device)
            
            optimizer.zero_grad()
            
            with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
                pred = model(mol, zeo, voxel)
                loss = criterion(pred, y)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            train_loss += loss.item()
            pbar.set_postfix({'loss': f"{loss.item():.4f}"})
        
        avg_train_loss = train_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for mol, zeo, voxel, y in val_loader:
                mol, zeo, voxel, y = mol.to(device), zeo.to(device), voxel.to(device), y.to(device)
                with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
                    pred = model(mol, zeo, voxel)
                    val_loss += criterion(pred, y).item()
        
        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        
        scheduler.step(avg_val_loss)
        
        if early_stopper(avg_val_loss, model):
            print(f"\n Early Stop ! Best : {early_stopper.best_loss:.4f}")
            model.load_state_dict(early_stopper.best_state_dict)
            break
            
        clear_output(wait=True)
        plt.figure(figsize=(10, 5))
        plt.plot(train_losses, label='Train')
        plt.plot(val_losses, label='Val')
        plt.title(f'Loss (Epoch {epoch+1}) - Best Val: {early_stopper.best_loss:.4f}')
        plt.legend()
        plt.grid()
        plt.show()
        
        print(f"Epoch {epoch+1}: Train={avg_train_loss:.4f}, Val={avg_val_loss:.4f}, LR={optimizer.param_groups[0]['lr']:.2e}")

    print("\nFinal evaluation...")
    model.eval()
    preds, targets = [], []
    with torch.no_grad():
        for mol, zeo, voxel, y in test_loader:
            mol, zeo, voxel, y = mol.to(device), zeo.to(device), voxel.to(device), y.to(device)
            preds.append(model(mol, zeo, voxel).cpu().numpy())
            targets.append(y.cpu().numpy())
    
    y_pred = train_dataset.target_scaler.inverse_transform(np.vstack(preds))
    y_true = train_dataset.target_scaler.inverse_transform(np.vstack(targets))
    
    print("="*60)
    for i, col in enumerate(Config.TARGET_COLS):
        r2 = r2_score(y_true[:, i], y_pred[:, i])
        mae = np.mean(np.abs(y_true[:, i] - y_pred[:, i]))
        print(f"{col[:25]:25} : RÂ² = {r2:.4f}, MAE = {mae:.4f}")
    print("="*60)
    
    torch.save(model.state_dict(), "zeolite_3d_gnn_enriched_cleaned.pth")
    print("Model saved")