In [20]:
import time
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import pickle
import PIL.Image as Image

In [21]:
dataset = datasets.CIFAR100("../Data", download=True)
tt = transforms.Compose([
    transforms.Resize(32),
    transforms.CenterCrop(32),
    transforms.ToTensor()
])
tp = transforms.ToPILImage()

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
print("Running on %s" % device)

def label_to_onehot(target, num_classes=100):
    target = torch.unsqueeze(target, 1)
    onehot_target = torch.zeros(target.size(0), num_classes, device=target.device)
    onehot_target.scatter_(1, target, 1)
    return onehot_target

def cross_entropy_for_onehot(pred, target):
    return torch.mean(torch.sum(- target * F.log_softmax(pred, dim=-1), 1))

Files already downloaded and verified
Running on cuda


In [22]:

class LeNet(nn.Module):
    def __init__(self, channel=3, hideen=768, num_classes=10):
        super(LeNet, self).__init__()
        act = nn.Sigmoid
        self.body = nn.Sequential(
            nn.Conv2d(channel, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=1),
            act(),
        )
        self.fc = nn.Sequential(
            nn.Linear(hideen, num_classes)
        )

    def forward(self, x):
        out = self.body(x)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


In [23]:
def weights_init(m):
    try:
        if hasattr(m, "weight"):
            m.weight.data.uniform_(-0.5, 0.5)
    except Exception:
        print('warning: failed in weights_init for %s.weight' % m._get_name())
    try:
        if hasattr(m, "bias"):
            m.bias.data.uniform_(-0.5, 0.5)
    except Exception:
        print('warning: failed in weights_init for %s.bias' % m._get_name())


In [24]:
dataset = 'cifar100'
root_path = '.'
data_path = os.path.join(root_path, '../Data')
save_path = os.path.join(root_path, 'results/DLG_%s'%dataset)

lr = 1.0
num_dummy = 1
Iteration = 300
num_exp = 10

use_cuda = torch.cuda.is_available()
device = 'cuda' if use_cuda else 'cpu'

tt = transforms.Compose([transforms.ToTensor()])
tp = transforms.Compose([transforms.ToPILImage()])

print(dataset, 'root_path:', root_path)
print(dataset, 'data_path:', data_path)
print(dataset, 'save_path:', save_path)

if not os.path.exists('results'):
    os.mkdir('results')
if not os.path.exists(save_path):
    os.mkdir(save_path)



''' load data '''

if dataset == 'cifar100':
    shape_img = (32, 32)
    num_classes = 100
    channel = 3
    hidden = 768
    dataset = datasets.CIFAR100(data_path, download=True)



cifar100 root_path: .
cifar100 data_path: ./../Data
cifar100 save_path: ./results/DLG_cifar100
Files already downloaded and verified


In [25]:


