In [None]:
%pip install plotly
import plotly.graph_objects as go
import ipywidgets as widgets
from math import sqrt
import numpy as np
import time

# If Using Google Colab, enable these extensions
# from google.colab import output
# output.enable_custom_widget_manager()

## MSE Loss: A Deep Dive into Gradient Descent

Ever wondered how machines learn to predict things? Let's explore one of the coolest techniques: gradient descent with Mean Squared Error (MSE) loss. It's like teaching a computer to play "Hot and Cold" with data!

### What's the Big Idea?

Imagine you're trying to draw the best straight line through a bunch of scattered points. That's linear regression in a nutshell. We use two magic numbers: 'm' (the slope) and 'c' (where the line crosses the y-axis). Our goal? Find the perfect 'm' and 'c' that make our line fit the points best.

### How Does It Work?

1. **Start with a Guess**: We pick random values for 'm' and 'c'.
2. **Check How Bad We Are**: We use MSE to measure how far off our line is from the real points.
3. **Take a Step in the Right Direction**: We adjust 'm' and 'c' a tiny bit to make our line fit better.
4. **Repeat Until We're Happy**: Keep tweaking until the line barely improves anymore.

### Cool Features in Our Simulation

- **Tolerance**: It's like saying, "If we're not getting much better, let's call it a day."
- **Dynamic Learning**: If we're suddenly doing worse, we slow down our adjustments.
- **Gradient Clipping**: Imagine putting guardrails on a mountain road – it keeps us from making crazy changes.

### The Secret Sauce: The Loss Surface

Picture a 3D landscape where height represents how bad our guess is. Our goal is to find the lowest point in this landscape. For MSE, it looks like a smooth bowl – there's only one bottom to find!

### Why Not Just Solve It Directly?

Sometimes we can! It's called using "normal equations." But imagine trying to solve a million-piece puzzle all at once – sometimes it's easier to tackle it bit by bit, like gradient descent does.

### When Does This Really Shine?

- **Big Data**: When you've got tons of points, gradient descent is often faster.
- **Tricky Data**: Sometimes the math for solving it directly just doesn't work out. Gradient descent doesn't care – it'll keep chugging along.

So there you have it! We're teaching computers to play a super-advanced version of "Hot and Cold" to find the best line through our data. Pretty cool, right?

In [None]:
## Loss function
def mse(y_true, y_pred):
    return sum((t - p)**2 for t, p in zip(y_true, y_pred)) / len(y_true)

## Gradient function
def grad_mse(t, p):
    return -2 * (t - p)

def get_y_pred(model, x_data):
    return [model.forward(x) for x in x_data]

def gradient_descent_step(model, x, t, p, grad_loss, lr, clip_value):
    params = vars(model).items()
    model_grad = model.grad(x)
    for pk, pv in params:
        grad = grad_loss(t, p) * model_grad[pk]
        # Apply gradient clipping
        grad = np.clip(grad, -clip_value, clip_value)
        new_pv = pv - lr * grad
        setattr(model, pk, new_pv)

def create_figure(model, x_data, y_data):
    x_min, x_max = min(x_data), max(x_data)
    y_min, y_max = min(y_data), max(y_data)
    y_pred = get_y_pred(model, x_data)

    fig = go.FigureWidget(
        data=[
            go.Scatter(x=x_data, y=y_data, mode="markers", name="Data", marker=dict(opacity=0.5)),
            go.Scatter(x=x_data, y=y_pred, mode="lines", name="Regression Line"),
            go.Scatter(
                x=[None], y=[None], mode="lines", line=dict(color="gray"), name="Error Lines"
            ),
        ],
        layout=go.Layout(
            title="Interactive Linear Regression",
            xaxis_title="Input",
            yaxis_title="Output",
            xaxis_range=[x_min - 2, x_max + 2],
            yaxis_range=[y_min - 2, y_max + 2],
            xaxis=dict(dtick=5),
            yaxis=dict(dtick=5),
            autosize=False,
            width=16 * 50,
            height=9 * 50,
            shapes=[
                dict(
                    type="line",
                    xref="x",
                    yref="y",
                    x0=x,
                    y0=y_data[i],
                    x1=x,
                    y1=y_pred[i],
                    line=dict(
                        color="gray",
                        width=1,
                    ),
                )
                for i, x in enumerate(x_data)
            ],
            annotations=[
                dict(
                    text=f"Mean Squared Error: {mse(y_data, y_pred):.2f}",
                    x=0.5,
                    y=0.98,
                    xref="paper",
                    yref="paper",
                    showarrow=False,
                )
            ],
        ),
    )
    return fig

