In [1]:
import functorch
import math
import numpy as np
import random
import torch
import torchopt
import torch.nn as nn
import torch.nn.functional as F
import warnings
from attack import (
    reconstruct_interactions,
    interaction_mia_fedrec,
)
from dataset import (
    MovieLens,
    Steam200K,
)
from ranker import (
    CollaborativeFilteringRecommender,
    NeuralCollaborativeFilteringRecommender,
)
from tqdm.notebook import tqdm
from utils import (
    Metrics,
    apply_gaussian_mechanism,
)
warnings.filterwarnings("ignore", message='.*make_functional.*')

In [2]:
def set_seed():
    torch.manual_seed(2023)
    random.seed(2023)
    np.random.seed(2023)

# data = MovieLens("../dataset/ML-100K/u.data")
data = Steam200K("../dataset/STEAM-200K/steam-200k.csv")

In [None]:
# Collaborative Filtering + DP

set_seed()

user_ids = data.get_all_user_ids()
item_ids = data.get_all_item_ids()
user_id_to_idx = {id: idx for idx, id in enumerate(user_ids)}
item_id_to_idx = {id: idx for idx, id in enumerate(item_ids)}
num_users = len(user_ids)
num_items = len(item_ids)
embedding_dim = 64
neg_sample_ratio = 4

num_sim_round = 1
atk_lr = 1e-01
max_iter = 1000
num_atk = 1

epsilons = [1.0, 10.0, 20.0, 100.0, 500.0, math.inf]
delta = 1e-08

local_epoch = 20
local_lr = 0.001

metrics = Metrics()

def train_fncf_functional(model, user_embedding, item_embeddings, interactions, num_epoch, lr):
    user_embedding.grad = None
    item_embeddings.grad = None
    func_model, model_params = functorch.make_functional(model)
    opt_params = (user_embedding, item_embeddings, *model_params)
    # use_accelerated_op=True would be faster but prevent reconstruction for some reasons. Bug?
    # eps_root must be set
    optimizer = torchopt.FuncOptimizer(torchopt.adam(lr=lr, eps_root=1e-08))
    for _ in range(num_epoch):
        preds = func_model(opt_params[2:], opt_params[0], opt_params[1])
        loss = F.binary_cross_entropy(preds.view(-1), interactions)
        opt_params = optimizer.step(loss, opt_params)
    return item_embeddings - opt_params[1], opt_params[2:]

