# Introduction to Encrypted Tensors

Following along here: <https://www.youtube.com/watch?v=CLunSEdSDaA>

## Docs

Main: <https://crypten.readthedocs.io/en/latest/mpctensor.html>

- CrypTensor: https://crypten.readthedocs.io/en/latest/cryptensor.html
- MPCTensor: https://crypten.readthedocs.io/en/latest/mpctensor.html
- Neural Nets: https://crypten.readthedocs.io/en/latest/nn.html

In [None]:
import sys

import torch
import torchvision
import crypten

assert sys.version_info[0] == 3 and sys.version_info[1] == 7, "python 3.7 is required!"

print(f"Okay, good! You have: {sys.version_info[:3]}")
# Now we can init crypten!
crypten.init()

In [None]:
x = crypten.cryptensor([1, 2, 3])
x

In [None]:
# Make it readable
x.get_plain_text()

## Let's test some operations

More operations here: [docs](https://crypten.readthedocs.io/en/latest/cryptensor.html#tensor-operations)

In [None]:
a = (2+x)
a.get_plain_text()

In [None]:
b = (a+x)
b.get_plain_text()

In [None]:
c = x*a
c.get_plain_text()

In [None]:
d = x.dot(a)
d.get_plain_text()

In [None]:
# Lets compute Mean Squared Loss

sql = (x - c)**2
msql = sql.mean()

msql.get_plain_text()

In [None]:
# The pytorch version
x_pt = torch.tensor([1,2,3.])
c_pt = x_pt*(2+x_pt)

sql_pt = (x_pt - c_pt)**2
msql_pt = sql_pt.mean().abs()
print(msql_pt)

## Neural Nets

[Docs](https://crypten.readthedocs.io/en/latest/nn.html)

`crypten.nn` provides modules for defining and training neural networks similar to `torch.nn`.

### From PyTorch to CrypTen

The simplest way to create a CrypTen network is to start with a PyTorch network, and use the `from_pytorch` function to convert it to a CrypTen network. This is particularly useful for pre-trained PyTorch networks that need to be encrypted before use.

In [None]:
import os,sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir) 

from config import PETER_ROOT, DATA_DIR, MNIST_SIZE
from ZeNet.nets import *

In [None]:
torch.set_num_threads(1)

subset = 1/6
train_ratio = 0.75
test_ratio = 1 - train_ratio
batch_size_train = int((subset * MNIST_SIZE) * train_ratio)
batch_size_test = int((subset * MNIST_SIZE) * test_ratio)

print(f"Using train_test ratios: {train_ratio} : {test_ratio}")
print(f"Train batch size: {batch_size_train}")
print(f"Test batch size: {batch_size_test}")

In [None]:
net = Net1()
net

In [None]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST(DATA_DIR, train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST(DATA_DIR, train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

  # If he needs to download it, cause it's not already in the data folder, he/she/it would do so and say so below.

In [None]:
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)

In [None]:
print(f"type: {type(example_data)}")
print(example_data.shape)

print(f"Means that we have {example_data.shape[0]} images of size {example_data.shape[2]}x{example_data.shape[3]} in {example_data.shape[1]} color channels (1 channel = greyscale)")

In [None]:
from plot_mnist import plot_batch

In [None]:
private_net = crypten.nn.from_pytorch(net)