In [24]:
# Build a LightGBM model to predict RT using random split
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
from rdkit import Chem
from rdkit.Chem import AllChem
import lightgbm as lgb
import numpy as np
import pandas as pd
import plotly.express as px

# Function to generate Morgan fingerprints using the latest generator method
def generate_fingerprint(smiles, radius=2, nBits=2048):
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        morgan_gen = AllChem.GetMorganGenerator(radius=radius, fpSize=nBits)
        return list(morgan_gen.GetFingerprint(mol))
    return [0] * nBits
dataset = pd.read_csv("dataset/SMRT_dataset_with_dates.csv").sort_values(by='date').reset_index(drop=True)
dataset['fingerprint'] = dataset['SMILES'].apply(lambda x: generate_fingerprint(x)) 

## Random Split LGBM Model

In [25]:

dataset_clean = dataset.iloc[:int(0.9*dataset.shape[0]),:]
ftest_set = dataset.iloc[int(0.9*dataset.shape[0]):,]
# Randomly split the dataset into training and validation sets
train_set, valid_set = train_test_split(dataset_clean, test_size=0.2, random_state=42)

# Prepare data for LightGBM
X_train = np.array(train_set['fingerprint'].tolist())
y_train = train_set['rt']
X_valid = np.array(valid_set['fingerprint'].tolist())
y_valid = valid_set['rt']
X_ftest = np.array(ftest_set['fingerprint'].tolist())
y_ftest = ftest_set['rt']

# Train a LightGBM model with default parameters
train_data = lgb.Dataset(X_train, label=y_train)
valid_data = lgb.Dataset(X_valid, label=y_valid, reference=train_data)

print("Training LightGBM model...")
model = lgb.train({}, train_data, valid_sets=[valid_data])

# Make predictions on the training and validation sets
train_set['predicted_rt'] = model.predict(X_train)
valid_set['predicted_rt'] = model.predict(X_valid)

# Calculate R^2 and MSE for training and validation sets
train_r2 = r2_score(y_train, train_set['predicted_rt'])
train_mse = mean_squared_error(y_train, train_set['predicted_rt'])
valid_r2 = r2_score(y_valid, valid_set['predicted_rt'])
valid_mse = mean_squared_error(y_valid, valid_set['predicted_rt'])

ftest_r2 = r2_score(y_ftest, model.predict(X_ftest))
ftest_mse = mean_squared_error(y_ftest, model.predict(X_ftest))


# Combine training and validation for scatter plot
train_set['set'] = 'Training'
valid_set['set'] = 'Validation'
combined_data = pd.concat([train_set[['rt', 'predicted_rt', 'set']], valid_set[['rt', 'predicted_rt', 'set']]])

# Scatter plot for training and validation sets
title_text = (
    f"RT Prediction: Training vs Validation<br>"
    f"Training R²: {train_r2:.2f}, MSE: {train_mse:.2f}<br>"
    f"Validation R²: {valid_r2:.2f}, MSE: {valid_mse:.2f}<br>"
    f"ftest R²: {ftest_r2:.2f}, MSE: {ftest_mse:.2f}"
)

fig = px.scatter(
    combined_data, x='rt', y='predicted_rt', color='set',
    title=title_text,
    labels={'rt': 'Actual RT', 'predicted_rt': 'Predicted RT'},
    opacity=0.6
)

fig.update_layout(
    plot_bgcolor='rgba(229, 236, 246,1)',
    paper_bgcolor='rgba(229, 236, 246,1)',
    width=800,
    height=500
)

fig.show()

Training LightGBM model...


## Chronological Split LGBM Model

In [26]:
# Chronological split the dataset into training and validation sets
train_set, valid_set = dataset_clean.iloc[:int(0.8*dataset_clean.shape[0]),:], dataset_clean.iloc[int(0.8*dataset_clean.shape[0]):,:]

# Prepare data for LightGBM
X_train = np.array(train_set['fingerprint'].tolist())
y_train = train_set['rt']
X_valid = np.array(valid_set['fingerprint'].tolist())
y_valid = valid_set['rt']
X_ftest = np.array(ftest_set['fingerprint'].tolist())
y_ftest = ftest_set['rt']

# Train a LightGBM model with default parameters
train_data = lgb.Dataset(X_train, label=y_train)
valid_data = lgb.Dataset(X_valid, label=y_valid, reference=train_data)

print("Training LightGBM model...")
model = lgb.train({}, train_data, valid_sets=[valid_data])

# Make predictions on the training and validation sets
train_set['predicted_rt'] = model.predict(X_train)
valid_set['predicted_rt'] = model.predict(X_valid)

# Calculate R^2 and MSE for training and validation sets
train_r2 = r2_score(y_train, train_set['predicted_rt'])
train_mse = mean_squared_error(y_train, train_set['predicted_rt'])
valid_r2 = r2_score(y_valid, valid_set['predicted_rt'])
valid_mse = mean_squared_error(y_valid, valid_set['predicted_rt'])

ftest_r2 = r2_score(y_ftest, model.predict(X_ftest))
ftest_mse = mean_squared_error(y_ftest, model.predict(X_ftest))


# Combine training and validation for scatter plot
train_set['set'] = 'Training'
valid_set['set'] = 'Validation'
combined_data = pd.concat([train_set[['rt', 'predicted_rt', 'set']], valid_set[['rt', 'predicted_rt', 'set']]])

# Scatter plot for training and validation sets
title_text = (
    f"RT Prediction: Training vs Validation<br>"
    f"Training R²: {train_r2:.2f}, MSE: {train_mse:.2f}<br>"
    f"Validation R²: {valid_r2:.2f}, MSE: {valid_mse:.2f}<br>"
    f"ftest R²: {ftest_r2:.2f}, MSE: {ftest_mse:.2f}"
)

fig = px.scatter(
    combined_data, x='rt', y='predicted_rt', color='set',
    title=title_text,
    labels={'rt': 'Actual RT', 'predicted_rt': 'Predicted RT'},
    opacity=0.6
)

fig.update_layout(
    plot_bgcolor='rgba(229, 236, 246,1)',
    paper_bgcolor='rgba(229, 236, 246,1)',
    width=800,
    height=500
)

fig.show()

Training LightGBM model...




A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/

## Hyperparameters optimization

In [None]:
import optuna
import plotly.graph_objects as go
from IPython.display import display


