### Imports

In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import random
import time

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn as nn

import analysis
import utils
from active_learning import compute_utility_scores_entropy
from active_learning import compute_utility_scores_gap
from active_learning import compute_utility_scores_greedy
from architectures.densenet_pre import densenetpre
from architectures.resnet_pre import resnetpre
from architectures.utils_architectures import pytorch2pickle


from datasets.utils import get_dataset_full_name
from datasets.utils import set_dataset
from datasets.utils import show_dataset_stats
from datasets.xray.xray_datasets import get_votes_only_for_dataset
from errors import check_perfect_balance_type
#from models.add_tau_per_model import set_taus
from models.big_ensemble_model import BigEnsembleModel
#from models.ensemble_model import EnsembleModel
from models.load_models import load_private_model_by_id
from models.load_models import load_private_models
from models.private_model import get_private_model_by_id
from models.utils_models import get_model_name_by_id
from models.utils_models import model_size
from params import get_parameters
from utils import eval_distributed_model
from utils import eval_model
from utils import from_result_to_str
from utils import get_unlabeled_indices
from utils import get_unlabeled_set
from utils import metric
from utils import result
from utils import train_model
from utils import update_summary
from utils import pick_labels_general
#from virtual_parties import query_ensemble_model_with_virtual_parties

In [2]:
args = get_parameters()

###################################################################
args: 
DPSGD :  False
DPSGD_BATCH_SIZE :  2
DPSGD_CCLIP :  0
DPSGD_EPOCHS :  10
DPSGD_LR :  0.001
DPSGD_NOISE_MULTIPLIER :  1.3
DPSGD_PASCAL_PATH :  /VOC2012/
adam_amsgrad :  False
adaptive_batch_size :  5
apply_data_independent_bound :  True
architecture :  MnistNetPate
architectures :  ['MnistNetPate']
attacker_dataset :  None
balance_type :  standard
batch_size :  64
begin_id :  0
bins_confidence :  10
budget :  2.5
budgets :  [2.5]
chexpert_dataset_type :  pos
class_type :  multiclass
coco_additional_datasets :  []
coco_data_loader :  custom
coco_datasets :  ['train', 'val']
coco_image_size :  448
coco_version :  2017
commands :  ['query_ensemble_model']
count_noise :  bounded
cuda :  True
cwd :  c:\Users\Ahmad\Desktop\UofT EngSci Year 2\Research\CleverHans\PATE Demo
data_aug :  True
data_aug_rot :  45
data_aug_scale :  0.15
data_aug_trans :  0.15
data_dir :  /home/nicolas/data
dataset :  mnist
dataset_type :  bala

###  Private Training / Model Loading

We use 250 models trained on disjoint partitions of the training set. We provide the code for loading our trained models (downloadable from Google Drive) below and the code to evaluate the performance of the models on a test set. The parameters are set in a seperate file (params.py) and can be modified.

In [3]:
private_model_path = "250-models"  # Set path to a folder containing the downloaded models
ensemble_model_path = "ensemble-models" # Create and select a seperate folder.
args.ensemble_model_path = "ensemble-models"
retrained_model_path = "retrained-private-models"  # Create and Set path to a folder to store the student model

In [10]:
start_time = time.time()

if args.num_querying_parties > 0:
    # Checks
    assert 0 <= args.begin_id
    assert args.begin_id < args.end_id
    assert args.end_id <= args.num_models
    args.querying_parties = range(args.begin_id, args.end_id, 1)
else:
    other_querying_party = -1
    assert args.num_querying_parties == other_querying_party
    args.querying_parties = args.querying_party_ids


test_type = args.test_models_type
# test_type = 'retrained'
# test_type = 'private'
if test_type == "private":
    args.save_model_path = private_model_path
else:
    raise Exception(f"Unknown test_type: {test_type}")

evalloader = utils.load_unlabeled_dataloader(args=args)
# This loads 1000 examples from the MNIST test set. 
# evalloader = utils.load_private_data(args=args)[0]
#print(f"eval dataset: ", evalloader.dataset)

if args.debug is True:
    # Logs about the eval set
    show_dataset_stats(
        dataset=evalloader.dataset, args=args, file=file,
        dataset_name="eval"
    )

