# Molecular set representation learning - Binding affinity prediction

## Imports

In [None]:
from multiprocessing import cpu_count

import torch

import pandas as pd
import lightning.pytorch as pl

from torch.utils.data import DataLoader

from molsetrep.models import (
    LightningDualSRRegressor,
)
from molsetrep.encoders import LigandProtEncoder

## Prepare the data

### Load the data from the data set

The following is a small reusable loader for complexes that have been preprocessed with `scripts/preprocess_pdbbind.py`.

In [None]:
import numpy as np
from rdkit import Chem

def complex_loader():
    meta_path = "../data/pdbbind/meta.csv"

    data = {"train": [[], []], "valid": [[], []], "test": [[], []]}
    df = pd.read_csv(meta_path)

    for _, row in df.iterrows():
        data[row["split"]][0].append(
            (
                Chem.MolFromMol2File(row["mol_path"], sanitize=False, removeHs=False),
                Chem.MolFromPDBFile(row["pocket_path"], sanitize=False, removeHs=False)
            )
        )
        data[row["split"]][1].append(row["label"])

    return (
        data["train"][0], np.array(data["train"][1]), 
        data["valid"][0], np.array(data["valid"][1]), 
        data["test"][0], np.array(data["test"][1])
    )

Now call the loader function.

In [None]:
train_X, train_y, valid_X, valid_y, test_X, test_y = complex_loader()

### Encode the data

In [None]:
enc = LigandProtEncoder()
dataset_train = enc.encode(train_X, train_y, torch.float32)
dataset_valid = enc.encode(valid_X, valid_y, torch.float32)
dataset_test = enc.encode(test_X, test_y, torch.float32)

# Get the dimensions of the encoding
dims_dual = [
    len(dataset_train[0][0][0]),
    len(dataset_train[0][1][0])
]

### Get torch data loaders

In [None]:
train_loader = DataLoader(
    dataset_train,
    batch_size=64,
    shuffle=True,
    num_workers=cpu_count() if cpu_count() < 8 else 8,
    drop_last=True
)

valid_loader = DataLoader(
    dataset_valid,
    batch_size=64,
    shuffle=True,
    num_workers=cpu_count() if cpu_count() < 8 else 8,
    drop_last=True
)

test_loader = DataLoader(
    dataset_test,
    batch_size=64,
    shuffle=True,
    num_workers=cpu_count() if cpu_count() < 8 else 8,
    drop_last=True
)

## Train

### Initialise the model

In [None]:
model = LightningDualSRRegressor([64, 64], [8, 8], dims_dual)

### Initialise the trainer and fit

In [None]:
trainer = pl.Trainer(
    max_epochs=150,
)

# Let's ignore the validation set for the exmaple
trainer.fit(model, train_dataloaders=train_loader)
trainer.test(dataloaders=test_loader)