<a href="https://colab.research.google.com/github/basics-lab/unlearning-MIA/blob/master/basic/unlearning_CIFAR10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<img src='https://unlearning-challenge.github.io/Unlearning-logo.png' width='100px'>




  * 💾 In the first section we'll load a sample dataset (CIFAR10) and pre-trained model (ResNet18).

  * 🎯 In the second section we'll develop the unlearning algorithm. We start by splitting the original training set into a retain set and a forget set. The goal of an unlearning algorithm is to update the pre-trained model so that it approximates as much as possible a model that has been trained on the retain set but not on the forget set. We provide a simple unlearning algorithm as a starting point for participants to develop their own unlearning algorithms.

  * 🏅 In the third section we'll score our unlearning algorithm using a simple membership inference attacks (MIA). Note that this is a different evaluation than the one that will be used in the competition's submission.

In [8]:
!git clone https://github.com/basics-lab/unlearning-MIA.git
!pip install opacus

SyntaxError: ignored

In [9]:
cd /unlearning-MIA/

/unlearning-MIA


In [3]:
cd ..

/Users/nived.rajaraman/Documents/GitHub/unlearning-MIA


In [4]:
"""This file is the main entry point for running the privacy auditing."""
import argparse
import logging
import os
import pickle
import random
import time
import requests
from pathlib import Path
import matplotlib.pyplot as plt
from sklearn import linear_model, model_selection

import numpy as np
import torch
import yaml
from basic.augment import get_signal_on_augmented_data
from basic.core import (
    load_dataset_for_existing_models,
    load_existing_models,
    load_existing_target_model,
    prepare_datasets,
    prepare_datasets_for_reference_in_attack,
    prepare_datasets_for_sample_privacy_risk,
    prepare_information_source,
    prepare_models,
    prepare_priavcy_risk_report,
)
from basic.dataset import get_dataset, get_dataset_subset
from basic.plot import plot_roc, plot_signal_histogram
from scipy.stats import norm
from sklearn.metrics import auc, roc_curve
from torch import nn
from basic.util import (
    check_configs,
    load_leave_one_out_models,
    load_models_with_data_idx_list,
    load_models_without_data_idx_list,
    sweep,
)

from privacy_meter.audit import Audit
from privacy_meter.model import PytorchModelTensor

torch.backends.cudnn.benchmark = True

print("Packages imported.")

ModuleNotFoundError: No module named 'seaborn'

# Create logger for current run

In this section, we'll create a logger object for the current run

In [None]:
def setup_log(name: str, save_file: bool):
    """Generate the logger for the current run.
    Args:
        name (str): Logging file name.
        save_file (bool): Flag about whether to save to file.
    Returns:
        logging.Logger: Logger object for the current run.
    """
    my_logger = logging.getLogger(name)
    my_logger.setLevel(logging.INFO)
    if save_file:
        log_format = logging.Formatter("%(asctime)s %(levelname)-8s %(message)s")
        filename = f"log_{name}.log"
        log_handler = logging.FileHandler(filename, mode="w")
        log_handler.setLevel(logging.INFO)
        log_handler.setFormatter(log_format)
        my_logger.addHandler(log_handler)

    return my_logger

# 💾 Download dataset and initialize models

In this section, we'll load the dataset and initialize models. main() calls the MIA either through audit_model or audit_model_sample.

In [None]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--cf",
        type=str,
        default="basic/config_models_reference_in_out.yaml",
        help="Yaml file which contains the configurations",
    )

    # Load the parameters
    args = parser.parse_args()
    with open(args.cf, "rb") as f:
        configs = yaml.load(f, Loader=yaml.Loader)

    check_configs(configs)
    # Set the random seed, log_dir and inference_game
    torch.manual_seed(configs["run"]["random_seed"])
    np.random.seed(configs["run"]["random_seed"])
    random.seed(configs["run"]["random_seed"])

    log_dir = configs["run"]["log_dir"]
    inference_game_type = configs["audit"]["privacy_game"].upper()

    # Set up the logger
    logger = setup_log("time_analysis", configs["run"]["time_log"])

    # Create folders for saving the logs if they do not exist
    Path(log_dir).mkdir(parents=True, exist_ok=True)
    report_dir = f"{log_dir}/{configs['audit']['report_log']}"
    Path(report_dir).mkdir(parents=True, exist_ok=True)

    start_time = time.time()

    # Load or initialize models based on metadata
    if os.path.exists((f"{log_dir}/models_metadata.pkl")):
        with open(f"{log_dir}/models_metadata.pkl", "rb") as f:
            model_metadata_list = pickle.load(f)
    else:
        model_metadata_list = {"model_metadata": {}, "current_idx": 0}
    # Load the dataset
    dataset = get_dataset(configs["data"]["dataset"], configs["data"]["data_dir"])

    privacy_game = configs["audit"]["privacy_game"]
    attack = configs["audit"]["algorithm"]

    if privacy_game in ["avg_privacy_loss_training_algo", "privacy_loss_model"]:
      audit_model(configs, privacy_game, dataset, model_metadata_list, logger)

    else:
      audit_model_rio(configs, privacy_game, dataset, model_metadata_list, logger)

    ############################
    # END
    ############################
    logger.info(
        "Run the privacy meter for the all steps costs %0.5f seconds",
        time.time() - start_time,
    )



