# Training Regression - Multicomponent

In [None]:
import pandas as pd
from lightning import pytorch as pl

from chemprop import data
from chemprop import featurizers
from chemprop import models
from chemprop import nn
from chemprop.nn import metrics
from chemprop.models import multi


# Load data

## Change your data inputs here

In [None]:
input_path = '../tests/data/regression/mol+mol.csv' # path to your data .csv file containing SMILES strings and target values
smiles_columns = ['smiles', 'solvent'] # name of the column containing SMILES strings
target_columns = ['peakwavs_max'] # list of names of the columns containing targets

## Read data

In [None]:
df_input = pd.read_csv(input_path)
df_input

## Get SMILES and targets

In [None]:
smiss = df_input.loc[:, smiles_columns].values
ys = df_input.loc[:, target_columns].values

In [None]:
# Take a look at the first 5 SMILES strings and targets
smiss[:5], ys[:5]

## Make molecule datapoints
Create a list of lists containing the molecule datapoints for each components. The target is stored in the 0th component.

In [None]:
all_data = [[data.MoleculeDatapoint.from_smi(smis[0], y) for smis, y in zip(smiss, ys)]]
all_data += [[data.MoleculeDatapoint.from_smi(smis[i]) for smis in smiss] for i in range(1, len(smiles_columns))]


# Split data

## Change your data splitting inputs here

In [None]:
split_key_molecule_index = 0 # key molecule used for splitting
split = 'random' # type of split
sizes = (0.8, 0.1, 0.1) # sizes of train, validation, and test sets

In [None]:
# available split types
list(data.SplitType.keys())

## Split data based on key molecule

In [None]:
train_data, val_data, test_data = data.split_component(all_data, split=split, sizes=sizes, key_index=split_key_molecule_index)


# Get MoleculeDataset for each components

In [None]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_datasets = [data.MoleculeDataset(train_data[i], featurizer) for i in range(len(smiles_columns))]
val_datasets = [data.MoleculeDataset(val_data[i], featurizer) for i in range(len(smiles_columns))]
test_datasets = [data.MoleculeDataset(test_data[i], featurizer) for i in range(len(smiles_columns))]

# Construct multicomponent dataset and scale the targets

In [None]:
train_mcdset = data.MulticomponentDataset(train_datasets)
scaler = train_mcdset.normalize_targets()
val_mcdset = data.MulticomponentDataset(val_datasets)
val_mcdset.normalize_targets(scaler)
test_mcdset = data.MulticomponentDataset(test_datasets)
test_mcdset.normalize_targets(scaler)


# Construct data loader

In [None]:
train_loader = data.MolGraphDataLoader(train_mcdset)
val_loader = data.MolGraphDataLoader(val_mcdset, shuffle=False)
test_loader = data.MolGraphDataLoader(test_mcdset, shuffle=False)

# Construct multicomponent MPNN

## MulticomponentMessagePassing
- `blocks`: a list of message passing block used for each components
- `n_components`: number of components

In [None]:
mcmp = nn.MulticomponentMessagePassing(
    blocks=[nn.BondMessagePassing() for i in range(len(smiles_columns))],
    n_components=len(smiles_columns),
)

## Aggregation

In [None]:
agg = nn.MeanAggregation()

## RegressionFFN

In [None]:
ffn = nn.RegressionFFN(
    loc=scaler.mean_, # pass in the mean of the training targets
    scale=scaler.scale_, # pass in the scale of the training targets
    input_dim=mcmp.output_dim,
)

## Metrics

In [None]:
metric_list = [metrics.RMSEMetric(), metrics.MAEMetric()] # Only the first metric is used for training and early stopping

## MulticomponentMPNN

In [None]:
mcmpnn = multi.MulticomponentMPNN(
    mcmp,
    agg,
    ffn,
    metrics=metric_list,
)

mcmpnn

# Set up trainer

In [None]:
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20, # number of epochs to train for
)

# Start training

In [None]:
trainer.fit(mcmpnn, train_loader, val_loader)

# Test results

In [None]:
results = trainer.test(mcmpnn, test_loader)
