In [None]:
import sys
import os
sys.path.append("../src/olinda/utils/")
from precalc_descriptors import DescriptorCalculator

os.makedirs("../precalculated_descriptors", exist_ok=True)

dc = DescriptorCalculator("../olinda_reference_library_1k.csv", "../precalculated_descriptors")
dc.calculate()

## A model to distill

You need a trained model for the distillation process. Here we are creating a simple pytorch model(untrained) for demo. A tensorflow model is also supported.

In [None]:
from typing import Any

import pytorch_lightning as pl
from torch import nn
from torch.nn import functional as F


class DemoModel(pl.LightningModule):
    """OlindaNet Zero: Slim(relatively) distillation network."""

    def __init__(
        self: "DemoModel"
    ) -> None:
        """Init."""
        super().__init__()
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 1)

        

    def forward(self: "DemoModel", x: Any) -> Any:
        """Forward function.

        Args:
            x (Any): model input

        Returns:
            Any: model output
        """
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x
###############

model = DemoModel()

## Distillation

In [None]:
from olinda.distillation import distill
from olinda.featurizer import MorganFeaturizer

Quickly test the distillation pipeline for your model with a small reference smiles dataset

In [None]:
#student_model = distill(model, num_data=100)
student_model = distill("/home/jason/zairachem_models/h3d_plasmodium_NF54_June2023") #, num_data=100) #An even smaller test set

In [None]:
x = MorganFeaturizer().featurize(["CCN(CC)CCCC(C)NC1=C2C=CC(=CC2=NC=C1)Cl"]) #chloroquine test molecule
student_model(x)

In [None]:
save_path = "path/to/distilled/model.onnx"
student_model.save(save_path)

In [None]:
import onnx
import onnxruntime as rt
onnx_model = onnx.load(save_path)

onnx_rt = rt.InferenceSession(onnx_model.SerializeToString())
output_names = [n.name for n in onnx_model.graph.output]
onnx_rt.run(output_names, {"input": x})