In [1]:
import math
import matplotlib.pyplot as plt
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from attack import (
    reconstruct_interactions,
    interaction_mia_fedrec,
)
from dataset import (
    LearningToRankDataset,
)
from more_itertools import grouper
from ranker import (
    LinearPDGDRanker,
    Neural1LayerPDGDRanker,
    Neural2LayerPDGDRanker,
    CollaborativeFilteringRecommender,
    NeuralCollaborativeFilteringRecommender,
)
from tqdm.notebook import tqdm
from utils import (
    CascadeClickModel,
    Metrics,
)

In [27]:
# Simulation for LTR

torch.manual_seed(2023)
random.seed(2023)

num_sim_round = 10
num_features = 10
num_data = 100
lr = 1e-01
max_iter = 1000
num_atk = 1

metrics = Metrics()

models = {
    "linear_pdgd": LinearPDGDRanker(num_features),
    # "neural_1_pdgd": Neural1LayerPDGDRanker(num_features, hidden_size=5),
    # "neural_2_pdgd": Neural2LayerPDGDRanker(
    #     num_features, hidden_size=4, hidden_size2=2
    # ),
}

for _ in tqdm(range(num_sim_round)):
    features = torch.rand(num_data, num_features) * 2 - 1
    interactions = torch.randint(0, 2, (num_data,))
    while interactions.sum() == 0:
        interactions = torch.randint(0, 2, (num_data,))
    
    ranking = list(range(num_data))
    random.shuffle(ranking)
    ranking = torch.LongTensor(ranking)

    for model_name, model in models.items():
        params = model.gen_params()
        log_pos_bias_weight = model.calc_log_pos_bias_weight(
            ranking, model.forward_multiple(params, features), num_data
        )
        
        target = model.grad(
            params,
            features,
            ranking,
            interactions,
            log_pos_bias_weight=log_pos_bias_weight,
        )

        preds_raw = reconstruct_interactions(
            lambda I: model.grad(
                params, features, ranking, I, log_pos_bias_weight=log_pos_bias_weight
            ),
            target,
            num_data,
            lr=lr,
            max_iter=max_iter,
            num_rounds=num_atk,
            return_raw=True,
        )
        preds = preds_raw.sigmoid().round().long()

        metrics.update(model_name, interactions, preds, preds_raw=preds_raw)

    # Data manipulation
    if num_data > num_features:
        num_new_features = num_data - num_features
        new_features = torch.rand(num_data, num_new_features)
        features = torch.cat([features, new_features], dim=1)

        model = LinearPDGDRanker(num_features + num_new_features)
        params = model.gen_params()
        log_pos_bias_weight = model.calc_log_pos_bias_weight(
            ranking, model.forward_multiple(params, features), num_data
        )
        
        target = model.grad(
            params,
            features,
            ranking,
            interactions,
            log_pos_bias_weight=log_pos_bias_weight,
        )

        preds_raw = reconstruct_interactions(
            lambda I: model.grad(
                params, features, ranking, I, log_pos_bias_weight=log_pos_bias_weight
            ),
            target,
            num_data,
            lr=lr,
            max_iter=max_iter,
            num_rounds=num_atk,
            return_raw=True,
        )
        preds = preds_raw.sigmoid().round().long()

        metrics.update(model_name + "_DM", interactions, preds, preds_raw=preds_raw)

100%|██████████| 10/10 [05:11<00:00, 31.16s/it]


In [29]:
print(metrics.get_dataframe().to_string())

              name  accuracy        f1  precision    recall       auc    auc-pr extra_data
