### 1.

In [611]:
##-- Importing Necessary Libraries --##
import time
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import pickle
import PIL.Image as Image
from torch.utils.data import DataLoader
from path import Path

### 2.

In [612]:
##-- Define the Model Class --##
class CNN_Model(nn.Module):
    def __init__(self, input_shape, hidden_units, output_shape):

        super(CNN_Model, self).__init__()
        self.body = nn.Sequential(
            nn.Conv2d(
                in_channels=input_shape, 
                out_channels=hidden_units, 
                kernel_size=5, 
                padding=5 // 2, # kernel_size/2
                stride=2),
            nn.Sigmoid(),
            nn.Conv2d(
                in_channels=hidden_units, 
                out_channels=hidden_units, 
                kernel_size=5, 
                padding=5 // 2, 
                stride=2),
            nn.Sigmoid(),
            nn.Conv2d(
                in_channels=hidden_units, 
                out_channels=hidden_units, 
                kernel_size=5, 
                padding=5 // 2, 
                stride=1),
            nn.Sigmoid()
        )

        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(
                in_features=8*8*hidden_units, 
                out_features=output_shape)
        )

    def forward(self, x):
        out = self.body(x)
        out = self.fc(out)
        return out

### 3.

In [613]:
##--  Initializes Weights & Biases With Uniform Values in [-0.5, 0.5] --##
##-- Sets Initial Random Weights & Biases Before Training Begins --##
def weights_init(m):
    try:
        if hasattr(m, "weight"):
            m.weight.data.uniform_(-0.5, 0.5)
    except Exception:
        print('warning: failed in weights_init for %s.weight' % m._get_name())
    try:
        if hasattr(m, "bias"):
            m.bias.data.uniform_(-0.5, 0.5)
    except Exception:
        print('warning: failed in weights_init for %s.bias' % m._get_name())

### 4.

In [614]:
##-- PyTorch dataset class that loads images from file paths and applies transformations --##
class Dataset_from_Image(Dataset):
    def __init__(self, imgs, labs, transform=None):
        self.imgs = imgs # img paths
        self.labs = labs # labs is ndarray
        self.transform = transform
        del imgs, labs

    def __len__(self):
        return self.labs.shape[0]

    def __getitem__(self, idx):
        lab = self.labs[idx]
        img = Image.open(self.imgs[idx])
        if img.mode != 'RGB':
            img = img.convert('RGB')
        img = self.transform(img)
        return img, lab

### 5.

In [615]:
##-- Prepares and returns an image dataset where each image has a corresponding label based on its folder name --##
def lfw_dataset(lfw_path, shape_img):
    images_all = []
    labels_all = []
    folders = os.listdir(lfw_path)
    for foldidx, fold in enumerate(folders):
        files = os.listdir(os.path.join(lfw_path, fold))
        for f in files:
            if len(f) > 4 and f[-4:] == '.jpg':
                images_all.append(os.path.join(lfw_path, fold, f))
                labels_all.append(foldidx)

    transform = transforms.Compose([transforms.Resize(size=shape_img)])
    dst = Dataset_from_Image(images_all, np.asarray(labels_all, dtype=int), transform=transform)
    return dst

### 6.

In [616]:
##-- Setting Up Device Agnostice Code
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [617]:
##-- Define Hyperparameters --##
LEARNING_RATE = 0.1
NUM_DUMMY = 1 # number of dummy samples used in an experiment
ITERATION =5000 # might represent the number of training iterations or optimization steps (not epochs)
NUM_EXP = 1 # no of times the experiment is repeated to obtain statistically significant results

In [618]:
##-- Define Dataset and Dataset Path & Plots Path --##
dataset = 'covid_xray'
data_path = Path("data/covid_xray") 
save_path = Path("results/iDLG_covid_xray")

# dataset = 'cifar100'
# root_path = '.' 
# data_path = os.path.join(root_path, './data').replace('\\', '/') 
# save_path = os.path.join(root_path, 'results/iDLG_%s'%dataset).replace('\\', '/')

print(dataset, 'data_path:', data_path)
print(dataset, 'save_path:', save_path)


if not os.path.exists('results'): # checks if the results directory exists. if not, creates it.
    os.mkdir('results')
if not os.path.exists(save_path): # checks if the save_path directory (e.g., './results/iDLG_lfw') exists. if not, creates it.
    os.mkdir(save_path)

covid_xray data_path: data/covid_xray
covid_xray save_path: results/iDLG_covid_xray


