In [1]:
import numpy as np
import scipy.optimize
import scipy.stats
import random
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output

In [16]:
#current code (working version)

# Functions to generate chi-squared values for a data series.
def lin_chisq(params, x_data, y_data, y_err):
    a_sol, b_sol = params
    model = a_sol + b_sol * x_data
    chisqval = np.sum(((y_data - model) / y_err) ** 2)
    return chisqval

def quad_chisq(params, x_data, y_data, y_err):
    a_sol, b_sol, c_sol = params
    model = a_sol + b_sol * x_data + c_sol * x_data ** 2
    chisqval = np.sum(((y_data - model) / y_err) ** 2)
    return chisqval

def lin_chisq_fit(xval, yval, yerr):
    initial = np.array([0.19, 0.8])
    fit = scipy.optimize.minimize(lin_chisq, initial, args=(xval, yval, yerr))
    a_soln, b_soln = fit.x
    fit_line = a_soln + b_soln * xval
    chisq = lin_chisq(fit.x, xval, yval, yerr)

    dof = len(xval) - initial.size
    p_value = scipy.stats.chi2.sf(chisq, dof)

    if p_value > 0.9: fit_acceptable = "a suspiciously good fit"
    elif p_value > 0.1: fit_acceptable = "an acceptable fit"
    elif p_value > 0.05: fit_acceptable = "a marginally acceptable fit"
    elif p_value > 0.01: fit_acceptable = "a marginally unacceptable fit"
    else: fit_acceptable = "an unacceptable fit"

    return fit.x, fit_line, chisq, dof, p_value, fit_acceptable

def quad_chisq_fit(xval, yval, yerr):
    initial = np.array([0.5, 0.5, -0.6])
    fit = scipy.optimize.minimize(quad_chisq, initial, args=(xval, yval, yerr))
    a_soln, b_soln, c_soln = fit.x
    fit_line = a_soln + b_soln * xval + c_soln * xval ** 2
    chisq = quad_chisq(fit.x, xval, yval, yerr)
    
    dof = len(xval) - initial.size
    p_value = scipy.stats.chi2.sf(chisq, dof)
    
    if p_value > 0.9: fit_acceptable = "a suspiciously good fit"
    elif p_value > 0.1: fit_acceptable = "an acceptable fit"
    elif p_value > 0.05: fit_acceptable = "a marginally acceptable fit"
    elif p_value > 0.01: fit_acceptable = "a marginally unacceptable fit"
    else: fit_acceptable = "an unacceptable fit"
    
    return fit.x, fit_line, chisq, dof, p_value, fit_acceptable

