In [1]:
import sys
sys.path.append("/home/ethan/mixture_embeddings/")

In [2]:
import torch
import numpy as np
import pandas as pd

from tqdm import tqdm

from icecream import ic

# local files
from src.util.data_handling.string_generator import str_seq_to_num_seq, ALPHABETS
from src.util.data_handling.data_loader import save_as_pickle, load_dataset
from src.util.distance_functions.distance_matrix import DISTANCE_MATRIX
from src.util.ml_and_math.loss_functions import AverageMeter
from src.util.data_handling.closest_string_dataset import ReferenceDataset, QueryDataset
from src.util.nearest_neighbors.bruteforce import BruteForceNearestNeighbors
from src.util.nearest_neighbors.hnsw import HNSW

In [3]:
def load_csr_dataset(path):
    sequences_references, sequences_queries, labels = load_dataset(path)
    reference_dataset = ReferenceDataset(sequences_references)
    query_dataset = QueryDataset(sequences_queries, labels)
    return reference_dataset, query_dataset

In [4]:
def load_model(encoder_path):
    
    # model
    encoder_model, state_dict = torch.load(encoder_path)
    encoder_model.load_state_dict(state_dict)

    # Restore best model
    print('Loading model ' + encoder_path)
    encoder_model.load_state_dict(state_dict)
    encoder_model.eval()
    
    return encoder_model

In [5]:
def embed_strings(loader, model, device, desc='Embedding sequences'):
    """ Embeds the sequences of a dataset one batch at the time given an encoder """
    embeddings = []

    for sequences in tqdm(loader, desc=desc):
        sequences = sequences.to(device)
        embedded = model.encode(sequences)
        embeddings.append(embedded.cpu().detach())

    embeddings = torch.cat(embeddings, axis=0)
    return embeddings

In [6]:
def test(query_loader, model, nn, device, num_neighbors, desc='Embedding queries'):
    """ Given the embedding of the references, embeds and checks the performance for one batch of queries at a time """
    
    # initial values
    avg_acc = AverageMeter(len_tuple=num_neighbors)
    nn_distances_pred = []
    nn_idxs_pred = []
    
    for query_sequences, labels in tqdm(query_loader, desc=desc):
        
        # embed query sequences
        query_sequences, labels = query_sequences.to(device), labels.to(device)
        embedded_query = model.encode(query_sequences)

        # compute nearest k nearest neighbors for each embedded query
        nn_distances, nn_idxs = nn.kneighbors(embedded_query)
        nn_distances_pred.append(nn_distances)
        nn_idxs_pred.append(nn_idxs)

        # compute top-k accuracy        
        correct = nn_idxs.eq(labels.unsqueeze(1)).expand_as(nn_idxs)[:10]
        rank = torch.cumsum(correct, 1)
        acc = [torch.mean((rank[:, i]).float()) for i in range(num_neighbors)]
        avg_acc.update(acc, query_sequences.shape[0])

    avg_acc = torch.vstack(avg_acc.avg).squeeze().detach().cpu()
    return avg_acc

In [7]:
def closest_string_retrieval(nn_alg, encoder_path, closest_strings_path, batch_size, num_neighbors=10, no_cuda=False, seed=42, verbose=True):
    
    # set the device
    cuda = not no_cuda and torch.cuda.is_available()
    device = 'cuda' if cuda else 'cpu'
    print('Using device:', device)

    # set the random seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    if cuda:
        torch.cuda.manual_seed(seed)    
        
    # load model
    encoder_model = load_model(encoder_path)
    distance_str = encoder_model.distance_str
    distance = DISTANCE_MATRIX[distance_str]
    
    # load data
    reference_dataset, query_dataset = load_csr_dataset(closest_strings_path)
    reference_loader = torch.utils.data.DataLoader(reference_dataset, batch_size=batch_size, shuffle=False)
    query_loader = torch.utils.data.DataLoader(query_dataset, batch_size=batch_size, shuffle=False)
    
    # embed reference data
    embedded_reference = embed_strings(reference_loader, encoder_model, device, desc='Embedding references')
    embedded_reference = embedded_reference.to(device)
    
    # get nearest neighbor algorithm
    nn = None
    if nn_alg == 'brute_force':
        nn = BruteForceNearestNeighbors(num_neighbors, distance, device, {'scaling': encoder_model.scaling})
    elif nn_alg == 'hnsw':
        nn = HNSW(num_neighbors, distance_str, device)
    else:
        raise ValueError("`nn` must be in `['brute_force', 'hnsw']`. `nn` is {}".format(nn))
    nn.fit(embedded_reference)
    
    # get closest strings by embedding queries and using nearest neighbor algorithm `nn`
    avg_acc = test(query_loader, encoder_model, nn, device, num_neighbors)
    avg_num_comparisons = nn.num_comparisons.avg
    
    if verbose:
        print('ACCURACY: Top1: {:.3f}  Top5: {:.3f}  Top10: {:.3f}'.format(avg_acc[0], avg_acc[4], avg_acc[9]))
        print('COMPARISONS: {:.3f}'.format(avg_num_comparisons))
    
    return avg_acc, avg_num_comparisons

