HEATMAP FOR THE GEOCHEMICAL COVARIATE NO DIRECT COUPLING CASE 

In [14]:
import sys
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pylab as pltx
import seaborn as sns
import pystan
import pickle
import scipy.cluster.hierarchy as sch
import copy
import os

# 🔥 Fix: Add utils/ directory where vb_stan.py and sub_fun.py are located
sys.path.append("../utils")  # Ensure utils is in module search path

import vb_stan as vbfun
import sub_fun as sf

# import ./data_file.py file from current directory
try:
    from data_file import *
except ModuleNotFoundError as e:
    print(f"❌ Could not import data_file as module, trying another way...")
    try:
        exec(open('./data_file.py').read())  # Execute the script
    except Exception as e:
        print(f"❌ Could not import data_file.py")


import os

# 1. Read config_mode.txt
config_file = "config_mode.txt"
if os.path.exists(config_file):
    with open(config_file, "r") as f:
        lines = f.read().splitlines()
        data = lines[0].strip() if len(lines) > 0 else "original"
        setting = int(lines[1]) if len(lines) > 1 else 1


# 2. Set plots folder based on config
if data == "original" and setting == 1: 
    figfol = "./plots1/"
elif data == "original" and setting == 2:
    figfol = "./plots2/"
elif data == "new" and setting == 2: 
    figfol = "./plots3/"
elif data == "new" and setting == 1:
    figfol = "./plots4/"

if not os.path.exists(figfol):
    os.makedirs(figfol)
    print(f"✅ Created folder: {figfol}")
else:
    print(f"📂 Folder already exists: {figfol}")

if data == "original": 
    data_path = "data_op"
else: 
    data_path = "data_new"
    
# Update plot settings
plt.rcParams.update(plt.rcParamsDefault)
params = {
    'legend.fontsize': 12,
    'font.weight': 'bold',
    'axes.labelsize': 12,
    'axes.titlesize': 12,
    'axes.labelweight': 'bold',
    'xtick.labelsize': 12,
    'figure.dpi': 200,
    'ytick.labelsize': 12
}
plt.rcParams.update(params)

%matplotlib inline


📂 Folder already exists: ./plots3/


In [15]:
# Call the output from the fitted model:
import glob
import pickle
import os


# Define folder path
folname = '../src/'
#fname_best = '18_52_model_nb_cvtest.pkl' #original data ndc
fname_best = '30_61_model_nb_cvtest.pkl' #new data ndc 
#fname_best = '30_68_model_nb_cvtest.pkl' #original data dc 
#fname_best = '30_17_model_nb_cvtest.pkl' # new daat dc
# Load fitted model output
#with open(os.path.join(folname, fname_best), 'rb') as f:
#    [holdout_mask, llpd, n_test, l, m_seed, sp_mean, sp_var, h_prop, uid, nsample_o, Yte_fit, cv_test] = pickle.load(f)


with open(os.path.join(folname, fname_best), 'rb') as f:
    [holdout_mask, llpd, n_test, l, m_seed, sp_mean, sp_var, h_prop, uid, nsample_o, Yte_fit, cv_test] = pickle.load(f)

#with open(os.path.join(folname, fname_best), 'rb') as f:
#    [holdout_mask, _1, _2, _3, l,m_seed,sp_mean,\
#                 sp_var, h_prop, uid, mtype,\
#                 Yte_fit, cv_test, Y, muest, Yte_cv, _4, _5] = pickle.load(f)
 #Construct filenames using uid and m_seed
fname_ot = os.path.join(folname, f"{uid}_{m_seed}_model_nb.pkl")
sample_fname = os.path.join(folname, f"{uid}_{m_seed}_sample_model_nb_cvtest.pkl")


# Load results
with open(fname_ot, 'rb') as f:
    results = pickle.load(f)

#Load sample model output
with open(sample_fname, 'rb') as f:
   [Yte_sample, Yte_cv] = pickle.load(f)


parma_mean  = dict(vbfun.vb_extract_mean(results))

  params = OrderedDict([(name, np.nan * np.empty(shape)) for name, shape in param_shapes.items()])


In [16]:
import pandas as pd
import numpy as np



# -------------------------------
# Use different logic based on `data` flag
# -------------------------------
if data == "original":
    # Load taxonomy file
    tax_name = pd.read_csv(f'../data/{data_path}/species_tax.csv')
    tax_name = tax_name.rename(columns={'Unnamed: 0': 'OTU'})
    tax_name = tax_name[1:]  # Drop header row if repeated
    tax_name.insert(0, 'Id', tax_name['OTU'].str[3:])
    tax_name.columns.values[1] = 'Label'
    tax_name.to_csv('node_otu.csv', index=False)
    tax_name[['Id']] = tax_name[['Id']].astype(np.int64)
    tax_name = tax_name.replace(pd.NA, 'Empty').replace(np.nan, 'Empty')

    # Extract most specific taxonomy level as Name
    temx = tax_name.iloc[:, :8].replace('Empty', '')
    species_name = []
    for i in range(temx.shape[0]):
        a = temx.iloc[i].values
        for j in range(len(a)-1, -1, -1):
            if len(a[j]) > 0:
                species_name.append(temx.columns[j][0].lower() + '_' + a[j])
                break
    tax_name['Name'] = np.array(species_name)

