# Tutorial 1: Basics of CrypTen Tensors

We now have a high-level understanding of how secure MPC works. Through these tutorials, we will explain how to use CrypTen to carry out secure operations on encrypted tensors. In this tutorial, we will introduce a fundamental building block in CrypTen, called a ```CrypTensor```.  ```CrypTensor```s are encrypted ```torch``` tensors that can be used for computing securely on data. 

CrypTen currently only supports secure MPC protocols (though we intend to add support for other advanced encryption protocols). Using the ```mpc``` backend, ```CrypTensor```s act as ```torch``` tensors whose values are encrypted using secure MPC protocols. Tensors created using the ```mpc``` backend are called ```MPCTensor```s. We will go into greater detail about ```MPCTensors``` in Tutorial 2. 

Let's begin by importing ```crypten``` and ```torch``` libraries. (If the imports fail, please see the installation instructions in the README.)

In [None]:
import torch
import crypten

crypten.init()

### Creating Encrypted Tensors
CrypTen provides a ```crypten.cryptensor``` factory function, similar to ```torch.tensor```, to make creating ```CrypTensors``` easy. 

Let's begin by creating a ```torch``` tensor and encrypting it using ```crypten.cryptensor```. To decrypt a ```CrypTensor```, use ```get_plain_text()``` to return the original tensor.  (```CrypTensor```s can also be created directly from a list or an array.)


In [None]:
# Create torch tensor
x = torch.tensor([1.0, 2.0, 3.0])

# Encrypt x
x_enc = crypten.cryptensor(x)

# Decrypt x
x_dec = x_enc.get_plain_text()   
print(x_dec)


# Create python list
y = [4.0, 5.0, 6.0]

# Encrypt x
y_enc = crypten.cryptensor(y)

# Decrypt x
y_dec = y_enc.get_plain_text()
print(y_dec)

## Operations on Encrypted Tensors
Now let's look at what we can do with our ```CrypTensors```.

#### Arithmetic Operations
We can carry out regular arithmetic operations between ```CrypTensors```, as well as between ```CrypTensors``` and plaintext tensors. Note that these operations never reveal any information about encrypted tensors (internally or externally) and return an encrypted tensor output.

In [None]:
#Arithmetic operations between CrypTensors and plaintext tensors
x_enc = crypten.cryptensor([1.0, 2.0, 3.0])

y = 2.0
y_enc = crypten.cryptensor(2.0)


# Addition
z_enc1 = x_enc + y      # Public
z_enc2 = x_enc + y_enc  # Private
print("\nPublic  addition:", z_enc1.get_plain_text())
print("Private addition:", z_enc2.get_plain_text())


# Subtraction
z_enc1 = x_enc - y      # Public
z_enc2 = x_enc - y_enc  # Private
print("\nPublic  subtraction:", z_enc1.get_plain_text())
print("Private subtraction:", z_enc2.get_plain_text())

# Multiplication
z_enc1 = x_enc * y      # Public
z_enc2 = x_enc * y_enc  # Private
print("\nPublic  multiplication:", z_enc1.get_plain_text())
print("Private multiplication:", z_enc2.get_plain_text())

# Division
z_enc1 = x_enc / y      # Public
z_enc2 = x_enc / y_enc  # Private
print("\nPublic  division:", z_enc1.get_plain_text())
print("Private division:", z_enc2.get_plain_text())

#### Comparisons
Similarly, we can compute element-wise comparisons on ```CrypTensors```. Like arithmetic operations, comparisons performed on ```CrypTensor```s will return a ```CrypTensor``` result. Decrypting these result ```CrypTensor```s will evaluate to 0's and 1's corresponding to ```False``` and ```True``` values respectively.

In [None]:
#Construct two example CrypTensors
x_enc = crypten.cryptensor([1.0, 2.0, 3.0, 4.0, 5.0])

y = torch.tensor([5.0, 4.0, 3.0, 2.0, 1.0])
y_enc = crypten.cryptensor(y)

# Print values:
print("x: ", x_enc.get_plain_text())
print("y: ", y_enc.get_plain_text())

# Less than
z_enc1 = x_enc < y      # Public
z_enc2 = x_enc < y_enc  # Private
print("\nPublic  (x < y) :", z_enc1.get_plain_text())
print("Private (x < y) :", z_enc2.get_plain_text())

# Less than or equal
z_enc1 = x_enc <= y      # Public
z_enc2 = x_enc <= y_enc  # Private
print("\nPublic  (x <= y):", z_enc1.get_plain_text())
print("Private (x <= y):", z_enc2.get_plain_text())

# Greater than
z_enc1 = x_enc > y      # Public
z_enc2 = x_enc > y_enc  # Private
print("\nPublic  (x > y) :", z_enc1.get_plain_text())
print("Private (x > y) :", z_enc2.get_plain_text())

# Greater than or equal
z_enc1 = x_enc >= y      # Public
z_enc2 = x_enc >= y_enc  # Private
print("\nPublic  (x >= y):", z_enc1.get_plain_text())
print("Private (x >= y):", z_enc2.get_plain_text())

