<a href="https://colab.research.google.com/github/donlap/stat424/blob/main/6_Gradient_Descent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Part 1: Visualization of Optimization Algorithms

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np
from IPython.display import HTML
import time
from tqdm import tqdm

def plot_loss(loss_fn, X_data, y_data,
                        xlim=(-2.0, 7.0), ylim=(-1.0, 4.0),
                        title="Loss Landscape", target_params=None,
                        show_plot=True):
    """
    Plots simple contour lines with numerical labels.
    """
    w_vals = jnp.linspace(xlim[0], xlim[1], 100)
    b_vals = jnp.linspace(ylim[0], ylim[1], 100)
    W, B = jnp.meshgrid(w_vals, b_vals)

    def grid_loss(w, b):
        return loss_fn({'w': w, 'b': b}, X_data, y_data)

    Z = jax.vmap(jax.vmap(grid_loss, in_axes=(0, 0)), in_axes=(0, 0))(W, B)

    fig, ax = plt.subplots(figsize=(9, 6))

    cp = ax.contour(W, B, Z, levels=30, cmap='viridis', linewidths=1.5)
    ax.clabel(cp, inline=True, fontsize=8, fmt='%.1f')

    if target_params:
        ax.plot(target_params['w'], target_params['b'], 'r*', markersize=15,
                label='Global Min')
        ax.legend(loc='upper right')

    ax.set_title(title)
    ax.set_xlabel('Weight (w)')
    ax.set_ylabel('Bias (b)')
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.grid(True, linestyle='--', alpha=0.5)

    if show_plot:
        plt.show()
        return None
    else:
        return fig, ax


def animate_optimization(loss_fn, history, X_data, y_data,
                         xlim=(-2.0, 7.0), ylim=(-1.0, 4.0),
                         title="Optimization Path", target_params=None):
    """
    Animates history on top of the plot_loss_landscape background.
    """
    fig, ax = plot_loss(loss_fn, X_data, y_data,
                        xlim=xlim, ylim=ylim,
                        title=title, target_params=target_params,
                        show_plot=False)

    line, = ax.plot([], [], 'r-', alpha=0.8, lw=2, label='Trajectory')
    point, = ax.plot([], [], 'ro', markersize=8, markeredgecolor='black', label='Current State')

    history = jnp.array(history)

    def init():
        line.set_data([], [])
        point.set_data([], [])
        return line, point

    def update(frame):
        current_path = history[:frame+1]

        line.set_data(current_path[:, 0], current_path[:, 1])
        point.set_data([history[frame, 0]], [history[frame, 1]])
        return line, point

    anim = FuncAnimation(fig, update, frames=len(history),
                         init_func=init, blit=True, interval=50)

    plt.close()
    return HTML(anim.to_jshtml())

## Part 2: Neural Networks in JAX

In [None]:
!pip install -q equinox

In [None]:
import equinox as eqx
import optax  # Standard optimizer library for JAX

# We use tensorflow/keras just to download the data easily
from tensorflow.keras.datasets import cifar10

# --- 1. DATA PREPARATION ---
def load_data():
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()

    x_train = x_train.astype('float32') / 255.0
    x_test = x_test.astype('float32') / 255.0

    y_train = y_train.flatten()
    y_test = y_test.flatten()

    subset_size = 5000
    return x_train[:subset_size], y_train[:subset_size], x_test[:1000], y_test[:1000]

X_train, y_train, X_test, y_test = load_data()

# Class names for CIFAR-10
class_names = ['plane', 'car', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

print(f"Training Shape: {X_train.shape}")

In [30]:
# Loss Function (Cross Entropy)
def loss_fn(model, x, y):
    logits = jax.vmap(model)(x)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, y)
    return jnp.mean(loss)

# Accuracy Helper
def compute_accuracy(model, x, y):
    logits = jax.vmap(model)(x)
    preds = jnp.argmax(logits, axis=-1)
    return jnp.mean(preds == y)