In [1]:
import sys
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pylab as pltx
import seaborn as sns
import pystan
import pickle
import scipy.cluster.hierarchy as sch
import copy
import os

# 🔥 Fix: Add utils/ directory where vb_stan.py and sub_fun.py are located
sys.path.append("../utils")  # Ensure utils is in module search path

import vb_stan as vbfun
import sub_fun as sf

# import ./data_file.py file from current directory
try:
    from data_file import *
except ModuleNotFoundError as e:
    print(f"❌ Could not import data_file as module, trying another way...")
    try:
        exec(open('./data_file.py').read())  # Execute the script
    except Exception as e:
        print(f"❌ Could not import data_file.py")


import os

# 1. Read config_mode.txt
config_file = "config_mode.txt"
if os.path.exists(config_file):
    with open(config_file, "r") as f:
        lines = f.read().splitlines()
        data = lines[0].strip() if len(lines) > 0 else "original"
        setting = int(lines[1]) if len(lines) > 1 else 1


# 2. Set plots folder based on config
if data == "original" and setting == 1: 
    figfol = "./plots1/"
elif data == "original" and setting == 2:
    figfol = "./plots2/"
elif data == "new" and setting == 2: 
    figfol = "./plots3/"
elif data == "new" and setting == 1: 
    figfol = "./plots4/"

if not os.path.exists(figfol):
    os.makedirs(figfol)
    print(f"✅ Created folder: {figfol}")
else:
    print(f"📂 Folder already exists: {figfol}")

if data == "original": 
    data_path = "data_op"
else: 
    data_path = "data_new"
    
# Update plot settings
plt.rcParams.update(plt.rcParamsDefault)
params = {
    'legend.fontsize': 12,
    'font.weight': 'bold',
    'axes.labelsize': 12,
    'axes.titlesize': 12,
    'axes.labelweight': 'bold',
    'xtick.labelsize': 12,
    'figure.dpi': 200,
    'ytick.labelsize': 12
}
plt.rcParams.update(params)

%matplotlib inline


Satellite data loaded: (180, 10)




📂 Folder already exists: ./plots4/


In [2]:
# Call the output from the fitted model:
import glob
import pickle
import os


# Define folder path
folname = '../src/'
#fname_best = '18_52_model_nb_cvtest.pkl' #original data ndc
#fname_best = '30_61_model_nb_cvtest.pkl' #new data nc 
#fname_best = '30_68_model_nb_cvtest.pkl' #original data dc
fname_best = '30_17_model_nb_cvtest.pkl'
# Load fitted model output
#with open(os.path.join(folname, fname_best), 'rb') as f:
#    [holdout_mask, llpd, n_test, l, m_seed, sp_mean, sp_var, h_prop, uid, nsample_o, Yte_fit, cv_test] = pickle.load(f)

with open(os.path.join(folname, fname_best), 'rb') as f:
    [holdout_mask, _1, _2, _3, l,m_seed,sp_mean,\
                 sp_var, h_prop, uid, mtype,\
                 Yte_fit, cv_test, Y, muest, Yte_cv, _4, _5] = pickle.load(f)

# Construct filenames using uid and m_seed
fname_ot = os.path.join(folname, f"{uid}_{m_seed}_model_nb.pkl")
#sample_fname = os.path.join(folname, f"{uid}_{m_seed}_sample_model_nb_cvtest.pkl")


# Load results
with open(fname_ot, 'rb') as f:
    results = pickle.load(f)

#Load sample model output
#with open(sample_fname, 'rb') as f:
#   [Yte_sample, Yte_cv] = pickle.load(f)


parma_mean  = dict(vbfun.vb_extract_mean(results))

In [3]:
# Load species taxonomy
tax_name = pd.read_csv(f'../data/{data_path}/species_tax.csv')