# Equal
z_enc1 = x_enc == y      # Public
z_enc2 = x_enc == y_enc  # Private
print("\nPublic  (x == y):", z_enc1.get_plain_text())
print("Private (x == y):", z_enc2.get_plain_text())

# Not Equal
z_enc1 = x_enc != y      # Public
z_enc2 = x_enc != y_enc  # Private
print("\nPublic  (x != y):", z_enc1.get_plain_text())
print("Private (x != y):", z_enc2.get_plain_text())


#### Advanced mathematics
We are also able to compute more advanced mathematical functions on ```CrypTensors``` using iterative approximations. CrypTen provides MPC support for functions like reciprocal, exponential, logarithm, square root, tanh, etc. Notice that these are subject to numerical error due to the approximations used. 

Additionally, note that some of these functions will fail silently when input values are outside of the range of convergence for the approximations used. These do not produce errors because value are encrypted and cannot be checked without decryption. Exercise caution when using these functions. (It is good practice here to normalize input values for certain models.)

In [None]:
torch.set_printoptions(sci_mode=False)

#Construct example input CrypTensor
x = torch.tensor([0.1, 0.3, 0.5, 1.0, 1.5, 2.0, 2.5])
x_enc = crypten.cryptensor(x)

# Reciprocal
z = x.reciprocal()          # Public
z_enc = x_enc.reciprocal()  # Private
print("\nPublic  reciprocal:", z)
print("Private reciprocal:", z_enc.get_plain_text())

# Logarithm
z = x.log()          # Public
z_enc = x_enc.log()  # Private
print("\nPublic  logarithm:", z)
print("Private logarithm:", z_enc.get_plain_text())

# Exp
z = x.exp()          # Public
z_enc = x_enc.exp()  # Private
print("\nPublic  exponential:", z)
print("Private exponential:", z_enc.get_plain_text())

# Sqrt
z = x.sqrt()          # Public
z_enc = x_enc.sqrt()  # Private
print("\nPublic  square root:", z)
print("Private square root:", z_enc.get_plain_text())

# Tanh
z = x.tanh()          # Public
z_enc = x_enc.tanh()  # Private
print("\nPublic  tanh:", z)
print("Private tanh:", z_enc.get_plain_text())


## Control Flow using Encrypted Tensors

Note that ```CrypTensors``` cannot be used directly in conditional expressions. Because the tensor is encrypted, the boolean expression cannot be evaluated unless the tensor is decrypted first. Attempting to execute control flow using an encrypted condition will result in an error.

Some control flow can still be executed without decrypting, but must be executed using mathematical expressions. We have provided the function ```crypten.where(condition, x, y)``` to abstract this kind of conditional value setting.

The following example illustrates how to write this kind conditional logic for ```CrypTensors```.

In [None]:
x_enc = crypten.cryptensor(2.0)
y_enc = crypten.cryptensor(4.0)

a, b = 2, 3

# Normal Control-flow code will raise an error
try:
    if x_enc < y_enc:
        z = a
    else:
        z = b
except RuntimeError as error:
    print(f"RuntimeError caught: \"{error}\"\n")

    
# Instead use a mathematical expression
use_a = (x_enc < y_enc)
z_enc = use_a * a + (1 - use_a) * b
print("z:", z_enc.get_plain_text())
    
    
# Or use the `where` function
z_enc = crypten.where(x_enc < y_enc, a, b)
print("z:", z_enc.get_plain_text())

### Advanced Indexing
CrypTen supports many of the operations that work on ```torch``` tensors. Encrypted tensors can be indexed, concatenated, stacked, reshaped, etc. For a full list of operations, see the CrypTen documentation.

In [None]:
x_enc = crypten.cryptensor([1.0, 2.0, 3.0])
y_enc = crypten.cryptensor([4.0, 5.0, 6.0])

# Indexing
z_enc = x_enc[:-1]
print("Indexing:\n", z_enc.get_plain_text())

# Concatenation
z_enc = crypten.cat([x_enc, y_enc])
print("\nConcatenation:\n", z_enc.get_plain_text())

# Stacking
z_enc = crypten.stack([x_enc, y_enc])
print('\nStacking:\n', z_enc.get_plain_text())

# Reshaping
w_enc = z_enc.reshape((-1, 6))
print('\nReshaping:\n', w_enc.get_plain_text())



### Implementation Note

Due to internal implementation details, ```CrypTensors``` must be the first operand of operations that combine ```CrypTensor```s and ```torch``` tensors. That is, for a ```CrypTensor``` ```x_enc``` and a plaintext tensor ```y```:
- The expression ```x_enc < y``` is valid, but the equivalent expression ```y > x_enc``` will result in an error.
- The expression ```x_enc + y``` is valid, but the equivalent expression ```y + x_enc``` will result in an error.

We intend to add support for both expressions in the future.