# Training
summary = {
    metric.loss: [],
    metric.acc: [],
    metric.balanced_acc: [],
    metric.auc: [],
    metric.map: [],
}
for id in range(250):

    model = load_private_model_by_id(
        args=args, id=id, model_path=args.save_model_path
    )

    result = eval_distributed_model(
        model=model, dataloader=evalloader, args=args)

    model_name = get_model_name_by_id(id=id)
    result["model_name"] = model_name
    result_str = from_result_to_str(result=result, sep="\n",
                                    inner_sep=args.sep)
    arr = result.get(metric.acc_detailed, None)
    print(f"Accuracy of model {id}", sum(arr)/len(arr))
    summary = update_summary(summary=summary, result=result)



end_time = time.time()
elapsed_time = end_time - start_time


Accuracy of model 0 83.93798593595977
Accuracy of model 1 86.60838342440223
Accuracy of model 2 88.27946047186852
Accuracy of model 3 87.06090583075586
Accuracy of model 4 89.28529309763621
Accuracy of model 5 88.33479738568906
Accuracy of model 6 88.1626867504269
Accuracy of model 7 88.22563035656923
Accuracy of model 8 87.1157449500574
Accuracy of model 9 89.90879986673932
Accuracy of model 10 87.61553046672798
Accuracy of model 11 88.2154668241559
Accuracy of model 12 87.77934165144396
Accuracy of model 13 88.6107888535266
Accuracy of model 14 87.18472064944191
Accuracy of model 15 88.58194168262212
Accuracy of model 16 85.72795127370837
Accuracy of model 17 87.29744375262253
Accuracy of model 18 88.78757984199027
Accuracy of model 19 88.59012390242415
Accuracy of model 20 85.89302006459965
Accuracy of model 21 89.33770370611629
Accuracy of model 22 90.24157093580968
Accuracy of model 23 89.84666418814828
Accuracy of model 24 87.34175959196948
Accuracy of model 25 86.2015182124159


KeyboardInterrupt: 

### Prediction

The prediction phase of PATE consists of several steps, the first of which is inference on the private models where logits are returned and turned to one-hot vector encodings. This happens in the inference() function in the class below. Next the votes are aggregated into a histogram which takes place in the query() function. This function also then contains the noisy argmax mechanism which then leads to the final prediction by the ensemble of models which in this case is stored as the preds variable. 

