## Exploratory data analysis

This code performs an exploratory data analysis of the metrics calculated from the extinction events. The analyzed metrics include:

* Number of new extinctions (new_ext)

* Bray–Curtis dissimilarity (BC_diss)

* Keystoneness (K_s)

* Time to stability after extinctions (ext_ts)

In [None]:
# | eval: false

# Load data
import pandas as pd
import os

# Section: Generate-paths
exp_dir = "/mnt/data/sur/users/mrivera/Train-sims/4379fd40-9f0a"
tgt_dir = os.path.join(exp_dir, "GNN-targets")
data_path = os.path.join(exp_dir, "parameters-sims.tsv")

#  Load-data
data = pd.read_csv(data_path, sep="\t")
ids20 = data.loc[data['n_species'] == 20]['id']
ids100 = data.loc[data['n_species'] == 100]['id']

In [None]:

# | eval: false

import pyarrow.feather as ft
import numpy as np
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import ProcessPoolExecutor
import multiprocessing

TGT_DIR = '/mnt/data/sur/users/mrivera/Train-sims/4379fd40-9f0a/GNN-targets'  # Replace with actual path
RAW_ODES_DIR = '/mnt/data/sur/users/mrivera/Train-sims/4379fd40-9f0a/raw-ODEs'

def read_data(id):
    # Load targets
    x = ft.read_table(os.path.join(TGT_DIR, f'tgt_{id}.feather'))
    ext = x['new_ext'].to_numpy()
    Bc = np.round(x['BC_diss'].to_numpy(), 5)
    Ks = np.round(x['K_s'].to_numpy(), 5)
    # Load relative frequency
    y = ft.read_table(os.path.join(RAW_ODES_DIR, f'O_{id}.feather'))
    freq = y.column(-1).to_numpy()
    # Calculate relative frequency 
    freq_sum = freq.sum()
    rel = freq / freq_sum if freq_sum != 0 else np.zeros_like(freq)
    return ext, Bc, Ks, rel


def par_dat(ids, max_workers=None):
    if max_workers is None:
        max_workers = len(os.sched_getaffinity(0))  # Respects SLURM allocation
    # Parallel
    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        results = list(executor.map(read_data, ids))
    # Unpack and convert to numpy arrays
    ext, Bc, Ks, rel = zip(*results)
    # Flatten/concatenate the vectors
    keys = ['extinctions', 'bray_curtis', 'keystone', 'relative']
    data = {k: np.concatenate(v) for k, v in zip(keys, [ext, Bc, Ks, rel])}
    return data
     
        
# 20 (specs)* Number of simulations with 20 species
data20 = par_dat(ids=ids20)
data100 = par_dat(ids=ids100)


In [None]:
from multiprocessing import Pool
from scipy.stats import pearsonr
from scipy.stats import spearmanr
import pyarrow.feather as ft
import numpy as np 

def read_data_old(id):
    x = ft.read_table(os.path.join(tgt_dir, f'tgt_{id}.feather'))
    ext = x['new_ext'].to_pandas()
    Bc = round(x['BC_diss'].to_pandas(), 5)
    Ks = round(x['K_s'].to_pandas(), 5)
    return ext, Bc, Ks

# This line is for testing ONLY 
def par_dat_old(ids):
    ext, Bc, Ks = [], [], []
    if __name__ == '__main__':
        with Pool(processes=8) as pool:
            results = pool.map(read_data_old, ids)
        # Unpack and convert to numpy arrays
        ext, Bc, Ks = map(np.array, zip(*results))
    return ext, Bc, Ks

id100 = data.loc[data['n_species'] == 100]['id']
ext100, Bc100, Ks100 =  par_dat_old(ids=id100)
corr, pvalue = pearsonr(ext100.flatten(), Bc100.flatten())

# Relative frequencies distribution

Generate a plot for relative frequencies distributions.

In [None]:
# | eval: false

import matplotlib.pyplot as plt
import numpy as np

# srun --partition=interactive --time=00:40:00 --cpus-per-task=8 --mem=20G --pty bash
b1 = np.linspace(0, 1, 100)
ct1, _ = np.histogram(data20['relative'], bins=100)
ct2, _ = np.histogram(data100['relative'], bins=100)