# Privacy auditing for a model or an algorithm

In [None]:
def audit_model(configs, privacy_game, dataset, model_metadata_list, logger):
  if "reference_in_out" not in configs['audit']['algorithm']:
    raise ValueError("Currently only reference_in_out is implemented. configs['audit']['algorithm'] must contain 'reference_in_out'")
        # Start time of audit
        baseline_time = time.time()

        # Load the trained models from disk
        if model_metadata_list["current_idx"] > 0:
            target_model_idx_list = load_existing_target_model(
                len(dataset), model_metadata_list, configs
            )
            trained_target_dataset_list = load_dataset_for_existing_models(
                len(dataset),
                model_metadata_list,
                target_model_idx_list,
                configs["data"],
            )

            trained_target_models_list = load_existing_models(
                model_metadata_list,
                target_model_idx_list,
                configs["train"]["model_name"],
                dataset,
                configs["data"]["dataset"],
            )
            num_target_models = configs["train"]["num_target_model"] - len(
                trained_target_dataset_list
            )
        else:
            target_model_idx_list = []
            trained_target_models_list = []
            trained_target_dataset_list = []
            num_target_models = configs["train"]["num_target_model"]

        # Prepare the datasets
        print(25 * ">" + "Prepare the the datasets")
        data_split_info = prepare_datasets(
            len(dataset), num_target_models, configs["data"]
        )

        logger.info(
            "Prepare the datasets costs %0.5f seconds", time.time() - baseline_time
        )

        # Prepare the target models
        print(25 * ">" + "Prepare the the target models")
        baseline_time = time.time()

        new_model_list, model_metadata_list, new_target_model_idx_list = prepare_models(
            log_dir,
            dataset,
            data_split_info,
            configs["train"],
            model_metadata_list,
            configs["data"]["dataset"],
        )

        # Combine the trained models with the existing models
        model_list = [*new_model_list, *trained_target_models_list]
        data_split_info["split"] = [
            *data_split_info["split"],
            *trained_target_dataset_list,
        ]
        target_model_idx_list = [*new_target_model_idx_list, *target_model_idx_list]

        logger.info(
            "Prepare the target model costs %0.5f seconds", time.time() - baseline_time
        )

        # Prepare the information sources
        print(25 * ">" + "Prepare the information source, including attack models")
        baseline_time = time.time()
        (
            target_info_source,
            reference_info_source,
            metrics,
            log_dir_list,
            model_metadata_list,
        ) = prepare_information_source(
            log_dir,
            dataset,
            data_split_info,
            model_list,
            configs["audit"],
            model_metadata_list,
            target_model_idx_list,
            configs["train"]["model_name"],
            configs["data"]["dataset"],
        )
        logger.info(
            "Prepare the information source costs %0.5f seconds",
            time.time() - baseline_time,
        )

        # Call core of Privacy Meter
        print(25 * ">" + "Auditing the privacy risk")
        baseline_time = time.time()
        audit_obj = Audit(
            metrics=metrics,
            inference_game_type=inference_game_type,
            target_info_sources=target_info_source,
            reference_info_sources=reference_info_source,
            fpr_tolerances=None,
            logs_directory_names=log_dir_list,
        )
        audit_obj.prepare()
        audit_results = audit_obj.run()
        logger.info(
            "Prepare the Privacy Meter result costs %0.5f seconds",
            time.time() - baseline_time,
        )

        # Generate the privacy risk report
        print(25 * ">" + "Generating privacy risk report")
        baseline_time = time.time()
        prepare_priavcy_risk_report(
            log_dir,
            audit_results,
            configs["audit"],
            save_path=f"{log_dir}/{configs['audit']['report_log']}",
        )
        print(100 * "#")

        logger.info(
            "Prepare the plot for the privacy risk report costs %0.5f seconds",
            time.time() - baseline_time,
        )

# Privacy Auditing for a model with reference_in_out attack

