## Setup

In [None]:
# Copyright (c) Meta Platforms, Inc. and affiliates.
import torch
import tqdm
from torchmetrics import AUROC
from torch.utils.data import DataLoader

from torchfm.dataset.criteo import CriteoDataset
from torchfm.dataset.movielens import MovieLens1MDataset, MovieLens20MDataset

from torchfm.model.dcn import DeepCrossNetworkModel
from torchfm.model.dfm import DeepFactorizationMachineModel
from torchfm.model.fm import FactorizationMachineModel
from torchfm.model.wd import WideAndDeepModel
from torchfm.model.afm import AttentionalFactorizationMachineModel

In [None]:
# Configure device, batch_size and metric.
device = torch.device('cuda:0')
batch_size = 512
auroc = AUROC(task='binary')

In [None]:
import sys
sys.path.append('../../pytei')

## Model and Dataset

- Using `torchfm` to implement model and dataset interfaces.
- Make sure to download `Movielens1M`, `Movielens20M`, `Criteo` datasets and put in the correct paths based on `dataset_paths`:
   - MovieLens-1M: https://grouplens.org/datasets/movielens/1m/
   - MovieLens-20M: https://grouplens.org/datasets/movielens/20m/
   - Criteo DAC: https://ailab.criteo.com/ressources/ 
- For the first time running this notebook please train the DRS models. For example, you can use the following bash command to train `DCN` with `MovieLens-1M` dataset:
    ```bash
    python3 ./train.py --dataset_name movielens1M --dataset_path ./ml-1m/ratings.dat --model_name dcn --save_dir ./chkpt
    ```

In [None]:
def get_dataset(name, path):
    if name == 'movielens1M':
        return MovieLens1MDataset(path)
    elif name == 'movielens20M':
        return MovieLens20MDataset(path)
    elif name == 'criteo':
        return CriteoDataset(path)
    else:
        raise ValueError('unknown dataset name: ' + name)

def get_model(name, dataset):
    field_dims = dataset.field_dims
    if name == 'fm':
        return FactorizationMachineModel(field_dims, embed_dim=16)
    elif name == 'wd':
        return WideAndDeepModel(field_dims, embed_dim=16, mlp_dims=(16, 16), dropout=0.2)
    elif name == 'dcn':
        return DeepCrossNetworkModel(field_dims, embed_dim=16, num_layers=3, mlp_dims=(16, 16), dropout=0.2)
    elif name == 'dfm':
        return DeepFactorizationMachineModel(field_dims, embed_dim=16, mlp_dims=(16, 16), dropout=0.2)
    elif name == 'afm':
        return AttentionalFactorizationMachineModel(field_dims, embed_dim=16, LNN_dim=1500, mlp_dims=(400, 400, 400), dropouts=(0, 0, 0))
    else:
        raise ValueError('unknown model name: ' + name)

def test(model, data_loader, device):
    model.eval()
    targets, predicts = [], []
    with torch.no_grad():
        for fields, target in tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0):
            fields, target = fields.to(device), target.to(device)
            y = model(fields)
            targets.extend(target.cpu())
            predicts.extend(y.cpu())
        targets = torch.FloatTensor(targets).squeeze()
        predicts = torch.FloatTensor(predicts).squeeze()
    return auroc(predicts, targets)

def testset_prepare(dataset_name, dataset_path):
    dataset = get_dataset(dataset_name, dataset_path)
    train_length = int(len(dataset) * 0.8)
    valid_length = int(len(dataset) * 0.1)
    test_length = len(dataset) - train_length - valid_length
    train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(dataset, (train_length, valid_length, test_length))
    test_data_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=0)
    return test_data_loader

## Error Injection

- Sweeping datasets, models and BERs. 
- Modify `./targets` to change targets for error injection / protection.

In [None]:
from pytei import Injector

In [None]:
base_dir = './chkpt/'
dataset_names = ['movielens1M', 'movielens20M', 'criteo']
dataset_paths = ['./ml-1m/ratings.dat', './MovieLens20M/rating.csv', './criteo-dac/train.txt']
model_paths = ['fm', 'dcn', 'afm', 'wd', 'dfm']
datasets = ['_movielens1M.pt', '_movielens20M.pt', '_criteo.pt']
bers = [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, ]
bers = bers[::-1]
folds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ]

