<a href="https://colab.research.google.com/github/bahador1/BahadorColabNotes/blob/main/dlg.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import time
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import pickle
import PIL.Image as Image

In [None]:
class LeNet(nn.Module):
    def __init__(self, channel=3, hideen=768, num_classes=10):
        super(LeNet, self).__init__()
        act = nn.Sigmoid
        self.body = nn.Sequential(
            nn.Conv2d(channel, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=1),
            act(),
        )
        self.fc = nn.Sequential(
            nn.Linear(hideen, num_classes)
        )

    def forward(self, x):
        out = self.body(x)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


def weights_init(m):
    try:
        if hasattr(m, "weight"):
            m.weight.data.uniform_(-0.5, 0.5)
    except Exception:
        print('warning: failed in weights_init for %s.weight' % m._get_name())
    try:
        if hasattr(m, "bias"):
            m.bias.data.uniform_(-0.5, 0.5)
    except Exception:
        print('warning: failed in weights_init for %s.bias' % m._get_name())


In [None]:
class Dataset_from_Image(Dataset):
    def __init__(self, imgs, labs, transform=None):
        self.imgs = imgs # img paths
        self.labs = labs # labs is ndarray
        self.transform = transform
        del imgs, labs

    def __len__(self):
        return self.labs.shape[0]

    def __getitem__(self, idx):
        lab = self.labs[idx]
        img = Image.open(self.imgs[idx])
        if img.mode != 'RGB':
            img = img.convert('RGB')
        img = self.transform(img)
        return img, lab



In [None]:
# def main():
dataset = 'MNIST'
root_path = '.'
data_path = os.path.join(root_path, '../data').replace('\\', '/')
save_path = os.path.join(root_path, 'results/iDLG_%s'%dataset).replace('\\', '/')

# lr = 1.0
num_dummy = 1
Iteration = 300
# num_exp = 1000

use_cuda = torch.cuda.is_available()
device = 'cuda' if use_cuda else 'cpu'

tt = transforms.Compose([transforms.ToTensor()])
tp = transforms.Compose([transforms.ToPILImage()])

print(dataset, 'root_path:', root_path)
print(dataset, 'data_path:', data_path)
print(dataset, 'save_path:', save_path)

if not os.path.exists('results'):
    os.mkdir('results')
if not os.path.exists(save_path):
    os.mkdir(save_path)



''' load data '''
if dataset == 'MNIST':
    shape_img = (28, 28)
    num_classes = 10
    channel = 1
    hidden = 588
    dst = datasets.MNIST(data_path, download=True)

MNIST root_path: .
MNIST data_path: ./../data
MNIST save_path: ./results/iDLG_MNIST


100%|██████████| 9.91M/9.91M [00:00<00:00, 58.9MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.71MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.8MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.08MB/s]


In [None]:
''' train DLG and iDLG '''
# for idx_net in range(num_exp):
net = LeNet(channel=channel, hideen=hidden, num_classes=num_classes)
net.apply(weights_init)

net = net.to(device)

In [None]:
idx_shuffle = np.random.permutation(len(dst))

# for method in ['DLG', 'iDLG']:
method = 'DLG'
print('%s, Try to generate %d images' % (method, num_dummy))

DLG, Try to generate 1 images


In [None]:
imidx_list = []
for imidx in range(num_dummy):
    idx = idx_shuffle[imidx]
    imidx_list.append(idx)
    tmp_datum = tt(dst[idx][0]).float().to(device)
    tmp_datum = tmp_datum.view(1, *tmp_datum.size())
    tmp_label = torch.Tensor([dst[idx][1]]).long().to(device)
    tmp_label = tmp_label.view(1, )
    if imidx == 0:
        gt_data = tmp_datum
        gt_label = tmp_label
    # else:
    #     gt_data = torch.cat((gt_data, tmp_datum), dim=0)
    #     gt_label = torch.cat((gt_label, tmp_label), dim=0)

In [None]:
criterion = nn.CrossEntropyLoss().to(device)

# compute original gradient
out = net(gt_data)
loss = criterion(out, gt_label)
dloss_dx = torch.autograd.grad(loss , net.parameters())
original_dy_dx = list((_.detach().clone() for _ in dloss_dx))

# generate dummy data and label
dummy_data = torch.randn(gt_data.size()).to(device).requires_grad_(True)
dummy_label = torch.randn((gt_data.shape[0], num_classes)).to(device).requires_grad_(True)

#this optimizer is for updating dummy data.
optimizer = torch.optim.Adam([dummy_data, dummy_label], lr=1.0)

#point 1
history = []
history_iters = []
losses = []
mses = []
train_iters = []