def make_plotly_optuna_callback(
    title="Optuna Optimization Progress",
    yaxis_title="Objective Value",
):
    fig = go.FigureWidget()

    fig.add_scatter(
        x=[],
        y=[],
        mode="markers",
        name="Trial value",
        marker=dict(size=10)
    )

    fig.add_scatter(
        x=[],
        y=[],
        mode="lines+markers",
        name="Best so far"
    )

    fig.update_layout(
        title=title,
        plot_bgcolor='rgba(229, 236, 246,1)',
        paper_bgcolor='rgba(229, 236, 246,1)',
        xaxis_title="Trial",
        yaxis_title=yaxis_title,
    )

    display(fig)  # MUST be called once only

    def callback(study: optuna.Study, trial: optuna.trial.FrozenTrial):

        # Collect only completed trials
        trials = [t for t in study.trials if t.value is not None]
        if not trials:
            return

        xs = [t.number for t in trials]
        ys = [t.value for t in trials]

        # Compute running best
        best = []
        current = float("inf") if study.direction == optuna.study.StudyDirection.MINIMIZE else float("-inf")

        for v in ys:
            if study.direction == optuna.study.StudyDirection.MINIMIZE:
                current = min(current, v)
            else:
                current = max(current, v)
            best.append(current)

        with fig.batch_update():
            fig.data[0].x = xs
            fig.data[0].y = ys
            fig.data[1].x = xs
            fig.data[1].y = best

    return callback, fig


In [33]:
import plotly.graph_objects as go
from sklearn.model_selection import cross_val_score
from sklearn.metrics import make_scorer, mean_squared_error
import warnings

# Suppress the specific UserWarning about feature names
warnings.filterwarnings('ignore', category=UserWarning, 
                        message='X does not have valid feature names')


# Define the objective function for Optuna
def objective(trial):
    params = {
        "min_gain_to_split": 0.0,
        "min_data_in_leaf": 1,
        'n_estimators': trial.suggest_int('n_estimators', 20, 100),
        'learning_rate': trial.suggest_float('learning_rate', 1e-2, 1, log=True),
        'num_leaves': trial.suggest_int('num_leaves', 20, 300),
        'max_depth': trial.suggest_int('max_depth', 10, 50),
        'min_child_samples': trial.suggest_int('min_child_samples', 5, 100),
        'reg_alpha': trial.suggest_float('reg_alpha', 1e-4, 0.1, log=True),
        'reg_lambda': trial.suggest_float('reg_lambda', 1e-4, 0.1, log=True),
    }

    # Perform cross-validation
    lgb_model = lgb.LGBMRegressor(**params,verbosity=-1)
    mse = -cross_val_score(lgb_model, X_train, y_train, cv=5, scoring=make_scorer(mean_squared_error,greater_is_better=False)).mean()
    return mse


cb, fig = make_plotly_optuna_callback(
    title="Visualization Optimization",
    yaxis_title="MSE"
)

study = optuna.create_study(direction="minimize")
study.optimize(
    objective,
    n_trials=20,
    callbacks=[cb]
)


FigureWidget({
    'data': [{'marker': {'size': 10},
              'mode': 'markers',
              'name': 'Trial value',
              'type': 'scatter',
              'uid': '94cffc1d-12bb-48f9-b2a7-fcdea0d697a2',
              'x': [],
              'y': []},
             {'mode': 'lines+markers',
              'name': 'Best so far',
              'type': 'scatter',
              'uid': '5a2ec2e7-ef79-4721-985c-7c8c335b4b3c',
              'x': [],
              'y': []}],
    'layout': {'template': '...',
               'title': {'text': 'Visualization Optimization'},
               'xaxis': {'title': {'text': 'Trial'}},
               'yaxis': {'title': {'text': 'MSE'}}}
})

[I 2025-12-22 15:13:53,020] A new study created in memory with name: no-name-a33d4951-e0e0-4a94-98ed-b203f77dbed0


[I 2025-12-22 15:14:12,633] Trial 0 finished with value: 13996.640789858735 and parameters: {'n_estimators': 49, 'learning_rate': 0.4688951267456498, 'num_leaves': 247, 'max_depth': 13, 'min_child_samples': 73, 'reg_alpha': 0.002108411551961366, 'reg_lambda': 0.0020166439961425003}. Best is trial 0 with value: 13996.640789858735.
[I 2025-12-22 15:14:36,358] Trial 1 finished with value: 14446.417832174955 and parameters: {'n_estimators': 67, 'learning_rate': 0.5394318324742686, 'num_leaves': 191, 'max_depth': 39, 'min_child_samples': 63, 'reg_alpha': 0.0006455090671622542, 'reg_lambda': 0.02155444139938929}. Best is trial 0 with value: 13996.640789858735.
[I 2025-12-22 15:15:19,636] Trial 2 finished with value: 11753.785023777158 and parameters: {'n_estimators': 96, 'learning_rate': 0.07300636642942446, 'num_leaves': 272, 'max_depth': 50, 'min_child_samples': 74, 'reg_alpha': 0.0004030310771538315, 'reg_lambda': 0.0001360188629569451}. Best is trial 2 with value: 11753.785023777158.
[I 

### Retrain the model with the best paramters


In [34]:
# Get the best parameters
best_params = study.best_params
print("Best parameters:", best_params)

# Train the LightGBM model with the best parameters
model = lgb.LGBMRegressor(**best_params,verbosity=-1)
#model = lgb.LGBMRegressor()
model.fit(X_train, y_train)

# Make predictions on the training and validation sets
train_set['predicted_rt'] = model.predict(X_train)
valid_set['predicted_rt'] = model.predict(X_valid)

# Calculate R^2 and MSE for training and validation sets
train_r2 = r2_score(y_train, train_set['predicted_rt'])
train_mse = mean_squared_error(y_train, train_set['predicted_rt'])
valid_r2 = r2_score(y_valid, valid_set['predicted_rt'])
valid_mse = mean_squared_error(y_valid, valid_set['predicted_rt'])
ftest_r2 = r2_score(y_ftest, model.predict(X_ftest))
ftest_mse = mean_squared_error(y_ftest, model.predict(X_ftest))


# Combine training and validation for scatter plot
train_set['set'] = 'Training'
valid_set['set'] = 'Validation'
combined_data = pd.concat([train_set[['rt', 'predicted_rt', 'set']], valid_set[['rt', 'predicted_rt', 'set']]])

# Scatter plot for training and validation sets
title_text = (
    f"RT Prediction: Training vs Validation<br>"
    f"Training R²: {train_r2:.2f}, MSE: {train_mse:.2f}<br>"
    f"Validation R²: {valid_r2:.2f}, MSE: {valid_mse:.2f}<br>"
    f"ftest R²: {ftest_r2:.2f}, MSE: {ftest_mse:.2f}"
)

fig = px.scatter(
    combined_data, x='rt', y='predicted_rt', color='set',
    title=title_text,
    labels={'rt': 'Actual RT', 'predicted_rt': 'Predicted RT'},
    opacity=0.6
)

fig.update_layout(
    plot_bgcolor='rgba(229, 236, 246,1)',
    paper_bgcolor='rgba(229, 236, 246,1)',
    width=800,
    height=500
)

fig.show()


Best parameters: {'n_estimators': 100, 'learning_rate': 0.15837953299683752, 'num_leaves': 297, 'max_depth': 49, 'min_child_samples': 41, 'reg_alpha': 0.00013897052651983726, 'reg_lambda': 0.00012956333347286168}




A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/

# Graphic Neural Networks

In [81]:
# Optimize LightGBM hyperparameters using Optuna with cross-validation
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import mean_squared_error, r2_score, make_scorer
from torch_geometric.utils import from_smiles
import pandas as pd
from pyg_chemprop_utils import smiles2data


dataset = pd.read_csv("dataset/SMRT_dataset_with_dates.csv").sort_values(by='date').reset_index(drop=True)
dataset_clean = dataset.iloc[:int(0.9*dataset.shape[0]),:]
ftest_set = dataset.iloc[int(0.9*dataset_clean.shape[0]):,]
# Randomly split the dataset into training and validation sets
train_set, valid_set = train_test_split(dataset_clean, test_size=0.2, random_state=42)

# Build graph dataset from train/valid splits
smiles_col = "SMILES"  # change if your column name differs
rt_col = "rt"          # change if your column name differs

train_graphs = []
for _, row in train_set.iterrows():
    data = smiles2data(row[smiles_col])
    data['y'] = row[rt_col]
    if data is not None:
        train_graphs.append(data)

valid_graphs = []
for _, row in valid_set.iterrows():
    data = smiles2data(row[smiles_col])
    data['y'] = row[rt_col]
    if data is not None:
        valid_graphs.append(data)

ftest_graphs = []
for _, row in ftest_set.iterrows():
    data = smiles2data(row[smiles_col])
    data['y'] = row[rt_col]
    if data is not None:
        ftest_graphs.append(data)
len(train_graphs), len(valid_graphs), len(ftest_graphs)

(57627, 14407, 15208)

In [36]:
from torch_geometric.data import InMemoryDataset
import torch

class MyDataset(InMemoryDataset):
    def __init__(self, data_list, transform=None):
        super().__init__(".", transform)
        self.data, self.slices = self.collate(data_list)

    def _download(self):
        pass

    def _process(self):
        pass

In [82]:
from pyg_chemprop import DMPNNEncoder, RevIndexedData,GCNEncoder

train_graphs = [RevIndexedData(graph) for graph in train_graphs]
valid_graphs = [RevIndexedData(graph) for graph in valid_graphs]
ftest_graphs = [RevIndexedData(graph) for graph in ftest_graphs]

In [83]:
from torch_geometric.loader import DataLoader
batch_size=128

train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_graphs, batch_size=batch_size, shuffle=False)
ftest_loader = DataLoader(ftest_graphs, batch_size=batch_size, shuffle=False)

