# Step 5: Analyze final Barnacle model

Use this notebook to compile and analyze the final version of your Barnacle model. This should be the version of the model that is fit with the optimal parameters you identified in step 4. There are several parts of this compilation and analysis notebook:
1. Align the components between bootstraps of your final model.
    - The order of components is not fixed in this tensor decomposition model. Therefore, in order to compare between bootstraps, the components must first be aligned to one another.
    - The aligned bootstraps will be saved as an xarray.DataSet so that you can access them for further analysis
1. Summarize the model weights for each component.
    - Each component can be understood to model a different pattern in the data. Depending on how you set up your data and your Barnacle model, each pattern might also be associated with a different cluster (e.g. gene clusters). This step separates out each component so you can more closely examine the pattern and/or cluster each is modeling.
1. Visualize your model.
    - Effective visualization depends on your data type, size, dimensions, and the questions you are asking. A few potential visualizations are suggested below to help get you started.

In [None]:
# imports

import itertools
import numpy as np
import os
import pandas as pd
import seaborn as sns
import tensorly as tl
import tlviz
import xarray as xr

from barnacle.tensors import SparseCPTensor
from barnacle.utils import subset_cp_tensor
from functools import reduce
from matplotlib import pyplot as plt
from tlab.cp_tensor import load_cp_tensor
from tqdm.notebook import tqdm
from scipy.spatial import distance
from scipy.cluster import hierarchy

# set color palette
sns.set_palette(sns.color_palette([
    '#9B5DE5', '#FFAC69', '#00C9AE', '#FD3F92', '#0F0A0A', '#959AB1', '#FFDB66', '#FFB1CA', '#63B9FF', '#4F1DD7'
]))


### Part A: Align model bootstraps

In [None]:
# USER INPUTS -- edit these variables as needed

# path to directory where the outputs from your parameter search were saved (e.g. 'directory/barnacle/fitting/')
fitpath = '/scratch/bgrodner/iron_ko_contigs/metat_search_results/barnacle/iron_KOs.txt-tidy_all_trim/sub_taxa_01/barnacle/fitting'

# path to the normalized data tensor used to fit barnacle (e.g. 'directory/data-tensor.nc')
datapath = '/scratch/bgrodner/iron_ko_contigs/metat_search_results/barnacle/iron_KOs.txt-tidy_all_trim/sub_taxa_01/data-tensor.nc'

# optimal rank parameter (number of components) used to fit your final model
optimal_rank = 100

# optimal lambda parameter (sparsity coefficient) used to fit your final model
optimal_lambda = 1.0

# number of bootstraps used for final model
n_bootstraps = 10

# output directory where files produced by this notebook will be saved
outdir = fitpath.strip('fitting') + 'model'
if not os.path.isdir(outdir):
    os.makedirs(outdir)
print(f"File outputs of this notebook will be saved here: {outdir}")


In [None]:
# align bootstraps of final model

# set up parameters and data structures
input_ds = xr.load_dataset(f"{fitpath}/bootstrap0/dataset-bootstrap0.nc")
replicates = [str(l) for l in set(input_ds.replicate_id.data)]
bootstraps = np.arange(n_bootstraps)
samplenames = {rep: [] for rep in replicates}    # sample names
cps = {rep: [] for rep in replicates}   # cp tensors with all samples present
subset_cps = {rep: [] for rep in replicates}    # cp tensors subset to just common samples

# collect sample names of each bootstrap/replicate pair
for boot in tqdm(bootstraps, desc='Extracting sample names'):
    for rep in replicates:
        ds = xr.open_dataset(f"{fitpath}/bootstrap{boot}/replicate{rep}/shuffled-replicate-{rep}.nc")
        samplenames[rep].append(ds.sample_id.data)
# compile set of samplenames common to all bootstrap / replicate splits
samplenames['common'] = reduce(np.intersect1d, itertools.chain.from_iterable([samplenames[r] for r in replicates]))

# import all fitted models
for boot in tqdm(bootstraps, desc='Importing model bootstraps'):
    for rep in replicates:
        # put together data path
        path_cp = f"bootstrap{boot}/replicate{rep}/rank{optimal_rank}/lambda{optimal_lambda}/fitted-model.h5"
        # store normalized cp tensor to cps
        cp = tl.cp_normalize(load_cp_tensor(f"{fitpath}/{path_cp}"))
        cps[rep].append(cp)
        # pull out common samplenames and store in subset_aligned_cps
        idx = np.where(np.isin(samplenames[rep][boot], samplenames['common']))[0]
        subset_cps[rep].append(subset_cp_tensor(cp, {2: idx}))
print(f"Successfully imported {len(cps[rep])} model bootstraps, each with {len(replicates)} replicates.")

# find best representative reference cp tensor
results = []
combos = list(itertools.product(replicates, bootstraps))
for ref_rep, ref_boot in tqdm(combos, desc='Identifying best reference model from bootstraps'):
    # limit comparisons to a random sample of 100 bootstraps
    if len(combos) > 100:
        combos = [combos[i] for i in np.random.choice(len(combos), size=100, replace=False)]
    for comp_rep, comp_boot in combos:
        # no point in comparing to self
        if ref_rep == comp_rep and ref_boot == comp_boot:
            continue
        reference_cp = subset_cps[ref_rep][ref_boot]
        comparison_cp = subset_cps[comp_rep][comp_boot]
        fms = tlviz.factor_tools.factor_match_score(reference_cp, comparison_cp, consider_weights=False)
        results.append({
            'reference_bootstrap': ref_boot, 
            'reference_replicate': ref_rep, 
            'comparison_bootstrap': comp_boot, 
            'comparison_replicate': comp_rep, 
            'fms': fms, 
        })
# summarize overall mean fms  
fms_df = pd.DataFrame(results)
fms_summary_df = fms_df.groupby([
    'reference_bootstrap', 
    'reference_replicate'
]).agg(
    mean_fms=('fms', 'mean'), 
    median_fms=('fms', 'median'), 
).reset_index()
# find the best representative bootstrap model based on maximum mean FMS
best_ref = fms_summary_df.iloc[fms_summary_df.mean_fms.idxmax(), :]
print('All bootstraps will be aligned to the following reference model:')
display(pd.DataFrame(best_ref).T.reset_index(drop=True))

# permute reference cp so that components are in descending order of explaned variation
ref_cp = tlviz.factor_tools.permute_cp_tensor(
    subset_cps[best_ref['reference_replicate']][best_ref['reference_bootstrap']], 
    consider_weights=False
)        

