# Load Data

In [1]:
import pandas as pd
import numpy as np

from util.data_handling.data_loader import load_dataset

In [2]:
MOMS_PI_TABLES_PATH = '../../data/interim/moms_pi/16s_tables.pkl'
MOMS_PI_METADATA_PATH = '../../data/interim/moms_pi/16s_metadata.tsv'
MOMS_PI_MANIFEST_PATH = '../../data/interim/moms_pi/16s_manifest.tsv'

ID_TO_EMBEDDING_SEQUENCE_PATH = '../../data/processed/greengenes/id_to_sequence_embedding.pickle'

In [3]:
otu_tables = load_dataset(MOMS_PI_TABLES_PATH)
metadata = pd.read_csv(MOMS_PI_METADATA_PATH, sep='\t')
manifest = pd.read_csv(MOMS_PI_MANIFEST_PATH, sep='\t')

id_to_embedding = load_dataset(ID_TO_EMBEDDING_SEQUENCE_PATH)

In [4]:
otu_tables

{'MCVD': 1246 x 162 <class 'biom.table.Table'> with 7185 nonzero entries (3% dense),
 'BC1D': 905 x 47 <class 'biom.table.Table'> with 2038 nonzero entries (4% dense),
 'BRCD': 2374 x 967 <class 'biom.table.Table'> with 31133 nonzero entries (1% dense),
 'BS1D': 2270 x 734 <class 'biom.table.Table'> with 30595 nonzero entries (1% dense),
 'MV1D': 3434 x 2055 <class 'biom.table.Table'> with 121282 nonzero entries (1% dense),
 'MCHD': 956 x 152 <class 'biom.table.Table'> with 3079 nonzero entries (2% dense),
 'BCKD': 2719 x 964 <class 'biom.table.Table'> with 40975 nonzero entries (1% dense),
 'BSTD': 538 x 37 <class 'biom.table.Table'> with 1602 nonzero entries (8% dense),
 'MCKD': 3748 x 2327 <class 'biom.table.Table'> with 286063 nonzero entries (3% dense),
 'MRCD': 6251 x 1725 <class 'biom.table.Table'> with 399902 nonzero entries (3% dense)}

In [5]:
manifest.head()

Unnamed: 0,file_id,md5,size,urls,sample_id
0,e50d0c183689e4053ccb35f8b26885b5,62250c3bd427c2c67f1e115ceff2ddd5,113363,https://downloads.hmpdacc.org/ihmp/ptb/genome/...,8938ac880f194a32bc6b736c15020e14
1,e50d0c183689e4053ccb35f8b21ef8a7,898138e9c5e1bb3f16d9f544d9d70e79,45248,https://downloads.hmpdacc.org/ihmp/ptb/genome/...,8938ac880f194a32bc6b736c1501d5b3
2,e50d0c183689e4053ccb35f8b2562491,86c6b76f324eee9e8534e3632b54f52b,37056,https://downloads.hmpdacc.org/ihmp/ptb/genome/...,8938ac880f194a32bc6b736c150200f9
3,6d91411d5ede0689305232cc77c163cc,0cff2481a77b30714f89b7554f6fa469,72337,https://downloads.hmpdacc.org/ihmp/ptb/genome/...,e50d0c183689e4053ccb35f8b2bbf6ff
4,6d91411d5ede0689305232cc77d9d224,b78623036ba3ca5c8abd36205096b1e2,37056,https://downloads.hmpdacc.org/ihmp/ptb/genome/...,8938ac880f194a32bc6b736c15021a0e


In [6]:
metadata.head()