0      linear_pdgd    0.5425  0.548148   0.587302  0.513889  0.558399  0.607230         {}
1   linear_pdgd_DM    1.0000  1.000000   1.000000  1.000000  1.000000  1.000000         {}
2      linear_pdgd    0.4700  0.459184   0.505618  0.420561  0.500879  0.536752         {}
3   linear_pdgd_DM    1.0000  1.000000   1.000000  1.000000  1.000000  1.000000         {}
4      linear_pdgd    0.5250  0.520202   0.512438  0.528205  0.517724  0.485470         {}
5   linear_pdgd_DM    1.0000  1.000000   1.000000  1.000000  1.000000  1.000000         {}
6      linear_pdgd    0.4950  0.459893   0.457447  0.462366  0.504648  0.476432         {}
7   linear_pdgd_DM    1.0000  1.000000   1.000000  1.000000  1.000000  1.000000         {}
8      linear_pdgd    0.5200  0.505155   0.505155  0.505155  0.500651  0.472847         {}
9   linear_pdgd_DM    1.0000  1.000000   1.000000  1.000000  1.000000  1.000000         {}

In [4]:
# Simulation for collaborative filtering

torch.manual_seed(2023)
random.seed(2023)

num_sim_round = 10
num_features = 64
num_data = 1000
lr = 1e-01
max_iter = 100000
num_atk = 10

metrics = Metrics()

for _ in tqdm(range(num_sim_round)):
    # features = torch.rand(num_data, num_features) * 2 - 1
    # user_embedding = torch.rand(num_features) * 2 - 1
    # user_embedding2 = torch.rand(num_features) * 2 - 1
    features = torch.normal(0, 1, (num_data, num_features))
    user_embedding = torch.normal(0, 1, (num_features,))
    user_embedding2 = torch.normal(0, 1, (num_features,))

    interactions = torch.randint(0, 2, (num_data,))
    while interactions.sum() == 0:
        interactions = torch.randint(0, 2, (num_data,))

    preds_raw = torch.rand((num_data),)
    metrics.update("Random", interactions, preds_raw.sigmoid().round().long(), preds_raw=preds_raw)

    # cf_rec = CollaborativeFilteringRecommender()
    # target = cf_rec.federated_item_grad(user_embedding, features, interactions)

    # preds_raw = reconstruct_interactions(
    #     lambda I: cf_rec.federated_item_grad(user_embedding2, features, I),
    #     target,
    #     num_data,
    #     lr=lr,
    #     max_iter=max_iter,
    #     num_rounds=num_atk,layers
    #     select_ratio=interactions.float().mean(),
    # )

    # metrics.update(
    #     "FCF_IMIA",
    #     interactions,
    #     preds,
    # )

    ncf_rec = NeuralCollaborativeFilteringRecommender(num_features, [128, 64, 32])

    target = ncf_rec.item_grad(user_embedding, features, interactions.float())
    scale = max(1.0 / target.mean().abs(), 1.0)
    target = scale * target

    preds_raw = reconstruct_interactions(
        lambda I: scale * ncf_rec.item_grad(user_embedding2, features, I),
        target,
        num_data,
        lr=lr,
        max_iter=max_iter,
        num_rounds=num_atk,
        return_raw=True,
    )
    preds = preds_raw.sigmoid().round().long()

    metrics.update(
        "FedNCF_simple",
        interactions,
        preds,
        preds_raw=preds_raw,
    )

    target = ncf_rec.item_grad(user_embedding, features, interactions.float())
    scale = max(1.0 / target.mean().abs(), 1.0)
    target = scale * target

    preds_raw, user_embedding_est = reconstruct_interactions(
        lambda I, U: scale * ncf_rec.item_grad(U, features, I),
        target,
        num_data,
        private_params_size=num_features,
        lr=lr,
        max_iter=max_iter,
        num_rounds=num_atk,
        return_raw=True,
    )
    preds = preds_raw.sigmoid().round().long()
    
    embedding_err = F.mse_loss(user_embedding_est, user_embedding).item()

    metrics.update(
        "FedNCF_private",
        interactions,
        preds,
        preds_raw=preds_raw,
        extra_data={"embedding_err": embedding_err},
    )

    item_grad = ncf_rec.item_grad(user_embedding, features, interactions.float()).flatten()
    scale = max(1.0 / item_grad.mean().abs(), 1.0)

    target = torch.cat(
        [
            scale * item_grad,
            ncf_rec.feature_grad(user_embedding, features, interactions.float()),
        ]
    )

    preds_raw, user_embedding_est = reconstruct_interactions(
        lambda I, U: torch.cat(
            [
                scale * ncf_rec.item_grad(U, features, I).flatten(),
                ncf_rec.feature_grad(U, features, I, retain_graph=True),
            ]
        ),
        target,
        num_data,
        private_params_size=num_features,
        lr=lr,
        max_iter=max_iter,
        num_rounds=num_atk,
        return_raw=True,
    )
    preds = preds_raw.sigmoid().round().long()

    embedding_err = F.mse_loss(user_embedding_est, user_embedding).item()

    metrics.update(
        "FedNCF_private2",
        interactions,
        preds,
        preds_raw=preds_raw,
        extra_data={"embedding_err": embedding_err},
    )

    target = ncf_rec.item_grad(user_embedding, features, interactions.float())

    preds = interaction_mia_fedrec(
        lambda I: ncf_rec.item_grad(user_embedding2, features, I.float()),
        target,
        num_data,
        select_ratio=interactions.float().mean(),
    )

    metrics.update(
        "FedNCF_IMIA",
        interactions,
        preds,
    )

