# Get estcounts for each taxon at each sample



## Setup

In [None]:
import os 
import gc
import re
import csv
import glob
import math
import umap
import json
import itertools
import numpy as np
import pandas as pd
import seaborn as sns
from time import time
from tqdm import tqdm
from scipy import stats
from collections import * 
from sklearn import cluster
from sklearn import decomposition
from ete4 import NCBITaxa, Tree
import matplotlib.pyplot as plt
import matplotlib.colors as pltc
from scipy.spatial import distance
from scipy.cluster import hierarchy
from matplotlib.lines import Line2D
import matplotlib.patches as mpatches


In [None]:
import sys
sys.path.append('../repo-armbrust-metat-search')

In [None]:
import functions.fn_metat_files as fnf

In [None]:
ncbi = NCBITaxa()

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
os.getcwd()

In [None]:
workdir = '/scratch/bgrodner/p_calceolata_enterobactin'
os.chdir(workdir)


In [None]:
os.getcwd()

In [None]:
os.listdir()

Plotting

In [None]:
def general_plot(
    xlabel="", ylabel="", ft=12, dims=(5, 3), col="k", lw=1, pad=0, tr_spines=True
):
    fig, ax = plt.subplots(figsize=(dims[0], dims[1]), tight_layout={"pad": pad})
    for i in ax.spines:
        ax.spines[i].set_linewidth(lw)
    if not tr_spines:
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
    else:
        ax.spines["top"].set_color(col)
        ax.spines["right"].set_color(col)
    ax.spines["bottom"].set_color(col)
    ax.spines["left"].set_color(col)
    ax.tick_params(direction="in", labelsize=ft, color=col, labelcolor=col)
    ax.set_xlabel(xlabel, fontsize=ft, color=col)
    ax.set_ylabel(ylabel, fontsize=ft, color=col)
    ax.patch.set_alpha(0)
    return (fig, ax)

def plot_umap(
    embedding,
    figsize=(10, 10),
    markersize=10,
    alpha=0.5,
    colors="k",
    xticks=[],
    yticks=[],
    markerstyle='o',
    cmap_name='tab20',
    cl_lab=False
):
    fig, ax = general_plot(dims=figsize)
    if isinstance(markerstyle, str):
        ax.scatter(
            embedding[:, 0],
            embedding[:, 1],
            s=markersize,
            alpha=alpha,
            c=colors,
            edgecolors="none",
            marker=markerstyle,
            cmap=cmap_name
        )
    else:
        for e0, e1, c, m in zip(
            embedding[:, 0], 
            embedding[:, 1],
            colors,
            markerstyle 
        ):
            ax.scatter(
                e0,
                e1,
                s=markersize,
                alpha=alpha,
                c=c,
                edgecolors="none",
                marker=m
            )
    ax.set_aspect("equal")
    if len(xticks) > 0:
        ax.set_xticks(xticks)
    if len(yticks) > 0:
        ax.set_yticks(yticks)
    ax.set_xlabel("UMAP 1")
    ax.set_ylabel("UMAP 2")
    return fig, ax


#### Get KO dict

Get dataframe

In [None]:
ko_fn = "../iron_ko_contigs/ko00001.json"
database = list()
for _, v in pd.read_json(ko_fn).iterrows():
    d = v["children"]
    cat_1 = d["name"]
    for child_1 in d["children"]:
        cat_2 = child_1["name"] # Module?
        for child_2 in child_1["children"]:
            cat_3 = child_2["name"]
            if "children" in child_2:
                for child_3 in child_2["children"]:
                    cat_4 = child_3["name"]
                    fields = [cat_1, cat_2, cat_3, cat_4]
                    database.append(fields)
df_kegg = pd.DataFrame(database, columns=["Level_A", "Level_B", "Level_C", "Level_D"])
df_kegg.shape


In [None]:
ld = df_kegg['Level_D'].values
ld[:5]

In [None]:
dict_ko_name = {}
for name in ld:
    ko = re.search(r"^\w+",name)[0]
    dict_ko_name[ko] = name

Get metadata table

In [None]:
metadata_path = "/scratch/bgrodner/iron_ko_contigs/metat_search_results/dicts_iron_KO_contig/tidy_tables/merge_all/iron_KOs.txt-metadata.csv"  # input('Enter the filepath of your batch metadata file:')
metadata = pd.read_csv(metadata_path)
metadata.iloc[[1,100,200],:]

Correct G1 S11C1 latitudes

In [None]:
lats_new = []
for i, row in metadata.iterrows():
    if (row['sample'] == 'S11C1') & ('G1' in row['assembly']):
        lats_new.append('36.569deg')
    else:
        lats_new.append(row.latitude)
metadata['latitude'] = lats_new
metadata.iloc[[1,100,200],:]

## Example file