In [4]:
class EnsembleModel(nn.Module):
    """
    Noisy ensemble of private models.
    All the models for the ensemble are pre-cached in memory.
    """

    def __init__(self, model_id: int, private_models, args):
        """

        :param model_id: id of the model (-1 denotes all private models).
        :param private_models: list of private models
        :param args: program parameters
        """
        super(EnsembleModel, self).__init__()
        self.id = model_id
        if self.id == -1:
            self.name = f"ensemble(all)"
        else:
            # This is ensemble for private model_id.
            self.name = get_model_name_by_id(id=model_id)
        self.num_classes = args.num_classes
        print("Building ensemble model '{}'!".format(self.name))
        self.ensemble = private_models

    def __len__(self):
        return len(self.ensemble)

    def inference(self, unlabeled_dataloader, args):
        all_votes = []
        end = 0
        with torch.no_grad():
            for data, _ in unlabeled_dataloader:
                if args.cuda:
                    data = data.cuda()
                # Generate raw ensemble votes.
                batch_size = data.shape[0]
                begin = end
                end = begin + batch_size
                votes = torch.zeros((batch_size, self.num_classes))
                for model in self.ensemble:
                    output = model(data)
                    if args.vote_type == 'discrete':
                        label = output.argmax(dim=1).cpu()
                        model_votes = utils.one_hot(label, self.num_classes)
                    elif args.vote_type == 'probability':
                        model_votes = F.softmax(output, dim=1).cpu()
                    else:
                        raise Exception(
                            f"Unknown args.vote_type: {args.vote_type}.")
                    votes += model_votes
                all_votes.append(votes.numpy())

        all_votes = np.concatenate(all_votes, axis=0)
        assert all_votes.shape == (
            len(unlabeled_dataloader.dataset), self.num_classes)
        if args.vote_type == 'discrete':
            assert np.all(all_votes.sum(axis=-1) == len(self.ensemble))
        filename = '{}-raw-votes-mode-{}-vote-type-{}'.format(
            self.name, args.mode, args.vote_type)
        filepath = os.path.join(args.ensemble_model_path, filename)
        np.save(filepath, all_votes)
        return all_votes

    def inference_confidence_scores(self, unlabeled_dataloader, args):
        """Generate raw softmax confidence scores for RDP analysis_test."""
        dataset = unlabeled_dataloader.dataset
        dataset_len = len(dataset)
        num_models = len(self.ensemble)
        confidence_scores = torch.zeros(
            (num_models, dataset_len, self.num_classes))
        end = 0
        with torch.no_grad():
            for data, _ in unlabeled_dataloader:
                if args.cuda:
                    data = data.cuda()
                # Generate raw ensemble votes.
                batch_size = data.shape[0]
                begin = end
                end = begin + batch_size
                for model_idx, model in enumerate(self.ensemble):
                    output = model(data)
                    softmax_scores = F.softmax(output, dim=1).cpu()
                    confidence_scores[model_idx, begin:end, :] = softmax_scores

        filename = '{}-raw-votes-mode-{}-vote-type-{}'.format(
            self.name, args.mode, args.vote_type)
        filepath = os.path.join(args.ensemble_model_path, filename)
        np.save(filepath, confidence_scores)
        return confidence_scores

    def query(self, queryloader, args, indices_queried, targets=None):
        """Query a noisy ensemble model."""
        indices_queried = np.array(indices_queried)
        indices_answered = []
        all_preds = []
        all_labels = []
        gaps_detailed = np.zeros(args.num_classes, dtype=np.float64)
        correct = np.zeros(args.num_classes, dtype=np.int64)
        wrong = np.zeros(args.num_classes, dtype=np.int64)
        with torch.no_grad():
            begin = 0
            end = 0
            for data, target in queryloader:
                if args.cuda:
                    data, target = data.cuda(), target.cuda()
                num_samples = data.shape[0]
                end += num_samples
                # Generate raw ensemble votes
                votes = torch.zeros((num_samples, self.num_classes))
                for model in self.ensemble:
                    output = model(data)
                    if args.vote_type == 'discrete':
                        label = output.argmax(dim=1).cpu()
                        model_votes = utils.one_hot(label, self.num_classes)
                    elif args.vote_type == 'probability':
                        model_votes = F.softmax(output, dim=1).cpu()
                    else:
                        raise Exception(
                            f"Unknown args.votes_type: {args.votes_type}.")
                    votes += model_votes

                # Threshold mechanism
                if args.sigma_threshold > 0:
                    noise_threshold = np.random.normal(0., args.sigma_threshold,
                                                       num_samples)
                    vote_counts = votes.data.max(dim=1)[0].numpy()
                    answered = (vote_counts + noise_threshold) > args.threshold
                    indices_answered.append(
                        indices_queried[begin:end][answered])
                else:
                    answered = [True for _ in range(num_samples)]
                    indices_answered.append(indices_queried[begin:end])

                # GNMax mechanism
                assert args.sigma_gnmax > 0
                noise_gnmax = np.random.normal(0., args.sigma_gnmax, (
                    data.shape[0], self.num_classes))
                preds = \
                    (votes + torch.from_numpy(noise_gnmax).float()).max(dim=1)[
                        1].numpy().astype(np.int64)[answered]
                all_preds.append(preds)
                # Gap between the ensemble votes of the two most probable
                # classes.
                sorted_votes = votes.sort(dim=-1, descending=True)[0]
                gaps = (sorted_votes[:, 0] - sorted_votes[:, 1]).numpy()[
                    answered]
                # Target labels
                target = target.data.cpu().numpy().astype(np.int64)[answered]
                all_labels.append(target)
                assert len(target) == len(preds) == len(gaps)
                for label, pred, gap in zip(target, preds, gaps):
                    gaps_detailed[label] += gap
                    if label == pred:
                        correct[label] += 1
                    else:
                        wrong[label] += 1
                begin += data.shape[0]
        indices_answered = np.concatenate(indices_answered, axis=0)
        all_preds = np.concatenate(all_preds, axis=0)
        all_labels = np.concatenate(all_labels, axis=0)
        total = correct.sum() + wrong.sum()
        assert len(indices_answered) == len(all_preds) == len(
            all_labels) == total
        filename = utils.get_aggregated_labels_filename(
            args=args, name=self.name)
        filepath = os.path.join(args.ensemble_model_path, filename)
        np.save(filepath, all_preds)
        return indices_answered, 100. * correct.sum() / total, 100. * correct / (
                correct + wrong), gaps_detailed.sum() / total, gaps_detailed / (
                       correct + wrong)

### Privacy Cost Computation

The RDP privacy cost of the queries is computed using the following function (or a slightly modified version depending on the thresholding mechanism used). 