for iters in range(Iteration):

    def closure():
        optimizer.zero_grad()
        pred = net(dummy_data)
        # if method == 'DLG':
        dummy_loss = - torch.mean(torch.sum(torch.softmax(dummy_label, -1) * torch.log(torch.softmax(pred, -1)), dim=-1))
        # dummy_loss = criterion(pred, gt_label)

        dummy_dy_dx = torch.autograd.grad(dummy_loss, net.parameters(), create_graph=True)

        grad_diff = 0
        for gx, gy in zip(dummy_dy_dx, original_dy_dx):
            grad_diff += ((gx - gy) ** 2).sum()

        grad_diff.backward()
        return grad_diff

    optimizer.step(closure) # b sabke LBFGS
    current_loss = closure().item()
    train_iters.append(iters)
    losses.append(current_loss)
    mses.append(torch.mean((dummy_data-gt_data)**2).item())


    if iters % int(Iteration / 30) == 0:
        current_time = str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime()))
        print(current_time, iters, 'D_i = %.8f, rec error = %.8f' %(current_loss, mses[-1]))
        history.append([tp(dummy_data[imidx].cpu()) for imidx in range(num_dummy)])
        history_iters.append(iters)

        for imidx in range(num_dummy):
            plt.figure(figsize=(12, 8))
            plt.subplot(3, 10, 1)
            plt.imshow(tp(gt_data[imidx].cpu()))
            for i in range(min(len(history), 29)):
                plt.subplot(3, 10, i + 2)
                plt.imshow(history[i][imidx])
                plt.title('iter=%d' % (history_iters[i]))
                plt.axis('off')
            if method == 'DLG':
                plt.savefig('%s/DLG_on_%s_%05d.png' % (save_path, imidx_list, imidx_list[imidx]))
                plt.close()


        if current_loss < 0.000001: # converge
            print("it has converged")
            break

if method == 'DLG':
    loss_DLG = losses
    label_DLG = torch.argmax(dummy_label, dim=-1).detach().item()
    mse_DLG = mses


print('imidx_list:', imidx_list)
print('loss_DLG:', loss_DLG[-1],)
print('mse_DLG:', mse_DLG[-1], )
print('gt_label:', gt_label.detach().cpu().data.numpy(), 'lab_DLG:', label_DLG, )

print('----------------------\n\n')

[2025-04-06 15:13:34] 0 D_i = 46.45686340, rec error = 2.02256799
[2025-04-06 15:13:34] 10 D_i = 18.11983490, rec error = 17.65144157
[2025-04-06 15:13:36] 20 D_i = 12.07498360, rec error = 26.97108650
[2025-04-06 15:13:36] 30 D_i = 10.43585396, rec error = 31.13733673
[2025-04-06 15:13:36] 40 D_i = 9.93463135, rec error = 32.86661530
[2025-04-06 15:13:37] 50 D_i = 9.51446533, rec error = 33.58797455
[2025-04-06 15:13:37] 60 D_i = 9.20694733, rec error = 33.92465210
[2025-04-06 15:13:37] 70 D_i = 8.90260029, rec error = 34.12365341
[2025-04-06 15:13:38] 80 D_i = 8.69240761, rec error = 34.26870346
[2025-04-06 15:13:38] 90 D_i = 8.52081490, rec error = 34.38152695
[2025-04-06 15:13:39] 100 D_i = 8.37277222, rec error = 34.46982956
[2025-04-06 15:13:39] 110 D_i = 8.23149872, rec error = 34.54435730
[2025-04-06 15:13:40] 120 D_i = 8.09077168, rec error = 34.61719131
[2025-04-06 15:13:40] 130 D_i = 7.95240927, rec error = 34.69346237
[2025-04-06 15:13:41] 140 D_i = 7.82345533, rec error = 

### `create_graph` in `torch.autograd.grad()`

In [1]:
import torch

In [6]:
x = torch.tensor(2., requires_grad = True)
y = x**2

In [7]:
grad = torch.autograd.grad(y, x )# dy/dx

In [8]:
torch.autograd.grad(grad,x)

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

### here is the example shows why it is hard to calculate the gradients of intermediate tensors using `.backward()` compare to `torch.autograd.grad()`.


In [50]:
import torch

In [51]:
x = torch.tensor(2., requires_grad = True)
y = x**2
z = y**3

In [52]:
z.backward()

In [53]:
y.grad

  y.grad


### `retain_graph` is also useless for calculating intermediate tensors' grads

In [60]:
import torch

In [67]:
x = torch.tensor(2., requires_grad = True)
y = x**2
z = y**3

In [68]:
z.backward(retain_graph =True)

In [69]:
y.grad

  y.grad


In [70]:
y.backward(retain_graph = True)

In [71]:
y.grad

  y.grad


In [82]:
import torch
x = torch.tensor(2.0, requires_grad=True)
y = x ** 2
z = y ** 3

# Get ∂z/∂y directly
dz_dx  = torch.autograd.grad(z, x, create_graph = True)

In [80]:
dz_dy = torch.autograd.grad(z, y)

In [81]:
dz_dy

(tensor(48.),)