# realign all the other cp tensors against the best representative cp tensor
for rep in replicates:
    for boot in bootstraps:
        # permute components to line up with best representative reference cp
        perm = tlviz.factor_tools.get_cp_permutation(subset_cps[rep][boot], reference_cp_tensor=ref_cp, consider_weights=False)
        cps[rep][boot] = tlviz.factor_tools.permute_cp_tensor(cps[rep][boot], permutation=perm)
        subset_cps[rep][boot] = tlviz.factor_tools.permute_cp_tensor(subset_cps[rep][boot], permutation=perm)
print('All model bootstraps successfully aligned.')


In [None]:
# compile aligned model weights into a single xarray.Dataset

# set up data structures
component_labels = np.arange(optimal_rank) + 1 # 1-based indexing for ease of communication
weights = {'mode0': [], 'mode1': [], 'component': [], 'pct_var': [], 'fms': []}
sample_info_df = pd.merge(
    input_ds.sample_id.to_series().reset_index(), 
    input_ds.replicate_id.to_series().reset_index(), 
    on='sample_replicate_id', how='inner'
)[['sample_id', 'replicate_id']].rename(columns={'sample_id': 'sample', 'replicate_id': 'replicate'})
sample_df = pd.DataFrame()

# separate out reference components for component specific FMS score
ref_components = SparseCPTensor(ref_cp).get_components()

# pull model weights from each bootstrap
for boot in tqdm(bootstraps, desc='Compiling weights from each bootstrap'):
    boot_sample_df = pd.DataFrame()
    for key in weights.keys():
        weights[key].append([])
    for rep in replicates:
        # fetch aligned cp tensor
        cp = cps[rep][boot]
        # add mode 0 weights to list
        weights['mode0'][boot].append(cp.factors[0].T)
        # add mode 1 weights to list
        weights['mode1'][boot].append(cp.factors[1].T)
        # add component weights to list
        weights['component'][boot].append(cp.weights)
        # calculate percent variation explained by components and add to list
        weights['pct_var'][boot].append(tlviz.factor_tools.percentage_variation(cp, dataset=input_ds.data.data, method='data'))
        # put mode 2 (sample) weights into a pd.DataFrame
        rep_sample_df = pd.DataFrame(
            cp.factors[2], index=samplenames[rep][boot], columns=component_labels
        ).reset_index().rename(columns={'index': 'sample'})
        rep_sample_df['replicate'] = rep
        # concatenate sample weights of all replicates
        boot_sample_df = pd.concat([boot_sample_df, rep_sample_df])
        # calculate component specific FMS for each component vs reference bootstrap
        fms_scores = []
        for i, comp_component in enumerate(SparseCPTensor(subset_cps[rep][boot]).get_components()):
            # skip the None-type components included for aligning smaller cp tensors
            if np.all(np.isnan(comp_component.factors[0])):
                continue
            # compare components
            fms_scores.append(tlviz.factor_tools.factor_match_score(comp_component, ref_components[i], consider_weights=False))
        weights['fms'][boot].append(fms_scores)
    # merge sample info into sample weights dataframe
    boot_sample_df = pd.merge(left=sample_info_df, right=boot_sample_df, on=['sample', 'replicate'], how='left')
    boot_sample_df['bootstrap'] = boot
    # concatenate sample weights of all bootstraps
    sample_df = pd.concat([sample_df, boot_sample_df])

# compile everything into an xarray.Dataset
modes = list(input_ds.coords)
modes[2] = 'sample'
ds = xr.Dataset({
    f"{modes[0]}_weights": xr.DataArray(
        np.array(weights['mode0']), 
        coords=[bootstraps, replicates, component_labels, input_ds[modes[0]].data], 
        dims=['bootstrap', 'replicate', 'component', modes[0]]
    ), 
    f"{modes[1]}_weights": xr.DataArray(
        np.array(weights['mode1']), 
        coords=[bootstraps, replicates, component_labels, input_ds[modes[1]].data], 
        dims=['bootstrap', 'replicate', 'component', modes[1]]
    ), 
    f"{modes[2]}_weights": xr.DataArray.from_series(
        sample_df.melt(
            id_vars=['bootstrap', 'replicate', modes[2]], 
            value_vars=component_labels, 
            var_name='component', 
            value_name='sample_weights'
        ).set_index(['bootstrap', 'replicate', 'component', modes[2]])[f"{modes[2]}_weights"]
    ), 
    'component_weights': xr.DataArray(
            np.array(weights['component']), 
            coords=[bootstraps, replicates, component_labels], 
            dims=['bootstrap', 'replicate', 'component']
    ),
    'percent_variation': xr.DataArray(
            np.array(weights['pct_var']), 
            coords=[bootstraps, replicates, component_labels], 
            dims=['bootstrap', 'replicate', 'component']
    ), 
    'fms_component': xr.DataArray(
            np.array(weights['fms']), 
            coords=[bootstraps, replicates, component_labels], 
            dims=['bootstrap', 'replicate', 'component']
    )
})

# add reference tensor, rank, and sparsity coefficient as attributes
ds.attrs['rank'] = optimal_rank
ds.attrs['lambda'] = optimal_lambda
ds.attrs['n_bootstraps'] = n_bootstraps
ds.attrs['align_ref_bootstrap'] = best_ref['reference_bootstrap']
ds.attrs['align_ref_replicate'] = best_ref['reference_replicate']

# save Dataset as netCDF4 file
ds.to_netcdf(f"{outdir}/aligned-bootstraps.nc")

# examine Dataset
ds

In [None]:
df

### Part B: Summarize model weights for each component

In [None]:
component=3
mode='KO'
df = pd.DataFrame({
    'mean_weight': mean_ds.sel(component=component)[f"{mode}_weights"].to_series(), 
    'std_weight': std_ds.sel(component=component)[f"{mode}_weights"].to_series(), 
    'median_weight': med_ds.sel(component=component)[f"{mode}_weights"].to_series(), 
    'pct_bootstraps_nonzero': support_ds.sel(component=component)[f"{mode}_weights"].to_series(), 
}).sort_values(by=['mean_weight', 'pct_bootstraps_nonzero'], key=abs, ascending=False)
df

In [None]:
ds.fms_component.sel(component=1).values.ravel().mean()

In [None]:
# summarize model weights for each component, and save summaries as csv files

# summarize weights across bootstraps and replicates
mean_ds = ds.mean(dim=['bootstrap', 'replicate'])    # mean of weights
std_ds = ds.std(dim=['bootstrap', 'replicate'])    # standard deviation of weights
med_ds = ds.median(dim=['bootstrap', 'replicate'])    # median of weights
support_ds = (ds != 0).mean(dim=['bootstrap', 'replicate']) * 100     # bootstrap support (percent of weights nonzero)