In [84]:
from tqdm import tqdm
from pyg_chemprop_utils import initialize_weights
def train(config, loader, device=torch.device("cpu")):
    criterion = config["loss"]
    model = config["model"]
    optimizer = config["optimizer"]
    scheduler = config["scheduler"]

    model = model.to(device)
    model.train()
    for batch in tqdm(loader, total=len(loader)):
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch)
        loss = criterion(out.squeeze(), batch.y.float())
        loss.backward()
        optimizer.step()
        scheduler.step()
def make_prediction(config, loader, device=torch.device("cpu")):
    model = config["model"]

    model = model.to(device)
    model.eval()
    y_pred = []
    y_true = []
    for batch in tqdm(loader, total=len(loader)):
        batch = batch.to(device)
        with torch.no_grad():
            batch_preds = model(batch)
        y_pred.extend(batch_preds)
        y_true.extend(batch.y)
    return torch.stack(y_pred).cpu(), torch.stack(y_true).cpu()

In [85]:
num_epochs = 50
hidden_size = 512
depth = 3
out_dim = 1

In [86]:
from torch import nn
head = nn.Sequential(
    nn.Linear(hidden_size, hidden_size, bias=True),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size//2, bias=True),
    nn.ReLU(),
    nn.Linear(hidden_size//2, out_dim, bias=True),
)
model = nn.Sequential(
    GCNEncoder(
        hidden_size,
        train_loader.dataset[0].num_node_features,
        depth,
    ),
    head,
)
initialize_weights(model)

In [87]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=1e-3, steps_per_epoch=len(train_loader), epochs=num_epochs
)
config = {
    "loss": criterion,
    "model": model,
    "optimizer": optimizer,
    "scheduler": scheduler,
}

In [88]:
from sklearn.metrics import mean_squared_error, r2_score
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}")
    train(config, train_loader)
    y_pred, y_true = make_prediction(config, valid_loader)
    mse = mean_squared_error(y_true, y_pred.squeeze())
    r2 = r2_score(y_true, y_pred.squeeze())
    print(f"val r2={r2:.4} mse={mse:.4}")

Epoch 1


100%|██████████| 451/451 [00:54<00:00,  8.34it/s]
100%|██████████| 113/113 [00:05<00:00, 20.72it/s]


val r2=0.07444 mse=3.921e+04
Epoch 2


100%|██████████| 451/451 [00:55<00:00,  8.18it/s]
100%|██████████| 113/113 [00:05<00:00, 21.95it/s]


val r2=0.1252 mse=3.706e+04
Epoch 3


100%|██████████| 451/451 [00:50<00:00,  8.85it/s]
100%|██████████| 113/113 [00:05<00:00, 20.20it/s]


val r2=0.1602 mse=3.558e+04
Epoch 4


100%|██████████| 451/451 [00:58<00:00,  7.66it/s]
100%|██████████| 113/113 [00:05<00:00, 20.86it/s]


val r2=0.214 mse=3.33e+04
Epoch 5


100%|██████████| 451/451 [00:53<00:00,  8.38it/s]
100%|██████████| 113/113 [00:05<00:00, 22.59it/s]


val r2=0.3102 mse=2.922e+04
Epoch 6


100%|██████████| 451/451 [00:49<00:00,  9.11it/s]
100%|██████████| 113/113 [00:05<00:00, 18.87it/s]


val r2=0.3971 mse=2.554e+04
Epoch 7


100%|██████████| 451/451 [00:54<00:00,  8.35it/s]
100%|██████████| 113/113 [00:05<00:00, 19.91it/s]


val r2=0.3656 mse=2.688e+04
Epoch 8


100%|██████████| 451/451 [00:56<00:00,  8.01it/s]
100%|██████████| 113/113 [00:06<00:00, 17.76it/s]


val r2=0.4502 mse=2.329e+04
Epoch 9


100%|██████████| 451/451 [00:54<00:00,  8.25it/s]
100%|██████████| 113/113 [00:05<00:00, 22.23it/s]


val r2=0.4145 mse=2.481e+04
Epoch 10


100%|██████████| 451/451 [00:54<00:00,  8.23it/s]
100%|██████████| 113/113 [00:06<00:00, 18.36it/s]


