In [1]:
%matplotlib inline

import numpy as np
from pprint import pprint

from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
import torchvision
from torchvision import models, datasets, transforms
torch.manual_seed(50)

print(torch.__version__, torchvision.__version__)

2.0.0+cu118 0.15.1+cu118


In [None]:
dst = datasets.LFWPeople("/home/raoxy/data",download=True)

Downloading http://vis-www.cs.umass.edu/lfw/lfw-funneled.tgz to /home/raoxy/data/lfw-py/lfw-funneled.tgz


 14%|█▍        | 34635776/243346528 [00:44<01:25, 2447109.91it/s]

In [None]:
 tp=transforms.Compose([
    transforms.Resize((32, 32)), # 缩放图像到 32 x 32
    transforms.Grayscale(), # 灰度化图像
    transforms.ToTensor(), # 转化为张量
    transforms.Normalize((0.5,), (0.5,)) # 归一化
])

tt = transforms.ToPILImage()

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
print("Running on %s" % device)

def label_to_onehot(target, num_classes=100):
    target = torch.unsqueeze(target, 1)
    onehot_target = torch.zeros(target.size(0), num_classes, device=target.device)
    onehot_target.scatter_(1, target, 1)
    return onehot_target

def cross_entropy_for_onehot(pred, target):
    return torch.mean(torch.sum(- target * F.log_softmax(pred, dim=-1), 1))

In [None]:
def weights_init(m):
    if hasattr(m, "weight"):
        m.weight.data.uniform_(-0.5, 0.5)
    if hasattr(m, "bias"):
        m.bias.data.uniform_(-0.5, 0.5)

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.body = nn.Sequential(
            nn.Conv2d(1, 12, kernel_size=5, stride=2, padding=2),
            nn.Sigmoid(),
            nn.Conv2d(12, 12, kernel_size=5, stride=2, padding=2),
            nn.Sigmoid(),
            nn.Conv2d(12, 12, kernel_size=5, stride=1, padding=2),
            nn.Sigmoid(),
            nn.Conv2d(12, 12, kernel_size=5, stride=1, padding=2),
            nn.Sigmoid()
        )
        self.fc = nn.Sequential(
            nn.Linear(in_features=768, out_features=5749, bias=True) # 修改输出特征数为 5749
        )

    def forward(self, x):
        x = self.body(x)
        x = x.view(-1, 768)
        x = self.fc(x)
        return x
# class LeNet(nn.Module):
#     def __init__(self):
#         super(LeNet, self).__init__()
#         act = nn.Sigmoid
#         self.body = nn.Sequential(
#             nn.Conv2d(3, 12, kernel_size=5, padding=5//2, stride=2),
#             act(),
#             nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=2),
#             act(),
#             nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=1),
#             act(),
#             nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=1),
#             act(),
#         )
#         self.fc = nn.Sequential(
#             nn.Linear(768, 100)
#         )

#     def forward(self, x):
#         out = self.body(x)
#         out = out.view(out.size(0), -1)
#         # print(out.size())
#         out = self.fc(out)
#         return out
# class LeNet(nn.Module):
#     def __init__(self):
#         super(LeNet, self).__init__()
#         self.conv1 = nn.Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
#         self.pool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0)
#         self.conv2 = nn.Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
#         self.pool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0)
#         self.fc1 = nn.Linear(64 * 7 * 7, 1000)
#         self.fc2 = nn.Linear(1000, 10)

#     def forward(self, x):
#         x = self.conv1(x)
#         x = nn.functional.relu(x)
#         x = self.pool1(x)
#         x = self.conv2(x)
#         x = nn.functional.relu(x)
#         x = self.pool2(x)
#         x = x.view(-1, 64 * 7 * 7)
#         x = self.fc1(x)
#         x = nn.functional.relu(x)
#         x= self.fc2(x)
#         x = nn.functional.softmax(x, dim=1)
#         return x
net = LeNet().to(device)
net.apply(weights_init)

In [None]:
criterion = cross_entropy_for_onehot
# criterion =  nn.CrossEntropyLoss().to(device)

In [None]:
def Raw(img_index):
    global History
    gt_data = tp(dst[img_index][0]).to(device)
    gt_data = gt_data.view(1, *gt_data.size())
    gt_label = torch.Tensor([dst[img_index][1]]).long().to(device)
    gt_label = gt_label.view(1, )
    gt_onehot_label = label_to_onehot(gt_label, num_classes=5749)
    # compute original gradient
    out = net(gt_data)
    y = criterion(out, gt_onehot_label)
    dy_dx = torch.autograd.grad(y, net.parameters())

    # share the gradients with other clients
    original_dy_dx = list((_.detach().clone() for _ in dy_dx))

    # generate dummy data and label
    dummy_data = torch.randn(gt_data.size()).to(device).requires_grad_(True)
    dummy_label = torch.randn(gt_onehot_label.size()).to(device).requires_grad_(True)
    optimizer = torch.optim.LBFGS([dummy_data, dummy_label] )

    history = []
    for iters in range(300):
        def closure():
            optimizer.zero_grad()

            pred = net(dummy_data)
            dummy_onehot_label = F.softmax(dummy_label, dim=-1)
            dummy_loss = criterion(pred, dummy_onehot_label) # TODO: fix the gt_label to dummy_label in both code and slides.
            dummy_dy_dx = torch.autograd.grad(dummy_loss, net.parameters(), create_graph=True)

            grad_diff = 0
            grad_count = 0
            for gx, gy in zip(dummy_dy_dx, original_dy_dx): # TODO: fix the variablas here
                grad_diff += ((gx - gy) ** 2).sum()
                grad_count += gx.nelement()
            # grad_diff = grad_diff / grad_count * 1000
            grad_diff.backward()

            return grad_diff

        optimizer.step(closure)
        if iters % 10 == 0:
            current_loss = closure()
#             print(iters, "%.4f" % current_loss.item())
        history.append(tt(dummy_data[0].cpu()))

    Print(History)
    return history

In [None]:
def Print(History):
    plt.figure(figsize=(12, 8))
    plt.subplots_adjust(wspace=0.05, hspace=0.05)
    for j in range(len(History)):
        plt.subplot(5, 10, j + 1)
        plt.imshow(History[j])
        plt.title("title={}".format(j))
        plt.axis('off')
    plt.show()

In [None]:
History=[]
for i in range(50):
    print("#"*50+"["+ str(i)+"]"+"#"*50)
    History.append(Raw(i)[-1])