In [None]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from scipy.stats import norm

from core import multi_evaluate, exp_aggregator, IdentityConv2d
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# CONFIG
batch_size = 20
num_neurons = 200
data_path = 'data'

transforms = Compose([
    ToTensor(),
    Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])
])


base_dataset = torchvision.datasets.CIFAR10(
    root='data', train=True, transform=transforms, download=True
)

val_loader = DataLoader(base_dataset, batch_size=batch_size, shuffle=True)

In [None]:
def plot20(batch, name=""):
    # Create a 4x5 grid of subplots
    fig, axs = plt.subplots(4, 5, figsize=(11.5, 9))
    fig.subplots_adjust(hspace=0.1, wspace=0)
    
    # Inverse normalization parameters
    mean = [0.4914, 0.4822, 0.4465]
    std = [0.247, 0.243, 0.261]
    
    # Iterate over each subplot and plot the corresponding image
    img_idx = 0
    for i in range(4):
        for j in range(5):
            axs[i, j].axis('off')  # Turn off axis for each subplot in the grid
            
            # Undo normalization
            img = batch[img_idx].permute(1, 2, 0).cpu().numpy()  # Convert from CHW to HWC format and to numpy array
            img = (img * np.array(std)) + np.array(mean)  # Apply inverse normalization
            img = np.clip(img, 0, 1)  # Clip values to be in the range [0, 1]
            
            axs[i, j].imshow(img)
            img_idx += 1
    
    # Hide the main plot axes
    plt.axis('off')
    plt.suptitle(name)
    
    # Show the plot
    #plt.savefig(name)
    plt.show()

In [None]:
user_data, label = next(iter(val_loader))
plot20(user_data, "True user data")

In [None]:
device = torch.device("cpu")
layer = nn.Linear(3 * 32 * 32, num_neurons).to(device)

with torch.no_grad():
    layer.weight.data.normal_()
    
# QBI
optimal_bias = norm.ppf(1 / batch_size) * np.sqrt(3 * 32 * 32)
layer.bias.data.fill_(optimal_bias)

model = IdentityConv2d(layer, 10)
criterion = nn.CrossEntropyLoss()

In [None]:
output = model(user_data)
loss = criterion(output, label)
loss.backward()

w_grad = model.fc1.weight.grad.clone()
b_grad = model.fc1.bias.grad.clone()

intermediate = w_grad / b_grad.view(-1, 1)
intermediate = intermediate.reshape(-1, 3, 32, 32).to("cpu")

In [None]:
result = [torch.zeros_like(user_data[0]) for _ in range(20)]

for i, user_image in enumerate(user_data):
    for n_grad in intermediate:
        if torch.allclose(user_image, n_grad):
            result[i] = n_grad
            print(f"Found image {i}")
            break
plot20(result, "Reconstructed data")