In [619]:
##-- Load Data --##
if dataset == 'covid_xray':
    shape_img = (32, 32)
    num_classes = 3 
    in_channels = 3

    transform = transforms.Compose([
        transforms.Resize((32, 32)),  
        transforms.Grayscale(num_output_channels=3)
    ])
    
    dst = datasets.ImageFolder(root=data_path, transform=transform)

# elif dataset == 'cifar100':
#     shape_img = (32, 32)
#     num_classes = 100
#     in_channel = 3
#     dst = datasets.CIFAR100(root=data_path, 
#                         download=True,
#                         transform = transforms.Compose([
#         transforms.Resize((32, 32)),  
#         transforms.Grayscale(num_output_channels=3)
#     ]))
     
else:
    exit('Unknown Dataset') 

### 7.

In [620]:
def main():

    tt = transforms.Compose([transforms.ToTensor()])
    tp = transforms.Compose([transforms.ToPILImage()]) # converts a tensor back into a PIL image

    ##-- Train DLG and iDLG --##
    for idx_net in range(NUM_EXP):
        model = CNN_Model(input_shape=in_channels, hidden_units=12, output_shape=num_classes)
        model.apply(weights_init) # apply custom weight initialization

        print(f"Running {idx_net}|{NUM_EXP} Experiment")
        model = model.to(device)
        idx_shuffle = np.random.permutation(len(dst)) # shuffles the indices of the dataset dst randomly [list].

        print(f"\niDLG, Try to generate {NUM_DUMMY} images")

        loss_fn = nn.CrossEntropyLoss().to(device)
        idx_list = []

        for imidx in range(NUM_DUMMY): # iterates NUM_DUMMY times, where each iteration selects one random sample from dst
            idx = idx_shuffle[imidx] # retrieves a randomly shuffled index from idx_shuffle
            idx_list.append(idx) # saves the selected index (idx) in idx_list, possibly for tracking or debugging.
            tmp_datum = tt(dst[idx][0]).float().to(device) # extracts the data sample from dst at index idx & applies tt(...), likely a transformation function
            tmp_datum = tmp_datum.view(1, *tmp_datum.size()) # Shape: (1, C, H, W)
            tmp_label = torch.Tensor([dst[idx][1]]).long().to(device) # extracts the label from dst at index `idx`.
            tmp_label = tmp_label.view(1, ) # reshapes tmp_label into a 1D tensor with one element (shape: (1,)).
            if imidx == 0:
                # first sample (tmp_datum, tmp_label) is directly assigned
                gt_data = tmp_datum
                gt_label = tmp_label
            else:
                # torch.cat() appends tmp_datum and tmp_label along dim=0, forming a growing batch
                gt_data = torch.cat((gt_data, tmp_datum), dim=0)
                gt_label = torch.cat((gt_label, tmp_label), dim=0)


            ##-- Compute Original Gradient --##
            out = model(gt_data) # passes the batch gt_data through the model & out contains the model’s raw outputs (logits for classification)
            y = loss_fn(out, gt_label) # computes the loss between out (model predictions) and gt_label (true labels).
            dy_dx = torch.autograd.grad(y, model.parameters()) # computes gradients of the loss (y) wrt model parameters (w and b). 
                # dy_dx is a list of gradients for each parameter in the model.
            original_dy_dx = list((_.detach().clone() for _ in dy_dx)) 
                # detaches each gradient from the computation graph to prevent unwanted autograd tracking.
                # clones the detached gradients to store an independent copy (original_dy_dx).
                # ensures original_dy_dx remains unchanged even if future operations modify dy_dx.

            ##-- Generate Dummy Data and Label --##
            dummy_data = torch.randn(gt_data.size()).to(device).requires_grad_(True)
            dummy_label = torch.randn((gt_data.shape[0], num_classes)).to(device).requires_grad_(True)

            # if (iDLG) is used, only dummy_data is optimized (label is inferred separately).
            optimizer = torch.optim.LBFGS([dummy_data, ], lr=LEARNING_RATE)
            ##-- Predict Ground-Truth Label --##
            # dim=1 specifies the axis along which an operation is applied.
            label_pred = torch.argmin(torch.sum(original_dy_dx[-2], dim=-1), dim=-1).detach().reshape((1,)).requires_grad_(False)

            history = []
            history_iters = []
            losses = []
            mses = []
            train_iters = []

            print('lr =', LEARNING_RATE)
            for iters in range(ITERATION):

                ##-- Computes the gradient difference between dummy and original gradients and updates dummy_data accordingly --##
                def closure():
                    optimizer.zero_grad() # clears previously accumulated gradients before computing new ones
                    pred = model(dummy_data) # feeds the dummy data into the model to get predictions

                    dummy_loss = loss_fn(pred, label_pred) # in iDLG, the loss is computed against label_pred, which was inferred earlier

                    dummy_dy_dx = torch.autograd.grad(dummy_loss, model.parameters(), create_graph=True) # computes gradients of dummy_loss w.r.t. model parameters

                    grad_diff = 0
                    for gx, gy in zip(dummy_dy_dx, original_dy_dx):
                        grad_diff += ((gx - gy) ** 2).sum() # accumulates the total gradient mismatch (grad_diff).
                    grad_diff.backward() # computes gradients of grad_diff w.r.t. dummy_data, allowing LBFGS to update dummy data
                    return grad_diff
                    # This closure() function is called repeatedly by LBFGS optimizer to refine dummy_data

                ##-- This is part of an LBFGS optimization loop to reconstruct data from gradients in DLG/iDLG. --##
                optimizer.step(closure) # optimizes dummy_data so that its gradients match the original gradients.
                current_loss = closure().item() # computes current gradient difference loss (grad_diff) & .item() extracts the scalar loss value
                train_iters.append(iters) # stores the current iteration count for visualization or analysis later.
                losses.append(current_loss) # appends current_loss to losses list to track progress.
                mses.append(torch.mean((dummy_data-gt_data)**2).item()) # (MSE) between the reconstructed dummy_data and original gt_data


                if iters % int(ITERATION / 50) == 0:
                    current_time = str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime()))
                    print(current_time, iters, 'loss = %.8f, mse = %.8f' %(current_loss, mses[-1]))
                    history.append([tp(dummy_data[imidx].cpu()) for imidx in range(NUM_DUMMY)])
                    history_iters.append(iters)

                    for imidx in range(NUM_DUMMY):
                        plt.figure(figsize=(15, 8))

                        plt.subplot(5, 10, 1)
                        plt.imshow(tp(gt_data[imidx].cpu())) # show original image
                        plt.title("original")
                        plt.axis(False)

                        for i in range(min(len(history), 48)):  # ensure i + 2 does not exceed 60
                            plt.subplot(5, 10, i + 2)
                            plt.imshow(history[i][imidx])
                            plt.title(f"iter={history_iters[i]}")
                            plt.axis(False)

                        plt.savefig('%s/iDLG_on_%s_%05d.png' % (save_path, idx_list, idx_list[imidx]))
                        plt.close()

                    if current_loss < 0.000001: # converge
                        break

            loss_iDLG = losses
            label_iDLG = label_pred.item()
            mse_iDLG = mses


        print('idx_list:', idx_list)
        print('loss_iDLG:', loss_iDLG[-1])
        print('mse_iDLG:', mse_iDLG[-1])
        print('gt_label:', gt_label.detach().cpu().data.numpy(),'lab_iDLG:', label_iDLG)

        print('--------------------------------------------\n')

