In [None]:
# ==================================================
# Import
# ==================================================

import pandas as pd
import numpy as np
import os
import json
from sklearn.metrics import root_mean_squared_error, make_scorer
from sklearn.model_selection import KFold, cross_validate, GridSearchCV
from sklearn.base import clone
import optuna
from lightgbm import LGBMRegressor
from sklearn.ensemble import RandomForestRegressor
import torch
from torch.utils.data import DataLoader
from argparse import Namespace
from sklearn.linear_model import Ridge
import joblib

# Feature Extractors

In [None]:
# ==================================================
# Feature Extractor: Baseline Descriptors
# ==================================================

def DescriptorFarm(df: pd.DataFrame) -> pd.DataFrame:
    from rdkit import Chem
    from rdkit.Chem import AllChem, Descriptors
    from rdkit.Chem.Descriptors3D import descList as _desc3d_list
    from rdkit.Chem.Scaffolds import MurckoScaffold

    # 2D Descriptors
    desc2d = [(n, f) for n, f in Descriptors._descList if not n.startswith('fr_')]
    names2d = {n for n, _ in desc2d}

    # 3D Descriptors
    desc3d = [(n, f) for n, f in _desc3d_list if not n.startswith('fr_') and n not in names2d]

    # Significant SMARTS Patterns
    smarts_defs = {
        'Imidazole': ['c1ncnc1'],
        'Pyrazole': ['n1nccc1'],
        'Thiazole': ['c1cscn1'],
        'Triazole': ['n1nncc1', 'c1ncnn1'],
        'Toluene': ['Cc1ccccc1'],
        'N-Ethyllformamide': ['O=CNCC[aR]'],
        'Amino_Arylmethane': ['[aR]CN'],
        'N-Phenethylformamide': ['O=CNCCc1ccccc1'],
        'Carbamate': ['OC(=O)N'],
        'Benzodioxole': ['c1cc2OCOc2cc1'],
        'Furan': ['c1ccoc1'],
        'Terminal_Alkyne': ['[C]#[CH]'],
        'Primary_Amine': ['[CX4][NH2]'],
        'Sec_Cyclic_Amine': ['[CX4;R][NH1;R][CX4;R]'],
        'Cyclopropyl_Amine': [
        'C(C)(C)(C)N(C1CC1)C(C)(C)(C)',
           'C(=C)(C)N(C1CC1)C(=C)(C)',
           'C(#C)N(C1CC1)C(#C)',
           '[H]N(C1CC1)[H]'
        ],
        'Hydroquinone':[
            'Oc1ccc(O)cc1',
            'Oc1ccccc1(O)',
            'O=c1ccc(=O)cc1',
            'O=c1ccccc1(=O)'
        ],
        'Epoxide': ['C1OC1'],
        'Tertiary_Amine': ['N1([CX4])CCN(CC1)[CX4]'],
        'Alkylphenol': [
            '[OH]c1ccc(C([C,H])([C,H])[C,H])cc1',
            '[OH]c1ccccc1(C([C,H])([C,H])[C,H])'
        ],
        'Alkylaromatic_Ether': [
            'c1c([CX4])cc[c](O[CX4])c1',
            'c1cc[c](O[CX4])c([CX4])c1'
        ],
        'Arenes':[
            'c1ccc2c(c1)CCCO2',
            '[C,H]N(C(c1ccccc1)c2ccccc2)[C,H]',
            'c1ccccc1[CX4]C=C(C)(C)',
            'c1ccccc1[CX4]C=C(=C)'
        ],
        'Alkoxybenzene': ['c1ccc(OC)cc1'],
    }
    smarts_patterns = {
        name: [Chem.MolFromSmarts(s) for s in smarts_defs[name]]
        for name in smarts_defs
    }

    # Row-Wise Computation
    records = []
    for smi in df['Canonical_Smiles']:
        rec = {}
        mol = Chem.MolFromSmiles(smi)
        if mol is None:
            for name, _ in desc2d + desc3d:
                rec[name] = None
            for col in [
                'Murcko_num_rings',
                'Murcko_max_ring_size',
                'Murcko_min_ring_size',
                'Murcko_mean_ring_size',
                'Murcko_heavy_atom_count',
                'Murcko_heteroatom_count',
                'Murcko_mol_wt'
            ]:
                rec[col] = None
            for name in smarts_patterns:
                rec[name] = None
        else:
            # 2D Descriptors
            for name, func in desc2d:
                try:
                    rec[name] = func(mol)
                except:
                    rec[name] = None

            # 3D Descriptors
            mol3d = Chem.AddHs(mol)
            try:
                AllChem.EmbedMolecule(mol3d, AllChem.ETKDG())
                AllChem.MMFFOptimizeMolecule(mol3d)
                for name, func in desc3d:
                    try:
                        rec[name] = func(mol3d)
                    except:
                        rec[name] = None
            except:
                for name, _ in desc3d:
                    rec[name] = None

            # Murcko Scaffold
            try:
                scaffold = MurckoScaffold.GetScaffoldForMol(mol)
                ri = scaffold.GetRingInfo()
                ring_sizes = [len(r) for r in ri.AtomRings()]
                num_rings = len(ring_sizes)
                rec['Murcko_num_rings'] = num_rings
                rec['Murcko_max_ring_size'] = max(ring_sizes) if ring_sizes else 0
                rec['Murcko_min_ring_size'] = min(ring_sizes) if ring_sizes else 0
                rec['Murcko_mean_ring_size'] = sum(ring_sizes)/num_rings if num_rings else 0
                rec['Murcko_heavy_atom_count'] = scaffold.GetNumHeavyAtoms()
                heteros = sum(1 for atom in scaffold.GetAtoms()
                             if atom.GetAtomicNum() not in (6,1))
                rec['Murcko_heteroatom_count'] = heteros
                rec['Murcko_mol_wt'] = Descriptors.MolWt(scaffold)
            except Exception:
                for col in ['Murcko_num_rings',
                            'Murcko_max_ring_size',
                            'Murcko_min_ring_size',
                            'Murcko_mean_ring_size',
                            'Murcko_heavy_atom_count',
                            'Murcko_heteroatom_count',
                            'Murcko_mol_wt']:
                    rec[col] = None

            # Matching SMARTS Patterns
            for name, patterns in smarts_patterns.items():
                try:
                    rec[name] = any(mol.HasSubstructMatch(pat) for pat in patterns)
                except:
                    rec[name] = None

        records.append(rec)

    # Result
    desc_df = pd.DataFrame(records, index = df.index)

    return pd.concat([df, desc_df], axis = 1)

