In [5]:
import pandas as pd
import numpy as np
from io import StringIO
import ot
import os
# import sys
import json
from collections import defaultdict
from tqdm.notebook import tqdm
from typing import List

# import collections

from ete3 import Tree
from skbio import TreeNode
import Levenshtein
from Bio import SeqIO, Align
from skbio.diversity.beta import weighted_unifrac, unweighted_unifrac

import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.model_selection import KFold, cross_val_predict, cross_val_score, StratifiedKFold
from sklearn.metrics import accuracy_score, f1_score, make_scorer
from sklearn.preprocessing import LabelEncoder, OneHotEncoder

# Classifiers to test
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.decomposition import PCA

from scipy.spatial import distance_matrix

import warnings
from sklearn.exceptions import ConvergenceWarning
warnings.filterwarnings('ignore', category=ConvergenceWarning)

from src import fasta_to_kmer_vector

# Number of folds for K-fold validation
N_FOLDS = 10

# Which BLAST database to look for OTU --> genome mappings
GENOME_DATABASE = "core_nt"

In [6]:
DATA_DIR = os.path.join("data")
SAMPLE_DFS_DIR = os.path.join("data", "samples_by_samples_dataframes")
OT_COST_MATRICES_DIR = os.path.join("data", "ot_cost_matrices")

os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(SAMPLE_DFS_DIR, exist_ok=True)
os.makedirs(OT_COST_MATRICES_DIR, exist_ok=True)

# Data Loading

In [None]:
%%bash

ROOT=$(git rev-parse --show-toplevel)

# Download the annotated GreenGenes tree if it's not already in data/
if [ ! -f "data/gg_13_5_otus_99_annotated.tree" ]; then
    echo "Downloading accession information from GreenGenes"
    wget --no-check-certificate https://ftp.microbio.me/greengenes_release/gg_13_5/gg_13_5_otus_99_annotated.tree.gz -O ${ROOT}/data/gg_13_5_otus_99_annotated.tree.gz
    gzip -v -d ${ROOT}/data/gg_13_5_accessions.txt.gz
fi

In [7]:
tree = Tree("data/gg_13_5_otus_99_annotated.tree", format=1, quoted_node_names=True)
skbio_tree = TreeNode.read(StringIO(tree.write(format_root_node=True)))

In [8]:
# Read 16S table, get top N OTUs
# TODO: Use columns other than `sample`
ibd_data = pd.read_csv("ihmp/ibd_data.csv.gz", dtype={0: str}, compression='gzip')
_otus = ibd_data.set_index(ibd_data.columns[3])
_otus.drop(columns=['patient', 'visit', 'site'], inplace=True)
otus = _otus.T.copy()

In [9]:
# We care about columns:
# dianosis: CD = chrons disease, UC = ulcerative colitis, nonIBD = control
ibd_metadata = pd.read_csv("ihmp/ibd_metadata_new.csv")

# Drop the nonIBD label since we only have 1 example of it
non_ibd_index = ibd_metadata[ibd_metadata.diagnosis == "nonIBD"].index.item()
non_ibd_sample_id = ibd_metadata[ibd_metadata.diagnosis == "nonIBD"]['sample'].item()
ibd_metadata.drop(index=non_ibd_index, inplace=True)

otus.drop(columns=non_ibd_sample_id, inplace=True)

# Ignore everything except sample id and diagnosis
ibd_metadata_diagnosis = ibd_metadata[['sample', 'diagnosis']].set_index('sample').copy()

In [10]:
# Ensure that metadata and normal samples have same sample ids
num_samples_data = otus.shape[1]
sample_intersection = len(set(ibd_metadata['sample'].tolist()) & set(ibd_data['sample'].tolist()))

assert num_samples_data == sample_intersection

In [11]:
otu_names = otus.index
otus_subset = otus.loc[otu_names].copy()

gg = SeqIO.index("data/gg_13_5.fasta", "fasta")

# Get sequences for top N_SEQS OTUs
topN_seqs = [gg[otu] for otu in otu_names]

In [12]:
skbio_subtree = skbio_tree.shear(otu_names)

# Calculating OT Cost Matrices

## 16S Sequence Distances

### Levenshtein Distance

In [None]:
levenshtein_cost_path = os.path.join(OT_COST_MATRICES_DIR, "levenshtein_cost_matrix.npy")

if os.path.exists(levenshtein_cost_path):
    print("Loading Levenshtein distance based cost matrix from disk")
    
    with open(levenshtein_cost_path, "rb") as file_in:
        levenshtein_cost_matrix = np.load(file_in)
else:
    print("Calculating Levenshtein distance based cost matrix")
    levenshtein_cost_matrix = np.zeros((len(topN_seqs), len(topN_seqs)), dtype=int)
    
    for i in range(len(topN_seqs)):
        for j in range(i + 1, len(topN_seqs)):
            dist = Levenshtein.distance(str(topN_seqs[i].seq), str(topN_seqs[j].seq))
            levenshtein_cost_matrix[i, j] = levenshtein_cost_matrix[j, i] = dist
    
    with open(levenshtein_cost_path, "wb") as f:
        np.save(f, levenshtein_cost_matrix)

### Alignment-based Distance

In [None]:
alignment_cost_path = os.path.join(OT_COST_MATRICES_DIR, "alignment_cost_matrix.npy")

if os.path.exists(alignment_cost_path):
    print("Loading alignment cost matrix from disk")
    
    with open(alignment_cost_path, "rb") as file_in:
        alignment_cost_matrix = np.load(file_in)
else:
    print("Calculating alignment cost matrix")
    alignment_cost_matrix = np.zeros((len(topN_seqs), len(topN_seqs)), dtype=int)
    aligner = Align.PairwiseAligner()

    for i in range(len(topN_seqs)):
        for j in range(i + 1, len(topN_seqs)):
            dist = aligner.align(topN_seqs[i].seq, topN_seqs[j].seq).score
            alignment_cost_matrix[i, j] = dist
            
    with open(alignment_cost_path, "wb") as file_out:
        np.save(file_out, alignment_cost_matrix)

## Tree-based Distances

### Phylogenetic Distance (Sum of GreenGene Edge Weights)

In [None]:
phylogenetic_distance_cost_path = os.path.join(OT_COST_MATRICES_DIR, "phylogenetic_cost_matrix.npy")

if os.path.exists(phylogenetic_distance_cost_path):
    print("Loading phylogenetic distance based cost matrix from disk")
    
    with open(phylogenetic_distance_cost_path, "rb") as file_in:
        phylogenetic_cost_matrix = np.load(file_in)
else:
    print("Calculating phylogenetic distance based cost matrix")
    
    skbio_subtree = skbio_tree.shear(set(otus.index.to_list()))
    phylogenetic_cost_matrix = skbio_subtree.tip_tip_distances().data
    
    with open(phylogenetic_distance_cost_path, "wb") as file_out:
        np.save(file_out, phylogenetic_cost_matrix)

## Genome Based Distances

In [None]:
%%bash

set -euo pipefail

# Get the root of the git repo
ROOT=$(git rev-parse --show-toplevel)

mkdir -p ${ROOT}/data/otu_fastas

# If there's no otus.txt file, extract otus from ibd_data.csv and put it in otus.txt
if [ ! -f "${ROOT}/data/otus.txt" ]; then
    echo "Extracting OTUs from Samples x OTUs table"
    head -n 1 ihmp/ibd_data.csv | tr ',' '\n' | tail -n +5 > ${ROOT}/data/otus.txt
fi

# Extract sequence information from the master fasta file into fasta file per OTU 
while read -r otu; do 
    grep "^>${otu}$" -A 1 data/gg_13_5.fasta > ${ROOT}/data/otu_fastas/${otu}.fasta; 
    echo $otu; 
done < ${ROOT}/data/otus.txt

In [None]:
%%bash -s {GENOME_DATABASE}

set -euo pipefail

# Grab GENOME_DATABSE which is declared in the top-most cell
GENOME_DATABASE=$1

# Get the root of the git repo
ROOT=$(git rev-parse --show-toplevel)

mkdir -p ${ROOT}/data/blast/${GENOME_DATABASE}

while read -r otu; do
    if [ ! -f ${ROOT}/data/blast/${GENOME_DATABASE}/${otu}.zip ] && [ ! -d ${ROOT}/data/blast/${GENOME_DATABASE}/${otu} ]; then
        python ${ROOT}/src/search_blast.py \
            -d ${GENOME_DATABASE} \
            -o ${ROOT}/data/blast/${GENOME_DATABASE}/${otu}.zip \
            -t 300 \
            ${ROOT}/data/otu_fastas/${otu}.fasta
    else
        echo "Skipping ${otu} because output exists";
    fi
