In [1]:
# Optimize LightGBM hyperparameters using Optuna with cross-validation
import optuna
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import mean_squared_error, r2_score, make_scorer
from torch_geometric.utils import from_smiles
import pandas as pd
from pyg_chemprop_utils import smiles2data


dataset_clean = pd.read_csv("dataset/SMRT_dataset_with_dates.csv")
# Randomly split the dataset into training and validation sets
train_set, valid_set = train_test_split(dataset_clean, test_size=0.2, random_state=42)

# Build graph dataset from train/valid splits
smiles_col = "SMILES"  # change if your column name differs
rt_col = "rt"          # change if your column name differs

train_graphs = []
for _, row in train_set.iterrows():
    data = smiles2data(row[smiles_col])
    data['y'] = row[rt_col]
    if data is not None:
        train_graphs.append(data)

valid_graphs = []
for _, row in valid_set.iterrows():
    data = smiles2data(row[smiles_col])
    data['y'] = row[rt_col]
    if data is not None:
        valid_graphs.append(data)

len(train_graphs), len(valid_graphs)

(64030, 16008)

In [2]:
from torch_geometric.data import InMemoryDataset
import torch

class MyDataset(InMemoryDataset):
    def __init__(self, data_list, transform=None):
        super().__init__(".", transform)
        self.data, self.slices = self.collate(data_list)

    def _download(self):
        pass

    def _process(self):
        pass

In [71]:
from pyg_chemprop import DMPNNEncoder, RevIndexedData

train_graphs = [RevIndexedData(graph) for graph in train_graphs]
valid_graphs = [RevIndexedData(graph) for graph in valid_graphs]

In [72]:
train_graphs[0]

RevIndexedData(y=865.2, edge_index=[2, 104], x=[49, 133], edge_attr=[104, 14], revedge_index=[104])

In [73]:
from torch_geometric.loader import DataLoader
batch_size=512

train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_graphs, batch_size=batch_size, shuffle=False)

In [74]:
from tqdm import tqdm
from pyg_chemprop_utils import initialize_weights
def train(config, loader, device=torch.device("cpu")):
    criterion = config["loss"]
    model = config["model"]
    optimizer = config["optimizer"]
    scheduler = config["scheduler"]

    model = model.to(device)
    model.train()
    for batch in tqdm(loader, total=len(loader)):
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch)
        loss = criterion(out.squeeze(), batch.y.float())
        loss.backward()
        optimizer.step()
        scheduler.step()
def make_prediction(config, loader, device=torch.device("cpu")):
    model = config["model"]

    model = model.to(device)
    model.eval()
    y_pred = []
    y_true = []
    for batch in tqdm(loader, total=len(loader)):
        batch = batch.to(device)
        with torch.no_grad():
            batch_preds = model(batch)
        y_pred.extend(batch_preds)
        y_true.extend(batch.y)
    return torch.stack(y_pred).cpu(), torch.stack(y_true).cpu()

In [75]:
num_epochs = 50
hidden_size = 256
depth = 3
out_dim = 1

In [76]:
from torch import nn
head = nn.Sequential(
    nn.Linear(hidden_size, hidden_size, bias=True),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size//2, bias=True),
    nn.ReLU(),
    nn.Linear(hidden_size//2, out_dim, bias=True),
)
model = nn.Sequential(
    DMPNNEncoder(
        hidden_size,
        train_loader.dataset[0].num_node_features,
        train_loader.dataset[0].num_edge_features,
        depth,
    ),
    head,
)
initialize_weights(model)

In [77]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=1e-3, steps_per_epoch=len(train_loader), epochs=num_epochs
)
config = {
    "loss": criterion,
    "model": model,
    "optimizer": optimizer,
    "scheduler": scheduler,
}


In [57]:
y_true, y_pred.squeeze()

(tensor([ 612.8000, 1065.5000, 1137.9000,  ...,  739.2000,  673.8000,
          954.4000]),
 tensor([775.3771, 790.3701, 862.6066,  ..., 777.1818, 767.7501, 792.4919]))

In [78]:
from sklearn.metrics import mean_squared_error, r2_score
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}")
    train(config, train_loader)
    y_pred, y_true = make_prediction(config, valid_loader)
    mse = mean_squared_error(y_true, y_pred.squeeze())
    r2 = r2_score(y_true, y_pred.squeeze())
    print(f"val r2={r2:.6} mse={mse:.6}")

Epoch 1


100%|██████████| 126/126 [00:35<00:00,  3.52it/s]
100%|██████████| 32/32 [00:04<00:00,  7.72it/s]


val r2=-12.3121 mse=5.65318e+05
Epoch 2


100%|██████████| 126/126 [00:36<00:00,  3.47it/s]
100%|██████████| 32/32 [00:03<00:00,  8.23it/s]


val r2=0.00687665 mse=42174.5
Epoch 3


100%|██████████| 126/126 [00:36<00:00,  3.50it/s]
100%|██████████| 32/32 [00:04<00:00,  7.94it/s]


val r2=0.0640051 mse=39748.4
Epoch 4