### Load contig taxon map

filename

In [None]:
fn_contig_tax = '/mnt/nfs/projects/armbrust-metat/gradients2/g2_station_ns_metat/assemblies/MarMicro_MarFerr_Diamond_2024_04_14/G2NS_ALL_MarFer_MMDB.tab'
with open(fn_contig_tax, 'r') as f:
    for _ in range(5):
        print(next(f))

Load dict

In [None]:
dict_contig_tax = {}
with open(fn_contig_tax, 'r') as f:
    for row in f:
        contig, ec, _ = row.split('\t')
        dict_contig_tax[contig] = ec

print(f'{len(dict_contig_tax):,} lines read')
fnf.getmem()

### Add counts to taxon for each sample

Filenames

In [None]:
dir_kallisto = '/mnt/nfs/projects/armbrust-metat/gradients2/g2_station_ns_metat/assemblies/ReadCounts/'
fns_kallisto = glob.glob(f'{dir_kallisto}/*/*.tsv')
fns_kallisto

In [None]:
len(fns_kallisto)

Iterate over samples

In [None]:
dict_sam_tax_estcounts = defaultdict(lambda: defaultdict(float))
for fn in fns_kallisto[:1]:
    sam = os.path.split(fn)[1]
    sam = os.path.splitext(sam)[0]
    print(sam)
    with open(fn, 'r') as f:
        _ = next(f)  # skip header
        for i, row in enumerate(f):
            contig, _, _, ec, _ = row.split('\t')
            ec = float(ec)
            if ec:
                tax = dict_contig_tax[contig]
                dict_sam_tax_estcounts[sam][tax] += ec
            if i%1e6 == 0:
                print(f'{i:,} lines read', end='\r')


### Faster to load each contig tax dict multiple times in parallel or single time and then get each sample in series?

Series

In [None]:
tl = 24.8 + 60*26.4
tl

Compare to parallel as cores increase

In [None]:
j = np.arange(1,100)
tp = np.ceil(500/j) * (24.8 + 26.4)
tp
plt.scatter(j, tp)
plt.plot([0,100],[tl]*2, 'k')

## Collapse counts to the trimmed tree

Get tree trim file

In [None]:
fn_tree_trim = '/scratch/bgrodner/iron_ko_contigs/metat_search_results/dicts_iron_KO_contig/tree_trim/merge_all/iron_KOs.txt-barnacle_tensor_tidy-tree_trim_thresh_60_minsamples_20_minbatches_4.csv'
with open(fn_tree_trim, 'r') as f:
    for _ in range(5):
        print(next(f))

Get set of taxa

In [None]:
list_taxa = []
with open(fn_tree_trim, 'r') as f:
    reader = csv.DictReader(f)
    for row in reader:
        list_taxa.append(row['taxon_trim'])
set_taxa = set(list_taxa)
set_taxa

Get Tree

In [None]:
tree = ncbi.get_topology(set_taxa)

print(tree.to_str(props=['sci_name', 'name'], compact=True))

All taxa counts filename

In [None]:
fn_taxon_estcounts = '/scratch/bgrodner/relative_abundance/metat_search_results/sample_taxa_estcounts/G2NS/G2NS-G2NS.S18C1.15m.3um.C.tsv-count_sum.txt'
with open(fn_taxon_estcounts, 'r') as f:
    for _ in range(5):
        print(next(f))


Get dict

In [None]:
dict_tax_estcounts = {}
with open(fn_taxon_estcounts, 'r') as f:
    reader = csv.reader(f)
    for tax, ec in reader:
        dict_tax_estcounts[tax] = float(ec)

Total counts

In [None]:
total_estcounts = 0
for _, ec in dict_tax_estcounts.items():
    total_estcounts += ec

Get Tree

In [None]:
tax_untrim = list(dict_tax_estcounts.keys())
tax_untrim.remove('0')
tree_untrim = ncbi.get_topology(tax_untrim)
print(tree_untrim.to_str(props=['sci_name','name'], compact=True))

Get taxtrim dict

In [None]:
dict_taxtrim_estcounts = defaultdict(float)
set_tax_tree = [n.name for n in tree.traverse()]
for tax, ec in dict_tax_estcounts.items():
    if (tax in set_tax_tree) or (int(tax) == 0):
        dict_taxtrim_estcounts[tax] += ec
    else:
        lin = ncbi.get_lineage(tax)
        lin.reverse()
        for t in lin:
            t = str(t)
            if t in set_tax_tree:
                dict_taxtrim_estcounts[t] += ec
                # ec_tot += ec
                break

In [None]:
ec_tot = 0
for _, ec in dict_taxtrim_estcounts.items():
    ec_tot += ec

ec_tot / total_estcounts

Add counts to tree

