In [1]:
from pytorch_lightning import seed_everything
seed_everything(42, workers=True)

# supress warnings
import warnings
from pytorch_lightning.utilities.warnings import PossibleUserWarning
warnings.filterwarnings("ignore", category=PossibleUserWarning)
import torch as t

Global seed set to 42


In [2]:
from chemxor.data import OlindaCDataModule, OlindaRDataModule
from chemxor.model import FHEOlindaNet, FHEOlindaNetOne, FHEOlindaNetZero, OlindaNet, OlindaNetOne, OlindaNetZero
from chemxor.utils import prepare_fhe_input, evaluate_fhe_model

In [3]:
model = OlindaNetZero(output=1)
fhe_model = FHEOlindaNetZero(model=model)

In [4]:
dm = OlindaRDataModule(model=fhe_model)
dm.setup("train")
train_loader = dm.train_dataloader()
enc_train_loader = dm.enc_train_dataloader(fhe_model.enc_context)

In [5]:
sample = next(iter(train_loader))
enc_sample = next(iter(enc_train_loader))

In [6]:
enc_sample

[<tenseal.tensors.ckksvector.CKKSVector at 0x7effdf5dbf10>,
 <tenseal.tensors.ckkstensor.CKKSTensor at 0x7effdf5dbbe0>,
 100]

In [7]:
len(enc_sample[0].decrypt())

1600

In [8]:
enc_out_0 = fhe_model(enc_sample[0], 0)
enc_out_0.shape

[3200]

In [9]:
input_1 = prepare_fhe_input(enc_out_0.decrypt(), fhe_model.pre_process[0], fhe_model.enc_context)

In [10]:
enc_out_1 = fhe_model(input_1, 1)
enc_out_1.shape

[256]

In [11]:
input_2 = prepare_fhe_input(enc_out_1.decrypt(), fhe_model.pre_process[1], fhe_model.enc_context)

In [12]:
enc_out_2 = fhe_model(input_2, 2)
enc_out_2.shape

[1]

In [13]:
enc_out_2.decrypt()

[-0.9786247181886254]

In [None]:
# Evaluate fhe model with a single function
enc_out = evaluate_fhe_model(fhe_model, enc_sample[0])
enc_out.decrypt()

# Normal model evaluation

In [14]:
out_1 = model(sample[0])
out_1.shape

torch.Size([32, 1])

In [15]:
out_1

tensor([[-0.9755],
        [-0.9709],
        [-0.9707],
        [-0.9727],
        [-0.9759],
        [-0.9660],
        [-0.9790],
        [-0.9772],
        [-0.9723],
        [-0.9763],
        [-0.9691],
        [-0.9792],
        [-0.9784],
        [-0.9745],
        [-0.9795],
        [-0.9713],
        [-0.9779],
        [-0.9722],
        [-0.9724],
        [-0.9781],
        [-0.9720],
        [-0.9714],
        [-0.9774],
        [-0.9733],
        [-0.9744],
        [-0.9721],
        [-0.9694],
        [-0.9771],
        [-0.9832],
        [-0.9703],
        [-0.9754],
        [-0.9703]], grad_fn=<AddmmBackward0>)