In [1]:
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch import optim
import numpy as np

from syft.frameworks.torch.differential_privacy import pate

W0627 23:50:33.599641 140444323645248 secure_random.py:22] Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow (1.14.1-dev20190517). Fix this by compiling custom ops.
W0627 23:50:33.616016 140444323645248 deprecation_wrapper.py:119] From /home/ayush/anaconda3/lib/python3.7/site-packages/tf_encrypted/session.py:28: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.



In [2]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,),(0.5,))])

trainset = datasets.MNIST('~/.pytorch/MNIST_data', train=True, download=True, transform=transform)
testset = datasets.MNIST('~/.pytorch/MNIST_data', train=False, download=True, transform=transform)

In [3]:
print(len(trainset))
print(len(testset))

60000
10000


In [4]:
num_teacher = 100
len_teacher_set = len(trainset)//num_teacher

teacher_set = [torch.utils.data.Subset(trainset, list(range(i * len_teacher_set, (i+1) * len_teacher_set))) for i in range(num_teacher)]
student_set = torch.utils.data.Subset(testset, list(range(int(len(testset) * 0.9))))
test_set = torch.utils.data.Subset(testset, list(range(int(len(testset) * 0.9), len(testset))))

print("Number of teachers : ", num_teacher)
print("Length of data in teacher set : ", len_teacher_set)
print("Length of student set : ", len(student_set))
print("Length of test set : ", len(test_set))

Number of teachers :  100
Length of data in teacher set :  600
Length of student set :  9000
Length of test set :  1000


In [5]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
    
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = F.log_softmax(self.fc2(x))
        
        return x

In [6]:
device = torch.device("cuda" if torch.cuda.is_available else "cpu")

teachers = [Model().to(device) for _ in range(num_teacher)]
student = Model().to(device)

In [7]:
teacherloader = [torch.utils.data.DataLoader(data, batch_size=60, shuffle = True, drop_last=True) for data in teacher_set]
teacheroptim = [optim.SGD(teacher.parameters(), lr = 0.01, momentum = 0.5) for teacher in teachers]
criterion = nn.NLLLoss()

In [8]:
for model in teachers:
    model.train()

In [None]:
epochs = 50

teacher_train_history = {'avg_losses':{}, 'avg_accuracies':{}}
for e in range(epochs):
    print("Epoch ", (e+1), " ...")
    
    avgloss = []
    avgacc = []
    
    for i in range(num_teacher):
        counter = 0
        total_loss = 0
        acc_count = 0
        
        model = teachers[i]
        dataloader = teacherloader[i]
        optimizer = teacheroptim[i]
        
        for images, labels in dataloader:
            counter += images.size(0)
            
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            
            logps = model(images)
            preds = torch.argmax(torch.exp(logps), dim=1)
            
            loss = criterion(logps, labels)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item() * images.size(0)
            acc_count += (preds == labels).sum().item()
        
        avgloss.append(total_loss/counter)
        avgacc.append(acc_count/counter)
    
    print(" Average losses : ", [avgl for avgl in avgloss])
    print(" Average accuracies : ", [avga for avga in avgacc])
    teacher_train_history['avg_losses'][e] = avgloss
    teacher_train_history['avg_accuracies'][e] = avgacc 