In [None]:
ec_tot = 0
for n in tree.traverse():
    ec = dict_taxtrim_estcounts.get(n.name)
    if ec:
        ec_tot += ec
        n.add_props(estcounts=round(ec))

print(ec_tot / total_estcounts)
print(tree.to_str(props=['sci_name','estcounts'], compact=True))

Add relative abundance to tree

In [None]:
pct_total = 0
ec_total = 0
for n in tree.traverse():
    ec = dict_taxtrim_estcounts.get(n.name)
    if ec:
        ec_total += ec
        pct = ec / total_estcounts * 100
        pct_total += pct
        n.add_props(pct_ec=f'{round(pct, 4)}%')
print(pct_total, ec_total / total_estcounts)
print(tree.to_str(props=['sci_name','pct_ec'], compact=True)) 

frac at root or unannotated

In [None]:
pct = dict_tax_estcounts['0'] / total_estcounts * 100
print(f'{round(pct , 5)}% unannotated')


Sort by which taxa have the most reads

In [None]:
pcts = []
taxs = []
for tax, ec in dict_tax_estcounts.items():
    pct = ec / total_estcounts * 100
    pcts.append(pct)
    taxs.append(tax)

taxs = [x for _, x in sorted(zip(pcts,taxs), reverse=True)]
pcts.sort(reverse=True)
for i in range(10):
    print(round(pcts[i],4), '%','\t',taxs[i], ncbi.get_taxid_translator([taxs[i]]) )

In [None]:
pcts = []
taxs = []
for tax, ec in dict_taxtrim_estcounts.items():
    pct = ec / total_estcounts * 100
    pcts.append(pct)
    taxs.append(tax)

taxs = [x for _, x in sorted(zip(pcts,taxs), reverse=True)]
pcts.sort(reverse=True)
print(sum(pcts))
for i in range(len(taxs)):
    print(round(pcts[i],4), '%','\t',taxs[i], ncbi.get_taxid_translator([taxs[i]]) )

## Get cruise-sample profiles after running snakemake

Snakefile_all_taxa_estcounts -> a tidytable with all the relative abundance info

In [None]:
fn_relabund = '/scratch/bgrodner/relative_abundance/metat_search_results/sample_taxa_estcounts/merge-sample_taxid_estcounts-tidy.txt'

Map fnsamplecounts to metadata

In [None]:
dict_meta = {}
for i, row in metadata.iterrows():
    dict_meta[row['fn_sample_counts']] = row.to_dict()

Add metadata to rows

In [None]:
dict_relabund = defaultdict(list)
with open(fn_relabund, 'r') as f:
    reader = csv.DictReader(f)
    for row in reader:
        # Get metadata for sample
        meta_row = dict_meta[row['fn_sample_counts']]
        # Determine selection type
        batch_bool = [s in row['batch'] for s in ['NS','G5']]
        if any(batch_bool):
            selec = 'NS'
        else:
            selec = 'PA'
        dict_relabund['selection'].append(selec)
        # Merge metadata and sample
        dict_row = row | meta_row
        # Add to main dict
        for k, v in dict_row.items():
            dict_relabund[k].append(v)
        

Build dataframe

In [None]:
df_relabund = pd.DataFrame(dict_relabund).fillna('')

Separate out selection and size fraction

In [None]:
df_relabund['size'].unique(), df_relabund['selection'].unique(), df_relabund['batch'].unique()

In [None]:
dict_sel_size_df = defaultdict(dict)
for sel in df_relabund['selection'].unique():
    bool_sel = df_relabund['selection'] == sel
    for size in df_relabund['size'].unique():
        sz = size
        if not size:
            sz = 'none'
        bool_size = df_relabund['size'] == size
        df = df_relabund[bool_sel & bool_size]
        dict_sel_size_df[sel][sz] = df

Convert assm_sample to cruise-location

In [None]:
metadata['cruise'] = [assm[:2] for assm in metadata.assembly]
metadata.cruise.unique(), metadata.assembly.unique()

In [None]:
dict_sam_csam = {}
for i, row in metadata.iterrows():
    if (row.cruise not in ['D1', 'G5']) and (row.assembly not in ['G3PA.diel', 'G3PA.PM']):
        csam = f"{row.cruise}-{row.latitude}"
    else:
        csam = row.assm_sample
    dict_sam_csam[row.assm_sample] = csam
dict_sam_csam

Cruise-location profiles for each taxon