In [12]:
def analyze_multiclass_gnmax(
        votes, threshold, sigma_threshold, sigma_gnmax, budget, delta,
        file=None, show_dp_budget='disable', args=None):
    """
    Analyze how the pre-defined privacy budget will be exhausted when answering
    queries using the gaussian noisy max algorithm but without the
    thresholding mechanism.

    Args:
        votes: a 2-D numpy array of raw ensemble votes, with each row
        corresponding to a query.
        threshold: not used but for compatibility with confident gnmax it
            is here
        sigma_threshold: not used but for compatibility is here
        sigma_gnmax: std of the Gaussian noise for the DP mechanism.
        budget: pre-defined epsilon value for (eps, delta)-DP.
        delta: pre-defined delta value for (eps, delta)-DP.
        file: for logs.
        show_dp_budget: show the current cumulative dp budget.
        args: all args of the program

    Returns:
        max_num_query: when the pre-defined privacy budget is exhausted.
        dp_eps: a numpy array of length L = num-queries, with each entry corresponding
            to the privacy cost at a specific moment.
        partition: a numpy array of length L = num-queries, with each entry corresponding
            to the partition of privacy cost at a specific moment.
        answered: a numpy array of length L = num-queries, with each entry corresponding
            to the expected number of answered queries at a specific moment.
        order_opt: a numpy array of length L = num-queries, with each entry corresponding
            to the order minimizing the privacy cost at a specific moment.
    """
    max_num_query = 0

    def compute_partition(order_opt, eps):
        """Analyze how the current privacy cost is divided."""
        idx = np.searchsorted(orders, order_opt)
        rdp_eps_gnmax = rdp_eps_total_curr[idx]
        p = np.array([rdp_eps_gnmax, -math.log(delta) / (order_opt - 1)])
        # assert sum(p) == eps
        # Normalize p so that sum(p) = 1
        return p / eps

    # RDP orders.
    orders = np.concatenate((np.arange(2, 100, .5),
                             np.logspace(np.log10(100), np.log10(1000),
                                         num=200)))
    # Number of queries
    n = votes.shape[0]

    # All cumulative results
    dp_eps = np.zeros(n)
    partition = [None] * n
    order_opt = np.full(n, np.nan, dtype=float)

    # Current cumulative results
    rdp_eps_total_curr = np.zeros(len(orders))
    # Iterating over all queries
    for i in range(n):
        v = votes[i]
        logq = compute_logq_gnmax(v, sigma_gnmax)
        if args.apply_data_independent_bound:
            rdp_eps_gnmax = compute_rdp_data_dependent_gnmax(
                logq, sigma_gnmax, orders)
        else:
            rdp_eps_gnmax = compute_rdp_data_dependent_gnmax_no_upper_bound(
                logq, sigma_gnmax, orders)

        # Update current cumulative results.
        rdp_eps_total_curr += rdp_eps_gnmax
        # Update all cumulative results.
        dp_eps[i], order_opt[i] = rdp_to_dp(orders, rdp_eps_total_curr, delta)
        partition[i] = compute_partition(order_opt[i], dp_eps[i])
        # Verify if the pre-defined privacy budget is exhausted.
        if dp_eps[i] <= budget:
            max_num_query = i + 1
        else:
            break
        # Logs
        # if i % 100000 == 0 and i > 0:
        if show_dp_budget == 'apply':
            raw_file = f'queries_answered_privacy_budget.txt'
            with open(raw_file, 'a+') as writer:
                if i == 0:
                    header = "queries answered,privacy budget"
                    writer.write(f"{header}\n")
                    writer.write("0,0\n")
                info = f"{i + 1},{dp_eps[i]}"
                writer.write(f"{info}\n")

    if file is not None:
        with open('privacy_budget_analysis.csv', 'a+') as writer:
            info = f"{n},{dp_eps[n - 1]}"
            writer.write(f"{info}\n")

    # print(f"{threshold},{sigma_threshold},{sigma_gnmax}")
    # analyze_results(votes=votes, max_num_query=max_num_query, dp_eps=dp_eps)
    # answered is the probability of a given label being answered. For the GNMax
    # without the confidence (no thresholding mechanism) each
    # label < max_num_query is answered.
    # answered = np.zeros(n, dtype=float)
    # answered[0:max_num_query] = 1
    answered = [x for x in range(1, max_num_query + 1)]
    return max_num_query, dp_eps, partition, answered, order_opt

These steps are all combined into a query answer process in the code below. A part of the MNIST test set (distinct from the part used to evaluate the models) is used to select query samples which are then queried using the Noisy Argmax mechanism as described in more detail above. A dataset is then created consisting of the query answer pairs (stored in unlabeled_dataloaders). 

