In [None]:
import pandas as pd
import skbio

from itertools import combinations
from multiprocessing import Pool
from skbio import TreeNode
from skbio.diversity.alpha import faith_pd
from skbio.diversity.beta import unweighted_unifrac, weighted_unifrac

In [None]:
# global variables go in all caps
MYBIOM = pd.read_table('./data/updated_lacto_sum_scaled.tsv')
MYTREE = skbio.read('./data/tree.nwk', format='newick', into=TreeNode)

In [None]:
SAMPLE_IDS = MYBIOM.columns[1:]
OTU_IDS = MYBIOM['asv']
NCORES = 6

## Alpha Diversity
To calculate the Faith's PD for each of the samples, we must iterate over each of the feature table's columns, extract that column, pass that column into the `faith_pd` function, and then save that output to a list.

This could be accomplished like so:

```python
alpha_list = []
for i in SAMPLE_IDS:
    sample = MYBIOM[sample_id]
    tmp = faith_pd(sample, otu_ids=OTU_IDS, tree=MYTREE)
    alpha_list.append(tmp)
```

This function is slow, but this is an easy problem to parallelize. To do so, we will use the `multiprocessing` library's `Pool` object.
This will allow us to create a pool of processes and throw it at this problem via `Pool.map()`. `map()` accepts a function (`faith_pd()`) and an iterable
(our sample vectors).
In this case, our function `faith_pd()` requires additional arguments so we need to do some wrangling before we can just pop it into `Pool.map()`.
We can accomplish this by using `partial()` or by creating a wrapper function that accepts the column name and then performs the `faith_pd()` calculation.

In [16]:
def faith_fun(sample_id):
    # to avoid doing "partial" nonsense, we wrap the faith_pd
    # function in another function so we can reference the 
    # global OTU_IDS and MYTREE variables.

    # get the column for the given sample_id
    sample = MYBIOM[sample_id]
    # run faith_pd on sample
    alpha = faith_pd(sample, otu_ids=OTU_IDS, tree=MYTREE)

    return alpha

In [17]:
# use the "with x as y:" syntax so we don't have to close
# our Pool manually. It's good to tidy up
with Pool(NCORES) as p:
    alphas = p.map(faith_fun, SAMPLE_IDS)

# convert to dataframe for export
alphas = pd.DataFrame({'sampleid' : SAMPLE_IDS,
                       'faith_pd' : alphas})

print(alphas.head())

      sampleid   faith_pd
0  1_0403_9699  41.863220
1  1_0410_9686  47.256110
2  1_0410_9687  39.344774
3  1_0410_9689  44.571748
4  1_0410_9691  40.486682


## Beta Diversity
This code is very inefficient as we populate the entire symmetric distance matrix instead of just the lower triangle. It would be better to instead pass the list of combinations to the `Pool.map()` function.

In [None]:
def unweighted_unifrac_fun_slow(sample_id):
    # return all unweighted unifrac distances
    # for a given sample_id (u)

    u = MYBIOM[sample_id]
    beta_list = []
    # iterate over all samples (v) and compare them to u
    for v_id in SAMPLE_IDS:
        # skip expensive calculations if self to self comparison
        if sample_id == v_id:
            beta_list.append(0)
        else:
            v = MYBIOM[v_id]
            beta = unweighted_unifrac(u, v, tree=MYTREE,
                                      otu_ids=OTU_IDS, validate=False)
            beta_list.append(beta)
        
    return beta_list

with Pool(NCORES) as p:
    uw_betas = p.map(unweighted_unifrac_fun_slow, SAMPLE_IDS)

uw = pd.DataFrame(uw_betas)

By instead using `itertool`'s `combinations` function, we instead only pass the unique pairs of u,v and we reduce our computational workload by over a factor of 2. The code is also more straightforward and compact. Unfortunately, this is still slower than molasses. The `unifrac` library is a Godsend. Use that instead.

In [None]:
def unweighted_unifrac_fun_fast(combs: tuple):
    # accepts a tuple of (u, v)
    u = MYBIOM[combs[0]]
    v = MYBIOM[combs[1]]

    beta = unweighted_unifrac(u, v, tree=MYTREE,
                              otu_ids=OTU_IDS, validate=False)
    
    return(beta)

# combinations will create all combinations of size n from a given iterable.
with Pool(NCORES) as p:
    uw_betas = p.map(unweighted_unifrac_fun_fast, combinations(SAMPLE_IDS, 2))
