# FHE OlindaNet Models

In [1]:
from pytorch_lightning import seed_everything
seed_everything(42, workers=True)

# supress warnings
import warnings
from pytorch_lightning.utilities.warnings import PossibleUserWarning
warnings.filterwarnings("ignore", category=PossibleUserWarning)
import torch as t

Global seed set to 42


In [2]:
from chemxor.data import OlindaCDataModule, OlindaRDataModule
from chemxor.model import FHEOlindaNet, FHEOlindaNetOne, FHEOlindaNetZero, OlindaNet, OlindaNetOne, OlindaNetZero
from chemxor.utils import prepare_fhe_input, evaluate_fhe_model

OlindaNet models can be wrapped with their respective wrappers to compute FHE inputs

In [3]:
# Intialize normal models
# These models can also load state from a chekpoint
model = OlindaNetZero(output=1)

# Models are wrapped with their respective wrappers
fhe_model = FHEOlindaNetZero(model=model)

Chemxor's DataModules can generate encrypted inputs for testing models

In [4]:
# Initialize datamodule and pass the FHE model
dm = OlindaRDataModule(model=fhe_model)
dm.setup("train")

# Create the encrypted dataloader
enc_train_loader = dm.enc_train_dataloader(fhe_model.enc_context)

In [5]:
# Encrypted samples can be generated with the dataloader
enc_sample = next(iter(enc_train_loader))

In [6]:
# Encrypted sample contains the input, target and a parameter used for image to column convolutions
# the parameter can be safely ignored
enc_sample

[<tenseal.tensors.ckksvector.CKKSVector at 0x7effdf5dbf10>,
 <tenseal.tensors.ckkstensor.CKKSTensor at 0x7effdf5dbbe0>,
 100]

In [None]:
# Evaluate fhe model with the utility function
enc_out = evaluate_fhe_model(fhe_model, enc_sample[0])
enc_out.decrypt()

## Serving FHE models

FHE models are partitioned to overcome some of the limitations of current FHE schemes. So, to evalaute a single FHE input, multiple round trip happens between the server and the client. Hence the name `PartitionNetServer` and `PartitionNetClient`.

In [None]:
from chemxor.service import PartitionNetServer, PartitionNetClient

In [None]:
model_server = PartitionNetServer(part_net=fhe_model)

# Start the server
# model_server.run()

# Execute this cell in a separate script to start the server at http://localhost:5000

During the initialization, client call the model server to retrieve encryption and model parameters.

At this stage the client generate keys for encryption.

In [None]:
model_client = PartitionNetClient(url="http://localhost:5000/v1/fhe")

Query the model with a convenient interface. It will take care of encrypting the input, the round trips between server and client, and preparing the input for each step.

In [None]:
model_client.query(x="COC(=O)C1=CC=CC2=C1C(=O)C1=CC([N+](=O)[O-])=CC=C21")