In [None]:
dict_sel_size_meanstd = defaultdict(dict)
for sel, dict_size_df in dict_sel_size_df.items():
    for size, df in dict_size_df.items():
        df = df.copy()
        print(sel, size, df.shape)
        if df.shape[0] > 0:
            df['frac_total_estcounts'] = df['frac_total_estcounts'].astype(float)
            mean_assm_sample = df.groupby(
                ['taxid','assm_sample']
            )['frac_total_estcounts'].mean().unstack(level=0).fillna(0)
            std_assm_sample = df.groupby(
                ['taxid','assm_sample']
            )['frac_total_estcounts'].std().unstack(level=0).fillna(0)
            mean_assm_sample.index = [dict_sam_csam[s] for s in mean_assm_sample.index]
            std_assm_sample.index = [dict_sam_csam[s] for s in std_assm_sample.index]
            dict_sel_size_meanstd[sel][size] = [mean_assm_sample, std_assm_sample]


## Correlate taxon location profiles

plot linkages

In [None]:
# Correlation plots
dict_sel_size_t = {
    'PA': {
        'none': 20,
        '3.0um': 25,
        '0.2um': 28,
    },
    'NS': {
        'none': 20,
        '3.0um': 18,
        '0.2um': 35,
    }
}
crit = 'maxclust'
max_sam = 10

filt_frac = 0.01 # taxa must have more than this fraction when summed across samples

corr_method = 'pearson'

# spatial plots
ft0 = 8
ft1 = 12
plt.rcParams['font.size'] = ft1
dims_sub_lat = (22,12)
dims_exp = (16,12)
dict_cruise_j = {
    'G1': 0,
    'G2': 1,
    'G3': 2,
    'D1PA': 0,
    'G3PA.diel': 1,
    'G3PA.PM': 2,
    'G5': 3,
}

for sel, dict_size_dfs in dict_sel_size_meanstd.items():
    for size, dfs in dict_size_dfs.items():
        mean_df = dfs[0]
        ms = mean_df.sum(axis=0)
        mean_df_filt = mean_df.loc[:,ms > filt_frac]
        print(mean_df.shape, mean_df_filt.shape)
        
        # calculate correlation matrix
        corr_df = mean_df_filt.corr(method=corr_method)
        corr_df = corr_df.replace({np.nan: 0})
        # Precalculate linkage to extract clusters later
        link = hierarchy.linkage(distance.pdist(np.asarray(corr_df)))

        t = dict_sel_size_t[sel][size]

        clust = hierarchy.fcluster(link, t=t, criterion=crit)
        nclust = np.unique(clust).shape[0]

        cmap = list(plt.get_cmap('tab20').colors)
        cmap *= math.ceil(nclust / len(cmap))
        clorder =[]
        for idx in hierarchy.leaves_list(link):
            cl = clust[idx]
            if cl not in clorder:
                clorder.append(cl)
        lut = dict(zip(clorder, cmap))
        # idx_df = hierarchy.leaves_list(links[mode])
        row_colors = [lut[cl] for cl in clust]

        # make clustered heatmap
        # using precalculated linkage
        g = sns.clustermap(
            corr_df.fillna(0), 
            row_linkage=link, col_linkage=link,
            row_colors=row_colors,
            col_colors=row_colors,
            mask=corr_df.isna(), 
            cmap='PuOr_r', vmin=-1, vmax=1, 
            cbar_kws={'shrink':0.5, 'label':f'{corr_method}\nCorrelation'}, 
        )
            # xticklabels=True, yticklabels=True
        # g = sns.clustermap(
        #     corr_df.fillna(0), mask=corr_df.isna(), cmap='PuOr_r', vmin=-1, vmax=1, cbar_kws={'shrink':0.5, 'label':'Pearson\nCorrelation'}, 
        #     xticklabels=True, yticklabels=True
        # )
        g.fig.suptitle(f"Similarity of relative abundance between Taxa in {sel} selection and {size} filter", y=1.02); 
        plt.show(g)
        plt.close()


        # Plot clusters over space
        clorder_trim = [cl for cl in clorder if sum(clust == cl) > 3]
        nrows = len(clorder_trim)
        fig_lat, axes_lat = plt.subplots(
            nrows=nrows, 
            ncols=3, 
            sharex=True, 
            figsize=(20,nrows*1.5)
        )
        fig_exp, axes_exp = plt.subplots(
            nrows=nrows,
            ncols=4, 
            figsize=(20,nrows)
        )
        profile_df = mean_df_filt.copy()
        profile_df = profile_df / profile_df.max(axis=0)
        profile_df = profile_df.T
        i = 0
        for ic, cl in enumerate(clorder):
            bool_cl = clust == cl
            if sum(bool_cl) > 3:
                dict_cruise_loc_weigts = defaultdict(dict)

                profile_cl = profile_df[bool_cl]
                for csam in profile_cl.columns:
                    cruise, info = csam.split('-',1)
                    if cruise not in ['D1PA', 'G5','G3PA.diel','G3PA.PM']:
                        info = float(info.strip('deg'))
                    vals = profile_cl[csam].values
                    if len(vals.shape) > 1:
                        vals = np.mean(vals, axis=1)
                    dict_cruise_loc_weigts[cruise][info] = vals

                # Plot each cruise separately
                for cruise, dict_loc_weights in dict_cruise_loc_weigts.items():
                    dfc = pd.DataFrame(dict_loc_weights)
                    j = dict_cruise_j[cruise]
                    if cruise not in ['D1PA', 'G5','G3PA.diel','G3PA.PM']:
                        ax = axes_lat[i,j]
                        # fig, ax = plt.subplots(figsize=dims_lat)
                        bp = ax.boxplot(dfc.values, positions=dfc.columns, patch_artist=True)
                        # dfc.boxplot(ax=ax, positions=dfc.columns)
                        # ax.set_xlim(23,43)
                        # ax.set_xlabel('Latitude', fontsize=ft1)
                        if i == nrows - 1:
                            xticks = np.arange(20,45,5).astype(int)
                            ax.set_xticks(xticks)
                            ax.set_xticklabels(xticks)
                            ax.tick_params(axis='both', labelsize=ft1) 

                    else:
                        ax = axes_exp[i,j]
                        # fig, ax = plt.subplots(figsize=dims_exp)
                        dfc = dfc.sort_index()
                        bp = ax.boxplot(dfc.values, patch_artist=True)
                        if i == nrows - 1:
                            xticks = np.arange(len(dfc.columns)) + 1
                            ax.set_xticks(xticks)
                            ax.set_xticklabels(dfc.columns, rotation=90)
                            ax.tick_params(axis='both', labelsize=ft1)
                        else:
                            ax.set_xticks([])

                    # Set colors
                    color = cmap[ic]
                    for item in ['boxes', 'whiskers', 'fliers', 'medians', 'caps']:
                            plt.setp(bp[item], color=color)
                    # plt.setp(box1["boxes"], facecolor=c2)
                    plt.setp(bp["fliers"], markeredgecolor=color)

                    if j == 2:
                        ax.set_ylabel(f'Cluster {cl}', rotation=0, fontsize=ft0, ha='left')
                        ax.yaxis.set_label_position("right")
                        # dfc.boxplot(ax=ax)
                        # ax.set_xlabel('Experiment', fontsize=ft1)
                    # ax.set_ylim(-0.05,1.05)
                    # ax.set_xticklabels([]); 
                    # ax.tick_params(axis='both', labelsize=ft0)
                    ax.grid(False)
                i += 1

                    # ax.set_ylabel('Component weight', fontsize=ft1)
                    # ax.set_title(f'{cruise} - Cluster {cl}', fontsize=ft1)







