In [None]:
import crypten
import torch
import psutil

from mpc.mpc_profile import *

crypten.init()

## Example usage

In [None]:
# Standard values
tensor_size = [60000, 1, 28, 28]
tensor_dtype = torch.float
iters = 10 # number of iterations for each benchmark

@profile_tensor
def torch_benchmark():
    tensor = torch.ones(tensor_size, dtype=tensor_dtype)
    #print(tensor[0])
    #print(tensor.shape)
    return tensor
    
@profile_tensor
def crypten_benchmark(ptype="plain"):
    if ptype==crypten.mpc.ptype.binary:
        tensor = crypten.cryptensor(torch.ones(tensor_size, dtype=tensor_dtype), ptype=ptype)
    elif ptype==crypten.mpc.ptype.arithmetic:
        tensor = crypten.cryptensor(torch.ones(tensor_size, dtype=tensor_dtype), ptype=ptype)
    #print(tensor.shape)
    return tensor

In [None]:
tensor, torch_profile = torch_benchmark(iters=iters);
print(tensor.shape)
del tensor

In [None]:
ptype = crypten.mpc.ptype.arithmetic
tensor, crypten_profile = crypten_benchmark(ptype=ptype, iters=iters);
print(tensor.shape)
del tensor

In [None]:
ptype = crypten.mpc.ptype.binary
tensor, crypten_bin_profile = crypten_benchmark(ptype=ptype, iters=iters);
print(tensor.shape)
del tensor

In [None]:
torch_profile

In [None]:
crypten_profile

## Plots

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

colors = sns.color_palette()

def print_styling(figsize=(10,8)):
    plt.rc('figure', figsize=figsize) 
    plt.style.use("seaborn")
    sns.set_palette("dark")
    SMALL_SIZE = 15
    MEDIUM_SIZE = 18
    BIGGER_SIZE = 26

    plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
    plt.rc('lines', linewidth=2)

    plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
    plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
    plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title
    plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
    
    plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
    plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
    
def plot_mem_usage(pos_text=None, color_text="black"):
    print_styling()
    torch_mem = torch_profile["consumed"]
    crypten_mem = crypten_profile["consumed"]
    
    plt.bar(x=["torch", "crypten"],height=[torch_mem, crypten_mem], color=colors)
    plt.title("Memory usage of torch tensor vs cryptensor in memory")
    plt.ylabel("MB");
    if pos_text==None:
        plt.text(0, crypten_mem, f"tensor.Size([{tensor_size[0]}, 1, 28, 28])\ni.e. size of MNIST training set", 
             ha="center", va="top" )
    else:
        plt.text(*pos_text, f"tensor.Size([{tensor_size[0]}, 1, 28, 28])\ni.e. size of MNIST training set", 
             ha="center", va="top", color=color_text, bbox={"facecolor": "gray"})

    plt.text(0, int(torch_mem-10), f"{torch_mem:.0f}MB", color="white", 
             ha="center", va="top" )
    plt.text(1, int(crypten_mem-10), 
             f"{crypten_mem:.0f}MB\n~{crypten_mem / torch_mem:.1f}x torch", 
             color="white", 
             ha="center", va="top" )
    plt.savefig(f"docs/figs/mem_crypt_vs_torch_tensor_size{tensor_size[0]}_dtype-{str(tensor_dtype)}.pdf", 
                bbox_inches="tight", pad_inches=0.2)

def plot_exec_time(pos_text=None, color_text="black"):
    print_styling()
    torch_mem = torch_profile["time"]
    crypten_mem = crypten_profile["time"]
    
    plt.bar(x=["torch", "crypten"],height=[torch_mem, crypten_mem], color=colors)
    plt.title("Execution time to load into memory")
    plt.ylabel("s");
    if pos_text==None:
        plt.text(0, crypten_mem, f"tensor.Size([{tensor_size[0]}, 1, 28, 28])\ni.e. size of MNIST training set", 
             ha="center", va="top" )
    else:
        plt.text(*pos_text, f"tensor.Size([{tensor_size[0]}, 1, 28, 28])\ni.e. size of MNIST training set", 
             ha="center", va="top", color=color_text, bbox={"facecolor": "gray"})

    plt.text(0, torch_mem, f"{torch_mem:.3f}s", color="black", 
             ha="center", va="bottom" )
    plt.text(1, crypten_mem, 
             f"{crypten_mem:.2f}s\n~{crypten_mem / torch_mem:.1f}x torch", 
             color="white", 
             ha="center", va="top" )
    plt.savefig(f"docs/figs/time_crypt_vs_torch_tensor_size{tensor_size[0]}_dtype-{str(tensor_dtype)}.pdf", 
                bbox_inches="tight", pad_inches=0.2)

## Produce plots

In [None]:
tensor_size = [60000, 1, 28, 28]
tensor_dtype =torch.float


tensor, torch_profile = torch_benchmark(iters=iters);
print(tensor.shape)
del tensor
ptype = crypten.mpc.ptype.arithmetic
tensor, crypten_profile = crypten_benchmark(ptype=ptype, iters=iters);
print(tensor.shape)
del tensor

plot_mem_usage()
plt.show()
plot_exec_time()

In [None]:
tensor_size = [60000, 1, 28, 28]
tensor_dtype = torch.double

tensor, torch_profile = torch_benchmark(iters=iters);
print(tensor.shape)
del tensor
ptype = crypten.mpc.ptype.arithmetic
tensor, crypten_profile = crypten_benchmark(ptype=ptype, iters=iters);
print(tensor.shape)
del tensor

plot_mem_usage((0, 50), color_text="white")
plt.show()
plot_exec_time()

In [None]:
tensor_size = [15000, 1, 28, 28]
tensor_dtype = torch.double

tensor, torch_profile = torch_benchmark(iters=iters);
print(tensor.shape)
del tensor
ptype = crypten.mpc.ptype.arithmetic
tensor, crypten_profile = crypten_benchmark(ptype=ptype, iters=iters);
print(tensor.shape)
del tensor

plot_mem_usage((0, 50), color_text="white")
plt.show()
plot_exec_time()

In [None]:
tensor_size = [30000, 1, 28, 28]
tensor_dtype = torch.double

tensor, torch_profile = torch_benchmark(iters=iters);
print(tensor.shape)
del tensor
ptype = crypten.mpc.ptype.arithmetic
tensor, crypten_profile = crypten_benchmark(ptype=ptype, iters=iters);
print(tensor.shape)
del tensor

plot_mem_usage((0, 50), color_text="white")
plt.show()
plot_exec_time()