In [36]:
# --- Core Libraries ---
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import math

# --- PyTorch ---
import torch
import torch.nn as nn

# --- Interactive Widgets & Progress Bar ---
import ipywidgets as widgets
from ipywidgets import Layout, interact_manual # FIX: Added interact_manual here
from tqdm.notebook import tqdm

# --- CUSTOM MODULE IMPORTS ---
try:
    from ffact import generate_factorial_table
    from wave_nn import ASTorqueModel, RPM_GRID, STATIC_DIM
    from ranking_functions import get_max_torque, get_avg_torque, get_smoothness, quicksort_3d
    print("Successfully imported custom modules (wave_nn, ranking_functions, ffact).")
except ImportError as e:
    print(f"ERROR: Could not import custom modules: {e}")
    print("Please ensure wave_nn.py, ranking_functions.py, and ffact.py are in the same directory as this notebook.")

# --- GLOBAL CONFIGURATION ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"PyTorch will use device: {DEVICE}")

# These are the maximum allowable ranges for the DOE parameters for UI validation.
MAX_RANGES = {
    'sec_header_len': (10.0, 14.0),
    'header_len': (400.0, 560.0),
    'runner_len': (6.0, 15.0),
    'plenum_vol': (1.0, 5.0),
}

Successfully imported custom modules (wave_nn, ranking_functions, ffact).
PyTorch will use device: cuda


In [37]:
# This global variable will store the results from the DOE run
torque_curves_data = None

try:
    # --- Load Normalization Statistics ---
    S_mean = np.load("S_mean.npy")
    S_std = np.load("S_std.npy")
    Y_mean = np.load("Y_mean.npy")
    Y_std = np.load("Y_std.npy")

    # --- Restore Trained Network ---
    model = ASTorqueModel().to(DEVICE)
    # Using weights_only=True is the recommended, secure way to load weights
    state_dict = torch.load("wave.nn", map_location=DEVICE, weights_only=True)
    model.load_state_dict(state_dict)
    model.eval()
    
    # --- Prepare Base RPM Tensor for Inference ---
    rpm_norm_base = ((RPM_GRID - RPM_GRID.mean()) / RPM_GRID.std()).astype(np.float32)
    rpm_norm_base = torch.tensor(rpm_norm_base, device=DEVICE).unsqueeze(0).unsqueeze(0)

    print("Model and normalization data loaded successfully.")
    
except (FileNotFoundError, RuntimeError) as e:
    print(f"ERROR loading model or data files: {e}")
    print("Please ensure wave.nn and all .npy files are present and that the model architecture in wave_nn.py matches the saved file.")

Model and normalization data loaded successfully.


In [38]:
style = {'description_width': 'initial'}
layout = Layout(width='100%')

# --- WIDGETS FOR DOE PARAMETERS ---
sh_slider = widgets.FloatRangeSlider(value=[10.5, 12.5], min=MAX_RANGES['sec_header_len'][0], max=MAX_RANGES['sec_header_len'][1], step=0.1, description='Sec. Header Len (in)', style=style, layout=layout, readout_format='.1f')
sh_step = widgets.FloatText(value=0.5, description='Step', layout=Layout(width='150px'))

ph_slider = widgets.FloatRangeSlider(value=[420, 480], min=MAX_RANGES['header_len'][0], max=MAX_RANGES['header_len'][1], step=1.0, description='Header Len (mm)', style=style, layout=layout, readout_format='.0f')
ph_step = widgets.FloatText(value=20, description='Step', layout=Layout(width='150px'))

ir_slider = widgets.FloatRangeSlider(value=[8.0, 12.0], min=MAX_RANGES['runner_len'][0], max=MAX_RANGES['runner_len'][1], step=0.1, description='Runner Len (in)', style=style, layout=layout, readout_format='.1f')
ir_step = widgets.FloatText(value=1.0, description='Step', layout=Layout(width='150px'))