# plot % variation explained and FMS for each component 
fig, axes = plt.subplots(2, 1, figsize=(min(15, len(ds.component.data)), 10), sharex=True)
sns.boxplot(ds.percent_variation.to_series().reset_index(), x='component', y='percent_variation', ax=axes[0]); 
axes[0].set(title='Explanatory Power of Components', xlabel='Component', ylabel='% Variation Explained'); 
sns.boxplot(ds.fms_component.to_series().reset_index(), x='component', y='fms_component', color=sns.color_palette()[8], ax=axes[1]); 
axes[1].set(title='Consistency of Components\n(Between Bootstraps)', xlabel='Component', ylabel='FMS\n(Relative to Alignment Reference)'); 
plt.show()

# extract summary of weights for each component
for component in ds.component.data:
    if not os.path.isdir(f"{outdir}/component{component}"):
        os.makedirs(f"{outdir}/component{component}")
    for i, mode in enumerate(modes):
        # make dataframe summarizing weights
        df = pd.DataFrame({
            'mean_weight': mean_ds.sel(component=component)[f"{mode}_weights"].to_series(), 
            'std_weight': std_ds.sel(component=component)[f"{mode}_weights"].to_series(), 
            'median_weight': med_ds.sel(component=component)[f"{mode}_weights"].to_series(), 
            'pct_bootstraps_nonzero': support_ds.sel(component=component)[f"{mode}_weights"].to_series(), 
        }).sort_values(by=['mean_weight', 'pct_bootstraps_nonzero'], key=abs, ascending=False)
        # drop values with a median weight of zero (equivalent to 50% bootstrap support threshold)
        df = df[~df.median_weight.eq(0)].reset_index()
        # df = df[~df.mean_weight.eq(0)].reset_index()
        # save data
        df.to_csv(f"{outdir}/component{component}/{mode}_weights_summary.csv", index=False)
        # plot top weights
        if component <= max_components:
            df = df.head(max_elements) # just look at top weights
            df[mode] = df[mode].astype(str) # in case labels aren't strings
            print(df[mode].values.tolist())
            fig, axis = plt.subplots(figsize=(min(15, len(df)), 3))
            sns.barplot(df, x=mode, y='mean_weight', color=sns.color_palette()[i+1], legend=False, ax=axis); 
            axis.errorbar(x=df[mode], y=df['mean_weight'], yerr=df['std_weight'], fmt='none', color=sns.color_palette()[5])
            axis.set(title=f"Component {component} Top {mode.capitalize()} Weights", xlabel=None, ylabel='Weight'); 
            fig.autofmt_xdate(rotation=90)
            plt.show()
        

Get components with multiple taxa and high mean FMS

In [None]:
comp_mult = []
for component in ds.component.data:
    mode = 'taxon_trim'
    df = pd.DataFrame({
        'mean_weight': mean_ds.sel(component=component)[f"{mode}_weights"].to_series(), 
        'std_weight': std_ds.sel(component=component)[f"{mode}_weights"].to_series(), 
        'median_weight': med_ds.sel(component=component)[f"{mode}_weights"].to_series(), 
        'pct_bootstraps_nonzero': support_ds.sel(component=component)[f"{mode}_weights"].to_series(), 
    }).sort_values(by=['mean_weight', 'pct_bootstraps_nonzero'], key=abs, ascending=False)
    w0, w1 = df['mean_weight'][:2].values
    mean_fms = ds.fms_component.sel(component=component).values.ravel().mean()
    if ((w0 * 0.4) < w1) and (mean_fms > 0.5):
        print(f"\n\nComponent: {component}")
        print(f"Mean FMS: {mean_fms}")
        print(df['mean_weight'][:5])
        comp_mult.append(component)


In [None]:
# extract summary of weights for each component
for component in comp_mult:
    mean_fms = ds.fms_component.sel(component=component).values.ravel().mean()
    print(f"\n\nComponent: {component}")
    print(f"Mean FMS: {mean_fms}")
    if not os.path.isdir(f"{outdir}/component{component}"):
        os.makedirs(f"{outdir}/component{component}")
    for i, mode in enumerate(modes):
        # make dataframe summarizing weights
        df = pd.DataFrame({
            'mean_weight': mean_ds.sel(component=component)[f"{mode}_weights"].to_series(), 
            'std_weight': std_ds.sel(component=component)[f"{mode}_weights"].to_series(), 
            'median_weight': med_ds.sel(component=component)[f"{mode}_weights"].to_series(), 
            'pct_bootstraps_nonzero': support_ds.sel(component=component)[f"{mode}_weights"].to_series(), 
        }).sort_values(by=['mean_weight', 'pct_bootstraps_nonzero'], key=abs, ascending=False)
        # drop values with a median weight of zero (equivalent to 50% bootstrap support threshold)
        df = df[~df.median_weight.eq(0)].reset_index()
        # df = df[~df.mean_weight.eq(0)].reset_index()
        # save data
        df.to_csv(f"{outdir}/component{component}/{mode}_weights_summary.csv", index=False)
        # plot top weights
        df = df.head(max_elements) # just look at top weights
        df[mode] = df[mode].astype(str) # in case labels aren't strings
        print(df[mode].values.tolist())
        fig, axis = plt.subplots(figsize=(min(15, len(df)), 3))
        sns.barplot(df, x=mode, y='mean_weight', color=sns.color_palette()[i+1], legend=False, ax=axis); 
        axis.errorbar(x=df[mode], y=df['mean_weight'], yerr=df['std_weight'], fmt='none', color=sns.color_palette()[5])
        axis.set(title=f"Component {component} Top {mode.capitalize()} Weights", xlabel=None, ylabel='Weight'); 
        fig.autofmt_xdate(rotation=90)
        plt.show()


### Part C: Visualization

In [None]:
# compare weight profiles between components

# make dataframe for concatenating weight profiles
concat_profile_df = pd.DataFrame()

