In [21]:
%pip install plotly
import plotly.graph_objects as go
import ipywidgets as widgets
from math import sqrt
import time

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [22]:
## Loss functions

def avg(y_true, y_pred):
    return sum((t - p) for t, p in zip(y_true, y_pred)) / len(y_true)

def mse(y_true, y_pred):
    return sum((t - p)**2 for t, p in zip(y_true, y_pred)) / len(y_true)

def rmse(y_true, y_pred):
    return sqrt(sum((t - p)**2 for t, p in zip(y_true, y_pred)) / len(y_true))

def mae(y_true, y_pred):
    return sum(abs(t - p) for t, p in zip(y_true, y_pred)) / len(y_true)

## Gradient functions
# Let us say we have a model parameter z and we want to minimize the loss function L(x) with respect to z, where x is the input data.
# The derivative of L with respect to z, denoted as dL/dz = dL/dp * dp/dz, where p is the output of the model.
# Here, we are implementing dL/dp for each loss function. The derivative of the model output p with respect to z is implemented in the model class.

def grad_avg(t, p):
    return -1

def grad_mse(t, p):
    return -2 * (t - p)

def grad_rmse(t, p):
    raise NotImplementedError("RMSE is not differentiable")

def grad_mae(t, p):
    return -1 if t > p else 1


In [33]:
def get_y_pred(model, x_data):
    return [model.forward(x) for x in x_data]


def gradient_descent_step(model, x, t, p, grad_loss, lr):
    params = vars(model).items()
    model_grad = model.grad(x)
    for pk, pv in params:
        grad = grad_loss(t, p) * model_grad[pk]
        new_pv = pv - lr * grad
        setattr(model, pk, new_pv)


def create_figure(model, x_data, y_data, loss_fn_dropdown):
    # Determine initial axis limits
    x_min, x_max = min(x_data), max(x_data)
    y_min, y_max = min(y_data), max(y_data)
    y_pred = get_y_pred(model, x_data)

    # Create figure and initial trace
    fig = go.FigureWidget(
        data=[
            go.Scatter(x=x_data, y=y_data, mode="markers", name="Data", marker=dict(opacity=0.5)),
            go.Scatter(x=x_data, y=y_pred, mode="lines", name="Regression Line"),
            go.Scatter(
                x=[None], y=[None], mode="lines", line=dict(color="gray"), name="Error Lines"
            ),  # Dummy trace for error lines
        ],
        layout=go.Layout(
            title="Interactive Linear Regression",
            xaxis_title="Input",
            yaxis_title="Output",
            xaxis_range=[x_min - 2, x_max + 2],
            yaxis_range=[y_min - 2, y_max + 2],
            xaxis=dict(dtick=5),  # Set tick frequency for x-axis
            yaxis=dict(dtick=5),  # Set tick frequency for y-axis
            autosize=False,  # Disable automatic resizing
            width=16 * 50,  # Set width and height to the same value for square shape
            height=9 * 50,
            shapes=[
                dict(
                    type="line",
                    xref="x",
                    yref="y",
                    x0=x,
                    y0=y_data[i],
                    x1=x,
                    y1=y_pred[i],
                    line=dict(
                        color="gray",
                        width=1,
                    ),
                )
                for i, x in enumerate(x_data)
            ],
            annotations=[
                dict(
                    text=f"{loss_fn_dropdown.label}: {loss_fn_dropdown.value[0](y_data, y_pred):.2f}",
                    x=0.5,
                    y=0.98,
                    xref="paper",
                    yref="paper",
                    showarrow=False,
                )
            ],
        ),
    )
    return fig


def create_loss_vs_param_figure(model, loss):
    if len(vars(model)) == 1:
        pk, pv = list(vars(model).items())[0]
        return go.FigureWidget(
            data=[
                go.Scatter(x=[pv], y=[loss], mode="markers", name="Loss", marker=dict(color=["rgba(255, 0, 0, 1)"], size=[10])),
            ],
            layout=go.Layout(
                title="Loss vs Model Parameters",
                xaxis_title=pk + "*",
                yaxis_title="Loss",
                xaxis=dict(dtick=5),
                yaxis=dict(dtick=5),
                xaxis_range=[-25, 25],
                yaxis_range=[-25, 25],
                autosize=False,
                width=9 * 50,
                height=9 * 50,
            ),
        )
    elif len(vars(model)) == 2:
        pk1, pv1 = list(vars(model).items())[0]
        pk2, pv2 = list(vars(model).items())[1]
        return go.FigureWidget(
            data=[
                go.Scatter3d(
                    x=[pv1],
                    y=[pv2],
                    z=[loss],
                    mode="markers",
                    name="Loss",
                    marker=dict(color=["rgba(255, 0, 0, 1)"], size=[10]),
                ),
            ],
            layout=go.Layout(
                title="Loss vs Model Parameters",
                scene=dict(
                    xaxis=dict(title=pk1, dtick=5, range=[-25, 25]),
                    yaxis=dict(title=pk2, dtick=5, range=[-25, 25]),
                    zaxis=dict(title="Loss", dtick=5, range=[-25, 25]),
                ),
                autosize=False,
                width=16 * 50,
                height=9 * 50,
            ),
        )
    else:
        return go.FigureWidget(
            data=[],
            layout=go.Layout(
                title="Loss vs Model Parameters",
                autosize=False,
                width=16 * 50,
                height=9 * 50,
                annotations=[
                    dict(
                        text="More than 2 model parameters not supported",
                        showarrow=False,
                        xref="paper",
                        yref="paper",
                        x=0.5,
                        y=0.5,
                    )
                ],
            ),
        )


