In [18]:
import numpy as np
from scipy.stats import wasserstein_distance
from scipy.optimize import linprog
import matplotlib.pyplot as plt
import biom
from ete3 import Tree



In [6]:
#!pip install biom-format



Collecting biom-format
  Downloading biom-format-2.1.16.tar.gz (11.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.7/11.7 MB[0m [31m46.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: biom-format
  Building wheel for biom-format (pyproject.toml) ... [?25l[?25hdone
  Created wheel for biom-format: filename=biom_format-2.1.16-cp310-cp310-linux_x86_64.whl size=12158973 sha256=a459308b52123129fe60ffce531eb697983e810f38d4d22f7a07c82f154ce630
  Stored in directory: /root/.cache/pip/wheels/8e/a9/f9/197fd5a0e5bbab5f2e03c89194f6c194bed7af5d7a8c8759f3
Successfully built biom-format
Installing collected packages: biom-format
Successfully installed biom-format-2.1.16


In [7]:
#!pip install ete3

Collecting ete3
  Downloading ete3-3.1.3.tar.gz (4.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.8/4.8 MB[0m [31m21.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: ete3
  Building wheel for ete3 (setup.py) ... [?25l[?25hdone
  Created wheel for ete3: filename=ete3-3.1.3-py3-none-any.whl size=2273787 sha256=38b2497af02238cd91e2d0fe42d725ac57079cd54909c4042cb6ad73ef39921d
  Stored in directory: /root/.cache/pip/wheels/a0/72/00/1982bd848e52b03079dbf800900120bc1c20e92e9a1216e525
Successfully built ete3
Installing collected packages: ete3
Successfully installed ete3-3.1.3


In [19]:

table = biom.load_table('feature-table.biom')


#  first few rows of data
print("OTUs (rows) in the feature table:")
print(table.ids('observation'))  # List of OTUs

# Get sample IDs (columns)
print("\nSample IDs (columns):")
print(table.ids('sample'))

# Check the actual data matrix
print("\nData matrix (samples x OTUs):")
print(table.matrix_data.toarray())


OTUs (rows) in the feature table:
['668fdb718997fc1589c7817655d4bb5f' 'a3f36ef32153f2fc2aaeac2feb23777f'
 '9496d87b94d90dff068f0716603930bd' ... '46b10e705d5fbdc5c8d8f3a24249591e'
 '4c5ce916f019b3ba5d0fa994d24aee1a' '530d3f556e633849ac4580c76fdda317']

Sample IDs (columns):
['206534' '206536' '206538' '206548' '206561' '206562' '206563' '206569'
 '206570' '206571' '206572' '206603' '206604' '206605' '206608' '206609'
 '206614' '206615' '206616' '206617' '206618' '206619' '206620' '206621'
 '206622' '206623' '206624' '206625' '206626' '206627' '206628' '206629'
 '206630' '206635' '206636' '206643' '206644' '206645' '206646' '206647'
 '206648' '206655' '206656' '206657' '206658' '206659' '206660' '206667'
 '206668' '206669' '206670' '206671' '206672' '206673' '206675' '206676'
 '206677' '206678' '206681' '206682' '206683' '206684' '206695' '206700'
 '206701' '206702' '206703' '206704' '206708' '206709' '206710' '206711'
 '206712' '206713' '206718' '206719' '206720' '206721' '206723' '206

In [20]:

#made up data


#gives the proportions of different microbial species in each sample.
#sample 1 is crohn , sample 2 is control
sample1 = np.array([0.4, 0.3, 0.2, 0.1])
sample2 = np.array([0.35, 0.3, 0.25, 0.1])

# 4 edges
# placeholder egde represent evolutionary distances between OTUs
edges = [0, 1, 2, 3]

# Edge lengths (evolutionary distances between nodes)
Le = [1.0, 0.5, 0.8, 1.2]

# Empirical distributions for each edge in the tree
#pe is crohn and qe is for control
Pe = [0.4, 0.3, 0.2, 0.1]
Qe = [0.35, 0.3, 0.25, 0.1]


In [21]:
def moment_screening_estimator(Pe, Qe, Le, polynomial_degree=1):
    """

    Parameters:
    - Pe: Empirical distribution for sample P at a tree edge
    - Qe: Empirical distribution for sample Q at a tree edge
    - Le: Length of the tree edge between nodes (e)
    - polynomial_degree: Degree of the polynomial used for bias reduction (1 means linear)

    Returns:
    - Wasserstein distance
"""

    bias_corrected_distance = Le * (np.abs(Pe - Qe) ** polynomial_degree)
    return bias_corrected_distance


In [22]:
def transport(Pe, Qe, Le):
    """
    compute the optimal transport cost between distributions using
    the Wasserstein distance formula.

    Parameters:
    - Pe: Empirical distribution for sample P at each tree edge
    - Qe: Empirical distribution for sample Q at each tree edge
    - Le: Length of the tree edges

    Returns:
    - Optimized transport cost (Wasserstein distance)
    """
    # cost matrix: absolute differences between distributions, weighted by edge lengths
    cost_matrix = np.abs(np.subtract(Pe, Qe)) * Le
    result = linprog(cost_matrix.flatten(), method='simplex')

    return result.fun


In [23]:
def compute_Wasserstein_distance(Pe, Qe, Le, edges, polynomial_degree=1):
    """
    total Wasserstein distance between two microbiome distributions (P and Q)

    Parameters:
    - Pe: Empirical distribution for sample P at each edge
    - Qe: Empirical distribution for sample Q at each edge
    - Le: Length of the tree edges
    - edges: List of tree edges to calculate the distance over
    - polynomial_degree: Degree of the polynomial for moment screening

    Returns:
    - Total computed Wasserstein distance
    """
    total_distance = 0

    for edge in edges:
        P_edge = Pe[edge]  # Empirical distribution for sample 1 (Crohn's)
        Q_edge = Qe[edge]  # Empirical distribution for sample 2 (Control)
        L_edge = Le[edge]  # Length of the tree edge

        # Apply moment screening for waser. distance
        total_distance += moment_screening_estimator(P_edge, Q_edge, L_edge, polynomial_degree)
    #trasnport for optmized distance
    optimized_transport_cost = transport(Pe, Qe, Le)
    #sum of the total distance
    return total_distance + optimized_transport_cost



In [24]:
# Compute the total Wasserstein distance between microbiome samples using MET
Wasserstein_distance = compute_Wasserstein_distance(Pe, Qe, Le, edges, polynomial_degree=2)
print(f"Total Wasserstein distance using MET: {Wasserstein_distance}")


Total Wasserstein distance using MET: 0.004500000000000004


  result = linprog(cost_matrix.flatten(), method='simplex')
