In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import argparse
import warnings

import torch
import numpy as np
import pandas as pd
from tqdm import tqdm

from geomstats.geometry.hyperbolic import Hyperbolic
from geomstats.learning.frechet_mean import FrechetMean

# local files
from src.util.data_handling.string_generator import str_seq_to_num_seq, ALPHABETS
from src.util.data_handling.data_loader import save_as_pickle, load_dataset, make_dir
from src.util.data_handling.closest_string_dataset import ReferenceDataset, QueryDataset

from icecream import ic

INFO: Using numpy backend


In [3]:
def full_stack():
    """Source: https://stackoverflow.com/a/16589622/14773537."""
    import traceback, sys
    exc = sys.exc_info()[0]
    stack = traceback.extract_stack()[:-1]  # last one would be full_stack()
    if exc is not None:  # i.e. an exception is present
        del stack[-1]       # remove call of full_stack, the printed exception
                            # will contain the caught exception caller instead
    trc = 'Traceback (most recent call last):\n'
    stackstr = trc + ''.join(traceback.format_list(stack))
    if exc is not None:
         stackstr += '  ' + traceback.format_exc().lstrip(trc)
    return stackstr

In [4]:
def load_model(encoder_path):
    
    # model
    encoder_model, state_dict = torch.load(encoder_path)
    encoder_model.load_state_dict(state_dict)

    # Restore best model
    print('Loading model ' + encoder_path)
    encoder_model.load_state_dict(state_dict)
    encoder_model.eval()
    
    return encoder_model

In [5]:
def get_num_seq(ids, auxillary_data_path):
    
    id_to_str_seq, _, alphabet_str, length = load_dataset(auxillary_data_path)
    alphabet = ALPHABETS[alphabet_str]
    str_seqs = [id_to_str_seq[str(_id)] for _id in ids]
    num_seqs = [str_seq_to_num_seq(s, length=length, alphabet=alphabet) for s in tqdm(str_seqs, desc='Convert string sequences to numerical sequences')]
    return num_seqs

In [6]:
def get_dataloader(num_seq, batch_size, labels=None):
    """Convert a num_seq to a dataloader. Optionally can add labels too."""
    
    if labels is  None:
        dataset = ReferenceDataset(num_seq) # iterate over just num_seq
    else:
        dataset = QueryDataset(num_seq, labels) # iterate over num_seq and labels together
        
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
    return dataloader

In [7]:
def embed_strings(loader, model, device, desc='Embedding sequences'):
    """ Embeds the sequences of a dataset one batch at the time given an encoder """
    embeddings = []

    for sequences in tqdm(loader, desc=desc):
        if isinstance(sequences, list): # query dataloader iterates over (sequences, label); so here sequnces is a list and must remove label
            sequences = sequences[0]
        sequences = sequences.to(device)
        embedded = model.encode(sequences)
        embeddings.append(embedded.cpu().detach())

    embeddings = np.vstack(embeddings)
    return embeddings

In [8]:
def get_otu_embeddings(data, encoder_path, batch_size, seed=42, no_cuda=False, labels=None, auxillary_data_path=None):
    """Compute otu embeddings.
    
    Data can either be a list of ids or num_seq.
    * If it is a list of ids, then we will need auxillary_data_path and will
      automatically compute num_seq from the data and then get the embeddings.
    * Otherwise if data is num_seq to begin with we will just simply get the
      embeddings.
    """    
    
    # set device
    cuda = not no_cuda and torch.cuda.is_available()
    device = 'cuda' if cuda else 'cpu'
    print('Using device:', device)

    # set random seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    if cuda:
        torch.cuda.manual_seed(seed)
        
    # load model
    encoder_model = load_model(encoder_path)
    
    # load data
    if auxillary_data_path is not None:
        ids = data
        num_seq = get_num_seq(ids, auxillary_data_path)
    else:
        num_seq = data
        
    # get dataloader
    loader = get_dataloader(num_seq, batch_size, labels)
    
    # embed strings
    embeddings = embed_strings(loader, encoder_model, device)
    return embeddings

In [9]:
def get_mixture_embeddings(data, otu_embeddings, distance_str):
    """Compute mixture embeddings.

    Parameters
    ----------
    data : pandas DataFrame of shape (n_samples, n_otus)
        Dataframe where the ijth entry is how much the ith person has of the jth
       otu.
    otu_embeddings : np.ndarray of shape (n_otus, embedding_size)
        _description_
    distance_str : string
        The distance metric used to generate the otu_embeddings. distance_str
        should be `hyperbolic` or euclidean`.

    Returns
    -------
    mixture_embeddings : np.ndarray of shape (n_samples, embedding_size)
        The mixture embeddings of each sample weighted by the otu abudances of
        that sample.
    """
        
    # initialize values
    weights = data.to_numpy()
    mixture_embeddings = []
    n_errors = 0
    
    # initialize frechet mean
    embedding_size = otu_embeddings.shape[1]
    hyperbolic = Hyperbolic(dim=embedding_size, default_coords_type='ball')
    fmean = FrechetMean(hyperbolic.metric, max_iter=1000, method='adaptive')
        
    for i in tqdm(range(len(data)), desc='Mixture Embeddings', disable=True):
        
        # compute the mixtured embedding for the current sample
        # temporarily set warnings to errors so can identify when Frechet mean does not converge
        # Source: https://stackoverflow.com/a/30368735/14773537
        if distance_str == 'hyperbolic':
            
            warnings.filterwarnings("error") # treat warnings like errors
            try:
                mixture_embedding = fmean.fit(otu_embeddings, weights=weights[i]).estimate_  
            except Exception as error:
                # ic(i, error)
                # print(full_stack())
                mixture_embedding = np.zeros(embedding_size)
                n_errors += 1
            warnings.resetwarnings() # treat warnings like warnings again
            # return

        else:
            mixture_embedding = np.average(otu_embeddings, weights=weights[i], axis=0)
        mixture_embeddings.append(mixture_embedding)
        
    ic(n_errors)
    mixture_embeddings = np.array(mixture_embeddings)
    return mixture_embeddings