# Add Abundance (assumes Y is defined)
    tax_name['Abundance'] = Y.sum(axis=0)
    tem = pd.read_csv(f'../data/{data_path}/species_tax_anot.amended.csv').iloc[:, [1, 12]]
    tax_name = tax_name.merge(tem, on='Label')
    tax_name = tax_name.rename(columns={"Ecologically_relevant_classification_aggregated": "ECR"})


    # Group rare ECR categories into 'Other'
    ind_var = tax_name['ECR'].values
    vals, counts = np.unique(ind_var, return_counts=True)
    sorted_vals = vals[(-counts).argsort()]
    common_vals = sorted_vals[:np.sum(counts > 10)]
    rare_vals = np.setdiff1d(vals, common_vals)
    tax_name['ECR'] = np.where(np.isin(ind_var, rare_vals), 'Other', ind_var)
    print("✅ Taxonomic annotation completed.")
    print("Top ECR categories:\n", tax_name['ECR'].value_counts().head())
elif data == "new":
    # Load species taxonomy
    tax_name = pd.read_csv(f'../data/{data_path}/species_tax.csv')

    # Clean up
    tax_name = tax_name.rename(columns={'Unnamed: 0': 'Label'})
    tax_name['Label'] = tax_name['Label'].str.strip()  # 🔑 Remove any whitespace
    tax_name = tax_name.replace(pd.NA, 'Empty').replace(np.nan, 'Empty')
    tax_name.to_csv('node_otu.csv', index = False) 
    species_name = []
    # Extract most specific taxonomy
    temx = tax_name.iloc[:, :8].replace('Empty', '')

    for i in range(temx.shape[0]):
        a = temx.iloc[i,:].values
        for j in range(a.shape[0]-1,-1,-1):
            if len(a[j]) > 0:
                species_name.append(temx.columns[j][0].lower()+'_'+ a[j])
                break;           
    species_name = np.array(species_name)  
    tax_name['Name'] = species_name

    tax_name.loc[:Y.shape[1]-1, 'Abundance'] = Y.sum(axis=0)

    tax_name['Id'] = tax_name['Label'].str.extract(r'(\d+)').astype(int)

    tax_name = tax_name.replace(np.nan,'Empty')
    #tem = pd.read_csv(f'../data/{data_path}/species_tax_anot.amended.csv')[['name', 'erc']]
    #tem_og = pd.read_csv(f'../data/data_op/species_tax_anot.amended.csv').iloc[:, [1, 12]]
    #tem = pd.read_csv(f'../data/{data_path}/Filtered_OTU_New_Annotation.csv')[['name', 'erc']]
    tem = pd.read_csv(f'../data/{data_path}/otu_new_annotation(1).csv')[['name', 'erc']]
    tax_name = tax_name.merge(tem, left_on='Label', right_on='name')
    # ✅ Show summary
    print("✅ Taxonomic annotation completed.")
    print("Top ECR categories:\n", tax_name['erc'].value_counts().head())



✅ Taxonomic annotation completed.
Top ECR categories:
 Flavobacteriales                 108
SAR11 clade                      101
Marinimicrobia (SAR406 clade)     82
Gamma-other                       75
Chloroplast (unclassified)        63
Name: erc, dtype: int64


In [17]:
if data == "original":
    selected_species = list(np.unique(tax_name['ECR']))
    #selected_species.remove('Other')
    selected_species_index = tax_name['ECR'].isin(selected_species).values
    #species_col_dict = dict(zip(selected_species,distinct_colp[:len(selected_species)]))
else: 
    selected_species = list(np.unique(tax_name['erc'].astype(str)))
    #selected_species = list(np.unique(tax_name['erc']))
    #selected_species.remove('Other')
    selected_species_index = tax_name['erc'].isin(selected_species).values
    

In [18]:
parma_mean.keys()

dict_keys(['C0', 'A_geo', 'L_sp', 'L_i', 'A_s', 'A_b', 'A_m', 'A_d', 'tau', 'phi'])

In [19]:
cov_mat = np.matmul(parma_mean['L_sp'],parma_mean['A_geo'].T)
if data == "new": 
    tax_name = tax_name.iloc[:cov_mat.shape[0]]
#cov_mat  = cov_mat  #cov_mat.max() - cov_mat
#cov_mat = (cov_mat + cov_mat.T)/2
#np.fill_diagonal(cov_mat,0)

if data == "original":
    selected_species_index = tax_name['ECR'].isin(selected_species).values
else: 
    selected_species_index = tax_name['erc'].isin(selected_species).values


In [21]:
if data == "original":
    datframe = pd.DataFrame(cov_mat)
    datframe["ecr"] = tax_name["ECR"].values
else: 
    datframe = pd.DataFrame(cov_mat)
    datframe["ecr"] = tax_name["erc"].values
