# Model Inversion Attack

## Settings and Imports

In [None]:
# suppress warnings
import warnings

warnings.filterwarnings('ignore')

#autoreload other packages when code changed
%load_ext autoreload
%autoreload 2

In [None]:
import torch

from torch import nn
from torch.utils.data import DataLoader
import torchvision

import matplotlib.pyplot as plt
%matplotlib inline

from tqdm.notebook import tqdm

In [None]:
#Own Code
from privacyflow.configs import path_configs
from privacyflow.datasets import faces_dataset, mi_dataset
from privacyflow.models import face_models, cifar_models, cifar_autoencoder, celeba_autoencoder

In [None]:
#Check if GPU is available
if torch.cuda.is_available():
    print("GPU will be used")
    device = torch.device('cuda')
else:
    print("No GPU available")
    device = torch.device('cpu')

## CIFAR-10 - Denoising Autoencoder

In [None]:
# #Custom Transformation for adding noise to trtaining_data
# class AddGaussianNoise(object):
#     def __init__(self, mean:float=0.0, std:float=0.0):
#         self.std = std
#         self.mean = mean
# 
#     def __call__(self, tensor):
#         tensor = tensor + torch.randn(tensor.size()) * self.std + self.mean
#         return torch.clip(tensor,min=0.0,max=1.0)

In [None]:
autoencoder_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    # AddGaussianNoise(mean=0.0, std=0.1)
])
cifar10_ds_train_autoencoder = torchvision.datasets.CIFAR10(root=path_configs.CIFAR_FOLDER_PATH,
                                                            transform=autoencoder_transform,
                                                            train=True,
                                                            download=True)

cifar10_ds_test_autoencoder = torchvision.datasets.CIFAR10(root=path_configs.CIFAR_FOLDER_PATH,
                                                           transform=autoencoder_transform,
                                                           train=False,
                                                           download=True)

#Combine the datasets for the usage for the autoencoder
cifar10_ds_autoencoder = torch.utils.data.ConcatDataset([cifar10_ds_test_autoencoder, cifar10_ds_train_autoencoder])
cifar10_dl_autoencoder = DataLoader(cifar10_ds_autoencoder, batch_size=64, shuffle=True)

In [None]:
def train_autoencoder_no_logs(model: nn.Module,
                        train_dl: torch.utils.data.DataLoader,
                        optimizer: torch.optim,
                        criterion: nn.Module,
                        num_epochs: int = 15):
    model.train()
    model = model.to(device)
    for _ in tqdm(range(num_epochs),leave=False):
        for images, _ in train_dl:
            images = images.to(device)
            model_inputs = torch.clip(images + torch.rand(images.size(),device=device) * 0.3, min=0.0, max=1.0)
            optimizer.zero_grad()
            model_outputs = model(model_inputs)
            loss = criterion(model_outputs, images)
            loss.backward()
            optimizer.step()

In [None]:
autoencoder = cifar_autoencoder.CifarDenoisingAutoencoder()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.0001)

train_autoencoder_no_logs(autoencoder,
                    train_dl=cifar10_ds_autoencoder,
                    optimizer=optimizer,
                    criterion=criterion,
                    num_epochs=20)

## CIFAR-10 Modell Inversion Attacke

In [None]:
def reconstruction_attack_cifar10(
        model:nn.Module,
        autoencoder:nn.Module,
        start_tensor:torch.Tensor,
        target:torch.Tensor,
        num_epochs:int=10_000,
        learning_rate:float =0.01,
        use_autoencoder:bool=True) -> torch.Tensor:
    
    #Params for reconstruction attack
    criterion_tensor = nn.NLLLoss()
    for _ in tqdm(range(num_epochs),leave=False):
        optimizer_tensor = torch.optim.Adam([start_tensor],lr=learning_rate)
        #Update tensor due to model
        optimizer_tensor.zero_grad()
        #the exp functions is due to the fact, that the cifar10 models are using log_softmax as last activation function
        # output = torch.exp(model(start_tensor))
        output = model(start_tensor)
        loss = criterion_tensor(output,target)
        loss.backward()
        optimizer_tensor.step()
        if use_autoencoder:
            start_tensor = torch.tensor(autoencoder(start_tensor),device=device,requires_grad=True)
    return start_tensor

In [None]:
attacked_model = cifar_models.CifarCNNModel()
attacked_model.load_state_dict(torch.load(f"{path_configs.MODELS_TRAINED_BASE_PATH}/cifar_10_base.pl"))

#set start tensor with random values with size equal to image size
#start_tensor = torch.rand([3,32,32],device=device).unsqueeze(0).requires_grad_()
start_tensor = torch.empty(3,32,32,device=device).fill_(0.5).unsqueeze(0).requires_grad_()
target_label = torch.tensor([8],dtype=torch.long,device=device)

In [None]:
autoencoder = autoencoder.to(device)
attacked_model = attacked_model.to(device)


torchvision.transforms.ToPILImage()(start_tensor.squeeze(0)).show()

