## Setup

In [None]:
import torch
import tqdm
from torchmetrics import AUROC
from torch.utils.data import DataLoader

from torchfm.dataset.avazu import AvazuDataset
from torchfm.dataset.criteo import CriteoDataset
from torchfm.dataset.movielens import MovieLens1MDataset, MovieLens20MDataset

from torchfm.model.dcn import DeepCrossNetworkModel
from torchfm.model.dfm import DeepFactorizationMachineModel
from torchfm.model.fm import FactorizationMachineModel
from torchfm.model.wd import WideAndDeepModel
from torchfm.model.afm import AttentionalFactorizationMachineModel

In [None]:
device = torch.device('cuda:0')
batch_size = 512
auroc = AUROC(task="binary")

## Model and Dataset

In [None]:
def get_dataset(name, path):
    if name == 'movielens1M':
        return MovieLens1MDataset(path)
    elif name == 'movielens20M':
        return MovieLens20MDataset(path)
    elif name == 'criteo':
        return CriteoDataset(path)
    elif name == 'avazu':
        return AvazuDataset(path)
    else:
        raise ValueError('unknown dataset name: ' + name)

def get_model(name, dataset):
    field_dims = dataset.field_dims
    if name == 'fm':
        return FactorizationMachineModel(field_dims, embed_dim=16)
    elif name == 'wd':
        return WideAndDeepModel(field_dims, embed_dim=16, mlp_dims=(16, 16), dropout=0.2)
    elif name == 'dcn':
        return DeepCrossNetworkModel(field_dims, embed_dim=16, num_layers=3, mlp_dims=(16, 16), dropout=0.2)
    elif name == 'dfm':
        return DeepFactorizationMachineModel(field_dims, embed_dim=16, mlp_dims=(16, 16), dropout=0.2)
    elif name == 'afm':
        return AttentionalFactorizationMachineModel(field_dims, embed_dim=16, LNN_dim=1500, mlp_dims=(400, 400, 400), dropouts=(0, 0, 0))
    else:
        raise ValueError('unknown model name: ' + name)


class EarlyStopper(object):
    def __init__(self, num_trials, save_path):
        self.num_trials = num_trials
        self.trial_counter = 0
        self.best_accuracy = 0
        self.save_path = save_path

    def is_continuable(self, model, accuracy):
        if accuracy > self.best_accuracy:
            self.best_accuracy = accuracy
            self.trial_counter = 0
            torch.save(model, self.save_path)
            return True
        elif self.trial_counter + 1 < self.num_trials:
            self.trial_counter += 1
            return True
        else:
            return False

def test(model, data_loader, device):
    model.eval()
    targets, predicts = [], []
    with torch.no_grad():
        for fields, target in tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0):
            fields, target = fields.to(device), target.to(device)
            y = model(fields)
            targets.extend(target.cpu())
            predicts.extend(y.cpu())
        targets = torch.FloatTensor(targets).squeeze()
        predicts = torch.FloatTensor(predicts).squeeze()
    return auroc(predicts, targets)

## Error Injection

In [None]:
from terrorch.terrorch import Injector

In [None]:
base_dir = './chkpt/'
dataset_names = ['movielens1M', 'movielens20M', 'criteo']
dataset_paths = ['D://Datasets//Rec//ml-1m//ratings.dat', 'D://Datasets//Rec//MovieLens20M//rating.csv', 'D://Datasets//Rec//criteo-dac//train.txt']
model_paths = ['fm', 'dcn', 'afm', 'wd', 'dfm']
datasets = ['_movielens1M.pt', '_movielens20M.pt', '_criteo.pt']
bers = [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, ]
bers = bers[::-1]
folds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ]

results = torch.zeros(size = (len(model_paths), len(datasets), len(bers), len(folds)))