''' train DLG and iDLG '''
for idx_net in range(num_exp):
    net = LeNet(channel=channel, hideen=hidden, num_classes=num_classes)
    net.apply(weights_init)

    print('running %d|%d experiment'%(idx_net, num_exp))########
    net = net.to(device)
    idx_shuffle = np.random.permutation(len(dataset))

    print('%s, Try to generate %d images' % ("DLG", num_dummy))

    criterion = nn.CrossEntropyLoss().to(device)
    imidx_list = []

    for imidx in range(num_dummy):
        idx = idx_shuffle[imidx]
        imidx_list.append(idx)
        tmp_datum = tt(dataset[idx][0]).float().to(device)
        tmp_datum = tmp_datum.view(1, *tmp_datum.size())
        tmp_label = torch.Tensor([dataset[idx][1]]).long().to(device)
        tmp_label = tmp_label.view(1, )
        if imidx == 0:
            gt_data = tmp_datum
            gt_label = tmp_label
        else:
            gt_data = torch.cat((gt_data, tmp_datum), dim=0)
            gt_label = torch.cat((gt_label, tmp_label), dim=0)


    # compute original gradient
    out = net(gt_data)
    y = criterion(out, gt_label)
    dy_dx = torch.autograd.grad(y, net.parameters())
    original_dy_dx = list((_.detach().clone() for _ in dy_dx))

    # generate dummy data and label
    dummy_data = torch.randn(gt_data.size()).to(device).requires_grad_(True)
    dummy_label = torch.randn((gt_data.shape[0], num_classes)).to(device).requires_grad_(True)

    optimizer = torch.optim.LBFGS([dummy_data, dummy_label], lr=lr)


    history = []
    history_iters = []
    losses = []
    mses = []
    train_iters = []

    print('lr =', lr)
    for iters in range(Iteration):

        def closure():
            optimizer.zero_grad()
            pred = net(dummy_data)

            dummy_loss = - torch.mean(torch.sum(torch.softmax(dummy_label, -1) * torch.log(torch.softmax(pred, -1)), dim=-1))
            # dummy_loss = criterion(pred, gt_label)


            dummy_dy_dx = torch.autograd.grad(dummy_loss, net.parameters(), create_graph=True)

            grad_diff = 0
            for gx, gy in zip(dummy_dy_dx, original_dy_dx):
                grad_diff += ((gx - gy) ** 2).sum()
            grad_diff.backward()
            return grad_diff

        optimizer.step(closure)
        current_loss = closure().item()
        train_iters.append(iters)
        losses.append(current_loss)
        mses.append(torch.mean((dummy_data-gt_data)**2).item())


        if iters % int(Iteration / 30) == 0:
            current_time = str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime()))
            print(current_time, iters, 'loss = %.8f, mse = %.8f' %(current_loss, mses[-1]))
            history.append([tp(dummy_data[imidx].cpu()) for imidx in range(num_dummy)])
            history_iters.append(iters)

            for imidx in range(num_dummy):
                plt.figure(figsize=(12, 8))
                plt.subplot(3, 10, 1)
                plt.imshow(tp(gt_data[imidx].cpu()))
                for i in range(min(len(history), 29)):
                    plt.subplot(3, 10, i + 2)
                    plt.imshow(history[i][imidx])
                    plt.title('iter=%d' % (history_iters[i]))
                    plt.axis('off')

                    plt.savefig('%s/DLG_on_%s_%05d.png' % (save_path, imidx_list, imidx_list[imidx]))
                    plt.close()


            if current_loss < 0.000001: # converge
                break

    loss_DLG = losses
    label_DLG = torch.argmax(dummy_label, dim=-1).detach().item()
    mse_DLG = mses


    print('imidx_list:', imidx_list)
    print('loss_DLG:', loss_DLG[-1], )
    print('mse_DLG:', mse_DLG[-1])
    print('gt_label:', gt_label.detach().cpu().data.numpy(), 'lab_DLG:', label_DLG, )

    print('----------------------\n\n')


running 0|10 experiment
DLG, Try to generate 1 images
lr = 1.0
[2025-02-09 06:06:04] 0 loss = 59.36498260, mse = 1.30176103
[2025-02-09 06:06:06] 10 loss = 1.09616494, mse = 0.47565717
[2025-02-09 06:06:09] 20 loss = 0.08930164, mse = 0.13494834
[2025-02-09 06:06:11] 30 loss = 0.01289919, mse = 0.03734213
[2025-02-09 06:06:14] 40 loss = 0.00215783, mse = 0.00954075
[2025-02-09 06:06:16] 50 loss = 0.00048877, mse = 0.00256131
[2025-02-09 06:06:19] 60 loss = 0.00011931, mse = 0.00069342
[2025-02-09 06:06:21] 70 loss = 0.00003694, mse = 0.00022138
[2025-02-09 06:06:24] 80 loss = 0.00001551, mse = 0.00008086
[2025-02-09 06:06:26] 90 loss = 0.00000980, mse = 0.00003897
[2025-02-09 06:06:29] 100 loss = 0.00000671, mse = 0.00001751
[2025-02-09 06:06:31] 110 loss = 0.00000538, mse = 0.00000969
[2025-02-09 06:06:33] 120 loss = 0.00000444, mse = 0.00000492
[2025-02-09 06:06:36] 130 loss = 0.00000393, mse = 0.00000299
[2025-02-09 06:06:38] 140 loss = 0.00000356, mse = 0.00000246
[2025-02-09 06:06

In [None]:
-------