In [1]:
%matplotlib inline
import matplotlib.pyplot as plt

import itertools
import torch, os
import numpy as np
from radbm.metrics.sswr import HaltingCounterSSWR
from radbm.search.mbsds import HashingMultiBernoulliSDS
from radbm.search.elba import Fbeta, MIHash, HashNet
state_dir = #The path where the experiments took place, it is the same state_dir as in run_<type>.py

In [2]:
#util function
def avg_if_enough(values, minimum_for_avg):
    values = np.array(values)
    notnan = ~np.isnan(values)
    return values[notnan].mean() if notnan.sum() >= minimum_for_avg else np.inf

def batch_avg_if_enough(values, minimum_for_avg):
    return np.array([avg_if_enough(v, minimum_for_avg) for v in values])

def nparray_to_latextable(array, columns, rows):
    #this produce the latex code (don't try to read it)
    return ' \\\\\n'.join([' & '.join(columns)] + [' & '.join([name] + ['{:.4f}'.format(v) for v in row])
                                                      for name, row in zip(rows, array.round(4))
                                                  ]
                         )

In [3]:
model_types = [
    'fbeta',
    'mihash',
    'hashnet',
    'shared_fbeta',
    'shared_mihash',
    'shared_hashnet',
]
models = [m+str(i) for m in model_types for i in range(5)]

In [4]:
#Experiment 1
path = os.path.join(state_dir, 'current/{}')
results = {
    name: torch.load(path.format(name), map_location=torch.device('cpu'))['results']
    for name in models
}

sswrs_halts = np.zeros((6,4))
for n, t in enumerate(model_types):
    hr_sswrs = np.zeros((5,5))
    mb_sswrs = np.zeros((5,5))
    hr_halts = np.zeros((5,5))
    mb_halts = np.zeros((5,5))
    for i in range(5):
        hr_row = batch_avg_if_enough(results[t+str(i)]['valid_2081hr_sswrs'], 2)
        mb_row = batch_avg_if_enough(results[t+str(i)]['valid_5001mb_sswrs'], 2)
        hr_index = hr_row.argsort()[:5]
        mb_index = mb_row.argsort()[:5]
        hr_sswrs[i] = hr_row[hr_index]
        mb_sswrs[i] = mb_row[mb_index]
        hr_halts[i] = 100*batch_avg_if_enough(results[t+str(i)]['valid_2081hr_halts'], 2)[hr_index]
        mb_halts[i] = 100*batch_avg_if_enough(results[t+str(i)]['valid_5001mb_halts'], 2)[mb_index]
    sswrs_halts[n, 0] = hr_sswrs.mean()
    sswrs_halts[n, 1] = mb_sswrs.mean()
    sswrs_halts[n, 2] = hr_halts.mean()
    sswrs_halts[n, 3] = mb_halts.mean()
print(nparray_to_latextable(sswrs_halts, ['Models','HRS','HMBS','HRS-Halt%','HMBS-Halt%'], model_types))

Models & HRS & HMBS & HRS-Halt% & HMBS-Halt% \\
fbeta & 0.3798 & 0.0169 & 22.5120 & 0.2040 \\
mihash & 1.3559 & 1.8907 & 92.7176 & 92.1903 \\
hashnet & 1.4162 & 2.0001 & 100.0000 & 100.0000 \\
shared_fbeta & 0.1636 & 0.0035 & 8.6440 & 0.0000 \\
shared_mihash & 0.2119 & 0.0083 & 11.8080 & 0.0960 \\
shared_hashnet & 0.2536 & 0.2828 & 16.1360 & 12.8786


In [None]:
#Experiment 2
max_halt = 10000
allowed_types = [
    'fbeta',
    'shared_fbeta',
    'shared_mihash',
    'shared_hashnet',
]

with open('common.py', 'r') as f:
    exec(f.read())

def evaluate(model, documents, queries, relevances, max_halt, batch_size=100):
    #return sswrs, halts
    #sswrs.shape == halts.shape == (#table_setting=5, max_halt)
    model.eval()
    N, M = len(documents), len(queries)
    with torch.no_grad():
        dlogits = torch.cat(batch_call(model.fd, documents, batch_size), dim=0)
        qlogits = torch.cat(batch_call(model.fq, queries, batch_size), dim=0)
        dls_pairs = model._log_sigmoid_pairs(dlogits)
        qls_pairs = model._log_sigmoid_pairs(qlogits)
    sswrs = np.zeros((5, M, max_halt+1))
    halts = np.zeros((5, M, max_halt+1))
    for i, ntables in enumerate([1,2,4,8,16]):
        struct = HashingMultiBernoulliSDS(ntables, 1)
        struct.batch_insert(dls_pairs, range(N))
        gens = struct.batch_itersearch(qls_pairs, yield_empty=True)
        for j, (gen, rel) in enumerate(zip(gens, relevances)):
            sswr, halt = HaltingCounterSSWR(rel, gen, N, max_halt)
            sswrs[i, j] = sswr
            halts[i, j] = halt
    return sswrs.mean(axis=1), sswrs.std(axis=1), halts.mean(axis=1), halts.std(axis=1)
        