### Correlate Taxon-sel-size across cruise=sample

Since G5, G3PA.PM, D1, and G3PA.diel don't have samples for each sel-size combo, I'm going to remove them here. Then each sel-size combo should have the same, or nearly the same, set of cruise-locations.

Subset the df

In [None]:
df_relabund['batch'].unique()

In [None]:
# Subset the dataframe to only those batches we want
bools = np.ones(df_relabund.shape[0])
batchs = df_relabund['batch'].values
for b in ['D1', 'G5.RR', 'G5.mix','G3PA.diel','G3PA.PM']:
# for b in ['D1', 'G5.RR', 'G5.mix','G3PA.diel','G3PA.PM','G2PA','G2NS','G1PA','G1NS','G3NS']:
    bools *= (batchs != b)
# bools *= (df_relabund['size'] == '3.0um')
df_relabund_transct = df_relabund.copy()
df_relabund_transct = df_relabund_transct[bools.astype(bool)]
# Give a new taxon-selection-size name
df_relabund_transct['taxid_sel_size'] = (
    df_relabund_transct['selection'].astype(str) 
    + '-' + df_relabund_transct['size'].astype(str) 
    + '-' + df_relabund_transct['taxid'].astype(str) 
)

df_relabund.shape, df_relabund_transct.shape, df_relabund_transct['taxid_sel_size'][:3]

Separate out selection and size fraction

In [None]:
df_relabund_transct['size'].unique(), df_relabund_transct['selection'].unique(), df_relabund_transct['batch'].unique()

In [None]:
dict_sel_size_df = defaultdict(dict)
for sel in df_relabund_transct['selection'].unique():
    bool_sel = df_relabund_transct['selection'] == sel
    for size in df_relabund_transct['size'].unique():
        sz = size
        if not size:
            sz = 'none'
        bool_size = df_relabund_transct['size'] == size
        df = df_relabund_transct[bool_sel & bool_size]
        dict_sel_size_df[sel][sz] = df

Get the mean across replicates and pivot the table