reconstructed_image = reconstruction_attack_cifar10(model=attacked_model,
                                            autoencoder=autoencoder,
                                            start_tensor=start_tensor,
                                            target=target_label,
                                            num_epochs=10_000,
                                            learning_rate=0.01,
                                            use_autoencoder=False)
torchvision.transforms.ToPILImage()(reconstructed_image.squeeze(0)).show()


reconstructed_image = reconstruction_attack_cifar10(model=attacked_model,
                                            autoencoder=autoencoder,
                                            start_tensor=reconstructed_image,
                                            target=target_label,
                                            num_epochs=10_000,
                                            learning_rate=0.1,
                                            use_autoencoder=True)
torchvision.transforms.ToPILImage()(reconstructed_image.squeeze(0)).show()

## CelebA Denoising Autoencoder

In [None]:
celeba_autoencoder_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor()
])

train_dataset_celeba = faces_dataset.FacesDataset(label_cols='all', 
                                               mode="train",
                                               transform=celeba_autoencoder_transform)
val_dataset_celeba = faces_dataset.FacesDataset(label_cols='all', 
                                             mode="val", 
                                             transform=celeba_autoencoder_transform)
test_dataset_celeba = faces_dataset.FacesDataset(label_cols='all', 
                                              mode="test", 
                                              transform=celeba_autoencoder_transform)

#Combien Datasets for training of autoencoder
dataset_celeba_combines = torch.utils.data.ConcatDataset([train_dataset_celeba,val_dataset_celeba,test_dataset_celeba])
dl_celeba_combines = DataLoader(dataset_celeba_combines,batch_size=32,num_workers=4,shuffle=True)

In [None]:
autoencoder = celeba_autoencoder.CelebADenoisingAutoencoder()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.0001)

train_autoencoder_no_logs(autoencoder,
                    train_dl=dl_celeba_combines,
                    optimizer=optimizer,
                    criterion=criterion,
                    num_epochs=10)

In [None]:
img = dataset_celeba_combines[128][0]
torchvision.transforms.ToPILImage()(img).save("./privacyflow/images/faces_autoencoder1.jpg")

img = torch.clip(img + torch.rand(img.size()) * 0.3, min=0.0, max=1.0)
torchvision.transforms.ToPILImage()(img).save("./privacyflow/images/faces_autoencoder2.jpg")

autoencoder = autoencoder.to('cpu')
img = autoencoder(img)
torchvision.transforms.ToPILImage()(img).save("./privacyflow/images/faces_autoencoder3.jpg")

In [None]:
def reconstruction_attack_cifar10(
        model:nn.Module,
        autoencoder:nn.Module,
        start_tensor:torch.Tensor,
        target:torch.Tensor,
        num_epochs:int=10_000,
        learning_rate:float =0.01,
        use_autoencoder:bool=True) -> torch.Tensor:
    
    #Params for reconstruction attack
    criterion_tensor = nn.BCELoss()
    for _ in tqdm(range(num_epochs),leave=False):
        optimizer_tensor = torch.optim.Adam([start_tensor],lr=learning_rate)
        #Update tensor due to model
        optimizer_tensor.zero_grad()
        output = model(start_tensor)
        loss = criterion_tensor(output,target)
        loss.backward()
        optimizer_tensor.step()
        if use_autoencoder:
            start_tensor = torch.tensor(autoencoder(start_tensor),device=device,requires_grad=True)
    return start_tensor

In [None]:
attacked_model = face_models.get_FaceModelResNet(40)
attacked_model.load_state_dict(torch.load(f"{path_configs.MODELS_TRAINED_BASE_PATH}/face_base_model.pl"))

#set start tensor with random values with size equal to image size
#start_tensor = torch.rand([3,224,224],device=device).unsqueeze(0).requires_grad_()
start_tensor = torch.empty(3,224,224,device=device).fill_(0.5).unsqueeze(0).requires_grad_()
target_label = dataset_celeba_combines[128][1].unsqueeze(0).to(device)

In [None]:
autoencoder = autoencoder.to(device)
attacked_model = attacked_model.to(device)

torchvision.transforms.ToPILImage()(start_tensor.squeeze(0)).show()

reconstructed_image = reconstruction_attack_cifar10(model=attacked_model,
                                            autoencoder=autoencoder,
                                            start_tensor=start_tensor,
                                            target=target_label,
                                            num_epochs=10_000,
                                            learning_rate=0.01,
                                            use_autoencoder=False)
torchvision.transforms.ToPILImage()(reconstructed_image.squeeze(0)).show()


reconstructed_image = reconstruction_attack_cifar10(model=attacked_model,
                                            autoencoder=autoencoder,
                                            start_tensor=reconstructed_image,
                                            target=target_label,
                                            num_epochs=10_000,
                                            learning_rate=0.1,
                                            use_autoencoder=True)
torchvision.transforms.ToPILImage()(reconstructed_image.squeeze(0)).show()