# Molecular set representation learning - Reaction yield prediction

## Imports

In [None]:
from multiprocessing import cpu_count

import torch

import pandas as pd
import lightning.pytorch as pl

from torch.utils.data import DataLoader

from molsetrep.models import (
    LightningDualSRRegressor,
)
from molsetrep.encoders import RXNSetEncoder

## Prepare the data

### Load the data from the data set

The following is a small reusable loader for complexes that have been preprocessed with `scripts/preprocess_pdbbind.py`.

In [None]:
import pickle

def reaction_loader(fold_idx: int = 0):
    az_path = "../data/az"
    splits = pickle.load(open(az_path + "/train_test_idxs.pickle", "rb"))

    train_ids = splits["train_idx"][fold_idx + 1]
    test_ids = splits["test_idx"][fold_idx + 1]

    df = pd.read_csv(az_path + "/az_no_rdkit.csv")
    
    df["smiles"] = (
        df.reactant_smiles
        + "."
        + df.solvent_smiles
        + "."
        + df.base_smiles
        + ">>"
        + df.product_smiles
    )

    train = df.iloc[train_ids]
    test = df.iloc[test_ids]

    # Validate on random sample from train
    valid = train.sample(frac=0.1)

    tasks = ["yield"]

    return (
        train["smiles"].to_list(), train["yield"].to_list(),
        valid["smiles"].to_list(), valid["yield"].to_list(),
        test["smiles"].to_list(), test["yield"].to_list(),
    )

Now call the loader function.

In [None]:
train_X, train_y, valid_X, valid_y, test_X, test_y = reaction_loader()

### Encode the data

In [None]:
enc = RXNSetEncoder()
dataset_train = enc.encode(train_X, train_y, torch.float32)
dataset_valid = enc.encode(valid_X, valid_y, torch.float32)
dataset_test = enc.encode(test_X, test_y, torch.float32)

# Get the dimensions of the encoding
dims_dual = [
    len(dataset_train[0][0][0]),
    len(dataset_train[0][1][0])
]

### Get torch data loaders

In [None]:
train_loader = DataLoader(
    dataset_train,
    batch_size=64,
    shuffle=True,
    num_workers=cpu_count() if cpu_count() < 8 else 8,
    drop_last=True
)

valid_loader = DataLoader(
    dataset_valid,
    batch_size=64,
    shuffle=True,
    num_workers=cpu_count() if cpu_count() < 8 else 8,
    drop_last=True
)

test_loader = DataLoader(
    dataset_test,
    batch_size=64,
    shuffle=True,
    num_workers=cpu_count() if cpu_count() < 8 else 8,
    drop_last=True
)

## Train

### Initialise the model

In [None]:
model = LightningDualSRRegressor([64, 64], [8, 8], dims_dual)

### Initialise the trainer and fit

In [None]:
trainer = pl.Trainer(
    max_epochs=150,
)

# Let's ignore the validation set for the exmaple
trainer.fit(model, train_dataloaders=train_loader)
trainer.test(dataloaders=test_loader)