results = torch.zeros(size = (len(model_paths), len(datasets), len(bers), len(folds)))

cnt = 0
with torch.no_grad():
    for dataset_i, dataset in enumerate(datasets):
        test_data_loader = testset_prepare(dataset_names[dataset_i], dataset_paths[dataset_i])
        for model_path_i, model_path in enumerate(model_paths):
            for ber_i, ber in enumerate(bers):
                for fold_i, fold in enumerate(folds):
                    model = torch.load(base_dir + model_path + dataset)
                    model = model.float().eval().to(device)
                    injector = Injector('./targets', p = ber, device = device, verbose = False)
                    injector.inject(model)
                    del injector
                    result = test(model, test_data_loader, device)
                    print(cnt, result)
                    results[model_path_i][dataset_i][ber_i][fold_i] = result
                    cnt += 1

## Mitigation

- To implement custom mitigation. please implement as `@classmethod` in `depytei.py`.

### 1. Activation Filtering

In [None]:
base_dir = './chkpt/'
dataset_names = ['movielens1M', 'movielens20M', 'criteo']
dataset_paths = ['./ml-1m/ratings.dat', './MovieLens20M/rating.csv', './criteo-dac/train.txt']
model_paths = ['fm', 'dcn', 'afm', 'wd', 'dfm']
datasets = ['_movielens1M.pt', '_movielens20M.pt', '_criteo.pt']
bers = [1e-2, 1e-3, 1e-4, 1e-5,]
bers = bers[::-1]
folds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ]

results = torch.zeros(size = (len(model_paths), len(datasets), len(bers), len(folds)))

cnt = 0
with torch.no_grad():
    for dataset_i, dataset in enumerate(datasets):
        test_data_loader = testset_prepare(dataset_names[dataset_i], dataset_paths[dataset_i])
        for model_path_i, model_path in enumerate(model_paths):
            for ber_i, ber in enumerate(bers):
                for fold_i, fold in enumerate(folds):
                    model = torch.load(base_dir + model_path + dataset)
                    model = model.float().eval().to(device)
                    injector = Injector('./targets', p = ber, device = device, verbose = False, mitigation = 'clip')
                    injector.inject(model, use_mitigation = True)
                    del injector
                    result = test(model, test_data_loader, device)
                    print(cnt, result)
                    results[model_path_i][dataset_i][ber_i][fold_i] = result
                    cnt += 1

### 2. Selective Bit Protection

In [None]:
base_dir = './chkpt/'
dataset_names = ['movielens1M', 'movielens20M', 'criteo']
dataset_paths = ['./ml-1m/ratings.dat', './MovieLens20M/rating.csv', './criteo-dac/train.txt']
model_paths = ['fm', 'dcn', 'afm', 'wd', 'dfm']
datasets = ['_movielens1M.pt', '_movielens20M.pt', '_criteo.pt']
bers = [1e-2, 1e-3, 1e-4, 1e-5,]
bers = bers[::-1]
folds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ]

results = torch.zeros(size = (len(model_paths), len(datasets), len(bers), len(folds)))

cnt = 0
with torch.no_grad():
    for dataset_i, dataset in enumerate(datasets):
        test_data_loader = testset_prepare(dataset_names[dataset_i], dataset_paths[dataset_i])
        for model_path_i, model_path in enumerate(model_paths):
            for ber_i, ber in enumerate(bers):
                for fold_i, fold in enumerate(folds):
                    model = torch.load(base_dir + model_path + dataset)
                    model = model.float().eval().to(device)
                    injector = Injector('./targets', p = ber, device = device, verbose = False, mitigation = 'SBP')
                    injector.inject(model, use_mitigation = True)
                    del injector
                    result = test(model, test_data_loader, device)
                    print(cnt, result)
                    results[model_path_i][dataset_i][ber_i][fold_i] = result
                    cnt += 1