100%|██████████| 10/10 [01:07<00:00,  6.77s/it]


In [5]:
print(metrics.get_dataframe().to_string())

               name  accuracy        f1  precision    recall       auc    auc-pr                             extra_data
0            Random     0.486  0.654105   0.486000  1.000000  0.514337  0.491345                                     {}
1     FedNCF_simple     0.968  0.967871   0.945098  0.991770  0.996982  0.996428                                     {}
2    FedNCF_private     1.000  1.000000   1.000000  1.000000  1.000000  1.000000  {"embedding_err": 0.9334539175033569}
3   FedNCF_private2     0.991  0.990769   0.987730  0.993827  0.999868  0.999861  {"embedding_err": 1.2190757989883423}
4       FedNCF_IMIA     0.516  0.502058   0.502058  0.502058       NaN       NaN                                     {}
5            Random     0.507  0.672860   0.507000  1.000000  0.521126  0.518385                                     {}
6     FedNCF_simple     0.971  0.970971   0.985772  0.956607  0.997447  0.997581                                     {}
7    FedNCF_private     0.994  0.994071 

In [None]:
# PDGD: MQ2008, single query, single epoch

torch.manual_seed(2023)
random.seed(2023)

num_item_per_ranking = 10
num_sim_round = 1
lr = 1e-01
max_iter = 1000
num_atk = 1

metrics = Metrics()

data = LearningToRankDataset("../dataset/MQ2008/Fold1/train.txt")
num_features = data.get_num_features()

models = {
    "linear_pdgd": LinearPDGDRanker(num_features),
    "neural_4_pdgd": Neural1LayerPDGDRanker(num_features, hidden_size=4),
    "neural_8_pdgd": Neural1LayerPDGDRanker(num_features, hidden_size=8),
    "neural_16_pdgd": Neural1LayerPDGDRanker(num_features, hidden_size=16),
    "neural_4_2_pdgd": Neural2LayerPDGDRanker(
        num_features, hidden_size=4, hidden_size2=2
    ),
    "neural_8_4_pdgd": Neural2LayerPDGDRanker(
        num_features, hidden_size=8, hidden_size2=4
    ),
    "neural_16_8_pdgd": Neural2LayerPDGDRanker(
        num_features, hidden_size=16, hidden_size2=4
    ),
}