done < ${ROOT}/data/otus.txt


for file in ${ROOT}/data/blast/${GENOME_DATABASE}/*.zip; do
    unzip -d ${file%.zip} $file
done

In [None]:
blast_search_dir = os.path.join(f"data/blast/{GENOME_DATABASE}")

# For every OTU, list all taxids corresponding to a complete genome
otu_to_taxid = defaultdict(set)
pbar = tqdm(os.listdir(blast_search_dir), total=len(os.listdir(blast_search_dir)))

for otu in pbar:
    subdir = os.path.join(blast_search_dir, otu)
    if os.path.isdir(subdir):
        files = os.listdir(subdir)
        info = [f for f in files if f.endswith("_1.json")][0]
        info_path = os.path.join(subdir, info)
        
        with open(info_path, "r") as f:
            info_dict = json.load(f)
            
        hits = info_dict['BlastOutput2']['report']['results']['search']['hits']
        for hit in hits:
            if 'title' in hit['description'][0].keys():
                if "complete genome" in hit['description'][0]['title']:
                    otu_to_taxid[otu].add(hit['description'][0]['taxid'])

# Convert sets to lists
otu_to_taxid_first_hit = {k: list(v)[0] for k,v in otu_to_taxid.items()}
unique_taxids = list(set(otu_to_taxid_first_hit.values()))

with open("data/unique_taxids.txt", "w") as file_out:
    for taxid in unique_taxids:
        file_out.write(str(taxid) + "\n")

In [None]:
%%bash -s {GENOME_DATABASE}

# Activate the Conda environment
source $(conda info --base)/etc/profile.d/conda.sh
conda activate base  # Replace "base" with your desired Conda environment name

echo "Using Conda environment: $(conda env list | grep '*' | awk '{print $1}')"

# Check if ncbi-datasets-cli is installed in the base environment
if ! command -v datasets &> /dev/null; then
    echo "Error: 'datasets' command is not installed. Please install it before running this script."
    exit 1
fi

GENOME_DATABASE=$1
ROOT=$(git rev-parse --show-toplevel)

echo "Downloading genomes"
mkdir -p ${ROOT}/data/genomes/${GENOME_DATABASE}

while read -r line; do 
    OUTFILE="data/genomes/${GENOME_DATABASE}/${line}.zip"
    OUTDIR="data/genomes/${GENOME_DATABASE}/${line}"

    if [ ! -d "${OUTDIR}" ] && [ ! -f "${OUTFILE}" ]; then
        echo ${line}
        datasets download genome taxon "${line}" \
            --fast-zip-validation \
            --reference \
            --assembly-version latest \
            --assembly-level complete \
            --exclude-atypical \
            --no-progressbar \
            --filename "${OUTFILE}"; 
    else
        echo "Skipping ${line} since ${OUTFILE} or ${OUTDIR} already exists"
    fi
done < ${ROOT}/data/unique_taxids.txt

for file in data/genomes/${GENOME_DATABSE}/*.zip; do
    if [ ! -d ${file%.zip} ]; then
        unzip -d ${file%.zip} $file
        rm -f $file
    fi
done

In [35]:
%%bash -s {GENOME_DATABASE}

set -euo pipefail

GENOME_DATABASE=$1
ROOT=$(git rev-parse --show-toplevel)

for taxid in ${ROOT}/data/genomes/${GENOME_DATABASE}/*; do
    id_only=$(basename ${taxid})
    source_directory="${ROOT}/data/genomes/${GENOME_DATABASE}/${id_only}/ncbi_dataset/data/"
    target_directory="${ROOT}/data/genomes/${GENOME_DATABASE}/${id_only}/"
    
    # Find all .fna files recursively and copy them to the taxid directory
    find ${source_directory} -type f -name "*.fna" -exec cp -v {}/*.fna "${target_directory}" \;
done

In [None]:
%%bash -s {GENOME_DATABASE}

# NOTE: if you have GNU parallel installed, run this in the terminal to get the same output much faster (from within phylosig):
# parallel python src/fasta_to_kmer_vector.py -i {} -o data/genomes/core_nt/{/.}.kmer_vector.tsv -v ::: $(find data/genomes -type f -name "*.fna")

set -euo pipefail

GENOME_DATABASE=$1
ROOT=$(git rev-parse --show-toplevel)

for input_file in $(find ${ROOT}/data/genomes/${GENOME_DATABASE} -maxdepth 2 -type f -name "*.fna"); do
    output_file="${ROOT}/data/genomes/${GENOME_DATABASE}/$(basename "$input_file" .fna).kmer_vector.tsv"
    
    if [[ ! -f "$output_file" ]]; then
        echo "${input_file} -> ${output_file}"
        python src/fasta_to_kmer_vector.py -i "${input_file}" -o "${output_file}" -v
    else
        echo "${output_file} already exists. Skipping."
    fi
done

In [52]:
%%bash -s {GENOME_DATABASE}

GENOME_DATABASE=$1
ROOT=$(git rev-parse --show-toplevel)

# NOTE: This is simply assigning the lexicographically first kmer vector.
# There are probably smarter ways to collate this (e.g. take median of values in each k-mer bin)
for taxid in ${ROOT}/data/genomes/${GENOME_DATABASE}/*; do 
    id_only=$(basename $taxid)
    echo ${id_only} $(ls ${taxid}/*.kmer_vector.tsv | head -n 1); 
done > ${ROOT}/data/_taxid_and_kmer.txt

cat ${ROOT}/data/_taxid_and_kmer.txt | tr ' ' '\t' > ${ROOT}/data/taxid_and_kmer.txt
rm -f ${ROOT}/data/_taxid_and_kmer.txt

In [92]:
df = pd.read_csv("data/taxid_and_kmer.txt", sep="\t", names=["taxid", "kmer_vector_path"]).drop_duplicates(subset="taxid")
taxid_to_otu = {v:k for k,v in otu_to_taxid_first_hit.items()}
df['otu'] = df['taxid'].map(taxid_to_otu)

df.drop('taxid', axis=1, inplace=True)
df.set_index('otu', inplace=True)

otu_to_vector = {
    otu: pd.read_csv(vector_path, sep="\t", index_col=0).to_numpy().flatten()
    for otu, vector_path in df['kmer_vector_path'].to_dict().items()
}

index = pd.read_csv(df.iloc[0].item(), sep="\t", index_col=0).index.to_list()

genome_df = pd.DataFrame(otu_to_vector, index=index)

# Ensure that genome_df's columns are in the same order as otus index
genome_df_columns = [col for col in otus.index if col in genome_df.columns]
genome_df = genome_df[genome_df_columns].copy()

In [93]:
genome_cost_matrix = np.zeros((len(genome_df_columns), len(genome_df_columns)), dtype=float)
genome_normalized_cost_matrix = np.zeros((len(genome_df_columns), len(genome_df_columns)), dtype=float)

for i in range(len(genome_df_columns)):
    for j in range(len(genome_df_columns)):
        otu_i = genome_df[genome_df_columns[i]].to_numpy()
        otu_j = genome_df[genome_df_columns[j]].to_numpy()
        
        otu_i_norm = otu_i / np.sum(otu_i)
        otu_j_norm = otu_j / np.sum(otu_j)
        
        dist = np.linalg.norm(otu_i - otu_j, 2)
        norm_dist = np.linalg.norm(otu_i_norm - otu_j_norm, 2)
        
        genome_cost_matrix[i, j] = dist
        genome_normalized_cost_matrix[i, j] = norm_dist

# Calculating Sample x Sample Cost Matrices

## 16S Sequence Distances

### Levenshtein

In [94]:
levenshtein_sample_df_path = os.path.join(SAMPLE_DFS_DIR, "levenshtein_sample_cost_df.csv")

if os.path.exists(levenshtein_sample_df_path):
    levenshtein_sample_cost_df = pd.read_csv(levenshtein_sample_df_path, index_col=0)
else:
    num_samples = otus.shape[1]
    levenshtein_sample_cost_matrix = np.zeros((num_samples, num_samples))

    # Calculate the upper triangle of (samples x samples) total cost matrix
    # Each entry is sum of entries in the transport matrix generated by OT
    for i in range(num_samples):
        for j in range(num_samples):
            # Create normalized p_i and q_j vectors
            p = otus[otus.columns[i]].to_numpy()
            p_norm = p / p.sum()
            q = otus[otus.columns[j]].to_numpy()
            q_norm = q / q.sum()
            
            # Calculate total cost which is equivalent to the Wasserstein distance
            total_cost = ot.emd2(p, q, levenshtein_cost_matrix, log=False, check_marginals=False)
            
            # Assign the cost to the corresponding entry
            levenshtein_sample_cost_matrix[i, j] = total_cost
                
    levenshtein_sample_cost_df = pd.DataFrame(levenshtein_sample_cost_matrix, index=otus.columns, columns=otus.columns)
    levenshtein_sample_cost_df.to_csv(levenshtein_sample_df_path)


In [95]:
alignment_sample_df_path = os.path.join(SAMPLE_DFS_DIR, "alignment_sample_cost_df.csv")

if os.path.exists(alignment_sample_df_path):
    alignment_sample_cost_df = pd.read_csv(alignment_sample_df_path, index_col=0)
else:
    num_samples = otus.shape[1]
    alignment_sample_cost_matrix = np.zeros((num_samples, num_samples))

    # Calculate the upper triangle of (samples x samples) total cost matrix
    # Each entry is sum of entries in the transport matrix generated by OT
    for i in range(num_samples):
        for j in range(num_samples):
            # Create normalized p_i and q_j vectors
            p = otus[otus.columns[i]].to_numpy()
            p_norm = p / p.sum()
            q = otus[otus.columns[j]].to_numpy()
            q_norm = q / q.sum()
            
            # Calculate total cost which is equivalent to the Wasserstein distance
            total_cost = ot.emd2(p, q, alignment_cost_matrix, log=False, check_marginals=False)
            
            # Assign the cost to the corresponding entry
            alignment_sample_cost_matrix[i, j] = total_cost
                

    alignment_sample_cost_df = pd.DataFrame(alignment_sample_cost_matrix, index=otus.columns, columns=otus.columns)
    alignment_sample_cost_df.to_csv(alignment_sample_df_path)

## Tree-based Distances

### Phylogenetic Distance

In [96]:
phylogenetic_sample_cost_path = os.path.join(SAMPLE_DFS_DIR, "phylogenetic_sample_cost_df.csv")

if os.path.exists(phylogenetic_sample_cost_path):
    phylogenetic_sample_cost_df = pd.read_csv(phylogenetic_sample_cost_path, index_col=0)
else:
    num_samples = otus.shape[1]
    phylogenetic_sample_cost_matrix = np.zeros((num_samples, num_samples))

    # Calculate the upper triangle of (samples x samples) total cost matrix
    # Each entry is sum of entries in the transport matrix generated by OT
    for i in range(num_samples):
        for j in range(num_samples):
            # Create normalized p_i and q_j vectors
            p = otus[otus.columns[i]].to_numpy()
            p_norm = p / p.sum()
            q = otus[otus.columns[j]].to_numpy()
            q_norm = q / q.sum()
            
            # Calculate total cost which is equivalent to the Wasserstein distance
            total_cost = ot.emd2(p, q, phylogenetic_cost_matrix, log=False, check_marginals=False)
            
            # Assign the cost to the corresponding entry
            phylogenetic_sample_cost_matrix[i, j] = total_cost
            
    phylogenetic_sample_cost_df = pd.DataFrame(phylogenetic_sample_cost_matrix, index=otus.columns, columns=otus.columns)
    phylogenetic_sample_cost_df.to_csv(phylogenetic_sample_cost_path)

## Genome Distance

In [97]:
genome_normalized_sample_cost_path = os.path.join(SAMPLE_DFS_DIR, "genome_normalized_sample_cost_df.csv")
genome_sample_cost_path = os.path.join(SAMPLE_DFS_DIR, "genome_sample_cost_df.csv")


if os.path.exists(genome_sample_cost_path) and os.path.exists(genome_normalized_sample_cost_path):
    genome_sample_cost_df = pd.read_csv(genome_sample_cost_path, index_col=0)
    genome_normalized_sample_cost_df = pd.read_csv(genome_normalized_sample_cost_path, index_col=0)    
else:
    # Some pandas shenanigans to get only columns where the sum of the 
    # subset of OTUs from the genome dataframe sum to > 0
    # This is to prevent infeasible problems in the OT formulation
    samples = otus.T.loc[otus.loc[genome_df_columns].sum(axis=0) != 0].T.columns.to_list()

    num_samples = len(samples)
    genome_sample_cost_matrix = np.zeros((num_samples, num_samples))
    genome_normalized_sample_cost_matrix = np.zeros((num_samples, num_samples))

    # Calculate the upper triangle of (samples x samples) total cost matrix
    # Each entry is sum of entries in the transport matrix generated by OT
    for i in range(num_samples):
        for j in range(num_samples):
            # Create normalized p_i and q_j vectors
            p = otus[samples[i]][genome_df_columns].to_numpy()
            q = otus[samples[j]][genome_df_columns].to_numpy()
            
            # Sometimes, subsetting p and q to only the OTUs for the genomes we have results in a 0 vector
            # When this is the case, the OT problem is infeasible
            if (p.sum().item() != 0) and (q.sum().item() != 0):
                p_norm = p / p.sum()
                q_norm = q / q.sum()
                total_cost = ot.emd2(p, q, genome_cost_matrix, log=False, check_marginals=False)
                total_cost_normalized = ot.emd2(p, q, genome_normalized_cost_matrix, log=False, check_marginals=False)
            else:
                total_cost_normalized = genome_normalized_cost_matrix.max(axis=1).sum()
                total_cost = genome_cost_matrix.max(axis=1).sum()
            
            # Assign the cost to the corresponding entry
            genome_normalized_sample_cost_matrix[i, j] = total_cost_normalized
            genome_sample_cost_matrix[i, j] = total_cost
                
    genome_sample_cost_df = pd.DataFrame(genome_sample_cost_matrix, index=samples, columns=samples)
    genome_normalized_sample_cost_df = pd.DataFrame(genome_normalized_sample_cost_matrix, index=samples, columns=samples)
    
    genome_sample_cost_df.to_csv(genome_sample_cost_path)
    genome_normalized_sample_cost_df.to_csv(genome_normalized_sample_cost_path)

## UniFrac

### Unweighted UniFrac

In [98]:
unweighted_unifrac_sample_cost_path = os.path.join(SAMPLE_DFS_DIR, "unweighted_unifrac_sample_cost_df.csv")

if os.path.exists(unweighted_unifrac_sample_cost_path):
    unweighted_unifrac_cost_df = pd.read_csv(unweighted_unifrac_sample_cost_path, index_col=0)    
else:
    # Calculating the Unweighted UniFrac cost matrix
    COLUMNS = otus.columns
    unweighted_unifrac_cost_matrix = np.zeros((len(COLUMNS), len(COLUMNS)))
    taxa = otus.index.to_list()

    for i, sample_i in enumerate(COLUMNS):
        for j, sample_j in enumerate(COLUMNS):
            if i == j:
                unweighted_unifrac_cost_matrix[i, j] = 0
            else:    
                u_counts = otus[sample_i].values.copy()
                v_counts = otus[sample_j].values.copy()
                uuf = unweighted_unifrac(u_counts=u_counts,
                                        v_counts=v_counts,
                                        taxa=taxa,
                                        tree=skbio_subtree)
                
                unweighted_unifrac_cost_matrix[i, j] = uuf

    unweighted_unifrac_cost_df = pd.DataFrame(unweighted_unifrac_cost_matrix, index=COLUMNS, columns=COLUMNS)
    unweighted_unifrac_cost_df.to_csv(unweighted_unifrac_sample_cost_path)

### Weighted UniFrac

In [99]:
weighted_unifrac_sample_cost_path = os.path.join(SAMPLE_DFS_DIR, "weighted_unifrac_sample_cost_df.csv")

if os.path.exists(weighted_unifrac_sample_cost_path):
    weighted_unifrac_cost_df = pd.read_csv(weighted_unifrac_sample_cost_path, index_col=0)    
else:
    # Get weighted unifrac
    weighted_unifrac_cost_matrix = np.zeros((len(otus.columns), len(otus.columns)))
    taxa = otus.index.to_list()

    for i, sample_i in enumerate(otus.columns):
        for j, sample_j in enumerate(otus.columns):
            if i == j:
                weighted_unifrac_cost_matrix[i, j] = 0
            else:
                u_counts = otus[sample_i].values
                v_counts = otus[sample_j].values
                wuf = weighted_unifrac(u_counts=u_counts,
                                    v_counts=v_counts,
                                    taxa=taxa,
                                    tree=skbio_subtree)

                weighted_unifrac_cost_matrix[i, j] = wuf

    weighted_unifrac_cost_df = pd.DataFrame(weighted_unifrac_cost_matrix, index=otus.columns, columns=otus.columns)
    weighted_unifrac_cost_df.to_csv(weighted_unifrac_sample_cost_path)

# Diagnosis Classification

## Baseline: Metadata-only 

In [100]:
metadata_columns_continuous = ["Age at diagnosis"]

metadata_columns_categorical = [#"Bowel frequency during the day",
                                "Soft drinks, tea or coffee with sugar (corn syrup, maple syrup, cane sugar, etc)",
                                "Diet soft drinks, tea or coffee with sugar (Stevia, Equal, Splenda etc)",
                                "Antibiotics",
                                "Occupation",
                                "Specify race",
                                "sex",
                                "sample"]

subset_metadata_df = ibd_metadata[metadata_columns_categorical].copy()
subset_metadata_df.set_index("sample", inplace=True)

In [None]:
# Some pandas shenanigans to get only columns where the sum of the 
# subset of OTUs from the genome dataframe sum to > 0
# This is to prevent infeasible problems in the OT formulation
samples = otus.T.loc[otus.loc[genome_df_columns].sum(axis=0) != 0].T.columns.to_list()

# Create a one-hot encoded metadata numpy array and pandas dataframe
enc = OneHotEncoder(categories='auto', sparse_output=False)
subset_metadata_one_hot = enc.fit_transform(subset_metadata_df.loc[samples])
subset_metadata_one_hot_df = pd.DataFrame(subset_metadata_one_hot, 
                                          index=samples, 
                                          columns=[f"metadata_{i}" for i in range(subset_metadata_one_hot.shape[1])])

sample_cost_dfs = {"OT_Levenshtein": levenshtein_sample_cost_df[samples].loc[samples].copy(),
                   "OT_Alignment": alignment_sample_cost_df[samples].loc[samples].copy(),
                   "OT_Genome": genome_sample_cost_df[samples].loc[samples].copy(),
                   "OT_Genome_Normalized": genome_normalized_sample_cost_df[samples].loc[samples].copy(),
                   "OT_Phylogenetic": phylogenetic_sample_cost_df[samples].loc[samples].copy(),
                   "Metadata": subset_metadata_df.loc[samples],
                   "Unweighted UniFrac": unweighted_unifrac_cost_df[samples].loc[samples].copy(),
                   "Weighted UniFrac": weighted_unifrac_cost_df[samples].loc[samples].copy(),
                   "OT_Levenshtein_plus_Metadata": pd.concat([levenshtein_sample_cost_df[samples].loc[samples], subset_metadata_one_hot_df], join="inner", axis=1),
                   "OT_Alignment_plus_Metadata": pd.concat([alignment_sample_cost_df[samples].loc[samples], subset_metadata_one_hot_df], join="inner", axis=1),
                   "OT_Genome_plus_Metadata": pd.concat([genome_sample_cost_df[samples].loc[samples], subset_metadata_one_hot_df], join="inner", axis=1),
                   "OT_Genome_Normalized_plus_Metadata": pd.concat([genome_normalized_sample_cost_df[samples].loc[samples], subset_metadata_one_hot_df], join="inner", axis=1),
                   "OT_Levenshtein_PCA_6": pd.DataFrame(PCA(n_components=6).fit_transform(levenshtein_sample_cost_df.loc[samples].to_numpy()), index=levenshtein_sample_cost_df.loc[samples].index, columns=[f"component_{i+1}" for i in range(6)]),
                   "OT_Alignment_PCA_6": pd.DataFrame(PCA(n_components=6).fit_transform(alignment_sample_cost_df.loc[samples].to_numpy()), index=alignment_sample_cost_df.loc[samples].index, columns=[f"component_{i+1}" for i in range(6)]),
                   "OT_Genome_PCA_6": pd.DataFrame(PCA(n_components=6).fit_transform(genome_sample_cost_df.loc[samples].to_numpy()), index=genome_sample_cost_df.loc[samples].index, columns=[f"component_{i+1}" for i in range(6)]),
                   "OT_Genome_Normalized_PCA_6": pd.DataFrame(PCA(n_components=6).fit_transform(genome_normalized_sample_cost_df.loc[samples].to_numpy()), index=genome_normalized_sample_cost_df.loc[samples].index, columns=[f"component_{i+1}" for i in range(6)]),
                   "Weighted_UniFrac_PCA_6": pd.DataFrame(PCA(n_components=6).fit_transform(weighted_unifrac_cost_df.loc[samples].to_numpy()), index=weighted_unifrac_cost_df.loc[samples].index, columns=[f"component_{i+1}" for i in range(6)]),
                   "Unweighted_UniFrac_PCA_6": pd.DataFrame(PCA(n_components=6).fit_transform(unweighted_unifrac_cost_df.loc[samples].to_numpy()), index=unweighted_unifrac_cost_df.loc[samples].index, columns=[f"component_{i+1}" for i in range(6)]),
}

METHODS_SHAPE = len(sample_cost_dfs.keys())
all_classifiers = ["random_forest", "decision_tree", "knn", "svc"]
score_types = ["f1", "accuracy"]
all_input_matrices = list(sample_cost_dfs.keys())
CLASSIFIERS_SHAPE = len(all_classifiers) * len(score_types)
RAW_SCORE_DF_SHAPE = len(all_classifiers) * len(score_types) * len(all_input_matrices)

multiindex_tuples = []

for _type in score_types:
    for _classifier in all_classifiers:
        multiindex_tuples.append((_type, _classifier))

multiindex = pd.MultiIndex.from_tuples(multiindex_tuples, names=["score_type", "classifier"])

multiindex_tuples_raw = []

for _type in score_types:
    for _classifier in all_classifiers:
        for _input_matrix in all_input_matrices:
            multiindex_tuples_raw.append((_type, _classifier, _input_matrix))

multiindex_raw = pd.MultiIndex.from_tuples(multiindex_tuples_raw, names=["score_type", "classifier", "input_matrix"])

empty_arr = np.zeros((CLASSIFIERS_SHAPE, METHODS_SHAPE), dtype=float)
raw_empty_arr = np.zeros((RAW_SCORE_DF_SHAPE, N_FOLDS), dtype=float)
raw_score_df = pd.DataFrame(raw_empty_arr,
                            columns=[f"{i}" for i in range(N_FOLDS)],
                            index=multiindex_raw,
                            copy=True)

score_df = pd.DataFrame(data=empty_arr,
                        columns=list(sample_cost_dfs.keys()),
                        index=multiindex,
                        copy=True)

std_deviation_df = pd.DataFrame(data=empty_arr,
                                columns=list(sample_cost_dfs.keys()),
                                index=multiindex,
                                copy=True)

predictions: List[pd.Series] = []

def generate_input_matrix(cost_df, _type, samples):
    if _type == "Metadata":
        enc = OneHotEncoder(categories='auto', sparse_output=False)
        X = enc.fit_transform(subset_metadata_df.loc[samples])
    else:
        X = cost_df.to_numpy()
    
    return X

outer_prog_bar = tqdm(sample_cost_dfs.items(), total=len(sample_cost_dfs.keys()))

for name, cost_df in outer_prog_bar:
    outer_prog_bar.set_description(desc=name)
    
    labels = ibd_metadata_diagnosis.loc[cost_df.index].values.ravel()
    X = generate_input_matrix(cost_df, _type=name, samples=samples)
    
    non_distance_matrices = [
        "Metadata",
        "OT_Levenshtein_plus_Metadata",
        "OT_Alignment_plus_Metadata",
        "OT_Genome_plus_Metadata",
        "OT_Genome_Normalized_plus_Metadata",
        "OT_Levenshtein_PCA_6",
        "OT_Alignment_PCA_6",
        "OT_Genome_PCA_6",
        "OT_Genome_Normalized_PCA_6",
        "Weighted_UniFrac_PCA_6",
        "Unweighted_UniFrac_PCA_6",
    ]
    
    classifiers = {"random_forest": RandomForestClassifier(random_state=42),
                   "decision_tree": DecisionTreeClassifier(random_state=42),
                   "knn": KNeighborsClassifier(n_neighbors=5),
                   "svc": SVC(kernel='linear', max_iter=1000000, class_weight='balanced') if name in non_distance_matrices else SVC(kernel='precomputed', max_iter=1000000, class_weight='balanced')}
    
    for cl_name, classifier in classifiers.items():
        le = LabelEncoder()
        y = le.fit_transform(labels)
        cv = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=42)

        # Keep track of predictions
        y_pred = cross_val_predict(classifier, X, y, cv=cv)
        prediction_series = pd.Series(y_pred, index=labels, name=f"{name}_{cl_name}")
        predictions.append(prediction_series)
        
        f1 = cross_val_score(classifier, X, y, cv=cv, scoring='f1_weighted')
        acc = cross_val_score(classifier, X, y, cv=cv, scoring='accuracy')
        
        raw_score_df.loc[("accuracy", cl_name, name)] = acc
        raw_score_df.loc[("f1", cl_name, name)] = f1
        
        score_df.at[("accuracy", cl_name), name] = acc.mean().item()
        score_df.at[("f1", cl_name), name] = f1.mean().item()
        
        std_deviation_df.at[("accuracy", cl_name), name] = acc.std().item()
        std_deviation_df.at[("f1", cl_name), name] = f1.std().item()

predictions.append(pd.Series(y, index=labels, name="Original"))
predictions_df = pd.concat(predictions, axis=1)

In [None]:
# Define colors for each metric (used feature in Cursor to generate unique colors)
# TODO: Create better way to do this
colors = {
    "OT_Levenshtein": '#1f77b4',
    "OT_Alignment": '#ff7f0e',
    "OT_Genome": '#2ca02c', 
    "OT_Genome_Normalized": '#d62728',
    "OT_Phylogenetic": '#9467bd',
    "Metadata": '#9467bd',
    "Unweighted UniFrac": '#8c564b',
    "Weighted UniFrac": '#e377c2',
    "OT_Levenshtein_plus_Metadata": '#7f7f7f',
    "OT_Alignment_plus_Metadata": '#bcbd22',
    "OT_Genome_plus_Metadata": '#17becf',
    "OT_Genome_Normalized_plus_Metadata": '#aec7e8',
    "OT_Levenshtein_PCA_6": '#ff9896',
    "OT_Alignment_PCA_6": '#98df8a', 
    "OT_Genome_PCA_6": '#c5b0d5',
    "OT_Genome_Normalized_PCA_6": '#c49c94',
    "Weighted_UniFrac_PCA_6": '#f7b6d2',
    "Unweighted_UniFrac_PCA_6": '#dbdb8d'
}

fig, axes = plt.subplots(4, 2, figsize=(16, 24))

for idx, classifier in enumerate(["random_forest", "decision_tree", "knn", "svc"]):
    for score_idx, score_type in enumerate(["f1", "accuracy"]):
        data = score_df.loc[score_type, classifier].sort_values().copy()
        errors = std_deviation_df.loc[score_type, classifier][data.index]
        
        axes[idx, score_idx].barh(data.index,
                                 data.values,
                                 color=[colors[x] for x in data.index],
                                 xerr=errors,
                                 error_kw={'marker': "x", "capsize": 4, "markeredgecolor": "black", "markerfacecolor": "black", "markersize": 4})
        axes[idx, score_idx].set_xlim(0, 1)
        axes[idx, score_idx].set_title(f"{score_type.capitalize()} - {classifier.replace('_', ' ').upper()}")

fig.suptitle("All Scores Not Stratified by Sex", y=1.01, fontsize=16)
fig.tight_layout()
fig.savefig(os.path.join("figures", "all_scores"), format="pdf", dpi=300, bbox_inches='tight')

## Visualization of Performance per Sample

In [None]:
# Sort the dataframes
cd_df = predictions_df.loc["CD"].sort_values(by=predictions_df.columns.to_list(), axis=0, ascending=False)
uc_df = predictions_df.loc["UC"].sort_values(by=predictions_df.columns.to_list(), axis=0, ascending=False)

fig, ax = plt.subplots(2, 1, figsize=(10, 20))

# Common kwargs for both heatmaps
kwargs = {
    'cmap': 'binary',
    'linewidths': 0.5,
    'linecolor': 'black',
    'xticklabels': True,
    'yticklabels': False,
    'annot': True,
    'cbar': False
}

# CD
original_column_index_cd = len(cd_df.columns) - 1
cd_df_sorted_columnwise = cd_df[cd_df.sum().sort_values(ascending=False).index].copy()
perfect_classification_columns_cd = [col for col in cd_df_sorted_columnwise.columns if (cd_df_sorted_columnwise[col] == 0).all()]

sns.heatmap(cd_df_sorted_columnwise, ax=ax[0], **kwargs)
ax[0].axvline(x=original_column_index_cd, color='red', linewidth=2, linestyle='--')
ax[0].set_title(f"CD Classifications on {cd_df.shape[0]} Samples", fontsize=12)

# UC
original_column_inde_uc = len(uc_df.columns) - 1
uc_df_sorted_columnwise = uc_df[uc_df.sum().sort_values(ascending=True).index].copy()
perfect_classification_columns_uc = [col for col in uc_df_sorted_columnwise.columns if (uc_df_sorted_columnwise[col] == 1).all()]

sns.heatmap(uc_df_sorted_columnwise, ax=ax[1], **kwargs)
ax[1].axvline(x=original_column_inde_uc, color='red', linewidth=2, linestyle='--')
ax[1].set_title(f"UC Classifications on {uc_df.shape[0]} Samples", fontsize=12)

# Color perfect classification column labels blue
for tick, label in enumerate(ax[0].get_xticklabels()):
    txt = label.get_text()
    
    if txt == "Original":
        continue
    
    if txt in perfect_classification_columns_cd:
        label.set_color('blue')
    elif txt in perfect_classification_columns_uc:
        label.set_color('red')
    elif (txt in perfect_classification_columns_uc) and (txt in perfect_classification_columns_cd):
        label.set_color('purple')
    else:
        label.set_color('black')


# Color perfect classification column labels blue
for tick, label in enumerate(ax[1].get_xticklabels()):
    txt = label.get_text()
    
    if txt == "Original":
        continue
    
    if txt in perfect_classification_columns_cd:
        label.set_color('blue')
    elif txt in perfect_classification_columns_uc:
        label.set_color('red')
    elif (txt in perfect_classification_columns_uc) and (txt in perfect_classification_columns_cd):
        label.set_color('purple')
    else:
        label.set_color('black')

fig.tight_layout()
plt.show()

## Visualization of Metadata Decision Tree

In [104]:
# enc = OneHotEncoder(categories='auto', sparse_output=False)
# X = enc.fit_transform(subset_metadata_df)

# plt.figure(figsize=(24, 10))

# dtree = DecisionTreeClassifier(random_state=42)

# X = generate_input_matrix(sample_cost_dfs["Metadata"], "Metadata", samples)
# labels = ibd_metadata_diagnosis.loc[sample_cost_dfs["Metadata"].index].values.ravel()
# le = LabelEncoder()
# y = le.fit_transform(labels)

# full_feature_names = enc.get_feature_names_out(subset_metadata_df.columns)
# truncated_feature_names = [f"{name[:8]}...{name[-20:]}" for name in full_feature_names]

# dtree.fit(X, y)
# plot_tree(dtree, 
#           fontsize=8, 
#           class_names=list(le.classes_),
#           feature_names=truncated_feature_names,
#           label='all',
#           filled=True)

# plt.show()

## Plotting Intra vs Inter Class Distances

In [None]:
# Some pandas shenanigans to get only columns where the sum of the 
# subset of OTUs from the genome dataframe sum to > 0
# This is to prevent infeasible problems in the OT formulation
samples = otus.T.loc[otus.loc[genome_df_columns].sum(axis=0) != 0].T.columns.to_list()

cd_indices = list(set(ibd_metadata_diagnosis[ibd_metadata_diagnosis["diagnosis"] == "CD"].index) & set(samples))
uc_indices = list(set(ibd_metadata_diagnosis[ibd_metadata_diagnosis["diagnosis"] == "UC"].index) & set(samples))

dfs = {
    'OT_Levenshtein': levenshtein_sample_cost_df,
    'OT_Alignment': alignment_sample_cost_df,
    "OT_Phylogenetic": phylogenetic_sample_cost_df,
    'Weighted UniFrac': weighted_unifrac_cost_df,
    'Unweighted UniFrac': unweighted_unifrac_cost_df,
    'OT_Genome_Normalized': genome_normalized_sample_cost_df,
    'OT_Genome': genome_sample_cost_df,
}

fig, ax = plt.subplots(8, 1, figsize=(10, 10), sharex=False, sharey=False)

for idx, (df_name, df) in enumerate(dfs.items()):
    intra_cd = df.loc[cd_indices][cd_indices].to_numpy().flatten()
    intra_uc = df.loc[uc_indices][uc_indices].to_numpy().flatten()
    inter_1 = df.loc[cd_indices][uc_indices].to_numpy().flatten()
    inter_2 = df.loc[uc_indices][cd_indices].to_numpy().flatten()

    sns.kdeplot(data=intra_cd, label=f"Intra Class Distance (CD)", ax=ax[idx], alpha=0.5, fill=True, warn_singular=False)
    sns.kdeplot(data=intra_uc, label=f"Intra Class Distance (UC)", ax=ax[idx], alpha=0.5, fill=True, warn_singular=False)
    sns.kdeplot(data=inter_1, label=f"Inter Class Distance", ax=ax[idx], alpha=0.5, fill=True, warn_singular=False)
    
    ax[idx].set_title(df_name)
    ax[idx].legend()

otus_l2_distances_cd = distance_matrix(otus[cd_indices].values, otus[cd_indices].values, p=2).flatten()
otus_l2_distances_uc = distance_matrix(otus[uc_indices].values, otus[uc_indices].values, p=2).flatten()
# otus_l2_distances = pd.DataFrame(distance_matrix(otus.values, otus.values), index=otus.columns.to_list())

sns.kdeplot(data=otus_l2_distances_cd, label="Intra Class L2 Distances (CD)", ax=ax[7], alpha=0.5, fill=True)
sns.kdeplot(data=otus_l2_distances_uc, label="Intra Class L2 Distances (UC)", ax=ax[7], alpha=0.5, fill=True)
# sns.kdeplot(data=otus_l2_distances_inter, label="Inter Class L2 Distances", ax=ax[6], alpha=0.5, fill=True)
ax[7].set_title("L2 Distance between Pairwise Normalized OTU Counts")
ax[7].legend()

# fig.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
fig.tight_layout()

fig.savefig("figures/intra_vs_inter_class_distances.pdf", format="pdf")
plt.show()

## Stratifying by Sex

In [106]:
male_samples = set(ibd_metadata[ibd_metadata["sex"] == "Male"]["sample"].to_list())
female_samples = set(ibd_metadata[ibd_metadata["sex"] == "Female"]["sample"].to_list())

# Some pandas shenanigans to get only columns where the sum of the 
# subset of OTUs from the genome dataframe sum to > 0
# This is to prevent infeasible problems in the OT formulation
valid_genome_samples = set(otus.T.loc[otus.loc[genome_df_columns].sum(axis=0) != 0].T.columns.to_list())

male_samples_set = set(ibd_metadata[ibd_metadata["sex"] == "Male"]["sample"].to_list())
female_samples_set = set(ibd_metadata[ibd_metadata["sex"] == "Female"]["sample"].to_list())

# Some pandas shenanigans to get only columns where the sum of the 
# subset of OTUs from the genome dataframe sum to > 0
# This is to prevent infeasible problems in the OT formulation
valid_genome_samples_set = set(otus.T.loc[otus.loc[genome_df_columns].sum(axis=0) != 0].T.columns.to_list())
valid_genome_samples_list = otus.T.loc[otus.loc[genome_df_columns].sum(axis=0) != 0].T.columns.to_list()

male_samples_list = list(male_samples_set & valid_genome_samples_set)
female_samples_list = list(female_samples_set & valid_genome_samples_set)

### Males

In [None]:
dfs = {
    "levenshtein": levenshtein_sample_cost_df.copy(),
    "alignment": alignment_sample_cost_df.copy(),
    "genome": genome_sample_cost_df.copy(),
    "genome_normalized": genome_normalized_sample_cost_df.copy()
}

fig, ax = plt.subplots(3, len(dfs.values()), figsize=(14, 9))

for idx, (name, df) in enumerate(dfs.items()):
    all_samples = PCA(n_components=6).fit_transform(df[valid_genome_samples_list].loc[valid_genome_samples_list].to_numpy())
    male_samples = PCA(n_components=6).fit_transform(df[male_samples_list].loc[male_samples_list].to_numpy())
    female_samples = PCA(n_components=6).fit_transform(df[female_samples_list].loc[female_samples_list].to_numpy())

    ax[0][idx].scatter(all_samples[:, 0], 
                all_samples[:, 1], 
                color=ibd_metadata_diagnosis.loc[valid_genome_samples_list]["diagnosis"].apply(lambda x: "red" if x == "CD" else "blue").values)

    ax[1][idx].scatter(male_samples[:, 0], 
                male_samples[:, 1], 
                color=ibd_metadata_diagnosis.loc[male_samples_list]["diagnosis"].apply(lambda x: "red" if x == "CD" else "blue").values)

    ax[2][idx].scatter(female_samples[:, 0], 
                female_samples[:, 1], 
                color=ibd_metadata_diagnosis.loc[female_samples_list]["diagnosis"].apply(lambda x: "red" if x == "CD" else "blue").values)
    
    if idx == 0:
        ax[0][idx].text(-0.6, 0.5, "Male + Female", rotation=90, transform=ax[0][idx].transAxes, va='center')
        ax[1][idx].text(-0.6, 0.5, "Male", rotation=90, transform=ax[1][idx].transAxes, va='center')
        ax[2][idx].text(-0.6, 0.5, "Female", rotation=90, transform=ax[2][idx].transAxes, va='center')
    
        ax[0][idx].set_ylabel("PCA 2")
        ax[1][idx].set_ylabel("PCA 2") 
        ax[2][idx].set_ylabel("PCA 2")
    
    ax[0][idx].set_title(name)
    
    # Add legend to bottom subplots
    if idx == len(dfs.items()) - 1:
        legend_elements = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='red', label='CD', markersize=10),
                         plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='blue', label='UC', markersize=10)]
        ax[2][idx].legend(handles=legend_elements, loc='center left', bbox_to_anchor=(1, 0.5))
    
fig.suptitle("6-component PCA by OT Cost Matrix Stratified by Sex")
fig.tight_layout()

In [None]:
samples = list(male_samples_set & valid_genome_samples_set)

# Create a one-hot encoded metadata numpy array and pandas dataframe
enc = OneHotEncoder(categories='auto', sparse_output=False)
subset_metadata_one_hot = enc.fit_transform(subset_metadata_df.loc[samples])
subset_metadata_one_hot_df = pd.DataFrame(subset_metadata_one_hot, 
                                          index=samples, 
                                          columns=[f"metadata_{i}" for i in range(subset_metadata_one_hot.shape[1])])

sample_cost_dfs = {"OT_Levenshtein": levenshtein_sample_cost_df[samples].loc[samples].copy(),
                   "OT_Alignment": alignment_sample_cost_df[samples].loc[samples].copy(),
                   "OT_Genome": genome_sample_cost_df[samples].loc[samples].copy(),
                   "OT_Genome_Normalized": genome_normalized_sample_cost_df[samples].loc[samples].copy(),
                   "Metadata": subset_metadata_df.loc[samples],
                   "Unweighted UniFrac": unweighted_unifrac_cost_df[samples].loc[samples].copy(),
                   "Weighted UniFrac": weighted_unifrac_cost_df[samples].loc[samples].copy(),
                   "OT_Levenshtein_plus_Metadata": pd.concat([levenshtein_sample_cost_df[samples].loc[samples], subset_metadata_one_hot_df], join="inner", axis=1),
                   "OT_Alignment_plus_Metadata": pd.concat([alignment_sample_cost_df[samples].loc[samples], subset_metadata_one_hot_df], join="inner", axis=1),
                   "OT_Genome_plus_Metadata": pd.concat([genome_sample_cost_df[samples].loc[samples], subset_metadata_one_hot_df], join="inner", axis=1),
                   "OT_Genome_Normalized_plus_Metadata": pd.concat([genome_normalized_sample_cost_df[samples].loc[samples], subset_metadata_one_hot_df], join="inner", axis=1),
                   "OT_Levenshtein_PCA_6": pd.DataFrame(PCA(n_components=6).fit_transform(levenshtein_sample_cost_df.loc[samples].to_numpy()), index=levenshtein_sample_cost_df.loc[samples].index, columns=[f"component_{i+1}" for i in range(6)]),
                   "OT_Alignment_PCA_6": pd.DataFrame(PCA(n_components=6).fit_transform(alignment_sample_cost_df.loc[samples].to_numpy()), index=alignment_sample_cost_df.loc[samples].index, columns=[f"component_{i+1}" for i in range(6)]),
                   "OT_Genome_PCA_6": pd.DataFrame(PCA(n_components=6).fit_transform(genome_sample_cost_df.loc[samples].to_numpy()), index=genome_sample_cost_df.loc[samples].index, columns=[f"component_{i+1}" for i in range(6)]),
                   "OT_Genome_Normalized_PCA_6": pd.DataFrame(PCA(n_components=6).fit_transform(genome_normalized_sample_cost_df.loc[samples].to_numpy()), index=genome_normalized_sample_cost_df.loc[samples].index, columns=[f"component_{i+1}" for i in range(6)]),
                   "Weighted_UniFrac_PCA_6": pd.DataFrame(PCA(n_components=6).fit_transform(weighted_unifrac_cost_df.loc[samples].to_numpy()), index=weighted_unifrac_cost_df.loc[samples].index, columns=[f"component_{i+1}" for i in range(6)]),
                   "Unweighted_UniFrac_PCA_6": pd.DataFrame(PCA(n_components=6).fit_transform(unweighted_unifrac_cost_df.loc[samples].to_numpy()), index=unweighted_unifrac_cost_df.loc[samples].index, columns=[f"component_{i+1}" for i in range(6)]),
}

METHODS_SHAPE = len(sample_cost_dfs.keys())
all_classifiers = ["random_forest", "decision_tree", "knn", "svc"]
score_types = ["f1", "accuracy"]
CLASSIFIERS_SHAPE = len(all_classifiers) * len(score_types)

multiindex_tuples = []

for _type in score_types:
    for _classifier in all_classifiers:
        multiindex_tuples.append((_type, _classifier))

multiindex = pd.MultiIndex.from_tuples(multiindex_tuples, names=["score_type", "classifier"])

empty_arr = np.zeros((CLASSIFIERS_SHAPE, METHODS_SHAPE), dtype=float)
score_df = pd.DataFrame(data=empty_arr,
                        columns=list(sample_cost_dfs.keys()),
                        index=multiindex,
                        copy=True)

std_deviation_df = pd.DataFrame(data=empty_arr,
                                columns=list(sample_cost_dfs.keys()),
                                index=multiindex,
                                copy=True)

predictions = []

def generate_input_matrix(cost_df, _type, samples):
    if _type == "Metadata":
        enc = OneHotEncoder(categories='auto', sparse_output=False)
        X = enc.fit_transform(subset_metadata_df.loc[samples])
    else:
        X = cost_df.to_numpy()
    
    return X

outer_prog_bar = tqdm(sample_cost_dfs.items(), total=len(sample_cost_dfs.keys()))

for name, cost_df in outer_prog_bar:
    outer_prog_bar.set_description(desc=name)
    
    labels = ibd_metadata_diagnosis.loc[cost_df.index].values.ravel()
    X = generate_input_matrix(cost_df, _type=name, samples=samples)
    
    non_distance_matrices = [
        "Metadata",
        "OT_Levenshtein_plus_Metadata",
        "OT_Alignment_plus_Metadata",
        "OT_Genome_plus_Metadata",
        "OT_Genome_Normalized_plus_Metadata",
        "OT_Levenshtein_PCA_6",
        "OT_Alignment_PCA_6",
        "OT_Genome_PCA_6",
        "OT_Genome_Normalized_PCA_6",
        "Weighted_UniFrac_PCA_6",
        "Unweighted_UniFrac_PCA_6",
    ]
    
    classifiers = {"random_forest": RandomForestClassifier(random_state=42),
                   "decision_tree": DecisionTreeClassifier(random_state=42),
                   "knn": KNeighborsClassifier(n_neighbors=5),
                   "svc": SVC(kernel='linear', max_iter=1000000) if name in non_distance_matrices else SVC(kernel='precomputed', max_iter=1000000)}
    
    for cl_name, classifier in classifiers.items():
        le = LabelEncoder()
        y = le.fit_transform(labels)
        cv = KFold(n_splits=N_FOLDS, shuffle=True, random_state=42)

        # Keep track of predictions
        y_pred = cross_val_predict(classifier, X, y, cv=cv)
        prediction_series = pd.Series(y_pred, index=labels, name=f"{name}_{cl_name}")
        predictions.append(prediction_series)
        
        f1 = cross_val_score(classifier, X, y, cv=cv, scoring='f1_weighted')
        acc = cross_val_score(classifier, X, y, cv=cv, scoring='accuracy')
        score_df.at[("accuracy", cl_name), name] = acc.mean().item()
        score_df.at[("f1", cl_name), name] = f1.mean().item()
        
        std_deviation_df.at[("accuracy", cl_name), name] = acc.std().item()
        std_deviation_df.at[("f1", cl_name), name] = f1.std().item()

predictions.append(pd.Series(y, index=labels, name="Original"))
predictions_df = pd.concat(predictions, axis=1)

In [None]:
# Define colors for each metric (used feature in Cursor to generate unique colors)
# TODO: Create better way to do this
colors = {
    "OT_Levenshtein": '#1f77b4',
    "OT_Alignment": '#ff7f0e',
    "OT_Genome": '#2ca02c', 
    "OT_Genome_Normalized": '#d62728',
    "OT_Phylogenetic": '#9467bd',
    "Metadata": '#9467bd',
    "Unweighted UniFrac": '#8c564b',
    "Weighted UniFrac": '#e377c2',
    "OT_Levenshtein_plus_Metadata": '#7f7f7f',
    "OT_Alignment_plus_Metadata": '#bcbd22',
    "OT_Genome_plus_Metadata": '#17becf',
    "OT_Genome_Normalized_plus_Metadata": '#aec7e8',
    "OT_Levenshtein_PCA_6": '#ff9896',
    "OT_Alignment_PCA_6": '#98df8a', 
    "OT_Genome_PCA_6": '#c5b0d5',
    "OT_Genome_Normalized_PCA_6": '#c49c94',
    "Weighted_UniFrac_PCA_6": '#f7b6d2',
    "Unweighted_UniFrac_PCA_6": '#dbdb8d'
}

fig, axes = plt.subplots(4, 2, figsize=(16, 24))

for idx, classifier in enumerate(["random_forest", "decision_tree", "knn", "svc"]):
    for score_idx, score_type in enumerate(["f1", "accuracy"]):
        data = score_df.loc[score_type, classifier].sort_values().copy()
        errors = std_deviation_df.loc[score_type, classifier][data.index]
        
        axes[idx, score_idx].barh(data.index,
                                 data.values,
                                 color=[colors[x] for x in data.index],
                                 xerr=errors,
                                 error_kw={'marker': "x", "capsize": 4, "markeredgecolor": "black", "markerfacecolor": "black", "markersize": 4})
        axes[idx, score_idx].set_xlim(0, 1)
        axes[idx, score_idx].set_title(f"{score_type.capitalize()} - {classifier.replace('_', ' ').upper()}")

fig.suptitle("All Scores - Males Only", y=1.01, fontsize=16)
fig.tight_layout()

In [None]:
# Sort the dataframes
cd_df = predictions_df.loc["CD"].sort_values(by=predictions_df.columns.to_list(), axis=0, ascending=False)
uc_df = predictions_df.loc["UC"].sort_values(by=predictions_df.columns.to_list(), axis=0, ascending=False)

# Create the plot
fig, ax = plt.subplots(1, 2, figsize=(20, 10))
sns.set_theme(font_scale=0.6)

# Common kwargs for both heatmaps
kwargs = {
    'cmap': 'binary',
    'linewidths': 0.5,
    'linecolor': 'black',
    'xticklabels': True,
    'yticklabels': True,
    'annot': True,
    'cbar': False
}

# Create the heatmaps with masks
# sns.heatmap(cd_df[cd_df.any(axis=1)], ax=ax[0], **kwargs)
# sns.heatmap(uc_df[uc_df[uc_df.columns[:-1]].any(axis=1)], ax=ax[1], **kwargs)

sns.heatmap(cd_df, ax=ax[0], **kwargs)
sns.heatmap(uc_df, ax=ax[1], **kwargs)

first_metadata_col_index = sorted([i for i, col in enumerate(cd_df.columns) if col.startswith("Metadata")])[0]
first_unifrac_col_index = sorted([i for i, col in enumerate(cd_df.columns) if "UniFrac" in col])[0]
original_column_index = len(cd_df.columns) - 1

# Formatting
for a in ax:
    # Rotate labels for better readability
    a.set_xticklabels(a.get_xticklabels(), rotation=45, ha='right')
    a.set_yticklabels(a.get_yticklabels(), rotation=0)
    
    # # Optional: Add vertical lines to further emphasize groups
    # a.axvline(x=first_metadata_col_index, color='red', linewidth=2, linestyle='--')  # After OT methods
    # a.axvline(x=first_unifrac_col_index, color='red', linewidth=2, linestyle='--')  # After Metadata
    # a.axvline(x=original_column_index, color='red', linewidth=2, linestyle='--') # After UniFrac

# Add titles
ax[0].set_title('CD Predictions')
ax[1].set_title('UC Predictions')
fig.suptitle("Predictions vs Original Data by Label - Males Only")

# Adjust layout to prevent label cutoff
plt.tight_layout()
plt.show()

### Females

In [None]:
samples = list(female_samples_set & valid_genome_samples_set)

# Create a one-hot encoded metadata numpy array and pandas dataframe
enc = OneHotEncoder(categories='auto', sparse_output=False)
subset_metadata_one_hot = enc.fit_transform(subset_metadata_df.loc[samples])
subset_metadata_one_hot_df = pd.DataFrame(subset_metadata_one_hot, 
                                          index=samples, 
                                          columns=[f"metadata_{i}" for i in range(subset_metadata_one_hot.shape[1])])

sample_cost_dfs = {"OT_Levenshtein": levenshtein_sample_cost_df[samples].loc[samples].copy(),
                   "OT_Alignment": alignment_sample_cost_df[samples].loc[samples].copy(),
                   "OT_Genome": genome_sample_cost_df[samples].loc[samples].copy(),
                   "OT_Genome_Normalized": genome_normalized_sample_cost_df[samples].loc[samples].copy(),
                   "Metadata": subset_metadata_df.loc[samples],
                   "Unweighted UniFrac": unweighted_unifrac_cost_df[samples].loc[samples].copy(),
                   "Weighted UniFrac": weighted_unifrac_cost_df[samples].loc[samples].copy(),
                   "OT_Levenshtein_plus_Metadata": pd.concat([levenshtein_sample_cost_df[samples].loc[samples], subset_metadata_one_hot_df], join="inner", axis=1),
                   "OT_Alignment_plus_Metadata": pd.concat([alignment_sample_cost_df[samples].loc[samples], subset_metadata_one_hot_df], join="inner", axis=1),
                   "OT_Genome_plus_Metadata": pd.concat([genome_sample_cost_df[samples].loc[samples], subset_metadata_one_hot_df], join="inner", axis=1),
                   "OT_Genome_Normalized_plus_Metadata": pd.concat([genome_normalized_sample_cost_df[samples].loc[samples], subset_metadata_one_hot_df], join="inner", axis=1),
                   "OT_Levenshtein_PCA_6": pd.DataFrame(PCA(n_components=6).fit_transform(levenshtein_sample_cost_df.loc[samples].to_numpy()), index=levenshtein_sample_cost_df.loc[samples].index, columns=[f"component_{i+1}" for i in range(6)]),
                   "OT_Alignment_PCA_6": pd.DataFrame(PCA(n_components=6).fit_transform(alignment_sample_cost_df.loc[samples].to_numpy()), index=alignment_sample_cost_df.loc[samples].index, columns=[f"component_{i+1}" for i in range(6)]),
                   "OT_Genome_PCA_6": pd.DataFrame(PCA(n_components=6).fit_transform(genome_sample_cost_df.loc[samples].to_numpy()), index=genome_sample_cost_df.loc[samples].index, columns=[f"component_{i+1}" for i in range(6)]),
                   "OT_Genome_Normalized_PCA_6": pd.DataFrame(PCA(n_components=6).fit_transform(genome_normalized_sample_cost_df.loc[samples].to_numpy()), index=genome_normalized_sample_cost_df.loc[samples].index, columns=[f"component_{i+1}" for i in range(6)]),
                   "Weighted_UniFrac_PCA_6": pd.DataFrame(PCA(n_components=6).fit_transform(weighted_unifrac_cost_df.loc[samples].to_numpy()), index=weighted_unifrac_cost_df.loc[samples].index, columns=[f"component_{i+1}" for i in range(6)]),
                   "Unweighted_UniFrac_PCA_6": pd.DataFrame(PCA(n_components=6).fit_transform(unweighted_unifrac_cost_df.loc[samples].to_numpy()), index=unweighted_unifrac_cost_df.loc[samples].index, columns=[f"component_{i+1}" for i in range(6)]),
}

METHODS_SHAPE = len(sample_cost_dfs.keys())
all_classifiers = ["random_forest", "decision_tree", "knn", "svc"]
score_types = ["f1", "accuracy"]
CLASSIFIERS_SHAPE = len(all_classifiers) * len(score_types)

multiindex_tuples = []

for _type in score_types:
    for _classifier in all_classifiers:
        multiindex_tuples.append((_type, _classifier))

multiindex = pd.MultiIndex.from_tuples(multiindex_tuples, names=["score_type", "classifier"])

empty_arr = np.zeros((CLASSIFIERS_SHAPE, METHODS_SHAPE), dtype=float)
score_df = pd.DataFrame(data=empty_arr,
                        columns=list(sample_cost_dfs.keys()),
                        index=multiindex,
                        copy=True)

std_deviation_df = pd.DataFrame(data=empty_arr,
                                columns=list(sample_cost_dfs.keys()),
                                index=multiindex,
                                copy=True)

predictions = []

def generate_input_matrix(cost_df, _type, samples):
    if _type == "Metadata":
        enc = OneHotEncoder(categories='auto', sparse_output=False)
        X = enc.fit_transform(subset_metadata_df.loc[samples])
    else:
        X = cost_df.to_numpy()
    
    return X

outer_prog_bar = tqdm(sample_cost_dfs.items(), total=len(sample_cost_dfs.keys()))

for name, cost_df in outer_prog_bar:
    outer_prog_bar.set_description(desc=name)
    
    labels = ibd_metadata_diagnosis.loc[cost_df.index].values.ravel()
    X = generate_input_matrix(cost_df, _type=name, samples=samples)
    
    non_distance_matrices = [
        "Metadata",
        "OT_Levenshtein_plus_Metadata",
        "OT_Alignment_plus_Metadata",
        "OT_Genome_plus_Metadata",
        "OT_Genome_Normalized_plus_Metadata",
        "OT_Levenshtein_PCA_6",
        "OT_Alignment_PCA_6",
        "OT_Genome_PCA_6",
        "OT_Genome_Normalized_PCA_6",
        "Weighted_UniFrac_PCA_6",
        "Unweighted_UniFrac_PCA_6",
    ]
    
    classifiers = {"random_forest": RandomForestClassifier(random_state=42),
                   "decision_tree": DecisionTreeClassifier(random_state=42),
                   "knn": KNeighborsClassifier(n_neighbors=5),
                   "svc": SVC(kernel='linear', max_iter=1000000) if name in non_distance_matrices else SVC(kernel='precomputed', max_iter=1000000)}
    
    for cl_name, classifier in classifiers.items():
        le = LabelEncoder()
        y = le.fit_transform(labels)
        cv = KFold(n_splits=N_FOLDS, shuffle=True, random_state=42)

        # Keep track of predictions
        y_pred = cross_val_predict(classifier, X, y, cv=cv)
        prediction_series = pd.Series(y_pred, index=labels, name=f"{name}_{cl_name}")
        predictions.append(prediction_series)
        
        f1 = cross_val_score(classifier, X, y, cv=cv, scoring='f1_weighted')
        acc = cross_val_score(classifier, X, y, cv=cv, scoring='accuracy')
        score_df.at[("accuracy", cl_name), name] = acc.mean().item()
        score_df.at[("f1", cl_name), name] = f1.mean().item()
        
        std_deviation_df.at[("accuracy", cl_name), name] = acc.std().item()
        std_deviation_df.at[("f1", cl_name), name] = f1.std().item()

predictions.append(pd.Series(y, index=labels, name="Original"))
predictions_df = pd.concat(predictions, axis=1)

In [None]:
# Define colors for each metric (used feature in Cursor to generate unique colors)
# TODO: Create better way to do this
colors = {
    "OT_Levenshtein": '#1f77b4',
    "OT_Alignment": '#ff7f0e',
    "OT_Genome": '#2ca02c', 
    "OT_Genome_Normalized": '#d62728',
    "OT_Phylogenetic": '#9467bd',
    "Metadata": '#9467bd',
    "Unweighted UniFrac": '#8c564b',
    "Weighted UniFrac": '#e377c2',
    "OT_Levenshtein_plus_Metadata": '#7f7f7f',
    "OT_Alignment_plus_Metadata": '#bcbd22',
    "OT_Genome_plus_Metadata": '#17becf',
    "OT_Genome_Normalized_plus_Metadata": '#aec7e8',
    "OT_Levenshtein_PCA_6": '#ff9896',
    "OT_Alignment_PCA_6": '#98df8a', 
    "OT_Genome_PCA_6": '#c5b0d5',
    "OT_Genome_Normalized_PCA_6": '#c49c94',
    "Weighted_UniFrac_PCA_6": '#f7b6d2',
    "Unweighted_UniFrac_PCA_6": '#dbdb8d'
}

fig, axes = plt.subplots(4, 2, figsize=(16, 24))

for idx, classifier in enumerate(["random_forest", "decision_tree", "knn", "svc"]):
    for score_idx, score_type in enumerate(["f1", "accuracy"]):
        data = score_df.loc[score_type, classifier].sort_values().copy()
        errors = std_deviation_df.loc[score_type, classifier][data.index]
        
        axes[idx, score_idx].barh(data.index,
                                 data.values,
                                 color=[colors[x] for x in data.index],
                                 xerr=errors,
                                 error_kw={'marker': "x", "capsize": 4, "markeredgecolor": "black", "markerfacecolor": "black", "markersize": 4})
        axes[idx, score_idx].set_xlim(0, 1)
        axes[idx, score_idx].set_title(f"{score_type.capitalize()} - {classifier.replace('_', ' ').upper()}")

fig.suptitle("All Scores - Females Only", y=1.01, fontsize=16)
fig.tight_layout()