In [None]:
def audit_model_rio(configs, privacy_game, dataset, model_metadata_list, logger)
    if "reference_in_out" not in configs["audit"]["algorithm"]:
      ValueError("reference_in_out should be present in configs['audit']['algorithm']")
        # The following code of generating the data is modified from the original code in the repo: https://github.com/tensorflow/privacy/tree/master/research/mi_lira_2021
        baseline_time = time.time()
        p_ratio = configs["data"]["keep_ratio"]
        dataset_size = configs["data"]["dataset_size"]
        number_of_models_total = (
            configs["audit"]["num_in_models"]
            + configs["audit"]["num_out_models"]
            + configs["train"]["num_target_model"]
        )
        (
            data_split_info,
            keep_matrix,
            target_data_index,
        ) = prepare_datasets_for_reference_in_attack(
            len(dataset),
            dataset_size,
            num_models=(number_of_models_total),
            keep_ratio=p_ratio,
            is_uniform=False,
        )
        data, targets = get_dataset_subset(
            dataset,
            target_data_index,
            configs["train"]["model_name"],
            device=configs["train"]["device"],
        )  # only the train dataset we want to attack
        logger.info(
            "Preparing the datasets costs %0.5f seconds",
            time.time() - baseline_time,
        )
        baseline_time = time.time()
        if model_metadata_list["current_idx"] == 0:
            # if the models are already trained and saved in the disk
            (model_list, model_metadata_dict, trained_model_idx_list) = prepare_models(
                log_dir,
                dataset,
                data_split_info,
                configs["train"],
                model_metadata_list,
                configs["data"]["dataset"],
            )
            logger.info(
                "Preparing the models costs %0.5f seconds",
                time.time() - baseline_time,
            )
            baseline_time = time.time()
            signals = []
            for model in model_list:
                model_pm = PytorchModelTensor(
                    model_obj=model,
                    loss_fn=nn.CrossEntropyLoss(),
                    device=configs["audit"]["device"],
                    batch_size=configs["audit"]["audit_batch_size"],
                )
                signals.append(
                    get_signal_on_augmented_data(
                        model_pm,
                        data,
                        targets,
                        method=configs["audit"]["augmentation"],
                        signal=configs["audit"]["signal"],
                    )
                )
            logger.info(
                "Preparing the signals costs %0.5f seconds",
                time.time() - baseline_time,
            )
        else:
            baseline_time = time.time()
            signals = []
            for idx in range(model_metadata_list["current_idx"]):
                print("Load the model and compute signals for model %d" % idx)
                model_pm = PytorchModelTensor(
                    model_obj=load_existing_models(
                        model_metadata_list,
                        [idx],
                        configs["train"]["model_name"],
                        dataset,
                        configs["data"]["dataset"],
                    )[0],
                    loss_fn=nn.CrossEntropyLoss(),
                    device=configs["audit"]["device"],
                    batch_size=10000,
                )
                signals.append(
                    get_signal_on_augmented_data(
                        model_pm,
                        data,
                        targets,
                        method=configs["audit"]["augmentation"],
                        signal=configs["audit"]["signal"],
                    )
                )
            logger.info(
                "Preparing the signals costs %0.5f seconds",
                time.time() - baseline_time,
            )
        baseline_time = time.time()
        signals = np.array(signals)

        # number of models we want to consider as test
        num_target = configs["train"]["num_target_model"]
        target_signal = signals[:num_target, :]
        reference_signals = signals[num_target:, :]
        reference_keep_matrix = keep_matrix[num_target:, :]
        membership = keep_matrix[:num_target, :]
        in_signals = []
        out_signals = []

        for data_idx in range(dataset_size):
            in_signals.append(
                reference_signals[reference_keep_matrix[:, data_idx], data_idx]
            )
            out_signals.append(
                reference_signals[~reference_keep_matrix[:, data_idx], data_idx]
            )

        in_size = min(min(map(len, in_signals)), configs["audit"]["num_in_models"])
        out_size = min(min(map(len, out_signals)), configs["audit"]["num_out_models"])
        in_signals = np.array([x[:in_size] for x in in_signals]).astype("float32")
        out_signals = np.array([x[:out_size] for x in out_signals]).astype("float32")
        mean_in = np.median(in_signals, 1)
        mean_out = np.median(out_signals, 1)
        fix_variance = configs["audit"]["fix_variance"]
        if fix_variance:
            std_in = np.std(in_signals)
            std_out = np.std(in_signals)
        else:
            std_in = np.std(in_signals, 1)
            std_out = np.std(out_signals, 1)

        prediction = []
        answers = []
        for ans, sc in zip(membership, target_signal):
            pr_in = -norm.logpdf(sc, mean_in, std_in + 1e-30)
            pr_out = -norm.logpdf(sc, mean_out, std_out + 1e-30)
            score = pr_in - pr_out
            if len(score.shape) == 2:  # the score is of size (data_size, num_augments)
                prediction.extend(score.mean(1))
                fpr_list, tpr_list, _ = roc_curve(ans, -score.mean(1))
            else:
                prediction.extend(score)
                fpr_list, tpr_list, _ = roc_curve(ans, -score)
            answers.extend(ans)
            acc = np.max(1 - (fpr_list + (1 - tpr_list)) / 2)
            roc_auc = auc(fpr_list, tpr_list)

        prediction = np.array(prediction)
        answers = np.array(answers, dtype=bool)
        fpr_list, tpr_list, _ = roc_curve(answers.ravel(), -prediction.ravel())
        acc = np.max(1 - (fpr_list + (1 - tpr_list)) / 2)
        roc_auc = auc(fpr_list, tpr_list)
        logger.info(
            "Preparing the privacy risks results costs %0.5f seconds",
            time.time() - baseline_time,
        )
        low = tpr_list[np.where(fpr_list < 0.001)[0][-1]]
        print("AUC %.4f, Accuracy %.4f, TPR@0.1%%FPR of %.4f" % (roc_auc, acc, low))
        plot_roc(
            fpr_list,
            tpr_list,
            roc_auc,
            f"{log_dir}/{configs['audit']['report_log']}/ROC.png",
        )