pv_slider = widgets.FloatRangeSlider(value=[2.0, 4.0], min=MAX_RANGES['plenum_vol'][0], max=MAX_RANGES['plenum_vol'][1], step=0.1, description='Plenum Vol (L)', style=style, layout=layout, readout_format='.1f')
pv_step = widgets.FloatText(value=0.5, description='Step', layout=Layout(width='150px'))

# --- UI LAYOUT ---
ui = widgets.VBox([
    widgets.HBox([sh_slider, sh_step]),
    widgets.HBox([ph_slider, ph_step]),
    widgets.HBox([ir_slider, ir_step]),
    widgets.HBox([pv_slider, pv_step])
])

def run_doe_solver(b):
    """Function to execute when the 'Run DOE' button is clicked."""
    global torque_curves_data
    
    # --- Generate DOE table using imported function ---
    print("Generating DOE table...")
    # The function signature from ffact.py expects lower bound, upper bound, and step for each variable.
    doe_table = np.asarray(generate_factorial_table(
        sh_slider.value[0], sh_slider.value[1], sh_step.value,
        ph_slider.value[0], ph_slider.value[1], ph_step.value,
        ir_slider.value[0], ir_slider.value[1], ir_step.value,
        pv_slider.value[0], pv_slider.value[1], pv_step.value
    ))
    
    n_cases = doe_table.shape[0]
    if n_cases == 0:
        print("Warning: DOE table is empty. Check your range and step values.")
        return
    print(f"Generated {n_cases} cases.")
    
    # --- Initialize result array ---
    SEQ_LEN = len(RPM_GRID)
    result_curves = np.zeros((SEQ_LEN, STATIC_DIM + 2, n_cases), dtype=np.float32)
    
    # --- Run inference loop ---
    with tqdm(total=n_cases, desc="Solving DOE cases") as pbar:
        for i, params in enumerate(doe_table):
            stat_norm = (params - S_mean) / S_std
            stat_tensor = torch.tensor(stat_norm, dtype=torch.float32, device=DEVICE).unsqueeze(0)
            
            with torch.no_grad():
                pred_norm = model(rpm_norm_base, stat_tensor).squeeze(0).cpu().numpy()
            
            pred_torque = pred_norm * Y_std + Y_mean
            
            block = np.hstack([
                np.tile(params, (SEQ_LEN, 1)),
                RPM_GRID.reshape(-1, 1),
                pred_torque.reshape(-1, 1)
            ])
            result_curves[:, :, i] = block
            pbar.update(1)
            
    # Pad array for future ranking metrics
    torque_curves_data = np.pad(result_curves, pad_width=((0, 0), (0, 4), (0, 0)), mode='constant')
    
    print("\nDOE solving complete. Results are stored.")
    print(f"Shape of results array: {torque_curves_data.shape}")

# --- BUTTON AND OUTPUT ---
run_button = widgets.Button(description="Run DOE Solver", button_style='success', icon='cogs')
run_button.on_click(run_doe_solver)

display(ui, run_button)

