# Classification with Encrypted Neural Networks

In this tutorial, we'll look at how we can achieve the <i>Model Hiding</i> application we discussed in the Introduction. That is, suppose say Alice has a trained model she wishes to keep private, and Bob has some data he wishes to classify while keeping it private. We will see how CrypTen allows Alice and Bob to coordinate and classify the data, while achieving their privacy requirements.

To simulate this scenario, we will begin with Alice training a simple neural network on MNIST data. Then we'll see how Alice and Bob encrypt their network and data respectively, classify the encrypted data and finally decrypt the labels.

## Setup

We first import the `torch` and `crypten` libraries, and initialize `crypten`. We will use a helper script `mnist_utils.py` to split the public MNIST data into Alice's portion and Bob's portion. 

In [1]:
import crypten
import crypten.nn as nn
import torch
import logging

crypten.init()
torch.set_num_threads(1)

In [4]:
layer = nn.Embedding(5, 10)
print(layer.weight)
layer.encrypt(src=0)
data = torch.tensor([1, 2, 0, 1, 2, 3, 4])
data_enc = crypten.cryptensor(data)
output = layer.forward(data_enc)
print(output.get_plain_text())

Parameter containing:
tensor([[-1.3749, -1.7004, -0.7123, -0.6363, -0.0685, -0.0942, -0.5310,  3.0206,
         -0.3940, -1.6077],
        [ 0.8402, -1.5226, -0.7774, -0.9728,  0.2623,  1.9485, -1.1585,  1.8922,
          0.7155,  0.1293],
        [ 0.7431,  1.7675,  0.0258,  0.6285,  1.3089,  0.3253, -2.4253,  0.3428,
         -1.0613,  1.7637],
        [-0.1879, -0.9199,  1.0139,  2.8269,  1.6581, -1.7827,  0.8733,  1.2514,
         -1.3100, -2.0500],
        [-0.3762, -1.0989,  1.9507, -0.7418,  0.3374, -0.7515,  0.3272, -0.1313,
         -0.8547,  1.1360]], requires_grad=True)