val r2=0.5241 mse=2.016e+04
Epoch 11


100%|██████████| 451/451 [00:56<00:00,  7.98it/s]
100%|██████████| 113/113 [00:04<00:00, 22.87it/s]


val r2=0.4615 mse=2.281e+04
Epoch 12


100%|██████████| 451/451 [00:55<00:00,  8.08it/s]
100%|██████████| 113/113 [00:05<00:00, 22.42it/s]


val r2=0.565 mse=1.843e+04
Epoch 13


100%|██████████| 451/451 [00:50<00:00,  8.94it/s]
100%|██████████| 113/113 [00:05<00:00, 19.73it/s]


val r2=0.5729 mse=1.809e+04
Epoch 14


100%|██████████| 451/451 [00:51<00:00,  8.83it/s]
100%|██████████| 113/113 [00:04<00:00, 22.67it/s]


val r2=0.5754 mse=1.799e+04
Epoch 15


100%|██████████| 451/451 [00:49<00:00,  9.06it/s]
100%|██████████| 113/113 [00:05<00:00, 22.18it/s]


val r2=0.5875 mse=1.748e+04
Epoch 16


100%|██████████| 451/451 [00:48<00:00,  9.22it/s]
100%|██████████| 113/113 [00:05<00:00, 22.37it/s]


val r2=0.5962 mse=1.711e+04
Epoch 17


100%|██████████| 451/451 [00:50<00:00,  8.95it/s]
100%|██████████| 113/113 [00:04<00:00, 22.73it/s]


val r2=0.5435 mse=1.934e+04
Epoch 18


100%|██████████| 451/451 [00:50<00:00,  9.02it/s]
100%|██████████| 113/113 [00:05<00:00, 22.52it/s]


val r2=0.5992 mse=1.698e+04
Epoch 19


100%|██████████| 451/451 [00:50<00:00,  8.95it/s]
100%|██████████| 113/113 [00:04<00:00, 23.12it/s]


val r2=0.6245 mse=1.591e+04
Epoch 20


100%|██████████| 451/451 [00:49<00:00,  9.13it/s]
100%|██████████| 113/113 [00:05<00:00, 20.51it/s]


val r2=0.6474 mse=1.494e+04
Epoch 21


100%|██████████| 451/451 [00:53<00:00,  8.48it/s]
100%|██████████| 113/113 [00:05<00:00, 20.06it/s]


val r2=0.6609 mse=1.437e+04
Epoch 22


100%|██████████| 451/451 [00:58<00:00,  7.73it/s]
100%|██████████| 113/113 [00:05<00:00, 22.32it/s]


val r2=0.6665 mse=1.413e+04
Epoch 23


100%|██████████| 451/451 [00:49<00:00,  9.12it/s]
100%|██████████| 113/113 [00:05<00:00, 19.94it/s]


val r2=0.6533 mse=1.469e+04
Epoch 24


100%|██████████| 451/451 [00:51<00:00,  8.76it/s]
100%|██████████| 113/113 [00:06<00:00, 18.09it/s]


val r2=0.6699 mse=1.398e+04
Epoch 25


100%|██████████| 451/451 [00:56<00:00,  7.94it/s]
100%|██████████| 113/113 [00:06<00:00, 17.10it/s]


val r2=0.682 mse=1.347e+04
Epoch 26


100%|██████████| 451/451 [00:50<00:00,  8.93it/s]
100%|██████████| 113/113 [00:05<00:00, 19.66it/s]


val r2=0.6294 mse=1.57e+04
Epoch 27


100%|██████████| 451/451 [00:51<00:00,  8.71it/s]
100%|██████████| 113/113 [00:06<00:00, 17.77it/s]


val r2=0.6924 mse=1.303e+04
Epoch 28


100%|██████████| 451/451 [00:58<00:00,  7.72it/s]
100%|██████████| 113/113 [00:04<00:00, 23.11it/s]


val r2=0.6828 mse=1.344e+04
Epoch 29


100%|██████████| 451/451 [00:50<00:00,  8.87it/s]
100%|██████████| 113/113 [00:04<00:00, 22.64it/s]


val r2=0.6804 mse=1.354e+04
Epoch 30


100%|██████████| 451/451 [00:50<00:00,  9.02it/s]
100%|██████████| 113/113 [00:05<00:00, 22.35it/s]


val r2=0.6918 mse=1.306e+04
Epoch 31


100%|██████████| 451/451 [00:56<00:00,  7.92it/s]
100%|██████████| 113/113 [00:06<00:00, 18.35it/s]


val r2=0.7022 mse=1.262e+04
Epoch 32


100%|██████████| 451/451 [00:50<00:00,  8.86it/s]
100%|██████████| 113/113 [00:04<00:00, 22.69it/s]


val r2=0.6886 mse=1.319e+04
Epoch 33


100%|██████████| 451/451 [00:57<00:00,  7.79it/s]
100%|██████████| 113/113 [00:05<00:00, 20.32it/s]


val r2=0.7207 mse=1.183e+04
Epoch 34


100%|██████████| 451/451 [01:00<00:00,  7.43it/s]
100%|██████████| 113/113 [00:06<00:00, 17.91it/s]


val r2=0.6845 mse=1.337e+04
Epoch 35


100%|██████████| 451/451 [00:51<00:00,  8.81it/s]
100%|██████████| 113/113 [00:05<00:00, 22.55it/s]


val r2=0.7022 mse=1.262e+04
Epoch 36


100%|██████████| 451/451 [00:54<00:00,  8.22it/s]
100%|██████████| 113/113 [00:05<00:00, 20.85it/s]


val r2=0.7109 mse=1.225e+04
Epoch 37


100%|██████████| 451/451 [00:53<00:00,  8.47it/s]
100%|██████████| 113/113 [00:04<00:00, 22.88it/s]


val r2=0.6719 mse=1.39e+04
Epoch 38


100%|██████████| 451/451 [00:51<00:00,  8.84it/s]
100%|██████████| 113/113 [00:04<00:00, 22.70it/s]


val r2=0.7167 mse=1.2e+04
Epoch 39


100%|██████████| 451/451 [00:56<00:00,  8.00it/s]
100%|██████████| 113/113 [00:05<00:00, 22.51it/s]


val r2=0.7301 mse=1.143e+04
Epoch 40


100%|██████████| 451/451 [00:55<00:00,  8.16it/s]
100%|██████████| 113/113 [00:05<00:00, 21.75it/s]


val r2=0.728 mse=1.152e+04
Epoch 41


100%|██████████| 451/451 [00:58<00:00,  7.76it/s]
100%|██████████| 113/113 [00:06<00:00, 17.29it/s]


val r2=0.73 mse=1.144e+04
Epoch 42


100%|██████████| 451/451 [00:50<00:00,  8.98it/s]
100%|██████████| 113/113 [00:05<00:00, 22.29it/s]