def create_loss_fn_dropdown():
    return widgets.Dropdown(
        options={
            "Average Error": [avg, "rgba(0, 0, 255, 0.8)"],  # Cyan
            "Mean Squared Error": [mse, "rgba(255, 0, 255, 0.8)"],  # Magenta
            "Root Mean Squared Error": [rmse, "rgba(0, 255, 0, 0.8)"],  # Lime Green
            "Mean Absolute Error": [mae, "rgba(255, 165, 0, 0.8)"],  # Orange
        },
        value=[avg, "rgba(0, 0, 255, 0.8)"],
        description="Losses:",
    )


def plot(x_data, y_data, model):
    loss_fn_dropdown = create_loss_fn_dropdown()
    fig = create_figure(model, x_data, y_data, loss_fn_dropdown)
    loss_vs_param_fig = create_loss_vs_param_figure(model, loss_fn_dropdown.value[0](y_data, get_y_pred(model, x_data)))
    learning_rate_slider = widgets.FloatLogSlider(value=0.1, base=10, min=-4, max=1, step=0.1, description="Learning Rate:")
    start_button = widgets.Button(description="Start Simulation")

    def update_plot(change):
        y_pred = get_y_pred(model, x_data)
        loss = loss_fn_dropdown.value[0](y_data, y_pred)
        with fig.batch_update():
            fig.data[1].y = y_pred  # Update the line
            # Update error lines
            fig.layout.shapes = [
                dict(
                    type="line",
                    xref="x",
                    yref="y",
                    x0=x,
                    y0=y_data[i],
                    x1=x,
                    y1=y_pred[i],
                    line=dict(
                        color="gray",
                        width=2,
                    ),
                )
                for i, x in enumerate(x_data)
            ]
            # Update loss function annotation
            fig.layout.annotations[0].text = f"{loss_fn_dropdown.label}: {loss:.2f}"

        # Update loss vs param figure
        if len(vars(model)) == 1:
            pv = list(vars(model).values())[0]
            loss_vs_param_fig.data[0].x += (pv,)
            loss_vs_param_fig.data[0].y += (loss,)
        elif len(vars(model)) == 2:
            pv1, pv2 = list(vars(model).values())
            loss_vs_param_fig.data[0].x += (pv1,)
            loss_vs_param_fig.data[0].y += (pv2,)
            loss_vs_param_fig.data[0].z += (loss,)

        current_colors = list(loss_vs_param_fig.data[0].marker.color)
        current_sizes = list(loss_vs_param_fig.data[0].marker.size)
        current_colors[-1] = loss_fn_dropdown.value[1]
        if isinstance(change, dict) and "owner" in change and change["owner"] == loss_fn_dropdown:
            current_colors[-1] = change["old"][1]
        else:
            current_colors[-1] = loss_fn_dropdown.value[1]
        current_colors += ["rgba(255, 0, 0, 1)"]
        current_sizes[-1] = 6
        current_sizes += [10]
        loss_vs_param_fig.data[0].marker.color = current_colors
        loss_vs_param_fig.data[0].marker.size = current_sizes

    def start_simulation(change):
        lr = learning_rate_slider.value
        y_pred = get_y_pred(model, x_data)
        grad_loss = globals()["grad_" + loss_fn_dropdown.value[0].__name__]
        for x, t, p in zip(x_data, y_data, y_pred):
            gradient_descent_step(model, x, t, p, grad_loss, lr)
            update_plot(change)
            time.sleep(1)

    start_button.on_click(start_simulation)
    loss_fn_dropdown.observe(update_plot, names="value")

    return widgets.HBox([widgets.VBox([loss_fn_dropdown, fig, learning_rate_slider, start_button]), loss_vs_param_fig])

In [34]:
class ConstantModel:
    def __init__(self, h=2.0):
        self.h = h

    def grad(self, x): # 
        return {"h": 1}
    
    def forward(self, x):
        return self.h

x_data = [1, 2, 3, 4]
y_data = [6, 8, 1, 3]

model = ConstantModel()
plot(x_data, y_data, model)

HBox(children=(VBox(children=(Dropdown(description='Losses:', options={'Average Error': [<function avg at 0x7f…

In [35]:
class SimpleLinearRegression:
    def __init__ (self, m=0.7, c=2.2):
        self.m = m
        self.c = c
    
    def grad(self, x):
        return {"m": x, "c": 1}
    
    def forward(self, x):
        return self.m * x + self.c

# Initial data points
x_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
y_data = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11]

model = SimpleLinearRegression()
plot(x_data, y_data, model)

HBox(children=(VBox(children=(Dropdown(description='Losses:', options={'Average Error': [<function avg at 0x7f…