# iterate through comparisons
comparisons = modes + ['concatenated']
# Dict for storing linkage calculations
links = {}
columns = {}
for mode in comparisons:
    print(f"Comparing {mode} weights:")
    if mode != 'concatenated':
        # calculate median weight profile for each component in model
        profile_df = ds[f"{mode}_weights"].median(dim=['bootstrap', 'replicate']).to_pandas()
        # add mode weight profile to concatenated profile
        concat_profile_df = pd.concat([concat_profile_df, profile_df], axis=1)
    else:
        profile_df = concat_profile_df
    # calculate correlation matrix
    corr_df = profile_df.T.corr()
    # Precalculate linkage to extract clusters later
    link = hierarchy.linkage(distance.pdist(np.asarray(corr_df)))
    links[mode] = link
    columns[mode] = profile_df.columns
    # make clustered heatmap
    # using precalculated linkage
    g = sns.clustermap(
        corr_df.fillna(0), 
        row_linkage=link, col_linkage=link,
        mask=corr_df.isna(), 
        cmap='PuOr_r', vmin=-1, vmax=1, 
        cbar_kws={'shrink':0.5, 'label':'Pearson\nCorrelation'}, 
        xticklabels=True, yticklabels=True
    )
    # g = sns.clustermap(
    #     corr_df.fillna(0), mask=corr_df.isna(), cmap='PuOr_r', vmin=-1, vmax=1, cbar_kws={'shrink':0.5, 'label':'Pearson\nCorrelation'}, 
    #     xticklabels=True, yticklabels=True
    # )
    g.fig.suptitle(f"Similarity of {mode.capitalize()} Weights Between Components", y=1.02); 
    plt.show(g)
    

Get taxon clusters

In [None]:
mode = 'taxon_trim'
t = 20
clust = hierarchy.fcluster(links[mode], t=t, criterion='maxclust')
profile_df = ds[f"{mode}_weights"].median(dim=['bootstrap', 'replicate']).to_pandas()

cmap = plt.get_cmap('tab20').colors
lut = dict(zip(np.unique(clust), cmap))
# idx_df = hierarchy.leaves_list(links[mode])
row_colors = [lut[cl] for cl in clust]

# calculate correlation matrix
corr_df = profile_df.T.corr()
# Precalculate linkage to extract clusters later
link = hierarchy.linkage(distance.pdist(np.asarray(corr_df)))
links[mode] = link
columns[mode] = profile_df.columns
# make clustered heatmap
# using precalculated linkage
g = sns.clustermap(
    corr_df.fillna(0), 
    row_linkage=link, col_linkage=link,
    row_colors=row_colors,
    col_colors=row_colors,
    mask=corr_df.isna(), 
    cmap='PuOr_r', vmin=-1, vmax=1, 
    cbar_kws={'shrink':0.5, 'label':'Pearson\nCorrelation'}, 
    xticklabels=True, yticklabels=True
)
# g = sns.clustermap(
#     corr_df.fillna(0), mask=corr_df.isna(), cmap='PuOr_r', vmin=-1, vmax=1, cbar_kws={'shrink':0.5, 'label':'Pearson\nCorrelation'}, 
#     xticklabels=True, yticklabels=True
# )
g.fig.suptitle(f"Similarity of {mode.capitalize()} Weights Between Components", y=1.02); 
plt.show(g)


for cl in np.unique(clust):
    color = lut[cl]
    bool_cl = clust == cl
    profile_cl = profile_df[bool_cl]
    print(profile_cl.index)

    pcl_mean = profile_cl.mean(axis=0).rename('mean')
    pcl_std = profile_cl.std(axis=0).rename('std')
    pcl_ms = pd.concat([pcl_mean, pcl_std], axis=1).T
    profile_cl = pd.concat(
        [profile_cl, pcl_ms], axis=0
    ).T
    profile_cl.index = profile_cl.index.map(str)
    fig, axis = plt.subplots(figsize=(min(15, len(df)), 3))
    sns.barplot(profile_cl, x=mode, y='mean', color=color, legend=False, ax=axis); 
    axis.errorbar(x=profile_cl.index, y=profile_cl['mean'], yerr=profile_cl['std'], fmt='none', color=sns.color_palette()[5])
    axis.set(title=f"Cluster {cl} Top {mode.capitalize()} Mean Weights", xlabel=None, ylabel='Weight'); 
    fig.autofmt_xdate(rotation=90)
    plt.show()


In [None]:
dkn = {'K00265': 'K00265  gltB; glutamate synthase (NADPH) large chain [EC:1.4.1.13]',
 'K01012': 'K01012  bioB; biotin synthase [EC:2.8.1.6]',
 'K01595': 'K01595  ppc; phosphoenolpyruvate carboxylase [EC:4.1.1.31]',
 'K01672': 'K01672  CA; carbonic anhydrase [EC:4.2.1.1]',
 'K01726': 'K01726  GAMMACA; gamma-carbonic anhydrase [EC:4.2.1.-]',
 'K02217': 'K02217  ftnA, ftn; ferritin [EC:1.16.3.2]',
 'K02255': 'K02255  ftnB; ferritin-like protein 2',
 'K02364': 'K02364  entF; L-serine---[L-seryl-carrier protein] ligase [EC:6.3.2.14 6.2.1.72]',
 'K02638': 'K02638  petE; plastocyanin',
 'K02639': 'K02639  petF; ferredoxin',
 'K03320': 'K03320  amt, AMT, MEP; ammonium transporter, Amt family',
 'K04564': 'K04564  SOD2; superoxide dismutase, Fe-Mn family [EC:1.15.1.1]',
 'K04641': 'K04641  bop; bacteriorhodopsin',
 'K04783': 'K04783  irp5, ybtE; yersiniabactin salicyl-AMP ligase [EC:6.3.2.-]',
 'K04784': 'K04784  irp2, HMWP2; yersiniabactin nonribosomal peptide synthetase',
 'K05524': 'K05524  fdxA; ferredoxin',
 'K07214': 'K07214  fes; iron(III)-enterobactin esterase [EC:3.1.1.108]',
 'K12237': 'K12237  vibF; nonribosomal peptide synthetase VibF',
 'K16087': 'K16087  TC.FEV.OM3, tbpA, hemR, lbpA, hpuB, bhuR, hugA, hmbR; hemoglobin/transferrin/lactoferrin receptor protein',
 'K19611': 'K19611  fepA, pfeA, iroN, pirA; ferric enterobactin receptor',
 'K22336': 'K22336  bfrB; bacterioferritin B [EC:1.16.3.1]',
 'K22552': 'K22552  mmcO; multicopper oxidase [EC:1.16.3.1]',
 'K23910': 'K23910  TFR2; transferrin receptor protein 2',
 'K25224': 'K25224  gapdh; glyceraldehyde-3-phosphate dehydrogenase (arsenate-transferring) [EC:1.2.1.107]'}
dtn = {
    '226': 'Alteromonas',
    '2864': 'Dinophyceae',
    '3041': 'Chlorophyta',
    '28211': 'Alphaproteobacteria',
    '31989': 'Paracoccaceae',
    '35127': 'Thalassiosira',
    '35677': 'Pelagomonas calceolata',
    '49546': 'Flavobacteriaceae',
    '135623': 'Vibrionales',
    '304208': "Pseudoalteromonas sp. '520P1 No. 412'",
    '487796': 'Flavobacteria bacterium MS024-2A',
    '1735725': 'Stramenopiles MAST-4',
    '2696291': 'Ochrophyta',
    '2854170': 'Roseobacteraceae'
 }