VBox(children=(HBox(children=(FloatRangeSlider(value=(10.5, 12.5), description='Sec. Header Len (in)', layout=…

Button(button_style='success', description='Run DOE Solver', icon='cogs', style=ButtonStyle())

In [39]:
# --- WIDGETS FOR RANKING ---
style = {'description_width': 'initial'}
layout = Layout(width='100%')

# Define the widgets that will serve as inputs
max_tq_weight_slider = widgets.FloatSlider(value=0.05, min=0, max=1.0, step=0.05, description='Weight: Max Torque', style=style, layout=layout)
avg_tq_weight_slider = widgets.FloatSlider(value=0.70, min=0, max=1.0, step=0.05, description='Weight: Avg Torque (7-11k RPM)', style=style, layout=layout)
smoothness_weight_slider = widgets.FloatSlider(value=0.25, min=0, max=1.0, step=0.05, description='Weight: Smoothness', style=style, layout=layout)
num_curves_slider = widgets.IntSlider(value=5, min=1, max=100, step=1, description='Top Curves to Plot:', style=style, layout=layout)

# Define the button that will trigger the action
rank_button = widgets.Button(description="Rank and Plot Results", button_style='primary', icon='sort')

# Create a dedicated Output widget to display results
ranking_output = widgets.Output()

def run_ranking_on_click(b):
    """
    This function redirects all output to the 'ranking_output' widget,
    clearing it before each new execution.
    """
    # Use a 'with' block to capture all output into the dedicated widget
    with ranking_output:
        # Clear the previous output. wait=True prevents flickering.
        ranking_output.clear_output(wait=True)
        
        if torque_curves_data is None:
            print("Error: Please run the DOE solver in Part 1 first.")
            return
            
        print("Starting ranking and plotting process...")
        rank_button.disabled = True
        
        try:
            # --- Calculation and Sorting Logic (unchanged) ---
            local_torque_curves = torque_curves_data.copy()
            shape = local_torque_curves.shape
            with tqdm(total=shape[2], desc="Calculating ranking parameters") as pbar:
                for i in range(shape[2]):
                    curve_slice = local_torque_curves[:, :, i]
                    temp = np.concatenate((get_max_torque(curve_slice), get_avg_torque(curve_slice, 7000, 11000), get_smoothness(curve_slice)), axis=1)
                    local_torque_curves[:, 6:9, i] = temp
                    pbar.update(1)
            local_torque_curves[:, 6, :] /= np.max(local_torque_curves[:, 6, :])
            local_torque_curves[:, 7, :] /= np.max(local_torque_curves[:, 7, :])
            local_torque_curves[:, 8, :] /= np.max(local_torque_curves[:, 8, :])
            local_torque_curves[:, 8, :] = 1 - local_torque_curves[:, 8, :]
            w_max, w_avg, w_smooth = max_tq_weight_slider.value, avg_tq_weight_slider.value, smoothness_weight_slider.value
            local_torque_curves[:, 9, :] = (w_max * local_torque_curves[0, 6, :] + w_avg * local_torque_curves[0, 7, :] + w_smooth * local_torque_curves[0, 8, :])
            with tqdm(total=shape[2] * math.ceil(np.log2(shape[2])), desc="Sorting curves") as pbar:
                sorted_curves = quicksort_3d(local_torque_curves, progress_bar=pbar)
            
            # --- Plotting Logic ---
            num_to_plot = num_curves_slider.value
            plt.figure(figsize=(14, 8))
            labels = []
            for i in range(min(num_to_plot, shape[2])):
                params, rank_score = sorted_curves[0, 0:4, i], sorted_curves[0, 9, i]
                label_text = f"Rank {i+1} (Score: {rank_score:.3f}): SH={params[0]:.2f}, H={params[1]:.1f}, R={params[2]:.2f}, PV={params[3]:.2f}"
                labels.append(label_text)
                plt.plot(sorted_curves[:, 4, i], sorted_curves[:, 5, i], alpha=0.9)
            plt.title(f"Top {num_to_plot} Torque Curves (Ranked by Custom Weights)")
            plt.xlabel("RPM")
            plt.ylabel("Torque (Nm)")
            plt.legend(labels, bbox_to_anchor=(1.02, 1), loc='upper left', fontsize='small')
            plt.grid(True, which='both', linestyle='--')
            plt.tight_layout()
            # plt.show() will now render the plot inside the ranking_output widget
            plt.show()
            print("Process complete.")

        except Exception as e:
            print(f"An unexpected error occurred: {e}")
        finally:
            rank_button.disabled = False

# --- ATTACH EVENT HANDLER AND DISPLAY UI ---
rank_button._click_handlers.callbacks = []
rank_button.on_click(run_ranking_on_click)

# Display the input controls and the dedicated output area
ranking_controls = widgets.VBox([max_tq_weight_slider, avg_tq_weight_slider, smoothness_weight_slider, num_curves_slider, rank_button])
display(ranking_controls, ranking_output)


VBox(children=(FloatSlider(value=0.05, description='Weight: Max Torque', layout=Layout(width='100%'), max=1.0,…

Output()

In [40]:
# --- WIDGETS FOR SINGLE PREDICTION ---
style = {'description_width': 'initial'}
sh_input = widgets.FloatText(value=12.0, description='Sec. Header (in):', style=style)
ph_input = widgets.FloatText(value=450.0, description='Header Len (mm):', style=style)
ir_input = widgets.FloatText(value=10.0, description='Runner Len (in):', style=style)
pv_input = widgets.FloatText(value=3.0, description='Plenum Vol (L):', style=style)
filename_input = widgets.Text(value='single_torque_curve.csv', description='Save Filename:', style=style)

single_pred_button = widgets.Button(description="Predict Torque Curve", button_style='success', icon='calculator')
save_csv_button = widgets.Button(description="Save to CSV", button_style='info', icon='save', disabled=True)

# Create a dedicated Output widget for this section
single_pred_output = widgets.Output()
latest_single_curve = None

def run_single_prediction(b):
    """Predicts a single curve and displays all output in the dedicated widget."""
    global latest_single_curve
    # Use a 'with' block to capture all output
    with single_pred_output:
        single_pred_output.clear_output(wait=True)
        print("Starting prediction...")
        single_pred_button.disabled = True
        save_csv_button.disabled = True
        
        try:
            params = np.array([sh_input.value, ph_input.value, ir_input.value, pv_input.value])
            valid = True
            for i, p_name in enumerate(MAX_RANGES.keys()):
                min_val, max_val = MAX_RANGES[p_name]
                if not (min_val <= params[i] <= max_val):
                    print(f"VALIDATION FAILED: '{p_name}' ({params[i]}) is out of range ({min_val} to {max_val}).")
                    valid = False
            if not valid: return
                
            stat_norm = (params - S_mean) / S_std
            stat_tensor = torch.tensor(stat_norm, dtype=torch.float32, device=DEVICE).unsqueeze(0)
            
            with torch.no_grad():
                pred_norm = model(rpm_norm_base, stat_tensor).squeeze(0).cpu().numpy()
            
            pred_torque = pred_norm * Y_std + Y_mean
            latest_single_curve = pd.DataFrame({'RPM': RPM_GRID.astype(int), 'Torque_Nm': pred_torque})
            
            plt.figure(figsize=(10, 6))
            plt.plot(latest_single_curve['RPM'], latest_single_curve['Torque_Nm'], marker='o', linestyle='-')
            plt.title(f"Predicted Torque Curve for:\nSH={params[0]:.2f}, H={params[1]:.1f}, R={params[2]:.2f}, PV={params[3]:.2f}")
            plt.xlabel("RPM")
            plt.ylabel("Torque (Nm)")
            plt.grid(True)
            plt.show() # Renders plot inside the output widget
            
            print("Prediction complete. You can now save the data to CSV.")
            save_csv_button.disabled = False
        
        except Exception as e:
            print(f"An unexpected error occurred during prediction: {e}")
        finally:
            single_pred_button.disabled = False

def save_to_csv(b):
    """Saves the curve to CSV, with status messages in the output widget."""
    with single_pred_output:
        print("\nAttempting to save data...")
        try:
            if latest_single_curve is not None:
                filename = filename_input.value
                if not filename.endswith('.csv'): filename += '.csv'
                latest_single_curve.to_csv(filename, index=False, float_format='%.2f')
                print(f"Successfully saved torque curve to '{os.path.abspath(filename)}'")
            else:
                print("No curve data to save. Please run a prediction first.")
        except Exception as e:
            print(f"An unexpected error occurred during save: {e}")

# --- ATTACH EVENT HANDLERS ---
single_pred_button._click_handlers.callbacks = []
save_csv_button._click_handlers.callbacks = []
single_pred_button.on_click(run_single_prediction)
save_csv_button.on_click(save_to_csv)

# --- DISPLAY UI ---
input_box = widgets.VBox([sh_input, ph_input, ir_input, pv_input])
button_box = widgets.VBox([single_pred_button, filename_input, save_csv_button])
# Display the controls and the dedicated output area
display(widgets.HBox([input_box, button_box]), single_pred_output)


HBox(children=(VBox(children=(FloatText(value=12.0, description='Sec. Header (in):', style=DescriptionStyle(de…

Output()