In [None]:
!pip install nibabel
!pip install torchio

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets,transforms, models
import torchvision.transforms.functional as TF
import nibabel as nib
from pathlib import Path
import onnxruntime
from sklearn.metrics import precision_recall_fscore_support
from scipy import stats

In [1]:
# to match fMRI ICA 100 components
batch_size = 100

StatementMeta(80619585-8c1b-471a-852e-6b5f92e90b7a, 0, 15, Finished, Available)

In [2]:

class MultilayerPerceptron2(nn.Module):
  # whatever op layer is becomes num ip for next hidden layer
  def __init__(self, input_size=45*54*45, output_size=58):
    super().__init__()
    N = 200
    self.d1 = nn.Linear(input_size, N)
    self.d2 = nn.Linear(N, N)
    self.d3 = nn.Linear(N, N)
    self.d4 = nn.Linear(N, N)
    self.d5 = nn.Linear(N, output_size)
    self.dropout = nn.Dropout(0.66)
    self.flat = nn.Flatten()
    

  def forward(self,X):

    X = self.flat(X) #X.view(-1,45*54*45)
    X = F.relu(self.d1(X))
    X = F.relu(self.d3(X))
    X = self.dropout(X)  
    X = F.relu(self.d4(X))
    X = self.d5(X)
    #X = torch.squeeze(X)

    return X

StatementMeta(80619585-8c1b-471a-852e-6b5f92e90b7a, 0, 16, Finished, Available)

In [None]:
# Replace the path below with the path to your model weights

device = 'cuda' if torch.cuda.is_available() else "cpu"
print(f'Device is {device}')
model_final =  torch.load('/path/to/model_mlp_final.pth', map_location=torch.device('cpu'))

model = MultilayerPerceptron2()


model.load_state_dict(model_final['model'])
model = model.to(device)
model.eval()


In [None]:
# Replace the path below with the path to your test data
volPath = "/home/user/path/to/test_data.nii.gz"
niii_mg = nib.load(volPath).get_fdata()

In [7]:
niii_mg.shape

(45, 54, 45, 176)

In [8]:
input_volume = niii_mg[:,:,:,0:batch_size].reshape(batch_size,1,45,54,45)
input_volume = torch.Tensor(input_volume) #.view(-1,45*54*45)
input_volume = input_volume.to(device)

In [9]:
print(input_volume.shape)
print(niii_mg[:,:,:,0:batch_size].shape)

torch.Size([100, 1, 45, 54, 45])
(45, 54, 45, 100)


In [54]:
model.load_state_dict(model_final['model'])
model = model.to(device)
model.eval()

probs = F.softmax(model(input_volume).detach(), dim=1) 

_, predictions = probs.max(1)
print(predictions)

tensor([42, 42, 42, 42, 42, 42, 42,  6, 42, 32, 56, 44, 10, 56, 51, 42, 10, 34,
        49, 37, 29, 10, 11,  4, 11, 11, 20, 11, 46, 46, 11, 44, 42, 50,  2, 11,
        11, 10, 28, 11, 44, 20, 24, 10, 11, 10, 21, 41, 11, 11, 11, 46, 11, 11,
        24, 53, 38, 10, 28, 11, 11,  7, 24, 11, 11, 10, 46, 49, 11, 11,  7, 11,
        11, 25, 56, 28,  2, 11, 19, 56, 23, 10, 11, 11, 56, 39, 56, 56, 56, 39,
        11, 46, 44, 42, 49, 42, 42, 42, 42, 42])


In [11]:
np.save("input_volume_100batch.npy", input_volume.detach().numpy())

In [12]:
output_onnx = str(Path("mlp_model.onnx"))

# Export an ONNX model.
with torch.no_grad():
    torch.onnx.export(
        model=model,
        # Using a fixed batch size of 1 since EzPC doesn't allow dynamic batch size
        args=(input_volume),
        f=output_onnx, 
        opset_version=13,
        verbose=True,
        input_names=["image"], 
        output_names=[ "score"],
    )

Exported graph: graph(%image : Float(100, 1, 45, 54, 45, strides=[109350, 109350, 2430, 45, 1], requires_grad=0, device=cpu),
      %d1.weight : Float(200, 109350, strides=[109350, 1], requires_grad=1, device=cpu),
      %d1.bias : Float(200, strides=[1], requires_grad=1, device=cpu),
      %d3.weight : Float(200, 200, strides=[200, 1], requires_grad=1, device=cpu),
      %d3.bias : Float(200, strides=[1], requires_grad=1, device=cpu),
      %d4.weight : Float(200, 200, strides=[200, 1], requires_grad=1, device=cpu),
      %d4.bias : Float(200, strides=[1], requires_grad=1, device=cpu),
      %d5.weight : Float(58, 200, strides=[200, 1], requires_grad=1, device=cpu),
      %d5.bias : Float(58, strides=[1], requires_grad=1, device=cpu)):
  %onnx::Gemm_11 : Float(100, 109350, strides=[109350, 1], requires_grad=0, device=cpu) = onnx::Flatten[axis=1, onnx_name="Flatten_0"](%image) # /anaconda/envs/azureml_py38/lib/python3.8/site-packages/torch/nn/modules/flatten.py:45:0
  %onnx::Relu_12 : 

