In [2]:
#Calling Imports 

import torch
import numpy as np
from torchvision import transforms, datasets
from matplotlib import pyplot as plt
import torch.nn as nn
import torch.optim as optim
import timeit
import time
import math

from models import LeNet5, VGG, ResNet18
from helpers import (
        weights_to_list, weights_to_list_fast, set_weights,
set_weights_fast, validate)

from collections import defaultdict
import pickle

device = 'cuda:0'

#Defining Main Variables

In [3]:
#Dataset to use, one of: "cifar10","mnist","cifar100"
dataset = "mnist"

#Batch Size
batch_size = 1000

#Number of batches to minimize over (i.e M in the paper)
n_samples = 400

#For later when we iterate over M
n_samples_list = np.arange(50,5000,300)

#NOTE: In this copy of the code we have N (variable from the paper) to be size-of-dataset - # of points to unlearn (in this case one batch)

#Defines # of epochs before the forging step
n_epochs = 1

#Defines index of the batch to forge
unl_batch_ind = [10]

In [5]:
#Selecting Datasets
#Can Change "current_model" to whatever one wants

datapath = "."

if dataset == "mnist":
    current_model = LeNet5

    mnist_tensor = datasets.MNIST(datapath, train = True, download = True, transform= transforms.Compose([transforms.ToTensor()]))
    mnist_val_tensor = datasets.MNIST(datapath, train = False, download = True, transform=transforms.Compose([transforms.ToTensor()]))

    current_train_data = mnist_tensor
    current_val_data = mnist_val_tensor

elif dataset == "cifar10":
    current_model = ResNet18

    cifar10_transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    cifar10_transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    cifar10_tensor = datasets.CIFAR10(
        datapath, train = True, download = True, transform=cifar10_transform_train)
    cifar10_val_tensor = datasets.CIFAR10(
        datapath, train = False, download = True, transform=cifar10_transform_test)

    current_train_data = cifar10_tensor
    current_val_data   = cifar10_val_tensor

elif dataset == "cifar100":
    current_model = ResNet18

    cifar100_transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])

    cifar100_transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])

    cifar100_tensor = datasets.CIFAR100(
        datapath, train = True, download = True, transform=cifar100_transform_train)
    cifar100_val_tensor = datasets.CIFAR100(
        datapath, train = False, download = True, transform=cifar100_transform_test)
    
    current_train_data = cifar100_tensor
    current_val_data = cifar100_val_tensor

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



#Forging with Fixed Number of Samples

First obtain a w_start and a w_final for the weights before and after applying the step to forge

In [6]:
train_dataloader = torch.utils.data.DataLoader(current_train_data, batch_size = batch_size, shuffle= False)
train_list = list(train_dataloader)

model = current_model()
model = model.to(device)

parameters_model = [parameter for name,parameter in model.named_parameters()]

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params = model.parameters(), lr = 0.01)

In [7]:
for epoch in range(n_epochs):
    for i,(img,label) in enumerate(train_dataloader):

      optimizer.zero_grad()
      img = img.to(device)
      label = label.to(device)

      out = model(img)
      loss = criterion(out,label)
      loss.backward()
      optimizer.step()

      if i%100==0:
        print(f"Epoch {epoch}, Batch {i}, Loss {loss}")

#Now have w_start
w_start = weights_to_list_fast(parameters_model)

Epoch 0, Batch 0, Loss 2.3026669025421143


In [8]:
#Proceeding to compute weights after forging steps

with torch.no_grad():
  set_weights_fast(w_start,parameters_model)

for ind in unl_batch_ind:
  print(ind)
  img,label = train_list[ind]


  #Comment out below if forging batch

  #img = imgs[unl_data_ind]
  #img = img.unsqueeze(0)
  #label = labels[unl_data_ind]
  #label = label.unsqueeze(0)


  #compute that update
  optimizer.zero_grad()
  img = img.to(device)
  label = label.to(device)

  out = model(img)
  loss = criterion(out,label)
  loss.backward()
  optimizer.step()

#save the final weights
w_final = weights_to_list_fast(parameters_model)

10


Now we procced to take a random sample of batches (not including the batch to forge) and find the minimum error producing batch

In [9]:
#First remove forged batch from the considered pool
len_list = list(range(len(train_list)))
del len_list[unl_batch_ind[0]]



In [10]:
#get samples
random_batches = np.random.choice(len_list,n_samples)
random_inds = np.random.choice(64, n_samples)


In [11]:
#Collect weights they produce when starting from w_start
w_forged_list = []

