In [20]:
import os
import glob

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import pandas as pd
import mpmath
import numba as nb
import scipy.optimize as opt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import matplotlib.image as mpimg

from tqdm import tqdm
from scipy import signal
from scipy.special import hyp2f1

from matplotlib.ticker import LogLocator, LogFormatterMathtext, NullFormatter
from matplotlib.colors import LogNorm

# Color palettes
light_palette = ['#504B43', '#4caf50', '#948d99']
dark_palette = ['#285DB1', '#AC3127', '#c1bbb0']
color_clusters = ['#648fff', '#ffb000', '#948d99']

# Figure settings
fig_formats = ['.pdf', '.eps', '.tiff']
cms = 0.393701

sns.set_theme(
    rc={
        'figure.figsize': figsizes["2 columns"],
        'figure.dpi': 200,
        'savefig.dpi': 300
    },
    font="Helvetica Neue",
    font_scale=1.3,
    style="ticks"
)

plt.rcParams.update({
    'legend.edgecolor': 'k',
    'legend.facecolor': 'w',
    'legend.frameon': True,
    'legend.framealpha': 1,
    'legend.fancybox': False,
    'legend.fontsize': 12,
    'axes.linewidth': 1.5,
    'axes.edgecolor': 'k',
    'xtick.labelsize': 14,
    'ytick.labelsize': 14,
    'axes.labelsize': 14,
    "text.usetex": True,
    "font.family": "sans-serif",
    "font.sans-serif": ["Helvetica"],
    "text.latex.preamble": r"\usepackage{amsmath, helvet} \renewcommand{\familydefault}{\sfdefault}"
})

%load_ext autoreload


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
def upper_incomplete_gamma(a, z):
    """
    Computes the upper incomplete Gamma function Gamma(a, z) for negative and positive values of a.
    
    Parameters:
    a : float
        Shape parameter, can be negative.
    z : float
        Lower limit of the integral.
    
    Returns:
    float
        The value of the upper incomplete Gamma function Gamma(a, z).
    """
    return mpmath.gammainc(a, z, mpmath.inf)

def pareto3exp(x, alpha, beta, k):
    """
    Computes the Pareto-like function with an exponential cutoff.

    Parameters:
    x (float or array-like): The point(s) at which to evaluate the function.
    alpha (float): Shape parameter, must be positive.
    beta (float): Exponential cutoff parameter, must be positive.
    k (float): Scale parameter, must be positive.

    Returns:
    float or array-like: The computed value of the function.
    """
    # Ensure parameters are positive
    if alpha <= 0 or beta <= 0 or k <= 0:
        raise ValueError("All parameters (alpha, beta, k) must be positive.")
    
    # Compute the exponential integral using the Gamma function
    z = k * beta
    exp_integral = (z ** alpha) * upper_incomplete_gamma(-alpha, z)
    #print(f'Exponential Integral (using mpmath): {exp_integral}')

    # Calculate the prefactor
    prefactor = np.exp(-k * beta) / (k * exp_integral)
    
    # Calculate the main expression
    result = prefactor * (1 + x / k) ** (-1 - alpha) * np.exp(-beta * x)
    
    return np.array(result, dtype=float)

In [22]:
def k_eff(B0, D0, B1, D1, b0, b1, d1, Delta):
    num = (B0+D0)*Delta
    den = B1+D1+d1*2
    return num/den

In [23]:
def log_pdf(data, nbins=20):
    if len(data) < 3:
        return np.array([]), np.array([]), np.array([]), np.array([])
    else:
        positive_data = data[data > 0]
        bins = np.logspace(np.log10(np.min(positive_data)), np.log10(np.max(positive_data)), nbins)
        hist, _ = np.histogram(positive_data, bins=bins, density=False)
        prob = hist / np.sum(hist)
        pdf = prob / np.diff(bins)
        x_plot = np.sqrt(bins[1:] * bins[:-1])
        y_plot = pdf
        x_plot = x_plot[y_plot > 0]
        y_plot = y_plot[y_plot > 0]
        return x_plot, y_plot, pdf, bins 