In [5]:
# file_name = r"logs-(num-models:{})-(num-query-parties:{})-(query-mode:{})-(threshold:{:.1f})-(sigma-gnmax:{:.1f})-(sigma-threshold:{:.1f})-(budget:{:.2f}).txt".format(
#     args.num_models,
#     args.num_querying_parties,
#     args.mode,
#     args.threshold,
#     args.sigma_gnmax,
#     args.sigma_threshold,
#     args.budget,
# )
file_name = "logs.txt"
print("ensemble_model_path: ", ensemble_model_path)
print("file_name: ", file_name)
#file = open(ensemble_model_path + "/" + file_name, "w")
file = open(os.path.join(ensemble_model_path, file_name), "w")
args.save_model_path = ensemble_model_path
utils.augmented_print("##########################################", file)
utils.augmented_print(
    "Query-answer process on '{}' dataset!".format(args.dataset), file
)
utils.augmented_print(
    "Number of private models: {:d}".format(args.num_models), file
)
utils.augmented_print(
    "Number of querying parties: {:d}".format(args.num_querying_parties),
    file
)
utils.augmented_print("Querying mode: {}".format(args.mode), file)
utils.augmented_print("Confidence threshold: {:.1f}".format(args.threshold),
                      file)
utils.augmented_print(
    "Standard deviation of the Gaussian noise in the GNMax mechanism: {:.1f}".format(
        args.sigma_gnmax
    ),
    file,
)
utils.augmented_print(
    "Standard deviation of the Gaussian noise in the threshold mechanism: {:.1f}".format(
        args.sigma_threshold
    ),
    file,
)
utils.augmented_print(
    "Pre-defined privacy budget: ({:.2f}, {:.0e})-DP".format(
        args.budget, args.delta
    ),
    file,
)
utils.augmented_print("##########################################", file)

model_path = private_model_path
private_models = load_private_models(args=args, model_path=model_path)
# Querying parties
prev_num_models = args.num_models


parties_q = private_models[: args.num_querying_parties]
args.querying_parties = parties_q

# Answering parties.
parties_a = []
for i in range(args.num_querying_parties):
    # For a given querying party, skip this very querying party as its
    # own answering party.
    if args.test_virtual is True:
        num_private = len(private_models) // args.num_querying_parties
        start = i * num_private
        end = start + (i + 1) * num_private
        private_subset = private_models[0:start] + private_models[end:]
    else:
        private_subset = private_models[:i] + private_models[i + 1:]

    ensemble_model = EnsembleModel(
        model_id=i, private_models=private_subset, args=args
    )
    parties_a.append(ensemble_model)

# Compute utility scores and sort available queries
utils.augmented_print(
    "##########################################", file, flush=True
)
if args.attacker_dataset:
    unlabeled_dataset = utils.get_attacker_dataset(
        args=args, dataset_name=args.attacker_dataset
    )
    print("attacker uses {} dataset".format(args.attacker_dataset))
else:
    unlabeled_dataset = utils.get_unlabeled_set(args=args)

if args.mode == "random":
    all_indices = get_unlabeled_indices(args=args,
                                        dataset=unlabeled_dataset)