# Clean up
tax_name = tax_name.rename(columns={'Unnamed: 0': 'Label'})
tax_name['Label'] = tax_name['Label'].str.strip()  # 🔑 Remove any whitespace
tax_name = tax_name.replace(pd.NA, 'Empty').replace(np.nan, 'Empty')
tax_name.to_csv('node_otu.csv', index = False) 
species_name = []
# Extract most specific taxonomy
temx = tax_name.iloc[:, :8].replace('Empty', '')

for i in range(temx.shape[0]):
    a = temx.iloc[i,:].values
    for j in range(a.shape[0]-1,-1,-1):
        if len(a[j]) > 0:
            species_name.append(temx.columns[j][0].lower()+'_'+ a[j])
            break;           
species_name = np.array(species_name)  
tax_name['Name'] = species_name

tax_name.loc[:Y.shape[1]-1, 'Abundance'] = Y.sum(axis=0)

tax_name['Id'] = tax_name['Label'].str.extract(r'(\d+)').astype(int)

tax_name = tax_name.replace(np.nan,'Empty')
#tem = pd.read_csv(f'../data/{data_path}/species_tax_anot.amended.csv')[['name', 'erc']]
tem = pd.read_csv(f'../data/{data_path}/otu_new_annotation(1).csv')[['name', 'erc']]
tax_name = tax_name.merge(tem, left_on='Label', right_on='name')


In [4]:
selected_species = list(np.unique(tax_name['erc'].astype(str)))


In [5]:
#selected_species.remove('Other')
selected_species_index = tax_name['erc'].isin(selected_species).values

In [7]:
tax_name = tax_name.iloc[:cov_mat.shape[0]]

In [8]:
## Species- species interactiion matrix estimate 
cov_mat = np.matmul(parma_mean['L_sp'],parma_mean['L_i'].T)
cov_mat  = cov_mat  #cov_mat.max() - cov_mat
cov_mat = (cov_mat + cov_mat.T)/2
np.fill_diagonal(cov_mat,0)

selected_species_index = tax_name['erc'].isin(selected_species).values
dist_pos = copy.copy(cov_mat[selected_species_index][:,selected_species_index])
dist_neg = copy.copy(cov_mat[selected_species_index][:,selected_species_index])
for i in range(dist_pos.shape[0]):
    tem = dist_pos[i].argsort()
    dist_pos[i,tem[:-5]] = 0. 
    dist_neg[i,tem[5:]] = 0.
    
dist_neg = dist_neg*(-1.0)
    

In [10]:
tax_name

Unnamed: 0,Label,Domain,Phylum,Class,Order,Family,Genus,OTU.rep,Name,Abundance,Id,name,erc
0,OTU1,Bacteria,Verrucomicrobia,Opitutae,Puniceicoccales,Puniceicoccaceae,Empty,AACY020016122.205.1491,o_AACY020016122.205.1491,2995.0,1,OTU1,Verrucomicrobiota
1,OTU5,Bacteria,Bacteroidetes,Flavobacteriia,Flavobacteriales,Flavobacteriaceae,Formosa,AACY020065672.95.1597,o_AACY020065672.95.1597,4992.0,5,OTU5,Flavobacteriales
2,OTU6,Bacteria,Proteobacteria,Gammaproteobacteria,Xanthomonadales,JTB255 marine benthic group,Empty,AACY020075636.150.1670,o_AACY020075636.150.1670,3575.0,6,OTU6,Gamma-other
3,OTU7,Bacteria,Cyanobacteria,Chloroplast,Empty,Empty,Empty,AACY020080403.492.1950,o_AACY020080403.492.1950,1603.0,7,OTU7,Chloroplast (unclassified)
4,OTU10,Bacteria,Proteobacteria,Alphaproteobacteria,Rickettsiales,SAR116 clade,Empty,AACY020119224.701.2163,o_AACY020119224.701.2163,2477.0,10,OTU10,Puniceispirillales (SAR116 clade)
...,...,...,...,...,...,...,...,...,...,...,...,...,...
1071,OTU23826,Bacteria,Proteobacteria,Gammaproteobacteria,Oceanospirillales,SAR86 clade,Empty,LWDU01044435.2029.3537,o_LWDU01044435.2029.3537,4456.0,23826,OTU23826,SAR86 clade
1072,OTU23830,Bacteria,Proteobacteria,Gammaproteobacteria,Oceanospirillales,Oceanospirillaceae,Oleibacter,LWFP01001472.3412.4949,o_LWFP01001472.3412.4949,25176.0,23830,OTU23830,Oceanospirillales
1073,OTU23857,Archaea,Thaumarchaeota,Marine Group I,Unknown Order,Unknown Family,Candidatus Nitrosopelagicus,U40238.31654.33126,o_U40238.31654.33126,13702.0,23857,OTU23857,Nitrosopumilus
1074,OTU23882,Bacteria,Cyanobacteria,Chloroplast,Empty,Empty,Empty,U70724.1.1482,o_U70724.1.1482,1657.0,23882,OTU23882,Chloroplast (unclassified)


