In [15]:
import torch
import numpy as np
from tqdm import tqdm

from icecream import ic
import torch.nn.functional as F

# local files
from src.util.data_handling.string_generator import str_seq_to_num_seq, ALPHABETS
from src.util.data_handling.data_loader import save_as_pickle, load_dataset
from src.models.pair_encoder import PairEmbeddingDistance
from src.util.distance_functions.distance_matrix import DISTANCE_MATRIX
from src.util.ml_and_math.loss_functions import AverageMeter
from src.util.data_handling.closest_string_dataset import ReferenceDataset, QueryDataset

In [16]:
def normalize_embeddings(embeddings, radius, distance_str):
    """Project embeddings to an hypersphere of a certain radius."""
    
    min_scale = 1e-7
    if distance_str == 'hyperbolic':
        max_scale = 1 - 1e-3
    else:
        max_scale = 1e10

    return F.normalize(embeddings, p=2, dim=1) * radius.clamp_min(min_scale).clamp_max(max_scale)

In [17]:
def get_dataloader(ids, id_to_str_seq, alphabet_str, length, batch_size, num_seqs=None, labels=None, dataset_type='reference'):
    """Convert a list of greengenes otu ids to a dataloader. """
    
    if num_seqs is None:
        alphabet = ALPHABETS[alphabet_str]
        str_seqs = [id_to_str_seq[str(_id)] for _id in ids]
        num_seqs = [str_seq_to_num_seq(s, length=length, alphabet=alphabet) for s in tqdm(str_seqs, desc='Convert string sequences to numerical sequences')]
    
    if dataset_type == 'reference':
        dataset = ReferenceDataset(num_seqs)
    elif dataset_type == 'query':
        dataset = QueryDataset(num_seqs, labels)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
    return dataloader

In [18]:
def embed_strings(loader, model, device, radius=0, distance_str=''):
    """ Embeds the sequences of a dataset one batch at the time given an encoder """
    embeddings = []

    for sequences in tqdm(loader, desc='Embed the sequences'):
        sequences = sequences.to(device)
        # embedded = model.encode(sequences)
        embedded = normalize_embeddings(model(sequences), radius, distance_str)
        embeddings.append(embedded.cpu().detach())

    embeddings = torch.cat(embeddings, axis=0)
    return embeddings

In [19]:
def load_model(encoder_path):
    
    # model
    model_class, model_args, state_dict, distance_str, radius, scaling = torch.load(encoder_path)
    encoder_model = model_class(**vars(model_args))
    
    # Restore best model
    print('Loading model ' + encoder_path)
    encoder_model.load_state_dict(state_dict)
    encoder_model.eval()
    
    return encoder_model, distance_str, radius, scaling

In [20]:
def ids_to_embeddings(ids, auxillary_data_path, encoder_path, batch_size=128, no_cuda=False, seed=42):
    """Convert a list of greengenes otu ids to PyTorch embeddings."""
    
    # set device
    cuda = not no_cuda and torch.cuda.is_available()
    device = 'cuda' if cuda else 'cpu'
    print('Using device:', device)

    # set the random seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    if cuda:
        torch.cuda.manual_seed(seed)
    
    # restore best model
    encoder_model, distance_str, radius, scaling = load_model(encoder_path)
    
    # turn greengene ids into dataloader of (numerical) greengene sequences
    id_to_str_seq, _, alphabet, length = load_dataset(auxillary_data_path)
    sequence_dataloader = get_dataloader(ids, id_to_str_seq, alphabet, length, batch_size)
    
    # get embeddings
    embeddings = embed_strings(sequence_dataloader, encoder_model, device, radius=radius, distance_str=distance_str)
    
    return embeddings

In [29]:
encoder_path = '../models/transformer_hyperbolic_16_model.pickle'
auxillary_data_path = '../data/interim/greengenes/auxillary_data.pickle'
id_to_str_seq, split_to_ids, alphabet, length = load_dataset(auxillary_data_path)
ids = split_to_ids['ref'][:300]
batch_size = 128
no_cuda = False
seed = 42

embeddings = ids_to_embeddings(ids, auxillary_data_path, encoder_path, batch_size=batch_size, no_cuda=no_cuda, seed=seed)

Using device: cuda
padding 0
Loading model ../models/transformer_hyperbolic_16_model.pickle


Convert string sequences to numerical sequences: 100%|██████████| 300/300 [00:00<00:00, 3092.83it/s]
Embed the sequences: 100%|██████████| 3/3 [00:00<00:00, 99.27it/s]


In [22]:
embeddings.shape

torch.Size([300, 16])