# 02 Basic classifier


## Data and processing

TODO:

In [1]:
import pyarrow.dataset as ds
import pyarrow.compute as pc

PATH_TRAIN_DATA = "../../../data/train.parquet"
PATH_TEST_DATA = "../../../data/test.parquet"
DATA = ds.dataset(source=PATH_TRAIN_DATA, format="parquet")
DATA_TEST = ds.dataset(source=PATH_TEST_DATA, format="parquet")

In [2]:
scanner_no_bind = DATA.scanner(filter=(pc.field("binds") == 0))
scanner_bind = DATA.scanner(filter=(pc.field("binds") == 1))

In [3]:
import numpy as np
import numpy.typing as npt

In [4]:
def split_indices(
    n_rows, train_split: float = 0.8, size: str | None = None
) -> (npt.NDArray[np.uint64], npt.NDArray[np.uint64]):
    """
    Splits the indices of rows into training and validation sets.

    Args:
        n_rows: The total number of rows to generate indices for.
        train_split: The proportion of indices to allocate to the training set.
        size: If provided, trims the number of indices to this size.

    Returns:
        A tuple containing two arrays:

            - train_indices: Indices for the training set.
            - val_indices: Indices for the validation set.
    """

    # Generate indices and shuffle them in place.
    indices = np.arange(n_rows)
    np.random.shuffle(indices)
    
    # Trim the number of indices to size.
    indices = indices[None:size]

    # Split indices into training and validation sets
    train_size = int(indices.shape[0] * train_split)
    train_indices = indices[:train_size]
    val_indices = indices[train_size:]

    return train_indices, val_indices

In [5]:
N_NO_BIND = 293656924
N_BIND = 1589906

In [6]:
TRAIN_SPLIT = 0.8
SIZE = 50000

train_indices_no_bind, valid_indices_no_bind = split_indices(
    N_NO_BIND, TRAIN_SPLIT, size=SIZE
)
train_indices_bind, valid_indices_bind = split_indices(
    N_BIND, TRAIN_SPLIT, size=SIZE
)


## Features

TODO:

In [7]:
from rdkit import Chem
from rdkit.Chem import AllChem
from concurrent.futures import ProcessPoolExecutor

In [8]:
def clean_smiles(smiles):

    # Remove [Dy] from smiles
    smiles = smiles.replace("[Dy]", "")

    # Convert SMILES to a RDKit molecule object
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError("Invalid SMILES string")
    
    # Remove any salts or fragments
    mol = Chem.RemoveHs(mol)  # Remove explicit hydrogens
    fragments = Chem.GetMolFrags(mol, asMols=True)
    
    # Keep the largest fragment
    largest_fragment = max(fragments, default=mol, key=lambda m: m.GetNumAtoms())
    
    # Standardize the molecule
    AllChem.Compute2DCoords(largest_fragment)  # Compute 2D coordinates
    
    # Convert the molecule back to a canonical SMILES string
    cleaned_smiles = Chem.MolToSmiles(largest_fragment, canonical=True)
    return cleaned_smiles

In [9]:
def get_mol(smiles, optimize=False):
    mol = Chem.MolFromSmiles(smiles)
    mol = Chem.AddHs(mol)
    AllChem.EmbedMolecule(mol)
    if optimize:
        AllChem.MMFFOptimizeMolecule(mol, maxIters=200)
    return mol

In [10]:
def get_features(smiles: str, radius=3, nBits=2048):
    mol = Chem.MolFromSmiles(smiles)
    features = AllChem.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=nBits)
    return np.array(features)

In [11]:
def process_row(row):
    smiles = row['molecule_smiles']
    fingerprint = get_features(clean_smiles(smiles))
    return fingerprint

In [12]:
def process_batch(batch):
    fingerprints = []
    for row in batch:
        try:
            fingerprint = process_row(row)
            fingerprints.append(fingerprint)
        except ValueError:
            continue
    return np.array(fingerprints)

In [13]:
def compute_fingerprints(scanner, indices, chunk_size=1000):
    def generator(chunk_indices):
        for index in chunk_indices:
            yield scanner.to_table().slice(index, 1).to_pandas().to_dict('records')[0]

    fingerprints = []
    with ProcessPoolExecutor() as executor:
        for chunk_start in range(0, len(indices), chunk_size):
            chunk_indices = indices[chunk_start:chunk_start + chunk_size]
            chunks = [record for record in generator(chunk_indices)]
            result_batches = executor.map(process_batch, np.array_split(chunks, 8))
            fingerprints.extend(np.concatenate(list(result_batches)))
    return np.array(fingerprints)

: 

In [14]:
no_bind_fingerprints_train = compute_fingerprints(scanner_no_bind, train_indices_no_bind)
no_bind_fingerprints_val = compute_fingerprints(scanner_no_bind, valid_indices_no_bind)
bind_fingerprints_train = compute_fingerprints(scanner_bind, train_indices_bind)
bind_fingerprints_val = compute_fingerprints(scanner_bind, valid_indices_bind)

In [None]:
train_data = np.concatenate([no_bind_fingerprints_train, bind_fingerprints_train])
val_data = np.concatenate([no_bind_fingerprints_val, bind_fingerprints_val])

## Training