# Molecular set representation learning

## Imports

In [20]:
from multiprocessing import cpu_count

import torch

import pandas as pd
import lightning.pytorch as pl

from torch.utils.data import DataLoader

from molsetrep.models import LightningDualSRRegressor
from molsetrep.encoders import DualSetEncoder

## Prepare the data

### Load from file

In [2]:
df_train = pd.read_csv("../data/adme/ADME_MDR1_ER_train.csv")
df_test = pd.read_csv("../data/adme/ADME_MDR1_ER_test.csv")

### Encode the data

In [7]:
enc = DualSetEncoder()
dataset_train = enc.encode(df_train["smiles"], df_train["activity"], torch.float32)
dataset_test = enc.encode(df_test["smiles"], df_test["activity"], torch.float32)

# Get the dimensions of the encoding
d = [
    len(dataset_train[0][0][0]),
    len(dataset_train[0][1][0])
]

### Get torch data loaders

In [14]:
train_loader = DataLoader(
    dataset_train,
    batch_size=64,
    shuffle=True,
    num_workers=cpu_count() if cpu_count() < 8 else 8,
    drop_last=True
)

test_loader = DataLoader(
    dataset_test,
    batch_size=64,
    shuffle=True,
    num_workers=cpu_count() if cpu_count() < 8 else 8,
    drop_last=True
)

## Train

### Initialise the model

In [12]:
model = LightningDualSRRegressor([64, 64], [8, 8], d)

### Initialise the trainer and fit

In [21]:
trainer = pl.Trainer(
    max_epochs=250,
)

trainer.fit(model, train_dataloaders=train_loader)
trainer.test(dataloaders=test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name          | Type              | Params
-----------------------------------------------------
0  | sr_regressor  | DualSRRegressor   | 191 K 
1  | train_r2      | R2Score           | 0     
2  | train_pearson | PearsonCorrCoef   | 0     
3  | train_rmse    | MeanSquaredError  | 0     
4  | train_mae     | MeanAbsoluteError | 0     
5  | val_r2        | R2Score           | 0     
6  | val_pearson   | PearsonCorrCoef   | 0     
7  | val_rmse      | MeanSquaredError  | 0     
8  | val_mae       | MeanAbsoluteError | 0     
9  | test_r2       | R2Score           | 0     
10 | test_pearson  | PearsonCorrCoef   | 0     
11 | test_rmse     | MeanSquaredError  | 0     
12 | test_mae      | MeanAbsoluteError | 0     
-----------------------------------------------------
191 K     Trainable param