# Import packages

In [None]:
import pandas as pd

from lightning import pytorch as pl
from sklearn.model_selection import train_test_split

from chemprop import data, featurizers, models, nn

# Change data inputs here

In [None]:
input_path = '../tests/data/regression/mol/mol.csv' # path to your data .csv file
num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading
smiles_column = 'smiles' # name of the column containing SMILES strings
target_columns = ['lipo'] # list of names of the columns containing targets

## Load data

In [None]:
df_input = pd.read_csv(input_path)
df_input

## Get SMILES and targets

In [None]:
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values

In [None]:
smis[:5] # show first 5 SMILES strings

In [None]:
ys[:5] # show first 5 targets

## Get molecule datapoints

In [None]:
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]

## Perform data splitting for training, validation, and testing

In [None]:
train_data, val_test_data = train_test_split(all_data, test_size=0.1)
val_data, test_data = train_test_split(val_test_data, test_size=0.5)

## Get MoleculeDataset

In [None]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_dset = data.MoleculeDataset(train_data, featurizer)
scaler = train_dset.normalize_targets()

val_dset = data.MoleculeDataset(val_data, featurizer)
val_dset.normalize_targets(scaler)
test_dset = data.MoleculeDataset(test_data, featurizer)
test_dset.normalize_targets(scaler)


## Get DataLoader

In [None]:
train_loader = data.MolGraphDataLoader(train_dset, num_workers=num_workers)
val_loader = data.MolGraphDataLoader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.MolGraphDataLoader(test_dset, num_workers=num_workers, shuffle=False)

# Change Message-Passing Neural Network (MPNN) inputs here

## Message Passing
A `Message passing` constructs molecular graphs using message passing to learn node-level hidden representations.

Options are `mp = nn.BondMessagePassing()` or `mp = nn.AtomMessagePassing()`

In [None]:
mp = nn.BondMessagePassing()

## Aggregation
An `Aggregation` is responsible for constructing a graph-level representation from the set of node-level representations after message passing.

Available options can be found in ` nn.agg.AggregationRegistry`, including
- `agg = nn.MeanAggregation()`
- `agg = nn.SumAggregation()`
- `agg = nn.NormAggregation()`

In [None]:
print(nn.agg.AggregationRegistry)

In [None]:
agg = nn.MeanAggregation()

## Feed-Forward Network (FFN)

A `FFN` takes the aggregated representations and make target predictions.

Available options can be found in `nn.PredictorRegistry`.

For regression:
- `ffn = nn.RegressionFFN()`
- `ffn = nn.MveFFN()`
- `ffn = nn.EvidentialFFN()`

For classification:
- `ffn = nn.BinaryClassificationFFN()`
- `ffn = nn.BinaryDirichletFFN()`
- `ffn = nn.MulticlassClassificationFFN()`
- `ffn = nn.MulticlassDirichletFFN()`

For spectral:
- `ffn = nn.SpectralFFN()` # will be available in future version

In [None]:
print(nn.PredictorRegistry)

In [None]:
ffn = nn.RegressionFFN(
    loc=scaler.mean_, # pass in the mean of the training targets
    scale=scaler.scale_, # pass in the scale of the training targets
)

## Batch Norm
A `Batch Norm` normalizes the outputs of the aggregation by re-centering and re-scaling.

Whether to use batch norm

In [None]:
batch_norm = True

## Metrics
`Metrics` are the ways to evaluate the performance of model predictions.

Available options can be found in `metrics.MetricRegistry`, including

In [None]:
print(nn.metrics.MetricRegistry)

In [None]:
metric_list = [nn.metrics.RMSEMetric(), nn.metrics.MAEMetric()] # Only the first metric is used for training and early stopping

## Constructs MPNN

In [None]:
mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list)

mpnn

# Set up trainer

In [None]:
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20, # number of epochs to train for
)

# Start training

In [None]:
trainer.fit(mpnn, train_loader, val_loader)

# Test results

In [None]:
results = trainer.test(mpnn, test_loader)