for i in range(n_samples):
    if i%10 == 0:
      print(f"On {i} sample")

    with torch.no_grad():
      set_weights_fast(w_start,parameters_model)

    batch_ind = random_batches[i]
    data_ind  = random_inds[i]

    imgs,labels = train_list[batch_ind]
    img = imgs[data_ind]
    img = img.unsqueeze(0)

    label = labels[data_ind]
    label = label.unsqueeze(0)

    #compute that update
    optimizer.zero_grad()
    img = img.to(device); label = label.to(device)

    out = model(img)
    loss = criterion(out,label)
    loss.backward()
    optimizer.step()

    w_forged = weights_to_list_fast(parameters_model)
    w_forged_list.append(w_forged)

On 0 sample
On 10 sample
On 20 sample
On 30 sample
On 40 sample
On 50 sample
On 60 sample
On 70 sample
On 80 sample
On 90 sample
On 100 sample
On 110 sample
On 120 sample
On 130 sample
On 140 sample
On 150 sample
On 160 sample
On 170 sample
On 180 sample
On 190 sample
On 200 sample
On 210 sample
On 220 sample
On 230 sample
On 240 sample
On 250 sample
On 260 sample
On 270 sample
On 280 sample
On 290 sample
On 300 sample
On 310 sample
On 320 sample
On 330 sample
On 340 sample
On 350 sample
On 360 sample
On 370 sample
On 380 sample
On 390 sample


Now for $\ell_2^{2}$ Forging

In [14]:
difs = []
#below will contain the l2 squared 
l2_difs =[]

for w_forged in w_forged_list:
  dif = np.array(w_final) - np.array(w_forged)
  difs.append(dif)

  l2_dif = np.dot(dif,dif)
  l2_difs.append(l2_dif)

l2_difs_np = np.array(l2_difs)
arg_min = np.argmin(l2_difs)

arg_max = np.argmax(l2_difs)

print(arg_min)
print(l2_difs_np[arg_min])
print(arg_max)
print(l2_difs_np[arg_max])

319
1.2060390560116795e-06
321
2.0503761842951316e-06


Now for $\ell_{\infty}$ Forging (note paper only did $\ell_2$)

In [15]:
difs = []
l_inf_difs = []

for w_forged in w_forged_list:
  dif = np.array(w_final) - np.array(w_forged)
  difs.append(dif)

  difs_abs = np.absolute(difs)
  l_inf = np.max(difs_abs)
  l_inf_difs.append(l_inf)

l_inf_difs_np = np.array(l_inf_difs)
arg_min = np.argmin(l_inf_difs_np)

arg_max = np.argmax(l_inf_difs_np)

print(arg_min)
print(l_inf_difs_np[arg_min])
print(arg_max)
print(l_inf_difs_np[arg_max])

0
0.0009264536201953888
200
0.0009606108069419861


#Iterating over M

In the following we iterate over different values of M (defined by "n_samples_list" variable) and report error in $\ell_2^{2}$

In [16]:

min_difs  = []
max_difs  = []
mean_difs = []


In [None]:
for n_samples in n_samples_list:
  print(f"Doing {n_samples} samples")

  random_batches = np.random.choice(len_list,n_samples)
  random_inds = np.random.choice(64, n_samples)

  w_forged_list = []

  for i in range(n_samples):
      if i%200 == 0:
        print(f"On {i} sample")

      with torch.no_grad():
        set_weights_fast(w_start,parameters_model)

      batch_ind = random_batches[i]
      data_ind = random_inds[i]

      imgs,labels = train_list[batch_ind]
      if data_ind >= (len(imgs)-1):
        img = imgs[-1]
        label = label[-1]
      else:
        img = imgs[data_ind]
        label = labels[data_ind]

      img = img.unsqueeze(0)
      label = label.unsqueeze(0)

      #compute that update
      optimizer.zero_grad()
      img = img.to(device)
      label = label.to(device)

      out = model(img)
      loss = criterion(out,label)
      loss.backward()
      optimizer.step()

      w_forged = weights_to_list_fast(parameters_model)
      w_forged_list.append(w_forged)

  l2_difs =[]

  for w_forged in w_forged_list:
    dif = np.array(w_final) - np.array(w_forged)
    l2_dif = np.dot(dif,dif)
    l2_difs.append(l2_dif)

  arg_min = np.argmin(l2_difs)
  min_difs.append(l2_difs[arg_min])

  arg_max = np.argmax(l2_difs)
  max_difs.append(l2_difs[arg_max])

  mean_difs.append(np.mean(l2_difs))

Plotting Results

In [None]:
plt.figure()
plt.plot(n_samples_list,min_difs)

In [None]:
plt.figure()
plt.plot(n_samples_list,max_difs)

In [None]:
plt.figure()
plt.plot(n_samples_list,mean_difs)