In [2]:
import sys
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pylab as pltx
import seaborn as sns
import pystan
import pickle
import scipy.cluster.hierarchy as sch
import copy
import os

# 🔥 Fix: Add utils/ directory where vb_stan.py and sub_fun.py are located
sys.path.append("../utils")  # Ensure utils is in module search path

import vb_stan as vbfun
import sub_fun as sf

# import ./data_file.py file from current directory
try:
    from data_file import *
except ModuleNotFoundError as e:
    print(f"❌ Could not import data_file as module, trying another way...")
    try:
        exec(open('./data_file.py').read())  # Execute the script
    except Exception as e:
        print(f"❌ Could not import data_file.py")


import os

# 1. Read config_mode.txt
config_file = "config_mode.txt"
if os.path.exists(config_file):
    with open(config_file, "r") as f:
        lines = f.read().splitlines()
        data = lines[0].strip() if len(lines) > 0 else "original"
        setting = int(lines[1]) if len(lines) > 1 else 1


# 2. Set plots folder based on config
if data == "original" and setting == 1: 
    figfol = "./plots1/"
elif data == "original" and setting == 2:
    figfol = "./plots2/"
elif data == "new" and setting == 2: 
    figfol = "./plots3/"

if not os.path.exists(figfol):
    os.makedirs(figfol)
    print(f"✅ Created folder: {figfol}")
else:
    print(f"📂 Folder already exists: {figfol}")

if data == "original": 
    data_path = "data_op"
else: 
    data_path = "data_new"
    
# Update plot settings
plt.rcParams.update(plt.rcParamsDefault)
params = {
    'legend.fontsize': 12,
    'font.weight': 'bold',
    'axes.labelsize': 12,
    'axes.titlesize': 12,
    'axes.labelweight': 'bold',
    'xtick.labelsize': 12,
    'figure.dpi': 200,
    'ytick.labelsize': 12
}
plt.rcParams.update(params)

%matplotlib inline




📂 Folder already exists: ./plots1/


In [None]:
# Call the output from the fitted model:
import glob
import pickle
import os


# Define folder path
folname = '../src/'
#fname_best = '18_52_model_nb_cvtest.pkl' #original data ndc
#fname_best = '30_68_model_nb_cvtest.pkl' #new data nc 
fname_best = '30_68_model_nb_cvtest.pkl' #original data dc 
# Load fitted model output
with open(os.path.join(folname, fname_best), 'rb') as f:
    [holdout_mask, llpd, n_test, l, m_seed, sp_mean, sp_var, h_prop, uid, nsample_o, Yte_fit, cv_test] = pickle.load(f)

# Construct filenames using uid and m_seed
fname_ot = os.path.join(folname, f"{uid}_{m_seed}_model_nb.pkl")
sample_fname = os.path.join(folname, f"{uid}_{m_seed}_sample_model_nb_cvtest.pkl")


# Load results
with open(fname_ot, 'rb') as f:
    results = pickle.load(f)

#Load sample model output
with open(sample_fname, 'rb') as f:
   [Yte_sample, Yte_cv] = pickle.load(f)


parma_mean  = dict(vbfun.vb_extract_mean(results))

  params = OrderedDict([(name, np.nan * np.empty(shape)) for name, shape in param_shapes.items()])


In [4]:
import pandas as pd
import numpy as np

# Load taxonomy file
tax_name = pd.read_csv(f'../data/{data_path}/species_tax.csv')
tax_name = tax_name.rename(columns={'Unnamed: 0': 'OTU'})
tax_name = tax_name[1:]  # Drop header row if repeated
tax_name.insert(0, 'Id', tax_name['OTU'].str[3:])
tax_name.columns.values[1] = 'Label'
tax_name.to_csv('node_otu.csv', index=False)
tax_name[['Id']] = tax_name[['Id']].astype(np.int64)
tax_name = tax_name.replace(pd.NA, 'Empty').replace(np.nan, 'Empty')