def create_loss_vs_param_figure(model, x_data, y_data):
    param_range = np.linspace(-25, 25, 50)
    m_values, c_values = np.meshgrid(param_range, param_range)
    losses = np.zeros_like(m_values)

    for i in range(m_values.shape[0]):
        for j in range(m_values.shape[1]):
            model.m = m_values[i, j]
            model.c = c_values[i, j]
            y_pred = get_y_pred(model, x_data)
            losses[i, j] = mse(y_data, y_pred)

    surface = go.Surface(x=m_values, y=c_values, z=losses, colorscale="Viridis")
    scatter = go.Scatter3d(
        x=[model.m],
        y=[model.c],
        z=[mse(y_data, get_y_pred(model, x_data))],
        mode="markers",
        name="Current Parameters",
        marker=dict(color="rgba(255, 0, 0, 1)", size=5),
    )

    fig = go.FigureWidget(
        data=[surface, scatter],
        layout=go.Layout(
            title="Loss vs Model Parameters",
            scene=dict(
                xaxis=dict(title="m", showticklabels=False),
                yaxis=dict(title="c", showticklabels=False),
                zaxis=dict(title="Loss", showticklabels=False),
            ),
            autosize=False,
            width=16 * 75,
            height=9 * 75,
            margin=dict(l=0, r=0, b=0, t=0),
        ),
    )
    return fig

def plot(x_data, y_data, model):
    fig = create_figure(model, x_data, y_data)
    loss_vs_param_fig = create_loss_vs_param_figure(model, x_data, y_data)
    learning_rate_slider = widgets.FloatLogSlider(value=0.001, base=10, min=-4, max=1, step=0.1, description="Learning Rate:")
    start_button = widgets.Button(description="Start Simulation")
    max_iterations = widgets.IntText(value=5, description="Max Iterations:")
    tol = widgets.FloatText(value=1e-6, description="Tolerance:")
    clip_value = widgets.FloatText(value=1.0, description="Clip Value:")

    def update_plot(change):
        y_pred = get_y_pred(model, x_data)
        loss = mse(y_data, y_pred)
        with fig.batch_update():
            fig.data[1].y = y_pred
            fig.layout.shapes = [
                dict(
                    type="line",
                    xref="x",
                    yref="y",
                    x0=x,
                    y0=y_data[i],
                    x1=x,
                    y1=y_pred[i],
                    line=dict(
                        color="gray",
                        width=2,
                    ),
                )
                for i, x in enumerate(x_data)
            ]
            fig.layout.annotations[0].text = f"Mean Squared Error: {loss:.2f}"

        current_loss = mse(y_data, get_y_pred(model, x_data))
        scatter = loss_vs_param_fig.data[1]
        scatter.x = [model.m]
        scatter.y = [model.c]
        scatter.z = [current_loss]

    def start_simulation(change):
        lr = learning_rate_slider.value
        max_iter = max_iterations.value
        tolerance = tol.value
        clip_val = clip_value.value
        previous_loss = float('inf')
        for iteration in range(max_iter):
            y_pred = get_y_pred(model, x_data)
            grad_loss = grad_mse
            for x, t, p in zip(x_data, y_data, y_pred):
                gradient_descent_step(model, x, t, p, grad_loss, lr, clip_val)
            update_plot(change)
            current_loss = mse(y_data, get_y_pred(model, x_data))
            print(f"Iteration {iteration}: m={model.m}, c={model.c}, loss={current_loss}")
            if abs(previous_loss - current_loss) < tolerance:
                break
            previous_loss = current_loss
            time.sleep(0.1)

    start_button.on_click(start_simulation)

    return widgets.HBox([widgets.VBox([fig, learning_rate_slider, start_button, max_iterations, tol, clip_value], layout=widgets.Layout(border='solid 2px gray', padding='10px')), widgets.VBox([loss_vs_param_fig], layout=widgets.Layout(margin='0 0 0 -50px'))])

class SimpleLinearRegression:
    def __init__(self, m=0.7, c=2.2):
        self.m = m
        self.c = c
    
    def grad(self, x):
        return {"m": x, "c": 1}
    
    def forward(self, x):
        return self.m * x + self.c

