In [None]:
import pandas as pd
import numpy as np
import os
import glob
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from scipy import stats
from scipy.optimize import minimize
from scipy.stats import lognorm, median_abs_deviation, wasserstein_distance, anderson, shapiro
from statsmodels.stats.diagnostic import lilliefors



In [None]:


def wasserstein_flow_lognormal(actual_loss):
   
    log_loss = np.log(actual_loss + 1e-10)
    
  
    mu = np.median(log_loss)
    sigma = median_abs_deviation(log_loss, scale='normal')
    
    n = len(actual_loss)
    sorted_actual = np.sort(actual_loss)
    q_ = np.linspace(0.5 / n, 1 - 0.5 / n, n)
    
    def objective_(params):
        mu_opt, sigma_opt = params
        # Ensure sigma is positive to avoid math errors in lognorm.ppf
        if sigma_opt <= 0:
            return np.inf
        ideal = stats.lognorm.ppf(q_, s=sigma_opt, scale=np.exp(mu_opt))
        return np.sum((sorted_actual - ideal)**2)
    
    result_ = minimize(objective_, [mu, sigma], method='Nelder-Mead')
    mu_opt_, sigma_opt_ = result_.x
    print(f'Optimized Lognormal Params -> mu: {mu_opt_:.4f}, sigma: {sigma_opt_:.4f}')
    
    # Return the idealized loss values from the fitted distribution
    return stats.lognorm.ppf(q_, s=sigma_opt_, scale=np.exp(mu_opt_))

def wasserstein_l2(dist_a, dist_b):
  
  sorted_a = np.sort(dist_a)
  sorted_b = np.sort(dist_b)
  squared_diffs = (sorted_a - sorted_b)**2
  return np.sqrt(np.mean(squared_diffs))

def calculate_wasserstein_and_ori_trend(actual_training_loss, validation_loss, ideal_training_loss):
    
    start_epoch = 1
    end_epoch = len(actual_training_loss)
    d_train_valid_trend, d_ideal_train_trend, d_ideal_valid_trend, ori_trend = [], [], [], []

    for i in range(start_epoch, end_epoch + 1):
        train_window = actual_training_loss[:i]
        valid_window = validation_loss[:i]
        ideal_window = ideal_training_loss[:i]
        
        d_train_valid = wasserstein_l2(train_window, valid_window)
        d_ideal_train = wasserstein_l2(ideal_window, train_window)
        d_ideal_valid = wasserstein_l2(ideal_window, valid_window)
        
        d_train_valid_trend.append(d_train_valid)
        d_ideal_train_trend.append(d_ideal_train)
        d_ideal_valid_trend.append(d_ideal_valid)
        
        if np.mean(train_window) > np.mean(valid_window) and np.median(train_window) > np.median(valid_window):
            ori = 0.0
        else:
            ori = 0.0 if d_ideal_valid == 0 else max(0.0, 1 - d_ideal_train / d_ideal_valid)
        ori_trend.append(ori)

    return d_train_valid_trend, d_ideal_train_trend, d_ideal_valid_trend, ori_trend

def save_fit_loss_to_csv(fit_loss_training, actual_training_loss, output_dir):
   
    os.makedirs(output_dir, exist_ok=True)
    df = pd.DataFrame({
        "epoch": np.arange(1, len(fit_loss_training) + 1),
        "actual_training_loss": np.asarray(actual_training_loss),
        "fit_loss_training": np.asarray(fit_loss_training),
        "difference": np.asarray(actual_training_loss) - np.asarray(fit_loss_training),
    })
    out_path = os.path.join(output_dir, "fit_loss_training_data.csv")
    df.to_csv(out_path, index=False)
    print(f" Saved fit_loss_training CSV to: {out_path}")
    return out_path


try:
    results_dirs = glob.glob("results*")
    if not results_dirs:
        raise FileNotFoundError("No 'results directories found.")
    latest_dir = max(results_dirs, key=os.path.getmtime)
    history_path = os.path.join(latest_dir, "training_history.csv")
    if not os.path.exists(history_path):
        raise FileNotFoundError(f"training_history.csv not found in '{latest_dir}'")
    print(f"Loading training history from: {history_path}")
    data_metrics = pd.read_csv(history_path)

except FileNotFoundError as e:
    print(f"ERROR: {e}")
   
    exit()


loss_training = data_metrics['loss'].to_numpy()
loss_validating = data_metrics['val_loss'].to_numpy()
print("\n--- Starting Overfitting Analysis ---")
fit_loss_training = wasserstein_flow_lognormal(loss_training)

save_fit_loss_to_csv(fit_loss_training, loss_training, latest_dir)

