# **Overview**

This notebook provides a widget-driven environment for configuring, tuning, and simulating a nonlinear Model Predictive Control (MPC) system based on pre-trained Radial Basis Function Neural Network (RBFNN) models.

The workflow supports:
- loading RBFNN-based plant and predictive MPC models from .mat files
- interactive initial steady-state selection from datasets or manual input
- MPC tuning via intuitive weight and horizon widgets
- case-study configuration including setpoints, constraints, and feature toggles
- closed-loop simulation with real-time progress feedback
- structured visualization of outputs, inputs, setpoints, and constraints

The notebook is designed to operate downstream of the RBFNN training pipeline, enabling rapid benchmarking of control performance under different operating scenarios without modifying controller internals.

Implemented within **Greece4.0: Network of Excellence for Digital Transformation Technologies in the Greek Manufacturing Industry**.

**Code repository:**  
https://github.com/ntua-unit-of-control-and-informatics/Predictive-Control-System-Greece4.0/tree/main

## **Code**

In [1]:
# =============================================================================
# Predictive control notebook UI
# =============================================================================

## ----------------------------------------------------------------------------
##                                  IMPORT PACKAGES
## ----------------------------------------------------------------------------

!pip install casadi
import casadi as ca

import os, re, math
import time as tm
import builtins
from pathlib import Path

import numpy as np
import pandas as pd
import scipy.io
from scipy.io import loadmat

import ipywidgets as widgets
from IPython.display import display, clear_output, HTML

try:
    from google.colab import files
except Exception:
    files = None  # Not running in Colab

# -----------------------------------------------------------------------------
# Style
# -----------------------------------------------------------------------------
display(HTML("""
<style>
.widget-button.btn-info {
    background-color: #006666 !important;
    border-color: #006666 !important;
    color: #ffffff !important;
    font-weight: 700 !important;
}
.widget-button.btn-info:hover,
.widget-button.btn-info:focus {
    background-color: #005555 !important;
    border-color: #005555 !important;
    color: #ffffff !important;
}
</style>
"""))