def fit_log_generalized_gamma(x, y):
    def log_generalized_gamma(X, a, b, C):
        return np.log(C) + (-1+a)*X - b*np.exp(X)
    if len(x) < 3:
        return np.array([np.nan, np.nan, np.nan])
    else:
        X = np.log(x)
        Y = np.log(y)
        popt, _ = opt.curve_fit(log_generalized_gamma, X, Y, p0=[1, 1, 1], bounds=([-5, 0, 1e-9], [5, np.inf, np.inf]))
        return popt
    
def fit_log_pareto3exp(x, y):
    def log_pareto3exp(X, alpha, beta, k):
        return np.log(pareto3exp(np.exp(X), alpha, beta, k))
    if len(x) < 3:
        return np.array([np.nan, np.nan, np.nan])
    else:
        X = np.log(x)
        Y = np.log(y)
        popt, _ = opt.curve_fit(log_pareto3exp, X, Y, p0=[0.5, 0.0001, 1], bounds=([1e-9, 1e-9, 1e-9], [5, np.inf, np.inf]))
        return popt
    
    # y = C * x**(-1+a) * np.exp(-b*x)
    # logy = log(C) + (-1+a)*log(x) - b*x
    # X = log(x)
    # x = exp(X)

def plotSAD( b0, d0, b1, d1, t0, gamma0, t1, gamma1, t_max, n_x, PDF_x, n_y, PDF_y, popt_x, popt_y, Delta=0, chain=0):
    def generalized_gamma(x, a, b, C):
        return C * x**(-1+a) * np.exp(-b*x)
    fig, axes = plt.subplots(1, 2, figsize=(figsizes["2 columns"][0], 0.5*figsizes["2 columns"][1]), constrained_layout=True)
    
    title = f"Chain {chain} - $b_0={b0:.4f}$, $d_0={d0:.4f}$, $b_1={b1:.4f}$, $d_1={d1:.4f}$, $t_0={t0:.4f}$, $\\gamma_0={gamma0:.4f}$, $t_1={t1:.4f}$, $\\gamma_1={gamma1:.4f}$, $t_{{max}}={t_max}$, $\\Delta={Delta:1.1f}$"
    fig.suptitle(title, fontsize=7)

    axes[0].set_xscale('log')
    axes[0].set_yscale('log')
    text = f'$a={popt_x[0]:.2f}$\n$b={popt_x[1]:.2e}$\n$C={popt_x[2]:.2e}$'
    #axes[0].text(0.1, 0.1, text, transform=axes[0].transAxes, fontsize=8, verticalalignment='bottom')
    axes[0].scatter(n_x, PDF_x, color=light_palette[0], alpha=1, edgecolor='k', lw=0.5, s=15)  
    # if popt is not nan, plot the fit
    if not np.isnan(popt_x).any(): 
        #axes[0].plot(n_x, generalized_gamma(n_x, *popt_x), color=light_palette[0], alpha=0.5)  
        axes[0].set_xlabel("DNA")
        axes[0].set_ylabel("PDF")
        
    alphaG = 4.0 * (d0*b1-b0*d1)/(b1+d1)**2
    kG = (d0+b0)/(d1+b1)
    betaG = 2*(d1-b1)/(b1+d1)
    alphaG_fit, betaG_fit, kG_fit = fit_log_pareto3exp(n_x[n_x>1e-1], PDF_x[n_x>1e-1])
    print(f'alphaG: {alphaG:.2f}, betaG: {betaG:.2e}, kG: {kG:.2f}')
    print(f'alphaG_fit: {alphaG_fit:.2f}, betaG_fit: {betaG_fit:.2e}, kG_fit: {kG_fit:.2f}')
    y_plot = pareto3exp(n_x, alphaG, betaG, kG)
    # fit Paret3exp to the data
    axes[0].plot(n_x, y_plot, color=light_palette[0], alpha=0.95, ls='-', label='model', zorder=-1)
    text = f'$\\alpha={alphaG:.2f}$\n$\\beta={betaG:.2e}$\n$k={kG:.2f}$'
    axes[0].text(0.1, 0.2, text, transform=axes[0].transAxes, fontsize=8, verticalalignment='center')
    text = f'$\\alpha_f={alphaG_fit:.2f}$\n$\\beta_f={betaG_fit:.2e}$\n$k_f={kG_fit:.2f}$'
    axes[0].text(0.45, 0.2, text, transform=axes[0].transAxes, fontsize=8, verticalalignment='center')
    #axes[0].legend()
    #axes[0].set_xlim(1e-2, 3e3)
        #axes[0].plot(n_x, generalized_gamma_distribution(n_x, *popt_x[:2]), color='k', alpha=1, linestyle='--')

    axes[1].set_xscale('log')
    axes[1].set_yscale('log')
    text = f'$a={popt_y[0]:.2f}$\n$b={popt_y[1]:.2e}$\n$C={popt_y[2]:.2e}$'
    #axes[1].text(0.1, 0.1, text, transform=axes[1].transAxes, fontsize=8, verticalalignment='bottom')
    axes[1].scatter(n_y, PDF_y, color=light_palette[1], alpha=1, edgecolor='k', lw=0.5, s=15)    
    if not np.isnan(popt_y).any():
        #axes[1].plot(n_y, generalized_gamma(n_y, *popt_y), color=light_palette[1], alpha=0.5)  
        axes[1].set_xlabel("RNA")
        axes[1].set_ylabel("PDF")
        
    def pdf_effective(n, a, b, c, k):
        return c * (1.+n/k)**(-1+a) * np.exp(-b*n)
        
    #y_plot = pdf_effective(n_y, a=-0.5, b=1.34653394*1e-5, c=0.5*1e-2, k=0.9*1e2)
    alphaT, betaG, kG = alphaG, betaG*gamma1/t0, kG*t0/gamma1
    y_plot = pareto3exp(n_y, alphaT, betaG, kG)
    axes[1].plot(n_y, y_plot, color=light_palette[1], alpha=0.95, ls='-')
    alphaT_fit, betaT_fit, kT_fit = fit_log_pareto3exp(n_y[n_y>1e-0], PDF_y[n_y>1e-0])
    #y_plot = pareto3exp(n_y, alphaT_fit, betaT_fit, kT_fit)
    text = f'$\\alpha={alphaT:.2f}$\n$\\beta={betaG:.2e}$\n$k={kG:.2f}$'
    axes[1].text(0.1, 0.2, text, transform=axes[1].transAxes, fontsize=8, verticalalignment='center')
    text = f'$\\alpha_f={alphaT_fit:.2f}$\n$\\beta_f={betaT_fit:.2e}$\n$k_f={kT_fit:.2f}$'
    axes[1].text(0.45, 0.2, text, transform=axes[1].transAxes, fontsize=8, verticalalignment='center')
    
    text = "NOT A FIT"
    axes[0].text(0.7, 0.8, text, transform=axes[0].transAxes, fontsize=8, verticalalignment='bottom')
    axes[1].text(0.7, 0.8, text, transform=axes[1].transAxes, fontsize=8, verticalalignment='bottom')
    
    os.makedirs('langevin/figures', exist_ok=True)
    figname = f'langevin/figures/langevin_chain_{chain}_b0_{b0:.4f}_d0_{d0:.4f}_b1_{b1:.4f}_d1_{d1:.4f}_t0_{t0:.4f}_gamma0_{gamma0:.4f}_t1_{t1:.4f}_gamma1_{gamma1:.4f}_tmax_{t_max}_Delta_{Delta:1.1f}.pdf'
    fig.savefig(figname, dpi=200)
    plt.show()
    #plt.close()
    
    return fig, axes