In [8]:
def test_all(closest_strings_path='../data/interim/greengenes/closest_strings_ref500_query500.pickle', outdir='../data/processed', model_dir='../models2', batch_size=128, seed=42, no_cuda=False):

    dimensions = [2, 4, 6, 8]
    distance_strs = ['euclidean', 'hyperbolic']
    nn_algs = ['hnsw', 'brute_force']
    results = []

    for dim in dimensions:
        for dist_str in distance_strs:
            for nn_alg in nn_algs:
                encoder_path = '{}/cnn_{}_{}_model.pickle'.format(model_dir, dist_str, dim)
                avg_acc, avg_num_comparisons = closest_string_retrieval(nn_alg, encoder_path, closest_strings_path, batch_size, seed=seed, no_cuda=no_cuda)
                print()
                
                result = {
                    'distance': dist_str, 
                    'dim': dim, 
                    'nn_alg': nn_alg,
                    'top 1 acc': avg_acc[0].item(),
                    'top 5 acc': avg_acc[4].item(),
                    'top 10 acc': avg_acc[9].item(),
                    'comparisons': avg_num_comparisons
                    }
                results.append(result)             
                save_as_pickle(result, '{}/csr_results_{}_{}_{}.pickle'.format(outdir, nn_alg, dist_str, dim))
             
    filename = '{}/csr_results_all.pickle'.format(outdir)
    save_as_pickle(results, filename)
    return filename

In [15]:
x = False
y = 5

In [16]:
y += 1 & x
y

5

In [9]:
test_all()

Using device: cuda
Loading model ../models2/cnn_euclidean_2_model.pickle


Embedding references: 100%|██████████| 4/4 [00:03<00:00,  1.09it/s]
Building HNSW graph: 100%|██████████| 500/500 [01:06<00:00,  7.54it/s]
Querying: 100%|██████████| 128/128 [00:20<00:00,  6.38it/s]
Querying: 100%|██████████| 128/128 [00:20<00:00,  6.30it/s]s/it]
Querying: 100%|██████████| 65/65 [00:10<00:00,  6.26it/s]24s/it]
Embedding queries: 100%|██████████| 3/3 [00:50<00:00, 16.95s/it]


ACCURACY: Top1: 0.199  Top5: 0.460  Top10: 0.660
COMPARISONS: 266.629

Using device: cuda
Loading model ../models2/cnn_euclidean_2_model.pickle


Embedding references: 100%|██████████| 4/4 [00:00<00:00, 134.04it/s]
Embedding queries: 100%|██████████| 3/3 [00:00<00:00, 24.31it/s]


ACCURACY: Top1: 0.199  Top5: 0.460  Top10: 0.660
COMPARISONS: 500.000

Using device: cuda
Loading model ../models2/cnn_hyperbolic_2_model.pickle


Embedding references: 100%|██████████| 4/4 [00:00<00:00, 108.60it/s]
Building HNSW graph:  29%|██▉       | 146/500 [00:10<00:26, 13.46it/s]


KeyboardInterrupt: 

# Plots

In [15]:
import plotly.express as px
import plotly.graph_objects as go

In [12]:
suffixes = ['euclidean_2', 'euclidean_4', 'euclidean_6', 'euclidean_8', 'hyperbolic_2', 'hyperbolic_4', 'hyperbolic_6', 'hyperbolic_8']
nn_algs = ['hnsw', 'brute_force']
outdir = '../data/processed'
filenames = ['{}/csr_results_{}_{}.pickle'.format(outdir, nn_alg, suffix) for suffix in suffixes for nn_alg in nn_algs]
results = [load_dataset(filename) for filename in filenames]

In [13]:
df = pd.DataFrame(results)
df

Unnamed: 0,distance,dim,nn_alg,top 1 acc,top 5 acc,top 10 acc,comparisons
0,euclidean,2,hnsw,0.199377,0.460125,0.660125,266.239875
1,euclidean,2,brute_force,0.199377,0.460125,0.660125,500.0
2,euclidean,4,hnsw,0.280374,0.7,0.860125,279.968847
3,euclidean,4,brute_force,0.280374,0.7,0.860125,500.0
4,euclidean,6,hnsw,0.360748,0.720872,0.760748,290.277259
5,euclidean,6,brute_force,0.360748,0.720872,0.760748,500.0
6,euclidean,8,hnsw,0.520872,0.800623,0.880374,306.529595
7,euclidean,8,brute_force,0.520872,0.800623,0.880374,500.0
8,hyperbolic,2,hnsw,0.060125,0.160125,0.320872,236.003115
9,hyperbolic,2,brute_force,0.060125,0.160125,0.320872,500.0