In [None]:
list_df_mean = []
list_df_std = []
for sel, dict_size_df in dict_sel_size_df.items():
    for size, df in dict_size_df.items():
        df = df.copy()
        print(sel, size, df.shape)
        if df.shape[0] > 0:
            df['frac_total_estcounts'] = df['frac_total_estcounts'].astype(float)
            mean_assm_sample = df.groupby(
                ['taxid_sel_size','assm_sample']
            )['frac_total_estcounts'].mean().unstack(level=0).fillna(0)
            std_assm_sample = df.groupby(
                ['taxid_sel_size','assm_sample']
            )['frac_total_estcounts'].std().unstack(level=0).fillna(0)
            mean_assm_sample.index = [dict_sam_csam[s] for s in mean_assm_sample.index]
            std_assm_sample.index = [dict_sam_csam[s] for s in std_assm_sample.index]
            list_df_mean.append(mean_assm_sample)
            list_df_std.append(std_assm_sample)
            


Merge on cruise-sample

In [None]:
tax_sel_sz_csam_profile = pd.concat(list_df_mean, axis=1)
tax_sel_sz_csam_profile.shape

Plot linkages

In [None]:

t = 2.75
criterion = 'distance'

corr_method = 'pearson'

filt_frac = 0.01 # taxa must have more than this fraction when summed across samples
ms = tax_sel_sz_csam_profile.sum(axis=0)
mean_df_filt = tax_sel_sz_csam_profile.loc[:,ms > filt_frac]

# calculate correlation matrix
corr_df = mean_df_filt.corr(method=corr_method)
corr_df = corr_df.replace({np.nan: 0})
# Precalculate linkage to extract clusters later
link = hierarchy.linkage(distance.pdist(np.asarray(corr_df)))

clust = hierarchy.fcluster(link, t=t, criterion=criterion)
nclust = np.unique(clust).shape[0]

cmap = list(plt.get_cmap('tab20').colors)
cmap *= math.ceil(nclust / len(cmap))
clorder =[]
for idx in hierarchy.leaves_list(link):
    cl = clust[idx]
    if cl not in clorder:
        clorder.append(cl)
lut = dict(zip(clorder, cmap))
# idx_df = hierarchy.leaves_list(links[mode])
row_colors = [lut[cl] for cl in clust]

# make clustered heatmap
# using precalculated linkage
g = sns.clustermap(
    corr_df.fillna(0), 
    row_linkage=link, col_linkage=link,
    row_colors=row_colors,
    col_colors=row_colors,
    mask=corr_df.isna(), 
    cmap='PuOr_r', vmin=-1, vmax=1, 
    cbar_kws={'shrink':0.5, 'label':f'{corr_method}\nCorrelation'}, 
)
    # xticklabels=True, yticklabels=True
# g = sns.clustermap(
#     corr_df.fillna(0), mask=corr_df.isna(), cmap='PuOr_r', vmin=-1, vmax=1, cbar_kws={'shrink':0.5, 'label':'Pearson\nCorrelation'}, 
#     xticklabels=True, yticklabels=True
# )
g.fig.suptitle(f"Similarity of Cruise-location relative abundance across taxa-selection-size", y=1.02); 
plt.show(g)


Plot sample profiles

In [None]:
# spatial plots
n_tax_filt = 5  # Only plot clusters with more than this number of taxa

ft0 = 8
ft1 = 12
plt.rcParams['font.size'] = ft1
dims_sub_lat = (22,12)
dims_exp = (16,12)
dict_cruise_j = {
    'G1': 0,
    'G2': 1,
    'G3': 2,
    'D1PA': 0,
    'G3PA.diel': 1,
    'G3PA.PM': 2,
    'G5': 3,
}

