In [1]:
import math
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from attack import (
    reconstruct_interactions,
    interaction_mia_fedrec,
)
from dataset import (
    MovieLens,
    Yelp,
)
from more_itertools import grouper
from ranker import (
    CollaborativeFilteringRecommender,
    NeuralCollaborativeFilteringRecommender,
)
from scipy.stats import ks_2samp
from tqdm.notebook import tqdm
from utils import (
    Metrics,
    apply_gaussian_mechanism,
)

In [2]:
def set_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

In [3]:
def worker(id, start, end):
    set_seed(id)

    data = Yelp()

    user_ids = data.get_all_user_ids()
    item_ids = data.get_all_item_ids()
    user_id_to_idx = {id: idx for idx, id in enumerate(user_ids)}
    item_id_to_idx = {id: idx for idx, id in enumerate(item_ids)}
    num_users = len(user_ids)
    num_items = len(item_ids)
    embedding_dim = 64
    neg_sample_ratio = 4

    atk_lr = 1e-01
    max_iter = 1000
    num_atk = 5

    epsilons = [1.0, 10.0, 100.0, math.inf]
    delta = 1e-08

    metrics = Metrics()

    user_embeddings = nn.Embedding(num_users, embedding_dim)
    item_embeddings = nn.Embedding(num_items, embedding_dim)

    fcf = CollaborativeFilteringRecommender()
    fncf = NeuralCollaborativeFilteringRecommender(embedding_dim, [16, 8])

    for user_id in tqdm(user_ids[start:end]):
        # Set up training data
        interacted_items = data.get_item_ids_for_users([user_id])[0]
        non_interacted_items = data.get_non_interacted_item_ids_for_users([user_id])[0]

        num_pos = len(interacted_items)
        sampled_non_interacted_items = random.sample(
            non_interacted_items,
            min(num_pos * neg_sample_ratio, len(non_interacted_items)),
        )
        num_neg = len(sampled_non_interacted_items)
        num_data = num_pos + num_neg

        user_embedding = (
            user_embeddings(torch.LongTensor([user_id_to_idx[user_id]]))
            .detach()
            .view(-1)
        )
        item_embedding = item_embeddings(
            torch.cat(
                [
                    torch.LongTensor([item_id_to_idx[id] for id in interacted_items]),
                    torch.LongTensor(
                        [item_id_to_idx[id] for id in sampled_non_interacted_items]
                    ),
                ]
            )
        ).detach()
        interactions = torch.cat([torch.ones(num_pos), torch.zeros(num_neg)])
        random_user_emb = torch.rand(embedding_dim)

        for epsilon in epsilons:
            # FCF Simple
            target = fcf.item_grad(user_embedding, item_embedding, interactions)
            target = apply_gaussian_mechanism(target, epsilon, delta, sensitivity=100)

            preds_raw, _ = reconstruct_interactions(
                lambda I: fcf.item_grad(random_user_emb, item_embedding, I),
                target,
                num_data,
                lr=atk_lr,
                max_iter=max_iter,
                num_rounds=num_atk,
                return_raw=True,
            )
            preds = preds_raw.sigmoid().round().long()

            metrics.update(
                f"FCF_simple_emb_{embedding_dim}_eps_{epsilon}",
                interactions,
                preds,
                preds_raw=preds_raw,
            )

            # FCF jointly estimate user embedding
            preds_raw, user_embedding_est, _ = reconstruct_interactions(
                lambda I, U: fcf.item_grad(U, item_embedding, I),
                target,
                num_data,
                private_params_size=embedding_dim,
                lr=atk_lr,
                max_iter=max_iter,
                num_rounds=num_atk,
                return_raw=True,
            )
            preds = preds_raw.sigmoid().round().long()
            embedding_err = F.mse_loss(user_embedding_est, user_embedding).item()

            metrics.update(
                f"FCF_joint_emb_{embedding_dim}_eps_{epsilon}",
                interactions,
                preds,
                preds_raw=preds_raw,
                extra_data={"est_user_emb_err": embedding_err},
            )

            # # FNCF setup
            target = fncf.item_grad(user_embedding, item_embedding, interactions)
            target = apply_gaussian_mechanism(target, epsilon, delta, sensitivity=0.005)

            mean_norm = torch.linalg.vector_norm(target, dim=1).mean()
            norm_scale = max(torch.Tensor([1.0]), torch.Tensor([1e+02]) / mean_norm)
            custom_loss = lambda e1, e2: F.pairwise_distance(e1, e2).mean() * norm_scale

            # FNCF simple
            preds_raw, _ = reconstruct_interactions(
                lambda I: fncf.item_grad(random_user_emb, item_embedding, I),
                target,
                num_data,
                lr=atk_lr,
                max_iter=max_iter,
                num_rounds=num_atk,
                loss_func=custom_loss,
                return_raw=True,
            )
            preds = preds_raw.sigmoid().round().long()

            metrics.update(
                f"FNCF_simple_emb_{embedding_dim}_eps_{epsilon}",
                interactions,
                preds,
                preds_raw=preds_raw,
            )

            # FNCF jointly estimate user embedding
            preds_raw, user_embedding_est, _ = reconstruct_interactions(
                lambda I, U: fncf.item_grad(U, item_embedding, I),
                target,
                num_data,
                private_params_size=embedding_dim,
                lr=atk_lr,
                max_iter=max_iter,
                num_rounds=num_atk,
                loss_func=custom_loss,
                return_raw=True,
            )
            preds = preds_raw.sigmoid().round().long()
            embedding_err = F.mse_loss(user_embedding_est, user_embedding).item()

            metrics.update(
                f"FNCF_joint_emb_{embedding_dim}_eps_{epsilon}",
                interactions,
                preds,
                preds_raw=preds_raw,
                extra_data={"est_user_emb_err": embedding_err},
            )

            # FNCF jointly estimate user embedding with neural net params
            feature_grad = fncf.feature_grad(user_embedding, item_embedding, interactions)
            feature_grad = apply_gaussian_mechanism(feature_grad, epsilon, delta, sensitivity=0.7)

            preds_raw, user_embedding_est, _ = reconstruct_interactions(
                lambda I, U: (
                    fncf.item_grad(U, item_embedding, I),
                    fncf.feature_grad(U, item_embedding, I, retain_graph=True),
                ),
                (target, feature_grad),
                num_data,
                private_params_size=embedding_dim,
                lr=atk_lr,
                max_iter=max_iter,
                num_rounds=num_atk,
                loss_func=lambda t1, t2: custom_loss(t1[0], t2[0]) + F.mse_loss(t1[1], t2[1]),
                return_raw=True,
            )
            preds = preds_raw.sigmoid().round().long()
            embedding_err = F.mse_loss(user_embedding_est, user_embedding).item()

            metrics.update(
                f"FNCF_joint_model_emb_{embedding_dim}_eps_{epsilon}",
                interactions,
                preds,
                preds_raw=preds_raw,
                extra_data={"est_user_emb_err": embedding_err},
            )

            # FNCF simple with neural net params
            preds_raw, _ = reconstruct_interactions(
                lambda I: (
                    fncf.item_grad(random_user_emb, item_embedding, I),
                    fncf.feature_grad(random_user_emb, item_embedding, I, retain_graph=True),
                ),
                (target, feature_grad),
                num_data,
                lr=atk_lr,
                max_iter=max_iter,
                num_rounds=num_atk,
                loss_func=lambda t1, t2: custom_loss(t1[0], t2[0]) + F.mse_loss(t1[1], t2[1]),
                return_raw=True,
            )
            preds = preds_raw.sigmoid().round().long()

            metrics.update(
                f"FNCF_simple_model_emb_{embedding_dim}_eps_{epsilon}",
                interactions,
                preds,
                preds_raw=preds_raw,
            )

            # Random guess
            preds_raw = torch.rand(num_data)
            metrics.update(
                f"Random_emb_{embedding_dim}_eps_{epsilon}",
                interactions,
                preds_raw.round().long(),
                preds_raw=preds_raw,
            )

            # IMIA FCF
            # target = fcf.item_grad(user_embedding, item_embedding, interactions)
            # preds = interaction_mia_fedrec(
            #     lambda I: fcf.item_grad(random_user_emb, item_embedding, I),
            #     target,
            #     num_data,
            #     select_ratio=interactions.mean(),
            # )

            # metrics.update(
            #     "FCF_IMIA_emb_{embedding_dim}_eps_{epsilon}",
            #     interactions,
            #     preds,
            # )

    metrics.save(f"../output/output_part_{id}.csv")

    return metrics

In [6]:
from multiprocessing import Pool, cpu_count

if __name__ == "__main__":
    num_processes = cpu_count()
    pool = Pool()

    data = Yelp()
    num_users = len(data.get_all_user_ids())
    num_users_per_process = math.ceil(num_users / num_processes)

    parameters = [
        (i, num_users_per_process * i, min(num_users, num_users_per_process * (i + 1)))
        for i in range(num_processes)
    ]

    metrics = pool.starmap(worker, parameters)
    final_metrics_df = pd.concat([m.df for m in metrics])
    final_metrics_df.to_csv(f"../output/output_final.csv", index=False)