val r2=0.7347 mse=1.124e+04
Epoch 43


100%|██████████| 451/451 [00:50<00:00,  8.95it/s]
100%|██████████| 113/113 [00:05<00:00, 21.54it/s]


val r2=0.7366 mse=1.116e+04
Epoch 44


100%|██████████| 451/451 [00:54<00:00,  8.28it/s]
100%|██████████| 113/113 [00:06<00:00, 17.69it/s]


val r2=0.7364 mse=1.117e+04
Epoch 45


100%|██████████| 451/451 [00:57<00:00,  7.87it/s]
100%|██████████| 113/113 [00:05<00:00, 22.36it/s]


val r2=0.7383 mse=1.109e+04
Epoch 46


100%|██████████| 451/451 [00:49<00:00,  9.03it/s]
100%|██████████| 113/113 [00:05<00:00, 22.53it/s]


val r2=0.7388 mse=1.107e+04
Epoch 47


100%|██████████| 451/451 [00:49<00:00,  9.11it/s]
100%|██████████| 113/113 [00:06<00:00, 18.68it/s]


val r2=0.7392 mse=1.105e+04
Epoch 48


100%|██████████| 451/451 [00:59<00:00,  7.64it/s]
100%|██████████| 113/113 [00:05<00:00, 19.92it/s]


val r2=0.7393 mse=1.105e+04
Epoch 49


100%|██████████| 451/451 [00:54<00:00,  8.29it/s]
100%|██████████| 113/113 [00:05<00:00, 21.05it/s]


val r2=0.7392 mse=1.105e+04
Epoch 50


100%|██████████| 451/451 [00:53<00:00,  8.51it/s]
100%|██████████| 113/113 [00:05<00:00, 22.16it/s]

val r2=0.7395 mse=1.104e+04





In [89]:
train_y_pred, train_y_true = make_prediction(config, train_loader)
train_mse = mean_squared_error(train_y_true, train_y_pred.squeeze())
train_r2 = r2_score(train_y_true, train_y_pred.squeeze())
print(f"train r2={train_r2:.4} mse={train_mse:.4}")
valid_y_pred, valid_y_true = make_prediction(config, valid_loader)
valid_mse = mean_squared_error(valid_y_true, valid_y_pred.squeeze())
valid_r2 = r2_score(valid_y_true, valid_y_pred.squeeze())
print(f"valid r2={valid_r2:.4} mse={valid_mse:.4}")
ftest_y_pred, ftest_y_true = make_prediction(config, ftest_loader)
ftest_mse = mean_squared_error(ftest_y_true, ftest_y_pred.squeeze())
ftest_r2 = r2_score(ftest_y_true, ftest_y_pred.squeeze())
print(f"ftest r2={ftest_r2:.4} mse={ftest_mse:.4}")

100%|██████████| 451/451 [00:22<00:00, 19.73it/s]


train r2=0.7629 mse=1.01e+04


100%|██████████| 113/113 [00:05<00:00, 22.13it/s]


valid r2=0.7395 mse=1.104e+04


100%|██████████| 119/119 [00:05<00:00, 20.89it/s]

ftest r2=0.7562 mse=1.037e+04





## GCN Chronological Split

In [75]:
dataset = pd.read_csv("dataset/SMRT_dataset_with_dates.csv").sort_values(by='date').reset_index(drop=True)
dataset_clean = dataset.iloc[:int(0.9*dataset.shape[0]),:]
ftest_set = dataset.iloc[int(0.9*dataset_clean.shape[0]):,]
train_set,valid_set = dataset_clean.iloc[:int(0.8*dataset_clean.shape[0]),:], dataset_clean.iloc[int(0.8*dataset_clean.shape[0]):,:]
# Build graph dataset from train/valid splits
smiles_col = "SMILES"  # change if your column name differs
rt_col = "rt"          # change if your column name differs

train_graphs = []
for _, row in train_set.iterrows():
    data = smiles2data(row[smiles_col])
    data['y'] = row[rt_col]
    if data is not None:
        train_graphs.append(data)

valid_graphs = []
for _, row in valid_set.iterrows():
    data = smiles2data(row[smiles_col])
    data['y'] = row[rt_col]
    if data is not None:
        valid_graphs.append(data)

ftest_graphs = []
for _, row in ftest_set.iterrows():
    data = smiles2data(row[smiles_col])
    data['y'] = row[rt_col]
    if data is not None:
        ftest_graphs.append(data)
len(train_graphs), len(valid_graphs), len(ftest_graphs)

(57627, 14407, 15208)

In [76]:
from pyg_chemprop import DMPNNEncoder, RevIndexedData,GCNEncoder

train_graphs = [RevIndexedData(graph) for graph in train_graphs]
valid_graphs = [RevIndexedData(graph) for graph in valid_graphs]
ftest_graphs = [RevIndexedData(graph) for graph in ftest_graphs]

from torch_geometric.loader import DataLoader
batch_size=128

train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_graphs, batch_size=batch_size, shuffle=False)
ftest_loader = DataLoader(ftest_graphs, batch_size=batch_size, shuffle=False)


In [77]:
from torch import nn

depth=5
dropout = 0.1051376604780759
learning_rate = 0.0002541394185359693

head = nn.Sequential(
    nn.Linear(hidden_size, hidden_size, bias=True),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size//2, bias=True),
    nn.ReLU(),
    nn.Linear(hidden_size//2, out_dim, bias=True),
)
model = nn.Sequential(
    GCNEncoder(
        hidden_size,
        train_loader.dataset[0].num_node_features,
        depth,
        dropout=dropout,
    ),
    head,
)
initialize_weights(model)

In [78]:
#optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=1e-3, steps_per_epoch=len(train_loader), epochs=num_epochs
)
config = {
    "loss": criterion,
    "model": model,
    "optimizer": optimizer,
    "scheduler": scheduler,
}

In [79]:
from sklearn.metrics import mean_squared_error, r2_score
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}")
    train(config, train_loader)
    y_pred, y_true = make_prediction(config, valid_loader)
    mse = mean_squared_error(y_true, y_pred.squeeze())
    r2 = r2_score(y_true, y_pred.squeeze())
    print(f"val r2={r2:.4} mse={mse:.4}")

Epoch 1


100%|██████████| 451/451 [02:06<00:00,  3.56it/s]
100%|██████████| 113/113 [00:08<00:00, 12.58it/s]


val r2=0.08428 mse=3.748e+04
Epoch 2


100%|██████████| 451/451 [02:07<00:00,  3.54it/s]
100%|██████████| 113/113 [00:09<00:00, 11.60it/s]


val r2=0.1386 mse=3.525e+04
Epoch 3


100%|██████████| 451/451 [02:05<00:00,  3.58it/s]
100%|██████████| 113/113 [00:08<00:00, 13.04it/s]