# Plot clusters over space
clorder_trim = [cl for cl in clorder if sum(clust == cl) > n_tax_filt]
nrows = len(clorder_trim)
fig_lat, axes_lat = plt.subplots(
    nrows=nrows, 
    ncols=3, 
    sharex=True, 
    sharey=True,
    figsize=(20,nrows*1.5)
)
# fig_exp, axes_exp = plt.subplots(
#     nrows=nrows,
#     ncols=4, 
#     figsize=(20,nrows)
# )
profile_df = mean_df_filt.copy()
profile_df = profile_df / profile_df.max(axis=0)
profile_df = profile_df.T
i = 0
for ic, cl in enumerate(clorder):
    bool_cl = clust == cl
    if sum(bool_cl) > n_tax_filt:
        dict_cruise_loc_weigts = defaultdict(dict)

        profile_cl = profile_df[bool_cl]
        for csam in profile_cl.columns:
            cruise, info = csam.split('-',1)
            if cruise not in ['D1PA', 'G5','G3PA.diel','G3PA.PM']:
                info = float(info.strip('deg'))
            vals = profile_cl[csam].values
            if len(vals.shape) > 1:
                vals = np.mean(vals, axis=1)
            dict_cruise_loc_weigts[cruise][info] = vals

        # Plot each cruise separately
        for cruise, dict_loc_weights in dict_cruise_loc_weigts.items():
            dfc = pd.DataFrame(dict_loc_weights)
            j = dict_cruise_j[cruise]
            if cruise not in ['D1PA', 'G5','G3PA.diel','G3PA.PM']:
                ax = axes_lat[i,j]
                # fig, ax = plt.subplots(figsize=dims_lat)
                bp = ax.boxplot(dfc.values, positions=dfc.columns, patch_artist=True)
                # dfc.boxplot(ax=ax, positions=dfc.columns)
                # ax.set_xlim(23,43)
                # ax.set_xlabel('Latitude', fontsize=ft1)
                if i == nrows - 1:
                    xticks = np.arange(20,45,5).astype(int)
                    ax.set_xticks(xticks)
                    ax.set_xticklabels(xticks)
                    ax.tick_params(axis='both', labelsize=ft1) 

            # else:
            #     ax = axes_exp[i,j]
            #     # fig, ax = plt.subplots(figsize=dims_exp)
            #     dfc = dfc.sort_index()
            #     bp = ax.boxplot(dfc.values, patch_artist=True)
            #     if i == nrows - 1:
            #         xticks = np.arange(len(dfc.columns)) + 1
            #         ax.set_xticks(xticks)
            #         ax.set_xticklabels(dfc.columns, rotation=90)
            #         ax.tick_params(axis='both', labelsize=ft1)
            #     else:
            #         ax.set_xticks([])

            # Set colors
            color = cmap[ic]
            for item in ['boxes', 'whiskers', 'fliers', 'medians', 'caps']:
                    plt.setp(bp[item], color=color)
            # plt.setp(box1["boxes"], facecolor=c2)
            plt.setp(bp["fliers"], markeredgecolor=color)

            if j == 2:
                ax.set_ylabel(f'Cluster {cl}', rotation=0, fontsize=ft0, ha='left')
                ax.yaxis.set_label_position("right")
                # dfc.boxplot(ax=ax)
                # ax.set_xlabel('Experiment', fontsize=ft1)
            # ax.set_ylim(-0.05,1.05)
            # ax.set_xticklabels([]); 
            # ax.tick_params(axis='both', labelsize=ft0)
            ax.grid(False)
        i += 1


Print taxa for clusters

In [None]:
ncbi.get_taxid_translator([35677])

In [None]:
out_fn = '/scratch/bgrodner/relative_abundance/metat_search_results/plots/relabund_tax_clusters.txt'
with open(out_fn, 'w') as f:
    for ic, cl in enumerate(clorder):
        bool_cl = clust == cl
        if sum(bool_cl) > n_tax_filt:
            print('Cluster: ',cl)
            f.write(f'\nCluster: {cl}\n')
            profile_cl = profile_df[bool_cl]
            # print('\t', profile_cl.index)
            dict_tax_selsz = defaultdict(list)
            for sel_sz_tax in profile_cl.index:
                sel, sz, tax = sel_sz_tax.split('-')
                dict_tax_selsz[tax].append(f'{sel}-{sz}')
            tree_ = ncbi.get_topology(list(dict_tax_selsz.keys()))
            for n in tree_.traverse():
                tax = n.name
                selszs = set(dict_tax_selsz[tax])
                prp = ''
                for s in selszs:
                    prp += f'{s},'
                prp = prp[:-1]
                n.add_props(selsizes=prp)
            treestr = tree_.to_str(props=['sci_name','selsizes'])
            print(treestr)
            f.write(treestr)


In [None]:
mm = mean_assm_sample.mean(axis=0)
ms = mean_assm_sample.std(axis=0)

mm.sort_values(ascending=False)[:20]

In [None]:
name = 'Cyanophyceae'
t = ncbi.get_name_translator([name])[name][0]
t

In [None]:
# spatial plots
ft0 = 8
ft1 = 12
plt.rcParams['font.size'] = ft1
dims_sub_lat = (22,12)
dims_exp = (16,12)
dict_cruise_j = {
    'G1': 0,
    'G2': 1,
    'G3': 2,
    'D1PA': 0,
    'G3PA.diel': 1,
    'G3PA.PM': 2,
    'G5': 3,
}