In [12]:
datframe = pd.DataFrame(parma_mean["C_geo"])
datframe["ecr"] = tax_name["erc"].values
datframe 

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,ecr
0,-0.021399,-0.002541,-0.041332,0.004537,0.046865,0.071289,0.038020,-0.000926,-0.014824,0.007989,Verrucomicrobiota
1,0.143674,0.021586,-0.012540,-0.022094,-0.105066,-0.090586,-0.028178,0.026776,0.000107,-0.010851,Flavobacteriales
2,-0.031800,0.000755,-0.018356,-0.003077,0.074416,0.178052,0.030701,-0.000764,-0.025322,-0.001424,Gamma-other
3,-0.002887,-0.013721,-0.068188,0.005015,0.059163,0.030677,0.025484,-0.000603,0.000924,0.005064,Chloroplast (unclassified)
4,-0.002652,-0.021925,-0.483265,0.001955,0.066333,0.005595,0.013430,0.024268,0.031513,-0.000140,Puniceispirillales (SAR116 clade)
...,...,...,...,...,...,...,...,...,...,...,...
1071,-0.001262,0.010871,-0.027956,0.052267,0.006804,0.020123,-0.001011,0.004420,-0.004821,-0.004525,SAR86 clade
1072,-0.032987,-0.074192,0.058638,-0.007561,-0.017639,-0.011653,-0.012687,-0.099353,0.056484,0.005299,Oceanospirillales
1073,0.163927,0.001070,0.015422,-0.007938,-0.046663,-0.085547,-0.055185,0.026022,0.002132,-0.012494,Nitrosopumilus
1074,-0.014978,-0.008897,0.049412,0.003404,-0.000996,0.004806,-0.012467,-0.028221,0.019208,-0.006225,Chloroplast (unclassified)


In [13]:
new_column_names = ["Temperature","Salinity","Oxygen","NO2","PO4","NO2NO3","Si","SST","ChlorophyllA","Carbon.total","ecr"]  # your custom names
datframe.columns = new_column_names

In [14]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Assuming your dataframe is named `datframe`

# Set the index to 'ecr' and transpose the matrix
heatmap_data = datframe.drop(columns='ecr').T  # shape will be (features x taxa)
heatmap_data.columns = datframe['ecr'].values  # set column names as taxa

# Optional: group duplicate taxa by mean (in case there are multiple rows per taxon)
heatmap_data = heatmap_data.groupby(by=heatmap_data.columns, axis=1).mean()

# Plot the heatmap
plt.figure(figsize=(16, 8))
sns.heatmap(heatmap_data, cmap='bwr', center=0)

plt.xlabel("")
plt.ylabel("Geochemical Covariates")
plt.xticks(rotation=90)
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig(figfol +"heatmap_geochemical_taxa.png", dpi=300)
plt.close()
plt.show()