click_models = {
    # "perfect": CascadeClickModel(prob_click=[0.0, 0.5, 1.0], prob_stop=[0.0, 0.0, 0.0]),
    "navigational": CascadeClickModel(
        prob_click=[0.05, 0.5, 0.95], prob_stop=[0.2, 0.5, 0.9]
    ),
    "informational": CascadeClickModel(
        prob_click=[0.4, 0.7, 0.9], prob_stop=[0.1, 0.3, 0.5]
    ),
}


def simulate_attack(model, features, relevances, click_model):
    params = model.gen_params()
    ranking = model.rank(params, features, sample=True)[:num_item_per_ranking]
    features = features[ranking]
    interactions = torch.Tensor(click_model.click(ranking, relevances))
    num_data = len(ranking)

    # Remap the original ranking into the correct range
    _, ranking = torch.where(
        torch.sort(ranking)[0].unsqueeze(1) == ranking.unsqueeze(0)
    )

    log_pos_bias_weight = model.calc_log_pos_bias_weight(
        ranking, model.forward_multiple(params, features), num_data
    )

    target = model.grad(
        params,
        features,
        ranking,
        interactions,
        log_pos_bias_weight=log_pos_bias_weight,
    )

    preds_raw = reconstruct_interactions(
        lambda I: model.grad(
            params,
            features,
            ranking,
            I,
            log_pos_bias_weight=log_pos_bias_weight,
        ),
        target,
        num_data,
        lr=lr,
        max_iter=max_iter,
        num_rounds=num_atk,
        return_raw=True,
    )
    preds = preds_raw.sigmoid().round().long()

    return (interactions, preds, preds_raw)


for _ in tqdm(range(num_sim_round)):
    for qid in tqdm(data.get_all_query_ids()):
        relevances, features = data.get_data_for_queries([qid])[0]
        features = torch.Tensor(features)

        for model_name, model in models.items():
            for click_model_name, click_model in click_models.items():
                interactions, preds, preds_raw = simulate_attack(
                    model, features, relevances, click_model
                )
                metrics.update(
                    f"{model_name}_{click_model_name}",
                    interactions,
                    preds,
                    preds_raw=preds_raw,
                )

                # Random guess
                random_preds_raw = torch.rand(preds_raw.shape)
                random_preds = random_preds_raw.round()
                metrics.update(
                    f"{model_name}_{click_model_name}_random",
                    interactions,
                    random_preds,
                    preds_raw=random_preds_raw,
                )

print(metrics.df[["name", "auc", "auc-pr"]].groupby("name").describe().to_string())

In [None]:
# PDGD: MQ2008, single query, multiple epochs

torch.manual_seed(2023)
random.seed(2023)

num_item_per_ranking = 10
num_local_epochs = 5
local_lr = 1e-01

num_sim_round = 1
lr = 1e-01
max_iter = 1000
num_atk = 1

metrics = Metrics()

data = LearningToRankDataset("../dataset/MQ2008/Fold1/train.txt")
num_features = data.get_num_features()

models = {
    "linear_pdgd": LinearPDGDRanker(num_features),
    # "neural_4_pdgd": Neural1LayerPDGDRanker(num_features, hidden_size=4),
    # "neural_8_pdgd": Neural1LayerPDGDRanker(num_features, hidden_size=8),
    # "neural_16_pdgd": Neural1LayerPDGDRanker(num_features, hidden_size=16),
    "neural_4_2_pdgd": Neural2LayerPDGDRanker(
        num_features, hidden_size=4, hidden_size2=2
    ),
    # "neural_8_4_pdgd": Neural2LayerPDGDRanker(
    #     num_features, hidden_size=8, hidden_size2=4
    # ),
    # "neural_16_8_pdgd": Neural2LayerPDGDRanker(
    #     num_features, hidden_size=16, hidden_size2=4
    # ),
}

click_models = {
    # "perfect": CascadeClickModel(prob_click=[0.0, 0.5, 1.0], prob_stop=[0.0, 0.0, 0.0]),
    "navigational": CascadeClickModel(
        prob_click=[0.05, 0.5, 0.95], prob_stop=[0.2, 0.5, 0.9]
    ),
    "informational": CascadeClickModel(
        prob_click=[0.4, 0.7, 0.9], prob_stop=[0.1, 0.3, 0.5]
    ),
}

