In [None]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize

from core import multi_evaluate, exp_aggregator, IdentityConv2d
import matplotlib.pyplot as plt
import random
from scipy.stats import norm
import numpy as np

In [None]:
def AGGP(tensor, a):
    p_l = 0.01
    p_u = 0.95
    c = 16

    # Apply Equation (14)
    p_keep = ((a - 1)**2 * (p_u - p_l)) / ((c - 2)**2) + p_l
    p_prune = 1 - p_keep
    
    num_elements = tensor.numel()
    num_elements_to_prune = int(num_elements * p_prune)
    sorted_indices = torch.argsort(torch.abs(tensor))
    tensor[sorted_indices[:num_elements_to_prune]] = 0
    keep_indices = sorted_indices[num_elements_to_prune:]
    
    # Determine the number of additional elements to prune
    num_additional_to_prune = int(keep_indices.numel() * 0.75)
    
    # Randomly select 50% of the keep_indices to prune
    additional_prune_indices = keep_indices[torch.randperm(keep_indices.numel())[:num_additional_to_prune]]
    
    # Set these randomly selected elements to zero
    tensor[additional_prune_indices] = 0
    
    return tensor

In [None]:
def plot20(batch, name=""):
    # Create a 4x5 grid of subplots
    fig, axs = plt.subplots(4, 5, figsize=(11.5, 9))
    fig.subplots_adjust(hspace=0.1, wspace=0)
    
    # Inverse normalization parameters
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    
    # Iterate over each subplot and plot the corresponding image
    img_idx = 0
    for i in range(4):
        for j in range(5):
            axs[i, j].axis('off')  # Turn off axis for each subplot in the grid
            
            # Undo normalization
            img = batch[img_idx].permute(1, 2, 0).cpu().numpy()  # Convert from CHW to HWC format and to numpy array
            img = (img * np.array(std)) + np.array(mean)  # Apply inverse normalization
            img[img==mean] = 0 # pixels that had value == 0 did not carry any information, set them to black for better visualization
            img = np.clip(img, 0, 1)  # Clip values to be in the range [0, 1]
            
            axs[i, j].imshow(img)
            img_idx += 1
    
    # Hide the main plot axes
    plt.axis('off')
    plt.suptitle(name)
    
    # Show the plot
    #plt.savefig(name)
    plt.show()

In [None]:
# CONFIG
batch_size = 20
num_neurons = 200
data_path = 'data/imagenet'

transforms = Compose([
    Resize(size=256),
    CenterCrop(size=(224, 224)),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


base_dataset = torchvision.datasets.ImageNet(
    root=data_path, split="val", transform=transforms
)

val_loader = DataLoader(base_dataset, batch_size=batch_size, shuffle=True)

## True user data

In [None]:
user_data, label = next(iter(val_loader))
plot20(user_data, "True user data")

### Basic setup

In [None]:
device = torch.device("cpu")
layer = nn.Linear(3 * 224 * 224, num_neurons).to(device)

model = IdentityConv2d(layer, 1000)
criterion = nn.CrossEntropyLoss()

## Passive data leak of first 20 neurons

In [None]:
output = model(user_data)
loss = criterion(output, label)
loss.backward()

w_grad = layer.weight.grad.clone()
b_grad = layer.bias.grad.clone()

intermediate = w_grad / b_grad.view(-1, 1)
intermediate = intermediate.reshape(-1, 3, 224, 224).to("cpu")

first_20_neurons = intermediate[:20].reshape(-1, 3, 224, 224).to("cpu")
plot20(first_20_neurons, "Passive leak of first 20 neurons (benign network)")

### Impact of gradient pruning on passive leak

In [None]:
layer.weight.grad.zero_()
layer.bias.grad.zero_()

output = model(user_data)
loss = criterion(output, label)
loss.backward()

w_grad = layer.weight.grad.clone()
b_grad = layer.bias.grad.clone()

# gradient pruning
activation_counts = model.activation_counts
num_features = w_grad.shape[0]
print(activation_counts[:20])
for i, a in enumerate(activation_counts[:20]):
    if a == 0 or a > 10:
        continue
    with torch.no_grad():
        w_grad[i] = AGGP(w_grad[i], a) 

# reconstruction
intermediate = w_grad / b_grad.view(-1, 1)
intermediate = intermediate.reshape(-1, 3, 224, 224).to("cpu")

first_20_neurons = intermediate[:20].reshape(-1, 3, 224, 224).to("cpu")
plot20(first_20_neurons, "Passive data leakage")

## Perform active attack

In [None]:
device = torch.device("cpu")
layer = nn.Linear(3 * 224 * 224, num_neurons).to(device)

with torch.no_grad():
    layer.weight.data.normal_()
    
# QBI
optimal_bias = norm.ppf(1 / batch_size) * np.sqrt(3 * 224 * 224)
layer.bias.data.fill_(optimal_bias)

model = IdentityConv2d(layer, 1000)
criterion = nn.CrossEntropyLoss()

## Plot active data leak of first 20 neurons

In [None]:
output = model(user_data)
loss = criterion(output, label)
loss.backward()

w_grad = model.fc1.weight.grad.clone()
b_grad = model.fc1.bias.grad.clone()

intermediate = w_grad / b_grad.view(-1, 1)
intermediate = intermediate.reshape(-1, 3, 224, 224).to("cpu")

first_20_neurons = intermediate[:20].reshape(-1, 3, 224, 224).to("cpu")
plot20(first_20_neurons, "Active leak of first 20 neurons in maliciously initialized model")

## Plot impact of activation based gradient pruning

In [None]:
layer.weight.grad.zero_()
layer.bias.grad.zero_()

output = model(user_data)
loss = criterion(output, label)
loss.backward()

w_grad = model.fc1.weight.grad.clone()
b_grad = model.fc1.bias.grad.clone()

activation_counts = model.activation_counts
num_features = w_grad.shape[0]
print(activation_counts[:20])
for i, a in enumerate(activation_counts[:20]):
    if a == 0 or a > 10:
        continue
    with torch.no_grad():
        w_grad[i] = AGGP(w_grad[i], a) 

# reconstruction
intermediate = w_grad / b_grad.view(-1, 1)
intermediate = intermediate.reshape(-1, 3, 224, 224).to("cpu")

first_20_neurons = intermediate[:20].reshape(-1, 3, 224, 224).to("cpu")
plot20(first_20_neurons, "Impact of AGGP on Active leak")