In [39]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import numpy as np
import random
import copy
import time
from torch.optim.optimizer import Optimizer, required

In [40]:
publicratio = 0.1 #What fraction of training data is treated as public
batch_size = 64
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

#First get whole training set
trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform) 
#Compute its size
dataset_size = len(trainset)
#Split into two based on dataset_size and publicratio
publicset, privateset = torch.utils.data.random_split(trainset, [int(dataset_size*publicratio), dataset_size-int(dataset_size*publicratio)])
#Public data only has one batch so we can compute full gradient
publicloader = torch.utils.data.DataLoader(publicset, batch_size=int(dataset_size*publicratio), shuffle=True, num_workers=2)
privateloader = torch.utils.data.DataLoader(privateset, batch_size=batch_size, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [41]:
#Check that the splitting procedure worked as desired
print(len(publicset))
print(len(privateset))
print(len(testset))

5000
45000
10000


In [54]:
#"Standard" low-complexity CIFAR 10 net

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) 
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = Net()

In [53]:
epochs = 100
eta = .001
criterion = nn.CrossEntropyLoss()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Running on "+"cuda:0" if torch.cuda.is_available() else "cpu")
optimizer = optim.SGD(model.parameters(), lr = eta)

for epoch in range(epochs):
    fullmodel = copy.deepcopy(model) #copy of model that will store the public gradient
    oldmodel = copy.deepcopy(model) #copy of model that will store the stochastic gradient at initial point
    avgmodel = copy.deepcopy(model) #copy of model that will store the running average
    #Compute grad f(x_s^0)
    for i, data in enumerate(publicloader, 0):
         fullmodel.zero_grad()
         inputs, labels = data
         outputs = fullmodel(inputs)
         loss = criterion(outputs, labels)
         loss.backward()
    
    running_loss = 0.0
    count = 0
    for i, data in enumerate(privateloader, 0): #One epoch of private steps
        inputs, labels = data
        
        #Compute grad f_i(x_s^t)
        model.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        #Compute grad f_i(x_s^0)
        oldmodel.zero_grad()
        outputs = oldmodel(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        #Set gradients of model to be grad f_i(x_s^t) - grad f_i(x_s^0) + grad f(x_s^0)
        for pA, pB, pC in zip(model.parameters(), oldmodel.parameters(), fullmodel.parameters()):
            vrgrad = pA.grad.clone() - pB.grad.clone() + pC.grad.clone()
            pA.grad = vrgrad
            
        #Make a step using these gradients
        optimizer.step()
        
        #Update running average
        count += 1
        sdA = avgmodel.state_dict()
        sdB = model.state_dict()
        for key in sdA:
            sdA[key] = (1 - 1.0/count)*sdA[key] + (1.0/count)*sdB[key]
        
    #Set the current model to be the average of this epoch 
    #model = copy.deepcopy(avgmodel) 
    
    #Print training/test accuracy
    print("Epoch "+str(epoch))
    
    #Training accuracy computed wrt private data
    correct = 0
    total = 0
    with torch.no_grad():
        for data in privateloader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('Training accuracy: %f %%' % (
        100.0 * correct / total))
    
    #Test accuracy
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('Test accuracy: %f %%' % (
        100.0 * correct / total))
        

cpu


Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f38e08669e8>>
Traceback (most recent call last):
  File "/home/arunganesh/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1324, in __del__
    self._shutdown_workers()
  File "/home/arunganesh/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1316, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.6/multiprocessing/process.py", line 134, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f38e0854278>>
Traceback (most recent call last):
  File "/home/arunganesh/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1324, in __d

Epoch 0


Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f38e08669e8>>
Traceback (most recent call last):
  File "/home/arunganesh/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1324, in __del__
    self._shutdown_workers()
  File "/home/arunganesh/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1316, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.6/multiprocessing/process.py", line 134, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: Exception ignored in: can only test a child process<bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f38e08669e8>>

Traceback (most recent call last):
Exception ignored in:   File "/home/arunganesh/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.

Training accuracy: 12.000000 %


Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f38e08669e8>>
Traceback (most recent call last):
  File "/home/arunganesh/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1324, in __del__
    self._shutdown_workers()
  File "/home/arunganesh/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1316, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.6/multiprocessing/process.py", line 134, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: Exception ignored in: can only test a child process<bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f38e08669e8>>

Traceback (most recent call last):
Exception ignored in:   File "/home/arunganesh/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.

Test accuracy: 12.390000 %


Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f38e08669e8>>
Traceback (most recent call last):
  File "/home/arunganesh/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1324, in __del__
    self._shutdown_workers()
  File "/home/arunganesh/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1316, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.6/multiprocessing/process.py", line 134, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f38e0854278>>
Traceback (most recent call last):
  File "/home/arunganesh/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1324, in __d

Epoch 1
Training accuracy: 12.180000 %
Test accuracy: 12.460000 %
Epoch 2
Training accuracy: 12.220000 %
Test accuracy: 12.520000 %
Epoch 3
Training accuracy: 12.540000 %
Test accuracy: 12.630000 %
Epoch 4
Training accuracy: 12.560000 %
Test accuracy: 12.720000 %
Epoch 5
Training accuracy: 12.580000 %
Test accuracy: 12.830000 %
Epoch 6
Training accuracy: 12.860000 %
Test accuracy: 12.940000 %
Epoch 7
Training accuracy: 13.020000 %
Test accuracy: 12.980000 %
Epoch 8
Training accuracy: 13.300000 %
Test accuracy: 13.060000 %
Epoch 9
Training accuracy: 13.260000 %
Test accuracy: 13.190000 %
Epoch 10
Training accuracy: 13.320000 %
Test accuracy: 13.340000 %
Epoch 11
Training accuracy: 13.400000 %


KeyboardInterrupt: 