In [1]:
import numpy as np
import pandas as pd

from utils.learnedbloomfilter import LearnedModel, LearnedBloomFilter

import torch

from tqdm import tqdm, trange

In [2]:
dataset = pd.read_csv('data/Japan_dataset_octet_3.csv')
dataset

Unnamed: 0,integer,label
0,65552,1
1,65553,1
2,65554,1
3,65555,1
4,65556,1
...,...,...
1893236,14187820,0
1893237,15423251,0
1893238,15193767,0
1893239,9174177,0


In [3]:
## LOAD MODEL

model_in = 24
model_arch = [model_in, 256, 128, 64, 32]
model_out = 1

model = torch.nn.Sequential()
for i in range(1, len(model_arch)):
    model.append(torch.nn.Linear(model_arch[i - 1], model_arch[i]))
    model.append(torch.nn.ReLU())
model.append(torch.nn.Linear(model_arch[-1], model_out))

model.load_state_dict(torch.load("saved_model/Japan_256_128_64_32_fp_4430.pth"))

<All keys matched successfully>

In [4]:
positives = dataset[dataset.label == 1].integer.to_numpy()

In [5]:
device = "cpu"
input_size = 24
lm = LearnedModel(model=model, input_size=input_size, device=device)

fpr = 0.01
lbf = LearnedBloomFilter(lm=lm, fpr=fpr, positives=positives)

In [6]:
lbf.n_bfilter

12674

### FPR ANALYSIS

In [7]:
l = 2**24
universe = np.arange(l)
preds = np.zeros(l)

In [8]:
labels = np.zeros(l)
labels[positives] = 1

In [9]:
batch_size = 1024

for i in trange(0, len(universe), batch_size):
    start = i
    end = min(i + batch_size, len(universe))
    inputs = universe[i : i + batch_size]
    outputs = lbf.query(inputs)
    preds[i: i + batch_size] = outputs

100%|████████████████████████████████████| 16384/16384 [01:25<00:00, 191.50it/s]


In [10]:
import sklearn.metrics as skm

In [11]:
tn, fp, fn, tp = skm.confusion_matrix(labels, preds).ravel()

In [12]:
fpr = (fp / (fp + tn)) * 100
precision = tp / (tp + fp)
recall = tp / (tp + fn)
accuracy = (tp + tn) / (tp + tn + fp + fn)
f1 = (2 * (precision * recall)) / (precision + recall)

In [13]:
print(f'FALSE POSITIVE RATE: {fpr:.4f}%')
print(f'Recall : {recall}, Precision : {precision}, f1: {f1}')
print(f'Accuracy : {accuracy}')

FALSE POSITIVE RATE: 1.4673%
Recall : 1.0, Precision : 0.7684516722775862, f1: 0.8690672007880186
Accuracy : 0.9860082268714905


In [None]:
tn, fp, fn, tp 