# Generate a more interesting dataset with a slight quadratic trend and noise
np.random.seed(42)
x_data = np.linspace(0, 10, 20)
y_data = 2 * x_data + 1 + np.random.normal(0, 1, x_data.shape)

model = SimpleLinearRegression()
plot_widget = plot(x_data, y_data, model)
plot_widget

## MAE Loss Function in Linear Regression

This simulation explores gradient descent for optimizing linear regression parameters by minimizing the Mean Absolute Error (MAE) loss. Unlike Mean Squared Error (MSE), MAE uses absolute error values, reducing sensitivity to outliers.

### Key Concepts

1. **MAE Loss Function**: Average of absolute differences between predicted and true values.
2. **Gradient Calculation**: For MAE, the gradient is -1 if the true value exceeds the predicted value, and 1 otherwise.
3. **Linear Regression Parameters**:
   - m: Slope
   - c: Intercept
   - Model: y = m * x + c

4. **Function Surface**:
   - Represents loss as a function of m and c
   - MAE surface has a characteristic "V" shape

### Gradient Descent for MAE

The algorithm updates parameters by moving opposite to the loss function's gradient. Gradient clipping prevents excessive updates.

### Visualization Features

1. **Loss Contours**:
   - Horizontal lines on the m axis show regions of constant loss
   - Reflect the "V" shape of the MAE surface
   - Aid in visualizing the gradient descent path

2. **Gradient Descent Path**:
   - Red dot shows optimization progress
   - Moves along contours towards minimum loss

### MAE vs. MSE

- MAE surface has linear sections forming a "V" shape
- MSE surface is smooth and convex

This simulation demonstrates gradient descent with MAE, showcasing convergence to optimal parameters and comparing it with normal equations. The visualization of the loss surface and parameter updates provides insights into gradient descent's behavior and advantages in various scenarios.

In [None]:
## Loss function
def mae(y_true, y_pred):
    return sum(abs(t - p) for t, p in zip(y_true, y_pred)) / len(y_true)

## Gradient function
def grad_mae(t, p):
    return -1 if t > p else 1

def get_y_pred(model, x_data):
    return [model.forward(x) for x in x_data]

def gradient_descent_step(model, x, t, p, grad_loss, lr, clip_value):
    params = vars(model).items()
    model_grad = model.grad(x)
    for pk, pv in params:
        grad = grad_loss(t, p) * model_grad[pk]
        # Apply gradient clipping
        grad = np.clip(grad, -clip_value, clip_value)
        new_pv = pv - lr * grad
        setattr(model, pk, new_pv)

def create_figure(model, x_data, y_data):
    x_min, x_max = min(x_data), max(x_data)
    y_min, y_max = min(y_data), max(y_data)
    y_pred = get_y_pred(model, x_data)

    fig = go.FigureWidget(
        data=[
            go.Scatter(x=x_data, y=y_data, mode="markers", name="Data", marker=dict(opacity=0.5)),
            go.Scatter(x=x_data, y=y_pred, mode="lines", name="Regression Line"),
            go.Scatter(
                x=[None], y=[None], mode="lines", line=dict(color="gray"), name="Error Lines"
            ),
        ],
        layout=go.Layout(
            title="Interactive Linear Regression",
            xaxis_title="Input",
            yaxis_title="Output",
            xaxis_range=[x_min - 2, x_max + 2],
            yaxis_range=[y_min - 2, y_max + 2],
            xaxis=dict(dtick=5),
            yaxis=dict(dtick=5),
            autosize=False,
            width=16 * 50,
            height=9 * 50,
            shapes=[
                dict(
                    type="line",
                    xref="x",
                    yref="y",
                    x0=x,
                    y0=y_data[i],
                    x1=x,
                    y1=y_pred[i],
                    line=dict(
                        color="gray",
                        width=1,
                    ),
                )
                for i, x in enumerate(x_data)
            ],
            annotations=[
                dict(
                    text=f"Mean Absolute Error: {mae(y_data, y_pred):.2f}",
                    x=0.5,
                    y=0.98,
                    xref="paper",
                    yref="paper",
                    showarrow=False,
                )
            ],
        ),
    )
    return fig