In [None]:
cl_select = [1,4]
# cl_select = [7, 17]
threshs = {
    'KO':0.2,
    'taxon_trim':0.1,
    'sample': 0.2,
    'mean_fms': 0.3
}
for cl in cl_select:
    print(cl)
    for i, mode in enumerate(modes):
        print(mode)
        profile_df = ds[f"{mode}_weights"].median(dim=['bootstrap', 'replicate']).to_pandas()
        bool_cl = clust == cl
        profile_cl = profile_df[bool_cl]
        pcl_max = np.abs(profile_cl).max(axis=0).rename('max')
        # fig, ax = plt.subplots()
        # ax.scatter(np.arange(pcl_max.shape[0]), pcl_max.sort_values(), color=color)
        # plt.show()
        bool_mode = pcl_max >= threshs[mode]
        pcl_sub = profile_cl.loc[:,bool_mode]
        # print(pcl_sub.columns)

        df = ds.sel({mode: ds[mode]})[f"{mode}_weights"].to_series().reset_index()
        df = df[
            df.component.isin(pcl_sub.index) 
            & df[mode].isin(pcl_sub.columns)
        ]
        mfms = []
        for c in pcl_sub.index:
            mean_fms = ds.fms_component.sel(component=c).values.ravel().mean()
            mfms.append(mean_fms)
        print(dict(zip(pcl_sub.index, mfms)))
        if mode == 'KO':
            df[mode] = [dkn[k] for k in df[mode]]
        if mode == 'taxon_trim':
            df[mode] = [dtn[str(t)] for t in df[mode]]
        g = sns.FacetGrid(df, row='component', aspect=5)
        g.map(sns.barplot, mode, f"{mode}_weights", order=df[mode].unique(), color=sns.color_palette()[modes.index(mode)+1], errorbar='sd');
        g.fig.suptitle(f"Cluster {cl} {mode.capitalize()} Component Weights > {threshs[mode]}", y = 1.005); 
        g.set_xticklabels(df[mode].unique(), rotation=45, ha='right'); 
        plt.show()


Get sample clusters

In [None]:
mode = 'sample'
max_sam = 10
t = 20
clust = hierarchy.fcluster(links[mode], t=t, criterion='maxclust')
profile_df = ds[f"{mode}_weights"].median(dim=['bootstrap', 'replicate']).to_pandas()

cmap = plt.get_cmap('tab20').colors
lut = dict(zip(np.unique(clust), cmap))
# idx_df = hierarchy.leaves_list(links[mode])
row_colors = [lut[cl] for cl in clust]

# calculate correlation matrix
corr_df = profile_df.T.corr()
# Precalculate linkage to extract clusters later
link = hierarchy.linkage(distance.pdist(np.asarray(corr_df)))
links[mode] = link
columns[mode] = profile_df.columns
# make clustered heatmap
# using precalculated linkage
g = sns.clustermap(
    corr_df.fillna(0), 
    row_linkage=link, col_linkage=link,
    row_colors=row_colors,
    col_colors=row_colors,
    mask=corr_df.isna(), 
    cmap='PuOr_r', vmin=-1, vmax=1, 
    cbar_kws={'shrink':0.5, 'label':'Pearson\nCorrelation'}, 
    xticklabels=True, yticklabels=True
)
# g = sns.clustermap(
#     corr_df.fillna(0), mask=corr_df.isna(), cmap='PuOr_r', vmin=-1, vmax=1, cbar_kws={'shrink':0.5, 'label':'Pearson\nCorrelation'}, 
#     xticklabels=True, yticklabels=True
# )
g.fig.suptitle(f"Similarity of {mode.capitalize()} Weights Between Components", y=1.02); 
plt.show(g)


for cl in np.unique(clust):
    color = lut[cl]
    bool_cl = clust == cl
    profile_cl = profile_df[bool_cl]
    print(profile_cl.index)


    pcl_mean = profile_cl.mean(axis=0).rename('mean')
    pcl_std = profile_cl.std(axis=0).rename('std')
    pcl_ms = pd.concat([pcl_mean, pcl_std], axis=1).T
    profile_cl = pd.concat(
        [profile_cl, pcl_ms], axis=0
    ).T.sort_values(by='mean',ascending=False)[:max_sam]
    profile_cl.index = profile_cl.index.map(str)
    fig, axis = plt.subplots(figsize=(min(15, len(df)), 3))
    sns.barplot(profile_cl, x=mode, y='mean', color=color, legend=False, ax=axis); 
    axis.errorbar(x=profile_cl.index, y=profile_cl['mean'], yerr=profile_cl['std'], fmt='none', color=sns.color_palette()[5])
    axis.set(title=f"Cluster {cl} Top {mode.capitalize()} Mean Weights", xlabel=None, ylabel='Weight'); 
    fig.autofmt_xdate(rotation=90)
    plt.show()


