In [None]:
%load_ext autoreload
%autoreload 2
%xmode verbose

In [None]:
import arviz as az
import matplotlib.pyplot as plot
import numpy

In [None]:
from ExoRM import numpy_nn_forward, read_rm_data, unique_radius, preprocess_data, init_model, ExoRM

data = read_rm_data()
data = unique_radius(data)
data = preprocess_data(data)
data

In [None]:
_x = data['radius']
_y = data['mass']

x = numpy.log10(_x)
y = numpy.log10(_y)

x_upper = numpy.log10(_x + data['pl_radeerr1'])
x_lower = numpy.log10(_x + data['pl_radeerr2'])
x_err = numpy.maximum(x_upper - x, x - x_lower)

y_upper = numpy.log10(_y + data['pl_bmasseerr1'])
y_lower = numpy.log10(_y + data['pl_bmasseerr2'])
y_err = numpy.maximum(y_upper - y, y - y_lower)

x_obs = x
y_true = y

x, y, x_obs, y_true

In [None]:
erm = ExoRM()
erm.load_trace('best_trace.nc')
erm

In [None]:
import os
os.cpu_count()

In [None]:
# erm.create_trace(x_obs, x_err, y_true, y_err, cores = 4) # change cores as needed

In [None]:
y_pred, lower, upper = erm.predict_full(x_obs, x_err)

plot.scatter(x, y)
plot.scatter(x, y_pred)
plot.scatter(x, lower)
plot.scatter(x, upper)
plot.show()

In [None]:
x_new = numpy.linspace(x.min() - 0.1, x.max() + 0.1, 1000)

y_pred, lower, upper = erm.predict_full(x_new)

plot.scatter(x, y, s = 0.5)
plot.scatter(x_new, y_pred, s = 0.5)
plot.scatter(x_new, lower, s = 0.5)
plot.scatter(x_new, upper, s = 0.5)
plot.show()

In [None]:
# erm.save_trace()

In [None]:
erm2 = ExoRM()
erm2.load_trace()

x_new = numpy.linspace(x.min() - 0.25, x.max() + 0.5, 1000)

y_pred, lower, upper = erm2.predict_full(x_new)

plot.scatter(x, y)
plot.scatter(x_new, y_pred)
plot.scatter(x_new, lower)
plot.scatter(x_new, upper)
plot.show()

In [None]:
# az.summary(erm.trace, round_to=2).sort_values('ess_bulk').head(10)

In [None]:
az.loo(erm.trace, pointwise = True)

In [None]:
# erm.save_defaults_to_other('premade_inputs_nn_9.pkl', 'premade_trace_nn_9.nc')

In [None]:
erm3 = ExoRM('pwlf')
erm3.load_trace('best_linear.nc')
erm3

In [None]:
x_new = numpy.linspace(x.min() - 0.1, x.max() + 0.1, 1000)

y_pred, lower, upper = [numpy.log10(_) for _ in erm.predict_full_linear(10 ** x_new)]

plot.scatter(x, y, s = 0.5, c = 'C0', label = 'true')
plot.plot(x_new, y_pred, c = 'C1', label = 'nn')
plot.plot(x_new, lower, c = 'C2')
plot.plot(x_new, upper, c = 'C2')

y_pred2, lower2, upper2 = [numpy.log10(_) for _ in erm3.predict_full_linear(10 ** x_new)]

plot.plot(x_new, y_pred2, c = 'C3', label = 'pwlf')
plot.plot(x_new, lower2, c = 'C4')
plot.plot(x_new, upper2, c = 'C4')

plot.legend()
plot.show()

In [None]:
x_min = erm.trace.posterior['_x_true'].values.min()
x_max = erm.trace.posterior['_x_true'].values.max()
x_grid = numpy.linspace(x_min, x_max, 1000)[:, None]

In [None]:
f_samples = []
for i in range(100):  # or all draws
    w1 = erm.trace.posterior['w1'].isel(draw=i, chain=0).values
    b1 = erm.trace.posterior['b1'].isel(draw=i, chain=0).values
    w2 = erm.trace.posterior['w2'].isel(draw=i, chain=0).values
    b2 = erm.trace.posterior['b2'].isel(draw=i, chain=0).values
    f_i = numpy_nn_forward(x_grid, w1, b1, w2, b2)
    f_samples.append(f_i)

# Ensure x_grid is a 1D array for gradient calculation
x_grid_flat = x_grid.flatten()

# Calculate slopes and slope changes for each sample
all_slopes = []
all_slope_changes = []

for f_i in f_samples:
    slopes_i = numpy.gradient(f_i.flatten(), x_grid_flat)
    slope_change_i = numpy.gradient(slopes_i, x_grid_flat)
    all_slopes.append(slopes_i)
    all_slope_changes.append(slope_change_i)

# Convert lists to NumPy arrays for easier manipulation
all_slopes = numpy.array(all_slopes)
all_slope_changes = numpy.array(all_slope_changes)
all_slope_changes

In [None]:
mean_slope_change = numpy.mean(all_slope_changes, axis=0)

# Keep std_slope_change as a single scalar (mean of standard deviations across x-points)
std_slope_change_global = numpy.std(all_slope_changes, axis=0).mean()

# --- ADJUST THIS VALUE TO CONTROL THE NUMBER OF EVENTS ---
threshold_multiplier = 1.0 # Try increasing this from 2.0 to 3.0, 4.0, 5.0, etc.

significant_threshold = threshold_multiplier * std_slope_change_global

