<a href="https://colab.research.google.com/github/bahador1/BahadorColabNotes/blob/main/dlg.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import time
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import pickle
import PIL.Image as Image

In [2]:
class LeNet(nn.Module):
    def __init__(self, channel=3, hideen=768, num_classes=10):
        super(LeNet, self).__init__()
        act = nn.Sigmoid
        self.body = nn.Sequential(
            nn.Conv2d(channel, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=1),
            act(),
        )
        self.fc = nn.Sequential(
            nn.Linear(hideen, num_classes)
        )

    def forward(self, x):
        out = self.body(x)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


def weights_init(m):
    try:
        if hasattr(m, "weight"):
            m.weight.data.uniform_(-0.5, 0.5)
    except Exception:
        print('warning: failed in weights_init for %s.weight' % m._get_name())
    try:
        if hasattr(m, "bias"):
            m.bias.data.uniform_(-0.5, 0.5)
    except Exception:
        print('warning: failed in weights_init for %s.bias' % m._get_name())


In [3]:
class Dataset_from_Image(Dataset):
    def __init__(self, imgs, labs, transform=None):
        self.imgs = imgs # img paths
        self.labs = labs # labs is ndarray
        self.transform = transform
        del imgs, labs

    def __len__(self):
        return self.labs.shape[0]

    def __getitem__(self, idx):
        lab = self.labs[idx]
        img = Image.open(self.imgs[idx])
        if img.mode != 'RGB':
            img = img.convert('RGB')
        img = self.transform(img)
        return img, lab



In [4]:
# def main():
dataset = 'MNIST'
root_path = '.'
data_path = os.path.join(root_path, '../data').replace('\\', '/')
save_path = os.path.join(root_path, 'results/iDLG_%s'%dataset).replace('\\', '/')

# lr = 1.0
num_dummy = 1
Iteration = 300
# num_exp = 1000

use_cuda = torch.cuda.is_available()
device = 'cuda' if use_cuda else 'cpu'

tt = transforms.Compose([transforms.ToTensor()])
tp = transforms.Compose([transforms.ToPILImage()])

print(dataset, 'root_path:', root_path)
print(dataset, 'data_path:', data_path)
print(dataset, 'save_path:', save_path)

if not os.path.exists('results'):
    os.mkdir('results')
if not os.path.exists(save_path):
    os.mkdir(save_path)



''' load data '''
if dataset == 'MNIST':
    shape_img = (28, 28)
    num_classes = 10
    channel = 1
    hidden = 588
    dst = datasets.MNIST(data_path, download=True)

MNIST root_path: .
MNIST data_path: ./../data
MNIST save_path: ./results/iDLG_MNIST


100%|██████████| 9.91M/9.91M [00:00<00:00, 58.9MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.71MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.8MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.08MB/s]


In [5]:
''' train DLG and iDLG '''
# for idx_net in range(num_exp):
net = LeNet(channel=channel, hideen=hidden, num_classes=num_classes)
net.apply(weights_init)

net = net.to(device)

In [6]:
idx_shuffle = np.random.permutation(len(dst))

# for method in ['DLG', 'iDLG']:
method = 'DLG'
print('%s, Try to generate %d images' % (method, num_dummy))

DLG, Try to generate 1 images


In [7]:
imidx_list = []
for imidx in range(num_dummy):
    idx = idx_shuffle[imidx]
    imidx_list.append(idx)
    tmp_datum = tt(dst[idx][0]).float().to(device)
    tmp_datum = tmp_datum.view(1, *tmp_datum.size())
    tmp_label = torch.Tensor([dst[idx][1]]).long().to(device)
    tmp_label = tmp_label.view(1, )
    if imidx == 0:
        gt_data = tmp_datum
        gt_label = tmp_label
    # else:
    #     gt_data = torch.cat((gt_data, tmp_datum), dim=0)
    #     gt_label = torch.cat((gt_label, tmp_label), dim=0)

In [16]:
criterion = nn.CrossEntropyLoss().to(device)

# compute original gradient
out = net(gt_data)
loss = criterion(out, gt_label)
dloss_dx = torch.autograd.grad(loss , net.parameters())
original_dy_dx = list((_.detach().clone() for _ in dloss_dx))

# generate dummy data and label
dummy_data = torch.randn(gt_data.size()).to(device).requires_grad_(True)
dummy_label = torch.randn((gt_data.shape[0], num_classes)).to(device).requires_grad_(True)

#this optimizer is for updating dummy data.
optimizer = torch.optim.Adam([dummy_data, dummy_label], lr=0.1)#LBFGS

#point 1
history = []
history_iters = []
losses = []
mses = []
train_iters = []

for iters in range(Iteration):

    def closure():
        optimizer.zero_grad()
        pred = net(dummy_data)
        # if method == 'DLG':
        dummy_loss = - torch.mean(torch.sum(torch.softmax(dummy_label, -1) * torch.log(torch.softmax(pred, -1)), dim=-1))
        # dummy_loss = criterion(pred, gt_label)

        dummy_dy_dx = torch.autograd.grad(dummy_loss, net.parameters(), create_graph=False)

        grad_diff = 0
        for gx, gy in zip(dummy_dy_dx, original_dy_dx):
            grad_diff += ((gx - gy) ** 2).sum()

        grad_diff.backward()
        return grad_diff

    optimizer.step(closure)
    current_loss = closure().item()
    train_iters.append(iters)
    losses.append(current_loss)
    mses.append(torch.mean((dummy_data-gt_data)**2).item())


    if iters % int(Iteration / 30) == 0:
        current_time = str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime()))
        print(current_time, iters, 'D_i = %.8f, rec error = %.8f' %(current_loss, mses[-1]))
        history.append([tp(dummy_data[imidx].cpu()) for imidx in range(num_dummy)])
        history_iters.append(iters)

        for imidx in range(num_dummy):
            plt.figure(figsize=(12, 8))
            plt.subplot(3, 10, 1)
            plt.imshow(tp(gt_data[imidx].cpu()))
            for i in range(min(len(history), 29)):
                plt.subplot(3, 10, i + 2)
                plt.imshow(history[i][imidx])
                plt.title('iter=%d' % (history_iters[i]))
                plt.axis('off')
            if method == 'DLG':
                plt.savefig('%s/DLG_on_%s_%05d.png' % (save_path, imidx_list, imidx_list[imidx]))
                plt.close()


        if current_loss < 0.000001: # converge
            print("it has converged")
            break

if method == 'DLG':
    loss_DLG = losses
    label_DLG = torch.argmax(dummy_label, dim=-1).detach().item()
    mse_DLG = mses


print('imidx_list:', imidx_list)
print('loss_DLG:', loss_DLG[-1],)
print('mse_DLG:', mse_DLG[-1], )
print('gt_label:', gt_label.detach().cpu().data.numpy(), 'lab_DLG:', label_DLG, )

print('----------------------\n\n')

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn