# Molecular set representation learning - Molecular property prediction

## Imports

In [1]:
from multiprocessing import cpu_count

import torch

import pandas as pd
import lightning.pytorch as pl

from torch.utils.data import DataLoader

from molsetrep.models import (
    LightningSRRegressor,
    LightningDualSRRegressor,
    LightningSRGNNRegressor,
)
from molsetrep.encoders import SingleSetEncoder, DualSetEncoder, GraphEncoder

## Prepare the data

### Load from file

In [2]:
df_train = pd.read_csv("../data/adme/ADME_MDR1_ER_train.csv")
df_test = pd.read_csv("../data/adme/ADME_MDR1_ER_test.csv")

### Encode the data

#### Single-set (MSR1)

In [3]:
enc_single = SingleSetEncoder()
dataset_train_single = enc_single.encode(df_train["smiles"], df_train["activity"], torch.float32)
dataset_test_single = enc_single.encode(df_test["smiles"], df_test["activity"], torch.float32)

# Get the dimensions of the encoding
dims_single = [
    len(dataset_train_single[0][i][0])
    for i in range(len(dataset_train_single[0]) - 1)
]

#### Dual-set (MSR2)

In [4]:
enc_dual = DualSetEncoder()
dataset_train_dual = enc_dual.encode(df_train["smiles"], df_train["activity"], torch.float32)
dataset_test_dual = enc_dual.encode(df_test["smiles"], df_test["activity"], torch.float32)

# Get the dimensions of the encoding
dims_dual = [
    len(dataset_train_dual[0][i][0])
    for i in range(len(dataset_train_dual[0]) - 1)
]

#### Set-enhanced GNN (SR-GNN)

In [5]:
enc_graph = GraphEncoder()
dataset_train_graph = enc_graph.encode(df_train["smiles"], df_train["activity"], label_dtype=torch.float32)
dataset_test_graph = enc_graph.encode(df_test["smiles"], df_test["activity"], label_dtype=torch.float32)

# Get the dimensions of the encoding
dims_graph = [
    dataset_train_graph.dataset[0].num_node_features,
    dataset_train_graph.dataset[0].num_edge_features,
]

