In [1]:
from pytorch_lightning import seed_everything
seed_everything(42, workers=True)

# supress warnings
import warnings
from pytorch_lightning.utilities.warnings import PossibleUserWarning
warnings.filterwarnings("ignore", category=PossibleUserWarning)
import torch as t

Global seed set to 42


In [2]:
from chemxor.data import OlindaCDataModule, OlindaRDataModule
from chemxor.model import FHEOlindaNet, FHEOlindaNetOne, FHEOlindaNetZero, OlindaNet, OlindaNetOne, OlindaNetZero
from chemxor.utils import prepare_fhe_input, evaluate_fhe_model

In [3]:
model = OlindaNet(output=1)
fhe_model = FHEOlindaNet(model=model)

In [4]:
dm = OlindaRDataModule(model=fhe_model)
dm.setup("train")
train_loader = dm.train_dataloader()
enc_train_loader = dm.enc_train_dataloader(fhe_model.enc_context)

In [5]:
sample = next(iter(train_loader))
enc_sample = next(iter(enc_train_loader))

In [6]:
enc_sample

[<tenseal.tensors.ckksvector.CKKSVector at 0x7f84245efe80>,
 <tenseal.tensors.ckkstensor.CKKSTensor at 0x7f8424499250>,
 100]

In [7]:
len(enc_sample[0].decrypt())

1600

In [8]:
enc_out_0 = fhe_model(enc_sample[0], 0)
enc_out_0.shape

[3200]

In [9]:
input_1 = prepare_fhe_input(enc_out_0.decrypt(), fhe_model.pre_process[0], fhe_model.enc_context)

In [10]:
enc_out_1 = fhe_model(input_1, 1)
enc_out_1.shape

[2048]

In [11]:
input_2 = prepare_fhe_input(enc_out_1.decrypt(), fhe_model.pre_process[1], fhe_model.enc_context)

In [12]:
enc_out_2 = fhe_model(input_2, 2)
enc_out_2.shape

[512]

In [13]:
input_3 = prepare_fhe_input(enc_out_2.decrypt(), fhe_model.pre_process[2], fhe_model.enc_context)

In [14]:
enc_out_3 = fhe_model(input_3, 3)
enc_out_3.shape

[1]

In [15]:
enc_out_3.decrypt()

[1.4199801266488208]

# Normal model evaluation

In [16]:
out_1 = model(sample[0])
out_1.shape

torch.Size([32, 1])

In [17]:
out_1

tensor([[1.4021],
        [1.4024],
        [1.4023],
        [1.4020],
        [1.4026],
        [1.4030],
        [1.4020],
        [1.4023],
        [1.4026],
        [1.4027],
        [1.4020],
        [1.4029],
        [1.4029],
        [1.4024],
        [1.4022],
        [1.4021],
        [1.4017],
        [1.4033],
        [1.4018],
        [1.4026],
        [1.4023],
        [1.4024],
        [1.4034],
        [1.4026],
        [1.4030],
        [1.4029],
        [1.4025],
        [1.4024],
        [1.4022],
        [1.4023],
        [1.4028],
        [1.4017]], grad_fn=<AddmmBackward0>)