# Heatmaps
This notebook generates the heatmaps figures for visual inspection of relative abundances across the pipelines.

In [None]:
import sys #noqa
sys.path.append("../../") #noqa

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
plt.rcParams.update({'figure.max_open_warning': 0})

from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.colors import LogNorm, Normalize
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from scipy.special import rel_entr
import os
from python_src.compositions import multiplicative_aitchison, uniform_replace_zeroes
from python_src.figures_utils import get_all_expected, generate_experimental_df, get_relabund_files, fully_combined
from skbio.stats.composition import multiplicative_replacement
from scipy.spatial.distance import braycurtis, euclidean
from scipy.stats import pearsonr
from dataclasses import dataclass
from python_src.compositions import clr
from typing import Tuple, List
from utils.data_paths import make_data_list

import tarfile

pdf_output = PdfPages("heatmaps.pdf")
pdf_replicate_output = PdfPages("heatmaps_replicates.pdf")

In [None]:
def rreplace(s, old, new, occurrence):
    li = s.rsplit(old, occurrence)
    return new.join(li)

In [None]:
def add_plot_decorations(ax: plt.Axes, n_rows: int):
    """
    Adds a rectangle around the first column of the heatmap.
    """
    # Add a rectangle around the first column
    rect = plt.Rectangle((0, 0), 1, n_rows, facecolor="None", edgecolor="red", linewidth=2)
    ax.add_patch(rect)

In [None]:
wanted_sources = ["biobakery3", "biobakery4", "jams", "wgsa2", "woltka"]
def heatmap(obs_root: str, exp_root: str, community: str, rank: str = "genus", many_to_one: bool = True, png: bool = False):
    """
    Plots a heatmap of relative abundances on the far left column with the expected as the other columns.

    Parameters
        Observed root: str
            The root directory of the observed data.
        Expected root: str
            The root directory of the expected data.
        Community: str
            The community to plot.
        Rank: str
            The rank to plot.
        Many to one: bool
            Whether or not the expected data is many to one.
        PNG: bool
            Whether or not to save the plot as a png.
    Returns
        Combined dataframe: pd.DataFrame
            The dataframe of the combined data.
    """
    full_df = fully_combined(obs_root, exp_root, rank=rank)
    exp_df = full_df[full_df["Source"] == "Expected"]
    obs_df = full_df[full_df["Source"] != "Expected"]

    sample_groups = obs_df.groupby("SampleID")
    for smpl, smpl_df in sample_groups:
        # i.e., if it is many to one, then we only have one expcected sample.
        if many_to_one:
            exp_smpl_df = exp_df
        else:
            # Have to choose where the expected data is also from the same sample. Might want to check figures.ipynb.
            exp_smpl_df = exp_df[exp_df["SampleID"] == smpl]

        heatmap_df = pd.DataFrame()
        
        for src, src_df in smpl_df.groupby("Source"):
            if src not in wanted_sources:
                continue
            # Join left on the index
            joined_df = exp_smpl_df.copy().join(src_df, how="left", lsuffix="_exp", rsuffix="_obs")

            # We need the format of feature on index and RA of each sample on columns.
            # Set genus_exp as index
            joined_df = joined_df.set_index(f"{rank}_exp")

            # We only need RA_obs columns.
            joined_df = joined_df[["RA_exp", "RA_obs"]]

            # Fill in missing values with 0
            joined_df = joined_df.fillna(0)

            rename_dict = {"RA_exp": "Expected", "RA_obs": src}
            joined_df = joined_df.rename(columns=rename_dict)
            # joined_df["Source"] = src

            heatmap_df = pd.concat([heatmap_df, joined_df], axis=1)
        
        # Drop duplicate columns since "Expected" is duplicated.
        heatmap_df = heatmap_df.loc[:, ~heatmap_df.columns.duplicated()]
        # heatmap_df["Sample"] = smpl

        n_rows = len(heatmap_df.index)
        g = sns.heatmap(heatmap_df, cmap="viridis", cbar_kws={'label': 'Relative Abundance'}, annot=False, linewidths=0.5, norm=LogNorm())
        g.set(ylabel=rank)

        add_plot_decorations(g, n_rows)

        title = f"{smpl} for {community.capitalize()} at {rank.capitalize()} Heatmap"
        g.set_title(title)

        joined_title = title.replace(" ", "_")
        # save as png to images folder
        if png:
            g.figure.savefig(f"images/{joined_title}", bbox_inches='tight', dpi=300)

        if many_to_one:
            pdf_replicate_output.savefig(g.figure, bbox_inches='tight', dpi=300)
            plt.close(g.figure)
            
        else:
            pdf_output.savefig(g.figure, bbox_inches='tight', dpi=300)
            plt.close(g.figure)

In [None]:
def make_heatmaps(png: bool):
    """ 
    Makes heatmaps for all of the data. 
    
    Takes paths from the make_data_list function from data_paths.py. 
    Then, it makes a heatmap for each of the data using the heatmap function.

    Parameters:
        png: bool
            Whether or not to save the plot as a png.
    """
    dpaths = make_data_list()

    # many_to_one_df = pd.DataFrame()
    for p in dpaths:
        pipeline = p.path.split("/")[-1]
        if pipeline == "gut" or pipeline == "tongue":
            continue

        # Replace only the last occurence of the string "pipeline" with "expected"
        exp_root = rreplace(p.path, "pipeline", "expected_pipeline", 1)

        print(pipeline)
        print(exp_root)

        if not os.path.exists(exp_root):
            raise FileNotFoundError(f"Expected pipeline directory {exp_root} does not exist.")

        if pipeline in ["hilo", "mixed", "tourlousse"]:
            # heatmap(p.path, exp_root, pipeline, rank="genus", many_to_one=True, png=png)
            heatmap(p.path, exp_root, pipeline, rank="species", many_to_one=True, png=png)

        else:
            # heatmap(p.path, exp_root, pipeline, rank="genus", many_to_one=False, png=png)
            heatmap(p.path, exp_root, pipeline, rank="species", many_to_one=False, png=png)

In [None]:
def tar_images():
    """ Packs the "images" folder into a tar.gz file. """
    if len(os.listdir("images")) == 0:
        return

    with tarfile.open("images.tar.gz", "w:gz") as tar:
        tar.add("images", arcname=os.path.basename("images"))

    # Clear the images folder
    for f in os.listdir("images"):
        os.remove(os.path.join("images", f))

In [None]:
make_heatmaps(True)
tar_images()

pdf_output.close()
pdf_replicate_output.close()