def generate_data(num_points, errortype, errorsize, uniform_spacing, fixedparams, modeltype):
    if uniform_spacing: x = np.linspace(-1, 1, num_points)
    else: x = np.random.uniform(-1, 1, num_points)
    newx = np.linspace(-1, 1, 1000)

    error_size_map = {"Small": 0.02, "Medium": 0.1, "Large": 0.25, "Huge": 0.5}
    error_size = error_size_map.get(errorsize)
    error_type_map = {"Homoscedastic": 0, "Heteroscedastic": 1}
    error_type = error_type_map.get(errortype)
    y_error = np.random.normal(error_size, 0.25 * error_type * error_size, size=x.size)

    if fixedparams:
        if modeltype == "Linear model":
            y = [0.19 + 0.8 * np.random.normal(x[i], y_error[i]) for i in range(num_points)]
            genfn = 0.19 + 0.8 * newx
        elif modeltype == "Quadratic model":
            y = [0.5 + 0.5 * np.random.normal(x[i], y_error[i]) - 0.6 * (np.random.normal(x[i]** 2, y_error[i])) for i in range(num_points)]
            genfn = 0.5 + 0.5 * newx - 0.6 * newx ** 2
    else:
        if modeltype == "Linear model":
            var1, var2 = np.random.uniform(-1, 1), np.random.uniform(-1, 1)
            y = [var1 + var2 * np.random.normal(x[i], y_error[i]) for i in range(num_points)]
            genfn = var1 + var2 * newx
        elif modeltype == "Quadratic model":
            var1, var2, var3 = np.random.uniform(-1, 1), np.random.uniform(-1, 1), np.random.uniform(-1, 1)
            y = [var1 + var2 * np.random.normal(x[i], y_error[i]) + var3 * (np.random.normal(x[i]** 2, y_error[i])) for i in range(num_points)]
            genfn = var1 + var2 * newx + var3 * newx ** 2

    lin_params, lin_fit_line, lin_chisq, lin_dof, lin_p_value, lin_fit_acceptable = lin_chisq_fit(x, y, y_error)
    quad_params, quad_fit_line, quad_chisq, quad_dof, quad_p_value, quad_fit_acceptable = quad_chisq_fit(x, y, y_error)

    lin_fit_line_smooth = lin_params[0] + lin_params[1] * newx
    quad_fit_line_smooth = quad_params[0] + quad_params[1] * newx + quad_params[2] * newx ** 2

    return {
        'x': x,
        'y': y,
        'y_error': y_error,
        'num_points': num_points,
        'errortype': errortype,
        'errorsize': errorsize,
        'uniform_spacing': uniform_spacing,
        'genfn': genfn,
        'newx': newx,
        'fixedparams': fixedparams,
        'lin_fit_line': lin_fit_line_smooth,
        'lin_chisq': lin_chisq,
        'lin_dof': lin_dof,
        'lin_p_value': lin_p_value,
        'lin_p_value_formatted': f'{lin_p_value:.4f}',
        'lin_fit_acceptable': lin_fit_acceptable,
        'quad_fit_line': quad_fit_line_smooth,
        'quad_chisq': quad_chisq,
        'quad_dof': quad_dof,
        'quad_p_value': quad_p_value,
        'quad_p_value_formatted': f'{quad_p_value:.4f}',
        'quad_fit_acceptable': quad_fit_acceptable,
        'model': modeltype
    }

# Widgets for interactive controls
num_points_widget = widgets.IntSlider(min=3, max=100, step=1, value=10, description='Data points:')
num_datasets_widget = widgets.IntSlider(min=1, max=100, step=1, value=1, description='Number of datasets:')
errortype_widget = widgets.Dropdown(options=["Homoscedastic", "Heteroscedastic"], value="Homoscedastic", description='Error type:')
errorsize_widget = widgets.Dropdown(options=["Small", "Medium", "Large", "Huge"], value="Medium", description='Error size:')
pointspacing_widget = widgets.Checkbox(value=True, description='Uniform spacing')
fixedparams_widget = widgets.Checkbox(value=True, description='Fixed parameters')
modeltype_widget = widgets.Dropdown(options=["Linear model", "Quadratic model"], value="Quadratic model", description='Model type:')
showgenmodel_widget = widgets.Checkbox(value=True, description='Show generating model')
showlinearfit_widget = widgets.Checkbox(value=False, description='Show linear fit')
showquadfit_widget = widgets.Checkbox(value=False, description='Show quadratic fit')
sort_options = widgets.Dropdown(options=["None", "Chi-squared (Linear)", "P-value (Linear)", "Chi-squared (Quadratic)", "P-value (Quadratic)"], value="None", description='Sort by:')
dataset_selector = widgets.Dropdown(description='Select dataset:')
regenerate_button = widgets.Button(description="Generate new data")
output = widgets.Output()

data_storage = {'data': [], 'chisqs': [], 'pvals': [], 'indexes': [], 'fit_acceptability': []}