# Extract most specific taxonomy level as Name
temx = tax_name.iloc[:, :8].replace('Empty', '')
species_name = []
for i in range(temx.shape[0]):
    a = temx.iloc[i].values
    for j in range(len(a)-1, -1, -1):
        if len(a[j]) > 0:
            species_name.append(temx.columns[j][0].lower() + '_' + a[j])
            break
tax_name['Name'] = np.array(species_name)

# Add Abundance (assumes Y is defined)
tax_name['Abundance'] = Y.sum(axis=0)

# -------------------------------
# Use different logic based on `data` flag
# -------------------------------
if data == "original":
    # Optional: use external annotation
    try:
        tem = pd.read_csv(f'../data/{data_path}/species_tax_anot.amended.csv').iloc[:, [1, 12]]
        tax_name = tax_name.merge(tem, on='Label')
        tax_name = tax_name.rename(columns={"Ecologically_relevant_classification_aggregated": "ECR"})

        # Group rare ECR categories into 'Other'
        ind_var = tax_name['ECR'].values
        vals, counts = np.unique(ind_var, return_counts=True)
        sorted_vals = vals[(-counts).argsort()]
        common_vals = sorted_vals[:np.sum(counts > 10)]
        rare_vals = np.setdiff1d(vals, common_vals)
        tax_name['ECR'] = np.where(np.isin(ind_var, rare_vals), 'Other', ind_var)

    except FileNotFoundError:
        raise FileNotFoundError("species_tax_anot.amended.csv not found. Cannot assign ECR.")

elif data == "new":
    # Use Order column for ECR
    if 'Order' in tax_name.columns:
        tax_name['ECR'] = tax_name['Order']
    else:
        raise ValueError("Column 'Order' not found in tax_name.")

    # Group rare Orders into 'Other'
    vals, counts = np.unique(tax_name['ECR'], return_counts=True)
    common_vals = vals[counts > 10]
    tax_name['ECR'] = np.where(tax_name['ECR'].isin(common_vals), tax_name['ECR'], 'Other')

else:
    raise ValueError(f"Unknown data mode: {data}")

# ✅ Show summary
print("✅ Taxonomic annotation completed.")
print("Top ECR categories:\n", tax_name['ECR'].value_counts().head())


✅ Taxonomic annotation completed.
Top ECR categories:
 SAR11 clade                   232
Alteromonadales               120
Flavobacteriales              101
Other                          84
Chloroplast (unclassified)     80
Name: ECR, dtype: int64


In [5]:
selected_species = list(np.unique(tax_name['ECR']))
#selected_species.remove('Other')
selected_species_index = tax_name['ECR'].isin(selected_species).values
#species_col_dict = dict(zip(selected_species,distinct_colp[:len(selected_species)]))

In [6]:
## Species- species interactiion matrix estimate 
cov_mat = np.matmul(parma_mean['L_sp'],parma_mean['L_i'].T)
cov_mat  = cov_mat  #cov_mat.max() - cov_mat
cov_mat = (cov_mat + cov_mat.T)/2
np.fill_diagonal(cov_mat,0)

selected_species_index = tax_name['ECR'].isin(selected_species).values
dist_pos = copy.copy(cov_mat[selected_species_index][:,selected_species_index])
dist_neg = copy.copy(cov_mat[selected_species_index][:,selected_species_index])
for i in range(dist_pos.shape[0]):
    tem = dist_pos[i].argsort()
    dist_pos[i,tem[:-5]] = 0. 
    dist_neg[i,tem[5:]] = 0.
    
dist_neg = dist_neg*(-1.0)
    

In [7]:
parma_mean.keys()

dict_keys(['C0', 'C_geo', 'L_sp', 'L_i', 'A_s', 'A_b', 'A_m', 'tau', 'phi'])

In [9]:
datframe = pd.DataFrame(parma_mean["C_geo"])
datframe["ecr"] = tax_name["ECR"].values
datframe 