# -----------------------------------------------------------------------------
# Small HTML banner helper (used for sim + plot messages)
# -----------------------------------------------------------------------------
def banner_html(text: str, kind: str = "info") -> str:
    styles = {
        "ok":   ("#e8f5e9", "#2e7d32", "#c8e6c9"),
        "info": ("#e3f2fd", "#1565c0", "#bbdefb"),
        "warn": ("#fff8e1", "#8d6e00", "#ffecb3"),
        "err":  ("#ffebee", "#c62828", "#ffcdd2"),
    }
    bg, fg, br = styles.get(kind, styles["info"])
    safe = (str(text).replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;"))
    return f"""
    <div style="
        background:{bg};
        color:{fg};
        border:1px solid {br};
        padding:6px 10px;
        border-radius:10px;
        margin:6px 0;
        max-width:720px;
        font-family:Arial, sans-serif;
        font-size:13px;
        line-height:1.2;">
      {safe}
    </div>
    """


# =============================================================================
# 1) Upload files (Colab)
# =============================================================================
def show_file_uploader():
    """
    Colab helper: uploads files into the current runtime and stores them in a global dict: uploaded.
    """
    if files is None:
        display(widgets.HTML("<b style='color:#b00'>File upload requires Google Colab.</b>"))
        return

    upload_btn = widgets.Button(
        description="Upload data/models",
        button_style="info",
        layout=widgets.Layout(width="260px")
    )
    upload_status_html = widgets.HTML("")
    upload_out = widgets.Output()

    def handle_upload_click(_):
        with upload_out:
            clear_output()
            upload_status_html.value = "<i>Select files to upload…</i>"
            new_files = files.upload()

            if not new_files:
                upload_status_html.value = "<b style='color:#A2142F'>No files uploaded.</b>"
                return

            if "uploaded" in globals():
                globals()["uploaded"].update(new_files)
            else:
                globals()["uploaded"] = new_files

            upload_status_html.value = f"<b style='color:#006666'>Imported {len(new_files)} file(s).</b>"
            print("Files available:")
            for k in globals()["uploaded"].keys():
                print(" •", k)

    upload_btn.on_click(handle_upload_click)

    display(widgets.VBox([upload_btn, upload_status_html]), upload_out)

# =============================================================================
# 2) Model loaders (MPC + Plant)
# =============================================================================

def load_mat_file_mpc(filepath: str):
    # IMPORTANT: do NOT squeeze -> keeps MATLAB-like shapes so MPC_function indexing works
    return loadmat(filepath, squeeze_me=False, struct_as_record=False)

def load_mat_file_plant(filepath: str):
    return loadmat(filepath, squeeze_me=True, struct_as_record=False)

def build_model_loader(
    *,
    title: str,
    folder: str = "/content",
    exts=(".mat",),
    loader_fn=load_mat_file_plant,
    allow_subfolders: bool = False,
    initial_filter: str = "",
    post_load_callback=None,
):
    out = widgets.Output()
    status = widgets.HTML("")
    loaded = {"path": None, "obj": None}

    folder_w = widgets.Text(
        value=folder,
        description="Folder:",
        layout=widgets.Layout(width="650px")
    )

    filter_w = widgets.Text(
        value=initial_filter,
        description="Filter:",
        placeholder="optional substring, e.g. 'mpc' or 'plant'",
        layout=widgets.Layout(width="400px")
    )

    file_dd = widgets.Dropdown(
        options=[],
        description="File:",
        layout=widgets.Layout(width="650px")
    )

    refresh_btn = widgets.Button(description="Refresh list")
    load_btn = widgets.Button(description=f"Load {title}", button_style="primary")

    def _list_files():
        base = Path(folder_w.value).expanduser()
        if not base.exists():
            return []
        paths = [p for p in (base.rglob("*") if allow_subfolders else base.glob("*")) if p.is_file()]

        ok = []
        fsub = (filter_w.value or "").strip().lower()
        extset = set(e.lower() for e in (exts or ()))
        for p in paths:
            if extset and p.suffix.lower() not in extset:
                continue
            if fsub and fsub not in p.name.lower():
                continue
            ok.append(str(p))
        return sorted(ok, key=str.lower)

    def _refresh(_=None):
        options = _list_files()
        file_dd.options = options
        if options:
            file_dd.value = options[0]
        with out:
            clear_output(wait=True)
            status.value = f"<span style='color:#555'>Found {len(options)} file(s).</span>"
            display(status)

    def _load(_=None):
        with out:
            clear_output(wait=True)
            try:
                if not file_dd.value:
                    raise RuntimeError("No file selected.")
                path = str(file_dd.value)
                obj = loader_fn(path)
                loaded["path"] = path
                loaded["obj"] = obj
                status.value = f"<span style='color:green'><b>Loaded:</b> {Path(path).name}</span>"
                display(status)

                if post_load_callback is not None:
                    post_load_callback(obj, path)

            except Exception as e:
                status.value = f"<span style='color:#b00'><b>Error:</b> {e}</span>"
                display(status)

    refresh_btn.on_click(_refresh)
    load_btn.on_click(_load)

    ui = widgets.VBox([
        widgets.HTML(f"<h3 style='margin:0'>{title}</h3>"),
        folder_w,
        widgets.HBox([filter_w, refresh_btn]),
        file_dd,
        load_btn,
        out
    ])

    def get_loaded():
        return loaded["obj"], loaded["path"]

    _refresh()
    return ui, get_loaded


# =============================================================================
# 3) Build plant_fn() from loaded plant .mat
# =============================================================================

if "RBF_predictions" not in globals():
    RBF_predictions = None

plant_model_path = None
plant_model_mat = None
plant_fn = None


def build_plant_function_from_mat(RBF_plant: dict):
    def _mat_1d(x):
        return np.array(x).reshape(-1)

    nu   = int(_mat_1d(RBF_plant["nu"])[0])
    ny   = int(_mat_1d(RBF_plant["ny"])[0])
    lpin = int(_mat_1d(RBF_plant["lpin"])[0])

    xmin = _mat_1d(RBF_plant["xmin"])
    xmax = _mat_1d(RBF_plant["xmax"])
    ymin = _mat_1d(RBF_plant["ymin"])
    ymax = _mat_1d(RBF_plant["ymax"])

    mv_ss_default = _mat_1d(RBF_plant["mv_ss"]).astype(float).tolist()
    mv_ss_default = (mv_ss_default + [0.0] * max(0, nu - len(mv_ss_default)))[:nu]

    centers_np = np.array(RBF_plant["centers"], dtype=float)
    weights_np = np.array(RBF_plant["weights"], dtype=float)

    centers = ca.DM(centers_np)
    weights = ca.DM(weights_np)

    width_np = np.array(RBF_plant["width"], dtype=float).reshape(-1)
    if width_np.size == 1:
        width = ca.DM(float(width_np[0]))
    else:
        width = ca.DM(width_np).reshape((width_np.size, 1))

    past_u = [None] * nu
    pred_fn = {"fn": None}

    def _normalize_u(u_val, i):
        u_val = float(u_val)
        return (2.0 * (u_val - float(xmin[i])) / (float(xmax[i]) - float(xmin[i])) - 1.0)

    def _build_predictor():
        x = ca.SX.sym("x", 1, int(centers.size2()))
        rep = ca.repmat(x, int(centers.size1()), 1)
        dist = ca.sqrt(ca.sum2((centers - rep) ** 2))
        phi  = ca.exp(-((dist ** 2) / (width ** 2)))
        pred = (weights.T @ phi + 1) * (ca.DM(ymax - ymin) / 2) + ca.DM(ymin)
        return ca.Function("RBFpredictions", [x], [pred])

    def plant_fn_inner(inputs, t, new_mvss=None):
        nonlocal past_u

        if new_mvss is None:
            mv_ss = list(mv_ss_default)
        else:
            mv_ss = np.array(new_mvss, dtype=float).reshape(-1).tolist()
            mv_ss = (mv_ss + [0.0] * max(0, nu - len(mv_ss)))[:nu]

        if t == 0:
            past_u = []
            for i in range(nu):
                u_norm = _normalize_u(mv_ss[i], i)
                past_u.append(ca.DM(np.ones((1, lpin)) * u_norm))

            if pred_fn["fn"] is None:
                pred_fn["fn"] = _build_predictor()

        u_now = np.array(inputs, dtype=float).reshape(-1).tolist()
        u_now = (u_now + [0.0] * max(0, nu - len(u_now)))[:nu]

        for i in range(nu):
            u_norm_now = _normalize_u(u_now[i], i)
            past_u[i] = ca.horzcat(ca.DM([[u_norm_now]]), past_u[i][:, :-1])

        x_in = ca.horzcat(*past_u)
        y = pred_fn["fn"](x_in)
        return np.array(y.full()).reshape(-1)

    return plant_fn_inner


def handle_plant_loaded(plant_mat, plant_path):
    global plant_model_path, plant_model_mat, plant_fn

    plant_model_path = plant_path
    plant_model_mat = plant_mat

    plant_fn = build_plant_function_from_mat(plant_mat)

    try:
        mv_ss = np.array(plant_mat["mv_ss"]).squeeze()
        plant_fn(mv_ss, t=0, new_mvss=mv_ss)
        print("Plant built + initialized.")
    except Exception as e:
        print("Plant built but init skipped:", e)


mpc_loader_ui, get_loaded_mpc_model = build_model_loader(
    title="MPC model",
    folder="/content",
    exts=(".mat",),
    loader_fn=load_mat_file_mpc,
    initial_filter="mpc",
    allow_subfolders=False
)

plant_loader_ui, get_loaded_plant_model = build_model_loader(
    title="Plant model",
    folder="/content",
    exts=(".mat",),
    loader_fn=load_mat_file_plant,
    initial_filter="plant",
    allow_subfolders=False,
    post_load_callback=handle_plant_loaded
)


# =============================================================================
# 4) Steady-state UI
# =============================================================================

num_inputs = None
num_outputs = None
time_units = None
input_labels = None
output_labels = None
input_units = None
output_units = None

u0_steady_state = None
y0_steady_state = None

input_columns = []
output_columns = []

_ss_df = None
_ss_selected_row_index = None
_ss_selected_inputs_list = None


def ss_make_table(rows, headers, title, font_size=40):
    if "create_html_table" in globals() and callable(globals()["create_html_table"]):
        return globals()["create_html_table"](rows, headers, title, font_size)

    head = "".join(f"<th style='padding:6px;border:1px solid #ddd'>{h}</th>" for h in headers)
    body = "".join(
        "<tr>" + "".join(f"<td style='padding:6px;border:1px solid #ddd'>{c}</td>" for c in r) + "</tr>"
        for r in rows
    )
    return (
        f"<h4 style='margin:8px 0'>{title}</h4>"
        f"<table style='border-collapse:collapse;width:100%'><tr>{head}</tr>{body}</table>"
    )


def list_dataset_files(folder: str):
    try:
        files_ = os.listdir(folder)
    except Exception:
        return []
    return sorted([f for f in files_ if f.lower().endswith((".csv", ".xlsx", ".xls"))], key=str.lower)


def read_dataset_file(path: str) -> pd.DataFrame:
    if path.lower().endswith(".csv"):
        return pd.read_csv(path)
    if path.lower().endswith((".xlsx", ".xls")):
        return pd.read_excel(path)
    raise ValueError("Unsupported file type")


def clean_dataset_columns(df: pd.DataFrame) -> pd.DataFrame:
    df = df.dropna(axis=1, how="all")
    bad = [c for c in df.columns if str(c).strip().lower().startswith("unnamed")]
    if bad:
        df = df.drop(columns=bad)
    return df


def find_columns_containing(df: pd.DataFrame, key: str):
    key = key.lower()
    return [c for c in df.columns if key in str(c).lower()]


def parse_label_and_units(col_name: str):
    s = str(col_name).strip()
    m_units = re.search(r"\(([^()]*)\)\s*$", s)
    units = m_units.group(1).strip() if m_units else "-"
    if m_units:
        s = s[:m_units.start()].strip()
    s = re.sub(r"^(Input|Output)\s*\d+\s*:\s*", "", s, flags=re.IGNORECASE).strip()
    return s, units


def detect_time_units(df: pd.DataFrame) -> str:
    for c in df.columns:
        if "time" in str(c).lower():
            m = re.search(r"\(([^()]*)\)", str(c))
            if m:
                return m.group(1).strip()
    return "samples"


def build_io_metadata(df: pd.DataFrame):
    global num_inputs, num_outputs, input_columns, output_columns
    global input_labels, output_labels, input_units, output_units, time_units

    time_units = detect_time_units(df)

    input_columns = find_columns_containing(df, "input")
    output_columns = find_columns_containing(df, "output")

    num_inputs = len(input_columns)
    num_outputs = len(output_columns)

    input_labels, input_units = [], []
    for c in input_columns:
        lab, unit = parse_label_and_units(c)
        input_labels.append(lab if lab else str(c))
        input_units.append(unit)

    output_labels, output_units = [], []
    for c in output_columns:
        lab, unit = parse_label_and_units(c)
        output_labels.append(lab if lab else str(c))
        output_units.append(unit)


def scrollable_table_html(df: pd.DataFrame, height_px: int = 260) -> str:
    view = df.copy()
    for c in view.columns:
        if pd.api.types.is_numeric_dtype(view[c]):
            view[c] = view[c].map(lambda x: f"{x:.6g}" if pd.notnull(x) else "")
    return f"""
    <div style="max-height:{height_px}px; overflow:auto; border:1px solid #ddd; padding:6px;">
      {view.to_html(index=True)}
    </div>
    """


SS_TITLE_HTML = widgets.HTML("<h3 style='margin:0'>Initial steady state</h3>")

SS_FOLDER_TEXT = widgets.Text(
    value="/content",
    description="Data folder:",
    layout=widgets.Layout(width="650px")
)

SS_REFRESH_BUTTON = widgets.Button(description="Refresh list")
SS_FILE_DD = widgets.Dropdown(options=[], description="Dataset:", layout=widgets.Layout(width="650px"))
SS_LOAD_BUTTON = widgets.Button(description="Load dataset", button_style="primary")

SS_ROW_SLIDER = widgets.IntSlider(value=0, min=0, max=0, step=1, description="Row:", continuous_update=False)
SS_USE_ROW_BUTTON = widgets.Button(description="Use this row", button_style="info")

SS_ROW_OUT = widgets.Output()
SS_TABLE_OUT = widgets.Output()

SS_DEFINE_INPUTS_TEXT = widgets.Text(
    value="",
    description="Inputs (separated by commas):",
    layout=widgets.Layout(width="720px")
)
SS_DEFINE_INPUTS_TEXT.style.description_width = "260px"

SS_COMPUTE_TYPED_BUTTON = widgets.Button(
    description="Compute from typed inputs",
    button_style="warning",
    layout=widgets.Layout(width="260px")
)

SS_STATUS_HTML = widgets.HTML("")
SS_STEADY_TABLE_HTML = widgets.HTML("")


def ss_refresh_files(_=None):
    folder = SS_FOLDER_TEXT.value.strip() or "/content"
    files_ = list_dataset_files(folder)
    SS_FILE_DD.options = [os.path.join(folder, f) for f in files_]
    if SS_FILE_DD.options:
        SS_FILE_DD.value = SS_FILE_DD.options[0]
    SS_STATUS_HTML.value = f"<span style='color:#555'>Found {len(files_)} dataset file(s).</span>"


def ss_on_load_clicked(_=None):
    global _ss_df
    try:
        path = SS_FILE_DD.value
        if not path:
            raise RuntimeError("No dataset selected.")

        df = read_dataset_file(path)
        df = clean_dataset_columns(df)

        _ss_df = df
        build_io_metadata(df)

        SS_ROW_SLIDER.max = max(0, len(df) - 1)
        SS_ROW_SLIDER.value = 0

        with SS_TABLE_OUT:
            clear_output(wait=True)
            display(HTML("<b>Dataset preview</b>"))
            display(HTML(scrollable_table_html(df, height_px=260)))

        SS_STATUS_HTML.value = "<span style='color:green'><b>Dataset loaded.</b></span>"
        ss_on_row_change({"name": "value"})

    except Exception as e:
        SS_STATUS_HTML.value = f"<span style='color:#b00'><b>Error:</b> {e}</span>"


def ss_compute_and_show_steady_state(u_list):
    global u0_steady_state, y0_steady_state

    if plant_fn is None or not callable(plant_fn):
        raise RuntimeError("Plant not loaded. Load the plant .mat first so plant_fn exists.")

    u0 = [float(v) for v in np.array(u_list, dtype=float).reshape(-1).tolist()]
    y0 = plant_fn(u0, 0, u0)

    u0_steady_state = u0
    y0_steady_state = np.array(y0).reshape(-1).astype(float).tolist()

    rows = []

    for i in range(len(u0_steady_state)):
        lab = input_labels[i] if input_labels and i < len(input_labels) else f"Input {i+1}"
        unit = input_units[i] if input_units and i < len(input_units) else "-"
        rows.append([lab, f"{u0_steady_state[i]:.4f}", unit])

    for i in range(len(y0_steady_state)):
        lab = output_labels[i] if output_labels and i < len(output_labels) else f"Output {i+1}"
        unit = output_units[i] if output_units and i < len(output_units) else "-"
        rows.append([lab, f"{y0_steady_state[i]:.4f}", unit])

    SS_STEADY_TABLE_HTML.value = ss_make_table(
        rows,
        ["Signal", "Value", "Units"],
        "Initial steady state"
    )


def ss_on_row_change(change):
    if change.get("name") != "value":
        return

    with SS_ROW_OUT:
        clear_output(wait=True)
        if _ss_df is None:
            display(HTML("Load a dataset first."))
            return

        r = int(SS_ROW_SLIDER.value)
        display(HTML(f"<b>Selected row:</b> {r}"))

        show_cols = []
        if input_columns:
            show_cols += input_columns
        if output_columns:
            show_cols += output_columns

        if show_cols:
            display(_ss_df.loc[[r], show_cols])
        else:
            display(_ss_df.loc[[r]])


def ss_on_use_row_clicked(_=None):
    global _ss_selected_row_index, _ss_selected_inputs_list
    try:
        if _ss_df is None:
            raise RuntimeError("Load a dataset first.")
        if not input_columns:
            raise RuntimeError("No input columns detected in dataset.")

        r = int(SS_ROW_SLIDER.value)
        vals = _ss_df.loc[r, input_columns]
        u_list = [float(x) for x in np.array(vals, dtype=float).tolist()]

        _ss_selected_row_index = r
        _ss_selected_inputs_list = u_list

        SS_DEFINE_INPUTS_TEXT.value = ", ".join([f"{v:.2f}" for v in u_list])

        ss_compute_and_show_steady_state(u_list)
        SS_STATUS_HTML.value = f"<span style='color:green'><b>Row {r} selected.</b> Steady state computed.</span>"

    except Exception as e:
        SS_STATUS_HTML.value = f"<span style='color:#b00'><b>Error:</b> {e}</span>"


def ss_on_compute_typed_clicked(_=None):
    try:
        if num_inputs is None:
            raise RuntimeError("Load a dataset first (to infer num_inputs).")

        vals = [v.strip() for v in SS_DEFINE_INPUTS_TEXT.value.split(",") if v.strip()]
        u_list = [float(v) for v in vals]

        if len(u_list) != int(num_inputs):
            raise RuntimeError(f"Typed inputs length ({len(u_list)}) must equal num_inputs ({num_inputs}).")

        SS_DEFINE_INPUTS_TEXT.value = ", ".join([f"{v:.2f}" for v in u_list])

        ss_compute_and_show_steady_state(u_list)
        SS_STATUS_HTML.value = "<span style='color:green'><b>Typed inputs used.</b> Steady state computed.</span>"

    except Exception as e:
        SS_STATUS_HTML.value = f"<span style='color:#b00'><b>Error:</b> {e}</span>"


SS_REFRESH_BUTTON.on_click(ss_refresh_files)
SS_LOAD_BUTTON.on_click(ss_on_load_clicked)
SS_ROW_SLIDER.observe(ss_on_row_change, names="value")
SS_USE_ROW_BUTTON.on_click(ss_on_use_row_clicked)
SS_COMPUTE_TYPED_BUTTON.on_click(ss_on_compute_typed_clicked)

ss_refresh_files()

steady_state_ui = widgets.VBox([
    SS_TITLE_HTML,
    widgets.HBox([SS_FOLDER_TEXT, SS_REFRESH_BUTTON]),
    widgets.HBox([SS_FILE_DD, SS_LOAD_BUTTON]),
    widgets.HTML("<b>Option A (from dataset):</b> pick a row → “Use this row” computes steady state"),
    widgets.HBox([SS_ROW_SLIDER, SS_USE_ROW_BUTTON]),
    SS_ROW_OUT,
    SS_TABLE_OUT,
    widgets.HTML("<hr style='margin:10px 0'>"),
    widgets.HTML("<b>Option B (manual):</b> type inputs → compute steady state"),
    widgets.HBox([SS_DEFINE_INPUTS_TEXT, SS_COMPUTE_TYPED_BUTTON], layout=widgets.Layout(width="100%")),
    SS_STATUS_HTML,
    SS_STEADY_TABLE_HTML
])

# =============================================================================
# 5) MPC tuning UI
# =============================================================================

q = None
r = None
Wmin = None
Wmax = None
ch = None
ph = None


tuning_out = widgets.Output()
tuning_status = widgets.HTML("")

create_tuning_sliders_btn = widgets.Button(
    description="Create tuning sliders",
    button_style="info"
)

apply_tuning_btn = widgets.Button(
    description="Tune MPC",
    button_style="warning"
)


control_horizon_input = widgets.BoundedIntText(
    value=10, min=1, max=500,
    description="Control horizon (ch):",
    layout=widgets.Layout(width="280px")
)
control_horizon_input.style.description_width = "initial"

prediction_horizon_input = widgets.BoundedIntText(
    value=20, min=1, max=1000,
    description="Prediction horizon (ph):",
    layout=widgets.Layout(width="280px")
)
prediction_horizon_input.style.description_width = "initial"


q_sliders = []
r_sliders = []
wmin_sliders = []
wmax_sliders = []

q_box = widgets.VBox([])
r_box = widgets.VBox([])
wmin_box = widgets.VBox([])
wmax_box = widgets.VBox([])


TUNING_DEFAULTS = {
    "Q":    {"min": 0.0, "max": 100.0, "step": 0.1, "value": 1.0},
    "R":    {"min": 0.0, "max": 100.0, "step": 0.1, "value": 0.1},
    "Wmin": {"min": 0.0, "max": 1e6,   "step": 100.0, "value": 1e4},
    "Wmax": {"min": 0.0, "max": 1e6,   "step": 100.0, "value": 1e4},
}


LABEL_W = "480px"
VAL_W   = "120px"
SLIDER_W = "300px"


def make_tuning_row(label_text, *, min, max, step, value):
    label = widgets.HTML(
        f"""
        <div style="
            display:flex;
            align-items:center;
            height:32px;
            margin-top:6px;
            white-space:normal;
            line-height:1.2;">
            {label_text}
        </div>
        """,
        layout=widgets.Layout(width=LABEL_W)
    )

    val = widgets.FloatText(
        value=float(value),
        layout=widgets.Layout(width=VAL_W, margin="0 10px 0 10px"),
        style={"description_width": "0px"}
    )

    sld = widgets.FloatSlider(
        min=float(min),
        max=float(max),
        step=float(step),
        value=float(value),
        continuous_update=False,
        readout=True,
        layout=widgets.Layout(width=SLIDER_W)
    )

    widgets.link((val, "value"), (sld, "value"))

    return sld, val, widgets.HBox([label, val, sld], layout=widgets.Layout(align_items="center"))


def build_tuning_sliders():
    global q_sliders, r_sliders, wmin_sliders, wmax_sliders

    nu = int(num_inputs)
    ny = int(num_outputs)

    q_rows = []
    q_sliders = []
    for i in range(ny):
        sld, _, row = make_tuning_row(
            f"Q component for {output_name(i)}:",
            **TUNING_DEFAULTS["Q"]
        )
        q_sliders.append(sld)
        q_rows.append(row)

    r_rows = []
    r_sliders = []
    for i in range(nu):
        sld, _, row = make_tuning_row(
            f"R component for {input_name(i)}:",
            **TUNING_DEFAULTS["R"]
        )
        r_sliders.append(sld)
        r_rows.append(row)

    wmin_rows = []
    wmin_sliders = []
    for i in range(ny):
        sld, _, row = make_tuning_row(
            f"Wmin penalty for {output_name(i)}:",
            **TUNING_DEFAULTS["Wmin"]
        )
        wmin_sliders.append(sld)
        wmin_rows.append(row)

    wmax_rows = []
    wmax_sliders = []
    for i in range(ny):
        sld, _, row = make_tuning_row(
            f"Wmax penalty for {output_name(i)}:",
            **TUNING_DEFAULTS["Wmax"]
        )
        wmax_sliders.append(sld)
        wmax_rows.append(row)

    q_box.children = q_rows
    r_box.children = r_rows
    wmin_box.children = wmin_rows
    wmax_box.children = wmax_rows


def on_create_tuning_sliders(_=None):
    with tuning_out:
        clear_output(wait=True)
        try:
            build_tuning_sliders()
            tuning_status.value = (
                f"<span style='color:green'><b>Sliders created.</b> "
                f"num_inputs={num_inputs}, num_outputs={num_outputs}</span>"
            )
        except Exception as e:
            tuning_status.value = f"<span style='color:#b00'><b>Error:</b> {e}</span>"


def on_apply_tuning(_=None):
    global q, r, Wmin, Wmax, ch, ph

    with tuning_out:
        clear_output(wait=True)
        try:
            if not q_sliders:
                raise RuntimeError("Create sliders first.")

            q = [s.value for s in q_sliders]
            r = [s.value for s in r_sliders]
            Wmin = [s.value for s in wmin_sliders]
            Wmax = [s.value for s in wmax_sliders]
            ch = int(control_horizon_input.value)
            ph = int(prediction_horizon_input.value)

            if ch > ph:
                raise RuntimeError("Require ch ≤ ph.")

            tuning_status.value = (
                "<span style='color:green'><b>MPC tuned.</b><br>"
                "Globals set: q, r, Wmin, Wmax, ch, ph</span>"
            )

        except Exception as e:
            tuning_status.value = f"<span style='color:#b00'><b>Error:</b> {e}</span>"


create_tuning_sliders_btn.on_click(on_create_tuning_sliders)
apply_tuning_btn.on_click(on_apply_tuning)


mpc_tuning_ui = widgets.VBox([
    widgets.HTML("<h3>MPC tuning</h3>"),

    widgets.HBox([create_tuning_sliders_btn, apply_tuning_btn]),
    tuning_status,

    widgets.HBox([control_horizon_input, prediction_horizon_input]),

    widgets.HTML("<hr>"),
    widgets.HTML("<b>Q – Setpoint tracking</b>"),
    q_box,

    widgets.HTML("<hr>"),
    widgets.HTML("<b>R – MV suppression</b>"),
    r_box,

    widgets.HTML("<hr>"),
    widgets.HTML("<b>Wmin – Lower-bound penalties</b>"),
    wmin_box,

    widgets.HTML("<hr>"),
    widgets.HTML("<b>Wmax – Upper-bound penalties</b>"),
    wmax_box,

    tuning_out
])


# =============================================================================
# 6) Case study UI
# =============================================================================

# Outputs of this UI (used by simulation)
case_params = None
sample_time = None
n_steps = None

# -----------------------------------------------------------------------------
# Core widgets
# -----------------------------------------------------------------------------

case_out = widgets.Output()
case_status = widgets.HTML("")

create_case_widgets_btn = widgets.Button(
    description="Create case study widgets",
    button_style="info",
    layout=widgets.Layout(width="260px")
)

submit_case_values_btn = widgets.Button(
    description="Submit values",
    button_style="warning",
    layout=widgets.Layout(width="180px")
)

# Created only after clicking "Create"
case_total_time = None
case_ts = None
case_sp_change_time = None

# Containers
case_sp_box = widgets.VBox([])
case_bounds_box = widgets.VBox([])
case_mv_bounds_box = widgets.VBox([])
case_toggles_box = widgets.VBox([])

# Internal widget lists
case_sp_pct, case_sp_val = [], []
case_lb_pct, case_lb_val = [], []
case_ub_pct, case_ub_val = [], []
case_umin_pct, case_umin_val = [], []
case_umax_pct, case_umax_val = [], []

case_toggle_sp, case_toggle_lb, case_toggle_ub = [], [], []

# Cached steady-state values
case_u_ss = None
case_y_ss = None

# Guard against recursive updates
case_sync_guard = {"busy": False}

# Layout constants
LABEL_W = "450px"
VAL_W = "200px"
PCT_W = "300px"

# -----------------------------------------------------------------------------
# Naming helpers
# -----------------------------------------------------------------------------

def output_name(i: int) -> str:
    if output_columns and i < len(output_columns):
        return str(output_columns[i])
    return f"Output {i+1}"

def input_name(i: int) -> str:
    if input_columns and i < len(input_columns):
        return str(input_columns[i])
    return f"Input {i+1}"

# -----------------------------------------------------------------------------
# Percent <-> value mapping
# -----------------------------------------------------------------------------

def _scale(base: float) -> float:
    base = float(base)
    return abs(base) if abs(base) > 1e-12 else 1.0

def value_from_pct(base: float, pct: float) -> float:
    s = _scale(base)
    return base * (1.0 + pct / 100.0) if abs(base) > 1e-12 else base + (pct / 100.0) * s

def pct_from_value(base: float, value: float) -> float:
    s = _scale(base)
    return 100.0 * (value / base - 1.0) if abs(base) > 1e-12 else 100.0 * ((value - base) / s)

def fmt4(x: float) -> float:
    return float(f"{float(x):.4f}")

# -----------------------------------------------------------------------------
# Row factory: label + value + % slider
# -----------------------------------------------------------------------------

def make_pct_value_row(*, label_html: str, base_value: float, default_pct: float):
    label = widgets.HTML(
        f"<div style='white-space:normal; line-height:1.2'>{label_html}</div>",
        layout=widgets.Layout(width=LABEL_W)
    )

    val = widgets.FloatText(
        value=fmt4(value_from_pct(base_value, default_pct)),
        description="Value:",
        layout=widgets.Layout(width=VAL_W, padding="0 0 0 15px"),
        style={"description_width": "initial"}
    )

    pct = widgets.FloatSlider(
        min=-200.0, max=200.0, step=1.0,
        value=default_pct,
        description="Δ%",
        continuous_update=False,
        layout=widgets.Layout(width=PCT_W),
        style={"description_width": "26px"}
    )

    def on_pct_change(change):
        if case_sync_guard["busy"]:
            return
        case_sync_guard["busy"] = True
        val.value = fmt4(value_from_pct(base_value, pct.value))
        case_sync_guard["busy"] = False

    def on_val_change(change):
        if case_sync_guard["busy"]:
            return
        case_sync_guard["busy"] = True
        val.value = fmt4(val.value)
        pct.value = max(-200.0, min(200.0, pct_from_value(base_value, val.value)))
        case_sync_guard["busy"] = False

    pct.observe(on_pct_change, names="value")
    val.observe(on_val_change, names="value")

    return pct, val, widgets.HBox([label, val, pct], layout=widgets.Layout(align_items="center"))

def make_toggle(label: str):
    t = widgets.ToggleButtons(
        options=[("✓ Enabled", 1), ("✗ Disabled", 0)],
        value=0,
        description=label,
        layout=widgets.Layout(width="900px")
    )
    t.style.description_width = "initial"
    return t

# -----------------------------------------------------------------------------
# Build widgets from steady-state
# -----------------------------------------------------------------------------

def build_case_widgets():
    global case_u_ss, case_y_ss
    global case_sp_pct, case_sp_val, case_lb_pct, case_lb_val
    global case_ub_pct, case_ub_val, case_umin_pct, case_umin_val
    global case_umax_pct, case_umax_val
    global case_toggle_sp, case_toggle_lb, case_toggle_ub

    if u0_steady_state is None or y0_steady_state is None:
        raise RuntimeError("Compute steady state first.")

    nu = int(num_inputs)
    ny = int(num_outputs)

    case_u_ss = list(u0_steady_state)
    case_y_ss = list(y0_steady_state[:ny])

    # Setpoints
    case_sp_pct, case_sp_val = [], []
    case_sp_box.children = []
    for i in range(ny):
        pct, val, row = make_pct_value_row(
            label_html=f"<b>Set-point for</b> {output_name(i)}",
            base_value=case_y_ss[i],
            default_pct=0.0
        )
        case_sp_pct.append(pct)
        case_sp_val.append(val)
        case_sp_box.children += (row,)

    # Output bounds
    case_lb_pct, case_lb_val, case_ub_pct, case_ub_val = [], [], [], []
    case_bounds_box.children = []
    for i in range(ny):
        lb = make_pct_value_row(
            label_html=f"<b>Lower Bound for</b> {output_name(i)}",
            base_value=case_y_ss[i],
            default_pct=-10.0
        )
        ub = make_pct_value_row(
            label_html=f"<b>Upper Bound for</b> {output_name(i)}",
            base_value=case_y_ss[i],
            default_pct=10.0
        )
        case_lb_pct.append(lb[0]); case_lb_val.append(lb[1])
        case_ub_pct.append(ub[0]); case_ub_val.append(ub[1])
        case_bounds_box.children += (widgets.VBox([lb[2], ub[2]]),)

    # MV bounds
    case_umin_pct, case_umin_val, case_umax_pct, case_umax_val = [], [], [], []
    case_mv_bounds_box.children = []
    for i in range(nu):
        umin = make_pct_value_row(
            label_html=f"<b>Umin for</b> {input_name(i)}",
            base_value=case_u_ss[i],
            default_pct=-50.0
        )
        umax = make_pct_value_row(
            label_html=f"<b>Umax for</b> {input_name(i)}",
            base_value=case_u_ss[i],
            default_pct=50.0
        )
        case_umin_pct.append(umin[0]); case_umin_val.append(umin[1])
        case_umax_pct.append(umax[0]); case_umax_val.append(umax[1])
        case_mv_bounds_box.children += (widgets.VBox([umin[2], umax[2]]),)

    # Toggles
    case_toggle_sp, case_toggle_lb, case_toggle_ub = [], [], []
    case_toggles_box.children = []
    for i in range(ny):
        sp = make_toggle(f"Set-point tracking for {output_name(i)}:")
        lb = make_toggle(f"Lower bounds for {output_name(i)}:")
        ub = make_toggle(f"Upper bounds for {output_name(i)}:")
        case_toggle_sp.append(sp)
        case_toggle_lb.append(lb)
        case_toggle_ub.append(ub)
        case_toggles_box.children += (widgets.VBox([sp, lb, ub, widgets.HTML("<hr>")]),)

# -----------------------------------------------------------------------------
# Click handlers
# -----------------------------------------------------------------------------

def on_create_case_widgets_clicked(_=None):
    global case_total_time, case_ts, case_sp_change_time

    with case_out:
        clear_output(wait=True)

        if not time_units:
            raise RuntimeError("Time units not defined.")

        case_total_time = widgets.FloatText(
            description=f"Simulation Time ({time_units}):",
            value=60.0,
            layout=widgets.Layout(width="420px")
        )
        case_total_time.style.description_width = "initial"

        case_ts = widgets.FloatText(
            description=f"Sampling Time ({time_units}):",
            value=1.0,
            layout=widgets.Layout(width="420px")
        )
        case_ts.style.description_width = "initial"

        case_sp_change_time = widgets.FloatText(
            description=f"Time to Apply SP ({time_units}):",
            value=0.0,
            layout=widgets.Layout(width="420px")
        )
        case_sp_change_time.style.description_width = "initial"


        build_case_widgets()

        display(widgets.VBox([
            widgets.HTML("<h3>Case Study Configuration</h3>"),
            widgets.HTML("<hr>"),
            widgets.HBox([case_total_time, case_ts]),
            case_sp_change_time,
            widgets.HTML("<hr><h4>Setpoints</h4>"),
            case_sp_box,
            widgets.HTML("<hr><h4>Output bounds</h4>"),
            case_bounds_box,
            widgets.HTML("<hr><h4>MV bounds</h4>"),
            case_mv_bounds_box,
            widgets.HTML("<hr><h4>Feature toggles</h4>"),
            case_toggles_box,
        ]))

        case_status.value = "<b>Case study widgets created.</b>"

def on_submit_case_values_clicked(_=None):
    global case_params, sample_time, n_steps

    if case_ts is None:
        raise RuntimeError("Create case study widgets first.")

    sample_time = float(case_ts.value)
    n_steps = int(np.floor(float(case_total_time.value) / sample_time))

    case_params = {
        "SP": [v.value for v in case_sp_val],
        "LB": [v.value for v in case_lb_val],
        "UB": [v.value for v in case_ub_val],
        "umin": [v.value for v in case_umin_val],
        "umax": [v.value for v in case_umax_val],
        "SP_change_step": int(round(case_sp_change_time.value / sample_time)),
    }

    for i in range(int(num_outputs)):
        case_params[f"toggle{i+1}1"] = case_toggle_sp[i].value
        case_params[f"toggle{i+1}2"] = case_toggle_lb[i].value
        case_params[f"toggle{i+1}3"] = case_toggle_ub[i].value

    case_status.value = "<span style='color:green'><b>Values submitted successfully.</b></span>"

# -----------------------------------------------------------------------------
# Wiring
# -----------------------------------------------------------------------------

create_case_widgets_btn.on_click(on_create_case_widgets_clicked)
submit_case_values_btn.on_click(on_submit_case_values_clicked)

case_study_ui = widgets.VBox([
    widgets.HBox([create_case_widgets_btn, submit_case_values_btn]),
    case_status,
    case_out
])

# =============================================================================
# 7) Simulation
# =============================================================================
simulation_messages_out = widgets.Output()
plot_out = widgets.Output()

# --- Plot layout dropdowns ---

output_plot_cols_dropdown = widgets.Dropdown(
    options=[],
    value=None,
    description="Output columns:",
    style={"description_width": "initial"}
)

input_plot_cols_dropdown = widgets.Dropdown(
    options=[],
    value=None,
    description="Input columns:",
    style={"description_width": "initial"}
)

output_plot_pref_dropdown = widgets.Dropdown(
    options=[
        "Plot each output in a separate subplot",
        "Plot all outputs in one subplot"
    ],
    value="Plot each output in a separate subplot",
    description="Output layout:",
    style={"description_width": "initial"}
)

input_plot_pref_dropdown = widgets.Dropdown(
    options=[
        "Plot each input in a separate subplot",
        "Plot all inputs in one subplot"
    ],
    value="Plot each input in a separate subplot",
    description="Input layout:",
    style={"description_width": "initial"}
)

def sim_msg(text: str, kind: str = "info"):
    with simulation_messages_out:
        display(widgets.HTML(banner_html(text, kind)))

def _resolve_plant_callable():
    if "plant_fn" in globals() and callable(globals()["plant_fn"]):
        return globals()["plant_fn"]
    if "PLANT" in globals() and callable(globals()["PLANT"]):
        return globals()["PLANT"]
    return None

def _resolve_mpc_model():
    # Prefer loader output; fallback to local file if user still uses old flow
    obj, _path = get_loaded_mpc_model()
    if obj is not None:
        return obj
    try:
        return scipy.io.loadmat("RBF_mpc.mat")
    except Exception:
        return None

def run_closed_loop(_=None):
    """
    Runs closed-loop simulation using:
      - case_params (from case study UI)
      - u0_steady_state, y0_steady_state (from steady-state UI)
      - plant_fn (from plant loader)
      - MPC_function (your CasADi MPC builder/solver)
      - tuning globals q, r, Wmin, Wmax, ch, ph
    Stores results back into case_params with histories (T, outputi, inputi, SP_historyi).
    """
    global case_params, variable_dict, Ts, N

    simulation_messages_out.clear_output()
    plot_out.clear_output()

    with simulation_messages_out:
        try:
            if case_params is None:
                raise RuntimeError("Case study parameters not set. Use the case study UI.")
            if u0_steady_state is None or y0_steady_state is None:
                raise RuntimeError("Initial steady state not set. Use the steady-state UI.")
            plant = _resolve_plant_callable()
            if plant is None:
                raise RuntimeError("Plant model not loaded/built. Load the plant model first.")
            if "MPC_function" not in globals() or not callable(MPC_function):
                raise RuntimeError("MPC_function is not defined/callable.")

            for name in ["q", "r", "Wmin", "Wmax", "ch", "ph"]:
                if name not in globals() or globals()[name] is None:
                    raise RuntimeError(f"{name} is not set. Tune MPC first.")

            if sample_time is None or n_steps is None:
                raise RuntimeError("Ts/N not set. Save case study parameters first.")

        except RuntimeError as e:
            sim_msg(str(e), "err")
            return

        # Mirror old globals so any downstream plotting code that still expects them works
        variable_dict = case_params
        Ts = sample_time
        N = n_steps

        original_SP = case_params["SP"].copy()

        # Init
        u0 = [float(v) for v in list(u0_steady_state)]
        y0 = np.array(y0_steady_state).reshape(-1).astype(float).tolist()

        y0 = (y0 + [0.0] * max(0, int(num_outputs) - len(y0)))[:int(num_outputs)]
        u0 = (u0 + [0.0] * max(0, int(num_inputs) - len(u0)))[:int(num_inputs)]

        case_params["T"] = [0.0]
        for i in range(int(num_outputs)):
            case_params[f"output{i+1}"] = [float(y0[i])]
        for i in range(int(num_inputs)):
            case_params[f"input{i+1}"] = [float(u0[i])]

        # Keep your behavior: set SP to current output at start, and keep a history
        case_params["SP"] = [float(case_params[f"output{i+1}"][0]) for i in range(int(num_outputs))]
        for i in range(int(num_outputs)):
            case_params[f"SP_history{i+1}"] = [case_params["SP"][i]]

        # Build parameters vector (same order as your MPC_function expects)
        parameters = []
        for i in range(int(num_outputs)):
            parameters.append(case_params[f"output{i+1}"][0])
            parameters.append(case_params["SP"][i])

        for i in range(int(num_outputs)):
            parameters.append(case_params["LB"][i])
            parameters.append(case_params["UB"][i])

        for i in range(len(case_params["umin"])):
            parameters.append(case_params["umin"][i])
            parameters.append(case_params["umax"][i])

        for i in range(int(num_outputs)):
            for j in range(1, 4):
                parameters.append(case_params[f"toggle{i+1}{j}"])

        parameters.extend(q)
        parameters.extend(r)

        for i in range(int(num_outputs)):
            parameters.append(Wmin[i])
            parameters.append(Wmax[i])

        parameters.append(ch)
        parameters.append(ph)

        mpc_model = _resolve_mpc_model()
        if mpc_model is None:
            sim_msg("Could not load MPC model (from loader or RBF_mpc.mat).", "err")
            return

        sim_msg("Parameters initialized.", "ok")
        sim_msg("Starting closed-loop simulation…", "info")
        sim_msg("First MPC call can be slow (CasADi setup).", "warn")

        f = widgets.IntProgress(min=0, max=int(N), value=0, description="Progress:", layout=widgets.Layout(width="740px"))
        f.style = {"description_width": "initial"}
        f.bar_style = "info"
        pct_label = widgets.HTML("<span style='font-family:Arial, sans-serif; font-size:13px; margin-left:8px;'>0%</span>")
        display(widgets.HBox([f, pct_label]))

        real_print = builtins.print

        def pretty_print_interceptor(*args, **kwargs):
            msg = " ".join(str(a) for a in args)
            if "Formulation of Optimization Problem Complete after:" in msg:
                sim_msg(f"MPC initialized. {msg}", "ok")
            else:
                real_print(*args, **kwargs)

        builtins.print = pretty_print_interceptor

        start_time = tm.time()
        try:
            for k in range(int(N)):
                if k == case_params["SP_change_step"]:
                    case_params["SP"] = original_SP.copy()

                # Update measured y and SP in parameter vector
                for i in range(int(num_outputs)):
                    parameters[2 * i] = case_params[f"output{i+1}"][-1]
                    parameters[2 * i + 1] = case_params["SP"][i]

                F = MPC_function(parameters, case_params["T"][k], mpc_model, u0)
                outputs = plant(F, case_params["T"][-1], u0)

                for i in range(int(num_outputs)):
                    case_params[f"output{i+1}"].append(outputs[i])

                for i in range(int(num_inputs)):
                    # CasADi DM -> float
                    case_params[f"input{i+1}"].append(float(np.array(F[i].full()).item()))

                case_params["T"].append(case_params["T"][-1] + float(Ts))

                for i in range(int(num_outputs)):
                    case_params[f"SP_history{i+1}"].append(case_params["SP"][i])

                f.value = k + 1
                pct = int(round(100.0 * (f.value / max(1, int(N)))))
                pct_label.value = f"<span style='font-family:Arial, sans-serif; font-size:13px; margin-left:8px;'>{pct}%</span>"
        finally:
            builtins.print = real_print

        elapsed = tm.time() - start_time
        sim_msg(f"Simulation finished in {elapsed:.2f} s. Results stored.", "ok")

        # Configure plot dropdowns after sim
        valid_output_options = [i for i in [1, 2, 3, 4] if i <= int(num_outputs)]
        valid_input_options = [i for i in [1, 2, 3, 4] if i <= int(num_inputs)]

        output_plot_cols_dropdown.options = valid_output_options
        output_plot_cols_dropdown.value = min(2, valid_output_options[-1]) if valid_output_options else None

        input_plot_cols_dropdown.options = valid_input_options
        input_plot_cols_dropdown.value = min(2, valid_input_options[-1]) if valid_input_options else None

        has_out_units = (output_units is not None and len(output_units) > 0)
        has_in_units = (input_units is not None and len(input_units) > 0)

        if has_out_units:
            uo = [u for u in output_units if str(u).strip() and str(u).strip() != "-"]
            output_plot_pref_dropdown.options = ["Plot each output in a separate subplot"] + (
                ["Plot all outputs in one subplot"] if len(set(uo)) == 1 and len(uo) > 0 else []
            )
        else:
            output_plot_pref_dropdown.options = ["Plot each output in a separate subplot"]
        output_plot_pref_dropdown.value = "Plot each output in a separate subplot"

        if has_in_units:
            ui = [u for u in input_units if str(u).strip() and str(u).strip() != "-"]
            input_plot_pref_dropdown.options = ["Plot each input in a separate subplot"] + (
                ["Plot all inputs in one subplot"] if len(set(ui)) == 1 and len(ui) > 0 else []
            )
        else:
            input_plot_pref_dropdown.options = ["Plot each input in a separate subplot"]
        input_plot_pref_dropdown.value = "Plot each input in a separate subplot"


run_simulation_btn = widgets.Button(
    description="Run closed-loop simulation",
    button_style="info",
    layout=widgets.Layout(width="300px")
)
run_simulation_btn.on_click(run_closed_loop)

# =============================================================================
# 8) Plotting UI
# =============================================================================

# Ensure plot_out exists
if "plot_out" not in globals():
    plot_out = widgets.Output()

# Message area
PLOT_MSG_OUT = widgets.Output()

def _plot_banner_html(text: str, kind: str = "info") -> str:
    styles = {
        "ok":   ("#e8f5e9", "#2e7d32", "#c8e6c9"),
        "info": ("#e3f2fd", "#1565c0", "#bbdefb"),
        "warn": ("#fff8e1", "#8d6e00", "#ffecb3"),
        "err":  ("#ffebee", "#c62828", "#ffcdd2"),
    }
    bg, fg, br = styles.get(kind, styles["info"])
    safe = (str(text).replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;"))
    return f"""
    <div style="
        background:{bg};
        color:{fg};
        border:1px solid {br};
        padding:6px 10px;
        border-radius:10px;
        margin:6px 0;
        max-width:720px;
        font-family:Arial, sans-serif;
        font-size:13px;
        line-height:1.2;">
      {safe}
    </div>
    """

def plot_msg(text: str, kind: str = "info"):
    with PLOT_MSG_OUT:
        clear_output(wait=True)
        display(widgets.HTML(_plot_banner_html(text, kind)))

# Dropdown widths
output_plot_cols_dropdown.layout = widgets.Layout(width="240px")
input_plot_cols_dropdown.layout  = widgets.Layout(width="240px")
output_plot_pref_dropdown.layout = widgets.Layout(width="520px")
input_plot_pref_dropdown.layout  = widgets.Layout(width="520px")

def _series_to_1d_float(series_list):
    out = []
    for v in series_list:
        if hasattr(v, "full"):          # CasADi DM
            v = np.array(v.full()).reshape(-1)
        else:
            v = np.array(v).reshape(-1)
        out.append(float(v.flat[0]) if v.size else np.nan)
    return np.array(out, dtype=float)

def plot_results():
    if "case_params" not in globals() or case_params is None or "T" not in case_params:
        plot_msg("No simulation results found. Run the closed-loop simulation first.", "err")
        return

    import matplotlib.pyplot as plt
    try:
        import seaborn as sns
    except Exception:
        sns = None

    ny = int(num_outputs)
    nu = int(num_inputs)

    T = np.array(case_params["T"], dtype=float).reshape(-1)

    output_pref = output_plot_pref_dropdown.value
    input_pref  = input_plot_pref_dropdown.value
    num_output_plot_cols = int(output_plot_cols_dropdown.value) if output_plot_cols_dropdown.value else 1
    num_input_plot_cols  = int(input_plot_cols_dropdown.value) if input_plot_cols_dropdown.value else 1

    output_rows = math.ceil(ny / num_output_plot_cols)
    input_rows  = 1 if input_pref == "Plot all inputs in one subplot" else math.ceil(nu / num_input_plot_cols)

    # ====================== OUTPUTS ======================
    if output_pref == "Plot all outputs in one subplot":
        fig_out = plt.figure(figsize=(12, 5))
        fig_out.suptitle("──── Outputs ────", fontsize=13, color="gray")
        ax = fig_out.add_subplot(1, 1, 1)
        colors = sns.color_palette("icefire", ny) if sns else [None] * ny

        for i in range(ny):
            y = _series_to_1d_float(case_params[f"output{i+1}"])
            ax.plot(T, y, linewidth=1.5, linestyle="-.", marker="o", markersize=3,
                    color=colors[i] if colors[i] is not None else None,
                    label=(output_labels[i] if "output_labels" in globals() else f"Output {i+1}"))

            sp = np.array(case_params.get(f"SP_history{i+1}", []), dtype=float).reshape(-1)
            if len(sp) == len(T):
                ax.step(T, sp, where="post", linestyle="--", linewidth=1.5,
                        color=colors[i] if colors[i] is not None else None,
                        label=f"Setpoint {i+1}")

        ax.set_xlabel(f"Time ({time_units})" if "time_units" in globals() else "Time")
        if "output_units" in globals() and output_units:
            ax.set_ylabel(f"({output_units[0]})")
        ax.legend(loc="upper left")
        plt.tight_layout()
        plt.show()

    else:
        fig_out = plt.figure(figsize=(12, 5 * output_rows))
        fig_out.suptitle("──── Outputs ────", fontsize=13, color="gray")
        axs_out = fig_out.subplots(output_rows, num_output_plot_cols)
        axs_out = np.array(axs_out).flatten()
        colors = sns.color_palette("icefire", ny) if sns else [None] * ny

        for i in range(ny):
            ax = axs_out[i]
            choose_color = colors[i % len(colors)] if colors[i % len(colors)] is not None else None

            y = _series_to_1d_float(case_params[f"output{i+1}"])
            ax.plot(T, y, linewidth=1.5, marker="o", markersize=3,
                    color=choose_color, markerfacecolor=choose_color,
                    label="Measured Output")

            sp = np.array(case_params.get(f"SP_history{i+1}", []), dtype=float).reshape(-1)
            if len(sp) == len(T):
                ax.step(T, sp, where="post", linestyle="--", linewidth=1.5,
                        color=choose_color, label="Setpoint Value")

            if "LB" in case_params and "UB" in case_params:
                ax.fill_between(
                    T,
                    np.full(len(T), float(case_params["LB"][i])),
                    np.full(len(T), float(case_params["UB"][i])),
                    color="lightgray", alpha=0.4, label="Bounds"
                )

            title_txt = output_labels[i] if "output_labels" in globals() else f"Output {i+1}"
            ax.set_title(title_txt, fontsize=11)
            ax.set_xlabel(f"Time ({time_units})" if "time_units" in globals() else "Time")
            if "output_units" in globals() and len(output_units) > i:
                ax.set_ylabel(f"({output_units[i]})")
            ax.legend(loc="upper left")

        for j in range(ny, len(axs_out)):
            axs_out[j].set_visible(False)

        plt.tight_layout()
        plt.show()

    # ====================== INPUTS ======================
    if input_pref == "Plot all inputs in one subplot":
        fig_in = plt.figure(figsize=(12, 5))
        fig_in.suptitle("──── Inputs ────", fontsize=13, color="gray")
        ax = fig_in.add_subplot(1, 1, 1)
        colors = sns.color_palette("viridis", nu) if sns else [None] * nu

        for i in range(nu):
            u = _series_to_1d_float(case_params[f"input{i+1}"])
            ax.plot(T, u, linewidth=1.5, linestyle="-.", marker="o", markersize=3,
                    color=colors[i] if colors[i] is not None else None,
                    label=(input_labels[i] if "input_labels" in globals() else f"Input {i+1}"))

        if "umin" in case_params and "umax" in case_params:
            ax.fill_between(T, min(case_params["umin"]), max(case_params["umax"]),
                            color="lightgray", alpha=0.4, label="Bounds")

        ax.set_xlabel(f"Time ({time_units})" if "time_units" in globals() else "Time")
        if "input_units" in globals() and input_units:
            ax.set_ylabel(f"({input_units[0]})")
        ax.legend(loc="upper left")
        plt.tight_layout()
        plt.show()

    else:
        fig_in = plt.figure(figsize=(12, 5 * input_rows))
        fig_in.suptitle("──── Inputs ────", fontsize=13, color="gray")
        axs_in = fig_in.subplots(input_rows, num_input_plot_cols)
        axs_in = np.array(axs_in).flatten()
        colors = sns.color_palette("viridis", nu) if sns else [None] * nu

        for i in range(nu):
            ax = axs_in[i]
            choose_color = colors[i % len(colors)] if colors[i % len(colors)] is not None else None

            u = _series_to_1d_float(case_params[f"input{i+1}"])
            ax.plot(T, u, linewidth=1.5, marker="o", markersize=3,
                    color=choose_color, markerfacecolor=choose_color,
                    label=(input_labels[i] if "input_labels" in globals() else f"Input {i+1}"))

            if "umin" in case_params and "umax" in case_params:
                ax.fill_between(T, min(case_params["umin"]), max(case_params["umax"]),
                                color="lightgray", alpha=0.4, label="Bounds")

            title_txt = input_labels[i] if "input_labels" in globals() else f"Input {i+1}"
            ax.set_title(title_txt, fontsize=11)
            ax.set_xlabel(f"Time ({time_units})" if "time_units" in globals() else "Time")
            if "input_units" in globals() and len(input_units) > i:
                ax.set_ylabel(f"({input_units[i]})")
            ax.legend(loc="upper left")

        for j in range(nu, len(axs_in)):
            axs_in[j].set_visible(False)

        plt.tight_layout()
        plt.show()

# Buttons
PLOT_BTN = widgets.Button(description="Plot Now", button_style="primary")
PLOT_BTN.layout = widgets.Layout(width="220px")

CLEAR_BTN = widgets.Button(description="Clear Plots", button_style="danger")
CLEAR_BTN.layout = widgets.Layout(width="180px")

def _plot_clicked(_):
    plot_out.clear_output(wait=False)
    import matplotlib.pyplot as plt
    plt.close('all')
    with plot_out:
        plot_results()

def _clear_clicked(_):
    plot_out.clear_output(wait=False)
    PLOT_MSG_OUT.clear_output(wait=False)
    try:
        import matplotlib.pyplot as plt
        plt.close("all")
    except Exception:
        pass
    plot_msg("Cleared plots.", "info")

PLOT_BTN.on_click(_plot_clicked)
CLEAR_BTN.on_click(_clear_clicked)

# UI layout
PLOT_TITLE = widgets.HTML("<h3 style='margin:0'>Plot simulation results</h3>")
PLOT_SUBTITLE = widgets.HTML(
    "<div style='color:#555; font-family:Arial; font-size:13px;'>"
    "Choose layout and press <b>Plot Now</b>. Use <b>Clear Plots</b> to remove everything immediately.</div>"
)

plot_ui = widgets.VBox([
    PLOT_TITLE,
    PLOT_SUBTITLE,
    widgets.HTML("<hr style='margin:8px 0'>"),
    PLOT_MSG_OUT,
    widgets.HTML("<b style='font-family:Arial'>Outputs</b>"),
    widgets.HBox([output_plot_cols_dropdown, output_plot_pref_dropdown], layout=widgets.Layout(gap="12px")),
    widgets.HTML("<hr style='margin:10px 0'>"),
    widgets.HTML("<b style='font-family:Arial'>Inputs</b>"),
    widgets.HBox([input_plot_cols_dropdown, input_plot_pref_dropdown], layout=widgets.Layout(gap="12px")),
    widgets.HTML("<hr style='margin:10px 0'>"),
    widgets.HBox([PLOT_BTN, CLEAR_BTN], layout=widgets.Layout(gap="10px")),
    plot_out
])

# =============================================================================
# 9) MPC controller
# =============================================================================

vars_store = None

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MPC CONTROLLER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
def MPC_function(parameters, t, predictive_model, new_mvss = None):
    global vars_store # Use one global variable to avoid pollution of the global namespace

    if t == 0:

        # Start timer
        start_time = tm.time()

        # Load Predictive Model
        nu = predictive_model['nu'][0][0] # number of inputs used in RBF training
        ny = predictive_model['ny'][0][0] # number of outputs used in RBF training
        if new_mvss is None:
          mv_ss = predictive_model['mv_ss'][0] # initial input steady state values
        else:
          mv_ss = new_mvss

        lpin = predictive_model['lpin'][0][0] # number of past values required for each input
        centers = predictive_model['centers'] # positions of centers in RBF hidden layer
        weights = predictive_model['weights'] # connecting weights in RBF
        ymin = predictive_model['ymin'][0] # minimum output values inspected during training (used for normalisation)
        ymax = predictive_model['ymax'][0] # maximum output values inspected during training (used for normalization)
        xmin = predictive_model['xmin'][0] # minimum input values inspected during training (used for normalization)
        xmax = predictive_model['xmax'][0] # maximum input values inspected during training (used for normalization)
        width = predictive_model['width']

        # Define the parameter vector P
        P = ca.SX.sym('P', 2*ny + 2*ny + 2*nu + 3*ny + centers.shape[1] + ny + ny + nu + 2*ny)
        # P elements include:
        # - measured value and set-point value for each controlled output (2*ny)
        # - minimum and maximum value for each controlled output (2*ny)
        # - minimum and maximum value for each manipulated variable (2*nu)
        # - 3 toggles for each controlled output to define control objectives:
        #   a) SP tracking, b) LB soft constraint, c) UB soft constraint (3*ny)
        # - past input values for RBF state (updated at each time instance) (centers.shape[1])
        # - corrective terms to minimize prediction error (using closed-loop feedback) (ny)
        # diagonal elements of penalty weighting matrix Q for SP tracking (ny)
        # diagonal elements of penalty weighting matrix R for MV move suppression (nu)
        # penalty weights for minimum and maximum soft constraints on outputs (2*ny)

        # Additional variable to denote positions before past u feedback values
        nf = 7*ny + 2*nu;
        nerrors = nf + centers.shape[1]
        ntuning = nerrors + ny

        # MPC horizons (stored in the last two elements of parameter vector provided by user)
        # control horizon
        ch = parameters[-2]
        # prediction horizon
        ph = parameters[-1]

        # Create empty dictionary to store variables
        vars_store = {}
        # Initialize lists inside vars_store
        vars_store['u'] = []; vars_store['du'] = []
        vars_store['past_u'] = []; vars_store['y'] = []
        vars_store['emin'] = []; vars_store['emax'] = []
        vars_store['RBF_fun'] = None; vars_store['solver'] = None; vars_store['previous_solution'] = None
        vars_store['lbg'] = []; vars_store['ubg'] = []
        vars_store['lbx'] = []; vars_store['ubx'] = []
        vars_store['ph'] = ph; vars_store['ch'] = ch; vars_store['nf'] = nf

        # Create symbolic variables dynamically for control inputs (u, du, past_u)
        for i in range(nu):
            vars_store['u'].append(ca.SX.sym(f'u{i+1}', ch))
            vars_store['du'].append(ca.SX.sym(f'du{i+1}', ch))
            vars_store['past_u'].append(mv_ss[i] * np.ones((1, lpin)))

        # Create symbolic variables dynamically for output (y, emin, emax)
        for i in range(ny):
            vars_store['y'].append(ca.SX.sym(f'y{i+1}', ph + 1))
            vars_store['emin'].append(ca.SX.sym(f'emin{i+1}', ph + 1))
            vars_store['emax'].append(ca.SX.sym(f'emax{i+1}', ph + 1))

        # RBF Predictive Model
        x = ca.SX.sym('x', 1, centers.shape[1])

        min_params = []; max_params = []
        for i in range(nu):
            for j in range(lpin):
                min_params.append(xmin[i])
                max_params.append(xmax[i])
        min_params = np.array(min_params).reshape(1, -1); max_params = np.array(max_params).reshape(1, -1)

        pred = (weights.T @ (ca.exp(-ca.sqrt(ca.sum2((centers - ca.repmat(
            (2 * (x - min_params) / (max_params - min_params) - 1), centers.shape[0], 1))**2))**2 / width**2)) + 1) * (ymax - ymin) / 2 + ymin
        vars_store['RBF_fun'] = ca.Function('RBFfun', [x], [pred])

        # Current RBF prediction
        current_predictions = vars_store['RBF_fun'](ca.horzcat(*vars_store['past_u']))

        # Dynamically assign y1[0], ysp1, y2[0], ysp2, etc.
        for i in range(ny):
            vars_store['y'][i][0] = P[2*i]
            vars_store[f'ysp{i+1}'] = P[2*i + 1]

        # Dynamically assign bounds for controlled variables ymin_ud1, ymax_ud1, etc.
        offset = 2 * ny
        for i in range(ny):
            vars_store[f'ymin_ud{i+1}'] = P[offset + 2*i]
            vars_store[f'ymax_ud{i+1}'] = P[offset + 2*i + 1]

        # Dynamically assign bounds for manipulated variables umin1, umax1, etc.
        offset += 2 * ny
        for i in range(nu):
            vars_store[f'umin{i+1}'] = P[offset + 2*i]
            vars_store[f'umax{i+1}'] = P[offset + 2*i + 1]

        # Dynamically assign toggle values for each y variable
        offset += 2 * nu
        for i in range(ny):
            for j in range(3):  # Three toggles per y variable
                vars_store[f'toggle{i+1}{j+1}'] = P[offset + 3*i + j]

        # Initialize vector of constraints
        g = []

        for k in range(ch):  # u(k) = du(1:k) + past u
            for i in range(nu):
                vars_store['u'][i][k] = ca.sum1(vars_store['du'][i][:k]) + P[nf + i*lpin]

        # Structure x across controller horizons
        x_elements = []
        for k in range(ph):
          if k <= ch-1:
            for i in range(nu):
                x_elements.append(vars_store['u'][i][k::-1])
                # Add corresponding slice of P
                x_elements.append(P[nf+i*lpin:nf+(i+1)*lpin-(k+1)])

            # Flatten the temporary list if necessary 1
            x[:] = ca.vertcat(*x_elements); x_elements = []

          elif k <= lpin-1:
            for i in range(nu):
                x_elements.append(vars_store['u'][i][-1]*np.ones(k-ch+1))
                x_elements.append(vars_store['u'][i][::-1])
                # Add corresponding slice of P
                x_elements.append(P[nf+i*lpin:nf+(i+1)*lpin-(k+1)])

            # Flatten the temporary list if necessary 2
            x[:] = ca.vertcat(*x_elements); x_elements = []

          else:

            for i in range(nu):
                x_elements.append(vars_store['u'][i][-1]*np.ones(lpin))

            #ipdb.set_trace()
            # Flatten the temporary list if necessary 3
            x[:] = ca.vertcat(*x_elements); x_elements = []

          # Predictions
          # Controlled Variables from RBF Model
          for i in range(ny):
            vars_store['y'][i][k+1] = vars_store['RBF_fun'](x)[i] + P[nerrors + i] #P[-(ny-i)]

          # Constraints
          for i in range(ny): # decision variables are presented in g vector to distinguish from free variables
            g.append(vars_store['y'][i][k])  # Append the corresponding y variables
            g.append(vars_store['emin'][i][k])  # Append the corresponding emin variables
            g.append(vars_store['emax'][i][k])  # Append the corresponding emax variables

        #ipdb.set_trace()

        # Bound last elements of emin, emax as well
        for i in range(ny):
            g.append(vars_store['emin'][i][-1])  # Append the corresponding emin variables
            g.append(vars_store['emax'][i][-1])  # Append the corresponding emax variables

        ## Tuning weights
        # Define Q,R matrices: weighting positive matrices for the output deviations
        # and the manipulated input increments

        # Dimensions of Q: (ny x ny) , ny the number of output variables
        SPtoggle =  np.array([vars_store[f'toggle{i+1}{1}'] for i in range(ny)])
        q = [P[ntuning + i] for i in range(ny)]
        Q = np.diag(q) @ np.diag(SPtoggle)
        if not ((len(Q) == len(Q[0])) and (len(Q) == ny)):
          raise ValueError("Matrix Q must be square (ny x ny). Current shape is {}x{}".format(len(Q), len(Q[0])))

        # Dimensions of R: (nu x nu) , nu the number of manipulated variables
        r = [P[ntuning + ny + i] for i in range(nu)]
        R = np.diag(r)
        if not ((len(R) == len(R[0])) and (len(R) == nu)):
          raise ValueError("Matrix R must be square (nu x nu). Current shape is {}x{}".format(len(R), len(R[0])))

        # Penalty weights for soft constraints
        Wmin = [P[ntuning + ny + nu + i] for i in range(ny)]
        Wmax = [P[ntuning + ny + nu + ny + i] for i in range(ny)]
        for i in range(ny):
          vars_store[f'Wmin{i+1}'] = Wmin[i]*vars_store[f'toggle{i+1}{2}']
          vars_store[f'Wmax{i+1}'] = Wmax[i]*vars_store[f'toggle{i+1}{3}']

        ## Objective Function
        J=0 # Initialize J

        for k in range(ph):
          if k <= ch-1:
            cv =  np.array([vars_store['y'][i][k+1] - vars_store[f'ysp{i+1}']  for i in range(ny)])
            sc =  np.array([vars_store[f'Wmin{i+1}'] * vars_store['emin'][i][k+1] +
                  vars_store[f'Wmax{i+1}'] * vars_store['emax'][i][k+1] for i in range(ny)])
            mv =  np.array([vars_store['du'][i][k]  for i in range(nu)])

            J += ca.mtimes(ca.mtimes(ca.horzcat(*cv),Q), ca.vertcat(*cv)) + sum(sc) + ca.mtimes(ca.mtimes(ca.horzcat(*mv),R), ca.vertcat(*mv))

            for i in range(nu):
              g.append(ca.sum1(vars_store['du'][i][:k+1]  + P[nf+i*lpin] - vars_store[f'umin{i+1}'])) # umin <= u => u - umin >= 0
              g.append(vars_store[f'umax{i+1}'] - ca.sum1(vars_store['du'][i][:k+1]) - P[nf+i*lpin])  # u <= umax => umax - u >= 0
            for j in range(ny):
              g.append(vars_store['y'][j][k+1] + vars_store['emin'][j][k+1] - vars_store[f'ymin_ud{j+1}'] )  # y >= ymin - emin => y - emin - ymin >= 0
              g.append(vars_store['emax'][j][k+1] + vars_store[f'ymax_ud{j+1}'] - vars_store['y'][j][k+1] )  # y <= ymax + emax => ymax + emax - y >= 0

          else:
            cv =  np.array([vars_store['y'][i][k+1] - vars_store[f'ysp{i+1}']  for i in range(ny)])
            sc =  np.array([vars_store[f'Wmin{i+1}'] * vars_store['emin'][i][k+1] +
                  vars_store[f'Wmax{i+1}'] * vars_store['emax'][i][k+1] for i in range(ny)])

            J += ca.mtimes(ca.mtimes(ca.horzcat(*cv),Q), ca.vertcat(*cv)) + sum(sc)
            for j in range(ny):
              g.append(vars_store['y'][j][k+1] + vars_store['emin'][j][k+1] - vars_store[f'ymin_ud{j+1}'] )  # y >= ymin - emin => y - emin - ymin >= 0
              g.append(vars_store['emax'][j][k+1] + vars_store[f'ymax_ud{j+1}'] - vars_store['y'][j][k+1] )  # y <= ymax + emax => ymax + emax - y >= 0

        ## Constraints
        # Lower Bounds (all g constraints have zero lower bounds, except for the predictions which are set using user defined bounds)
        lbg = np.zeros(len(g)); lbx = -ca.inf*np.ones((nu*ch)+2*ny*(ph+1))
        for j in range(ph):
          for i in range(ny):
            lbg[3*i+3*ny*j] = -np.inf # predictions of internal model are unconstrained

        # Upper Bounds (all g constraints have infinite upper bounds, except for the predictions and the 1st elements of slack variables)
        ubg = ca.inf*np.ones(len(g)); ubx = ca.inf*np.ones((nu*ch)+2*ny*(ph+1))
        for j in range(ph):
          for i in range(ny):
            ubg[3*i+3*ny*j] = np.inf # predictions of internal model are unconstrained

        # Set first element of emin and emax equal to 0 (not used during optimization)
        for i in range(ny):
          ubg[3*i+1] = 0
          ubg[3*i+2] = 0

        # Decision Variables
        decision_variables = ca.vertcat(
          *[vars_store['du'][i] for i in range(nu)],
          *[vars_store['emin'][i] for i in range(ny)],
          *[vars_store['emax'][i] for i in range(ny)])

        # Create the solver
        prob = {'f': J, 'x': decision_variables, 'g': ca.vertcat(*g) , 'p': P}
        #opts = {'ipopt.print_level': 0,'print_time': 0,'ipopt.sb': 'yes'}

        opts = {
            'ipopt.print_level': 0,
            'print_time': 0,
            'ipopt.sb': 'yes',
            'ipopt.tol': 1e-4,  # Increase tolerance if precision allows
            'ipopt.max_iter': 100  # Set maximum iterations
        }


        solver = ca.nlpsol('solver', 'ipopt', prob,opts)
        vars_store['solver'] = solver; vars_store['lbg'] = lbg; vars_store['ubg'] = ubg; vars_store['lbx'] = lbx; vars_store['ubx'] = ubx

        # Call the solver
        past_inputs = ca.horzcat(*vars_store['past_u'])

        # Corrections automatically adjusted
        for i in range(ny):
            vars_store[f'corrections{i+1}'] = parameters[2*i] - current_predictions[i]

        p = ca.horzcat(*parameters[:nf],past_inputs,*[vars_store[f'corrections{i+1}'] for i in range(ny)],*parameters[nf:-2])
        sol = solver(lbx=lbx, ubx=ubx, lbg=lbg, ubg=ubg, p=p)
        vars_store['previous_solution'] = sol['x']

        # Select first control action for each MV
        du = [sol['x'][i * ch].full() for i in range(nu)]

        # Update past inputs
        for i in range(nu):
          vars_store['past_u'][i] = ca.horzcat(du[i]+vars_store['past_u'][i][:,0],vars_store['past_u'][i][:,:-1])

        # DMC Outputs
        F = [vars_store['past_u'][i][:,0] for i in range(nu)]
        # End timer
        end_time = tm.time()
        # Calculate elapsed time
        elapsed_time = end_time - start_time
        # print(f"Formulation of Optimization Problem Complete after: {elapsed_time:.2f} seconds")
        msg = f"Formulation of Optimization Problem Complete after: {elapsed_time:.2f} seconds"
        if "MPC_INIT_REPORTER" in globals() and callable(globals()["MPC_INIT_REPORTER"]):
            globals()["MPC_INIT_REPORTER"](msg)
        else:
            print(msg)

    else:

        # Get current prediction to integrate error
        past_inputs = ca.horzcat(*vars_store['past_u']);

        current_predictions = vars_store['RBF_fun'](ca.horzcat(*vars_store['past_u']))

        # Corrections automatically adjusted
        for i in range(len(vars_store['y'])):
            vars_store[f'corrections{i+1}'] = parameters[2*i] - current_predictions[i]

        # Call the solver
        p = ca.horzcat(*parameters[:vars_store['nf']],past_inputs,*[vars_store[f'corrections{i+1}'] for i in range(len(vars_store['y']))],*parameters[vars_store['nf']:-2])
        sol = vars_store['solver'](lbx=vars_store['lbx'], ubx=vars_store['ubx'], lbg=vars_store['lbg'], ubg=vars_store['ubg'], p=p, x0=vars_store['previous_solution'])
        vars_store['previous_solution'] = sol['x']
        # print(sol['x'].full().flatten())

        # Select first control action for each MV
        du = [sol['x'][i * vars_store['ch']].full() for i in range(len(vars_store['u']))]

        for i in range(len(vars_store['u'])):
          vars_store['past_u'][i] = ca.horzcat(du[i]+vars_store['past_u'][i][:,0],vars_store['past_u'][i][:,:-1])

        # DMC Outputs
        F = [vars_store['past_u'][i][:,0] for i in range(len(vars_store['u']))]

    return F


Collecting casadi
  Downloading casadi-3.7.2-cp312-none-manylinux2014_x86_64.whl.metadata (2.2 kB)
Downloading casadi-3.7.2-cp312-none-manylinux2014_x86_64.whl (75.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.6/75.6 MB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: casadi
Successfully installed casadi-3.7.2


## **Import Files**

Use the file uploader to load the following files before running the workflow:

- **Dataset file** (`.csv`, `.xlsx`):  
  Time-series data containing the input and output signals used for training and validation.

- **Trained MPC model** (`.mat`):  
  The RBFNN-based predictive model exported from the training notebook (**Predictive_RBFNNs_for_Control_Systems_Greece4_0.ipynb**), which will be used by the MPC controller.

- **Plant model** (`.mat`):  
  A separate RBFNN model representing the plant dynamics, used for closed-loop simulation, exported from the training notebook (**Predictive_RBFNNs_for_Control_Systems_Greece4_0.ipynb**).

In [2]:
show_file_uploader()

VBox(children=(Button(button_style='info', description='Upload data/models', layout=Layout(width='260px'), sty…

Output()

## **Load Models**

Select the **MPC model** and the **plant model** (`.mat` files) from the uploaded files.

In [3]:
display(widgets.VBox([mpc_loader_ui, plant_loader_ui]))

VBox(children=(VBox(children=(HTML(value="<h3 style='margin:0'>MPC model</h3>"), Text(value='/content', descri…

## **Define Initial Steady State**

Define the initial steady-state operating point either by selecting a row from a dataset or by manually entering input values.


In [4]:
display(steady_state_ui)

VBox(children=(HTML(value="<h3 style='margin:0'>Initial steady state</h3>"), HBox(children=(Text(value='/conte…

## **Tune MPC Parameters**

Configure the MPC tuning parameters, including the prediction and control horizons, the weights related to output tracking, move suppression and bound violation.

In [5]:
display(mpc_tuning_ui)

VBox(children=(HTML(value='<h3>MPC tuning</h3>'), HBox(children=(Button(button_style='info', description='Crea…

## **Define Case Study**


Define the simulation scenario by specifying setpoints, input and output bounds, and feature toggles relative to control objectives.


In [6]:
display(case_study_ui)

VBox(children=(HBox(children=(Button(button_style='info', description='Create case study widgets', layout=Layo…

## **Run Closed-Loop Simulation**

Run the closed-loop MPC simulation using the configured models, steady state, tuning parameters, and case study settings.


In [7]:
display(run_simulation_btn)
display(simulation_messages_out)

Button(button_style='info', description='Run closed-loop simulation', layout=Layout(width='300px'), style=Butt…

Output()

## **Plot Results**

Visualize the closed-loop trajectories of outputs and inputs with configurable subplot layouts.


In [8]:
display(plot_ui)

VBox(children=(HTML(value="<h3 style='margin:0'>Plot simulation results</h3>"), HTML(value="<div style='color:…