In [None]:
# ==================================================
# Feature Extractor: Fingerprints
# ==================================================

def FingerprintFarm(
        df: pd.DataFrame,
        smiles_col: str = 'Canonical_Smiles',
        fp: str | None = None
) -> pd.DataFrame:
    import os
    from rdkit import Chem, RDConfig
    from rdkit.Chem import rdMolDescriptors, rdFingerprintGenerator
    from rdkit.Chem.rdmolops import PatternFingerprint
    from rdkit.Chem.Pharm2D.SigFactory import SigFactory
    from rdkit.Chem.Pharm2D import Generate as Pharm2DGen
    from rdkit.Chem.rdReducedGraphs import GetErGFingerprint
    from rdkit.Chem.rdMHFPFingerprint import MHFPEncoder
    from rdkit.Chem import ChemicalFeatures

    # Fingerprint Generator
    fp_name = None
    fp_gen = None
    if fp:
        nm = fp.lower()
        if nm == 'rdkit':
            fp_name, fp_gen = 'rdkit', rdFingerprintGenerator.GetRDKitFPGenerator(fpSize = 2048)
        elif nm == 'atompairs':
            fp_name, fp_gen = 'atompairs', rdFingerprintGenerator.GetAtomPairGenerator(fpSize = 2048)
        elif nm == 'topologicaltorsions':
            fp_name, fp_gen = 'topologicaltorsions', rdFingerprintGenerator.GetTopologicalTorsionGenerator(fpSize = 2048)
        elif nm == 'ecfp4':
            fp_name, fp_gen = 'ecfp4', rdFingerprintGenerator.GetMorganGenerator(radius = 2, fpSize = 2048)
        elif nm == 'ecfp6':
            fp_name, fp_gen = 'ecfp6', rdFingerprintGenerator.GetMorganGenerator(radius = 3, fpSize = 2048)
        elif nm == 'maccs':
            fp_name = 'maccs'
        elif nm == 'pattern':
            fp_name = 'pattern'
        elif nm == '2dpharmacophore':
            fdef = os.path.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
            featFactory = ChemicalFeatures.BuildFeatureFactory(fdef)
            sigFactory = SigFactory(featFactory, minPointCount = 2, maxPointCount = 3, trianglePruneBins = False)
            sigFactory.SetBins([(0, 2), (2, 5), (5, 8)])
            sigFactory.Init()
            fp_name, fp_gen = '2dpharmacophore', sigFactory
        elif nm == 'erg':
            fp_name = 'erg'
        elif nm == 'mhfp':
            fp_name, fp_gen = 'mhfp', MHFPEncoder()
        elif nm == 'secfp':
            fp_name, fp_gen = 'secfp', MHFPEncoder()

    # Row-Wise Computation
    records = []
    for smi in df[smiles_col]:
        rec = {}
        mol = Chem.MolFromSmiles(smi)
        if fp_name is None or mol is None:
            if fp_name:
                for i in range(2048):
                    rec[f'{fp_name}_{i}'] = None
        else:
            if fp_name in ('rdkit', 'atompairs', 'topologicaltorsions', 'ecfp4', 'ecfp6'):
                bv = fp_gen.GetFingerprint(mol)
                bits = [int(b) for b in bv.ToBitString()]
            elif fp_name == 'maccs':
                bv = rdMolDescriptors.GetMACCSKeysFingerprint(mol)
                bits = [int(b) for b in bv.ToBitString()]
            elif fp_name == 'pattern':
                bv = PatternFingerprint(mol)
                bits = [int(b) for b in bv.ToBitString()]
            elif fp_name == '2dpharmacophore':
                bv = Pharm2DGen.Gen2DFingerprint(mol, fp_gen)
                bits = [int(b) for b in bv.ToBitString()]
            elif fp_name == 'erg':
                try:
                    bv = GetErGFingerprint(mol)
                    bits = list(bv)
                except KeyError:
                    bits = [None] * 2048
            elif fp_name == 'mhfp':
                bv = fp_gen.EncodeMol(mol)
                bits = list(bv)
            elif fp_name == 'secfp':
                bv = fp_gen.EncodeSECFPMol(mol)
                bits = list(bv)
            else:
                bits = [0] * 2048
            for i, bit in enumerate(bits):
                rec[f'{fp_name}_{i}'] = bit
        records.append(rec)

    fp_df = pd.DataFrame(records, index = df.index)
    return pd.concat([df, fp_df], axis = 1)