In [621]:
if __name__ == '__main__':
    main()

Running 0|1 Experiment

iDLG, Try to generate 1 images
lr = 0.1
[2025-03-26 02:29:45] 0 loss = 105.90216827, mse = 1.12503135
[2025-03-26 02:29:53] 100 loss = 0.00022188, mse = 0.00130259
[2025-03-26 02:30:04] 200 loss = 0.00000154, mse = 0.00000795
[2025-03-26 02:30:06] 300 loss = 0.00000153, mse = 0.00000787
[2025-03-26 02:30:07] 400 loss = 0.00000153, mse = 0.00000787
[2025-03-26 02:30:07] 500 loss = 0.00000153, mse = 0.00000787
[2025-03-26 02:30:08] 600 loss = 0.00000153, mse = 0.00000787
[2025-03-26 02:30:09] 700 loss = 0.00000153, mse = 0.00000787
[2025-03-26 02:30:10] 800 loss = 0.00000153, mse = 0.00000787
[2025-03-26 02:30:11] 900 loss = 0.00000153, mse = 0.00000787
[2025-03-26 02:30:11] 1000 loss = 0.00000153, mse = 0.00000787
[2025-03-26 02:30:13] 1100 loss = 0.00000153, mse = 0.00000787
[2025-03-26 02:30:15] 1200 loss = 0.00000153, mse = 0.00000787
[2025-03-26 02:30:15] 1300 loss = 0.00000153, mse = 0.00000787
[2025-03-26 02:30:16] 1400 loss = 0.00000153, mse = 0.00000787
[