# Training OlindaNet Models

In [1]:
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
seed_everything(42, workers=True)

# supress warnings
import warnings
from pytorch_lightning.utilities.warnings import PossibleUserWarning
warnings.filterwarnings("ignore", category=PossibleUserWarning)
import torch as t

Global seed set to 42


In [2]:
from chemxor.data import OlindaCDataModule, OlindaRDataModule
from chemxor.model import FHEOlindaNet, FHEOlindaNetOne, FHEOlindaNetZero, OlindaNet, OlindaNetOne, OlindaNetZero
from chemxor.utils import prepare_fhe_input, evaluate_fhe_model, get_package_root_path

## Intialize Model

All the models accept an `output` parameter. Here the model is initialized with `output=1` for regression task

In [3]:
model = OlindaNetZero(output=1)

## Prepare Dataset

ChemXor provides generic `datamodules` for regression and classification tasks. The `datamodules` accepts a `csv_path` parameter to load custom csv datasets. The `datamodules` expect custom csv datasets to have two columns (target, SMILES). A demo csv dataset is provided by the ChemXor library for testing.

In [4]:
dataset_path = get_package_root_path()/"ersilia_output_slim.csv"
dm = OlindaRDataModule(csv_path=dataset_path)
dm.setup("train")
train_loader = dm.train_dataloader()

The dataloaders are usual Pytorch's dataloaders. Iterate dataloaders to look at training samples.

In [5]:
sample = next(iter(train_loader))

## Training

OlindaNet models are compatible with Pytorch Lightning. Use of Pytorch Lightning trainer is recommended for training models.

In [None]:
# Create a callback to checkpoint models
checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath="olindanet_zero",
        save_top_k=3,
        monitor="VAL_Loss",
    )

In [None]:
# Initialize trainer
trainer = pl.Trainer(
    callbacks=[checkpoint_callback],
    accelerator="auto",
    gradient_clip_val=0.5, # Use gradient clipping to control exploding gradients
    val_check_interval=0.10,
)

In [None]:
# pass the model and datamodule to the trainer
trainer.fit(model=model, datamodule=dm)

The trainer logs metrics compatible with TensorBoard