100%|██████████| 126/126 [00:36<00:00,  3.48it/s]
100%|██████████| 32/32 [00:04<00:00,  7.56it/s]


val r2=0.117943 mse=37457.9
Epoch 5


100%|██████████| 126/126 [00:36<00:00,  3.42it/s]
100%|██████████| 32/32 [00:04<00:00,  7.11it/s]


val r2=0.151058 mse=36051.6
Epoch 6


100%|██████████| 126/126 [00:36<00:00,  3.46it/s]
100%|██████████| 32/32 [00:04<00:00,  7.47it/s]


val r2=0.193923 mse=34231.3
Epoch 7


100%|██████████| 126/126 [00:38<00:00,  3.28it/s]
100%|██████████| 32/32 [00:04<00:00,  7.84it/s]


val r2=0.2481 mse=31930.6
Epoch 8


100%|██████████| 126/126 [00:36<00:00,  3.43it/s]
100%|██████████| 32/32 [00:04<00:00,  7.54it/s]


val r2=0.279872 mse=30581.3
Epoch 9


100%|██████████| 126/126 [00:36<00:00,  3.44it/s]
100%|██████████| 32/32 [00:04<00:00,  7.58it/s]


val r2=0.22763 mse=32799.8
Epoch 10


100%|██████████| 126/126 [00:36<00:00,  3.44it/s]
100%|██████████| 32/32 [00:04<00:00,  7.75it/s]


val r2=0.352268 mse=27506.9
Epoch 11


100%|██████████| 126/126 [00:36<00:00,  3.44it/s]
100%|██████████| 32/32 [00:04<00:00,  7.25it/s]


val r2=0.348023 mse=27687.2
Epoch 12


100%|██████████| 126/126 [00:36<00:00,  3.42it/s]
100%|██████████| 32/32 [00:04<00:00,  7.63it/s]


val r2=0.422647 mse=24518.1
Epoch 13


100%|██████████| 126/126 [00:36<00:00,  3.46it/s]
100%|██████████| 32/32 [00:04<00:00,  7.86it/s]


val r2=0.436474 mse=23931.0
Epoch 14


100%|██████████| 126/126 [00:37<00:00,  3.33it/s]
100%|██████████| 32/32 [00:04<00:00,  7.28it/s]


val r2=0.467063 mse=22632.0
Epoch 15


100%|██████████| 126/126 [00:36<00:00,  3.45it/s]
100%|██████████| 32/32 [00:04<00:00,  7.23it/s]


val r2=0.238001 mse=32359.4
Epoch 16


100%|██████████| 126/126 [00:37<00:00,  3.36it/s]
100%|██████████| 32/32 [00:04<00:00,  7.85it/s]


val r2=0.536349 mse=19689.6
Epoch 17


100%|██████████| 126/126 [00:37<00:00,  3.34it/s]
100%|██████████| 32/32 [00:04<00:00,  6.78it/s]


val r2=0.564482 mse=18494.9
Epoch 18


100%|██████████| 126/126 [00:37<00:00,  3.40it/s]
100%|██████████| 32/32 [00:04<00:00,  6.98it/s]


val r2=0.491863 mse=21578.8
Epoch 19


100%|██████████| 126/126 [00:37<00:00,  3.37it/s]
100%|██████████| 32/32 [00:04<00:00,  7.72it/s]


val r2=0.592994 mse=17284.1
Epoch 20


100%|██████████| 126/126 [00:36<00:00,  3.43it/s]
100%|██████████| 32/32 [00:04<00:00,  7.75it/s]


val r2=0.605388 mse=16757.8
Epoch 21


100%|██████████| 126/126 [00:36<00:00,  3.42it/s]
100%|██████████| 32/32 [00:04<00:00,  6.97it/s]


val r2=0.273676 mse=30844.4
Epoch 22


100%|██████████| 126/126 [00:36<00:00,  3.48it/s]
100%|██████████| 32/32 [00:04<00:00,  6.88it/s]


val r2=0.596052 mse=17154.2
Epoch 23


100%|██████████| 126/126 [00:37<00:00,  3.37it/s]
100%|██████████| 32/32 [00:04<00:00,  6.75it/s]


val r2=0.634812 mse=15508.3
Epoch 24


100%|██████████| 126/126 [00:37<00:00,  3.37it/s]
100%|██████████| 32/32 [00:04<00:00,  6.88it/s]


val r2=0.640123 mse=15282.7
Epoch 25


100%|██████████| 126/126 [00:37<00:00,  3.37it/s]
100%|██████████| 32/32 [00:04<00:00,  7.37it/s]


val r2=0.64831 mse=14935.0
Epoch 26


100%|██████████| 126/126 [00:38<00:00,  3.28it/s]
100%|██████████| 32/32 [00:04<00:00,  7.22it/s]


val r2=0.586824 mse=17546.2
Epoch 27


100%|██████████| 126/126 [00:36<00:00,  3.43it/s]
100%|██████████| 32/32 [00:04<00:00,  7.64it/s]


val r2=0.667099 mse=14137.2
Epoch 28


100%|██████████| 126/126 [00:37<00:00,  3.38it/s]
100%|██████████| 32/32 [00:04<00:00,  6.41it/s]