tensor([[ 0.8402, -1.5226, -0.7774, -0.9728,  0.2623,  1.9484, -1.1585,  1.8922,
          0.7155,  0.1293],
        [ 0.7430,  1.7675,  0.0258,  0.6285,  1.3089,  0.3253, -2.4253,  0.3428,
         -1.0613,  1.7637],
        [-1.3748, -1.7004, -0.7123, -0.6362, -0.0685, -0.0942, -0.5310,  3.0206,
         -0.3940, -1.6077],
        [ 0.8402, -1.5226, -0.7774, -0.9728,  0.2623,  1.9484, -1.1585,  1.8922,
    

In [6]:
l = torch.nn.Embedding(5, 10)
l.weight = torch.nn.Parameter(layer.weight.get_plain_text())
l(data)

tensor([[ 0.8402, -1.5226, -0.7774, -0.9728,  0.2623,  1.9484, -1.1585,  1.8922,
          0.7155,  0.1293],
        [ 0.7430,  1.7675,  0.0258,  0.6285,  1.3089,  0.3253, -2.4253,  0.3428,
         -1.0613,  1.7637],
        [-1.3748, -1.7004, -0.7123, -0.6362, -0.0685, -0.0942, -0.5310,  3.0206,
         -0.3940, -1.6077],
        [ 0.8402, -1.5226, -0.7774, -0.9728,  0.2623,  1.9484, -1.1585,  1.8922,
          0.7155,  0.1293],
        [ 0.7430,  1.7675,  0.0258,  0.6285,  1.3089,  0.3253, -2.4253,  0.3428,
         -1.0613,  1.7637],
        [-0.1879, -0.9199,  1.0139,  2.8269,  1.6581, -1.7827,  0.8733,  1.2514,
         -1.3100, -2.0500],
        [-0.3762, -1.0989,  1.9507, -0.7418,  0.3374, -0.7515,  0.3272, -0.1313,
         -0.8547,  1.1360]], grad_fn=<EmbeddingBackward0>)

In [5]:
data_enc = crypten.cryptensor(torch.tensor([1, 2, 0, 1, 2, 3, 4])) #, dtype=torch.long))
print(data_enc.get_plain_text())
lut = crypten.cryptensor(torch.tensor([[10, 20, 30], [11, 21, 31], [12, 22, 32], [13, 23, 33], [14, 24, 34]])).share
# print(lut / 2**16)
data_enc.evaluate_embed(lut).get_plain_text()

tensor([1., 2., 0., 1., 2., 3., 4.])
<built-in method type of Tensor object at 0x38764ba70>


tensor([[11., 21., 31.],
        [12., 22., 32.],
        [10., 20., 30.],
        [11., 21., 31.],
        [12., 22., 32.],
        [13., 23., 33.],
        [14., 24., 34.]])

In [7]:
model = nn.LayerNorm(4)
model.weight = torch.tensor([1, 2, 3, 4])
model.bias = torch.tensor([1, 2, 3, 4])

model.encrypt(src=0)

# Load data to Bob
print('loading data')
# data_enc = crypten.load_from_party('/tmp/bob_test.pth', src=ALICE)
data_enc = crypten.cryptensor(torch.rand(2, 3, 4)) #, dtype=torch.long))

# print(f"{data_enc.get_plain_text()=}")
# Classify the encrypted data
model.eval()
print("forward")
output_enc = model(data_enc)
print('output_enc')
# Compute the accuracy
output = output_enc.get_plain_text()
print(f"{output=}")

INFO:root:In LayerNorm
INFO:root:weight=Parameter containing:
tensor([1., 1., 1., 1.], requires_grad=True)
INFO:root:bias=Parameter containing:
tensor([0., 0., 0., 0.], requires_grad=True)
INFO:root:In AUTOGRAD
INFO:root:inv_var.get_plain_text()=tensor([[3.8304, 3.5486, 4.0272],
        [3.0654, 2.5956, 2.2561]])


loading data
forward
output_enc
output=tensor([[[ 1.2089,  4.1880,  2.9328, -1.1218],
         [ 2.0143,  3.1603, -0.5563,  2.3644],
         [ 1.4534,  1.7210,  5.9170, -1.1451]],

        [[-0.0346,  3.8206,  5.3922,  1.3083],
         [ 1.3207, -0.6203,  2.7214,  8.3293],
         [ 0.7060,  4.9697,  1.1984,  1.6388]]])


In [8]:
layer = torch.nn.LayerNorm(4)
layer.weight = torch.nn.Parameter(torch.tensor([1.0, 2.0, 3.0, 4.0]))
layer.bias = torch.nn.Parameter(torch.tensor([1.0, 2.0, 3.0, 4.0]))
print(layer(data_enc.get_plain_text()))
print(data_enc.get_plain_text().mean(dim=-1))
print(data_enc.get_plain_text().var(dim=-1))
print(1/data_enc.get_plain_text().var(dim=-1).sqrt())

tensor([[[ 1.2462,  4.5780,  2.9208, -2.0350],
         [ 2.1834,  3.3537, -1.1494,  2.0914],
         [ 1.5395,  1.6681,  6.4706, -2.1216]],

        [[-0.1971,  4.1063,  5.7676,  0.8856],
         [ 1.3703, -1.0258,  2.6783,  8.9994],
         [ 0.6606,  5.4288,  0.9198,  1.2737]]],
       grad_fn=<NativeLayerNormBackward0>)
tensor([[0.5748, 0.5381, 0.7149],
        [0.4933, 0.5724, 0.3004]])
tensor([[0.0654, 0.0778, 0.0581],
        [0.1060, 0.1484, 0.1965]])
tensor([[3.9090, 3.5858, 4.1501],
        [3.0716, 2.5958, 2.2559]])


In [11]:
model = nn.Attention(768, 12)

model.encrypt(src=0)

# Load data to Bob
print('loading data')
# data_enc = crypten.load_from_party('/tmp/bob_test.pth', src=ALICE)
data_enc = crypten.cryptensor(torch.rand(1, 128, 768)) #, dtype=torch.long))

# Classify the encrypted data
model.eval()
print("forward")
output_enc = model(data_enc)
print('output_enc')
# Compute the accuracy
output = output_enc.get_plain_text()
print(f"{output=}")

INFO:root:In Linear Init
INFO:root:In Linear Init
INFO:root:In Linear Init
INFO:root:In Linear Init
INFO:root:In Linear
INFO:root:In Linear


loading data
forward


INFO:root:In Linear


output_enc
output=tensor([[[-0.3534, -1.6627, -0.4255,  ..., -2.0522, -0.3004,  2.2854],
         [-0.3498, -1.6441, -0.4201,  ..., -2.0388, -0.2984,  2.2722],
         [-0.3558, -1.6705, -0.4261,  ..., -2.0273, -0.2972,  2.2610],
         ...,
         [-0.3586, -1.6877, -0.4316,  ..., -2.0528, -0.3011,  2.2874],
         [-0.3594, -1.6921, -0.4322,  ..., -2.0154, -0.2949,  2.2465],
         [-0.3616, -1.6967, -0.4324,  ..., -2.0409, -0.2984,  2.2742]]])


In [9]:
class Block(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(Block, self).__init__()
        embed_dim = embed_dim
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.attn = nn.Attention(embed_dim, num_heads)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim),
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x

model = Block(768, 12)
model.encrypt(src=0)

# Load data to Bob
print('loading data')
# data_enc = crypten.load_from_party('/tmp/bob_test.pth', src=ALICE)
data_enc = crypten.cryptensor(torch.rand(1, 128, 768)) #, dtype=torch.long))

# Classify the encrypted data
model.eval()
print("forward")
output_enc = model(data_enc)
print('output_enc')
# Compute the accuracy
output = output_enc.get_plain_text()
print(f"{output=}")

INFO:root:In LayerNorm
INFO:root:weight=Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.

loading data
forward


INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear


output_enc
output=tensor([[[ 1.1848e+00,  3.3322e-01,  3.7883e-01,  ...,  6.9392e-01,
           1.1551e+00, -2.1667e-02],
         [ 5.4616e-01,  9.1588e-01, -1.9666e-01,  ...,  1.0253e+00,
           5.2177e-01, -3.9018e-01],
         [-7.3838e+07, -2.5166e+07,  9.3520e+07,  ..., -2.5163e+07,
           1.7955e+08, -7.8470e+07],
         ...,
         [ 7.8235e-01,  1.1859e+00,  3.5832e-01,  ...,  8.9456e-01,
           1.8242e-01,  4.4464e-01],
         [ 8.2220e-01,  1.6330e-01,  4.1974e-01,  ...,  1.0208e+00,
           9.9847e-01,  3.1905e-01],
         [ 5.9824e-01,  6.0516e-02,  7.4852e-01,  ...,  1.6308e+00,
           1.6758e+00, -5.7065e-01]]])


In [34]:
class GPT(nn.Module):
    def __init__(self, embed_dim, num_heads, num_blocks, vocab_size, seq_len, full=True):
        super(GPT, self).__init__()
        self.full = full
        if full:
            self.tok_embed = nn.Embedding(vocab_size, embed_dim)
            self.pos_embed = crypten.cryptensor(torch.zeros(1, seq_len, embed_dim))

        self.blocks = nn.Sequential(
            *[Block(embed_dim, num_heads) for _ in range(num_blocks)]
        )
        if full:
            self.ln = nn.LayerNorm(embed_dim)
            self.fc = nn.Linear(embed_dim, vocab_size)
            self.softmax = nn.Softmax(-1)

    def forward(self, x, target=None):
        if self.full:
            tok_embedding = self.tok_embed(x)
            pos_embedding = self.pos_embed[:, :x.size()[1], :]
            x = tok_embedding + pos_embedding
        x = self.blocks(x)
        if self.full:
            x = self.ln(x)
            x = self.fc(x)
            x = self.softmax(x)
        return x

full = False
# model = GPT(768, 12, 12, 50257, 128, full) # gpt2 13.5s
model = GPT(2048, 16, 24, 50257, 128, full) # gpt-neo 2m 43.6s
model.encrypt(src=0)

# Load data to Bob
print('loading data')
# data_enc = crypten.load_from_party('/tmp/bob_test.pth', src=ALICE)
if full:
    data_enc = crypten.cryptensor(torch.arange(64).reshape(1, 64))
else:
    data_enc = crypten.cryptensor(torch.arange(64 * 2048).reshape(1, 64, 2048))

# Classify the encrypted data
model.eval()
print("forward")
output_enc = model(data_enc)
print('output_enc')
# Compute the accuracy
output = output_enc.get_plain_text()
print(f"{output=}")


INFO:root:In LayerNorm
INFO:root:weight=Parameter containing:
tensor([1., 1., 1.,  ..., 1., 1., 1.], requires_grad=True)
INFO:root:bias=Parameter containing:
tensor([0., 0., 0.,  ..., 0., 0., 0.], requires_grad=True)
INFO:root:In LayerNorm
INFO:root:weight=Parameter containing:
tensor([1., 1., 1.,  ..., 1., 1., 1.], requires_grad=True)
INFO:root:bias=Parameter containing:
tensor([0., 0., 0.,  ..., 0., 0., 0.], requires_grad=True)
INFO:root:In Linear Init
INFO:root:In Linear Init
INFO:root:In Linear Init
INFO:root:In Linear Init
INFO:root:In Linear Init
INFO:root:In Linear Init
INFO:root:In LayerNorm
INFO:root:weight=Parameter containing:
tensor([1., 1., 1.,  ..., 1., 1., 1.], requires_grad=True)
INFO:root:bias=Parameter containing:
tensor([0., 0., 0.,  ..., 0., 0., 0.], requires_grad=True)
INFO:root:In LayerNorm
INFO:root:weight=Parameter containing:
tensor([1., 1., 1.,  ..., 1., 1., 1.], requires_grad=True)
INFO:root:bias=Parameter containing:
tensor([0., 0., 0.,  ..., 0., 0., 0.], re

loading data
forward


INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In

output_enc
output=tensor([[[ 3.5694e+09, -1.2558e+09, -7.0618e+08,  ..., -5.7081e+09,
          -3.3800e+09,  7.3424e+08],
         [ 3.5694e+09, -1.2558e+09, -7.0617e+08,  ..., -5.7081e+09,
          -3.3800e+09,  7.3424e+08],
         [ 3.5694e+09, -1.2558e+09, -7.0617e+08,  ..., -5.7081e+09,
          -3.3800e+09,  7.3425e+08],
         ...,
         [ 3.5695e+09, -1.2556e+09, -7.0605e+08,  ..., -5.7080e+09,
          -3.3799e+09,  7.3437e+08],
         [ 3.5695e+09, -1.2556e+09, -7.0605e+08,  ..., -5.7079e+09,
          -3.3799e+09,  7.3437e+08],
         [ 3.5695e+09, -1.2556e+09, -7.0605e+08,  ..., -5.7079e+09,
          -3.3799e+09,  7.3437e+08]]])


In [32]:
class BertBlock(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(BertBlock, self).__init__()
        embed_dim = embed_dim
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.attn = nn.Attention(embed_dim, num_heads)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim),
        )

    def forward(self, x):
        x = self.ln1(x + self.attn(x))
        x = self.ln2(x + self.ff(x))
        return x

class Bert(nn.Module):
    def __init__(self, embed_dim, num_heads, num_blocks, vocab_size, seq_len, full=True):
        super(Bert, self).__init__()
        self.full = full
        if full:
            self.tok_embed = nn.Embedding(vocab_size, embed_dim)
            self.pos_embed = crypten.cryptensor(torch.zeros(1, seq_len, embed_dim))
        self.ln = nn.LayerNorm
        self.blocks = nn.Sequential(
            *[BertBlock(embed_dim, num_heads) for _ in range(num_blocks)]
        )
        self.ln = nn.LayerNorm(embed_dim)
        if full:
            self.fc = nn.Linear(embed_dim, vocab_size)
            self.softmax = nn.Softmax(-1)

    def forward(self, x, target=None):
        if self.full:
            tok_embedding = self.tok_embed(x)
            pos_embedding = self.pos_embed[:, :x.size()[1], :]
            x = tok_embedding + pos_embedding
        x = self.ln(x)
        x = self.blocks(x)
        if self.full:
            x = self.fc(x)
            x = self.softmax(x)
        return x

full = False
# model = Bert(128, 2, 2, 30522, 128, full) # bert tiny 0.3s
# model = Bert(768, 12, 12, 30522, 128, full) # bert base 13.5s
model = Bert(1024, 16, 24, 30522, 128, full) # bert large 44.8s
model.encrypt(src=0)

# Load data to Bob
print('loading data')
# data_enc = crypten.load_from_party('/tmp/bob_test.pth', src=ALICE)
if full:
    data_enc = crypten.cryptensor(torch.arange(64).reshape(1, 64))
else:
    data_enc = crypten.cryptensor(torch.arange(64 * 1024).reshape(1, 64, 1024))

# Classify the encrypted data
model.eval()
print("forward")
output_enc = model(data_enc)
print('output_enc')
# Compute the accuracy
output = output_enc.get_plain_text()
print(f"{output=}")


INFO:root:In LayerNorm
INFO:root:weight=Parameter containing:
tensor([1., 1., 1.,  ..., 1., 1., 1.], requires_grad=True)
INFO:root:bias=Parameter containing:
tensor([0., 0., 0.,  ..., 0., 0., 0.], requires_grad=True)
INFO:root:In LayerNorm
INFO:root:weight=Parameter containing:
tensor([1., 1., 1.,  ..., 1., 1., 1.], requires_grad=True)
INFO:root:bias=Parameter containing:
tensor([0., 0., 0.,  ..., 0., 0., 0.], requires_grad=True)
INFO:root:In Linear Init
INFO:root:In Linear Init
INFO:root:In Linear Init
INFO:root:In Linear Init
INFO:root:In Linear Init
INFO:root:In Linear Init
INFO:root:In LayerNorm
INFO:root:weight=Parameter containing:
tensor([1., 1., 1.,  ..., 1., 1., 1.], requires_grad=True)
INFO:root:bias=Parameter containing:
tensor([0., 0., 0.,  ..., 0., 0., 0.], requires_grad=True)
INFO:root:In LayerNorm
INFO:root:weight=Parameter containing:
tensor([1., 1., 1.,  ..., 1., 1., 1.], requires_grad=True)
INFO:root:bias=Parameter containing:
tensor([0., 0., 0.,  ..., 0., 0., 0.], re

loading data
forward


INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In AUTOGRAD
INFO:root:In Linear
INFO:root:In

output_enc
output=tensor([[[-5.5166e+08, -1.0820e+08,  8.9591e+08,  ..., -1.1512e+09,
           3.2500e+08,  8.6850e+08],
         [-5.5166e+08, -1.0820e+08,  8.9591e+08,  ..., -1.1512e+09,
           3.2500e+08,  8.6850e+08],
         [-5.5166e+08, -1.0820e+08,  8.9591e+08,  ..., -1.1512e+09,
           3.2500e+08,  8.6850e+08],
         ...,
         [-5.5166e+08, -1.0820e+08,  8.9591e+08,  ..., -1.1512e+09,
           3.2500e+08,  8.6850e+08],
         [-5.5166e+08, -1.0820e+08,  8.9591e+08,  ..., -1.1512e+09,
           3.2500e+08,  8.6850e+08],
         [-5.5166e+08, -1.0820e+08,  8.9591e+08,  ..., -1.1512e+09,
           3.2500e+08,  8.6850e+08]]])


In [6]:
import math

class Attention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(Attention, self).__init__()

        assert embed_dim % num_heads == 0, "invalid heads and embedding dimension"

        self.embed_dim = embed_dim
        self.num_heads = config.num_heads
        self.search_dim = embed_dim // num_heads

        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.query = nn.Linear(embed_dim, embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)


    def forward(self, x):
        batch_size = x.shape[0]
        seq_len = x.shape[1]

        level = logging.getLogger().level
        logging.getLogger().setLevel(logging.INFO)
        logging.info("==================")
        logging.info("In forward" )
        logging.info("==================")

        k_t = self.key(x).reshape(batch_size, seq_len, self.num_heads, self.search_dim).permute(0, 2, 3, 1)
        v = self.value(x).reshape(batch_size, seq_len, self.num_heads, self.search_dim).transpose(1, 2)
        q = self.query(x).reshape(batch_size, seq_len, self.num_heads, self.search_dim).transpose(1, 2)

        logging.info(f"{q.shape=}")

        attn = q.matmul(k_t) / math.sqrt(q.size(-1))
        attn = attn.softmax(dim=-1)

        logging.info(f"{attn.shape=}")
        logging.info(f"{v.shape=}")

        y = attn.matmul(v)

        logging.info(f"{y.shape=}")

        y = y.transpose(1, 2)

        logging.info(f"{y.shape=}")

        y = y.reshape(batch_size, seq_len, self.embed_dim)

        logging.info(f"{y.shape=}")

        logging.getLogger().setLevel(level)

        return y

model = Attention(768, 12)

model.encrypt(src=0)

# Load data to Bob
print('loading data')
# data_enc = crypten.load_from_party('/tmp/bob_test.pth', src=ALICE)
data_enc = crypten.cryptensor(torch.rand(1, 128, 768)) #, dtype=torch.long))

# Classify the encrypted data
model.eval()
print("forward")
output_enc = model(data_enc)
print('output_enc')
# Compute the accuracy
output = output_enc.get_plain_text()
print(f"{output=}")


NameError: name 'config' is not defined

Next, we will define the structure of Alice's network as a class. Even though Alice has a pre-trained model, the CrypTen will require this structure as input.

In [None]:
import logging
import crypten.nn as nn

class GPTConfig:
    # Set dropout to 0 for inference. It is only needed for training
    attn_dropout = 0.1
    embed_dropout = 0.1
    ff_dropout = 0.1

    def __init__(
        self, vocab_size, max_len, **kwargs
    ):
        self.vocab_size = vocab_size
        self.max_len = max_len
        for key, value in kwargs.items():
            setattr(self, key, value)

class GPT1Config(GPTConfig):
    num_heads = 12
    num_blocks = 12
    embed_dim = 768

vocab_size = 50257
max_len = 1024

config = GPT1Config(vocab_size, max_len)

# class LayerNorm(nn.Module):
#     """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """

#     def __init__(self, ndim, bias):
#         super(LayerNorm, self).__init__()
#         self.weight = nn.Parameter(torch.ones(ndim))
#         self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

#     def forward(self, input):
#         return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)


class MultiheadAttention(nn.Module):
    def __init__(self, config):
        super(MultiheadAttention, self).__init__()

        embed_dim = config.embed_dim
        self.num_heads = config.num_heads
        assert embed_dim % self.num_heads == 0, "invalid heads and embedding dimension configuration"

        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.query = nn.Linear(embed_dim, embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.attn_dropout = nn.Dropout(config.attn_dropout)
        self.proj_dropout = nn.Dropout(config.ff_dropout)
        self.register_buffer(
            "mask",
            torch.tril(torch.ones(config.max_len, config.max_len))
            .unsqueeze(0).unsqueeze(0)
        )


    def forward(self, x):
        with open("foo.txt", "w") as file:
            file.write("Your text goes here")

        level = logging.getLogger().level
        logging.getLogger().setLevel(logging.INFO)
        logging.info("==================")
        logging.info("In forward" )
        logging.info("==================")

        # x.shape == (batch_size, seq_len, embed_dim)
        k_t = self.key(x).T
        v = self.value(x)
        q = self.query(x)
        # shape == (batch_size, num_heads, seq_len, head_dim)
        logging.info(f"{k_t.shape=}")
        logging.info(f"{v.shape=}")
        logging.info(f"{q.shape=}")

        logging.info("KQV created")
        attn = torch.matmul(q, k_t) / torch.sqrt(q.size(-1))

        # attn.shape == (batch_size, num_heads, seq_len, seq_len)
        # attn = attn.masked_fill(self.mask == 0, float("-inf"))
        logging.info("masked fill")

        attn = self.attn_dropout(attn)

        # attn.shape == (batch_size, num_heads, seq_len, seq_len)
        attn = F.softmax(attn, dim=-1)

        logging.info('here %s %s', attn.shape, v.shape)

        y = torch.matmul(attn, v)
        logging.info("matmul done")

        # y.shape == (batch_size, seq_len, embed_dim)
        y = self.proj_dropout(self.proj(y))
        logging.info("proj_dropout")
        logging.info(f"{y.type}")

        logging.getLogger().setLevel(level)

        return y

class Block(nn.Module):
    def __init__(self, config):
        super(Block, self).__init__()
        embed_dim = config.embed_dim
        self.ln1 = nn.BatchNorm1d(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.attn = MultiheadAttention(config)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim),
            nn.Dropout(config.ff_dropout),
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config.embed_dim
        self.max_len = config.max_len
        # self.tok_embed = nn.Embedding(
        #     config.vocab_size, embed_dim
        # )
        # self.pos_embed = nn.Parameter(
        #     torch.zeros(1, config.max_len, embed_dim)
        # )
        self.dropout = nn.Dropout(config.embed_dropout)
        self.blocks = nn.Sequential(
            *[Block(config) for _ in range(config.num_blocks)]
        )
        self.ln = nn.LayerNorm(embed_dim)
        self.fc = nn.Linear(embed_dim, config.vocab_size)

    def forward(self, x, target=None):
        # embed_dim = config.embed_dim
        # batch_size = x.size(0)
        seq_len = x.size(1)
        assert seq_len <= self.max_len, "sequence longer than model capacity"

        # tok_embedding = self.tok_embed(x)
        # tok_embedding.shape == (batch_size, seq_len, embed_dim)
        # pos_embedding = self.pos_embed[:, :seq_len, :]
        # pos_embedding.shape == (1, seq_len, embed_dim)
        # x = self.dropout(tok_embedding + pos_embedding)
        x = self.blocks(x)
        x = self.ln(x)
        x = self.fc(x)
        # x.shape == (batch_size, seq_len, vocab_size)
        return x

model = GPT(config)
# model = Block(config)
# model = MultiheadAttention(config)
# model = LayerNorm(ndim=10, bias=1)

crypten.common.serial.register_safe_class(GPT)

We will also define a helper routine `compute_accuracy` to make it easy to compute the accuracy of the output we get.

In [None]:
def compute_accuracy(output, labels):
    pred = output.argmax(1)
    correct = pred.eq(labels)
    correct_count = correct.sum(0, keepdim=True).float()
    accuracy = correct_count.mul_(100.0 / output.size(0))
    return accuracy

## Encrypting a Pre-trained Model

Assume that Alice has a pre-trained network ready to classify data. Let's see how we can use CrypTen to encrypt this network, so it can be used to classify data without revealing its parameters. We'll use the pre-trained model in `models/tutorial4_alice_model.pth` in this tutorial. As in Tutorial 3, we will assume Alice is using the rank 0 process, while Bob is using the rank 1 process. 

In [None]:
ALICE = 0
BOB = 1

In CrypTen, encrypting PyTorch network is straightforward: we load a PyTorch model from file to the appropriate source, convert it to a CrypTen model and then encrypt it. Let us understand each of these steps.

As we did with CrypTensors in Tutorial 3, we will use CrypTen's load functionality (i.e., `crypten.load`) to read a model from file to a particular source. The source is indicated by the keyword argument `src`. As in Tutorial 3, this src argument tells us the rank of the party we want to load the model to (and later, encrypt the model from). In addition, here we also need to provide a dummy model to tell CrypTen the model's structure. The dummy model is indicated by the keyword argument `dummy_model`. Note that unlike loading a tensor, the result from `crypten.load` is not encrypted. Instead, only the `src` party's model is populated from the file.

Once the model is loaded, we call the function `from_pytorch`: this function sets up a CrypTen network from the PyTorch network. It takes the plaintext network as input as well as dummy input. The dummy input must be a `torch` tensor of the same shape as a potential input to the network, however the values inside the tensor do not matter.  

Finally, we call `encrypt` on the CrypTen network to encrypt its parameters. Once we call the `encrypt` function, the models `encrypted` property will verify that the model parameters have been encrypted. (Encrypted CrypTen networks can also be decrypted using the `decrypt` function).

In [None]:
# Load pre-trained model to Alice
# dummy_model = AliceNet()
# plaintext_model = torch.load('models/gpt2.bin')
# model.load_state_dict(plaintext_model)
plaintext_model = model

print(model)


# Encrypt the model from Alice:

# 1. Create a dummy input with the same shape as the model input
dummy_input = torch.empty((1024, 768), dtype=torch.float32)


# 2. Construct a CrypTen network with the trained model and dummy_input
private_model = crypten.nn.from_pytorch(plaintext_model, dummy_input)

# 3. Encrypt the CrypTen network with src=ALICE
private_model.encrypt(src=ALICE)

#Check that model is encrypted:
print("Model successfully encrypted:", private_model.encrypted)

print(private_model)

## Classifying Encrypted Data with Encrypted Model

We can now use Alice's encrypted network to classify Bob's data. For this, we need to encrypt Bob's data as well, as we did in Tutorial 3 (recall that Bob has the rank 1 process). Once Alice's network and Bob's data are both encrypted, CrypTen inference is performed with essentially identical steps as in PyTorch. 

In [None]:
from transformers import GPT2Model

# model = GPT2Model.from_pretrained('gpt2')

In [None]:
import multiprocessing
multiprocessing.set_start_method('fork')

In [None]:
import crypten.mpc as mpc
import crypten.communicator as comm

print("starting")
labels = torch.load('/tmp/bob_test_labels.pth').long()
count = 100 # For illustration purposes, we'll use only 100 samples for classification
print("started")

# @mpc.run_multiprocess(world_size=2)
def encrypt_model_and_data():
    print("loading")
    # Load pre-trained model to Alice
    # model = crypten.load_from_party('models/gpt2.bin', src=ALICE)
    print("loaded gpt2")
    # Encrypt model from Alice
    dummy_input = torch.empty((1024, 768)) #, dtype=torch.long)
    print('dummy_input')
    # private_model = crypten.nn.from_pytorch(model, dummy_input)
    private_model = model
    print("encrypting")
    private_model.encrypt(src=ALICE)

    # Load data to Bob
    print('loading data')
    # data_enc = crypten.load_from_party('/tmp/bob_test.pth', src=ALICE)
    data_enc = crypten.cryptensor(torch.rand(1024, 768)) #, dtype=torch.long))
    data_enc2 = data_enc[:count]
    data_flatten = data_enc2.flatten(start_dim=1)
    print('flattened')

    # Classify the encrypted data
    private_model.eval()
    print("forward")
    output_enc = private_model(data_flatten)
    print('output_enc')
    # Compute the accuracy
    output = output_enc.get_plain_text()
    accuracy = compute_accuracy(output, labels[:count])
    crypten.print("\tAccuracy: {0:.4f}".format(accuracy.item()))

encrypt_model_and_data()

INFO:root:In forward
INFO:root:In Linear
INFO:root:In Linear
INFO:root:In Linear
INFO:root:q=MPCTensor(
	_tensor=tensor([[-10347,  -2793,  -4438,  ...,   -382, -23082,  -2037],
        [-11186,   8446,   5831,  ...,   5702,  -8415,  -9827],
        [-11965, -16345,  -1484,  ..., -23496, -13105, -17982],
        ...,
        [  7657, -12255,   9758,  ...,   2460, -18667, -30977],
        [  6498, -17621,  18986,  ...,   4994, -34420, -21247],
        [ -5863,  -8031,  29161,  ...,   7533, -24777,  -2178]])
	plain_text=HIDDEN
	ptype=ptype.arithmetic
)


starting
started
loading
loaded gpt2
dummy_input
encrypting
loading data
flattened
forward
output_enc
	Accuracy: 0.0000


## Validating Encrypted Classification

Finally, we will verify that CrypTen classification results in encrypted output, and that this output can be decrypted into meaningful labels. 

To see this, in this tutorial, we will just check whether the result is an encrypted tensor; in the next tutorial, we will look into the values of tensor and confirm the encryption. We will also decrypt the result. As we discussed before, Alice and Bob both have access to the decrypted output of the model, and can both use this to obtain the labels. 

In [None]:
@mpc.run_multiprocess(world_size=2)
def encrypt_model_and_data():
    # Load pre-trained model to Alice
    plaintext_model = crypten.load_from_party('models/tutorial4_alice_model.pth', src=ALICE)

    # Encrypt model from Alice
    dummy_input = torch.empty((1, 784))
    private_model = crypten.nn.from_pytorch(plaintext_model, dummy_input)
    private_model.encrypt(src=ALICE)

    # Load data to Bob
    data_enc = crypten.load_from_party('/tmp/bob_test.pth', src=BOB)
    data_enc2 = data_enc[:count]
    data_flatten = data_enc2.flatten(start_dim=1)

    # Classify the encrypted data
    private_model.eval()
    output_enc = private_model(data_flatten)

    # Verify the results are encrypted:
    crypten.print("Output tensor encrypted:", crypten.is_encrypted_tensor(output_enc))

    # Decrypting the result
    output = output_enc.get_plain_text()

    # Obtaining the labels
    pred = output.argmax(dim=1)
    crypten.print("Decrypted labels:\n", pred)

encrypt_model_and_data()

Process Process-1:
Traceback (most recent call last):
  File "/Users/memo/.pyenv/versions/3.11.4/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/Users/memo/.pyenv/versions/3.11.4/lib/python3.11/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/memo/Documents/curl/crypten/mpc/context.py", line 30, in _launch
    return_value = func(*func_args, **func_kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/folders/__/frht1s0n5hd1nnltt_cl9m1r0000gn/T/ipykernel_58278/2535009032.py", line 4, in encrypt_model_and_data
    plaintext_model = crypten.load_from_party('models/tutorial4_alice_model.pth', src=ALICE)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/memo/Documents/curl/crypten/__init__.py", line 337, in load_from_party
    result = load_closure(f, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/

KeyboardInterrupt: 

Process Process-2:
Traceback (most recent call last):
  File "/Users/memo/.pyenv/versions/3.11.4/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/Users/memo/.pyenv/versions/3.11.4/lib/python3.11/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/memo/Documents/curl/crypten/mpc/context.py", line 30, in _launch
    return_value = func(*func_args, **func_kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/folders/__/frht1s0n5hd1nnltt_cl9m1r0000gn/T/ipykernel_58278/2535009032.py", line 4, in encrypt_model_and_data
    plaintext_model = crypten.load_from_party('models/tutorial4_alice_model.pth', src=ALICE)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/memo/Documents/curl/crypten/__init__.py", line 356, in load_from_party
    result = comm.get().broadcast_obj(None, src)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

This completes our tutorial. While we have used a simple network here to illustrate the concepts, CrypTen provides primitives to allow for encryption of substantially more complex networks. In our examples section, we demonstrate how CrypTen can be used to encrypt LeNet and ResNet, among others. 

Before exiting this tutorial, please clean up the files generated using the following code.

In [None]:
import os

filenames = ['/tmp/alice_train.pth',
             '/tmp/alice_train_labels.pth',
             '/tmp/bob_test.pth',
             '/tmp/bob_test_labels.pth']

for fn in filenames:
    if os.path.exists(fn): os.remove(fn)