In [76]:
import pathlib
import sys

import pandas as pd

import chemprop
import chempropstereo

from chemprop import data, models, nn
from lightning import pytorch as pl
import torch

In [77]:
choice = 2
run = 0
mode = "V2"

match choice:
    case 0:
        case = "chemprop"
        featurizer = chemprop.featurizers.SimpleMoleculeMolGraphFeaturizer(
            atom_featurizer=chemprop.featurizers.get_multi_hot_atom_featurizer(
                mode=mode
            )
        )
    case 1:
        case = "chempropstereo_diverge"
        featurizer = chempropstereo.featurizers.MoleculeStereoFeaturizer(
            mode=mode, divergent_bonds=True
        )
    case 2:
        case = "chempropstereo_converge"
        featurizer = chempropstereo.featurizers.MoleculeStereoFeaturizer(
            mode=mode, divergent_bonds=False
        )

### Directories

In [78]:
cwd = pathlib.Path.cwd()
dataset_dir = cwd

### Parameters

In [79]:
num_workers = 8

## Prepare Data

In [80]:
input_path = dataset_dir / "ld_classification_dataset.csv.gz"

### Load data

In [None]:
df_input = pd.read_csv(input_path, compression="gzip")
df_input

### Get molecule datapoints

In [82]:
all_data = list(
    map(
        data.MoleculeDatapoint.from_smi,
        df_input.loc[:, "smiles"].values,
        df_input.loc[:, ["sign_rotation"]].values,
    )
)

### Perform data splitting for training, validation, and testing

In [None]:
train_indices, val_indices, test_indices = (
    [df_input[df_input[f"split_{run}"] == split].index.to_list()]
    for split in ["train", "val", "test"]
)

# mols = range(len(all_data) // 2)
# train_pair_indices, val_pair_indices, test_pair_indices = data.make_split_indices(
#     mols, "random", (0.8, 0.1, 0.1)
# )

# train_indices, val_indices, test_indices = (
#     [sum(([2 * i, 2 * i + 1] for i in indices), [])]
#     for indices in (train_pair_indices[0], val_pair_indices[0], test_pair_indices[0])
# )

train_data, val_data, test_data = data.split_data_by_indices(
    all_data, train_indices, val_indices, test_indices
)

### Get MoleculeDataset

In [84]:
train_dset = data.MoleculeDataset(train_data[0], featurizer)
val_dset = data.MoleculeDataset(val_data[0], featurizer)
test_dset = data.MoleculeDataset(test_data[0], featurizer)

### Get DataLoader

In [85]:
train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)

## Message-Passing Neural Network (MPNN)

In [86]:
mpnn = models.MPNN(
    message_passing=nn.BondMessagePassing(
        d_v=featurizer.atom_fdim, d_e=featurizer.bond_fdim
    ),
    agg=nn.MeanAggregation(),
    predictor=nn.BinaryClassificationFFN(
        criterion=nn.BCELoss(),
    ),
    batch_norm=False,
    metrics=(
        nn.BinaryAUROC(),  # Area under the receiver operating characteristic curve
        nn.BinaryMCCMetric(),  # Binary Matthews correlation coefficient
        nn.BinaryF1Score(),  # The harmonic mean of precision and recall
        nn.BinaryAccuracy(),  # Fraction of correct predictions
    ),
)

## Training

In [None]:
checkpointing = pl.callbacks.ModelCheckpoint(
    filename="best-{epoch}-{val_loss:.3f}",
    monitor="val_loss",
    mode="min",
    save_last=True,
)

early_stopping = pl.callbacks.EarlyStopping(monitor="val_loss", patience=5)

trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    accelerator="gpu",
    devices=1,
    max_epochs=50,
    callbacks=[checkpointing, early_stopping],
)

In [88]:
# torch.set_float32_matmul_precision("medium")

In [None]:
checkpointing = pl.callbacks.ModelCheckpoint(
    dirpath=f"checkpoints/{case}/{mode}/run{run}",
    filename="best-{epoch}-{val_loss:.3f}",
    monitor="val_loss",
    mode="min",
    save_last=True,
)

early_stopping = pl.callbacks.EarlyStopping(monitor="val_loss", patience=5)

trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    accelerator="gpu",
    devices=1,
    max_epochs=50,
    callbacks=[checkpointing, early_stopping],
)

trainer.fit(mpnn, train_loader, val_loader)

## Test results

In [None]:
results = trainer.test(mpnn, test_loader)

In [None]:
results