Epoch  1  ...


  app.launch_new_instance()


 Average losses :  [2.2995219469070434, 2.305666947364807, 2.29987108707428, 2.302030611038208, 2.306314539909363, 2.3153051376342773, 2.3184807538986205, 2.317122983932495, 2.311061978340149, 2.3035390615463256, 2.3155261278152466, 2.3142776250839234, 2.319117045402527, 2.3171235084533692, 2.3179646253585817, 2.3066146850585936, 2.3072688817977904, 2.306585669517517, 2.2924521207809447, 2.3174992561340333, 2.310452961921692, 2.3078145742416383, 2.3044384956359862, 2.309654426574707, 2.3101694107055666, 2.3078023910522463, 2.3066566705703737, 2.297309970855713, 2.3204730272293093, 2.3042556047439575, 2.319070076942444, 2.3130789518356325, 2.3139078855514525, 2.3053664684295656, 2.3030473709106447, 2.3131420612335205, 2.31833131313324, 2.3091835260391234, 2.303018665313721, 2.304009747505188, 2.3160953521728516, 2.3063392639160156, 2.315123105049133, 2.3172696113586424, 2.3124369859695433, 2.3058178186416627, 2.306938910484314, 2.316112756729126, 2.3174724102020265, 2.3147616147994996, 

 Average losses :  [2.2878992319107057, 2.286103105545044, 2.288249111175537, 2.288210391998291, 2.284455490112305, 2.2952073574066163, 2.281778264045715, 2.2881479263305664, 2.2996309280395506, 2.2815032482147215, 2.2958001136779784, 2.2943394422531127, 2.303026056289673, 2.289693737030029, 2.2994964838027956, 2.297956609725952, 2.291089129447937, 2.2782259941101075, 2.2895596742630007, 2.3039602994918824, 2.282843065261841, 2.2914784431457518, 2.2947973251342773, 2.2900833606719972, 2.3045347929000854, 2.291913318634033, 2.289506721496582, 2.2774462938308715, 2.2954468011856077, 2.2729913234710692, 2.295777702331543, 2.2921847820281984, 2.2996872663497925, 2.293791890144348, 2.284015488624573, 2.291038751602173, 2.2778178453445435, 2.2867202758789062, 2.2798385858535766, 2.2895717144012453, 2.293033790588379, 2.289947748184204, 2.2696926832199096, 2.295727586746216, 2.2953218460083007, 2.2920624732971193, 2.289257216453552, 2.2866121530532837, 2.299182152748108, 2.2924185514450075, 2

 Average losses :  [2.2851342439651487, 2.272595524787903, 2.272645950317383, 2.270553731918335, 2.2571977376937866, 2.2829313039779664, 2.2544655323028566, 2.2855082988739013, 2.284347891807556, 2.24337375164032, 2.2800102710723875, 2.277532124519348, 2.284353566169739, 2.2581329345703125, 2.2828537702560423, 2.276129698753357, 2.2692363023757935, 2.250962495803833, 2.261930751800537, 2.279922866821289, 2.2736698627471923, 2.2716723918914794, 2.2747774600982664, 2.255029535293579, 2.292940044403076, 2.2748756408691406, 2.275331664085388, 2.2606488466262817, 2.2813384771347045, 2.249269127845764, 2.27335467338562, 2.281207633018494, 2.2908272743225098, 2.2785287380218504, 2.2589105129241944, 2.2703941822052003, 2.263875675201416, 2.2682674646377565, 2.2601881980895997, 2.2443919658660887, 2.271041202545166, 2.2711960554122923, 2.2395296812057497, 2.277860140800476, 2.28188693523407, 2.268101716041565, 2.2740428686141967, 2.2670061588287354, 2.289654493331909, 2.2726292610168457, 2.2915

In [None]:
import matplotlib.pyplot as plt
plt.plot(range(1, epochs+1),[np.mean(teacher_train_history['avg_losses'][x]) for x in teacher_train_history['avg_losses']], label="Training loss")
plt.plot(range(1, epochs+1),[np.mean(teacher_train_history['avg_accuracies'][x]) for x in teacher_train_history['avg_accuracies']], label="Training Accuracies")
plt.legend()
plt.show()

In [None]:
for model in teachers:
    model.eval()

In [None]:
testloader = torch.utils.data.DataLoader(testset, batch_size=1024)
criterion = nn.NLLLoss()

In [None]:
counter = 0
total_loss = 0
acc_count = 0

pred_list = 0
label_list = 0

for images, labels in testloader:    
    
    label_list.append(labels)
    images, labels = images.to(device), images.to(device)
    
    temp_pred = []
    with torch.no_grad:
         for model in teachers:                 
            counter += images.size(0)
            
            logps = model(images)
            preds = (torch.exp(logps)).argmax(dim=1)
            pred_list.append(preds.cpu())
            
            loss = criterion(logps, labels)
            total_loss += loss.item()
            acc_count += (preds==labels).sum().item()
            