def create_loss_vs_param_figure(model, x_data, y_data):
    param_range = np.linspace(-25, 25, 50)
    m_values, c_values = np.meshgrid(param_range, param_range)
    losses = np.zeros_like(m_values)

    for i in range(m_values.shape[0]):
        for j in range(m_values.shape[1]):
            model.m = m_values[i, j]
            model.c = c_values[i, j]
            y_pred = get_y_pred(model, x_data)
            losses[i, j] = mae(y_data, y_pred)

    surface = go.Surface(x=m_values, y=c_values, z=losses, colorscale="Viridis", contours={"z": {"show": True, "usecolormap": True, "highlightcolor": "limegreen", "project": {"z": True}}})
    scatter = go.Scatter3d(
        x=[model.m],
        y=[model.c],
        z=[mae(y_data, get_y_pred(model, x_data))],
        mode="markers",
        name="Current Parameters",
        marker=dict(color="rgba(255, 0, 0, 1)", size=5),
    )

    fig = go.FigureWidget(
        data=[surface, scatter],
        layout=go.Layout(
            title="Loss vs Model Parameters",
            scene=dict(
                xaxis=dict(title="m", showticklabels=False),
                yaxis=dict(title="c", showticklabels=False),
                zaxis=dict(title="Loss", showticklabels=False),
            ),
            autosize=False,
            width=16 * 75,
            height=9 * 75,
            margin=dict(l=0, r=0, b=0, t=0),
        ),
    )
    return fig

def plot(x_data, y_data, model):
    fig = create_figure(model, x_data, y_data)
    loss_vs_param_fig = create_loss_vs_param_figure(model, x_data, y_data)
    learning_rate_slider = widgets.FloatLogSlider(value=0.001, base=10, min=-4, max=1, step=0.1, description="Learning Rate:")
    start_button = widgets.Button(description="Start Simulation")
    max_iterations = widgets.IntText(value=5, description="Max Iterations:")
    tol = widgets.FloatText(value=1e-6, description="Tolerance:")
    clip_value = widgets.FloatText(value=1.0, description="Clip Value:")

    def update_plot(change):
        y_pred = get_y_pred(model, x_data)
        loss = mae(y_data, y_pred)
        with fig.batch_update():
            fig.data[1].y = y_pred
            fig.layout.shapes = [
                dict(
                    type="line",
                    xref="x",
                    yref="y",
                    x0=x,
                    y0=y_data[i],
                    x1=x,
                    y1=y_pred[i],
                    line=dict(
                        color="gray",
                        width=2,
                    ),
                )
                for i, x in enumerate(x_data)
            ]
            fig.layout.annotations[0].text = f"Mean Absolute Error: {loss:.2f}"

        current_loss = mae(y_data, get_y_pred(model, x_data))
        scatter = loss_vs_param_fig.data[1]
        scatter.x = [model.m]
        scatter.y = [model.c]
        scatter.z = [current_loss]

    def start_simulation(change):
        lr = learning_rate_slider.value
        max_iter = max_iterations.value
        tolerance = tol.value
        clip_val = clip_value.value
        previous_loss = float('inf')
        for iteration in range(max_iter):
            y_pred = get_y_pred(model, x_data)
            grad_loss = grad_mae
            for x, t, p in zip(x_data, y_data, y_pred):
                gradient_descent_step(model, x, t, p, grad_loss, lr, clip_val)
            update_plot(change)
            current_loss = mae(y_data, get_y_pred(model, x_data))
            print(f"Iteration {iteration}: m={model.m}, c={model.c}, loss={current_loss}")
            if abs(previous_loss - current_loss) < tolerance:
                break
            previous_loss = current_loss
            time.sleep(0.1)

    start_button.on_click(start_simulation)

    return widgets.HBox([widgets.VBox([fig, learning_rate_slider, start_button, max_iterations, tol, clip_value], layout=widgets.Layout(border='solid 2px gray', padding='10px')), widgets.VBox([loss_vs_param_fig], layout=widgets.Layout(margin='0 0 0 -50px'))])

class SimpleLinearRegression:
    def __init__(self, m=0.7, c=2.2):
        self.m = m
        self.c = c
    
    def grad(self, x):
        return {"m": x, "c": 1}
    
    def forward(self, x):
        return self.m * x + self.c

# Generate a more interesting dataset with a slight quadratic trend and noise
np.random.seed(42)
x_data = np.linspace(0, 10, 20)
y_data = 2 * x_data + 1 + np.random.normal(0, 1, x_data.shape)

model = SimpleLinearRegression()
plot_widget = plot(x_data, y_data, model)
plot_widget