Unnamed: 0,sample_id,subject_id,subject_uuid,sample_body_site,visit_number,subject_gender,subject_race,study_full_name,project_name
0,8938ac880f194a32bc6b736c15020e14,EP575820,88af6472fb03642dd5eaf8cddcbfadb5,rectum,3,female,unknown,momspi,Integrative Human Microbiome Project
1,8938ac880f194a32bc6b736c1501d5b3,EP575820,88af6472fb03642dd5eaf8cddcbfadb5,rectum,1,female,unknown,momspi,Integrative Human Microbiome Project
2,8938ac880f194a32bc6b736c150200f9,EP575820,88af6472fb03642dd5eaf8cddcbfadb5,buccal mucosa,3,female,unknown,momspi,Integrative Human Microbiome Project
3,e50d0c183689e4053ccb35f8b2bbf6ff,EP575820,88af6472fb03642dd5eaf8cddcbfadb5,buccal mucosa,4,female,unknown,momspi,Integrative Human Microbiome Project
4,8938ac880f194a32bc6b736c15021a0e,EP575820,88af6472fb03642dd5eaf8cddcbfadb5,buccal mucosa,7,female,unknown,momspi,Integrative Human Microbiome Project


# Clean OTU Data

Normalize the OTU data.

In [7]:
otu_tables_normed = {otu_type: otu_table.norm() for otu_type, otu_table in otu_tables.items()}

An OTU matrix may have sequence ids not found in the greengenes dataset. We must
remove these ids.

In [8]:
def drop_missing_ids(otu_tables, verbose=False):
    "Remove ids from the OTU matrix that are not found in the greengenes dataset."
    
    otu_tables_cleaned = {}
    
    for otu_type, otu_table in otu_tables.items():
        
        ids = otu_table.ids(axis='observation')
        valid_ids = [id_ for id_ in ids if int(id_) in id_to_embedding]
        otu_table_valid = otu_table.filter(valid_ids, axis='observation', inplace=False)
        otu_tables_cleaned[otu_type] = otu_table_valid
        
        if verbose:
            valid_ratio = len(valid_ids) / len(ids)
            print(f'{otu_type}: {valid_ratio:.2%} of ids are valid')
            
    return otu_tables_cleaned

In [13]:
otu_tables_cleaned = drop_missing_ids(otu_tables_normed, verbose=True)

MCVD: 78.81% of ids are valid
BC1D: 87.29% of ids are valid
BRCD: 84.12% of ids are valid
BS1D: 84.14% of ids are valid
MV1D: 87.42% of ids are valid
MCHD: 89.64% of ids are valid
BCKD: 82.53% of ids are valid
BSTD: 82.71% of ids are valid
MCKD: 83.35% of ids are valid
MRCD: 91.20% of ids are valid


# Compute Mixture Embeddings

In [14]:
from geomstats.learning.frechet_mean import FrechetMean
from geomstats.geometry.poincare_ball import PoincareBall

In [15]:
poincare = PoincareBall(dim=128)
fmean = FrechetMean(poincare.metric)

In [16]:
otu_table = otu_tables_cleaned['MRCD']
ids = otu_table.ids(axis='observation').astype(int)
embeddings = np.array([id_to_embedding[id_] for id_ in ids])
embeddings.shape, ids.shape, otu_table.shape

((5701, 128), (5701,), (5701, 1725))

These dimensions make sense.

* 5701 = the total number of OTUs detected in each sample
  * this corresponds to the number of unique IDs we have
* 128 = embedding dimension
* 1725 = the number of biological samples we took

## Sanity Check

Check if a single embedding, `point`, is located on the poincare disk?

In [17]:
point = embeddings[0, :]
poincare.belongs(point)

False

Why are the embeddings we generated not on the pointcare disk?

Ideas:
1. The poincare disk I created in this file is somehow different than the disk
   used to generate the embeddings.
2. The NeuroSEED code does not use Geomstats package anywhere in their code.
   They instead manually define functions for hyperbolic distance and such.
   Perhaps the code they wrote in their own hyperbolic implementation differs
   from Geomstats's hyperbolic implementation.
3. Perhaps when we normalize the OTU matrix, we are moving the embeddings off of
   the Poincare disk. (I tried not normalizing the OTU matrix, but it didn't fix
   anything.)
4. Other ideas??

In [None]:
for i in range(otu_table.length()):
    weights = otu_table[:, i]
    mean = fmean.fit(embeddings, weights=weights).estimate_
    print(mean)
    break

ValueError: Points do not belong to the Poincare ball