In [10]:
def get_embeddings(encoder_path, ihmp_data_path, outdir, batch_size, auxillary_data_path='data/interim/greengenes/auxillary_data.pickle', seed=42, save=True, no_cuda=False):
    """Get mixture embeddings for all data"""
    
    # inital values
    model_name = '_'.join(encoder_path.split('/')[-1].split('_')[:-1])
    distance_str = model_name.split('_')[1]
    data_name = ihmp_data_path.split('/')[-1].split("_")[0]
    print('\n' + '-'*5 + 'Compute {} Embeddings'.format(data_name) + '-'*5)
    
    # load data        
    data = pd.read_csv(ihmp_data_path, index_col='Sample')
    otu_ids = data.columns.to_list()

    # compute and save otu embeddings
    otu_embeddings = get_otu_embeddings(otu_ids, encoder_path, batch_size, seed=seed, no_cuda=no_cuda, auxillary_data_path=auxillary_data_path)
    otu_embeddings_df = pd.DataFrame(otu_embeddings, index=data.columns)
    otu_embeddings_df.index.name = 'OTU'
    if save:
        otu_filename = '{}/otu_embeddings/{}/{}_otu_embeddings.csv'.format(outdir, data_name, model_name)
        otu_embeddings_df.to_csv(make_dir(otu_filename))
    
    # compute and save mixture embeddings
    mixture_embeddings = get_mixture_embeddings(data, otu_embeddings, distance_str)
    mixture_embeddings_df = pd.DataFrame(mixture_embeddings, index=data.index.to_list())
    mixture_embeddings_df.index.name = 'Sample'
    if save:
        mixture_filename = '{}/mixture_embeddings/{}/{}_mixture_embeddings.csv'.format(outdir, data_name, model_name)
        mixture_embeddings_df.to_csv(make_dir(mixture_filename))
    
    return otu_embeddings_df, mixture_embeddings_df

In [11]:
data_name = 'ibd'
distance = 'hyperbolic'
embedding_size = 2
batch_size = 128
no_cuda = False
save = True

encoder_path = '../models/cnn_{}_{}_model.pickle'.format(distance, embedding_size)
ihmp_data_path = '../data/interim/ihmp/{}_data.csv'.format(data_name)
outdir = '../data/processed'
auxillary_data_path = '../data/interim/greengenes/auxillary_data.pickle'


otu_embeddings_df, mixture_embeddings_df = get_embeddings(
    encoder_path, 
    ihmp_data_path,
    outdir,
    batch_size,
    auxillary_data_path=auxillary_data_path,
    save=save
    )


-----Compute ibd Embeddings-----
Using device: cuda
Loading model ../models/cnn_hyperbolic_2_model.pickle


Convert string sequences to numerical sequences: 100%|██████████| 1370/1370 [00:00<00:00, 3519.46it/s]
Embedding sequences: 100%|██████████| 11/11 [00:01<00:00,  6.06it/s]
ic| n_errors: 19


In [12]:
otu_embeddings_df

Unnamed: 0_level_0,0,1
OTU,Unnamed: 1_level_1,Unnamed: 2_level_1
1000269,-0.971315,-0.233556
1008348,-0.188370,0.981080
1009894,-0.938192,0.343216
1012376,-0.713432,0.699297
1017181,-0.995324,-0.085623
...,...,...
976470,-0.775837,-0.629347
979707,-0.535303,0.843476
988375,-0.504512,0.862246
988932,-0.979092,-0.198445


In [13]:
mixture_embeddings_df

Unnamed: 0_level_0,0,1
Sample,Unnamed: 1_level_1,Unnamed: 2_level_1
CSM5FZ3N,0.000000,0.000000
CSM5FZ3X,0.000000,0.000000
CSM5FZ3Z,-0.019282,0.780946
CSM5FZ44,0.000000,0.000000
CSM5FZ46,0.000000,0.000000
...,...,...
MSM5LLIO,-0.187492,0.394580
MSM5LLIQ,-0.359186,0.083998
MSM5LLIS,-0.237867,0.063742
MSM5ZOJY,-0.379982,-0.431924


# Analysis


IBD Data:
| Distance         | Dimension     | Num Errors |
|--------------|-----------|------------|
| Hyperbolic |2     | 19     |
| Hyperbolic      | 4  | 0     |
| Hyperbolic      | 6  | 0     |
| Hyperbolic      | 8  | 0     |
| Hyperbolic      | 16  | 0     |
| Hyperbolic      | 32  | 0     |
| Hyperbolic      | 64  | 0     |
| Hyperbolic      | 128  | 0     |


Moms Data:
| Distance         | Dimension     | Num Errors |
|--------------|-----------|------------|
| Hyperbolic |2     | 0     |
| Hyperbolic      | 4  | 0     |
| Hyperbolic      | 6  | 0     |
| Hyperbolic      | 8  | 0     |
| Hyperbolic      | 16  | 0     |
| Hyperbolic      | 32  | 0     |
| Hyperbolic      | 64  | 0     |
| Hyperbolic      | 128  | 0     |