d_it = wasserstein_l2(loss_training, fit_loss_training)
d_iv = wasserstein_l2(loss_validating, fit_loss_training)
d_vt = wasserstein_l2(loss_training, loss_validating)
print('\n\033[1;31;45m Wasserstein distances:\033[0m\n ')
print('between training loss and validating loss:')
print("\033[1;31;45m Wasserstein Distance :(new range) \033[0m", d_vt) 
print('between training loss and ideal log normal')
print("\033[1;31;40m Wasserstein Distance :(new range) \033[0m",d_it )
#wasserstein_dist(loss_training,fit_loss_training)
#wasserstein_dist_point(loss_training,fit_loss_training)
print('between ideal log normal and validating loss')
print("\033[1;31;47m Wasserstein Distance :(new range) \033[0m", d_iv)
print('\n--- Wasserstein Distances (Final) ---')
print(f"Training vs. Ideal (d_it): {d_it:.4f}")
print(f"Validation vs. Ideal (d_iv): {d_iv:.4f}")
print(f"Training vs. Validation (d_vt): {d_vt:.4f}")
'''
final_ori = 0.0
if not (np.mean(loss_training) > np.mean(loss_validating) and np.median(loss_training) > np.median(loss_validating)):
    if d_iv > 0:
        final_ori = max(0, 1 - d_it / d_iv)
print(f"Final ORI (Overfitting Resistance Index): {final_ori:.4f}")
'''
w1, w2, w3, ori_trend = calculate_wasserstein_and_ori_trend(loss_training, loss_validating, fit_loss_training)

print('ori trend:',ori_trend)
trends_df = pd.DataFrame({  
    'epoch': np.arange(1, len(ori_trend) + 1),  
    
    'ori_trend': ori_trend  
})  

# Define the output path and save the file  
output_csv_path = os.path.join(latest_dir, 'overfitting_analysis_trends.csv')  
trends_df.to_csv(output_csv_path, index=False)  
# --- 3. PLOTTING ---
print("\n--- Generating Analysis Plot ---")
fig, axs = plt.subplots(1, 4, figsize=(20, 4.5))
#fig.suptitle(f'Analysis for DistilRoBERTa on IMDB Dataset', fontsize=16)

# Plot 1: Loss Curves
ax = axs[0]
ax.plot(loss_training, label='Training', linewidth=1.5, color='red')
ax.plot(loss_validating, label='Validation', linewidth=1.5, color='green')
# FIX: Plotting fit_loss_training directly against epochs is correct.
# Sorting it distorts the time-series representation of the ideal loss curve.
ax.plot(sorted(fit_loss_training,reverse=True), label='Ideal Training', linewidth=1.5, color='blue', linestyle='--')
ax.set_title('Loss Curves')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss Value')
#ax.axvline(stop, linestyle='--', color='skyblue', label=f'Early Stop (Epoch {stop})')
#ax.vlines(stop, 0,max(max(loss_training),max(loss_validating)),linestyles='dashdot', colors='skyblue')  #plot earlystopping line
#ax.annotate(' ES: epoch '+str(stop), xy=(stop-8, max(max(loss_training),max(loss_validating))*0.7),xytext=(stop-8, max(max(loss_training),max(max(loss_validating))*0.7),fontsize=10 )   #set text annotation for intersection line 
ax.legend(loc='best')
ax.grid(True, which='both', linestyle=':', linewidth=0.5)

# Plot 2: Density Plot
ax = axs[1]
sns.kdeplot(loss_training, color='red', label='Training', ax=ax)
sns.kdeplot(loss_validating, color='green', label='Validation', ax=ax)
sns.kdeplot(fit_loss_training, color='blue', label='Ideal Training', linestyle='--', ax=ax)
ax.set_title('Density of Loss Values')
ax.set_xlabel('Loss Value')
ax.set_ylabel('Density')
ax.legend(loc='best')
ax.grid(True, which='both', linestyle=':', linewidth=0.5)

# Plot 3: Wasserstein Distance Trends
ax = axs[2]
ax.plot(w1, color='red', linewidth=2.0, linestyle='-', label='WD(Train, Valid)')
ax.plot(w2, color='black', linewidth=2.0, linestyle='--', label='WD(Train, Ideal)')
ax.plot(w3, color='green', linewidth=2.0, linestyle='-', label='WD(Valid, Ideal)')
ax.set_title('Wasserstein Distance Trends')
ax.set_xlabel('Epoch')
ax.set_ylabel('Distance')
#ax.axvline(stop, linestyle='--', color='#40BEE6')
#ax.axvline(stop, linestyle='--', color='skyblue', label=f'Early Stop (Epoch {stop})')
#ax.vlines(stop, 0,max(max(loss_training),max(loss_validating)),linestyles='dashdot', colors='skyblue')  #plot earlystopping line
#ax.annotate(' ES: epoch '+str(stop), xy=(stop-8, max(max(w2),max(w3))*0.7),xytext=(stop-8,max(max(w2),max(w3))*0.7),fontsize=10 )   #set text annotation f
ax.legend(loc='best')
ax.grid(True, which='both', linestyle=':', linewidth=0.5)


# Plot 4: ORI Trend
ax = axs[3]
ax.plot(ori_trend, color='#732BF5', linewidth=2.0, linestyle='-')
ax.set_title('Overfitting Robustness Index (ORI) Trend')
ax.set_xlabel('Epoch')
ax.set_ylabel('ORI Value')
ax.set_ylim(-0.1, 1.1)
#ax.axvline(stop, linestyle='--', color='gray')
ax.grid(True, which='both', linestyle=':', linewidth=0.5)

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        
output_plot_path = os.path.join(latest_dir, 'DistilRoBERTa_imdb_overfitting_analysis.png')
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.savefig(output_plot_path, dpi=250)
print(f"\nAnalysis plot saved successfully to: {output_plot_path}")
plt.show()

new_ORI = ori_trend[-1]
print(f"Calculated ORI: {new_ORI:.4f}")