# Apply log
ct20_log = np.log1p(ct1) 
ct100_log = np.log1p(ct2) 

# Plot distribution
plt.clf()  # Clear the entire figure
plt.figure(figsize=(12, 5))
plt.plot(b1, ct20_log, marker='o', linestyle='-', color='blue', label='20 species')
plt.plot(b1, ct100_log, marker='o', linestyle='-', color='red', label='100 species')
plt.xlabel('Relative frequencies')
plt.ylabel('log(Frequency + 1)')
plt.title('Distribution of relative frequencies')
plt.tight_layout()
plt.legend(loc='upper right')
plt.savefig('/mnt/data/sur/users/mrivera/Plots/RelFreq-distr.png', dpi=300)
plt.close()


In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

img = mpimg.imread('/mnt/data/sur/users/mrivera/Plots/RelFreq-distr.png')
plt.figure(figsize=(10, 6))
plt.imshow(img)
plt.axis('off')
plt.show()

## Extinctions distribution
We generate a function to compare the distribution of the number of extinctions.

In [None]:
# | eval: false

import matplotlib.pyplot as plt
import seaborn as sns

# Generate counting 
labels20, counts20 = np.unique(data20['extinctions'], return_counts=True)
labels100, counts100 = np.unique(data100['extinctions'], return_counts=True)

def pie_dat(labels, counts):
    # Filter by >5% of data
    rel = counts/sum(counts)
    labels_final, counts_final, fail_counts = labels[rel > 0.05],  counts[rel > 0.05], counts[rel <= 0.05].sum()
    # Pie chart final labels and counts
    pie_labels= np.append(labels_final, 'other')
    pie_counts= np.append(counts_final, fail_counts)
    return pie_counts, pie_labels

# Generate pie chart data
pie_20, pie_labels20 = pie_dat(labels = labels20, counts = counts20)
pie_100, pie_labels100 = pie_dat(labels = labels100, counts = counts100)

# Generate the pie chart
plt.clf()  # Clear the entire figure
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
colors = sns.color_palette('pastel')

# First pie chart 20 species
ax1.pie(pie_20, labels=pie_labels20, autopct='%1.1f%%', colors=colors,
        wedgeprops={'edgecolor': 'white', 'linewidth': 2},
        startangle=90,
        textprops={'fontsize': 11, 'weight': 'bold'},
        pctdistance=0.85,
        explode=[0.05] * len(pie_20)  # Slightly separate all slices
        )
ax1.set_title('Number of extinctions in 20 species data', fontweight='bold')

# Second pie chart 100 species
ax2.pie(pie_100, labels=pie_labels100, autopct='%1.1f%%', colors=colors,
        wedgeprops={'edgecolor': 'white', 'linewidth': 2},
        startangle=90,
        textprops={'fontsize': 11, 'weight': 'bold'},
        pctdistance=0.85,
        explode=[0.05] * len(pie_100)  # Slightly separate all slices
        )
ax2.set_title('Number of extinctions in 100 species data', fontweight='bold')

plt.tight_layout()
plt.savefig('/mnt/data/sur/users/mrivera/Plots/ext-pie.png', dpi = 300)  # Saves as PNG

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

img = mpimg.imread('/mnt/data/sur/users/mrivera/Plots/ext-pie.png')
plt.figure(figsize=(10, 6))
plt.imshow(img)
plt.axis('off')
plt.show()

## HeatMap of relative frequency vs number of extinctions

In [None]:
# | eval: false

# Distribution of extinctions histogram
import numpy as np
import matplotlib.pyplot as plt

# Create figure with two subplots
plt.clf()  # Clear the entire figure
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Compute 2D histogram
H, xedges, yedges = np.histogram2d(
    data20['relative'], 
    data20['extinctions'], 
    bins=[10, 10]
)

# Apply log transformation to counts (add 1 to avoid log(0))
H_log = np.log1p(H.T)  
H_log_masked = np.ma.masked_where(H_log == 0, H_log)