datframe 

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,ecr
0,-3.709919,0.441534,-1.843652,0.259526,-0.621473,2.198524,-0.231371,1.178850,-0.718751,0.170559,Verrucomicrobiota
1,5.637025,0.403506,1.547831,-0.351452,-0.672991,-1.478874,0.926042,0.287345,-0.597356,-0.083949,Flavobacteriales
2,-3.804167,0.114468,-1.691990,-0.476263,-0.611059,3.413498,-0.161296,1.558965,-1.344586,0.441754,Gamma-other
3,-1.655438,-0.425480,-3.475665,0.418488,2.320011,-2.309842,0.005184,0.129052,0.853060,-0.305093,Chloroplast (unclassified)
4,-2.234914,-0.501028,-4.837134,0.271498,2.612486,-2.804389,0.072202,0.830153,1.609132,-0.544629,Puniceispirillales (SAR116 clade)
...,...,...,...,...,...,...,...,...,...,...,...
1071,-2.687384,0.950909,-3.347218,1.188903,-0.759426,-1.041389,0.304689,1.258727,0.124697,-0.089685,SAR86 clade
1072,-4.977252,-0.834797,1.746771,-0.351799,1.519845,-2.192408,-2.168740,-3.417560,2.630264,-1.261861,Oceanospirillales
1073,8.617831,-0.725060,0.191279,0.673926,-1.387640,-2.079884,1.336120,0.715977,0.882692,-0.980807,Nitrosopumilus
1074,0.971710,-0.015931,2.739015,0.394707,1.648411,0.007113,-1.299660,-1.916586,1.406760,-0.959849,Chloroplast (unclassified)


In [23]:
datframe["ecr"].unique()

array(['Verrucomicrobiota', 'Flavobacteriales', 'Gamma-other',
       'Chloroplast (unclassified)', 'Puniceispirillales (SAR116 clade)',
       'Poseidoniales (MGII Archaea)', 'Marinimicrobia (SAR406 clade)',
       'Cytophagales', 'Alpha-other', 'Marine Actinobacteria',
       'CPR bacteria', 'SAR86 clade', 'Bacteria-other',
       'Rhodospirillales', 'Bacteroidota-other', 'SAR11 clade', 'Unkown',
       'Oceanospirillales', 'Sphingomonadales', 'Pseudomonadales',
       'Nitrosopumilus', 'Archaea-other', 'Alteromonadales',
       'SAR324 clade', 'Ectothiorhodospiraceae (Gammaproteobacteria)',
       'Synechococcus', 'Roseobacter clade', 'SAR202 clade',
       'MGIII Archaea', 'Prochlorococcus', 'Planctomycetota',
       'Bdellovibrionota', 'Cyanobacteria-other', 'Myxococcaceae',
       'Nitrosomonadales', 'Desulfovibrio', 'Nitrospinota',
       'SAR92 clade (Gammaproteobacteria)', 'Nitrosococcales',
       'Burkholderiales', 'Nitrospirales'], dtype=object)

In [None]:
"""
target_ecrs = [
    'Prochlorococcus', 'SAR86 clade', 'Rhodospirillales', 'SAR202 clade',
    'SAR324 clade', 'Other', 'Poseidoniales (MGII Archaea)',
    'Marinimicrobia (SAR406 clade)', 'Flavobacteriales', 'SAR11 clade',
    'Puniceispirillales (SAR116 clade)', 'Cytophagales', 'Synechococcus',
    'Roseobacter clade', 'Marine Actinobacteria', 'Oceanospirillales',
    'Bacteria-other', 'Verrucomicrobiota', 'Alpha-other', 'Nitrosopumilus',
    'Gamma-other', 'Chloroplast (unclassified)', 'Sphingomonadales',
    'Alteromonadales', 'Pseudomonadales'
]

datframe= datframe[datframe["ecr"].isin(target_ecrs)]
"""


In [25]:
if data == "new":
    new_column_names = ["Temperature","Salinity","Oxygen","NO2","PO4","NO2NO3","Si","SST","ChlorophyllA","Carbon.total","ecr"]  # your custom names
    datframe.columns = new_column_names
elif data == "original":
    new_column_names = ["Temperature","Salinity","Oxygen","Nitrates","NO2","PO4","NO2NO3","Si","SST","ecr"]  
    datframe.columns = new_column_names

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Assuming your dataframe is named `datframe`

# Set the index to 'ecr' and transpose the matrix
heatmap_data = datframe.drop(columns='ecr').T  # shape will be (features x taxa)
heatmap_data.columns = datframe['ecr'].values  # set column names as taxa

# Optional: group duplicate taxa by mean (in case there are multiple rows per taxon)
heatmap_data = heatmap_data.groupby(by=heatmap_data.columns, axis=1).mean()

# Plot the heatmap
plt.figure(figsize=(16, 8))
sns.heatmap(heatmap_data, cmap='bwr', center=0)

plt.xlabel("")
plt.ylabel("Geochemical Covariates")
plt.xticks(rotation=90)
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig(figfol +"heatmap_geochemical_taxa_nocoupling.png", dpi=300)
plt.close()
plt.show()