In [1]:
import torch
import numpy as np
from tqdm import tqdm

from icecream import ic

# local files
from src.util.data_handling.string_generator import str_seq_to_num_seq, ALPHABETS
from src.util.data_handling.data_loader import save_as_pickle, load_dataset
from src.util.distance_functions.distance_matrix import DISTANCE_MATRIX
from src.util.data_handling.closest_string_dataset import ReferenceDataset, QueryDataset

In [2]:
import geomstats.backend as gs

gs.random.seed(2020)

INFO: Using numpy backend


# Get Embeddings

In [3]:
def load_model(encoder_path):
    
    # model
    encoder_model, state_dict = torch.load(encoder_path)
    encoder_model.load_state_dict(state_dict)

    # Restore best model
    print('Loading model ' + encoder_path)
    encoder_model.load_state_dict(state_dict)
    encoder_model.eval()
    
    return encoder_model

In [4]:
def get_num_seq(ids, auxillary_data_path):
    
    id_to_str_seq, _, alphabet_str, length = load_dataset(auxillary_data_path)
    alphabet = ALPHABETS[alphabet_str]
    str_seqs = [id_to_str_seq[str(_id)] for _id in ids]
    num_seqs = [str_seq_to_num_seq(s, length=length, alphabet=alphabet) for s in tqdm(str_seqs, desc='Convert string sequences to numerical sequences')]
    return num_seqs

In [5]:
def get_dataloader(num_seq, batch_size, labels=None):
    """Convert a num_seq to a dataloader. Optionally can add labels too."""
    
    if labels is  None:
        dataset = ReferenceDataset(num_seq) # iterate over just num_seq
    else:
        dataset = QueryDataset(num_seq, labels) # iterate over num_seq and labels together
        
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
    return dataloader

In [6]:
def embed_strings(loader, model, device, desc='Embedding sequences'):
    """ Embeds the sequences of a dataset one batch at the time given an encoder """
    embeddings = []

    for sequences in tqdm(loader, desc=desc):
        sequences = sequences.to(device)
        embedded = model.encode(sequences)
        embeddings.append(embedded.cpu().detach())

    embeddings = torch.cat(embeddings, axis=0)
    return embeddings

In [7]:
def get_embeddings(data, encoder_path, batch_size, seed=42, no_cuda=False, labels=None, auxillary_data_path=None, save=True, outdir=None):
    """Get embeddings from data. 
    
    Data can either be a list of ids or num_seq.
    * If it is a list of ids, then we will need auxillary_data_path and will
      automatically compute num_seq from the data and then get the embeddings.
    * Otherwise if data is num_seq to begin with we will just simply get the
      embeddings.
    """    
    
    # set device
    cuda = not no_cuda and torch.cuda.is_available()
    device = 'cuda' if cuda else 'cpu'
    print('Using device:', device)

    # set random seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    if cuda:
        torch.cuda.manual_seed(seed)    
        
    # load model
    encoder_model = load_model(encoder_path)
    
    # load data
    if auxillary_data_path is not None:
        ids = data
        num_seq = get_num_seq(ids, auxillary_data_path)
    else:
        num_seq = data
        
    # get dataloader
    loader = get_dataloader(num_seq, batch_size, labels)
    
    # embed strings
    embeddings = embed_strings(loader, encoder_model, device)
    if save:
        model_name = encoder_path.split('/')[-1]
        save_as_pickle(embeddings, '{}/{}'.format(outdir, model_name))
    
    return embeddings

In [8]:
auxillary_data_path = '../data/interim/greengenes/auxillary_data.pickle'
id_to_str_seq, split_to_ids, alphabet, length = load_dataset(auxillary_data_path)
ids = split_to_ids['ref'][:300]

In [9]:
# encoder_path = '../models2/cnn_hyperbolic_16_model.pickle'
# auxillary_data_path = '../data/interim/greengenes/auxillary_data.pickle'
# batch_size = 128
# no_cuda = False
# seed = 42

# embeddings = get_embeddings(ids, encoder_path, batch_size, seed=seed, no_cuda=no_cuda, auxillary_data_path=auxillary_data_path)
# embeddings.shape

# Mixture Embeddings

In [10]:
import torch
from geomstats.geometry.hyperbolic import Hyperbolic
from geomstats.learning.frechet_mean import FrechetMean

# local files
from src.util.data_handling.data_loader import load_dataset, save_as_pickle

import geomstats.backend as gs

gs.random.seed(2020)