# Function to regenerate data
def regenerate_data(*args):
    with output:
        clear_output(wait=True)
        
        data_storage['data'] = []
        data_storage['chisqs'] = []
        data_storage['pvals'] = []
        data_storage['indexes'] = []
        data_storage['fit_acceptability'] = []

        num_datasets = num_datasets_widget.value
        num_points = num_points_widget.value
        errortype = errortype_widget.value
        errorsize = errorsize_widget.value
        uniform_spacing = pointspacing_widget.value
        fixedparams = fixedparams_widget.value
        modeltype = modeltype_widget.value
        
        for i in range(num_datasets):
            dataset = generate_data(num_points, errortype, errorsize, uniform_spacing, fixedparams, modeltype)
            data_storage['data'].append(dataset)
            data_storage['chisqs'].append({
                'linear': dataset['lin_chisq'],
                'quadratic': dataset['quad_chisq']
            })
            data_storage['pvals'].append({
                'linear': dataset['lin_p_value'],
                'quadratic': dataset['quad_p_value']
            })
            data_storage['indexes'].append(i)
            data_storage['fit_acceptability'].append({
                'linear': dataset['lin_fit_acceptable'],
                'quadratic': dataset['quad_fit_acceptable']
            })
        
        update_dataset_selector()

# Function to update dataset selector options
# Define sorted_options globally
sorted_options = []

def update_dataset_selector():
    global sorted_options  # Gives access to the globally defined "sorted_options" variable
    
    num_datasets = num_datasets_widget.value
    options = [
        (i, f'Dataset {i+1}: Linear χ²={data_storage["chisqs"][i]["linear"]:.2f}, P-value={data_storage["pvals"][i]["linear"]:.4f} | '
         f'Quadratic χ²={data_storage["chisqs"][i]["quadratic"]:.2f}, P-value={data_storage["pvals"][i]["quadratic"]:.4f}')
        for i in range(num_datasets)
    ]
    
    # Apply sorting if needed
    sort_by = sort_options.value
    if sort_by == "Chi-squared (Linear)": sorted_indices = sorted(data_storage['indexes'], key=lambda i: data_storage['chisqs'][i]['linear'])
    elif sort_by == "P-value (Linear)": sorted_indices = sorted(data_storage['indexes'], key=lambda i: data_storage['pvals'][i]['linear'])
    elif sort_by == "Chi-squared (Quadratic)": sorted_indices = sorted(data_storage['indexes'], key=lambda i: data_storage['chisqs'][i]['quadratic'])
    elif sort_by == "P-value (Quadratic)": sorted_indices = sorted(data_storage['indexes'], key=lambda i: data_storage['pvals'][i]['quadratic'])
    else: sorted_indices = data_storage['indexes']
    
    sorted_options = [(i, options[i][1]) for i in sorted_indices]
    dataset_selector.options = [label for index, label in sorted_options]
    
    if num_datasets > 0:
        dataset_selector.value = dataset_selector.options[0]
        
    update_chart()