for _ in tqdm(range(num_sim_round)):
    user_embeddings = nn.Embedding(num_users, embedding_dim)
    item_embeddings = nn.Embedding(num_items, embedding_dim)

    # fcf = CollaborativeFilteringRecommender()
    fncf = NeuralCollaborativeFilteringRecommender(embedding_dim, [128, 64, 32])

    for user_id in tqdm(user_ids):
        # Set up training data
        interacted_items = data.get_item_ids_for_users([user_id])[0]
        non_interacted_items = data.get_non_interacted_item_ids_for_users([user_id])[0]

        num_pos = len(interacted_items)
        sampled_non_interacted_items = random.sample(
            non_interacted_items,
            min(num_pos * neg_sample_ratio, len(non_interacted_items)),
        )
        num_neg = len(sampled_non_interacted_items)
        num_data = num_pos + num_neg

        user_embedding = (
            user_embeddings(torch.LongTensor([user_id_to_idx[user_id]]))
            .detach()
            .view(-1)
        )
        item_embedding = item_embeddings(
            torch.cat(
                [
                    torch.LongTensor([item_id_to_idx[id] for id in interacted_items]),
                    torch.LongTensor(
                        [item_id_to_idx[id] for id in sampled_non_interacted_items]
                    ),
                ]
            )
        ).detach()
        user_embedding.requires_grad_()
        item_embedding.requires_grad_()
        interactions = torch.cat([torch.ones(num_pos), torch.zeros(num_neg)])
        random_user_emb = torch.rand(embedding_dim, requires_grad=True)

        for epsilon in epsilons:
            # FNCF setup
            target, target_model_params = train_fncf_functional(fncf, user_embedding, item_embedding, interactions, local_epoch, local_lr)
            target = apply_gaussian_mechanism(target.detach(), epsilon, delta, sensitivity=1e-01)
            mean_norm = torch.linalg.vector_norm(target, dim=1).mean()
            norm_scale = max(torch.Tensor([1.0]), torch.Tensor([1e+02]) / mean_norm)
            custom_loss = lambda e1, e2: F.pairwise_distance(e1, e2).mean() * norm_scale

            # FNCF simple
            preds_raw, _ = reconstruct_interactions(
                lambda I: train_fncf_functional(fncf, random_user_emb, item_embedding, I, local_epoch, local_lr)[0] / local_lr,
                target / local_lr,
                num_data,
                lr=atk_lr,
                max_iter=max_iter,
                num_rounds=num_atk,
                loss_fn=custom_loss,
                return_raw=True,
            )
            preds = preds_raw.sigmoid().round().long()

            metrics.update(
                f"FNCF_simple_emb_{embedding_dim}_eps_{epsilon}",
                interactions,
                preds,
                preds_raw=preds_raw,
            )

            # # FNCF jointly estimate user embedding
            # preds_raw, user_embedding_est, _ = reconstruct_interactions(
            #     lambda I, U: train_fncf_functional(fncf, U, item_embedding, I, local_epoch, local_lr)[0] / local_lr,
            #     target / local_lr,
            #     num_data,
            #     private_params_size=embedding_dim,
            #     lr=atk_lr,
            #     max_iter=max_iter,
            #     num_rounds=num_atk,
            #     loss_fn=custom_loss,
            #     return_raw=True,
            # )
            # preds = preds_raw.sigmoid().round().long()
            # embedding_err = F.mse_loss(user_embedding_est, user_embedding).item()

            # metrics.update(
            #     f"FNCF_joint_emb_{embedding_dim}_eps_{epsilon}",
            #     interactions,
            #     preds,
            #     preds_raw=preds_raw,
            #     extra_data={"est_user_emb_err": embedding_err},
            # )

            # # FNCF simple with neural net params
            # target_model_params = torch.cat([p.detach().view(-1) for p in target_model_params])
            # preds_raw, _ = reconstruct_interactions(
            #     lambda I: train_fncf_functional(fncf, random_user_emb, item_embedding, I, local_epoch, local_lr),
            #     (target, target_model_params),
            #     num_data,
            #     lr=atk_lr,
            #     max_iter=max_iter,
            #     num_rounds=num_atk,
            #     loss_fn=lambda t1, t2: custom_loss(t1[0] / local_lr, t2[0] / local_lr) + F.mse_loss(torch.cat([p.view(-1) for p in t1[1]]), t2[1]),
            #     return_raw=True,
            # )
            # preds = preds_raw.sigmoid().round().long()

            # metrics.update(
            #     f"FNCF_simple_model_emb_{embedding_dim}_eps_{epsilon}",
            #     interactions,
            #     preds,
            #     preds_raw=preds_raw,
            # )

            # # FNCF jointly estimate user embedding with neural net params
            # preds_raw, _ = reconstruct_interactions(
            #     lambda I, U: train_fncf_functional(fncf, U, item_embedding, I, local_epoch, local_lr),
            #     (target, target_model_params),
            #     num_data,
            #     private_params_size=embedding_dim,
            #     lr=atk_lr,
            #     max_iter=max_iter,
            #     num_rounds=num_atk,
            #     loss_fn=lambda t1, t2: custom_loss(t1[0] / local_lr, t2[0] / local_lr) + F.mse_loss(torch.cat([p.view(-1) for p in t1[1]]), t2[1]),
            #     return_raw=True,
            # )
            # preds = preds_raw.sigmoid().round().long()

            # metrics.update(
            #     f"FNCF_joint_model_emb_{embedding_dim}_eps_{epsilon}",
            #     interactions,
            #     preds,
            #     preds_raw=preds_raw,
            #     extra_data={"est_user_emb_err": embedding_err},
            # )


            # # FCF Simple
            # sens = 0.005
            # target = 0.01 * fcf.item_grad(user_embedding, item_embedding, interactions).detach()
            # target = apply_gaussian_mechanism(target, epsilon, delta, sensitivity=sens)
            # mean_norm = torch.linalg.vector_norm(target, dim=1).mean()
            # norm_scale = max(torch.Tensor([1.0]), torch.Tensor([1e+02]) / mean_norm)
            # custom_loss = lambda e1, e2: F.pairwise_distance(e1, e2).mean() * norm_scale

            # preds_raw, _ = reconstruct_interactions(
            #     lambda I: rescale_grad_for_dp(fcf.item_grad(random_user_emb, item_embedding, I, create_graph=True), epsilon, sens),
            #     target,
            #     num_data,
            #     lr=atk_lr,
            #     max_iter=max_iter,
            #     num_rounds=num_atk,
            #     loss_fn=custom_loss,
            #     return_raw=True,
            # )
            # preds = preds_raw.sigmoid().round().long()

            # metrics.update(
            #     f"FCF_simple_emb_{embedding_dim}_eps_{epsilon}",
            #     interactions,
            #     preds,
            #     preds_raw=preds_raw,
            # )

            # # FCF jointly estimate user embedding
            # preds_raw, user_embedding_est, _ = reconstruct_interactions(
            #     lambda I, U: 0.01 * fcf.item_grad(U, item_embedding, I, create_graph=True),
            #     target,
            #     num_data,
            #     private_params_size=embedding_dim,
            #     lr=atk_lr,
            #     max_iter=max_iter,
            #     num_rounds=num_atk,
            #     loss_fn=custom_loss,
            #     return_raw=True,
            # )
            # preds = preds_raw.sigmoid().round().long()
            # embedding_err = F.mse_loss(user_embedding_est, user_embedding).item()

            # metrics.update(
            #     f"FCF_joint_emb_{embedding_dim}_eps_{epsilon}",
            #     interactions,
            #     preds.detach(),
            #     preds_raw=preds_raw.detach(),
            #     extra_data={"est_user_emb_err": embedding_err},
            # )

            # # Random guess
            # preds_raw = torch.rand(num_data)
            # metrics.update(
            #     f"Random_emb_{embedding_dim}_eps_{epsilon}",
            #     interactions,
            #     preds_raw.round().long(),
            #     preds_raw=preds_raw,
            # )

            # IMIA FCF
            # target = fcf.item_grad(user_embedding, item_embedding, interactions).detach()
            # preds = interaction_mia_fedrec(
            #     lambda I: fcf.item_grad(random_user_emb, item_embedding, I),
            #     target,
            #     num_data,
            #     select_ratio=interactions.mean(),
            # )

            # metrics.update(
            #     "FCF_IMIA_emb_{embedding_dim}_eps_{epsilon}",
            #     interactions,
            #     preds,
            # )

# metrics.save("../output/rec_metrics.csv")
print(metrics.df[["name", "auc", "auc-pr"]].groupby("name").describe().to_string())