# Import packages

In [None]:
import pandas as pd
from lightning import pytorch as pl
from sklearn.model_selection import train_test_split

from chemprop import data
from chemprop import featurizers
from chemprop import models
from chemprop import nn

# Change data inputs here

In [None]:
input_path = '../tests/data/regression/rxn.csv'
num_workers = 0  # number of workers for dataloader. 0 means using main process for data loading
smiles_column = 'smiles'
target_columns = ['ea']

## Load data

In [None]:
df_input = pd.read_csv(input_path)
df_input

## Load smiles and targets

In [None]:
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values

smis[:5], ys[:5]

## Get datapoints

In [None]:
all_data = [data.ReactionDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]

## Perform data splitting for training, validation, and testing

In [None]:
train_data, val_test_data = train_test_split(all_data, test_size=0.1)
val_data, test_data = train_test_split(val_test_data, test_size=0.5)

# Defining the featurizer

Reactions can be featurized using the ```CondensedGraphOfReactionFeaturizer``` (also labeled ```CGRFeaturizer```).


Use ```_mode``` keyword to set the mode by which a reaction should be featurized into a ```MolGraph```.

Options are can be found with ```featurizers.RxnMode.keys```

In [None]:
for key in featurizers.RxnMode.keys():
    print(key)

In [None]:
featurizer = featurizers.CondensedGraphOfReactionFeaturizer(mode_="PROD_DIFF")

## Get ReactionDatasets

In [None]:
train_dset = data.ReactionDataset(train_data, featurizer)
scaler = train_dset.normalize_targets()

val_dset = data.ReactionDataset(val_data, featurizer)
val_dset.normalize_targets(scaler)
test_dset = data.ReactionDataset(test_data, featurizer)
test_dset.normalize_targets(scaler)

## Get dataloaders

In [None]:
train_loader = data.MolGraphDataLoader(train_dset, num_workers=num_workers)
val_loader = data.MolGraphDataLoader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.MolGraphDataLoader(test_dset, num_workers=num_workers, shuffle=False)

# Change Message-Passing Neural Network (MPNN) inputs here

## Message passing

Message passing blocks must be given the shape of the featurizer's outputs.

Options are `mp = nn.BondMessagePassing()` or `mp = nn.AtomMessagePassing()`

In [None]:
fdims = featurizer.shape # the dimensions of the featurizer, given as (atom_dims, bond_dims).
mp = nn.BondMessagePassing(*fdims)

## Aggregation

In [None]:
print(nn.agg.AggregationRegistry)

In [None]:
agg = nn.MeanAggregation()

## Feed-Forward Network (FFN)

In [None]:
print(nn.PredictorRegistry)

In [None]:
ffn = nn.RegressionFFN()

## Batch norm

In [None]:
batch_norm = True

## Metrics

In [None]:
print(nn.metrics.MetricRegistry)

In [None]:
metric_list = [nn.metrics.RMSEMetric(), nn.metrics.MAEMetric()] 
# Only the first metric is used for training and early stopping

## Construct MPNN

In [None]:
mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list)
mpnn

# Training and testing

## Set up trainer

In [None]:
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,  # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20,  # number of epochs to train for
)

## Start training

In [None]:
trainer.fit(mpnn, train_loader, val_loader)

## Test results

In [None]:
results = trainer.test(mpnn, test_loader)