def train(model, params, features, ranking, interactions, num_local_epochs, local_lr):
    cur_params = params.clone()

    for _ in range(num_local_epochs):
        cur_grad = model.grad(
            cur_params,
            features,
            ranking,
            interactions,
        )

        cur_params = cur_params + local_lr * cur_grad

    return cur_params


def simulate_attack(model, features, relevances, click_model):
    params = model.gen_params()
    ranking = model.rank(params, features, sample=True)[:num_item_per_ranking]
    features = features[ranking]
    interactions = torch.Tensor(click_model.click(ranking, relevances))
    num_data = len(ranking)

    # Remap the original ranking into the correct range
    _, ranking = torch.where(
        torch.sort(ranking)[0].unsqueeze(1) == ranking.unsqueeze(0)
    )

    target = train(model, params, features, ranking, interactions, num_local_epochs, local_lr)

    preds_raw = reconstruct_interactions(
        lambda I: train(model, params, features, ranking, I, num_local_epochs, local_lr),
        target,
        num_data,
        lr=lr,
        max_iter=max_iter,
        num_rounds=num_atk,
        return_raw=True,
    )
    preds = preds_raw.sigmoid().round().long()

    return (interactions, preds, preds_raw)


for _ in tqdm(range(num_sim_round)):
    for qid in tqdm(data.get_all_query_ids()):
        relevances, features = data.get_data_for_queries([qid])[0]
        features = torch.Tensor(features)

        for model_name, model in models.items():
            for click_model_name, click_model in click_models.items():
                interactions, preds, preds_raw = simulate_attack(
                    model, features, relevances, click_model
                )
                metrics.update(
                    f"{model_name}_{click_model_name}",
                    interactions,
                    preds,
                    preds_raw=preds_raw,
                )

                # Random guess
                random_preds_raw = torch.rand(preds_raw.shape)
                random_preds = random_preds_raw.round()
                metrics.update(
                    f"{model_name}_{click_model_name}_random",
                    interactions,
                    random_preds,
                    preds_raw=random_preds_raw,
                )

print(metrics.df[["name", "auc", "auc-pr"]].groupby("name").describe().to_string())

In [None]:
# PDGD: MQ2008, multiple queries, no randomness

torch.manual_seed(2023)
random.seed(2023)

num_query_per_user = [1, 8, 16, 24, 32]
num_item_per_ranking = 10
local_lr = 1e-01

num_sim_round = 1
lr = 1e-01
max_iter = 1000
num_atk = 1

metrics = Metrics()

data = LearningToRankDataset("../dataset/MQ2008/Fold1/train.txt")
num_features = data.get_num_features()

models = {
    "linear_pdgd": LinearPDGDRanker(num_features),
    # "neural_4_pdgd": Neural1LayerPDGDRanker(num_features, hidden_size=4),
    # "neural_8_pdgd": Neural1LayerPDGDRanker(num_features, hidden_size=8),
    # "neural_16_pdgd": Neural1LayerPDGDRanker(num_features, hidden_size=16),
    "neural_4_2_pdgd": Neural2LayerPDGDRanker(
        num_features, hidden_size=4, hidden_size2=2
    ),
    # "neural_8_4_pdgd": Neural2LayerPDGDRanker(
    #     num_features, hidden_size=8, hidden_size2=4
    # ),
    # "neural_16_8_pdgd": Neural2LayerPDGDRanker(
    #     num_features, hidden_size=16, hidden_size2=4
    # ),
}

