In [1]:
!pip install -q pytorch-msssim

In [2]:
import seaborn as sns
from matplotlib import ticker
import time
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torchvision
from torchvision import datasets, transforms
import pickle
import PIL.Image as Image
from pytorch_msssim import ssim
from matplotlib import rc
import csv
from os.path import join as oj

In [3]:
use_cuda = torch.cuda.is_available()
device = 'cuda' if use_cuda else 'cpu'

In [None]:
import torch.distributions as dist

class LeNet(nn.Module):
    def __init__(self, channel=3, hideen=768, num_classes=10):
        super(LeNet, self).__init__()
        act = nn.Sigmoid
        self.body = nn.Sequential(
            nn.Conv2d(channel, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=1),
            act(),
        )
        self.fc = nn.Sequential(
            nn.Linear(hideen, num_classes)
        )

    def forward(self, x):
        out = self.body(x)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


def weights_init(m):
    try:
        if hasattr(m, "weight"):
            m.weight.data.uniform_(-0.5, 0.5)

    except Exception:
        print('warning: failed in weights_init for %s.weight' % m._get_name())

    try:
        if hasattr(m, "bias"):
            m.bias.data.uniform_(-0.5, 0.5)

    except Exception:
        print('warning: failed in weights_init for %s.bias' % m._get_name())


