In [None]:
import numpy as np
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

In [None]:
class ModelInversionAttack:
    def __init__(self, model_params, target_class, device):
        self.model_params = model_params
        self.target_class = target_class
        self.device = device

    def perform_multimodel_attack(self, models_list, num_iterations=1000, learning_rate=0.01, reg_param=0.01):
        """
        Perform model inversion attack using multiple models with different parameters.

        Args:
            models_list: List of model parameters for different models
            num_iterations: Number of optimization iterations
            learning_rate: Learning rate for optimization
            reg_param: Regularization parameter for total variation loss

        Returns:
            Tuple of (recovered image, loss history)
        """
        loaded_models = []
        for model_params in models_list:
            model = CNN().to(self.device)
            model.load_state_dict({k: v.to(self.device) for k, v in model_params.items()})
            model.eval()
            loaded_models.append(model)

        # Initialize the recovered image
        recovered_image = torch.rand(1, 1, 28, 28, requires_grad=True, device=self.device)
        optimizer = optim.Adam([recovered_image], lr=learning_rate)

        # One-hot encode target class
        target = torch.zeros(1, 10, device=self.device)
        target[0, self.target_class] = 1

        losses = []

        for i in range(num_iterations):
            optimizer.zero_grad()

            # Calculate classification loss across all models
            classification_loss = 0
            for model in loaded_models:
                pred = model(recovered_image)
                classification_loss += -torch.sum(target * torch.log(pred + 1e-10))

            # Average the classification loss
            classification_loss /= len(loaded_models)

            # Total variation regularization
            tv_loss = torch.sum(torch.abs(recovered_image[:, :, :, :-1] - recovered_image[:, :, :, 1:])) + \
                      torch.sum(torch.abs(recovered_image[:, :, :-1, :] - recovered_image[:, :, 1:, :]))
            tv_loss = tv_loss * reg_param

            # L2 regularization
            l2_loss = torch.sum(recovered_image ** 2) * 0.001

            # Combine all losses
            total_loss = classification_loss + tv_loss + l2_loss

            # Backpropagate and update
            total_loss.backward()
            optimizer.step()

            # Clamp values to valid image range
            with torch.no_grad():
                recovered_image.clamp_(0, 1)

            # Log progress
            if i % 400 == 0:
                losses.append(classification_loss.item())
                print(f"[{i}, {classification_loss.item():.4f}]")

        # Convert final result to numpy array
        with torch.no_grad():
            recovered_image_np = recovered_image.cpu().numpy()

        return recovered_image_np, losses

    def perform_attack(self, num_iterations=1000, learning_rate=0.01, reg_param=0.01):
        model = CNN().to(self.device)
        model.load_state_dict({k: v.to(self.device) for k, v in self.model_params.items()})
        model.eval()

        recovered_image = torch.rand(1, 1, 28, 28, requires_grad=True, device=self.device)
        # recovered_image_2 = torch.rand(1, 1, 28, 28, requires_grad=True, device=self.device)

        # optimizer = optim.Adam([recovered_image, recovered_image_2], lr=learning_rate)
        optimizer = optim.Adam([recovered_image], lr=learning_rate)

        target = torch.zeros(1, 10, device=self.device)
        target[0, self.target_class] = 1

        losses = []

        for i in range(num_iterations):
            optimizer.zero_grad()

            pred = model(recovered_image)
            # pred_2 = model(recovered_image_2)

            classification_loss = -torch.sum(target * torch.log(pred + 1e-10)) #-torch.sum(target * torch.log(pred_2 + 1e-10))

            tv_loss = torch.sum(torch.abs(recovered_image[:, :, :, :-1] - recovered_image[:, :, :, 1:])) + \
                    torch.sum(torch.abs(recovered_image[:, :, :-1, :] - recovered_image[:, :, 1:, :]))
            tv_loss = tv_loss * reg_param

            # tv_loss_2 = torch.sum(torch.abs(recovered_image_2[:, :, :, :-1] - recovered_image_2[:, :, :, 1:])) + \
                    # torch.sum(torch.abs(recovered_image_2[:, :, :-1, :] - recovered_image_2[:, :, 1:, :]))
            # tv_loss_2 = tv_loss_2 * reg_param

            l2_loss = torch.sum(recovered_image ** 2) * 0.001
            # l2_loss_2 = torch.sum(recovered_image_2 ** 2) * 0.001

            total_loss = classification_loss + tv_loss + l2_loss #+ tv_loss_2 + l2_loss_2

            total_loss.backward()
            optimizer.step()

            with torch.no_grad():
                recovered_image.clamp_(0, 1)
                # recovered_image_2.clamp_(0, 1)

            if i % 800 == 0:
                losses.append(classification_loss.item())
                print(f"[{i}, {classification_loss.item():.4f}]")

        with torch.no_grad():
            recovered_image_np = recovered_image.cpu().numpy()
            # recovered_image_np = ((recovered_image + recovered_image_2) / 2).cpu().numpy()

        return recovered_image_np, losses