In [11]:
def get_mixture_embeddings(data, otu_embeddings, distance_str):
    """Compute mixture embeddings

    Parameters
    ----------
    data : pandas DataFrame of shape (num_samples, num_otus)
        Dataframe where the ijth entry is how much the ith person has of the jth
       otu.
    otu_embeddings : np.ndarray of shape (num_otus, embedding_size)
        _description_
    distance_str : string
        _description_
    desc : str, optional
        _description_, by default 'Mixture Embeddings'

    Returns
    -------
    _type_
        _description_
    """
        
    # initialize values
    weights = data.to_numpy()
    mixture_embeddings = []
    
    # initialize frechet mean
    embedding_size = otu_embeddings.shape[1]
    hyperbolic = Hyperbolic(dim=embedding_size, default_coords_type='ball')
    fmean = FrechetMean(hyperbolic.metric, max_iter=100)
        
    for i in tqdm(range(len(data)), desc='Mixture Embeddings'):
        
        # compute the mixtured embedding for the current sample
        if distance_str == 'hyperbolic':
            mixture_embedding = fmean.fit(otu_embeddings, weights=weights[i]).estimate_  
        else:
            mixture_embedding = np.average(otu_embeddings, weights=weights.iloc[i])
        mixture_embeddings.append(mixture_embedding)
        
    mixture_embeddings = np.array(mixture_embeddings)
    return mixture_embeddings

# Run it

In [15]:
data_path = '../data/interim/ihmp/ibd_data.pickle'
data = load_dataset(data_path)
ids = data.columns.to_list()

Unnamed: 0_level_0,1000269,1008348,1009894,1012376,1017181,1017413,1019823,1019878,102222,1023075,...,964363,968675,968954,971907,975306,976470,979707,988375,988932,999046
sample id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
CSM5FZ3N,-12.798335,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,...,-23.025851,-10.718925,-10.090319,-23.025851,-7.704621,-7.276910,-23.025851,-23.025851,-11.412068,-23.025851
CSM5FZ3X,-12.727710,-12.034579,-23.025851,-23.025851,-11.341440,-23.025851,-23.025851,-11.629120,-23.025851,-23.025851,...,-23.025851,-9.683219,-10.088684,-10.781828,-7.316097,-11.629120,-23.025851,-8.966542,-12.034579,-12.727710
CSM5FZ3Z,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-11.321341,-23.025851,-11.321341,-23.025851,-23.025851,...,-23.025851,-9.711910,-12.419937,-23.025851,-8.468717,-23.025851,-23.025851,-23.025851,-10.810519,-23.025851
CSM5FZ44,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,...,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851
CSM5FZ46,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-12.192481,-23.025851,-23.025851,-23.025851,-23.025851,...,-23.025851,-10.806202,-11.499344,-12.192481,-8.321300,-23.025851,-23.025851,-23.025851,-11.499344,-23.025851
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
MSM5LLIO,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-11.003752,-12.256500,-23.025851,...,-23.025851,-10.384716,-11.157902,-11.563363,-6.647049,-11.563363,-23.025851,-6.559427,-12.256500,-23.025851
MSM5LLIQ,-8.727984,-23.025851,-23.025851,-23.025851,-10.859606,-11.419218,-23.025851,-7.380571,-23.025851,-23.025851,...,-23.025851,-5.957516,-7.284061,-6.873276,-8.362870,-10.608293,-23.025851,-23.025851,-8.488033,-11.706897
MSM5LLIS,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,...,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851,-23.025851
MSM5ZOJY,-10.368532,-23.025851,-23.025851,-23.025851,-10.997138,-10.773996,-23.025851,-6.245711,-23.025851,-23.025851,...,-23.025851,-5.533312,-8.745851,-6.110561,-5.273559,-8.142111,-23.025851,-23.025851,-10.591675,-13.076538


In [16]:
encoder_path = '../models2/cnn_hyperbolic_16_model.pickle'
auxillary_data_path = '../data/interim/greengenes/auxillary_data.pickle'
batch_size = 128
no_cuda = False
seed = 42

otu_embeddings = get_embeddings(ids, encoder_path, batch_size, seed=seed, no_cuda=no_cuda, auxillary_data_path=auxillary_data_path)
otu_embeddings.shape

Using device: cuda
Loading model ../models2/cnn_hyperbolic_16_model.pickle


Convert string sequences to numerical sequences: 100%|██████████| 1370/1370 [00:00<00:00, 4055.89it/s]
Embedding sequences: 100%|██████████| 11/11 [00:00<00:00, 277.17it/s]


torch.Size([1370, 16])

In [18]:
distance_str = encoder_path.split('/')[-1].split('_')[1]
embeddings = get_mixture_embeddings(data, otu_embeddings, distance_str)

Mixture Embeddings: 100%|██████████| 96/96 [00:00<00:00, 319.56it/s]
