# Figure 5 data: mouse association shifts
This script generates xml files for the networks showing changes in associations of Akkemansia and Lactobacillus in figure 5 of the paper. The xml files can be viewed in cytoscape and combined in adobe illustrator to generate figure 5.

### Before you start
This notebook assumes the analysis on the mouse dataset have been run and results are located in `MCSPACE_paper/results/analysis/Mouse`. Refer to the README in `scripts/analysis` for the analysis pipeline and more details.

In [1]:
import numpy as np
import torch
from mcspace.model import MCSPACE
from mcspace.trainer import train_model
from mcspace.data_utils import get_data, get_mouse_diet_perturbations_dataset

from mcspace import utils as ut
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from mcspace.dataset import DataSet
import pandas as pd

import ete3
from Bio import SeqIO, Phylo

import matplotlib.colors as mcolors
import networkx as nx
import scipy.cluster.hierarchy as sch

# Paths

Note: Paths are relative to this notebook, which is assumed to be located in `MCSPACE_paper/figures`

In [2]:
rootpath = Path("../../")
basepath = Path("./")

In [3]:
respath = rootpath / "MCSPACE_paper" / "results" / "analysis" / "Mouse" 
treepath = rootpath / "MCSPACE_paper" / "datasets" / "mouse_experiments" / "output" 
treefile = "newick_tree_query_reads.nhx"

In [4]:
outpath = basepath / "paper_figures" / "mouse_association_networks"
outpath.mkdir(exist_ok=True, parents=True)

# Load analysis results

In [5]:
thetadf = pd.read_csv(respath / "assemblages.csv")
betadf = pd.read_csv(respath / "assemblage_proportions.csv", index_col=0)
pertsdf = pd.read_csv(respath / "perturbation_bayes_factors.csv", index_col=0)
radf= pd.read_csv(respath / "relative_abundances.csv", index_col=0)

In [6]:
taxlevels = ['Otu', 'Domain', 'Phylum', 'Class', 'Order', 'Family', 'Genus', 'Species']
taxonomy = thetadf[taxlevels].copy()
taxonomy = taxonomy.set_index("Otu")
thetadf = thetadf.set_index(taxlevels)
radf.columns = radf.columns.astype(int)
avebeta = ut.get_subj_averaged_assemblage_proportions(betadf)

# Export plots to graphML files for cytoscape

In [7]:
def average_diet_times(df, diet_times):
    df2 = df.copy()
    diets_keep = list(diet_times.keys())
    for k in diets_keep:
        df2[k] = df.loc[:,diet_times[k]].mean(axis=1)
    df2 = df2[diets_keep]
    return df2

In [8]:
diet_times = {'S1': [10],
'HF': [18],
'HFHF': [43],
'LP': [65]}

## Get main taxa

Select taxa of interest using a criteria of having at least a 5% abundance on at least 3 diets

In [9]:
radiets = average_diet_times(radf, diet_times)

In [10]:
otu_threshold = 0.05
n_diets = 3
otu_plot = radiets.index[((radiets > otu_threshold).sum(axis=1)>=n_diets)]

In [11]:
print(otu_plot)

Index(['Otu2', 'Otu1'], dtype='object', name='Otu')


# Output xml files for each taxon

In [12]:
for oidx in otu_plot:
    otu_name = ut.get_lowest_level_name(oidx, taxonomy)

    # get edges and node weights
    alpha = ut.get_assoc_scores(thetadf, avebeta, oidx)
    alphasub = average_diet_times(alpha, diet_times)
    ew = ut.filter_assoc_scores(alphasub, radiets, oidx, ra_threshold=0.01, edge_threshold=0.01)
    nw = radiets.loc[ew.index,:]

    # update labels for taxa 
    nw3 = ut.update_names(nw, taxonomy)
    ew3 = ut.update_names(ew, taxonomy)
    
    nw3.to_csv(outpath / f"node_data_{oidx}.csv")
    ew3.to_csv(outpath / f"edge_data_{oidx}.csv")

    # output to file
    ut.output_association_network_to_graphML(oidx, nw3, ew3, taxonomy, outpath / f'{otu_name}.xml')