In [1]:
import math
import matplotlib.pyplot as plt
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from attack import (
    reconstruct_interactions,
    interaction_mia_fedrec,
)
from dataset import (
    MovieLens,
)
from more_itertools import grouper
from ranker import (
    CollaborativeFilteringRecommender,
    NeuralCollaborativeFilteringRecommender,
)
from tqdm.notebook import tqdm
from utils import (
    Metrics,
    apply_gaussian_mechanism,
)

In [2]:
def set_seed():
    torch.manual_seed(2023)
    random.seed(2023)
    np.random.seed(2023)

data = MovieLens("../dataset/ML-100K/u.data")

In [None]:
# Simulation for collaborative filtering

set_seed()

num_sim_round = 10
num_features = 64
num_data = 1000
atk_lr = 1e-01
max_iter = 100000
num_atk = 10

metrics = Metrics()

for _ in tqdm(range(num_sim_round)):
    # features = torch.rand(num_data, num_features) * 2 - 1
    # user_embedding = torch.rand(num_features) * 2 - 1
    # user_embedding2 = torch.rand(num_features) * 2 - 1
    features = torch.normal(0, 1, (num_data, num_features))
    user_embedding = torch.normal(0, 1, (num_features,))
    user_embedding2 = torch.normal(0, 1, (num_features,))

    interactions = torch.randint(0, 2, (num_data,))
    while interactions.sum() == 0:
        interactions = torch.randint(0, 2, (num_data,))

    preds_raw = torch.rand((num_data),)
    metrics.update("Random", interactions, preds_raw.round().long(), preds_raw=preds_raw)

    ncf_rec = NeuralCollaborativeFilteringRecommender(num_features, [128, 64, 32])

    target = ncf_rec.item_grad(user_embedding, features, interactions.float())
    scale = max(1.0 / target.mean().abs(), 1.0)
    target = scale * target

    preds_raw, _ = reconstruct_interactions(
        lambda I: scale * ncf_rec.item_grad(user_embedding2, features, I),
        target,
        num_data,
        lr=atk_lr,
        max_iter=max_iter,
        num_rounds=num_atk,
        return_raw=True,
    )
    preds = preds_raw.sigmoid().round().long()

    metrics.update(
        "FedNCF_simple",
        interactions,
        preds,
        preds_raw=preds_raw,
    )

    target = ncf_rec.item_grad(user_embedding, features, interactions.float())
    scale = max(1.0 / target.mean().abs(), 1.0)
    target = scale * target

    preds_raw, user_embedding_est, _ = reconstruct_interactions(
        lambda I, U: scale * ncf_rec.item_grad(U, features, I),
        target,
        num_data,
        private_params_size=num_features,
        lr=atk_lr,
        max_iter=max_iter,
        num_rounds=num_atk,
        return_raw=True,
    )
    preds = preds_raw.sigmoid().round().long()
    
    embedding_err = F.mse_loss(user_embedding_est, user_embedding).item()

    metrics.update(
        "FedNCF_private",
        interactions,
        preds,
        preds_raw=preds_raw,
        extra_data={"est_user_emb_err": embedding_err},
    )

    item_grad = ncf_rec.item_grad(user_embedding, features, interactions.float()).flatten()
    scale = max(1.0 / item_grad.mean().abs(), 1.0)

    target = torch.cat(
        [
            scale * item_grad,
            ncf_rec.feature_grad(user_embedding, features, interactions.float()),
        ]
    )

    preds_raw, user_embedding_est, _ = reconstruct_interactions(
        lambda I, U: torch.cat(
            [
                scale * ncf_rec.item_grad(U, features, I).flatten(),
                ncf_rec.feature_grad(U, features, I, retain_graph=True),
            ]
        ),
        target,
        num_data,
        private_params_size=num_features,
        lr=atk_lr,
        max_iter=max_iter,
        num_rounds=num_atk,
        return_raw=True,
    )
    preds = preds_raw.sigmoid().round().long()

    embedding_err = F.mse_loss(user_embedding_est, user_embedding).item()

    metrics.update(
        "FedNCF_private2",
        interactions,
        preds,
        preds_raw=preds_raw,
        extra_data={"est_user_emb_err": embedding_err},
    )

    target = ncf_rec.item_grad(user_embedding, features, interactions.float())

    preds = interaction_mia_fedrec(
        lambda I: ncf_rec.item_grad(user_embedding2, features, I.float()),
        target,
        num_data,
        select_ratio=interactions.float().mean(),
    )

    metrics.update(
        "FedNCF_IMIA",
        interactions,
        preds,
    )