100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 2113/2113 [00:06<00:00, 328.23it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 529/529 [00:01<00:00, 320.94it/s]


### Get torch data loaders

#### Single-set (MSR1)

In [18]:
train_loader_single = DataLoader(
    dataset_train_single,
    batch_size=64,
    shuffle=True,
    num_workers=cpu_count() if cpu_count() < 8 else 8,
    drop_last=True
)

test_loader_single = DataLoader(
    dataset_test_single,
    batch_size=64,
    shuffle=True,
    num_workers=cpu_count() if cpu_count() < 8 else 8,
    drop_last=True
)

#### Dual-set (MSR2)

In [19]:
train_loader_dual = DataLoader(
    dataset_train_dual,
    batch_size=64,
    shuffle=True,
    num_workers=cpu_count() if cpu_count() < 8 else 8,
    drop_last=True
)

test_loader_dual = DataLoader(
    dataset_test_dual,
    batch_size=64,
    shuffle=True,
    num_workers=cpu_count() if cpu_count() < 8 else 8,
    drop_last=True
)

#### Set-enhanced GNN (SR-GNN)

In [20]:
train_loader_graph = dataset_train_graph
test_loader_graph = dataset_test_graph

## Train

### Initialise the model

#### Single-set (MSR1)

In [21]:
model_single = LightningSRRegressor([64], [8], dims_single)

#### Dual-set (MSR2)

In [22]:
model_dual = LightningDualSRRegressor([64, 64], [8, 8], dims_dual)

#### Set-enhanced GNN (SR-GNN)

In [23]:
model_graph = LightningSRGNNRegressor(
    [128, 128], [64, 64],
    n_hidden_channels=[128, 64],
    n_in_channels=dims_graph[0], 
    n_edge_channels=dims_graph[1],
    n_layers=8
)

### Initialise the trainer and fit

#### Single-set (MSR1)

In [25]:
trainer_single = pl.Trainer(
    max_epochs=250,
)

trainer_single.fit(model_single, train_dataloaders=train_loader_single)
trainer_single.test(dataloaders=test_loader_single)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name          | Type              | Params | Mode 
-------------------------------------------------------------
0  | sr_regressor  | SRRegressor       | 70.9 K | train
1  | train_r2      | R2Score           | 0      | train
2  | train_pearson | PearsonCorrCoef   | 0      | train
3  | train_rmse    | MeanSquaredError  | 0      | train
4  | train_mae     | MeanAbsoluteError | 0      | train
5  | val_r2        | R2Score           | 0      | train
6  | val_pearson   | PearsonCorrCoef   | 0      | train
7  | val_rmse      | MeanSquaredError  | 0      | train
8  | val_mae       | MeanAbsoluteError | 0      | train
9  | test_r2       | R2Score           | 0      | train
10 | test_pearson  | PearsonCorrCoef   | 0      | train
11 | test_rmse     | MeanSquaredError  | 0      | train
12 | test_mae      | MeanAbsoluteError | 0      | t

Epoch 249: 100%|███████████████████████████████████████████████████████████████████████████████████| 33/33 [00:00<00:00, 117.71it/s, v_num=4]

`Trainer.fit` stopped: `max_epochs=250` reached.


Epoch 249: 100%|███████████████████████████████████████████████████████████████████████████████████| 33/33 [00:00<00:00, 115.56it/s, v_num=4]

Restoring states from the checkpoint path at /home/daenu/Code/fix/molsetrep/example/lightning_logs/version_4/checkpoints/epoch=249-step=8250.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/daenu/Code/fix/molsetrep/example/lightning_logs/version_4/checkpoints/epoch=249-step=8250.ckpt





/home/daenu/micromamba/envs/molsetrep-fix/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:475: Your `test_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.


Testing DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 218.30it/s]


[{'test/loss': 0.26776665449142456,
  'test/r2': 0.4690614938735962,
  'test/pearson': 0.6880572438240051,
  'test/rmse': 0.5174617171287537,
  'test/mae': 0.39242780208587646}]

#### Dual-set (MSR2)

In [26]:
trainer_dual = pl.Trainer(
    max_epochs=250,
)

trainer_dual.fit(model_dual, train_dataloaders=train_loader_dual)
trainer_dual.test(dataloaders=test_loader_dual)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name          | Type              | Params | Mode 
-------------------------------------------------------------
0  | sr_regressor  | DualSRRegressor   | 191 K  | train
1  | train_r2      | R2Score           | 0      | train
2  | train_pearson | PearsonCorrCoef   | 0      | train
3  | train_rmse    | MeanSquaredError  | 0      | train
4  | train_mae     | MeanAbsoluteError | 0      | train
5  | val_r2        | R2Score           | 0      | train
6  | val_pearson   | PearsonCorrCoef   | 0      | train
7  | val_rmse      | MeanSquaredError  | 0      | train
8  | val_mae       | MeanAbsoluteError | 0      | train
9  | test_r2       | R2Score           | 0      | train
10 | test_pearson  | PearsonCorrCoef   | 0      | train
11 | test_rmse     | MeanSquaredError  | 0      | train
12 | test_mae      | MeanAbsoluteError | 0      | t

Epoch 249: 100%|███████████████████████████████████████████████████████████████████████████████████| 33/33 [00:00<00:00, 118.57it/s, v_num=5]

`Trainer.fit` stopped: `max_epochs=250` reached.


Epoch 249: 100%|███████████████████████████████████████████████████████████████████████████████████| 33/33 [00:00<00:00, 115.55it/s, v_num=5]

Restoring states from the checkpoint path at /home/daenu/Code/fix/molsetrep/example/lightning_logs/version_5/checkpoints/epoch=249-step=8250.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/daenu/Code/fix/molsetrep/example/lightning_logs/version_5/checkpoints/epoch=249-step=8250.ckpt



Testing DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 191.97it/s]


[{'test/loss': 0.27378132939338684,
  'test/r2': 0.4561035633087158,
  'test/pearson': 0.6984140872955322,
  'test/rmse': 0.5232411623001099,
  'test/mae': 0.38792866468429565}]

#### Set-enhanced GNN (SR-GNN)

In [27]:
trainer_graph = pl.Trainer(
    max_epochs=900,
)

trainer_graph.fit(model_graph, train_dataloaders=train_loader_graph)
trainer_graph.test(dataloaders=test_loader_graph)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name          | Type              | Params | Mode 
-------------------------------------------------------------
0  | gnn_regressor | SRGNNRegressor    | 1.5 M  | train
1  | train_r2      | R2Score           | 0      | train
2  | train_pearson | PearsonCorrCoef   | 0      | train
3  | train_rmse    | MeanSquaredError  | 0      | train
4  | train_mae     | MeanAbsoluteError | 0      | train
5  | val_r2        | R2Score           | 0      | train
6  | val_pearson   | PearsonCorrCoef   | 0      | train
7  | val_rmse      | MeanSquaredError  | 0      | train
8  | val_mae       | MeanAbsoluteError | 0      | train
9  | test_r2       | R2Score           | 0      | train
10 | test_pearson  | PearsonCorrCoef   | 0      | train
11 | test_rmse     | MeanSquaredError  | 0      | train
12 | test_mae      | MeanAbsoluteError | 0      | t

Epoch 899: 100%|████████████████████████████████████████████████████████████████████████████████████| 33/33 [00:00<00:00, 72.84it/s, v_num=6]

`Trainer.fit` stopped: `max_epochs=900` reached.


Epoch 899: 100%|████████████████████████████████████████████████████████████████████████████████████| 33/33 [00:00<00:00, 68.08it/s, v_num=6]


Restoring states from the checkpoint path at /home/daenu/Code/fix/molsetrep/example/lightning_logs/version_6/checkpoints/epoch=899-step=29700.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/daenu/Code/fix/molsetrep/example/lightning_logs/version_6/checkpoints/epoch=899-step=29700.ckpt


Testing DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 140.76it/s]


[{'test/loss': 0.18581531941890717,
  'test/r2': 0.6286168098449707,
  'test/pearson': 0.7939651012420654,
  'test/rmse': 0.4310629963874817,
  'test/mae': 0.3162570595741272}]