In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import statsmodels.api as sm
import matplotlib.ticker as ticker
from scipy.optimize import curve_fit
from scipy.interpolate import interp1d

from scipy.stats import t

plt.rcParams['font.family'] = 'Arial'
plt.rcParams['axes.unicode_minus'] = False  



xls_new = pd.ExcelFile(r'NDVI_all.xlsx')
# xls_new = pd.ExcelFile(r'LAI_all.xlsx')
# xls_new = pd.ExcelFile(r'GPP_all.xlsx')


dfs_new = [xls_new.parse(x) for x in xls_new.sheet_names]

# Define the exponential decay function with log in the exponent
def exp_log_simple(x, a, b, c):
    return a * np.exp(b * x) + c * np.log(x)


colors =["#4878D0", "#6ACC64","#F98181"]


legend_labels = ["Pre-2010", "Post-2010", "All"]


x_labels = [ "0", "1", "2", "3", "4", "5", "6", "7"]

for i, df in enumerate(dfs_new, start=1):
#     df = df.iloc[:, -9:]
    df_new = df.iloc[:, -8:].sub(df.iloc[:, -9])
    df = df.iloc[:, -8:].sub(df.iloc[:, -9])

    means = df.mean()
    q1 = df.quantile(0.25)
    q3 = df.quantile(0.75)
    stds = df.std()
    
    df_new.columns = range(df_new.shape[1])

    # Only use the last 7 columns for fitting
    df_new = df_new.iloc[:, 1:]

    # Calculate the mean for each column
    x_vals = np.arange(1, 8)
    y_vals = df_new.mean().values

    # Fit the simplified model using curve_fit
    popt, pcov = curve_fit(exp_log_simple, x_vals, y_vals, p0=(1, -0.1, 1), maxfev=10000)

    # Get the parameter estimates and covariance matrix
    perr = np.sqrt(np.diag(pcov))

    # Generate the fitted curve
    x_fit = np.linspace(1, 7, 100)
    y_fit = exp_log_simple(x_fit, *popt)
    y_fit_upper = exp_log_simple(x_fit, *(popt + 1.96 * perr))
    y_fit_lower = exp_log_simple(x_fit, *(popt - 1.96 * perr))
    
    # Calculate R2
    y_pred = exp_log_simple(x_vals, *popt)
    ss_res = np.sum((y_vals - y_pred) ** 2)
    ss_tot = np.sum((y_vals - np.mean(y_vals)) ** 2)
    r2 = 1 - (ss_res / ss_tot)
    
    print(r2)
    
    fig, ax = plt.subplots(figsize=(5, 4))
    
 
    lower_error = means - q1
    upper_error = q3 - means
    asymmetric_error = [lower_error, upper_error]
    

    line=ax.axhline(y=0, color='black', linestyle='-.', label="", linewidth=1.5)


    


    ax.errorbar(x_labels, means, yerr=asymmetric_error, fmt='o', 
                color=colors[i-1], linewidth=3, capsize=0, label=legend_labels[i-1], zorder=1, markersize=10, markerfacecolor='white', elinewidth=2)

    
    # Plot the fitted exp_log_simple function
    plt.plot(x_fit, y_fit, label='Fitted', color='#B1759A', linewidth=2)
    # Plot the confidence interval
    plt.fill_between(x_fit, y_fit_lower, y_fit_upper, color='#B1759A', alpha=0.2)
    
    # Display the fitted equation as text using LaTeX formatting
    equation_text = f"$y = {popt[0]:.3f} \cdot e^{{{popt[1]:.3f} \cdot x}} + ({popt[2]:.3f}) \cdot \log(x)$\n$R^2 = {r2:.2f}$"
    plt.text(0.1, max(y_vals)-0.1, equation_text,  fontname="Arial", fontsize=13,verticalalignment='top')
    

#     equation_text_1 = f"N = {len(df)}"
    
#     ax.text(0.04, 0.90, equation_text_1, transform=ax.transAxes,
#         bbox=dict(boxstyle="round,pad=0.6", edgecolor="#333333", facecolor="#ffffff"),
#         fontname="Arial", fontsize=16, color="#333333")
    
    

    legend = ax.legend(loc='upper right', 
                       ncol=1, 
                       fontsize='small',  
                       labelspacing=0.5, 
                       borderpad=0.5, 
                       columnspacing=1,
                       frameon=False,  
                       edgecolor='none',  
                       handletextpad=0.5) 


    for text in legend.get_texts():
        text.set_color('black')  
        text.set_fontsize(16)  
    
    ax.set_ylim(ylims[i-1])  
    
    plt.xlabel('Year',fontsize=18, weight='semibold')
    plt.ylabel('dNDVI',fontsize=18, weight='semibold')
    
    plt.xticks(fontsize=16, weight='normal')
    plt.yticks(fontproperties='Arial', fontsize=16, weight='normal')
        
    plt.savefig(f"TrackLine_NDVI_{i}.jpg",bbox_inches = 'tight',dpi=600)
    
    plt.show()