cnt = 0
with torch.no_grad():
    for dataset_i, dataset in enumerate(datasets):
        test_data_loader = testset_prepare(dataset_names[dataset_i], dataset_paths[dataset_i])
        for model_path_i, model_path in enumerate(model_paths):
            for ber_i, ber in enumerate(bers):
                for fold_i, fold in enumerate(folds):
                    model = torch.load(base_dir + model_path + dataset)
                    model = model.float().eval().to(device)
                    injector = Injector(ber, param_names = ['mlp', 'fc', 'afm'], device = device, verbose = False)
                    injector.inject(model)
                    del injector
                    result = test(model, test_data_loader, device)
                    print(cnt, result)
                    results[model_path_i][dataset_i][ber_i][fold_i] = result
                    cnt += 1

## Mitigation

### 1. Activation Filtering

In [None]:
base_dir = './chkpt/'
dataset_names = ['movielens1M', 'movielens20M', 'criteo']
dataset_paths = ['D://Datasets//Rec//ml-1m//ratings.dat', 'D://Datasets//Rec//MovieLens20M//rating.csv', 'D://Datasets//Rec//criteo-dac//train.txt']
model_paths = ['fm', 'dcn', 'afm', 'wd', 'dfm']
datasets = ['_movielens1M.pt', '_movielens20M.pt', '_criteo.pt']
bers = [1e-2, 1e-3, 1e-4, 1e-5,]
bers = bers[::-1]
folds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ]

results = torch.zeros(size = (len(model_paths), len(datasets), len(bers), len(folds)))

cnt = 0
with torch.no_grad():
    for dataset_i, dataset in enumerate(datasets):
        test_data_loader = testset_prepare(dataset_names[dataset_i], dataset_paths[dataset_i])
        for model_path_i, model_path in enumerate(model_paths):
            for ber_i, ber in enumerate(bers):
                for fold_i, fold in enumerate(folds):
                    model = torch.load(base_dir + model_path + dataset)
                    model = model.float().eval().to(device)
                    injector = Injector(ber, param_names = ['mlp', 'fc', 'afm'], device = device, verbose = False, mitigation = 'clip')
                    injector.inject(model)
                    injector.perform_mitigation(model)
                    del injector
                    result = test(model, test_data_loader, device)
                    print(cnt, result)
                    results[model_path_i][dataset_i][ber_i][fold_i] = result
                    cnt += 1

### 2. Selective Bit Protection

In [None]:
base_dir = './chkpt/'
dataset_names = ['movielens1M', 'movielens20M', 'criteo']
dataset_paths = ['D://Datasets//Rec//ml-1m//ratings.dat', 'D://Datasets//Rec//MovieLens20M//rating.csv', 'D://Datasets//Rec//criteo-dac//train.txt']
model_paths = ['fm', 'dcn', 'afm', 'wd', 'dfm']
datasets = ['_movielens1M.pt', '_movielens20M.pt', '_criteo.pt']
bers = [1e-2, 1e-3, 1e-4, 1e-5,]
bers = bers[::-1]
folds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ]

results = torch.zeros(size = (len(model_paths), len(datasets), len(bers), len(folds)))

cnt = 0
with torch.no_grad():
    for dataset_i, dataset in enumerate(datasets):
        test_data_loader = testset_prepare(dataset_names[dataset_i], dataset_paths[dataset_i])
        for model_path_i, model_path in enumerate(model_paths):
            for ber_i, ber in enumerate(bers):
                for fold_i, fold in enumerate(folds):
                    model = torch.load(base_dir + model_path + dataset)
                    model = model.float().eval().to(device)
                    injector = Injector(ber, param_names = ['mlp', 'fc', 'afm'], device = device, verbose = False, mitigation = 'SBP')
                    injector.perform_mitigation(injector)
                    injector.inject(model)
                    del injector
                    result = test(model, test_data_loader, device)
                    print(cnt, result)
                    results[model_path_i][dataset_i][ber_i][fold_i] = result
                    cnt += 1