val r2=0.2426 mse=3.1e+04
Epoch 4


100%|██████████| 451/451 [01:53<00:00,  3.97it/s]
100%|██████████| 113/113 [00:08<00:00, 13.43it/s]


val r2=0.3172 mse=2.794e+04
Epoch 5


100%|██████████| 451/451 [01:57<00:00,  3.85it/s]
100%|██████████| 113/113 [00:08<00:00, 12.62it/s]


val r2=0.3979 mse=2.464e+04
Epoch 6


100%|██████████| 451/451 [02:09<00:00,  3.48it/s]
100%|██████████| 113/113 [00:09<00:00, 12.06it/s]


val r2=0.4 mse=2.456e+04
Epoch 7


100%|██████████| 451/451 [01:58<00:00,  3.82it/s]
100%|██████████| 113/113 [00:09<00:00, 11.60it/s]


val r2=0.5129 mse=1.993e+04
Epoch 8


100%|██████████| 451/451 [02:00<00:00,  3.74it/s]
100%|██████████| 113/113 [00:09<00:00, 11.97it/s]


val r2=0.4254 mse=2.352e+04
Epoch 9


100%|██████████| 451/451 [01:56<00:00,  3.87it/s]
100%|██████████| 113/113 [00:08<00:00, 13.51it/s]


val r2=0.5416 mse=1.876e+04
Epoch 10


100%|██████████| 451/451 [02:00<00:00,  3.75it/s]
100%|██████████| 113/113 [00:09<00:00, 11.94it/s]


val r2=0.5729 mse=1.748e+04
Epoch 11


100%|██████████| 451/451 [02:00<00:00,  3.74it/s]
100%|██████████| 113/113 [00:08<00:00, 13.60it/s]


val r2=0.59 mse=1.678e+04
Epoch 12


100%|██████████| 451/451 [01:53<00:00,  3.96it/s]
100%|██████████| 113/113 [00:09<00:00, 11.79it/s]


val r2=-0.1967 mse=4.898e+04
Epoch 13


100%|██████████| 451/451 [02:02<00:00,  3.68it/s]
100%|██████████| 113/113 [00:08<00:00, 13.44it/s]


val r2=0.582 mse=1.711e+04
Epoch 14


100%|██████████| 451/451 [01:59<00:00,  3.79it/s]
100%|██████████| 113/113 [00:08<00:00, 13.31it/s]


val r2=0.6412 mse=1.468e+04
Epoch 15


100%|██████████| 451/451 [02:01<00:00,  3.70it/s]
100%|██████████| 113/113 [00:09<00:00, 12.18it/s]


val r2=0.6299 mse=1.515e+04
Epoch 16


100%|██████████| 451/451 [01:58<00:00,  3.82it/s]
100%|██████████| 113/113 [00:08<00:00, 13.70it/s]


val r2=0.6549 mse=1.413e+04
Epoch 17


100%|██████████| 451/451 [01:59<00:00,  3.77it/s]
100%|██████████| 113/113 [00:09<00:00, 12.34it/s]


val r2=0.6481 mse=1.44e+04
Epoch 18


100%|██████████| 451/451 [02:11<00:00,  3.42it/s]
100%|██████████| 113/113 [00:08<00:00, 13.06it/s]


val r2=0.5689 mse=1.764e+04
Epoch 19


100%|██████████| 451/451 [02:02<00:00,  3.68it/s]
100%|██████████| 113/113 [00:08<00:00, 12.69it/s]


val r2=0.6165 mse=1.57e+04
Epoch 20


100%|██████████| 451/451 [01:59<00:00,  3.78it/s]
100%|██████████| 113/113 [00:08<00:00, 12.86it/s]


val r2=0.6853 mse=1.288e+04
Epoch 21


100%|██████████| 451/451 [02:10<00:00,  3.46it/s]
100%|██████████| 113/113 [00:08<00:00, 12.91it/s]


val r2=0.587 mse=1.69e+04
Epoch 22


100%|██████████| 451/451 [01:59<00:00,  3.78it/s]
100%|██████████| 113/113 [00:09<00:00, 12.12it/s]


val r2=0.6087 mse=1.602e+04
Epoch 23


100%|██████████| 451/451 [01:58<00:00,  3.82it/s]
100%|██████████| 113/113 [00:09<00:00, 12.49it/s]


val r2=0.6602 mse=1.391e+04
Epoch 24


100%|██████████| 451/451 [02:08<00:00,  3.51it/s]
100%|██████████| 113/113 [00:09<00:00, 12.33it/s]


val r2=0.6999 mse=1.228e+04
Epoch 25


100%|██████████| 451/451 [02:10<00:00,  3.47it/s]
100%|██████████| 113/113 [00:08<00:00, 12.58it/s]


val r2=0.6976 mse=1.238e+04
Epoch 26


100%|██████████| 451/451 [01:58<00:00,  3.80it/s]
100%|██████████| 113/113 [00:08<00:00, 13.83it/s]


val r2=0.7168 mse=1.159e+04
Epoch 27


100%|██████████| 451/451 [02:08<00:00,  3.50it/s]
100%|██████████| 113/113 [00:08<00:00, 13.24it/s]


val r2=0.7204 mse=1.144e+04
Epoch 28


100%|██████████| 451/451 [02:05<00:00,  3.60it/s]
100%|██████████| 113/113 [00:09<00:00, 12.22it/s]


val r2=0.7065 mse=1.201e+04
Epoch 29


100%|██████████| 451/451 [02:00<00:00,  3.75it/s]
100%|██████████| 113/113 [00:09<00:00, 11.47it/s]


val r2=0.7282 mse=1.113e+04
Epoch 30


100%|██████████| 451/451 [01:57<00:00,  3.83it/s]
100%|██████████| 113/113 [00:09<00:00, 12.10it/s]


val r2=0.7276 mse=1.115e+04
Epoch 31


100%|██████████| 451/451 [02:00<00:00,  3.75it/s]
100%|██████████| 113/113 [00:08<00:00, 12.86it/s]


val r2=0.7396 mse=1.066e+04
Epoch 32


100%|██████████| 451/451 [02:09<00:00,  3.49it/s]
100%|██████████| 113/113 [00:08<00:00, 13.59it/s]


val r2=0.739 mse=1.068e+04
Epoch 33


100%|██████████| 451/451 [02:00<00:00,  3.74it/s]
100%|██████████| 113/113 [00:08<00:00, 12.66it/s]


val r2=0.7271 mse=1.117e+04
Epoch 34


100%|██████████| 451/451 [02:07<00:00,  3.53it/s]
100%|██████████| 113/113 [00:08<00:00, 12.89it/s]


val r2=0.7546 mse=1.004e+04
Epoch 35