Unnamed: 0,0,1,2,3,4,5,6,7,8,ecr
0,-0.027003,0.030176,0.028744,-0.016490,-0.022581,-0.085728,-0.181170,-0.046121,-0.003771,Prochlorococcus
1,0.003816,0.175351,0.033316,0.015619,-0.114895,-0.207067,-0.002411,-0.097559,0.037132,Prochlorococcus
2,-0.020287,0.268054,0.119312,-0.180467,-0.266538,-0.179257,-0.290694,-0.042669,-0.321770,SAR86 clade
3,0.162636,0.692321,0.019715,-0.238223,-0.602797,-0.127534,-0.082981,0.005076,-0.565583,SAR86 clade
4,-0.072499,0.050732,-0.542972,0.045001,-0.082856,-0.475566,-0.261402,0.127213,-0.004419,Rhodospirillales
...,...,...,...,...,...,...,...,...,...,...
1373,-0.290782,-0.081863,0.336774,0.081741,0.029785,0.078715,0.098732,0.065758,-0.031955,SAR11 clade
1374,0.085575,0.043922,0.005905,-0.053185,-0.069169,-0.029873,-0.047056,0.000971,-0.128529,Flavobacteriales
1375,0.342059,0.052978,-0.048701,-0.121351,-0.036186,-0.053660,-0.070604,-0.024430,-0.008285,SAR11 clade
1376,-0.091933,-0.018549,0.081919,0.075092,-0.066110,-0.022063,0.053354,-0.035622,0.006574,Alpha-other


In [13]:
new_column_names = ['Temperature', 'Salinity', 'Oxygen', "Nitrates", "NO2", "PO4", "NO2NO3", "SI", "grad SST", "ecr"]  # your custom names
datframe.columns = new_column_names


In [14]:
datframe

Unnamed: 0,Temperature,Salinity,Oxygen,Nitrates,NO2,PO4,NO2NO3,SI,grad SST,ecr
0,-0.027003,0.030176,0.028744,-0.016490,-0.022581,-0.085728,-0.181170,-0.046121,-0.003771,Prochlorococcus
1,0.003816,0.175351,0.033316,0.015619,-0.114895,-0.207067,-0.002411,-0.097559,0.037132,Prochlorococcus
2,-0.020287,0.268054,0.119312,-0.180467,-0.266538,-0.179257,-0.290694,-0.042669,-0.321770,SAR86 clade
3,0.162636,0.692321,0.019715,-0.238223,-0.602797,-0.127534,-0.082981,0.005076,-0.565583,SAR86 clade
4,-0.072499,0.050732,-0.542972,0.045001,-0.082856,-0.475566,-0.261402,0.127213,-0.004419,Rhodospirillales
...,...,...,...,...,...,...,...,...,...,...
1373,-0.290782,-0.081863,0.336774,0.081741,0.029785,0.078715,0.098732,0.065758,-0.031955,SAR11 clade
1374,0.085575,0.043922,0.005905,-0.053185,-0.069169,-0.029873,-0.047056,0.000971,-0.128529,Flavobacteriales
1375,0.342059,0.052978,-0.048701,-0.121351,-0.036186,-0.053660,-0.070604,-0.024430,-0.008285,SAR11 clade
1376,-0.091933,-0.018549,0.081919,0.075092,-0.066110,-0.022063,0.053354,-0.035622,0.006574,Alpha-other


In [16]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Assuming your dataframe is named `datframe`

# Set the index to 'ecr' and transpose the matrix
heatmap_data = datframe.drop(columns='ecr').T  # shape will be (features x taxa)
heatmap_data.columns = datframe['ecr'].values  # set column names as taxa

# Optional: group duplicate taxa by mean (in case there are multiple rows per taxon)
heatmap_data = heatmap_data.groupby(by=heatmap_data.columns, axis=1).mean()

# Plot the heatmap
plt.figure(figsize=(16, 8))
sns.heatmap(heatmap_data, cmap='bwr', center=0)

plt.xlabel("")
plt.ylabel("Geochemical Covariates")
plt.xticks(rotation=90)
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig(figfol +"heatmap_geochemical_taxa.png", dpi=300)
plt.close()
plt.show()