In [None]:
dkn = {'K00265': 'K00265  gltB; glutamate synthase (NADPH) large chain [EC:1.4.1.13]',
 'K00368': 'K00368  nirK; nitrite reductase (NO-forming) [EC:1.7.2.1]',
 'K00615': 'K00615  E2.2.1.1, tktA, tktB; transketolase [EC:2.2.1.1]',
 'K00855': 'K00855  PRK, prkB; phosphoribulokinase [EC:2.7.1.19]',
 'K01012': 'K01012  bioB; biotin synthase [EC:2.8.1.6]',
 'K01601': 'K01601  rbcL, cbbL; ribulose-bisphosphate carboxylase large chain [EC:4.1.1.39]',
 'K01602': 'K01602  rbcS, cbbS; ribulose-bisphosphate carboxylase small chain [EC:4.1.1.39]',
 'K01624': 'K01624  FBA, fbaA; fructose-bisphosphate aldolase, class II [EC:4.1.2.13]',
 'K01672': 'K01672  CA; carbonic anhydrase [EC:4.2.1.1]',
 'K01673': 'K01673  cynT, can; carbonic anhydrase [EC:4.2.1.1]',
 'K02364': 'K02364  entF; L-serine---[L-seryl-carrier protein] ligase [EC:6.3.2.14 6.2.1.72]',
 'K02639': 'K02639  petF; ferredoxin',
 'K02689': 'K02689  psaA; photosystem I P700 chlorophyll a apoprotein A1 [EC:1.97.1.12]',
 'K02695': 'K02695  psaH; photosystem I subunit VI',
 'K02699': 'K02699  psaL; photosystem I subunit XI',
 'K02703': 'K02703  psbA; photosystem II P680 reaction center D1 protein [EC:1.10.3.9]',
 'K02704': 'K02704  psbB; photosystem II CP47 chlorophyll apoprotein',
 'K02705': 'K02705  psbC; photosystem II CP43 chlorophyll apoprotein',
 'K02706': 'K02706  psbD; photosystem II P680 reaction center D2 protein [EC:1.10.3.9]',
 'K02708': 'K02708  psbF; photosystem II cytochrome b559 subunit beta',
 'K02711': 'K02711  psbJ; photosystem II PsbJ protein',
 'K02717': 'K02717  psbP; photosystem II oxygen-evolving enhancer protein 2',
 'K02718': 'K02718  psbT; photosystem II PsbT protein',
 'K02721': 'K02721  psbW; photosystem II PsbW protein',
 'K02722': 'K02722  psbX; photosystem II PsbX protein',
 'K02724': 'K02724  psbZ; photosystem II PsbZ protein',
 'K03320': 'K03320  amt, AMT, MEP; ammonium transporter, Amt family',
 'K03542': 'K03542  psbS; photosystem II 22kDa protein',
 'K03594': 'K03594  bfr; bacterioferritin [EC:1.16.3.1]',
 'K03839': 'K03839  fldA, nifF, isiB; flavodoxin I',
 'K03841': 'K03841  FBP, fbp; fructose-1,6-bisphosphatase I [EC:3.1.3.11]',
 'K04564': 'K04564  SOD2; superoxide dismutase, Fe-Mn family [EC:1.15.1.1]',
 'K04565': 'K04565  SOD1; superoxide dismutase, Cu-Zn family [EC:1.15.1.1]',
 'K04755': 'K04755  fdx; ferredoxin, 2Fe-2S',
 'K04759': 'K04759  feoB; ferrous iron transport protein B',
 'K04784': 'K04784  irp2, HMWP2; yersiniabactin nonribosomal peptide synthetase',
 'K04787': 'K04787  mbtA; mycobactin salicyl-AMP ligase [EC:6.3.2.-]',
 'K05374': 'K05374  irp4, ybtT; yersiniabactin synthetase, thioesterase component',
 'K05524': 'K05524  fdxA; ferredoxin',
 'K07214': 'K07214  fes; iron(III)-enterobactin esterase [EC:3.1.1.108]',
 'K08940': 'K08940  pscA; photosystem P840 reaction center large subunit',
 'K10850': 'K10850  narT; MFS transporter, NNP family, putative nitrate transporter',
 'K11645': 'K11645  fbaB; fructose-bisphosphate aldolase, class I [EC:4.1.2.13]',
 'K12237': 'K12237  vibF; nonribosomal peptide synthetase VibF',
 'K13859': 'K13859  SLC4A8; solute carrier family 4 (sodium bicarbonate cotransporter), member 8',
 'K13860': 'K13860  SLC4A9, AE4; solute carrier family 4 (sodium bicarbonate cotransporter), member 9',
 'K16087': 'K16087  TC.FEV.OM3, tbpA, hemR, lbpA, hpuB, bhuR, hugA, hmbR; hemoglobin/transferrin/lactoferrin receptor protein',
 'K21567': 'K21567  fnr; ferredoxin/flavodoxin---NADP+ reductase [EC:1.18.1.2 1.19.1.1]',
 'K22336': 'K22336  bfrB; bacterioferritin B [EC:1.16.3.1]',
 'K22552': 'K22552  mmcO; multicopper oxidase [EC:1.16.3.1]',
 'K23723': 'K23723  iroD; iron(III)-salmochelin esterase [EC:3.1.1.109]',
 'K24110': 'K24110  asbC; 3,4-dihydroxybenzoate---[aryl-carrier protein] ligase [EC:6.2.1.62]',
 'K00264': 'K00264  GLT1; glutamate synthase (NADH) [EC:1.4.1.14]',
 'K00266': 'K00266  gltD; glutamate synthase (NADPH) small chain [EC:1.4.1.13]',
 'K00284': 'K00284  GLU, gltS; glutamate synthase (ferredoxin) [EC:1.4.7.1]',
 'K00362': 'K00362  nirB; nitrite reductase (NADH) large subunit [EC:1.7.1.15]',
 'K00522': 'K00522  FTH1; ferritin heavy chain [EC:1.16.3.1]',
 'K00532': 'K00532  E1.12.7.2; ferredoxin hydrogenase [EC:1.12.7.2]',
 'K02011': 'K02011  afuB, fbpB; iron(III) transport system permease protein',
 'K02574': 'K02574  napH; ferredoxin-type protein NapH',
 'K02697': 'K02697  psaJ; photosystem I subunit IX',
 'K02714': 'K02714  psbM; photosystem II PsbM protein',
 'K02716': 'K02716  psbO; photosystem II oxygen-evolving enhancer protein 1',
 'K04641': 'K04641  bop; bacteriorhodopsin',
 'K04786': 'K04786  irp1, HMWP1; yersiniabactin nonribosomal peptide/polyketide synthase',
 'K06441': 'K06441  E1.12.7.2G; ferredoxin hydrogenase gamma subunit [EC:1.12.7.2]',
 'K06503': 'K06503  TFRC, CD71; transferrin receptor',
 'K08906': 'K08906  petJ; cytochrome c6',
 'K11959': 'K11959  urtA; urea transport system substrate-binding protein',
 'K13575': 'K13575  SLC4A4, NBC1; solute carrier family 4 (sodium bicarbonate cotransporter), member 4',
 'K14578': 'K14578  nahAb, nagAb, ndoA, nbzAb, dntAb; naphthalene 1,2-dioxygenase ferredoxin component',
 'K15579': 'K15579  nrtD, cynD; nitrate/nitrite transport system ATP-binding protein',
 'K19611': 'K19611  fepA, pfeA, iroN, pirA; ferric enterobactin receptor',
 'K19791': 'K19791  FET3_5; iron transport multicopper oxidase',
 'K21949': 'K21949  sbnA; N-(2-amino-2-carboxyethyl)-L-glutamate synthase [EC:2.5.1.140]',
 'K22338': 'K22338  hylA; formate dehydrogenase (NAD+, ferredoxin) subunit A [EC:1.17.1.11]',
 'K22339': 'K22339  hylB; formate dehydrogenase (NAD+, ferredoxin) subunit B [EC:1.17.1.11]',
 'K23184': 'K23184  fecE; ferric citrate transport system ATP-binding protein [EC:7.2.2.18]',
 'K23910': 'K23910  TFR2; transferrin receptor protein 2',
 'K25286': 'K25286  fagD, cchF, irp1A, piaA; iron-siderophore transport system substrate-binding protein',
 'K02012': 'K02012  afuA, fbpA; iron(III) transport system substrate-binding protein',
 'K02690': 'K02690  psaB; photosystem I P700 chlorophyll a apoprotein A2 [EC:1.97.1.12]',
 'K00372': 'K00372  nasC,  nasA; assimilatory nitrate reductase catalytic subunit [EC:1.7.99.-]',
 'K01595': 'K01595  ppc; phosphoenolpyruvate carboxylase [EC:4.1.1.31]',
 'K02719': 'K02719  psbU; photosystem II PsbU protein',
 'K02720': 'K02720  psbV; photosystem II cytochrome c550',
 'K04783': 'K04783  irp5, ybtE; yersiniabactin salicyl-AMP ligase [EC:6.3.2.-]',
 'K05710': 'K05710  hcaC; 3-phenylpropionate/trans-cinnamate dioxygenase ferredoxin component',
 'K14718': 'K14718  SLC39A12, ZIP12; solute carrier family 39 (zinc transporter), member 12',
 'K23725': 'K23725  iroB; enterobactin C-glucosyltransferase [EC:2.4.1.369]',
 'K24245': "K24245  ligXd; 5,5'-dehydrodivanillate O-demethylase ferredoxin reductase subunit [EC:1.18.1.-]"}
