In [1]:
import sys
sys.path.append("/home/ethan/mixture_embeddings/")

In [2]:
import torch
import numpy as np
import pandas as pd

from tqdm import tqdm

from icecream import ic

# local files
from src.util.data_handling.string_generator import str_seq_to_num_seq, ALPHABETS
from src.util.data_handling.data_loader import save_as_pickle, load_dataset
from src.util.distance_functions.distance_matrix import DISTANCE_MATRIX
from src.util.nearest_neighbors.bruteforce import BruteForceNearestNeighbors
from src.util.nearest_neighbors.hnsw import HNSW
from src.embeddings.embeddings import get_otu_embeddings, embed_strings, load_model

INFO: Using numpy backend


In [None]:
# set the device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using device:', device)

In [3]:
csr_data_path = '../data/interim/greengenes/closest_strings_ref500_query500.pickle'
references, queries, labels = load_dataset(csr_data_path)

In [4]:
batch_size = 128
no_cuda = False
seed = 42
encoder_path = '../models2/cnn_euclidean_128_model.pickle'
reference_dataloader = get_otu_embeddings(references, encoder_path, batch_size, seed=seed, no_cuda=no_cuda)
query_dataloader = get_otu_embeddings(queries, encoder_path, batch_size, seed=seed, no_cuda=no_cuda, labels=labels)

Using device: cuda
Loading model ../models2/cnn_euclidean_128_model.pickle


Embedding sequences: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s]


Using device: cuda
Loading model ../models2/cnn_euclidean_128_model.pickle


Embedding sequences: 100%|██████████| 3/3 [00:00<00:00, 64.09it/s]


In [None]:
# load model
encoder_model = load_model(encoder_path)

# embed reference data
embedded_reference = embed_strings(reference_dataloader, encoder_model, device, desc='Embedding references')
embedded_reference = embedded_reference.to(device)

In [None]:
def closest_string_retrieval(nn_alg, encoder_path, closest_strings_path, batch_size, num_neighbors=10, no_cuda=False, seed=42, verbose=True):
    
    # set the device
    cuda = not no_cuda and torch.cuda.is_available()
    device = 'cuda' if cuda else 'cpu'
    print('Using device:', device)

    # set the random seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    if cuda:
        torch.cuda.manual_seed(seed)    
        
    # load model
    encoder_model = load_model(encoder_path)
    distance_str = encoder_model.distance_str
    distance = DISTANCE_MATRIX[distance_str]
    
    # load data
    reference_dataset, query_dataset = load_csr_dataset(closest_strings_path)
    reference_loader = torch.utils.data.DataLoader(reference_dataset, batch_size=batch_size, shuffle=False)
    query_loader = torch.utils.data.DataLoader(query_dataset, batch_size=batch_size, shuffle=False)
    
    # embed reference data
    embedded_reference = embed_strings(reference_loader, encoder_model, device, desc='Embedding references')
    embedded_reference = embedded_reference.to(device)
    
    # get nearest neighbor algorithm
    nn = None
    if nn_alg == 'brute_force':
        nn = BruteForceNearestNeighbors(num_neighbors, distance, device, {'scaling': encoder_model.scaling})
    elif nn_alg == 'hnsw':
        nn = HNSW(num_neighbors, distance_str, device)
    else:
        raise ValueError("`nn` must be in `['brute_force', 'hnsw']`. `nn` is {}".format(nn))
    nn.fit(embedded_reference)
    
    # get closest strings by embedding queries and using nearest neighbor algorithm `nn`
    avg_acc = test(query_loader, encoder_model, nn, device, num_neighbors)
    avg_num_comparisons = nn.num_comparisons.avg
    
    if verbose:
        print('ACCURACY: Top1: {:.3f}  Top5: {:.3f}  Top10: {:.3f}'.format(avg_acc[0], avg_acc[4], avg_acc[9]))
        print('COMPARISONS: {:.3f}'.format(avg_num_comparisons))
    
    return avg_acc, avg_num_comparisons