In [2]:
import functorch
import math
import numpy as np
import random
import torch
import torchopt
import torch.nn as nn
import torch.nn.functional as F
import warnings
from attack import (
    reconstruct_interactions,
)
from dataset import (
    MovieLens,
    Steam200K,
)
from ranker import (
    NeuralCollaborativeFilteringRecommender,
)
from tqdm.notebook import tqdm
from utils import (
    Metrics,
    apply_gaussian_mechanism,
)
warnings.filterwarnings("ignore", message='.*make_functional.*')

In [7]:
# Load dataset. Choose either ML-100K or STEAM-200K

data = MovieLens("../dataset/ML-100K/u.data")
# data = Steam200K("../dataset/STEAM-200K/steam-200k.csv")

user_ids = data.get_all_user_ids()
item_ids = data.get_all_item_ids()
user_id_to_idx = {id: idx for idx, id in enumerate(user_ids)}
item_id_to_idx = {id: idx for idx, id in enumerate(item_ids)}
num_users = len(user_ids)
num_items = len(item_ids)

# Scaled down for artifact eval. Comment this out to run on the entire dataset
user_ids = user_ids[:30]

In [None]:
# Main results in Table V and VI of Section VI.A as well as DP results in Table IX of Section VII.A

def set_seed(seed=2023):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)    
set_seed()

# Initialize embeddings and model
embedding_dim = 64
neg_sample_ratio = 4
user_embeddings = nn.Embedding(num_users, embedding_dim)
item_embeddings = nn.Embedding(num_items, embedding_dim)
fncf = NeuralCollaborativeFilteringRecommender(embedding_dim, [128, 64, 32])

# Reconstruction attack parameters
atk_lr = 1e-01
max_iter = 1000
num_atk = 1

# Differential privacy parameters
epsilons = [1.0, 20.0, 100.0, 500.0, math.inf]
delta = 1e-08
sensitivity = 1e-01

# Local learning parameters
local_epoch = 20
local_lr = 0.001
reg_factors = [0.0, 1.0] # IMIA defense mu, 0.0 means no defense

metrics = Metrics()

# Local training algorithm
def train_fncf_functional(model, user_embedding, item_embeddings, interactions, num_epoch, lr, reg_factor):
    user_embedding.grad = None
    item_embeddings.grad = None
    func_model, model_params = functorch.make_functional(model)
    opt_params = (user_embedding, item_embeddings, *model_params)
    # use_accelerated_op=True would be faster but prevent reconstruction for some reasons. Bug?
    # eps_root must be set
    optimizer = torchopt.FuncOptimizer(torchopt.adam(lr=lr, eps_root=1e-08))
    for _ in range(num_epoch):
        preds = func_model(opt_params[2:], opt_params[0], opt_params[1])
        reg_loss = reg_factor * F.l1_loss(opt_params[1], item_embeddings) # IMIA defense L1 reg term
        loss = F.binary_cross_entropy(preds.view(-1), interactions) + reg_loss
        opt_params = optimizer.step(loss, opt_params)
    return item_embeddings - opt_params[1], opt_params[2:]

# Simulate attack on each user
for user_id in tqdm(user_ids):
    # Sample items and interactions
    interacted_items = data.get_item_ids_for_users([user_id])[0]
    non_interacted_items = data.get_non_interacted_item_ids_for_users([user_id])[0]

    num_pos = len(interacted_items)
    sampled_non_interacted_items = random.sample(
        non_interacted_items,
        min(num_pos * neg_sample_ratio, len(non_interacted_items)),
    )
    num_neg = len(sampled_non_interacted_items)
    num_data = num_pos + num_neg

    user_embedding = (
        user_embeddings(torch.LongTensor([user_id_to_idx[user_id]]))
        .detach()
        .view(-1)
    )
    item_embedding = item_embeddings(
        torch.cat(
            [
                torch.LongTensor([item_id_to_idx[id] for id in interacted_items]),
                torch.LongTensor(
                    [item_id_to_idx[id] for id in sampled_non_interacted_items]
                ),
            ]
        )
    ).detach()
    user_embedding.requires_grad_()
    item_embedding.requires_grad_()
    interactions = torch.cat([torch.ones(num_pos), torch.zeros(num_neg)]) # Ground truth interactions
    random_user_emb = torch.rand(embedding_dim, requires_grad=True) # The server doesn't know the real user embedding

    for epsilon in epsilons:
        for reg_factor in reg_factors:
            # Local training
            target, target_model_params = train_fncf_functional(fncf, user_embedding, item_embedding, interactions, local_epoch, local_lr, reg_factor)
            target = apply_gaussian_mechanism(target.detach(), epsilon, delta, sensitivity=sensitivity)

            # Attack
            mean_norm = torch.linalg.vector_norm(target, dim=1).mean()
            norm_scale = max(torch.Tensor([1.0]), torch.Tensor([1e+02]) / mean_norm)
            custom_loss = lambda e1, e2: F.pairwise_distance(e1, e2).mean() * norm_scale
            preds_raw, _ = reconstruct_interactions(
                lambda I: train_fncf_functional(fncf, random_user_emb, item_embedding, I, local_epoch, local_lr, reg_factor)[0] / local_lr,
                target / local_lr,
                num_data,
                lr=atk_lr,
                max_iter=max_iter,
                num_rounds=num_atk,
                loss_fn=custom_loss,
                return_raw=True,
            )
            preds = preds_raw.sigmoid().round().long()

            # Update attack performance
            metrics.update(
                f"FNCF_eps_{epsilon}_IMIA_{reg_factor}",
                interactions,
                preds,
                preds_raw=preds_raw,
            )

metrics.save("../output/rec_metrics.csv")
metrics.print_summary(["auc", "f1"])