click_models = {
    # "perfect": CascadeClickModel(prob_click=[0.0, 0.5, 1.0], prob_stop=[0.0, 0.0, 0.0]),
    "navigational": CascadeClickModel(
        prob_click=[0.05, 0.5, 0.95], prob_stop=[0.2, 0.5, 0.9]
    ),
    "informational": CascadeClickModel(
        prob_click=[0.4, 0.7, 0.9], prob_stop=[0.1, 0.3, 0.5]
    ),
}


def train(model, params, grouped_train_data, local_lr):
    cur_params = params.clone()

    for features, ranking, interactions in grouped_train_data:
        cur_grad = model.grad(
            cur_params,
            features,
            ranking,
            interactions,
        )

        cur_params = cur_params + local_lr * cur_grad

    return cur_params


def simulate_attack(model, grouped_data, click_model):
    params = model.gen_params()

    grouped_train_data = []
    indices = []
    start_ind = 0
    for relevances, features in grouped_data:
        features = torch.Tensor(features)
        ranking = model.rank(params, features, sample=True)[:num_item_per_ranking]
        features = features[ranking]
        interactions = torch.Tensor(click_model.click(ranking, relevances))

        # Remap the original ranking into the correct range
        _, ranking = torch.where(
            torch.sort(ranking)[0].unsqueeze(1) == ranking.unsqueeze(0)
        )
        grouped_train_data.append((features, ranking, interactions))
        indices.append((start_ind, start_ind + len(ranking)))
        start_ind += len(ranking)

    target = train(model, params, grouped_train_data, local_lr)

    preds_raw = reconstruct_interactions(
        lambda I: train(
            model,
            params,
            [
                (features, ranking, I[indices[idx][0] : indices[idx][1]])
                for idx, (features, ranking, _) in enumerate(grouped_train_data)
            ],
            local_lr,
        ),
        target,
        indices[-1][1],
        lr=lr,
        max_iter=max_iter,
        num_rounds=num_atk,
        return_raw=True,
    )
    preds = preds_raw.sigmoid().round().long()
    interactions = torch.cat([I for (_, _, I) in grouped_train_data])
    return (interactions, preds, preds_raw)


for num_query in num_query_per_user:
    print(f"Num query: {num_query}")
    for _ in tqdm(range(num_sim_round)):
        for qids in tqdm(
            grouper(data.get_all_query_ids(), num_query, incomplete="ignore"),
            total=len(data.get_all_query_ids()) // num_query
        ):
            grouped_data = data.get_data_for_queries(list(qids))

            for model_name, model in models.items():
                for click_model_name, click_model in click_models.items():
                    interactions, preds, preds_raw = simulate_attack(
                        model, grouped_data, click_model
                    )
                    metrics.update(
                        f"{model_name}_{click_model_name}_{num_query}_query",
                        interactions,
                        preds,
                        preds_raw=preds_raw,
                    )

                    # Random guess
                    random_preds_raw = torch.rand(preds_raw.shape)
                    random_preds = random_preds_raw.round()
                    metrics.update(
                        f"{model_name}_{click_model_name}_{num_query}_query_random",
                        interactions,
                        random_preds,
                        preds_raw=random_preds_raw,
                    )

print(metrics.df[["name", "auc", "auc-pr"]].groupby("name").describe().to_string())

In [None]:
# PDGD: MQ2008, multiple queries, random order

torch.manual_seed(2023)
random.seed(2023)

num_query_per_user = [8, 16, 24, 32]
num_item_per_ranking = 10
local_lr = 1e-01

num_sim_round = 1
lr = 1e-01
max_iter = 1000
num_atk = 1

metrics = Metrics()

data = LearningToRankDataset("../dataset/MQ2008/Fold1/test.txt")
num_features = data.get_num_features()

