# Import packages

In [None]:
import numpy as np
import pandas as pd
import torch
from lightning import pytorch as pl

from chemprop import data, featurizers
from chemprop.models import multi

# Change model input here

In [None]:
checkpoint_path = '../tests/data/example_model_v2_regression_multi.ckpt' # path to the checkpoint file. 
# If the checkpoint file is generated using the training notebook, it will be in the `checkpoints` folder with name similar to `checkpoints/epoch=19-step=180.ckpt`.

## Load model

In [None]:
mcmpnn = multi.MulticomponentMPNN.load_from_checkpoint(checkpoint_path)
mcmpnn

# Change predict input here

In [None]:
test_path = '../tests/data/regression/mol+mol.csv' # path to your .csv file containing SMILES strings to make predictions for
smiles_columns = ['smiles', 'solvent'] # name of the column containing SMILES strings

## Load test smiles

In [None]:
df_test = pd.read_csv(test_path)
df_test

## Get smiles

In [None]:
smiss = df_test[smiles_columns].values
smiss[:5]

## Get molecule datapoints

In [None]:
n_componenets = len(smiles_columns)
test_datapointss = [[data.MoleculeDatapoint.from_smi(smi) for smi in smiss[:, i]] for i in range(n_componenets)]

## Get molecule datasets

In [None]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
test_dsets = [data.MoleculeDataset(test_datapoints, featurizer) for test_datapoints in test_datapointss]

# Get multicomponent dataset and data loader

In [None]:
test_mcdset = data.MulticomponentDataset(test_dsets)
test_loader = data.MolGraphDataLoader(test_mcdset, shuffle=False)

# Set up trainer

In [None]:
with torch.inference_mode():
    trainer = pl.Trainer(
        logger=None,
        enable_progress_bar=True,
        accelerator="auto",
        devices=1
    )
    test_preds = trainer.predict(mcmpnn, test_loader)

In [None]:
test_preds = np.concatenate(test_preds, axis=0)
df_test['pred'] = test_preds
df_test