significant_change_indices = numpy.where(numpy.abs(mean_slope_change) > significant_threshold)[0]

print("--- Initial Identification ---")
print("Raw indices of significant slope changes:")
print(significant_change_indices)

print(f"Max absolute mean slope change: {numpy.abs(mean_slope_change).max():.4e}")
print(f"Global standard deviation of slope changes (std_slope_change_global): {std_slope_change_global:.4e}")
print(f"Calculated significant_threshold: {significant_threshold:.4e}")
print(f"Number of individual points exceeding threshold: {len(significant_change_indices)}")


# --- Grouping Adjacent Indices into Events (allowing for small gaps) ---
# Set a maximum allowable gap between indices to consider them part of the same event
# For example, if max_gap_size = 1, [221, 222, 224] would group 221,222 but separate 224.
# If max_gap_size = 2, [221, 222, 224] would group all three.
max_gap_size = 10 # You can adjust this if needed, but try higher threshold_multiplier first.

event_indices = []
if len(significant_change_indices) > 0:
    current_event = [significant_change_indices[0]]
    for i in range(1, len(significant_change_indices)):
        if (significant_change_indices[i] - significant_change_indices[i-1]) <= max_gap_size:
            # If the current index is within the max_gap_size of the previous one, extend the current event
            current_event.append(significant_change_indices[i])
        else:
            # If there's a gap larger than max_gap_size, start a new event
            event_indices.append(current_event)
            current_event = [significant_change_indices[i]]
    event_indices.append(current_event) # Add the last event

print("\n--- Grouped Events ---")
print(f"Grouped indices of significant slope change events (allowing gap of {max_gap_size}):")
for i, event_group in enumerate(event_indices):
    print(f"Event {i+1}: {event_group}")
print(f"Number of grouped events: {len(event_indices)}")


# --- Reporting for Each Grouped Event ---
final_events_report = []

for event_group in event_indices:
    # For each event, find the index with the maximum absolute mean slope change within that group
    group_mean_slopes = numpy.abs(mean_slope_change[event_group])
    peak_local_index = numpy.argmax(group_mean_slopes)
    peak_global_index = event_group[peak_local_index]

    x_peak_of_change = x_grid_flat[peak_global_index]

    local_std_slope_change_at_peak = numpy.std(all_slope_changes, axis=0)[peak_global_index]
    num_samples = all_slope_changes.shape[0]
    sem_slope_change_at_peak = local_std_slope_change_at_peak / numpy.sqrt(num_samples)

    z_score_95 = 1.96
    ci_lower_at_peak = mean_slope_change[peak_global_index] - z_score_95 * sem_slope_change_at_peak
    ci_upper_at_peak = mean_slope_change[peak_global_index] + z_score_95 * sem_slope_change_at_peak

    x_interval_start = x_grid_flat[event_group[0]]
    x_interval_end = x_grid_flat[event_group[-1]]

    # Ensure approx_x_start/end don't go out of bounds
    approx_x_start = x_grid_flat[max(0, event_group[0] - 1)]
    approx_x_end = x_grid_flat[min(len(x_grid_flat) - 1, event_group[-1] + 1)]

    final_events_report.append({
        'peak_index': peak_global_index,
        'x_peak': x_peak_of_change,
        'mean_slope_change_at_peak': mean_slope_change[peak_global_index],
        'std_error_at_peak': local_std_slope_change_at_peak,
        'ci_95': (ci_lower_at_peak, ci_upper_at_peak),
        'detected_x_interval_event': (x_interval_start, x_interval_end),
        'reported_x_interval': (approx_x_start, approx_x_end)
    })

print("\n--- Final Report for Significant Slope Change Events ---")
if not final_events_report:
    print("No significant events detected with the current threshold.")
for i, event_data in enumerate(final_events_report):
    print(f"\nEvent {i+1}:")
    print(f"  Peak X value: {event_data['x_peak']:.4f} (Index: {event_data['peak_index']})")
    print(f"  Mean Slope Change at Peak: {event_data['mean_slope_change_at_peak']:.4f}")
    print(f"  Error (Std Dev) at Peak: {event_data['std_error_at_peak']:.4f}")
    print(f"  95% CI at Peak: [{event_data['ci_95'][0]:.4f}, {event_data['ci_95'][1]:.4f}]")
    print(f"  Detected X-interval of Event (from first to last index in group): [{event_data['detected_x_interval_event'][0]:.4f}, {event_data['detected_x_interval_event'][1]:.4f}]")
    print(f"  Reported X and X' (approximate interval for paper): [{event_data['reported_x_interval'][0]:.4f}, {event_data['reported_x_interval'][1]:.4f}]")

In [None]:
mean_slope_change = numpy.mean(all_slope_changes, axis = 0)

std_slope_change = numpy.std(all_slope_changes, axis = 0)

lower_bound_slope_change = numpy.percentile(all_slope_changes, 2.5, axis=0)
upper_bound_slope_change = numpy.percentile(all_slope_changes, 97.5, axis=0)

In [None]:
plot.plot(x_grid_flat, mean_slope_change, label='Mean Slope Change', color='blue')
plot.fill_between(x_grid_flat, lower_bound_slope_change, upper_bound_slope_change,
                  color='blue', alpha=0.2, label='95% Credible Interval')
plot.vlines([0.1917, 0.1494, 0.2245], -500, 500)
plot.xlabel('x')
plot.ylabel('Slope Change')
plot.title('Mean Slope Change with 95% Credible Interval')
plot.legend()
plot.show()