print(metrics.df[["name", "auc", "auc-pr"]].groupby("name").describe().to_string())

In [None]:
# Collaborative Filtering + DP

set_seed()

user_ids = data.get_all_user_ids()
item_ids = data.get_all_item_ids()
user_id_to_idx = {id: idx for idx, id in enumerate(user_ids)}
item_id_to_idx = {id: idx for idx, id in enumerate(item_ids)}
num_users = len(user_ids)
num_items = len(item_ids)
embedding_dim = 64
neg_sample_ratio = 4

num_sim_round = 1
atk_lr = 1e-01
max_iter = 1000
num_atk = 5

# epsilons = [1.0, 10.0, 100.0, math.inf]
epsilons = [math.inf]
delta = 1e-08

metrics = Metrics()

# def train_batch(model, user_embedding, item_embedding, interactions, num_batch, local_lr=0.1):
#     item_embedding = item_embedding.clone()
#     for _ in range(num_batch):
#         item_grad = model.item_grad(user_embedding, item_embedding, interactions)
#         item_embedding = item_embedding - local_lr * item_grad    
#     return item_embedding

for _ in tqdm(range(num_sim_round)):
    user_embeddings = nn.Embedding(num_users, embedding_dim)
    item_embeddings = nn.Embedding(num_items, embedding_dim)

    fcf = CollaborativeFilteringRecommender()
    fncf = NeuralCollaborativeFilteringRecommender(embedding_dim, [16, 8])

    for user_id in tqdm(user_ids):
        # Set up training data
        interacted_items = data.get_item_ids_for_users([user_id])[0]
        non_interacted_items = data.get_non_interacted_item_ids_for_users([user_id])[0]

        num_pos = len(interacted_items)
        sampled_non_interacted_items = random.sample(
            non_interacted_items,
            min(num_pos * neg_sample_ratio, len(non_interacted_items)),
        )
        num_neg = len(sampled_non_interacted_items)
        num_data = num_pos + num_neg

        user_embedding = (
            user_embeddings(torch.LongTensor([user_id_to_idx[user_id]]))
            .detach()
            .view(-1)
        )
        item_embedding = item_embeddings(
            torch.cat(
                [
                    torch.LongTensor([item_id_to_idx[id] for id in interacted_items]),
                    torch.LongTensor(
                        [item_id_to_idx[id] for id in sampled_non_interacted_items]
                    ),
                ]
            )
        ).detach()
        item_embedding.requires_grad_()
        interactions = torch.cat([torch.ones(num_pos), torch.zeros(num_neg)])
        random_user_emb = torch.rand(embedding_dim)

        for epsilon in epsilons:
            # FCF Simple
            target = fcf.item_grad(user_embedding, item_embedding, interactions)
            target = apply_gaussian_mechanism(target, epsilon, delta, sensitivity=100)

            preds_raw, _ = reconstruct_interactions(
                lambda I: fcf.item_grad(random_user_emb, item_embedding, I),
                target,
                num_data,
                lr=atk_lr,
                max_iter=max_iter,
                num_rounds=num_atk,
                return_raw=True,
            )
            preds = preds_raw.sigmoid().round().long()

            metrics.update(
                f"FCF_simple_emb_{embedding_dim}_eps_{epsilon}",
                interactions,
                preds,
                preds_raw=preds_raw,
            )

            # FCF jointly estimate user embedding
            preds_raw, user_embedding_est, _ = reconstruct_interactions(
                lambda I, U: fcf.item_grad(U, item_embedding, I),
                target,
                num_data,
                private_params_size=embedding_dim,
                lr=atk_lr,
                max_iter=max_iter,
                num_rounds=num_atk,
                return_raw=True,
            )
            preds = preds_raw.sigmoid().round().long()
            embedding_err = F.mse_loss(user_embedding_est, user_embedding).item()

            metrics.update(
                f"FCF_joint_emb_{embedding_dim}_eps_{epsilon}",
                interactions,
                preds,
                preds_raw=preds_raw,
                extra_data={"est_user_emb_err": embedding_err},
            )

            # FNCF setup
            target = fncf.item_grad(user_embedding, item_embedding, interactions)
            target = apply_gaussian_mechanism(target, epsilon, delta, sensitivity=0.005)
            mean_norm = torch.linalg.vector_norm(target, dim=1).mean()
            norm_scale = max(torch.Tensor([1.0]), torch.Tensor([1e+02]) / mean_norm)
            custom_loss = lambda e1, e2: F.pairwise_distance(e1, e2).mean() * norm_scale

            # FNCF simple
            preds_raw, _ = reconstruct_interactions(
                lambda I: fncf.item_grad(random_user_emb, item_embedding, I, create_graph=True),
                target,
                num_data,
                lr=atk_lr,
                max_iter=max_iter,
                num_rounds=num_atk,
                loss_func=custom_loss,
                return_raw=True,
            )
            preds = preds_raw.sigmoid().round().long()

            metrics.update(
                f"FNCF_simple_emb_{embedding_dim}_eps_{epsilon}",
                interactions,
                preds,
                preds_raw=preds_raw,
            )

            # FNCF jointly estimate user embedding
            preds_raw, user_embedding_est, _ = reconstruct_interactions(
                lambda I, U: fncf.item_grad(U, item_embedding, I, create_graph=True),
                target,
                num_data,
                private_params_size=embedding_dim,
                lr=atk_lr,
                max_iter=max_iter,
                num_rounds=num_atk,
                loss_func=custom_loss,
                return_raw=True,
            )
            preds = preds_raw.sigmoid().round().long()
            embedding_err = F.mse_loss(user_embedding_est, user_embedding).item()

            metrics.update(
                f"FNCF_joint_emb_{embedding_dim}_eps_{epsilon}",
                interactions,
                preds,
                preds_raw=preds_raw,
                extra_data={"est_user_emb_err": embedding_err},
            )

            # # FNCF jointly estimate user embedding with neural net params
            feature_grad = fncf.feature_grad(user_embedding, item_embedding, interactions)
            feature_grad = apply_gaussian_mechanism(feature_grad, epsilon, delta, sensitivity=0.7)

            preds_raw, user_embedding_est, _ = reconstruct_interactions(
                lambda I, U: (
                    fncf.item_grad(U, item_embedding, I, create_graph=True),
                    fncf.feature_grad(U, item_embedding, I, create_graph=True),
                ),
                (target, feature_grad),
                num_data,
                private_params_size=embedding_dim,
                lr=atk_lr,
                max_iter=max_iter,
                num_rounds=num_atk,
                loss_func=lambda t1, t2: custom_loss(t1[0], t2[0]) + F.mse_loss(t1[1], t2[1]),
                return_raw=True,
            )
            preds = preds_raw.sigmoid().round().long()
            embedding_err = F.mse_loss(user_embedding_est, user_embedding).item()

            metrics.update(
                f"FNCF_joint_model_emb_{embedding_dim}_eps_{epsilon}",
                interactions,
                preds,
                preds_raw=preds_raw,
                extra_data={"est_user_emb_err": embedding_err},
            )

            # # FNCF simple with neural net params
            preds_raw, _ = reconstruct_interactions(
                lambda I: (
                    fncf.item_grad(random_user_emb, item_embedding, I, create_graph=True),
                    fncf.feature_grad(random_user_emb, item_embedding, I, create_graph=True),
                ),
                (target, feature_grad),
                num_data,
                lr=atk_lr,
                max_iter=max_iter,
                num_rounds=num_atk,
                loss_func=lambda t1, t2: custom_loss(t1[0], t2[0]) + F.mse_loss(t1[1], t2[1]),
                return_raw=True,
            )
            preds = preds_raw.sigmoid().round().long()

            metrics.update(
                f"FNCF_simple_model_emb_{embedding_dim}_eps_{epsilon}",
                interactions,
                preds,
                preds_raw=preds_raw,
            )

            # Random guess
            preds_raw = torch.rand(num_data)
            metrics.update(
                f"Random_emb_{embedding_dim}_eps_{epsilon}",
                interactions,
                preds_raw.round().long(),
                preds_raw=preds_raw,
            )

            # IMIA FCF
            # target = fcf.item_grad(user_embedding, item_embedding, interactions)
            # preds = interaction_mia_fedrec(
            #     lambda I: fcf.item_grad(random_user_emb, item_embedding, I),
            #     target,
            #     num_data,
            #     select_ratio=interactions.mean(),
            # )

            # metrics.update(
            #     "FCF_IMIA_emb_{embedding_dim}_eps_{epsilon}",
            #     interactions,
            #     preds,
            # )