# Models

In [None]:
# ==================================================
# Trained Models & Ensemble Weights
# ==================================================

# Models Dictionary
md = joblib.load('model_dict.joblib')

# Ensemble Weights
with open('params/params_ens.json', 'r') as f:
    w = json.load(f)
w = np.array(list(w.values()))

# Preprocessing

In [None]:
# ==================================================
# Preprocessing: LGBM
# ==================================================

# Fingerprints List
fp_list = [
    'rdkit', 'atompairs', 'topologicaltorsions', 'ecfp4', 'ecfp6',
    'maccs', 'pattern', '2dpharmacophore', 'erg', 'mhfp', 'secfp'
]

# Train Data: Baseline Descriptors
train = pd.read_csv('train.csv')
train = DescriptorFarm(train)

# Train Data: Inputs Dictionary
x = {'base': train.drop(['ID', 'Canonical_Smiles', 'Inhibition'], axis = 1)}
for i in fp_list:
    tmp = FingerprintFarm(train, fp = i)
    tmp = tmp.drop(['ID', 'Canonical_Smiles', 'Inhibition'], axis = 1)
    x[i] = tmp

# Test Data: Baseline Descriptors
test = pd.read_csv('test.csv')
test = DescriptorFarm(test)

# Test Data: Inputs Dictionary
x_test = {'base': test.drop(['ID', 'Canonical_Smiles'], axis = 1)}
for i in fp_list:
    tmp = FingerprintFarm(test, fp = i)
    tmp = tmp.drop(['ID', 'Canonical_Smiles'], axis = 1)
    x_test[i] = tmp