def plotCORR(datax, datay, b0, d0, b1, d1, t0, gamma0, t1, gamma1, t_max, Delta, chain=0):
    fig, ax = plt.subplots(figsize=(3,3), constrained_layout=True)
    sns.scatterplot(x=datax, y=datay, s=5, lw=0.5, edgecolor='k', alpha=0.5, ax=ax,color='orange')
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlabel('DNA')
    ax.set_ylabel('RNA')
    ax.grid(True)
    
    R2 = np.corrcoef(datax, datay)[0,1]**2
    ax.text(0.6, 0.9, f'$R^2={R2:.2}$', transform=ax.transAxes, fontsize=8, verticalalignment='center', bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.2'))
    
    xmin = np.log10(np.min(datax[datax>10]))
    x_plot = np.logspace(xmin, np.log10(datax.max()), 10)
    y_plot = x_plot*t0/gamma1
    y_plot_std = np.sqrt(y_plot)
    #sns.lineplot(x=x_plot, y=y_plot, lw=2, color='k', ls='--', zorder=1)
    #ax.errorbar(x=x_plot, y=y_plot, yerr = y_plot_std*3, color='k', zorder=1)
    
    figname = f'figures/scatter_langevin_chain_{chain}_b0_{b0:.4f}_d0_{d0:.4f}_b1_{b1:.4f}_d1_{d1:.4f}_t0_{t0:.4f}_gamma0_{gamma0:.4f}_t1_{t1:.4f}_gamma1_{gamma1:.4f}_tmax_{t_max}_Delta_{Delta:1.1f}_corr.pdf'
    fig.savefig(figname, dpi=200)
    
    
def return_alphaG(b0, d0, b1, d1):
    return 4.0 * (d0*b1-b0*d1)/(b1+d1)**2
def return_kG(b0, d0, b1, d1):
    return (d0+b0)/(d1+b1)
def return_betaG(b0, d0, b1, d1):
    return 2*(d1-b1)/(b1+d1)

def return_kT(b0, d0, b1, d1, t0, gamma1):
    return (d0+b0)/(d1+b1)*t0/gamma1
def return_betaT(b0, d0, b1, d1, t0, gamma1):
    return 2*(d1-b1)/(b1+d1)*gamma1/t0
    
    

In [24]:
r0 = 100.0
r1 = 0.
gamma1 = 100.01
b0 = 9.5/1
b1 = 1./1
d0 = 10.5/1
d1 = 1.001/1
Delta = 0.
Tmax = int(2e8)
dt = 0.01
chain = 0

In [None]:
pattern = f'data/{r0:.4f}_{r1:.4f}_{gamma1:.4f}_{b0:.4f}_{b1:.4f}_{d0:.4f}_{d1:.4f}_{Delta:.4f}_{Tmax}_{dt:.4f}_*.npz'
#data = np.load('data/1.0000_10.0000_10.0100_0.1000_1.0000_0.5000_1.0010_100.0000_500000000_0.0100_0.npz')

matching_files = glob.glob(pattern)

# Load data from all matched files
datasets = [np.load(file) for file in matching_files]

In [None]:
x_uncorrelated = np.concatenate([data['x_uncorrelated'] for data in datasets]).flatten()
y_uncorrelated = np.concatenate([data['y_uncorrelated'] for data in datasets]).flatten()

In [None]:
filter_thsh = 1e-1
x_uncorrelated_filtered = x_uncorrelated[((x_uncorrelated >= filter_thsh) & (y_uncorrelated >= filter_thsh))]
y_uncorrelated_filtered = y_uncorrelated[((x_uncorrelated >= filter_thsh) & (y_uncorrelated >= filter_thsh))]


n_x, PDF_x, _, _ = log_pdf(x_uncorrelated, nbins=20)
popt_x = fit_log_generalized_gamma(n_x, PDF_x)

n_y, PDF_y, _, _ = log_pdf(y_uncorrelated, nbins=25)
popt_y = fit_log_generalized_gamma(n_y, PDF_y)

In [None]:
height = 3.24931369
phi = (1 + 5**0.5) / 2  # Golden ratio


fig, ax0 = plt.subplots(figsize=(height*phi, height*1.3), constrained_layout=False)

lw =3.5

ax = fig.add_axes([-2.03*0.78 / phi, 0.1, 0.78*1.1, 0.78])

ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlim(1e-1, 1e5)
# ax.set_ylim(1e-8, 1e0)
#ax.set_facecolor("#EEEEEE")
ax.set_xlabel(r"unigene density")
ax.set_ylabel(r"abundance distribution")
#ax.set_ylim(1e-8, 1e0)
ax.xaxis.set_major_locator(LogLocator(base=10.0, numticks=10))
ax.xaxis.set_major_formatter(LogFormatterMathtext())
ax.xaxis.set_minor_locator(plt.LogLocator(base=10.0, subs=np.arange(2, 10) * 0.1, numticks=10))
ax.yaxis.set_major_locator(LogLocator(base=10.0, numticks=10))
ax.yaxis.set_major_formatter(LogFormatterMathtext())
ax.yaxis.set_minor_locator(plt.LogLocator(base=10.0, subs=np.arange(2, 10) * 0.1, numticks=10))
ax.yaxis.set_minor_formatter(NullFormatter())
ax.scatter(n_x, PDF_x, color=light_palette[0], alpha=1, edgecolor='k', lw=1., s=50, marker='d')
x_plot = np.logspace(np.log10(n_x.min()), np.log10(1.6*n_x.max()), 100)
y_plot = pareto3exp(x_plot, return_alphaG(b0, d0, b1, d1), return_betaG(b0, d0, b1, d1), return_kG(b0, d0, b1, d1))
ax.plot(x_plot, y_plot, color=light_palette[0], alpha=0.95, ls='-', label='model', zorder=-1, lw=lw)
ax.set_xticks([1e0, 1e2, 1e4])
ax.set_yticks([1e-1, 1e-3, 1e-5, 1e-7])


ax.scatter(n_y, PDF_y, color=light_palette[1], alpha=1, edgecolor='k', lw=1., s=50, marker='o')
print(f'alphaG: {return_alphaG(b0, d0, b1, d1):.2f}, betaG: {return_betaG(b0, d0, b1, d1):.2e}, kG: {return_kG(b0, d0, b1, d1):.2f}')
print(f'alphaT: {return_alphaG(b0, d0, b1, d1):.2f}, betaT: {return_betaT(b0, d0, b1, d1, r0, gamma1):.2e}, kT: {return_kT(b0, d0, b1, d1, r0, gamma1):.2f}')
alphaT, betaT, kT = return_alphaG(b0, d0, b1, d1), return_betaT(b0, d0, b1, d1, r0, gamma1), return_kT(b0, d0, b1, d1, r0, gamma1)
x_plot = np.logspace(np.log10(n_y.min()), np.log10(1.5*n_y.max()), 100)
y_plot = pareto3exp(x_plot, alphaT, betaT, kT)
ax.plot(x_plot, y_plot, color=light_palette[1], alpha=0.95, ls='-', label='model', zorder=-1, lw=lw)

# add inset ax1

ax1 = ax.inset_axes([0.1, 0.1, 0.35, 0.48])  # [left, bottom, width, height]
ax1.set_xscale('log')
ax1.set_yscale('log')
# ax1.set_xlabel(r"DNA unigene density $x$")
# ax1.set_ylabel(r"mRNA unigene density $y$")
ax1.set_xlim(filter_thsh * 0.8, 3e4)
ax1.set_ylim(filter_thsh * 0.8, 3e4)

x = np.array(x_uncorrelated_filtered)
y = np.array(y_uncorrelated_filtered)
mask = (x > 0) & (y > 0)
x = x[mask]
y = y[mask]

n_bins = 100
x_min, x_max = x.min(), x.max()
y_min, y_max = y.min(), y.max()

# Binning logaritmico
x_edges = np.logspace(np.log10(x_min), np.log10(x_max), n_bins)
y_edges = np.logspace(np.log10(y_min), np.log10(y_max), n_bins)

H, xedges, yedges = np.histogram2d(x, y, bins=[x_edges, y_edges])

# Calcolo densità normalizzata
dx = np.diff(xedges)
dy = np.diff(yedges)
area = np.outer(dx, dy)
density = H / (area * H.sum())

x_plot = np.logspace(np.log10(xedges[1]), np.log10(xedges[-1]), 100)
y_plot = x_plot * r0 / gamma1
ax1.plot(x_plot, y_plot, color='k', alpha=0.5, ls='--', lw=lw)
text = r"$y = \frac{r}{\gamma} x$"
ax1.text(0.5, 0.5, text, transform=ax1.transAxes, fontsize=14, verticalalignment='center')
mesh = ax1.pcolormesh(xedges, yedges, density.T, norm=LogNorm(), cmap='magma')

ax1.xaxis.set_major_locator(LogLocator(base=10.0, numticks=10))
ax1.xaxis.set_major_formatter(LogFormatterMathtext())
ax1.xaxis.set_minor_locator(plt.LogLocator(base=10.0, subs=np.arange(2, 10) * 0.1, numticks=10))
ax1.yaxis.set_major_locator(LogLocator(base=10.0, numticks=10))
ax1.yaxis.set_major_formatter(LogFormatterMathtext())
ax1.yaxis.set_minor_locator(plt.LogLocator(base=10.0, subs=np.arange(2, 10) * 0.1, numticks=10))
ax1.yaxis.set_minor_formatter(NullFormatter())
# ax1.set_xticks([1e-1, 10, 1000, 1e5])
# ax1.set_yticks([1e-1, 10, 1000, 1e5])
ax1.set_xticks([1e0, 1e2, 1e4])
ax1.set_yticks([1e0, 1e2, 1e4])

text = r"metaG $P(x) \sim \left(x+k_G\right)^{-1-\alpha_G} e^{-\beta_G x}$"
text = "metaG"
ax.text(5e0, 1e-1, text, usetex=True, fontsize=16, ha='left', va='bottom', bbox=bbox_props, color=light_palette[0])
text = "metaG\n density " + r"$x$"
# r"metaG density $x$"
ax.text(7e1, 2e-8, text, usetex=True, fontsize=12, ha='left', va='center', bbox=bbox_props, color='k')#light_palette[0])
text = "metaT"
ax.text(3e2, 1e-3, text, usetex=True, fontsize=16, ha='left', va='bottom', bbox=bbox_props, color=light_palette[1])
text = "metaT\n density " + r"$y$"
ax.text(3e-1, 6e-4, text, usetex=True, fontsize=12, ha='left', va='center', bbox=bbox_props, color='k')




# phantom ax3 with horizontal location as ax1 and vertical as ax
ax3 = fig.add_axes([-4*0.78 / phi, 0.1, 0.78 / phi, 0.78])

# lettera B vicino ad ax, C vicino ad ax1, D vicino ad ax2
ax3.text(-0.15*phi, 1.01, '(A)', transform=ax3.transAxes, fontsize=16, verticalalignment='bottom', ha='center', usetex=False)
ax.text(-0.2, 1.01, '(B)', transform=ax.transAxes, fontsize=16, verticalalignment='bottom', ha='center')
ax3.text(-0.15*phi, -0.22, '(C)', transform=ax3.transAxes, fontsize=16, verticalalignment='bottom', ha='center')

# ax2.text(-0.25, 1.05, '(D)', transform=ax2.transAxes, fontsize=16, verticalalignment='bottom', ha='center')

# ax3 invisibile
ax3.set_facecolor('none')
ax3.set_xticks([])
ax3.set_yticks([])
# remove ax3
xpos = 0.2
ax3.spines['top'].set_visible(False)
ax3.spines['right'].set_visible(False)
ax3.spines['left'].set_visible(False)
ax3.spines['bottom'].set_visible(False)
fontsize = 14
ax3.text(xpos, 0.9, 'duplication', transform=ax3.transAxes, fontsize=fontsize, verticalalignment='bottom', horizontalalignment='left', usetex=False)
ax3.text(xpos, 0.7, 'influx', transform=ax3.transAxes, fontsize=fontsize, verticalalignment='bottom', horizontalalignment='left', usetex=False)
ax3.text(xpos, 0.5, 'deletion', transform=ax3.transAxes, fontsize=fontsize, verticalalignment='bottom', horizontalalignment='left', usetex=False)
ax3.text(xpos, 0.25, 'transcription', transform=ax3.transAxes, fontsize=fontsize, verticalalignment='center', horizontalalignment='left', usetex=False)
ax3.text(xpos, 0.05, 'degradation', transform=ax3.transAxes, fontsize=fontsize, verticalalignment='center', horizontalalignment='left', usetex=False)

xpos = 0.125
ax3.text(xpos, 0.9, r'$X \xrightarrow{r_1} 2X$', transform=ax3.transAxes, fontsize=fontsize, verticalalignment='bottom', horizontalalignment='right', usetex=True)
ax3.text(xpos, 0.7, r'$\emptyset \xrightarrow{r_0} X$', transform=ax3.transAxes, fontsize=fontsize, verticalalignment='bottom', horizontalalignment='right', usetex=True)
ax3.text(xpos, 0.5, r'$X \xrightarrow{d_1+\frac{d_0}{x}} \emptyset$', transform=ax3.transAxes, fontsize=fontsize, verticalalignment='bottom', horizontalalignment='right', usetex=True)
ax3.text(xpos, 0.25, r'$X \xrightarrow{r} X + Y$', transform=ax3.transAxes, fontsize=fontsize, verticalalignment='center', horizontalalignment='right', usetex=True)
ax3.text(xpos, 0.05, r'$Y \xrightarrow{\gamma_1} \emptyset$', transform=ax3.transAxes, fontsize=fontsize, verticalalignment='center', horizontalalignment='right', usetex=True)

zoom= 0.025
img = mpimg.imread('figures/biorender/X_XX.png')
imagebox = OffsetImage(img, zoom=zoom)
ab = AnnotationBbox(imagebox, (0.9, 0.905), frameon=False)  # position in data coords
ax3.add_artist(ab)
img = mpimg.imread('figures/biorender/0_X.png')
imagebox = OffsetImage(img, zoom=zoom)
ab = AnnotationBbox(imagebox, (0.9, 0.725), frameon=False)  # position in data coords
ax3.add_artist(ab)
img = mpimg.imread('figures/biorender/X_0.png')
imagebox = OffsetImage(img, zoom=zoom)
ab = AnnotationBbox(imagebox, (0.9, 0.525), frameon=False)  # position in data coords
ax3.add_artist(ab)
img = mpimg.imread('figures/biorender/X_XY.png')
imagebox = OffsetImage(img, zoom=zoom)
ab = AnnotationBbox(imagebox, (0.9, 0.25), frameon=False)  # position in data coords
ax3.add_artist(ab)
img = mpimg.imread('figures/biorender/Y_0.png')
imagebox = OffsetImage(img, zoom=zoom)
ab = AnnotationBbox(imagebox, (0.9, 0.05), frameon=False)  # position in data coords
ax3.add_artist(ab)

# draw a vertical line at custom x position with custom height
# Example: x=0.5, ymin=0.1, ymax=0.9, color='k', linestyle='--', linewidth=1
custom_x = 0.5      # set your custom x position here
custom_ymin = 0.1   # set your custom ymin (in axes fraction: 0=bottom, 1=top)
custom_ymax = 0.9   # set your custom ymax (in axes fraction: 0=bottom, 1=top)
#ax3.vlines(x=custom_x, ymin=custom_ymin, ymax=custom_ymax, color='k', linestyle='--', linewidth=1)


ax2 = fig.add_axes([-2.2*0.78 / phi+0.0225, -0.4, 0.78, 0.3])
ax2.set_facecolor('none')
ax2.set_xticks([])
ax2.set_yticks([])
ax2.axis('off')

column_labels = ['', r'\textbf{metaG}', r'\textbf{metaT}']
table_data = [
    [r'\textit{scale parameter}', r'$k_G=\frac{r_0+d_0}{r_1+d_1}$', r'$k_T=\frac{r}{\gamma_1}k_G$'],
    [r'\textit{power-law exponent}', r'$\alpha_G=4 \frac{d_0 r_1 - r_0 d_1}{(r_1+d_1)^2}$', r'$\alpha_T=\alpha_G$'],
    [r'\textit{exponential cut-off}', r'$\beta_G=2 \frac{d_1-r_1}{d_1+r_1}$', r'$\beta_T=\frac{\gamma_1}{r}\beta_G$']
]

table = ax2.table(
    cellText=table_data,
    colLabels=column_labels,
    cellLoc='center',
    loc='center'
)
table.scale(1.35, 1.9)
table.auto_set_font_size(False)
table.set_fontsize(14)

cell = table[(0, 1)]
cell.set_height(0.25)
cell.set_facecolor(mpl.colors.to_rgba(light_palette[0], alpha=0.25))
cell = table[(0, 2)]
cell.set_facecolor(mpl.colors.to_rgba(light_palette[1], alpha=0.75))
cell.set_height(0.25)

cell = table[(0, 0)]
cell.set_height(0.25)

# set all height of the cells to 0.25
for i in range(3):  # 2 righe + header
    cell = table[(i, 0)]
    cell.set_height(0.25)
    cell = table[(i, 1)]
    cell.set_height(0.25)
    cell = table[(i, 2)]
    cell.set_height(0.25)


ax4 = fig.add_axes([-4*0.78 / phi-0.1, -0.4, 0.78, 0.3])
ax4.set_facecolor('none')
ax4.set_xticks([])
ax4.set_yticks([])
ax4.axis('off')

# text = r"$P(x) = \mathcal{N} \left(1+\frac{x}{k_G}\right)^{-1-\alpha_G} e^{-\beta_G x}$"
# # text = r"$P(x) \sim \left(x+k_G\right)^{-1-\alpha_G} e^{-\beta_G x}$"
# bbox_props = dict(boxstyle="round", facecolor='none', edgecolor='none', linewidth=1)
# ax4.text(0.5, 0.6, text, usetex=True, fontsize=16, ha='center', va='bottom', bbox=bbox_props, transform=ax4.transAxes)
# ax4.text(-3.6*0.78 / phi, 1.01, '(B)', transform=ax.transAxes, fontsize=16, verticalalignment='bottom', ha='center')

# text = r"$P(y) = \mathcal{N} \left(1+\frac{y}{k_T}\right)^{-1-\alpha_T} e^{-\beta_T y}$"
# # text = r"$P(x) \sim \left(x+k_G\right)^{-1-\alpha_G} e^{-\beta_G x}$"
# bbox_props = dict(boxstyle="round", facecolor='none', edgecolor='none', linewidth=1)
# ax4.text(0.5, 0.1, text, usetex=True, fontsize=16, ha='center', va='bottom', bbox=bbox_props, transform=ax4.transAxes)



# Dati della tabella
table_data = [
    [r'\textbf{metaG}', r"$P(x) = \mathcal{N} \left(x+k_G\right)^{-1-\alpha_G} e^{-\beta_G x}$"],
    [r'\textbf{metaT}', r"$P(y) = \mathcal{N} \left(y+k_T\right)^{-1-\alpha_T} e^{-\beta_T y}$"]
]

# Intestazione verticale
column_labels = ['', r'\textit{stationary abundance distribution}']

# Crea tabella
table = ax4.table(
    cellText=table_data,
    colLabels=column_labels,
    cellLoc='center',
    loc='center'
)

# Stile
table.scale(1.35, 1.9)
table.auto_set_font_size(False)
table.set_fontsize(14)
for i in range(3):  # 2 righe + header
    table[(i, 0)].set_width(0.25)  # colonna sinistra
    table[(i, 1)].set_width(0.8)
    
cell = table[(0, 0)]
cell.set_height(0.25)
cell = table[(0, 1)]
cell.set_height(0.25)

cell = table[(1,0)]
cell.set_facecolor(mpl.colors.to_rgba(light_palette[0], alpha=0.25))
cell = table[(2,0)]
cell.set_facecolor(mpl.colors.to_rgba(light_palette[1], alpha=0.75))

for i in range(1,3):  # 2 righe + header
    cell = table[(i, 0)]
    cell.set_height(0.25*3/2)
    cell = table[(i, 1)]
    cell.set_height(0.25*3/2)



ax5 = fig.add_axes([-2.9*0.78 / phi, 0.08, 0.15, 0.78])
ax5.set_facecolor('none')
ax5.set_xticks([])
ax5.set_yticks([])
ax5.axis('off')
ax5.vlines(x=0, ymin=0., ymax=0.355, color=light_palette[1], linestyle='-', linewidth=10, transform=ax5.transAxes)
ax5.vlines(x=0, ymin=0.5, ymax=1, color=light_palette[0], linestyle='-', linewidth=10, transform=ax5.transAxes)

text = "FAST"
ax5.text(0.3, 0.355/2, text, usetex=True, fontsize=16, ha='center', va='center', transform=ax5.transAxes, color=light_palette[1], rotation=90)
text = "SLOW"   
ax5.text(0.3, 0.75, text, usetex=True, fontsize=16, ha='center', va='center', transform=ax5.transAxes, color=light_palette[0], rotation=90)

# remove ax0



ax0.remove()


fig.savefig('../figures/fig2.png', dpi=300, bbox_inches='tight', transparent=False)