# Plot clusters over space
# clorder_trim = [cl for cl in clorder if sum(clust == cl) > 3]
taxa = [f'NS-3.0um-{t}',f'NS-0.2um-{t}',f'PA-3.0um-{t}',f'PA-0.2um-{t}']
nrows = len(taxa)
fig_lat, axes_lat = plt.subplots(
    nrows=nrows, 
    ncols=3, 
    sharex=True, 
    figsize=(20,nrows*1.5)
)
# fig_exp, axes_exp = plt.subplots(
#     nrows=nrows,
#     ncols=4, 
#     figsize=(20,nrows)
# )
# profile_df = tax_sel_sz_csam_profile[taxa].copy()
# profile_df = tax_sel_sz_csam_profile[['NS-3.0um-1218','NS-0.2um-1218','PA-3.0um-0','PA-0.2um-0']].copy()
# profile_df = profile_df / profile_df.max(axis=0)
# profile_df = profile_df.T
i = 0
for ic, cln in enumerate(taxa):
    # bool_cl = clust == cl
    # if sum(bool_cl) > 3:
    # dict_cruise_loc_weigts = defaultdict(dict)

    for csam, val in tax_sel_sz_csam_profile[cln].items():
        cruise, info = csam.split('-',1)
        j = dict_cruise_j[cruise]

        if cruise not in ['D1PA', 'G5','G3PA.diel','G3PA.PM']:
            info = float(info.strip('deg'))
        if cruise not in ['D1PA', 'G5','G3PA.diel','G3PA.PM']:
            ax = axes_lat[i,j]
            # fig, ax = plt.subplots(figsize=dims_lat)
            bp = ax.scatter([info], [float(val)],c='k')
            # dfc.boxplot(ax=ax, positions=dfc.columns)
            # ax.set_xlim(23,43)
            # ax.set_xlabel('Latitude', fontsize=ft1)
            if i == nrows - 1:
                xticks = np.arange(20,45,5).astype(int)
                ax.set_xticks(xticks)
                ax.set_xticklabels(xticks)
                ax.tick_params(axis='both', labelsize=ft1) 
        if j == 2:
            ax.set_ylabel(f'{cln}', rotation=0, fontsize=ft0, ha='left')
            ax.yaxis.set_label_position("right")
    # ylim = ax.get_ylim()
    # ax.set_ylim(0,ylim[1])
            # dfc.boxplot(ax=ax)
            # ax.set_xlabel('Experiment', fontsize=ft1)        
    # for csam in profile_cl.columns:
    #     cruise, info = csam.split('-',1)
    #     if cruise not in ['D1PA', 'G5','G3PA.diel','G3PA.PM']:
    #         info = float(info.strip('deg'))
    #     vals = profile_cl[csam].values
    #     if len(vals.shape) > 1:
    #         vals = np.mean(vals, axis=1)
    #     dict_cruise_loc_weigts[cruise][info] = vals

    # # Plot each cruise separately
    # for cruise, dict_loc_weights in dict_cruise_loc_weigts.items():
    #     # dfc = pd.DataFrame(dict_loc_weights)
    #     j = dict_cruise_j[cruise]
    #     if cruise not in ['D1PA', 'G5','G3PA.diel','G3PA.PM']:
    #         ax = axes_lat[i,j]
    #         # fig, ax = plt.subplots(figsize=dims_lat)
    #         bp = ax.boxplot(dfc.values, positions=dfc.columns, patch_artist=True)
    #         # dfc.boxplot(ax=ax, positions=dfc.columns)
    #         # ax.set_xlim(23,43)
    #         # ax.set_xlabel('Latitude', fontsize=ft1)
    #         if i == nrows - 1:
    #             xticks = np.arange(20,45,5).astype(int)
    #             ax.set_xticks(xticks)
    #             ax.set_xticklabels(xticks)
    #             ax.tick_params(axis='both', labelsize=ft1) 

    #     else:
    #         ax = axes_exp[i,j]
    #         # fig, ax = plt.subplots(figsize=dims_exp)
    #         dfc = dfc.sort_index()
    #         bp = ax.boxplot(dfc.values, patch_artist=True)
    #         if i == nrows - 1:
    #             xticks = np.arange(len(dfc.columns)) + 1
    #             ax.set_xticks(xticks)
    #             ax.set_xticklabels(dfc.columns, rotation=90)
    #             ax.tick_params(axis='both', labelsize=ft1)
    #         else:
    #             ax.set_xticks([])

    #     # Set colors
    #     color = cmap[ic]
    #     for item in ['boxes', 'whiskers', 'fliers', 'medians', 'caps']:
    #             plt.setp(bp[item], color=color)
    #     # plt.setp(box1["boxes"], facecolor=c2)
    #     plt.setp(bp["fliers"], markeredgecolor=color)

    #     if j == 2:
    #         ax.set_ylabel(f'Cluster {cl}', rotation=0, fontsize=ft0, ha='left')
    #         ax.yaxis.set_label_position("right")
    #         # dfc.boxplot(ax=ax)
    #         # ax.set_xlabel('Experiment', fontsize=ft1)
    #     # ax.set_ylim(-0.05,1.05)
    #     # ax.set_xticklabels([]); 
    #     # ax.tick_params(axis='both', labelsize=ft0)
    #     ax.grid(False)
    i += 1