# Create the heatmap on ax1
im1 = ax1.imshow(
    H_log, # Transpose to match plot orientation
    origin='lower',
    aspect='auto',
    extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]],
    cmap='coolwarm',
    interpolation='nearest'
)

fig.colorbar(im1, ax=ax1, label='Log(Frequency)')
ax1.set_xlabel('Relative Frequency')
ax1.set_ylabel('Number of Extinctions')
ax1.set_title('20 Species')

# For 100 species (ax2)
H100, xedges100, yedges100 = np.histogram2d(
    data100['relative'], 
    data100['extinctions'], 
    bins=[10, 10]
)
H100_log = np.log1p(H100.T)

im2 = ax2.imshow(
    H100_log, 
    origin='lower',
    aspect='auto',
    extent=[xedges100[0], xedges100[-1], yedges100[0], yedges100[-1]],
    cmap='coolwarm',
    interpolation='nearest'
)

fig.colorbar(im2, ax=ax2, label='Log(Frequency)')
ax2.set_xlabel('Relative Frequency')
ax2.set_ylabel('Number of Extinctions')
ax2.set_title('100 Species')

plt.tight_layout()
plt.savefig('/mnt/data/sur/users/mrivera/Plots/exts-distr.png', dpi=300)
plt.close()

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

img = mpimg.imread('/mnt/data/sur/users/mrivera/Plots/exts-distr.png')
plt.figure(figsize=(10, 6))
plt.imshow(img)
plt.axis('off')
plt.show()

# Correlation between variables


### Pearson correlation

In [None]:

# | eval: false

import numpy as np 
from scipy.stats import pearsonr
from scipy.stats import spearmanr
import itertools

def pearson_cor(var1, var2, name1, name2, specs):
    corr, pvalue = pearsonr(var1, var2)
    return f'>> Pearson correlation of {name1} with {name2}, is correlated with r={corr} for {specs} species. Statistical significance: pval={pvalue:.2e}\n'
    

def spearman_cor(var1, var2, name1, name2, specs):
    corr, pvalue = spearmanr(var1, var2)
    return f'>> Spearman correlation of {name1} with {name2}, is correlated with r={corr} for {specs} species. Statistical significance: pval={pvalue:.2e}\n'

def run_correlation(data, n_species):
    # Create list for lines
    lines_pearson = []
    lines_spearman = []
    # Generate combinations
    keys = ['extinctions', 'bray_curtis', 'keystone', 'relative']
    pairs = list(itertools.combinations(keys, 2))           
    for p in pairs:
        name1, name2 = p[0], p[1]                           # Variable names
        var1, var2 = data20[p[0]], data20[p[1]]             # Data
        lines_pearson.append( pearson_cor(var1, var2, name1, name2, n_species) )
        lines_spearman.append( spearman_cor(var1, var2, name1, name2, n_species) )
    # Write files
    with open(f'/mnt/data/sur/users/mrivera/Logs/pearson{n_species}_correlation.txt', 'w') as f:
        f.writelines(lines_pearson)
    with open(f'/mnt/data/sur/users/mrivera/Logs/spearman{n_species}_correlation.txt', 'w') as f:
        f.writelines(lines_spearman)

# Run it
run_correlation(data20, 20)
run_correlation(data100, 100)

## Run with 20 species

In [None]:
with open('/mnt/data/sur/users/mrivera/Logs/pearson20_correlation.txt', 'r') as f:
    content = f.read()
    print(content)

with open('/mnt/data/sur/users/mrivera/Logs/spearman20_correlation.txt', 'r') as f:
    content = f.read()
    print(content)

## Run with 100 species

In [None]:
with open('/mnt/data/sur/users/mrivera/Logs/pearson100_correlation.txt', 'r') as f:
    content = f.read()
    print(content)

with open('/mnt/data/sur/users/mrivera/Logs/spearman100_correlation.txt', 'r') as f:
    content = f.read()
    print(content)



* extinctions is correlated with bray_curtis with r=-0.02 for 20 species. Statistical significance: pval=8.17e-08

* extinctions is correlated with keystone with r=-0.01 for 20 species. Statistical significance: pval=2.46e-07