def update_chart(*args):
    global sorted_options  # Access global variable
    
    with output:
        clear_output(wait=True)
        
        if not dataset_selector.options: return  # Avoid errors if no options are available
        
        selected_option = dataset_selector.value
        selected_index = next(index for index, label in sorted_options if label == selected_option)
        dataset = data_storage['data'][selected_index]
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
        
        x = dataset['x']
        y = dataset['y']
        y_error = dataset['y_error']
        
        # Plot fitting lines and data points
        if showgenmodel_widget.value: ax1.plot(dataset['newx'], dataset['genfn'], 'k--', label='Generating Model', alpha=0.5)
        if showlinearfit_widget.value: ax1.plot(dataset['newx'], dataset['lin_fit_line'], 'b-', label=f'Linear Fit: χ²={dataset["lin_chisq"]:.2f}, dof={dataset["lin_dof"]}, P(χ²,dof)={dataset["lin_p_value_formatted"]}, {dataset["lin_fit_acceptable"]}')
        if showquadfit_widget.value: ax1.plot(dataset['newx'], dataset['quad_fit_line'], 'r-', label=f'Quadratic Fit: χ²={dataset["quad_chisq"]:.2f}, dof={dataset["quad_dof"]}, P(χ²,dof)={dataset["quad_p_value_formatted"]}, {dataset["quad_fit_acceptable"]}')
        
        ax1.errorbar(x, y, yerr=y_error, fmt='o', color='black')
        ax1.set_xlabel('X')
        ax1.set_ylabel('Y')
        ax1.legend()
        ax1.set_title(f'Dataset {selected_index + 1}')

        # Plot histogram of chi-squared values
        lin_chisqs = [dataset['lin_chisq'] for dataset in data_storage['data']]
        quad_chisqs = [dataset['quad_chisq'] for dataset in data_storage['data']]
        bins = np.linspace(0, max(max(lin_chisqs), max(quad_chisqs)), 30)
        
        ax2.hist(lin_chisqs, bins, alpha=0.5, color='b', label='Linear Fit χ²')
        ax2.hist(quad_chisqs, bins, alpha=0.5, color='r', label='Quadratic Fit χ²')

        # Add dotted lines for the chi-squared values of the selected dataset
        ax2.axvline(dataset['lin_chisq'], color='b', linestyle=':', label='Linear Fit χ² of ' + f'Dataset {selected_index + 1}')
        ax2.axvline(dataset['quad_chisq'], color='r', linestyle=':', label='Quadratic Fit χ² of '+ f'Dataset {selected_index + 1}')
        
        ax2.set_xlabel('Chi-squared')
        ax2.set_ylabel('Frequency')
        ax2.legend()
        ax2.set_title('Histogram of Chi-squared Values')

        plt.tight_layout()
        plt.show()



# Link widget events to functions
num_points_widget.observe(lambda x: regenerate_data(), names='value')
num_datasets_widget.observe(lambda x: regenerate_data(), names='value')
errortype_widget.observe(lambda x: regenerate_data(), names='value')
errorsize_widget.observe(lambda x: regenerate_data(), names='value')
pointspacing_widget.observe(lambda x: regenerate_data(), names='value')
fixedparams_widget.observe(lambda x: regenerate_data(), names='value')
modeltype_widget.observe(lambda x: regenerate_data(), names='value')
showgenmodel_widget.observe(lambda x: update_chart(), names='value')
showlinearfit_widget.observe(lambda x: update_chart(), names='value')
showquadfit_widget.observe(lambda x: update_chart(), names='value')
sort_options.observe(lambda x: update_dataset_selector(), names='value')
dataset_selector.observe(lambda x: update_chart(), names='value')
regenerate_button.on_click(lambda x: regenerate_data())

# Display widgets and output
display(num_points_widget, num_datasets_widget, errortype_widget, errorsize_widget, pointspacing_widget, fixedparams_widget, modeltype_widget, showgenmodel_widget, showlinearfit_widget, showquadfit_widget, sort_options, dataset_selector, regenerate_button, output)



IntSlider(value=10, description='Data points:', min=3)

IntSlider(value=1, description='Number of datasets:', min=1)

Dropdown(description='Error type:', options=('Homoscedastic', 'Heteroscedastic'), value='Homoscedastic')

Dropdown(description='Error size:', index=1, options=('Small', 'Medium', 'Large', 'Huge'), value='Medium')

Checkbox(value=True, description='Uniform spacing')

Checkbox(value=True, description='Fixed parameters')

Dropdown(description='Model type:', index=1, options=('Linear model', 'Quadratic model'), value='Quadratic mod…

Checkbox(value=True, description='Show generating model')

Checkbox(value=False, description='Show linear fit')

Checkbox(value=False, description='Show quadratic fit')

Dropdown(description='Sort by:', options=('None', 'Chi-squared (Linear)', 'P-value (Linear)', 'Chi-squared (Qu…

Dropdown(description='Select dataset:', options=(), value=None)

Button(description='Generate new data', style=ButtonStyle())

Output()