else:
    unlabeled_dataloaders = utils.load_unlabeled_dataloaders(
        args=args, unlabeled_dataset=unlabeled_dataset
    )
    utility_scores = []

    # Select the utility function.
    if args.mode == "entropy":
        utility_function = compute_utility_scores_entropy
    elif args.mode == "gap":
        utility_function = compute_utility_scores_gap
    elif args.mode == "greedy":
        utility_function = compute_utility_scores_greedy
    else:
        raise Exception(f"Unknown query selection mode: {args.mode}.")

    for i in range(args.num_querying_parties):
        filename = "{}-utility-scores-(mode-{})-dataset-{}.npy".format(
            parties_q[i].name, args.mode, args.dataset
        )
        filepath = os.path.join(ensemble_model_path, filename)
        if os.path.isfile(filepath) and args.debug is True:
            utils.augmented_print(
                "Loading utility scores for '{}' in '{}' mode!".format(
                    parties_q[i].name, args.mode
                ),
                file,
            )
            utility = np.load(filepath)
        else:
            utils.augmented_print(
                "Computing utility scores for '{}' in '{}' mode!".format(
                    parties_q[i].name, args.mode
                ),
                file,
            )
            utility = utility_function(
                model=parties_q[i], dataloader=unlabeled_dataloaders[i],
                args=args
            )
        utility_scores.append(utility)

    # Sort unlabeled data according to their utility scores.
    all_indices = []
    for i in range(args.num_querying_parties):
        offset = i * (
                args.num_unlabeled_samples // args.num_querying_parties)
        indices = utility_scores[i].argsort()[::-1] + offset
        all_indices.append(indices)
        assert len(set(indices)) == len(indices)
    if not args.attacker_dataset:
        # this assertion seems only fails in entropy mode when using a different attacker dataset, is this okay?
        assert (
                len(set(np.concatenate(all_indices, axis=0)))
                == args.num_unlabeled_samples
        )

utils.augmented_print(
    "##########################################", file, flush=True
)
utils.augmented_print(
    "Select queries according to their utility scores subject to the pre-defined privacy budget",
    file,
    flush=True,
)

for i in range(args.num_querying_parties):
    # Raw ensemble votes
    if args.attacker_dataset is None:
        attacker_dataset = ""
    else:
        attacker_dataset = args.attacker_dataset
    filename = "{}-raw-votes-(mode-{})-dataset-{}-attacker-{}.npy".format(
        parties_a[i].name, args.mode, args.dataset, attacker_dataset
    )
    filepath = os.path.join(ensemble_model_path, filename)
    utils.augmented_print(f"filepath: {filepath}", file=file)
    if os.path.isfile(filepath) and args.debug is True:
        utils.augmented_print(
            "Loading raw ensemble votes for '{}' in '{}' mode!".format(
                parties_a[i].name, args.mode
            ),
            file,
        )
        votes = np.load(filepath)
    else:
        utils.augmented_print(
            "Generating raw ensemble votes for '{}' in '{}' mode!".format(
                parties_a[i].name, args.mode
            ),
            file,
        )
        # Load unlabeled data according to a specific order
        unlabeled_dataloader_ordered = utils.load_ordered_unlabeled_data(
            args, all_indices[i], unlabeled_dataset=unlabeled_dataset
        )
        if args.vote_type == "confidence_scores":
            votes = parties_a[i].inference_confidence_scores(
                unlabeled_dataloader_ordered, args
            )
        else:
            votes = parties_a[i].inference(unlabeled_dataloader_ordered,
                                           args)
        np.save(file=filepath, arr=votes)

    # Analyze how the pre-defined privacy budget will be exhausted when
    # answering queries.
    (
        max_num_query,
        dp_eps,
        partition,
        answered,
        order_opt,
    ) = analysis.analyze_privacy(votes=votes, args=args, file=file)

    utils.augmented_print("Querying party: {}".format(parties_q[i].name),
                          file)
    utils.augmented_print(
        "Maximum number of queries: {}".format(max_num_query), file
    )
    utils.augmented_print(
        "Privacy guarantee achieved: ({:.4f}, {:.0e})-DP".format(
            dp_eps[max_num_query - 1], args.delta
        ),
        file,
    )
    utils.augmented_print(
        "Expected number of queries answered: {:.3f}".format(
            answered[max_num_query - 1]
        ),
        file,
    )
    utils.augmented_print(
        "Partition of privacy cost: {}".format(
            np.array2string(
                partition[max_num_query - 1], precision=3, separator=", "
            )
        ),
        file,
    )

    utils.augmented_print(
        "##########################################", file, flush=True
    )
    utils.augmented_print("Generate query-answer pairs.", file)
    indices_queried = all_indices[i][:max_num_query]
    queryloader = utils.load_ordered_unlabeled_data(
        args=args, indices=indices_queried,
        unlabeled_dataset=unlabeled_dataset
    )
    indices_answered, acc, acc_detailed, gap, gap_detailed = parties_a[
        i].query(
        queryloader, args, indices_queried
    )
    utils.save_raw_queries_targets(
        args=args,
        indices=indices_answered,
        dataset=unlabeled_dataset,
        name=parties_q[i].name,
    )
    utils.augmented_print("Accuracy on queries: {:.2f}%".format(acc), file)
    utils.augmented_print(
        "Detailed accuracy on queries: {}".format(
            np.array2string(acc_detailed, precision=2, separator=", ")
        ),
        file,
    )
    utils.augmented_print(
        "Gap on queries: {:.2f}% ({:.2f}|{:d})".format(
            100.0 * gap / len(parties_a[i].ensemble),
            gap,
            len(parties_a[i].ensemble),
        ),
        file,
    )
    utils.augmented_print(
        "Detailed gap on queries: {}".format(
            np.array2string(gap_detailed, precision=2, separator=", ")
        ),
        file,
    )

    utils.augmented_print(
        "##########################################", file, flush=True
    )
    utils.augmented_print("Check query-answer pairs.", file)
    queryloader = utils.load_ordered_unlabeled_data(
        args=args, indices=indices_answered,
        unlabeled_dataset=unlabeled_dataset
    )
    counts, ratios = utils.class_ratio(queryloader.dataset, args)
    utils.augmented_print(
        "Label counts: {}".format(np.array2string(counts, separator=", ")),
        file
    )
    utils.augmented_print(
        "Class ratios: {}".format(
            np.array2string(ratios, precision=2, separator=", ")
        ),
        file,
    )
    utils.augmented_print(
        "Number of samples: {:d}".format(len(queryloader.dataset)), file
    )
    utils.augmented_print(
        "##########################################", file, flush=True
    )
file.close()
args.num_models = prev_num_models

ensemble_model_path:  ensemble-models
file_name:  logs.txt
##########################################
Query-answer process on 'mnist' dataset!
Number of private models: 250
Number of querying parties: 1
Querying mode: random
Confidence threshold: 200.0
Standard deviation of the Gaussian noise in the GNMax mechanism: 40.0
Standard deviation of the Gaussian noise in the threshold mechanism: 150.0
Pre-defined privacy budget: (2.50, 1e-05)-DP
##########################################
Building ensemble model 'model(1)'!
##########################################
##########################################
Select queries according to their utility scores subject to the pre-defined privacy budget
filepath: ensemble-models\model(1)-raw-votes-(mode-random)-dataset-mnist-attacker-.npy
Generating raw ensemble votes for 'model(1)' in 'random' mode!
queries answered,privacy budget
0.6280399927976317,0.1386180709072033
Number of queries: 1 | E[answered]: 0.628 | E[eps] at order 168.318: 0.1386 (cont

### Student Training 

Once the query answer process has been completed using the Noisy GNMax mechanism with the ensemble of 250 models, a student model is trained on the labels returned by the ensemble model. The training process used is standard and the same as the process to train the initial private models.  

In [6]:
assert 0 <= args.begin_id and args.begin_id < args.end_id and args.end_id

if args.num_querying_parties > 0:
    args.querying_parties = range(args.begin_id, args.end_id, 1)
else:
    other_querying_party = -1
    assert args.num_querying_parties == other_querying_party
    args.querying_parties = args.querying_party_ids

# Logs
# filename = "logs-(num_models:{:d})-(id:{:d}-{:d})-(num-epochs:{:d})-(budget:{:f})-(dataset:{})-(architecture:{}).txt".format(
#     args.num_models,
#     args.begin_id + 1,
#     args.end_id,
#     args.num_epochs,
#     args.budget,
#     args.dataset,
#     args.architecture,
# )
filename = "logs"
print("filename: ", filename)
file = open(os.path.join(retrained_model_path, filename), "w")
args.save_model_path = retrained_model_path
utils.augmented_print("##########################################", file)
utils.augmented_print(
    "Retraining the private models of all querying parties on '{}' dataset!".format(
        args.dataset
    ),
    file,
)
utils.augmented_print(
    "Number of querying parties: {:d}".format(len(args.querying_parties)),
    file
)
utils.augmented_print("Initial learning rate: {:.2f}".format(args.lr), file)
utils.augmented_print(
    "Number of epochs for retraining each model: {:d}".format(
        args.num_epochs), file
)
if args.test_virtual:
    assert args.num_querying_parties > 0
    prev_num_models = args.num_models
    args.num_models = args.num_querying_parties
    if args.dataset_type == "imbalanced":
        all_private_trainloaders = utils.load_private_data_imbalanced(args)
    elif args.dataset_type == "balanced":
        all_private_trainloaders = utils.load_private_data(args)
    else:
        raise Exception(
            "Unknown dataset type: {}".format(args.dataset_type))
    evalloader = utils.load_evaluation_dataloader(args)
# Dataloaders
if args.dataset_type == "imbalanced":
    all_augmented_dataloaders = utils.load_private_data_and_qap_imbalanced(
        args=args
    )
elif args.dataset_type == "balanced":
    if args.balance_type == "standard":
        all_augmented_dataloaders = utils.load_private_data_and_qap(
            args=args)
    elif args.balance_type == "perfect":
        check_perfect_balance_type(args=args)
        all_augmented_dataloaders = utils.load_private_data_and_qap_imbalanced(
            args=args
        )
    else:
        raise Exception(f"Unknown args.balance_type: {args.balance_type}.")
else:
    raise Exception(f"Unknown dataset type: {args.dataset_type}.")
evalloader = utils.load_evaluation_dataloader(args)
# Training
utils.augmented_print("##########################################", file)
# Different random seeds.
# seed_list = [11, 13, 17, 113, 117]
# seed_list = [11, 13, 17]
seed_list = [args.seed]
model_name = get_model_name_by_id(id=0)
summary = {
    metric.loss: [],
    metric.acc: [],
    metric.balanced_acc: [],
    metric.auc: [],
    metric.map: [],
    metric.acc_detailed: [],
    metric.balanced_acc_detailed: [],
    metric.auc_detailed: [],
    metric.map_detailed: []
}
trainloader = all_augmented_dataloaders[0]
# print("len trainloader", len(trainloader))
#print("attr", trainloader.dataset.__dict__.keys())
print(all_augmented_dataloaders[0].dataset)
# show_dataset_stats(
#     dataset=trainloader.dataset, args=args, dataset_name="retrain data", file=file
# )
if args.dataset == "pascal" and args.retrain_fine_tune:
    model = resnetpre()
    print("Loaded pretrained resnet50")
elif args.dataset == "cxpert" and args.retrain_fine_tune:
    model = densenetpre()
    print("Loaded pretrained densenet")
else:
    if args.retrain_model_type == 'load':
        model = load_private_model_by_id(
            args=args, id=0, model_path=private_model_path)
    elif args.retrain_model_type == 'raw':
        model = get_private_model_by_id(args=args, id=0)
        model.name = model_name
    else:
        raise Exception(f"Unknown args.retrain_model_type: "
                        f"{args.retrain_model_type}")

args.seed = seed_list[0]
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
train_model(args=args, model=model, trainloader=trainloader,
            evalloader=evalloader)

result = eval_model(model=model, dataloader=evalloader, args=args)
summary = update_summary(summary=summary, result=result)
summary["model_name"] = model_name
from_args = ["dataset", "num_models", "budget", "architecture"]
for arg in from_args:
    summary[arg] = getattr(args, arg)

# Aggregate results from different seeds.
for metric_key in [metric.loss, metric.acc, metric.balanced_acc, metric.auc,
                   metric.map]:
    value = summary[metric_key]
    if len(value) > 0:
        avg_value = np.mean(value)
        summary[metric_key] = avg_value
    else:
        summary[metric_key] = "N/A"

for metric_key in [metric.acc_detailed, metric.balanced_acc_detailed,
                   metric.auc_detailed, metric.map_detailed]:
    detailed_value = summary[metric_key]
    if len(detailed_value) > 0:
        detailed_value = np.array(detailed_value)
        summary[metric_key] = detailed_value.mean(axis=0)
        summary[metric_key.name + "_std"] = detailed_value.std(axis=0)
    else:
        summary[metric_key] = "N/A"

summary_str = from_result_to_str(result=summary, sep=" | ", inner_sep=": ")
utils.augmented_print(text=summary_str, file=file, flush=True)

if model is not None:
    utils.save_model(args=args, model=model, result_test=summary)

utils.augmented_print("##########################################", file)

utils.augmented_print("##########################################", file)

file.close()

filename:  logs
##########################################
Retraining the private models of all querying parties on 'mnist' dataset!
Number of querying parties: 1
Initial learning rate: 0.10
Number of epochs for retraining each model: 20
In QuerySet:
number of new labeled data points:  883
number of all new items:  883
number of not answered:  0
number of answered:  883
##########################################
<torch.utils.data.dataset.ConcatDataset object at 0x000001B061357988>
STARTED TRAINING


AttributeError: Caught AttributeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "C:\Users\Ahmad\anaconda3\envs\esc180\lib\site-packages\torch\utils\data\_utils\worker.py", line 202, in _worker_loop
    data = fetcher.fetch(index)
  File "C:\Users\Ahmad\anaconda3\envs\esc180\lib\site-packages\torch\utils\data\_utils\fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "C:\Users\Ahmad\anaconda3\envs\esc180\lib\site-packages\torch\utils\data\_utils\collate.py", line 83, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "C:\Users\Ahmad\anaconda3\envs\esc180\lib\site-packages\torch\utils\data\_utils\collate.py", line 83, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "C:\Users\Ahmad\anaconda3\envs\esc180\lib\site-packages\torch\utils\data\_utils\collate.py", line 52, in default_collate
    numel = sum([x.numel() for x in batch])
  File "C:\Users\Ahmad\anaconda3\envs\esc180\lib\site-packages\torch\utils\data\_utils\collate.py", line 52, in <listcomp>
    numel = sum([x.numel() for x in batch])
AttributeError: 'int' object has no attribute 'numel'


Algorithms such as MixMatch can also be used to further improve the student training process especially in cases where the mamount of data is limited which is often true when using medical datasets. We provide code to do so in the subfolder mix_match. 