In [21]:
y_text_shift = 5
exp_offset = 3
num_decimals = 3

df = pd.DataFrame(results)
df['top 1 acc rounded'] = df['top 1 acc'].round(3)
df['top 1 acc exp'] = df['top 1 acc'].pow(exp_offset)
df['comparisons shifted'] = df['comparisons'] + y_text_shift * (df['distance'] == 'euclidean')

In [22]:
fig1 = px.scatter(df, x='dim', y='comparisons shifted', color='distance', size='top 1 acc exp', symbol='nn_alg')
fig2 = px.line(df, x="dim", y="comparisons shifted", color='distance', symbol='nn_alg')
fig = go.Figure(data = fig1.data + fig2.data)

fig.update_traces(textposition="bottom right")
fig.update_layout(title={'text': 'Nearest Neighbor Search in Hyperbolic and Euclidean Space', 'xanchor': 'center', 'x':0.5})
fig.update_xaxes(title="Embedding Dimensions")
fig.update_yaxes(title="Number of Comparisons (per query per dimension)")

fig.show()

In [40]:
df['dim + nn_alg'] = df[['dim', 'nn_alg']].apply(lambda row: '_'.join(row.values.astype(str)), axis=1)
df['top 1 acc rounded'] = df['top 1 acc'].round(3)
df

Unnamed: 0,distance,dim,nn_alg,top 1 acc,top 5 acc,top 10 acc,comparisons,dim + nn_alg,top 1 acc rounded
0,euclidean,2,hnsw,0.199377,0.460125,0.660125,266.239875,2_hnsw,0.199
1,euclidean,2,brute_force,0.199377,0.460125,0.660125,500.0,2_brute_force,0.199
2,euclidean,4,hnsw,0.280374,0.7,0.860125,279.968847,4_hnsw,0.28
3,euclidean,4,brute_force,0.280374,0.7,0.860125,500.0,4_brute_force,0.28
4,euclidean,6,hnsw,0.360748,0.720872,0.760748,290.277259,6_hnsw,0.361
5,euclidean,6,brute_force,0.360748,0.720872,0.760748,500.0,6_brute_force,0.361
6,euclidean,8,hnsw,0.520872,0.800623,0.880374,306.529595,8_hnsw,0.521
7,euclidean,8,brute_force,0.520872,0.800623,0.880374,500.0,8_brute_force,0.521
8,hyperbolic,2,hnsw,0.060125,0.160125,0.320872,236.003115,2_hnsw,0.06
9,hyperbolic,2,brute_force,0.060125,0.160125,0.320872,500.0,2_brute_force,0.06


In [41]:
fig = px.bar(df, x='dim + nn_alg', y='top 1 acc', color='distance', barmode="group", text='top 1 acc rounded')
fig.update_layout(title={'text': 'Top 1% Accuracy for Predicting Nearest Neighbor', 'xanchor': 'center', 'x':0.5})
fig

# Plots

In [9]:
# suffixes = ['euclidean_2', 'euclidean_4', 'euclidean_6', 'euclidean_8', 'hyperbolic_2', 'hyperbolic_4', 'hyperbolic_6']
# outdir = '../data/processed'
# filenames = ['{}/csr_results_{}.pickle'.format(outdir, suffix) for suffix in suffixes]
# results = [load_dataset(filename) for filename in filenames]

In [10]:
y_text_shift = 1
exp_offset = 3
num_decimals = 3

df = pd.DataFrame(results)
df['top 1 acc rounded'] = df['top 1 acc'].round(3)
df['top 1 acc exp'] = df['top 1 acc'].pow(exp_offset)
df['comparisons shifted'] = df['comparisons'] + y_text_shift * (df['distance'] == 'euclidean')

df

NameError: name 'results' is not defined

In [None]:
import plotly.express as px
import plotly.graph_objects as go

In [None]:
fig1 = px.scatter(df, x='dim', y='comparisons shifted', color='distance', size='top 1 acc exp', symbol='nn_alg', text='top 1 acc rounded')
fig2 = px.line(df, x="dim", y="comparisons shifted", color='distance')
fig = go.Figure(data = fig1.data + fig2.data)

fig.update_traces(textposition="bottom right")
fig.update_layout(title={'text': 'Nearest Neighbor Search in Hyperbolic and Euclidean Space', 'xanchor': 'center', 'x':0.5})
fig.update_xaxes(title="Embedding Dimensions")
fig.update_yaxes(title="Number of Comparisons (per query per dimension)")

fig.show()