100%|██████████| 451/451 [01:57<00:00,  3.85it/s]
100%|██████████| 113/113 [00:09<00:00, 12.32it/s]


val r2=0.7509 mse=1.02e+04
Epoch 36


100%|██████████| 451/451 [02:03<00:00,  3.66it/s]
100%|██████████| 113/113 [00:08<00:00, 13.04it/s]


val r2=0.7606 mse=9.797e+03
Epoch 37


100%|██████████| 451/451 [02:08<00:00,  3.52it/s]
100%|██████████| 113/113 [00:08<00:00, 13.75it/s]


val r2=0.7604 mse=9.806e+03
Epoch 38


100%|██████████| 451/451 [02:00<00:00,  3.74it/s]
100%|██████████| 113/113 [00:09<00:00, 11.77it/s]


val r2=0.7671 mse=9.53e+03
Epoch 39


100%|██████████| 451/451 [02:01<00:00,  3.70it/s]
100%|██████████| 113/113 [00:08<00:00, 12.99it/s]


val r2=0.767 mse=9.537e+03
Epoch 40


100%|██████████| 451/451 [02:06<00:00,  3.56it/s]
100%|██████████| 113/113 [00:08<00:00, 12.58it/s]


val r2=0.7718 mse=9.338e+03
Epoch 41


100%|██████████| 451/451 [01:59<00:00,  3.76it/s]
100%|██████████| 113/113 [00:08<00:00, 13.06it/s]


val r2=0.7683 mse=9.482e+03
Epoch 42


100%|██████████| 451/451 [01:58<00:00,  3.82it/s]
100%|██████████| 113/113 [00:09<00:00, 11.90it/s]


val r2=0.7711 mse=9.367e+03
Epoch 43


100%|██████████| 451/451 [02:08<00:00,  3.52it/s]
100%|██████████| 113/113 [00:08<00:00, 13.73it/s]


val r2=0.7743 mse=9.236e+03
Epoch 44


100%|██████████| 451/451 [01:59<00:00,  3.77it/s]
100%|██████████| 113/113 [00:08<00:00, 12.77it/s]


val r2=0.775 mse=9.211e+03
Epoch 45


100%|██████████| 451/451 [01:59<00:00,  3.78it/s]
100%|██████████| 113/113 [00:08<00:00, 14.05it/s]


val r2=0.7746 mse=9.226e+03
Epoch 46


100%|██████████| 451/451 [01:58<00:00,  3.80it/s]
100%|██████████| 113/113 [00:09<00:00, 12.07it/s]


val r2=0.7781 mse=9.082e+03
Epoch 47


100%|██████████| 451/451 [02:07<00:00,  3.54it/s]
100%|██████████| 113/113 [00:08<00:00, 13.40it/s]


val r2=0.7772 mse=9.119e+03
Epoch 48


100%|██████████| 451/451 [02:01<00:00,  3.71it/s]
100%|██████████| 113/113 [00:08<00:00, 13.03it/s]


val r2=0.7801 mse=8.999e+03
Epoch 49


100%|██████████| 451/451 [02:02<00:00,  3.69it/s]
100%|██████████| 113/113 [00:09<00:00, 11.40it/s]


val r2=0.7801 mse=8.999e+03
Epoch 50


100%|██████████| 451/451 [01:52<00:00,  4.02it/s]
100%|██████████| 113/113 [00:08<00:00, 12.94it/s]


val r2=0.7803 mse=8.992e+03


In [80]:
train_y_pred, train_y_true = make_prediction(config, train_loader)
train_mse = mean_squared_error(train_y_true, train_y_pred.squeeze())
train_r2 = r2_score(train_y_true, train_y_pred.squeeze())
print(f"train r2={train_r2:.4} mse={train_mse:.4}")
valid_y_pred, valid_y_true = make_prediction(config, valid_loader)
valid_mse = mean_squared_error(valid_y_true, valid_y_pred.squeeze())
valid_r2 = r2_score(valid_y_true, valid_y_pred.squeeze())
print(f"valid r2={valid_r2:.4} mse={valid_mse:.4}")
ftest_y_pred, ftest_y_true = make_prediction(config, ftest_loader)
ftest_mse = mean_squared_error(ftest_y_true, ftest_y_pred.squeeze())
ftest_r2 = r2_score(ftest_y_true, ftest_y_pred.squeeze())
print(f"ftest r2={ftest_r2:.4} mse={ftest_mse:.4}")

100%|██████████| 451/451 [00:33<00:00, 13.33it/s]


train r2=0.8101 mse=8.16e+03


100%|██████████| 113/113 [00:09<00:00, 12.09it/s]


valid r2=0.7803 mse=8.992e+03


100%|██████████| 119/119 [00:10<00:00, 11.89it/s]

ftest r2=0.7863 mse=9.092e+03





# GCN hyperparameters tuning with optuna

In [None]:
import optuna
import torch
from torch import nn
from torch_geometric.loader import DataLoader
from pyg_chemprop import GCNEncoder
from pyg_chemprop_utils import initialize_weights
from sklearn.metrics import mean_squared_error
import warnings

# Suppress warnings
warnings.filterwarnings('ignore')

# Define a silent train function to avoid tqdm clutter during optimization
def train_silent(config, loader, device=torch.device("cpu")):
    criterion = config["loss"]
    model = config["model"]
    optimizer = config["optimizer"]
    scheduler = config["scheduler"]

    model = model.to(device)
    model.train()
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch)
        loss = criterion(out.squeeze(), batch.y.float())
        loss.backward()
        optimizer.step()
        scheduler.step()

def objective(trial):
    # Hyperparameters to optimize
    hidden_size = trial.suggest_categorical('hidden_size', [128, 256, 512])
    depth = trial.suggest_int('depth', 2, 5)
    dropout = trial.suggest_float('dropout', 0.0, 0.5)
    learning_rate = trial.suggest_float('learning_rate', 1e-4, 1e-2, log=True)
    batch_size = trial.suggest_categorical('batch_size', [128, 256, 512])
    
    # Create DataLoaders for this trial
    # Using train_graphs and valid_graphs from the global scope
    train_loader_trial = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
    valid_loader_trial = DataLoader(valid_graphs, batch_size=batch_size, shuffle=False)
    
    # Model definition
    out_dim = 1
    head = nn.Sequential(
        nn.Linear(hidden_size, hidden_size, bias=True),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(hidden_size, hidden_size//2, bias=True),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(hidden_size//2, out_dim, bias=True),
    )
    
    model = nn.Sequential(
        GCNEncoder(
            hidden_size,
            train_graphs[0].num_node_features,
            depth,
            dropout=dropout
        ),
        head,
    )
    initialize_weights(model)
    
    # Training setup
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()
    
    num_epochs = 10 # Use fewer epochs for optimization
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=learning_rate*10, steps_per_epoch=len(train_loader_trial), epochs=num_epochs
    )
    
    config = {
        "loss": criterion,
        "model": model,
        "optimizer": optimizer,
        "scheduler": scheduler,
    }
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Training loop
    for epoch in range(num_epochs):
        train_silent(config, train_loader_trial, device=device)
        
    # Final Evaluation
    y_pred, y_true = make_prediction(config, valid_loader_trial, device=device)
    mse = mean_squared_error(y_true, y_pred.squeeze())
    
    return mse