models = {
    "linear_pdgd": LinearPDGDRanker(num_features),
    # "neural_4_pdgd": Neural1LayerPDGDRanker(num_features, hidden_size=4),
    # "neural_8_pdgd": Neural1LayerPDGDRanker(num_features, hidden_size=8),
    # "neural_16_pdgd": Neural1LayerPDGDRanker(num_features, hidden_size=16),
    "neural_4_2_pdgd": Neural2LayerPDGDRanker(
        num_features, hidden_size=4, hidden_size2=2
    ),
    # "neural_8_4_pdgd": Neural2LayerPDGDRanker(
    #     num_features, hidden_size=8, hidden_size2=4
    # ),
    # "neural_16_8_pdgd": Neural2LayerPDGDRanker(
    #     num_features, hidden_size=16, hidden_size2=4
    # ),
}

click_models = {
    # "perfect": CascadeClickModel(prob_click=[0.0, 0.5, 1.0], prob_stop=[0.0, 0.0, 0.0]),
    "navigational": CascadeClickModel(
        prob_click=[0.05, 0.5, 0.95], prob_stop=[0.2, 0.5, 0.9]
    ),
    "informational": CascadeClickModel(
        prob_click=[0.4, 0.7, 0.9], prob_stop=[0.1, 0.3, 0.5]
    ),
}


def train(model, params, grouped_train_data, local_lr):
    cur_params = params.clone()

    for features, ranking, interactions in grouped_train_data:
        cur_grad = model.grad(
            cur_params,
            features,
            ranking,
            interactions,
        )

        cur_params = cur_params + local_lr * cur_grad

    return cur_params


def simulate_attack(model, grouped_data, click_model):
    params = model.gen_params()

    grouped_train_data = []
    indices = []
    start_ind = 0
    for relevances, features in grouped_data:
        features = torch.Tensor(features)
        ranking = model.rank(params, features, sample=True)[:num_item_per_ranking]
        features = features[ranking]
        interactions = torch.Tensor(click_model.click(ranking, relevances))

        # Remap the original ranking into the correct range
        _, ranking = torch.where(
            torch.sort(ranking)[0].unsqueeze(1) == ranking.unsqueeze(0)
        )
        grouped_train_data.append((features, ranking, interactions))
        indices.append((start_ind, start_ind + len(ranking)))
        start_ind += len(ranking)

    target = train(
        model,
        params,
        random.sample(grouped_train_data, len(grouped_train_data)),
        local_lr,
    )

    preds_raw = reconstruct_interactions(
        lambda I: train(
            model,
            params,
            [
                (features, ranking, I[indices[idx][0] : indices[idx][1]])
                for idx, (features, ranking, _) in enumerate(grouped_train_data)
            ],
            local_lr,
        ),
        target,
        indices[-1][1],
        lr=lr,
        max_iter=max_iter,
        num_rounds=num_atk,
        return_raw=True,
    )
    preds = preds_raw.sigmoid().round().long()
    interactions = torch.cat([I for (_, _, I) in grouped_train_data])
    return (interactions, preds, preds_raw)


for num_query in num_query_per_user:
    print(f"Num query: {num_query}")
    for _ in tqdm(range(num_sim_round)):
        for qids in tqdm(
            grouper(data.get_all_query_ids(), num_query, incomplete="ignore"),
            total=len(data.get_all_query_ids()) // num_query,
        ):
            grouped_data = data.get_data_for_queries(list(qids))

            for model_name, model in models.items():
                for click_model_name, click_model in click_models.items():
                    interactions, preds, preds_raw = simulate_attack(
                        model, grouped_data, click_model
                    )
                    metrics.update(
                        f"{model_name}_{click_model_name}_{num_query}_query",
                        interactions,
                        preds,
                        preds_raw=preds_raw,
                    )

                    # Random guess
                    random_preds_raw = torch.rand(preds_raw.shape)
                    random_preds = random_preds_raw.round()
                    metrics.update(
                        f"{model_name}_{click_model_name}_{num_query}_query_random",
                        interactions,
                        random_preds,
                        preds_raw=random_preds_raw,
                    )

print(metrics.df[["name", "auc", "auc-pr"]].groupby("name").describe().to_string())