dtn = {
    '226': 'Alteromonas',
    '2864': 'Dinophyceae',
    '3041': 'Chlorophyta',
    '28211': 'Alphaproteobacteria',
    '31989': 'Paracoccaceae',
    '35127': 'Thalassiosira',
    '35677': 'Pelagomonas calceolata',
    '49546': 'Flavobacteriaceae',
    '135623': 'Vibrionales',
    '304208': "Pseudoalteromonas sp. '520P1 No. 412'",
    '487796': 'Flavobacteria bacterium MS024-2A',
    '1735725': 'Stramenopiles MAST-4',
    '2696291': 'Ochrophyta',
    '2854170': 'Roseobacteraceae'
 }


In [None]:
cl_select = [17]
# cl_select = [7, 17]
threshs = {
    'KO':0.2,
    'taxon_trim':0.2,
    'sample': 0.2
}
for cl in cl_select:
    print(cl)
    for i, mode in enumerate(modes):
        print(mode)
        profile_df = ds[f"{mode}_weights"].median(dim=['bootstrap', 'replicate']).to_pandas()
        bool_cl = clust == cl
        profile_cl = profile_df[bool_cl]
        pcl_max = np.abs(profile_cl).max(axis=0).rename('max')
        # fig, ax = plt.subplots()
        # ax.scatter(np.arange(pcl_max.shape[0]), pcl_max.sort_values(), color=color)
        # plt.show()
        bool_mode = pcl_max >= threshs[mode]
        pcl_sub = profile_cl.loc[:,bool_mode]
        # print(pcl_sub.columns)

        df = ds.sel({mode: ds[mode]})[f"{mode}_weights"].to_series().reset_index()
        df = df[
            df.component.isin(pcl_sub.index) 
            & df[mode].isin(pcl_sub.columns)
        ]
        if mode == 'KO':
            df[mode] = [dkn[k] for k in df[mode]]
            for i, row in pcl_sub.iterrows():
                mean_fms = ds.fms_component.sel(component=i).values.ravel().mean()
                print(f"Component {i}")
                print(f"Mean FMS {round(mean_fms,2)}")
                for k, v in row.items():
                    if v > 0.1:
                        print(round(v, 2), dkn[k])
        if mode == 'taxon_trim':
            df[mode] = [dtn[str(t)] for t in df[mode]]
        g = sns.FacetGrid(df, row='component', aspect=5)
        g.map(sns.barplot, mode, f"{mode}_weights", order=df[mode].unique(), color=sns.color_palette()[modes.index(mode)+1], errorbar='sd');
        g.fig.suptitle(f"Cluster {cl} {mode.capitalize()} Component Weights > {threshs[mode]}", y = 1.005); 
        g.set_xticklabels(df[mode].unique(), rotation=45, ha='right'); 
        plt.show()

Get KO clusters

In [None]:
mode = 'KO'
max_sam = 10
t = 15
clust = hierarchy.fcluster(links[mode], t=t, criterion='maxclust')
profile_df = ds[f"{mode}_weights"].median(dim=['bootstrap', 'replicate']).to_pandas()

cmap = plt.get_cmap('tab20').colors
lut = dict(zip(np.unique(clust), cmap))
# idx_df = hierarchy.leaves_list(links[mode])
row_colors = [lut[cl] for cl in clust]

# calculate correlation matrix
corr_df = profile_df.T.corr()
# Precalculate linkage to extract clusters later
link = hierarchy.linkage(distance.pdist(np.asarray(corr_df)))
links[mode] = link
columns[mode] = profile_df.columns
# make clustered heatmap
# using precalculated linkage
g = sns.clustermap(
    corr_df.fillna(0), 
    row_linkage=link, col_linkage=link,
    row_colors=row_colors,
    col_colors=row_colors,
    mask=corr_df.isna(), 
    cmap='PuOr_r', vmin=-1, vmax=1, 
    cbar_kws={'shrink':0.5, 'label':'Pearson\nCorrelation'}, 
    xticklabels=True, yticklabels=True
)
# g = sns.clustermap(
#     corr_df.fillna(0), mask=corr_df.isna(), cmap='PuOr_r', vmin=-1, vmax=1, cbar_kws={'shrink':0.5, 'label':'Pearson\nCorrelation'}, 
#     xticklabels=True, yticklabels=True
# )
g.fig.suptitle(f"Similarity of {mode.capitalize()} Weights Between Components", y=1.02); 
plt.show(g)


for cl in np.unique(clust):
    color = lut[cl]
    bool_cl = clust == cl

    if sum(bool_cl) > 2:
        profile_cl = profile_df[bool_cl]
        print(profile_cl.index)

        pcl_mean = profile_cl.mean(axis=0).rename('mean')
        pcl_std = profile_cl.std(axis=0).rename('std')
        pcl_ms = pd.concat([pcl_mean, pcl_std], axis=1).T
        profile_cl = pd.concat(
            [profile_cl, pcl_ms], axis=0
        ).T.sort_values(by='mean',ascending=False)[:max_sam]
        profile_cl.index = profile_cl.index.map(str)
        print(profile_cl.index)
        fig, axis = plt.subplots(figsize=(min(15, len(df)), 3))
        sns.barplot(profile_cl, x=mode, y='mean', color=color, legend=False, ax=axis); 
        axis.errorbar(x=profile_cl.index, y=profile_cl['mean'], yerr=profile_cl['std'], fmt='none', color=sns.color_palette()[5])
        axis.set(title=f"Cluster {cl} Top {mode.capitalize()} Mean Weights", xlabel=None, ylabel='Weight'); 
        fig.autofmt_xdate(rotation=90)
        plt.show()