# Visualization callback
cb, fig = make_plotly_optuna_callback(
    title="GCN Hyperparameter Optimization",
    yaxis_title="MSE"
)

# Run optimization
study_gcn = optuna.create_study(direction="minimize")
study_gcn.optimize(
    objective,
    n_trials=20,
    callbacks=[cb]
)

In [None]:
from sklearn.metrics import mean_squared_error, r2_score

# Get the best parameters
best_params = study_gcn.best_params
print("Best parameters:", best_params)

# Extract params
hidden_size = best_params['hidden_size']
depth = best_params['depth']
dropout = best_params['dropout']
learning_rate = best_params['learning_rate']
batch_size = best_params['batch_size']

# Recreate loaders with best batch size
train_loader_best = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
valid_loader_best = DataLoader(valid_graphs, batch_size=batch_size, shuffle=False)
ftest_loader_best = DataLoader(ftest_graphs, batch_size=batch_size, shuffle=False)

# Rebuild model
out_dim = 1
head = nn.Sequential(
    nn.Linear(hidden_size, hidden_size, bias=True),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(hidden_size, hidden_size//2, bias=True),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(hidden_size//2, out_dim, bias=True),
)

model = nn.Sequential(
    GCNEncoder(
        hidden_size,
        train_graphs[0].num_node_features,
        depth,
        dropout=dropout
    ),
    head,
)
initialize_weights(model)

# Training setup
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
num_epochs = 50 # Full training
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=learning_rate*10, steps_per_epoch=len(train_loader_best), epochs=num_epochs
)

config = {
    "loss": criterion,
    "model": model,
    "optimizer": optimizer,
    "scheduler": scheduler,
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Train
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}")
    train(config, train_loader_best, device=device) # Use the original train with tqdm
    y_pred, y_true = make_prediction(config, valid_loader_best, device=device)
    mse = mean_squared_error(y_true, y_pred.squeeze())
    r2 = r2_score(y_true, y_pred.squeeze())
    print(f"val r2={r2:.4} mse={mse:.4}")

# Final evaluation
train_y_pred, train_y_true = make_prediction(config, train_loader_best, device=device)
train_mse = mean_squared_error(train_y_true, train_y_pred.squeeze())
train_r2 = r2_score(train_y_true, train_y_pred.squeeze())
print(f"train r2={train_r2:.4} mse={train_mse:.4}")

valid_y_pred, valid_y_true = make_prediction(config, valid_loader_best, device=device)
valid_mse = mean_squared_error(valid_y_true, valid_y_pred.squeeze())
valid_r2 = r2_score(valid_y_true, valid_y_pred.squeeze())
print(f"valid r2={valid_r2:.4} mse={valid_mse:.4}")

ftest_y_pred, ftest_y_true = make_prediction(config, ftest_loader_best, device=device)
ftest_mse = mean_squared_error(ftest_y_true, ftest_y_pred.squeeze())
ftest_r2 = r2_score(ftest_y_true, ftest_y_pred.squeeze())
print(f"ftest r2={ftest_r2:.4} mse={ftest_mse:.4}")

In [None]:
import plotly.express as px
import pandas as pd

# Combine training and validation for scatter plot
# We need to get the original RT values and the predicted ones.
# Since we used loaders, the order might be shuffled for training, but we can just use the y_true and y_pred tensors we collected.

# Convert tensors to numpy
train_true_np = train_y_true.numpy()
train_pred_np = train_y_pred.squeeze().numpy()
valid_true_np = valid_y_true.numpy()
valid_pred_np = valid_y_pred.squeeze().numpy()

train_df = pd.DataFrame({'rt': train_true_np, 'predicted_rt': train_pred_np, 'set': 'Training'})
valid_df = pd.DataFrame({'rt': valid_true_np, 'predicted_rt': valid_pred_np, 'set': 'Validation'})
combined_data = pd.concat([train_df, valid_df])

# Scatter plot for training and validation sets
title_text = (
    f"GCN RT Prediction: Training vs Validation (Best Params)<br>"
    f"Training R²: {train_r2:.2f}, MSE: {train_mse:.2f}<br>"
    f"Validation R²: {valid_r2:.2f}, MSE: {valid_mse:.2f}<br>"
    f"ftest R²: {ftest_r2:.2f}, MSE: {ftest_mse:.2f}"
)

fig = px.scatter(
    combined_data, x='rt', y='predicted_rt', color='set',
    title=title_text,
    labels={'rt': 'Actual RT', 'predicted_rt': 'Predicted RT'},
    opacity=0.6
)

fig.update_layout(
    plot_bgcolor='rgba(229, 236, 246,1)',
    paper_bgcolor='rgba(229, 236, 246,1)',
    width=800,
    height=500
)

fig.show()

In [1]:
import plotly.express as px
import pandas as pd

# Define the data for the 4 groups
# Note: Replace the values below with the actual R2 scores you obtained from your runs
data = {
    'Model': [
        'LGBM Random', 'LGBM Random',
        'LGBM Chronological', 'LGBM Chronological',
        'GCN Random', 'GCN Random',
        'GCN Chronological', 'GCN Chronological'
    ],
    'Metric': [
        'Valid R2', 'Ftest R2',
        'Valid R2', 'Ftest R2',
        'Valid R2', 'Ftest R2',
        'Valid R2', 'Ftest R2'
    ],
    'R2': [
        0.85, 0.82,  # Placeholder for LGBM Random
        0.80, 0.78,  # Placeholder for LGBM Chronological
        0.88, 0.85,  # Placeholder for GCN Random
        0.83, 0.81   # Placeholder for GCN Chronological
    ]
}

# Create DataFrame
df_comparison = pd.DataFrame(data)

# Create grouped bar chart
fig = px.bar(
    df_comparison, 
    x='Model', 
    y='R2', 
    color='Metric', 
    barmode='group',
    title='Model Performance Comparison (R²)',
    text='R2',
    opacity=0.8
)

# Update layout for better visualization
fig.update_traces(texttemplate='%{text:.3f}', textposition='outside')
fig.update_layout(
    plot_bgcolor='rgba(229, 236, 246,1)',
    paper_bgcolor='rgba(229, 236, 246,1)',
    yaxis_title="R² Score",
    xaxis_title="Model Group",
    width=900,
    height=600,
    yaxis=dict(range=[0, 1.1]) # Extend y-axis slightly for text labels
)

fig.show()