# Constant Columns Removal & Columns Matching
for i in x.keys():
    nonconst_cols = x[i].columns[x[i].nunique() > 1]
    x[i] = x[i][nonconst_cols]
    x_test[i] = x_test[i][nonconst_cols]

del x

In [None]:
# ==================================================
# Preprocessing: RF
# ==================================================

# Inputs Dictionary Update
for i in md.keys():
    if '_additional' in i:
        i_tmp = i.replace('_additional', '')
        x_test[i] = x_test[i_tmp].copy()

In [None]:
# ==================================================
# Preprocessing: GROVER (https://github.com/tencent-ailab/grover) + Ridge
# ==================================================

from grover.model.models import GroverFpGeneration
from grover.data.molgraph import MolCollator

# `MolCollator`-Compatible Class
smiles_list = pd.read_csv('test.csv')['Canonical_Smiles'].to_list()

class Record:
    __slots__ = ("smiles", "features", "targets")
    def __init__(self, s):
        self.smiles, self.features, self.targets = s, None, [None]

records = [Record(s) for s in smiles_list]

collator = MolCollator({}, Namespace(bond_drop_rate = 0, no_cache = True))

# Data Loader
loader = DataLoader(
    records,
    batch_size = 128,
    shuffle = False,
    collate_fn = collator
)

# Pretrained GROVER
grover_large = torch.load(
    'grover/grover_large.pt', map_location = 'cpu', weights_only = False
)
grover_state = grover_large['state_dict']

# Additional Arguments Required
grover_args = grover_large['args']
grover_args.cuda = torch.cuda.is_available()
grover_args.dropout = 0.1
grover_args.fingerprint_source = 'both'

# Model Definition
grover = GroverFpGeneration(grover_args)
grover.load_state_dict(grover_state, strict = False)

# Device Setting
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Embedding Extraction
grover.to(device).eval()
emb_list = []

with torch.no_grad():
    for _, graph_components, *_ in loader:
        graph_components = tuple(
            t.to(device) if torch.is_tensor(t) else t
            for t in graph_components
        )
        
        emb = grover(graph_components, [None]).cpu().numpy()
        emb_list.append(emb)
        
        # Resource Optimization
        del emb, graph_components
        loader.collate_fn.shared_dict = {}
        torch.cuda.empty_cache()

grover_test = pd.DataFrame(np.concatenate(emb_list, axis = 0))

# Inputs Dictionary Update
x_test['grover'] = grover_test

# Inference

In [None]:
# ==================================================
# Inference
# ==================================================

# Meta Features
meta_x_test = {}
for i in md.keys():
    meta_x_test[i] = md[i].predict(x_test[i])

meta_x_test = pd.DataFrame(meta_x_test)

# Prediction
y_pred = meta_x_test.mul(w, axis = 1).sum(axis = 1) / w.sum()

# Submission
submission = pd.read_csv('test.csv')
submission['Inhibition'] = np.clip(y_pred, 0, 100)
submission.to_csv('submission.csv')