In [None]:
# compare top weights (mode 0) between components

# USER INPUTS -- edit as needed
max_elements = 20    # maximum number of elements visualized in any one component
viz_components = [1, 2, 3, 4, 5]    # list components you want to compare against one another
heuristic = 'max_weight'    # pull out the top weights across any component
# heuristic = 'max_variation'    # pull out the weights that vary the most across components

# pull out the top weights
mode = modes[0]
if heuristic == 'max_weight':
    # weights sorted by maximum across all components
    idx = np.abs(ds[f"{mode}_weights"].median(dim=['bootstrap', 'replicate'])).max(dim='component').argsort()
elif heuristic == 'max_variation': 
    idx = ds[f"{mode}_weights"].median(dim=['bootstrap', 'replicate']).std(dim='component').argsort()
df = ds.sel({mode: ds[mode][idx[-max_elements:].data[::-1]]})[f"{mode}_weights"].to_series().reset_index()

# plot data, only including listed components
g = sns.FacetGrid(df[df.component.isin(viz_components)], row='component', aspect=5)
g.map(sns.barplot, mode, f"{mode}_weights", order=df[mode].unique(), color=sns.color_palette()[modes.index(mode)+1], errorbar='sd');
g.fig.suptitle(f"Component Weights of Top {max_elements} {mode.capitalize()}s", y=1.02); 
g.set_xticklabels(df[mode].unique(), rotation=90); 


df[mode].unique()

In [None]:
# compare top weights (mode 1) between components

# USER INPUTS -- edit as needed
max_elements = 20    # maximum number of elements visualized in any one component
viz_components = [1, 2, 3, 4, 5]    # list components you want to compare against one another
heuristic = 'max_weight'    # pull out the top weights across any component
# heuristic = 'max_variation'    # pull out the weights that vary the most across components

# pull out the top weights
mode = modes[1]
if heuristic == 'max_weight':
    # weights sorted by maximum across all components
    idx = np.abs(ds[f"{mode}_weights"].median(dim=['bootstrap', 'replicate'])).max(dim='component').argsort()
elif heuristic == 'max_variation': 
    idx = ds[f"{mode}_weights"].median(dim=['bootstrap', 'replicate']).std(dim='component').argsort()
df = ds.sel({mode: ds[mode][idx[-max_elements:].data[::-1]]})[f"{mode}_weights"].to_series().reset_index()

# plot data, only including listed components
g = sns.FacetGrid(df[df.component.isin(viz_components)], row='component', aspect=5)
g.map(sns.barplot, mode, f"{mode}_weights", order=df[mode].unique(), color=sns.color_palette()[modes.index(mode)+1], errorbar='sd');
g.fig.suptitle(f"Component Weights of Top {max_elements} {mode.capitalize()}s", y=1.02); 
g.set_xticklabels(df[mode].unique(), rotation=90); 

df[mode].unique()

In [None]:
# compare top weights (mode 2 -- sample mode) between components

# USER INPUTS -- edit as needed
max_elements = 20    # maximum number of elements visualized in any one component
viz_components = [1, 2, 3, 4, 5]    # list components you want to compare against one another
heuristic = 'max_weight'    # pull out the top weights across any component
# heuristic = 'max_variation'    # pull out the weights that vary the most across components

# pull out the top weights
mode = modes[2]
if heuristic == 'max_weight':
    # weights sorted by maximum across all components
    idx = np.abs(ds[f"{mode}_weights"].median(dim=['bootstrap', 'replicate'])).max(dim='component').argsort()
elif heuristic == 'max_variation': 
    idx = ds[f"{mode}_weights"].median(dim=['bootstrap', 'replicate']).std(dim='component').argsort()
df = ds.sel({mode: ds[mode][idx[-max_elements:].data[::-1]]})[f"{mode}_weights"].to_series().reset_index()

# plot data, only including listed components
g = sns.FacetGrid(df[df.component.isin(viz_components)], row='component', aspect=5)
g.map(sns.barplot, mode, f"{mode}_weights", order=df[mode].unique(), color=sns.color_palette()[modes.index(mode)+1], errorbar='sd');
g.fig.suptitle(f"Component Weights of Top {max_elements} {mode.capitalize()}s", y=1.02); 
g.set_xticklabels(df[mode].unique(), rotation=90); 


In [None]:
# combined figure that looks at top weights across all three modes for a subset of components

# USER INPUTS -- edit as needed
max_elements = 10    # maximum number of elements visualized in any one component
viz_components = [1, 2, 3, 4, 5]    # list components you want to compare against one another
heuristic = 'max_weight'    # pull out the top weights across any component
# heuristic = 'max_variation'    # pull out the weights that vary the most across components

# pull out the top weights
combined_df = pd.DataFrame()
for mode in modes[:3]:
    if heuristic == 'max_weight':
        # weights sorted by maximum across all components
        idx = np.abs(ds[f"{mode}_weights"].median(dim=['bootstrap', 'replicate'])).max(dim='component').argsort()
    elif heuristic == 'max_variation': 
        idx = ds[f"{mode}_weights"].median(dim=['bootstrap', 'replicate']).std(dim='component').argsort()
    df = ds.sel({mode: ds[mode][idx[-max_elements:].data[::-1]]})[f"{mode}_weights"].to_series().reset_index()
    df = df.rename(columns={mode: 'Label', f"{mode}_weights": 'Weights'})
    df['mode'] = mode
    combined_df = pd.concat([combined_df, df], axis=0)
# pull out just components of interest
combined_df = combined_df[combined_df.component.isin(viz_components)].reset_index(drop=True)

# plot combined data 
fig, axes = plt.subplots(len(viz_components), 3, figsize=(15, len(viz_components)*2), sharex='col', sharey='col')
for i, comp in enumerate(viz_components):
    for j, mode in enumerate(modes[:3]):
        plot_df = combined_df[combined_df.component.eq(comp) & combined_df['mode'].eq(mode)]
        sns.barplot(plot_df, x='Label', y='Weights', errorbar='sd', color=sns.color_palette()[j+1], ax=axes[i][j])
        if not i:
            axes[i][j].set(title=mode)
        if not j:
            axes[i][j].set(ylabel=f"Component {comp}\n\nWeights")
        if comp == viz_components[-1]:
            axes[i][j].tick_params(axis='x', labelrotation=90); 
            axes[i][j].set(xlabel=mode);