val r2=0.65986 mse=14444.5
Epoch 29


100%|██████████| 126/126 [00:37<00:00,  3.37it/s]
100%|██████████| 32/32 [00:04<00:00,  7.63it/s]


val r2=0.669433 mse=14038.0
Epoch 30


100%|██████████| 126/126 [00:37<00:00,  3.37it/s]
100%|██████████| 32/32 [00:04<00:00,  7.67it/s]


val r2=0.689215 mse=13197.9
Epoch 31


100%|██████████| 126/126 [00:37<00:00,  3.33it/s]
100%|██████████| 32/32 [00:04<00:00,  6.93it/s]


val r2=0.687362 mse=13276.7
Epoch 32


100%|██████████| 126/126 [00:36<00:00,  3.45it/s]
100%|██████████| 32/32 [00:04<00:00,  7.15it/s]


val r2=0.69862 mse=12798.5
Epoch 33


100%|██████████| 126/126 [00:36<00:00,  3.47it/s]
100%|██████████| 32/32 [00:04<00:00,  7.43it/s]


val r2=0.645926 mse=15036.3
Epoch 34


100%|██████████| 126/126 [00:36<00:00,  3.48it/s]
100%|██████████| 32/32 [00:04<00:00,  7.78it/s]


val r2=0.700736 mse=12708.7
Epoch 35


100%|██████████| 126/126 [00:37<00:00,  3.38it/s]
100%|██████████| 32/32 [00:04<00:00,  7.94it/s]


val r2=0.702279 mse=12643.2
Epoch 36


100%|██████████| 126/126 [00:37<00:00,  3.32it/s]
100%|██████████| 32/32 [00:04<00:00,  7.67it/s]


val r2=0.626861 mse=15845.9
Epoch 37


100%|██████████| 126/126 [00:37<00:00,  3.34it/s]
100%|██████████| 32/32 [00:03<00:00,  8.12it/s]


val r2=0.718572 mse=11951.2
Epoch 38


100%|██████████| 126/126 [00:38<00:00,  3.29it/s]
100%|██████████| 32/32 [00:03<00:00,  8.27it/s]


val r2=0.714921 mse=12106.3
Epoch 39


100%|██████████| 126/126 [00:38<00:00,  3.31it/s]
100%|██████████| 32/32 [00:04<00:00,  7.04it/s]


val r2=0.717883 mse=11980.5
Epoch 40


100%|██████████| 126/126 [00:36<00:00,  3.47it/s]
100%|██████████| 32/32 [00:04<00:00,  7.34it/s]


val r2=0.719578 mse=11908.5
Epoch 41


100%|██████████| 126/126 [00:35<00:00,  3.51it/s]
100%|██████████| 32/32 [00:04<00:00,  7.44it/s]


val r2=0.725761 mse=11646.0
Epoch 42


100%|██████████| 126/126 [00:36<00:00,  3.45it/s]
100%|██████████| 32/32 [00:04<00:00,  7.82it/s]


val r2=0.728024 mse=11549.9
Epoch 43


100%|██████████| 126/126 [00:37<00:00,  3.36it/s]
100%|██████████| 32/32 [00:04<00:00,  7.94it/s]


val r2=0.728579 mse=11526.3
Epoch 44


100%|██████████| 126/126 [00:36<00:00,  3.42it/s]
100%|██████████| 32/32 [00:03<00:00,  8.00it/s]


val r2=0.729427 mse=11490.3
Epoch 45


100%|██████████| 126/126 [00:36<00:00,  3.47it/s]
100%|██████████| 32/32 [00:04<00:00,  7.38it/s]


val r2=0.73066 mse=11437.9
Epoch 46


100%|██████████| 126/126 [00:36<00:00,  3.47it/s]
100%|██████████| 32/32 [00:04<00:00,  7.85it/s]


val r2=0.730371 mse=11450.2
Epoch 47


100%|██████████| 126/126 [00:37<00:00,  3.38it/s]
100%|██████████| 32/32 [00:03<00:00,  8.12it/s]


val r2=0.73113 mse=11418.0
Epoch 48


100%|██████████| 126/126 [00:36<00:00,  3.42it/s]
100%|██████████| 32/32 [00:04<00:00,  7.73it/s]


val r2=0.731706 mse=11393.5
Epoch 49


100%|██████████| 126/126 [00:37<00:00,  3.37it/s]
100%|██████████| 32/32 [00:04<00:00,  7.42it/s]


val r2=0.7317 mse=11393.8
Epoch 50


100%|██████████| 126/126 [00:35<00:00,  3.55it/s]
100%|██████████| 32/32 [00:04<00:00,  7.84it/s]


val r2=0.731699 mse=11393.8


In [80]:
y_pred, y_true = make_prediction(config, train_loader)
mse = mean_squared_error(y_true, y_pred.squeeze())
r2 = r2_score(y_true, y_pred.squeeze())
print(f"test r2={r2:.6} mse={mse:.6}")

100%|██████████| 126/126 [00:17<00:00,  7.29it/s]


test r2=0.740261 mse=11110.7