def main():
    dataset = 'cifar10'
    root_path = '.'
    data_path = os.path.join(root_path, '../data').replace('\\', '/')
    save_path = os.path.join(root_path, 'results/EPAFL_%s'%dataset).replace('\\', '/')

    lr = 1.0
    num_dummy = 1
    Iteration = 300
    num_exp = 100
    threshold = 0.001
    patience = 15

    succ_rec_img = 0
    succ_rec_label = 0
    all_succ_loss = []
    all_succ_mse = []
    all_succ_ssim = []
    all_succ_iters = []
    all_loss = []
    all_mse = []
    all_ssim = []
    all_iters = []
    train_iters = 0
    all_time = []
    all_succ_time = []


    tt = transforms.Compose([transforms.ToTensor()])
    tp = transforms.Compose([transforms.ToPILImage()])

    print(dataset, 'root_path:', root_path)
    print(dataset, 'data_path:', data_path)
    print(dataset, 'save_path:', save_path)

    if not os.path.exists('results'):
        os.mkdir('results')
    if not os.path.exists(save_path):
        os.mkdir(save_path)


    # load data
    if dataset == 'MNIST':
        shape_img = (28, 28)
        num_classes = 10
        channel = 1
        hidden = 588
        dst = datasets.MNIST(data_path, download=True)

    elif dataset == 'cifar10':
        shape_img = (32, 32)
        num_classes = 10
        channel = 3
        hidden = 768
        dst = datasets.CIFAR10(data_path, download=True)

    else:
        exit('unknown dataset')

    np.random.seed(42)
    idx_shuffle = np.random.permutation(len(dst))


    start_time = time.time()
    for idx_net in range(num_exp):

        best_val_loss = None
        wait = 0
        early_stop = False
        net = LeNet(channel=channel, hideen=hidden, num_classes=num_classes)
        net.apply(weights_init)

        print('running %d|%d experiment'%(idx_net, num_exp))
        net = net.to(device)

        print('Try to generate %d images' % (num_dummy))

        criterion = nn.CrossEntropyLoss().to(device)
        imidx_list = []

        for imidx in range(num_dummy):
            idx = idx_shuffle[idx_net]
            imidx_list.append(idx)
            tmp_datum = tt(dst[idx][0]).float().to(device)
            tmp_datum = tmp_datum.view(1, *tmp_datum.size())
            tmp_label = torch.Tensor([dst[idx][1]]).long().to(device)
            tmp_label = tmp_label.view(1, )
            if imidx == 0:
                gt_data = tmp_datum
                gt_label = tmp_label
            else:
                gt_data = torch.cat((gt_data, tmp_datum), dim=0)
                gt_label = torch.cat((gt_label, tmp_label), dim=0)


        # compute original gradient
        out = net(gt_data)
        y = criterion(out, gt_label)
        dy_dx = torch.autograd.grad(y, net.parameters())
        original_dy_dx = list((_.detach().clone() for _ in dy_dx))

        # generate dummy data and label
        dummy_data = torch.randn(gt_data.size()).to(device).requires_grad_(True)
        dummy_label = torch.randn((gt_data.shape[0], num_classes)).to(device).requires_grad_(True)

        optimizer = torch.optim.LBFGS([dummy_data, ], lr=lr)
        # predict the ground-truth label
        label_pred = torch.argmin(torch.sum(original_dy_dx[-2], dim=-1), dim=-1).detach().reshape((1,)).requires_grad_(False)


        history = []
        history_iters = []
        losses = []
        mses = []
        ssims = []

        iter_start_time = time.time()
        for iters in range(Iteration):

            def closure():
                optimizer.zero_grad()
                pred = net(dummy_data)
                dummy_loss = criterion(pred, label_pred)

                dummy_dy_dx = torch.autograd.grad(dummy_loss, net.parameters(), create_graph=True)

                grad_diff = 0
                for gx, gy in zip(dummy_dy_dx, original_dy_dx):

                    #Euclidean distance
                    grad_diff += ((gx - gy) ** 2).sum()

                grad_diff.backward()
                return grad_diff

            optimizer.step(closure)
            current_loss = closure().item()

            losses.append(current_loss)
            mean = torch.mean((dummy_data-gt_data)**2).item()
            mses.append(mean)
            s = ssim(dummy_data,torch.unsqueeze(gt_data[0],dim=0),data_range=0).item()
            ssims.append(s)

            if (iters+1) % 10 == 0:

                current_time = str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime()))
                print(current_time, iters, 'loss = %.8f, mse = %.8f, ssim = %.8f' %(current_loss, mses[-1], ssims[-1]))

                history.append([tp(dummy_data[imidx].cpu()) for imidx in range(num_dummy)])
                history_iters.append(iters)


                for imidx in range(num_dummy):
                    plt.figure(figsize=(12, 8))
                    plt.subplot(5, 10, 1)
                    plt.imshow(tp(gt_data[imidx].cpu()))
                    plt.title('Ground truth')
                    plt.axis('off')
                    for i in range(min(len(history), 29)):
                        plt.subplot(5, 10, i + 2)
                        plt.imshow(history[i][imidx])
                        plt.title('iter=%d' % (history_iters[i]+1))
                        plt.axis('off')
                    plt.savefig('%s/EPAFL_on_%s_%05d.png' % (save_path, imidx_list, imidx_list[imidx]))
                    plt.close()


            #detectplateau

            if best_val_loss is None or current_loss < best_val_loss:
              best_val_loss = current_loss
              wait = 0

              #Threshold

              if current_loss < threshold:
                  print('Iteration required: ', iters+1)

                  train_iters = iters+1
                  early_stop = True

                  break # converge

            elif wait >= patience:

              print('Iteration required: ', iters+1)

              train_iters = iters+1
              early_stop = True

              break # converge

            else:
              wait += 1



        all_time.append(time.time()-iter_start_time)

        pred_label = label_pred.item()

        l = gt_label.detach().cpu().data


        if not early_stop:
          train_iters = Iteration


        print('imidx_list:', imidx_list)
        print('loss:', losses[-1])
        print('mse:', mses[-1])
        print('ssim:', ssims[-1])
        print('gt_label:', l.item(), 'pred_label:', pred_label)

        all_loss.append(losses[-1])
        all_mse.append(mses[-1])
        all_ssim.append(ssims[-1])
        all_iters.append(train_iters)

        if l == pred_label:
          succ_rec_label += 1

        if ssims[-1] > 0.9:
          succ_rec_img += 1
          all_succ_loss.append(losses[-1])
          all_succ_mse.append(mses[-1])
          all_succ_ssim.append(ssims[-1])
          all_succ_iters.append(train_iters)
          all_succ_time.append(all_time[-1])


        print('----------------------\n\n')


    folder = oj('results',
				dataset,
				"plateau+threshold",
				'p={}_th={}'.format(str(patience),str(threshold)))


    os.makedirs(folder, exist_ok=True)

    with open(oj(folder, 'results.txt'), 'w') as file:
      file.write('ASRl = '+ str((succ_rec_label * 100) / 100) + '\n' )
      file.write('ASRc = '+ str((succ_rec_img * 100) / 100) + '\n')
      file.write('All Total time = '+ str(np.sum(all_time)) + '\n')
      file.write('All Loss = '+ str(np.average(all_loss)) + '\n')
      file.write('All MSE = '+ str(np.average(all_mse)) + '\n')
      file.write('All SSIM = ' + str(np.average(all_ssim)) + '\n')
      file.write('All Max iteration = ' + str(np.max(all_iters)) + '\n')
      file.write('All Min iteration = '+ str(np.min(all_iters)) + '\n')
      file.write('All Average iteration = '+ str(np.average(all_iters))+ '\n')
      file.write('All Stand. dev. iteration = '+ str(np.std(all_iters))+ '\n\n')

      file.write('Succ Total time = '+ str(np.sum(all_succ_time))+ '\n')
      file.write('Succ Loss = '+ str(np.average(all_succ_loss))+ '\n')
      file.write('Succ MSE = '+ str(np.average(all_succ_mse))+ '\n')
      file.write('Succ SSIM = '+ str(np.average(all_succ_ssim))+ '\n')
      file.write('Succ Max iteration = '+ str(np.max(all_succ_iters))+ '\n')
      file.write('Succ Min iteration = '+ str(np.min(all_succ_iters))+ '\n')
      file.write('Succ Average iteration = '+ str(np.average(all_succ_iters))+ '\n')
      file.write('Succ Stand. dev. iteration = '+ str(np.std(all_succ_iters))+ '\n\n')


if __name__ == '__main__':
    main()