batchs = np.zeros((len(allowed_types), 5, 5), dtype=np.int64)
sswrs = np.zeros((len(allowed_types), 5, max_halt+1))
sswrs_std = np.zeros((len(allowed_types), 5, max_halt+1))
halts = np.zeros((len(allowed_types), 5, max_halt+1))
halts_std = np.zeros((len(allowed_types), 5, max_halt+1))
#model_name = '{}_hr{}'
#path = os.path.join(state_dir, 'hr', model_name)
model_name = '{}_mb{}' #hashtag this and the next lines and unhashtag the two lines above to crunch w.r.t. hamming radius
path = os.path.join(state_dir, 'mb', model_name)
for n, t in enumerate(allowed_types):
    fq = MNISTResNet(64)
    fd = fq if t.startswith('shared') else MNISTResNet(64)
    if t.endswith('fbeta'):
        model = Fbeta(fq, fd, None, -np.log(10000))
    elif t.endswith('mihash'):
        model = MIHash(fq, fd, None, 64, 1/8)
    elif t.endswith('hashnet'):
        model = HashNet(fq, fd, None, 1/2)
    else: raise ValueError()
    model.cuda()
    model.eval()
    t_sswrs = np.zeros((5, 5, 5, max_halt+1))
    t_sswrs_std = np.zeros((5, 5, 5, max_halt+1))
    t_halts = np.zeros((5, 5, 5, max_halt+1))
    t_halts_std = np.zeros((5, 5, 5, max_halt+1))
    for i, j in itertools.product(range(5), range(5)):
        state = torch.load(path.format(t+str(i), j))
        batchs[n, i, j] = state.pop('nbatch')
        model.load_state_dict(state)
        t_sswrs[i, j], t_sswrs_std[i, j], t_halts[i, j], t_halts_std[i, j] = evaluate(
            model, valid_d, valid_q, relevances, max_halt)
        print('{} crunched'.format(model_name.format(t+str(i), j)))
    sswrs[n] = t_sswrs.mean(axis=(0,1))
    sswrs_std[n] = t_sswrs_std.mean(axis=(0,1))
    halts[n] = t_halts.mean(axis=(0,1))
    halts_std[n] = t_halts_std.mean(axis=(0,1))

fbeta0_mb0 crunched
fbeta0_mb1 crunched
fbeta0_mb2 crunched
fbeta0_mb3 crunched
fbeta0_mb4 crunched
fbeta1_mb0 crunched
fbeta1_mb1 crunched
fbeta1_mb2 crunched
fbeta1_mb3 crunched
fbeta1_mb4 crunched
fbeta2_mb0 crunched
fbeta2_mb1 crunched
fbeta2_mb2 crunched
fbeta2_mb3 crunched
fbeta2_mb4 crunched
fbeta3_mb0 crunched
fbeta3_mb1 crunched
fbeta3_mb2 crunched
fbeta3_mb3 crunched
fbeta3_mb4 crunched
fbeta4_mb0 crunched
fbeta4_mb1 crunched
fbeta4_mb2 crunched
fbeta4_mb3 crunched
fbeta4_mb4 crunched
shared_fbeta0_mb0 crunched
shared_fbeta0_mb1 crunched
shared_fbeta0_mb2 crunched
shared_fbeta0_mb3 crunched
shared_fbeta0_mb4 crunched
shared_fbeta1_mb0 crunched
shared_fbeta1_mb1 crunched
shared_fbeta1_mb2 crunched
shared_fbeta1_mb3 crunched
shared_fbeta1_mb4 crunched
shared_fbeta2_mb0 crunched
shared_fbeta2_mb1 crunched
shared_fbeta2_mb2 crunched
shared_fbeta2_mb3 crunched
shared_fbeta2_mb4 crunched
shared_fbeta3_mb0 crunched
shared_fbeta3_mb1 crunched
shared_fbeta3_mb2 crunched
shared_fbeta3_

In [None]:
#uncomment to save

#import pickle
#with open('mbcrunch.pkl', 'wb') as f:
#    data = (batchs, sswrs, sswrs_std, halts, halts_std)
#    pickle.dump(data, f, -1)