# metrics.save("../output/rec_metrics.csv")
print(metrics.df[["name", "auc", "auc-pr"]].groupby("name").describe().to_string())

In [None]:
# Collaborative Filtering + Pruning

set_seed()

user_ids = data.get_all_user_ids()
item_ids = data.get_all_item_ids()
user_id_to_idx = {id: idx for idx, id in enumerate(user_ids)}
item_id_to_idx = {id: idx for idx, id in enumerate(item_ids)}
num_users = len(user_ids)
num_items = len(item_ids)
embedding_dim = 64
neg_sample_ratio = 4

num_sim_round = 1
atk_lr = 1e-01
max_iter = 1000
num_atk = 5

prune_pct = [0.1, 0.3, 0.5, 0.7, 0.9, 0.99]

metrics = Metrics()

for _ in tqdm(range(num_sim_round)):
    user_embeddings = nn.Embedding(num_users, embedding_dim)
    item_embeddings = nn.Embedding(num_items, embedding_dim)

    fcf = CollaborativeFilteringRecommender()
    fncf = NeuralCollaborativeFilteringRecommender(embedding_dim, [16, 8])

    for user_id in tqdm(user_ids):
        # Set up training data
        interacted_items = data.get_item_ids_for_users([user_id])[0]
        non_interacted_items = data.get_non_interacted_item_ids_for_users([user_id])[0]

        num_pos = len(interacted_items)
        sampled_non_interacted_items = random.sample(
            non_interacted_items,
            min(num_pos * neg_sample_ratio, len(non_interacted_items)),
        )
        num_neg = len(sampled_non_interacted_items)
        num_data = num_pos + num_neg

        user_embedding = (
            user_embeddings(torch.LongTensor([user_id_to_idx[user_id]]))
            .detach()
            .view(-1)
        )
        item_embedding = item_embeddings(
            torch.cat(
                [
                    torch.LongTensor([item_id_to_idx[id] for id in interacted_items]),
                    torch.LongTensor(
                        [item_id_to_idx[id] for id in sampled_non_interacted_items]
                    ),
                ]
            )
        ).detach()
        item_embedding.requires_grad_()
        interactions = torch.cat([torch.ones(num_pos), torch.zeros(num_neg)])
        random_user_emb = torch.rand(embedding_dim)

        for pct in prune_pct:
            # FCF Simple
            target = fcf.item_grad(user_embedding, item_embedding, interactions)
            target = target * (target.abs() >= target.abs().quantile(pct))

            preds_raw, _ = reconstruct_interactions(
                lambda I: fcf.item_grad(random_user_emb, item_embedding, I),
                target,
                num_data,
                lr=atk_lr,
                max_iter=max_iter,
                num_rounds=num_atk,
                return_raw=True,
            )
            preds = preds_raw.sigmoid().round().long()

            metrics.update(
                f"FCF_simple_emb_{embedding_dim}_prune_{pct}",
                interactions,
                preds,
                preds_raw=preds_raw,
            )

            # FCF jointly estimate user embedding
            preds_raw, user_embedding_est, _ = reconstruct_interactions(
                lambda I, U: fcf.item_grad(U, item_embedding, I),
                target,
                num_data,
                private_params_size=embedding_dim,
                lr=atk_lr,
                max_iter=max_iter,
                num_rounds=num_atk,
                return_raw=True,
            )
            preds = preds_raw.sigmoid().round().long()
            embedding_err = F.mse_loss(user_embedding_est, user_embedding).item()

            metrics.update(
                f"FCF_joint_emb_{embedding_dim}_prune_{pct}",
                interactions,
                preds,
                preds_raw=preds_raw,
                extra_data={"est_user_emb_err": embedding_err},
            )

            # # FNCF setup
            target = fncf.item_grad(user_embedding, item_embedding, interactions)
            target = target * (target.abs() >= target.abs().quantile(pct))

            mean_norm = torch.linalg.vector_norm(target, dim=1).mean()
            norm_scale = max(torch.Tensor([1.0]), torch.Tensor([1e+02]) / mean_norm)
            custom_loss = lambda e1, e2: F.pairwise_distance(e1, e2).mean() * norm_scale

            # FNCF simple
            preds_raw, _ = reconstruct_interactions(
                lambda I: fncf.item_grad(random_user_emb, item_embedding, I, create_graph=True),
                target,
                num_data,
                lr=atk_lr,
                max_iter=max_iter,
                num_rounds=num_atk,
                loss_func=custom_loss,
                return_raw=True,
            )
            preds = preds_raw.sigmoid().round().long()

            metrics.update(
                f"FNCF_simple_emb_{embedding_dim}_prune_{pct}",
                interactions,
                preds,
                preds_raw=preds_raw,
            )

            # FNCF jointly estimate user embedding
            preds_raw, user_embedding_est, _ = reconstruct_interactions(
                lambda I, U: fncf.item_grad(U, item_embedding, I, create_graph=True),
                target,
                num_data,
                private_params_size=embedding_dim,
                lr=atk_lr,
                max_iter=max_iter,
                num_rounds=num_atk,
                loss_func=custom_loss,
                return_raw=True,
            )
            preds = preds_raw.sigmoid().round().long()
            embedding_err = F.mse_loss(user_embedding_est, user_embedding).item()

            metrics.update(
                f"FNCF_joint_emb_{embedding_dim}_prune_{pct}",
                interactions,
                preds,
                preds_raw=preds_raw,
                extra_data={"est_user_emb_err": embedding_err},
            )

            # FNCF jointly estimate user embedding with neural net params
            feature_grad = fncf.feature_grad(user_embedding, item_embedding, interactions)
            feature_grad = feature_grad * (feature_grad.abs() > feature_grad.abs().quantile(pct))

            preds_raw, user_embedding_est, _ = reconstruct_interactions(
                lambda I, U: (
                    fncf.item_grad(U, item_embedding, I, create_graph=True),
                    fncf.feature_grad(U, item_embedding, I, create_graph=True),
                ),
                (target, feature_grad),
                num_data,
                private_params_size=embedding_dim,
                lr=atk_lr,
                max_iter=max_iter,
                num_rounds=num_atk,
                loss_func=lambda t1, t2: custom_loss(t1[0], t2[0]) + F.mse_loss(t1[1], t2[1]),
                return_raw=True,
            )
            preds = preds_raw.sigmoid().round().long()
            embedding_err = F.mse_loss(user_embedding_est, user_embedding).item()

            metrics.update(
                f"FNCF_joint_model_emb_{embedding_dim}_prune_{pct}",
                interactions,
                preds,
                preds_raw=preds_raw,
                extra_data={"est_user_emb_err": embedding_err},
            )

            # FNCF simple with neural net params
            preds_raw, _ = reconstruct_interactions(
                lambda I: (
                    fncf.item_grad(random_user_emb, item_embedding, I, create_graph=True),
                    fncf.feature_grad(random_user_emb, item_embedding, I, create_graph=True),
                ),
                (target, feature_grad),
                num_data,
                lr=atk_lr,
                max_iter=max_iter,
                num_rounds=num_atk,
                loss_func=lambda t1, t2: custom_loss(t1[0], t2[0]) + F.mse_loss(t1[1], t2[1]),
                return_raw=True,
            )
            preds = preds_raw.sigmoid().round().long()

            metrics.update(
                f"FNCF_simple_model_emb_{embedding_dim}_prune_{pct}",
                interactions,
                preds,
                preds_raw=preds_raw,
            )

metrics.save("../output/rec_ML100K_pruned_metrics.csv")