In [2]:
import numpy as np
import sketches
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
import torch

n = 1000000
def get_f_error(true, pred):
    return torch.sum(true * abs(pred - true)) / torch.sum(true)

def get_l2_error(true, pred):
    delta = (true - pred).float()
    return torch.sqrt(torch.mean(delta ** 2))

def get_errors(sketch, title, orig_freqs, mask, space, trials=1):
    widths = []
    new_freqs = orig_freqs.detach().clone()
    if len(mask) > 0:
        new_freqs[mask] = 0
    errs = []
    l2errs = []
    eresid = []
    for nhashes in range(1, 5): #modify this if you want to change the hyperparameters. 
        running_err = 0
        running_l2err = 0
        running_eresid = 0
        for trial in tqdm(range(trials)):
            if sketch == "Count-Sketch":
                preds = torch.Tensor(sketches.count_sketch_preds(nhashes, new_freqs.numpy(), space // nhashes, 10 + trial)).int()
            if sketch == "Count-Min":
                preds = torch.Tensor(sketches.cm_sketch_preds(nhashes, new_freqs.numpy(), space // nhashes, 10 + trial)).int()
            if len(mask) > 0:
                preds[mask] = orig_freqs.detach().clone()[mask]
            running_err += get_f_error(orig_freqs, preds).item()
            running_eresid += get_f_error(orig_freqs, preds).item() * orig_freqs.sum() / new_freqs.sum()
            running_l2err += get_l2_error(orig_freqs, preds).item()
        errs.append(running_err / trials)
        l2errs.append(running_l2err / trials)
        eresid.append(running_eresid / trials)
    print(f"{title}, {space} cells, {trials} trials, average f error")
    for i in range(nhashes):
        print(f"{i} layers (error): {errs[i]:.02f}")
    print()
    
    print(f"{title}, {space} cells, {trials} trials, average e resid")
    for i in range(nhashes):
        print(f"{i} layers (e resid): {eresid[i]:.02f}")
    print()
    
    print(f"{title}, {space} cells, {trials} trials, average l2 error")
    for i in range(nhashes):
        print(f"{i} layers (l2 err): {l2errs[i]:.02f}")
    print()
        
def get_all_errors(freqs, trials=1000):
    no_mask = torch.IntTensor([])
    bad_mask = torch.multinomial(torch.ones(12500), 10000, replacement=False)
    good_mask = torch.multinomial(torch.ones(10000), 10000, replacement=False)
    trials=1000
    for space in [20000]:
        get_errors("Count-Min", "No Screening", freqs, no_mask, space, trials)
        get_errors("Count-Min", "Imperfect Screening", freqs, bad_mask, space, trials)
        get_errors("Count-Min", "Perfect Screening", freqs, good_mask, space, trials)
    
for zipf in [1]:
    freqs = ((n / torch.arange(1, n+1)) ** zipf).int()
    get_all_errors(freqs)

100%|██████████| 1000/1000 [01:03<00:00, 15.79it/s]
100%|██████████| 1000/1000 [01:37<00:00, 10.24it/s]
100%|██████████| 1000/1000 [01:46<00:00,  9.35it/s]
100%|██████████| 1000/1000 [02:39<00:00,  6.29it/s]
  0%|          | 2/1000 [00:00<01:06, 14.95it/s]

No Screening, 20000 cells, 1000 trials, average f error
0 layers (error): 592.16
1 layers (error): 529.23
2 layers (error): 731.38
3 layers (error): 952.03

No Screening, 20000 cells, 1000 trials, average e resid
0 layers (e resid): 592.16
1 layers (e resid): 529.23
2 layers (e resid): 731.38
3 layers (e resid): 952.03

No Screening, 20000 cells, 1000 trials, average l2 error
0 layers (l2 err): 9098.19
1 layers (l2 err): 686.32
2 layers (l2 err): 769.70
3 layers (l2 err): 983.48



100%|██████████| 1000/1000 [01:06<00:00, 14.93it/s]
100%|██████████| 1000/1000 [01:43<00:00,  9.66it/s]
100%|██████████| 1000/1000 [01:51<00:00,  8.94it/s]
100%|██████████| 1000/1000 [02:44<00:00,  6.07it/s]
  0%|          | 2/1000 [00:00<01:05, 15.29it/s]

Imperfect Screening, 20000 cells, 1000 trials, average f error
0 layers (error): 133.58
1 layers (error): 167.30
2 layers (error): 241.49
3 layers (error): 318.94

Imperfect Screening, 20000 cells, 1000 trials, average e resid
0 layers (e resid): 297.91
1 layers (e resid): 373.13
2 layers (e resid): 538.59
3 layers (e resid): 711.32

Imperfect Screening, 20000 cells, 1000 trials, average l2 error
0 layers (l2 err): 3169.78
1 layers (l2 err): 394.08
2 layers (l2 err): 543.02
3 layers (l2 err): 713.80



100%|██████████| 1000/1000 [01:10<00:00, 14.26it/s]
100%|██████████| 1000/1000 [01:52<00:00,  8.91it/s]
100%|██████████| 1000/1000 [01:50<00:00,  9.07it/s]
100%|██████████| 1000/1000 [02:34<00:00,  6.46it/s]

Perfect Screening, 20000 cells, 1000 trials, average f error
0 layers (error): 62.75
1 layers (error): 109.17
2 layers (error): 158.90
3 layers (error): 210.12

Perfect Screening, 20000 cells, 1000 trials, average e resid
0 layers (e resid): 209.36
1 layers (e resid): 364.22
2 layers (e resid): 530.13
3 layers (e resid): 701.01

Perfect Screening, 20000 cells, 1000 trials, average l2 error
0 layers (l2 err): 219.29
1 layers (l2 err): 369.31
2 layers (l2 err): 533.03
3 layers (l2 err): 702.36