* extinctions is correlated with relative with r=-0.13 for 20 species. Statistical significance: pval=0.00e+00

* bray_curtis is correlated with keystone with r=1.00 for 20 species. Statistical significance: pval=0.00e+00

* bray_curtis is correlated with relative with r=0.08 for 20 species. Statistical significance: pval=6.66e-194

* keystone is correlated with relative with r=0.08 for 20 species. Statistical significance: pval=1.28e-173

## Run with 100 species
* extinctions is correlated with bray_curtis with r=-0.02 for 100 species. Statistical significance: pval=8.17e-08

* extinctions is correlated with keystone with r=-0.01 for 100 species. Statistical significance: pval=2.46e-07

* extinctions is correlated with relative with r=-0.13 for 100 species. Statistical significance: pval=0.00e+00

* bray_curtis is correlated with keystone with r=1.00 for 100 species. Statistical significance: pval=0.00e+00

* bray_curtis is correlated with relative with r=0.08 for 100 species. Statistical significance: pval=6.66e-194

* keystone is correlated with relative with r=0.08 for 100 species. Statistical significance: pval=1.28e-173

## Plot correlations


In [None]:
# | eval: false

import seaborn as sns
import pandas as pd

# Convert to pandas DataFrame if it's not already
data20_df = pd.DataFrame(data20)

# Then create pairplot
pairplot = sns.pairplot(data20_df[['extinctions', 'bray_curtis', 'keystone', 'relative']])
pairplot.savefig('/mnt/data/sur/users/mrivera/Plots/pairplot.png', dpi=300)
plt.close()

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

img = mpimg.imread('/mnt/data/sur/users/mrivera/Plots/pairplot.png')
plt.figure(figsize=(10, 6))
plt.imshow(img)
plt.axis('off')
plt.show()

# Distribution of keystoness and Bray-Curtis

To determine the optimal number of bins, we can use the Freedman–Diaconis rule. However, since the interquartile range (IQR) is equal to zero in this case, the rule cannot be applied. Therefore, we use Scott’s rule as an alternative.

In [None]:
# | eval: false

import numpy as np
import matplotlib.pyplot as plt

# Calculate optimal number of bins
def bin_calc(x):
    std = np.std(x)
    n = len(x)
    h_scott = 3.5 * std / (n ** (1/3))
    bins_scott = int(np.ceil((x.max() - x.min()) / h_scott))
    return bins_scott

# Compute the number of bins
x1 = np.round(Ks20, 5)
bins1 = bin_calc(x1)
x2 = np.round(Ks100, 5)
bins2 = bin_calc(x2)

# Generate the histogram for keystoness
plt.clf()  # Clear the entire figure
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
ax1.hist(x1, bins=bins1, density=True, histtype='step',  color='blue', label='S20')  # 'step' draws only the outline
ax1.hist(x2, bins=bins2, density=True, histtype='step', color='red', label='S100')  # 'step' draws only the outline
ax1.set_xlabel('Keystoness')
ax1.set_ylabel('Frequency')
ax1.set_title('Keystoness distribution')

# Compute the number of bins
x1 = np.round(Bc20, 5)
x2 = np.round(Bc100, 5)
bins1 = bin_calc(x1)
bins2 = bin_calc(x2)
ax2.hist(x1, bins=bins1, density=True, histtype='step',  color='blue', label='S20')  # 'step' draws only the outline
ax2.hist(x2, bins=bins2, density=True, histtype='step', color='red', label='S100')  # 'step' draws only the outline
ax2.set_xlabel('Bray-Curtis')
ax2.set_ylabel('Frequency')
ax2.set_title('Bray-Curtis distribution')
plt.legend()
plt.savefig('/mnt/data/sur/users/mrivera/Plots/both-distr.png', dpi = 300)  # Saves as PNG
plt.close()

In [None]:
# | eval: false

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

img = mpimg.imread('/mnt/data/sur/users/mrivera/Plots/both-distr.png')
plt.figure(figsize=(10, 6))
plt.imshow(img)
plt.axis('off')
plt.show()