In [55]:
session = onnxruntime.InferenceSession(output_onnx)
ort_output = session.run(
    output_names=None,
    input_feed={ 
        "image": input_volume.numpy(),
    },
)
ort_probs = F.softmax(torch.tensor(ort_output[0]), dim=1) 
_, ort_predictions = ort_probs.max(1)
print(ort_predictions)

tensor([42, 42, 42, 42, 42, 42, 42,  6, 42, 32, 56, 44, 10, 56, 51, 42, 10, 34,
        49, 37, 29, 10, 11,  4, 11, 11, 20, 11, 46, 46, 11, 44, 42, 50,  2, 11,
        11, 10, 28, 11, 44, 20, 24, 10, 11, 10, 21, 41, 11, 11, 11, 46, 11, 11,
        24, 53, 38, 10, 28, 11, 11,  7, 24, 11, 11, 10, 46, 49, 11, 11,  7, 11,
        11, 25, 56, 28,  2, 11, 19, 56, 23, 10, 11, 11, 56, 39, 56, 56, 56, 39,
        11, 46, 44, 42, 49, 42, 42, 42, 42, 42])


In [14]:
# Setup EzPC
!./ezpc_setup.sh

Get:1 file:/var/nccl-repo-2.2.13-ga-cuda9.2  InRelease
Ign:1 file:/var/nccl-repo-2.2.13-ga-cuda9.2  InRelease
Get:2 file:/var/nccl-repo-2.2.13-ga-cuda9.2  Release [574 B]                   [0m
Hit:3 http://azure.archive.ubuntu.com/ubuntu focal InRelease                   [0m
Hit:4 http://azure.archive.ubuntu.com/ubuntu focal-updates InRelease           
Hit:5 http://azure.archive.ubuntu.com/ubuntu focal-backports InRelease         
Hit:6 http://azure.archive.ubuntu.com/ubuntu focal-security InRelease          
Get:2 file:/var/nccl-repo-2.2.13-ga-cuda9.2  Release [574 B]                   [0m
Hit:7 https://nvidia.github.io/libnvidia-container/stable/ubuntu18.04/amd64  InRelease
Hit:8 https://cloud.r-project.org/bin/linux/ubuntu focal-cran40/ InRelease     [0m
Hit:9 https://nvidia.github.io/nvidia-container-runtime/stable/ubuntu18.04/amd64  InRelease
Hit:10 https://apt.repos.intel.com/mkl all InRelease                           [0m
Hit:11 https://nvidia.github.io/nvidia-docker/ubunt

In [15]:
# run EzPC secure compilation and key generation
!./ezpc_secure_compile.sh

[32;3m2023-07-21 20:14:01,709 - OnnxBridge - INFO <<<<- Application Started ->>>> (backend.py:58)
[0m[32;3m2023-07-21 20:14:01,709 - OnnxBridge - INFO <<<<- Loading onnx graph: mlp_model ->>>> (backend.py:62)
[0m[32;3m2023-07-21 20:14:03,608 - OnnxBridge - INFO <<<<- Model Received : opset version : 13 ->>>> (backend.py:66)
[0m[32;3m2023-07-21 20:14:03,609 - OnnxBridge - INFO <<<<- Batch Size : 100 ->>>> (backend.py:69)
[0m[32;3m2023-07-21 20:14:05,976 - OnnxBridge - INFO <<<<- Model Optimized ->>>> (backend.py:75)
[0m[32;3m2023-07-21 20:14:06,231 - OnnxBridge - INFO <<<<- Shape Inference Done ->>>> (backend.py:78)
[0m[32;3m2023-07-21 20:14:06,231 - OnnxBridge - INFO <<<<- Model is OK! ->>>> (backend.py:83)
[0m[32;3m2023-07-21 20:14:36,059 - OnnxBridge - INFO <<<<- Dumping model weights in:
 /mnt/batch/tasks/shared/LS_root/mounts/clusters/user2/code/Users/user/EzPC/OnnxBridge/workdir/mlp_model_input_weights.dat ->>>> (backend.py:101)
[0m[32;3m2023-07-21 20:14:36,059 - 

In [16]:
# Replace the path below with the path to your files
!./ezpc_run_on_single_machine.sh "/home/user/path/to/input_volume_100batch.npy" {batch_size}

FLoat point output saved in  input_volume_100batch.inp
Key Size: 264124032 bytes
waiting for connection from client...Key Size: 528349616 bytes
trying to connect with server...connectedconnected

=== COMPUTATION START ===

=== COMPUTATION START ===

>> MatMul2D - Start
>> MatMul2D - Start
   Key Read Time =    Key Read Time = 376 milliseconds
   Compute Time = 15645.2 milliseconds
      Eigen Time = 15645.2 milliseconds
   Reconstruct Time = 0.737 milliseconds
   Online Time = 15645.9 milliseconds
   Online Comm = 320000 bytes
6313 milliseconds
   Compute Time = 15637.6 milliseconds
      Eigen Time = 15637.6 milliseconds
   Reconstruct Time = 8.291 milliseconds
   Online Time = 15645.9 milliseconds
   Online Comm = 320000 bytes
>> MatMul2D - End
>> MatMul2D - End
>> Relu (Spline) - Start
>> Relu (Spline) - Start
   Key Read Time = 1033 milliseconds
   Compute Time = 150.222 milliseconds
   Reconstruct Time = 0.77 milliseconds
   Online Time = 150.992 milliseconds
   Online Comm = 3600

In [17]:
# Replace the path below with the path to your files
ezpc_output_file="/home/user/EzPC/OnnxBridge/workdir/ezpc_prediction.dat"

In [56]:
with open(ezpc_output_file) as f:
    lines = [[float(x) for x in line.strip().split(" ")] for line in f]

ezpc_probs = F.softmax(torch.as_tensor(lines), dim=1)
_, ezpc_predictions = ezpc_probs.max(1)
print(ezpc_predictions)

tensor([42, 42, 42, 42, 42, 42, 42,  6, 42, 32, 56, 44, 10, 56, 51, 42, 10, 34,
        49, 37, 29, 10, 11,  4, 11, 11, 20, 11, 46, 46, 11, 44, 42, 50,  2, 11,
        11, 10, 28, 11, 44, 20, 24, 10, 11, 10, 21, 41, 11, 11, 11, 46, 11, 11,
        24, 53, 38, 10, 28, 11, 11,  7, 24, 11, 11, 10, 46, 49, 11, 11,  7, 11,
        11, 25, 56, 28,  2, 11, 19, 56, 23, 10, 11, 11, 56, 39, 56, 56, 56, 39,
        11, 46, 44, 42, 49, 42, 42, 42, 42, 42])


In [49]:
def compare_tensor_outputs(actual, predicted):
    # calculate absolute error across all classes
    error = torch.abs(actual - predicted).flatten()
    mean = torch.mean(error)
    std_dev = torch.std(error)
    ci = stats.norm.interval(alpha=0.95, loc=mean, scale=stats.sem(error))
    print(f"absolute error: mean: {mean}, standard deviation: {std_dev}, confidence interval: {ci}, max: {torch.max(error)}")
    print(f"===== Mean \u00b1 CI: {mean:.4E} \u00b1 {(ci[1] - ci[0]) / 2.0:.2E} ======")

In [57]:
# errors between original pyTorch outputs and EzPC 
compare_tensor_outputs(probs, ezpc_probs)

absolute error: mean: 5.204777335166e-05, standard deviation: 0.00016560518997721374, confidence interval: (4.778582642261696e-05, 5.6309720280703034e-05), max: 0.0060277581214904785


In [59]:
# errors between EzPC and plain ONNX inference
compare_tensor_outputs(ort_probs, ezpc_probs)

absolute error: mean: 5.20477224199567e-05, standard deviation: 0.0001656016247579828, confidence interval: (4.778586724392669e-05, 5.630957759598671e-05), max: 0.006027638912200928


In [60]:
# errors between original pyTorch outputs and plain ONNX inference
compare_tensor_outputs(probs, ort_probs)

absolute error: mean: 4.411621290500989e-09, standard deviation: 2.207849369995074e-08, confidence interval: (3.8434182744279515e-09, 4.979824306574026e-09), max: 7.748603820800781e-07


In [36]:
def print_metrics(actual, predicted):
    scores = precision_recall_fscore_support(actual,  predicted, average="micro")
    print(f"precision: {scores[0]}, recall: {scores[1]}, f1 score: {scores[2]}")

In [37]:
# Metrics for original pyTorch and EzPC predictions
print_metrics(predictions,  ezpc_predictions)

precision: 1.0, recall: 1.0, f1 score: 1.0


In [38]:
# Metrics for original pyTorch and plain ONNX predictions
print_metrics(predictions,  ort_predictions)

precision: 1.0, recall: 1.0, f1 score: 1.0
