In [43]:
%matplotlib inline

In [15]:
import jax
import numpy as np
import seaborn as sns
import jax.numpy as jnp

import matplotlib.pyplot as plt
import ipywidgets as widgets

from IPython.display import display
from ipywidgets import VBox, HBox, interactive_output

from sklearn.utils import gen_batches
from sklearn.linear_model import LinearRegression
from sklearn.datasets import make_regression
from jax import random, grad, lax, make_jaxpr, jit, vmap, value_and_grad

np.random.seed(69)

In [50]:
def model(w, b, X):
    return jnp.dot(X, w) + b

@jit
def loss_fn(w, b, X, y):    
    y_hat = model(w, b, X)
    
    return jnp.sqrt(jnp.mean((y - y_hat)**2))


def update(w, b, X, y, lr=0.01):
    grad_fn = value_and_grad(loss_fn, argnums=(0, 1))
    loss_val, (grad_w, grad_b) = grad_fn(w, b, X, y)
    
    return w - lr * grad_w, b - lr * grad_b, loss_val


def generate_data(num_points, noise_level, random_state: int = 69):
    return make_regression(
        n_features=1,
        n_informative=1,
        n_samples=num_points,
        noise=noise_level,
        random_state=random_state
    )


def show(
        number_of_samples: int,
        noise: float = 0.01,        
        iterations: int = 100,
        lr: float = 0.01,        
        seed: int = 69
):
    dataset, target = generate_data(
        num_points=number_of_samples,
        noise_level=noise,
        random_state=seed
    )
    dataset = jnp.array(dataset).astype(jnp.float32)
    target = jnp.array(target).astype(jnp.float32)
    
    # random coefficients
    key = random.PRNGKey(seed)
    key, subkey = random.split(key)
    
    losses = []
    w = random.normal(key, (1, ), dtype=jnp.float32)
    b = random.normal(subkey, dtype=jnp.float32)
    
    for _ in range(iterations):
        w, b, loss_val = update(w, b, dataset, target, lr=lr)
        losses.append(loss_val)
    
    # plot regression
    x_regression_line = jnp.linspace(dataset.min(), dataset.max(), number_of_samples).reshape(-1, 1)
    y_regression_line = model(w, b, x_regression_line)
    
    fig, ax = plt.subplots(1, 2, figsize=(14, 6))
    
    ax[0].scatter(dataset, target, color='orange', label='data', s=10)
    ax[0].plot(x_regression_line, y_regression_line, color='blue', label='regression line')
    ax[0].legend()
    
    ax[0].set_xlabel('X')
    ax[0].set_ylabel('y')
    
    ax[1].plot(np.arange(len(losses)), losses, color='purple', label="loss")
    ax[1].set_xlabel('Iterations')
    ax[1].set_ylabel('Loss')
    ax[1].legend()
    
    plt.show()

In [51]:
NUMBER_OF_SAMPLES_SLIDER = widgets.IntSlider(
    value=300,
    min=10, 
    max=5_000, 
    step=10, 
    description='Number of samples',
    continuous_update=False,
    style={'description_width': 'initial'}
)

NOISE_SLIDER = widgets.FloatSlider(
    value=20.0,
    min=0.1, max=1000.0, 
    step=0.05, 
    description='Noise',
    continuous_update=False,
    style={'description_width': 'initial'}
)

### Model parameters

ITERATIONS_SLIDER = widgets.IntSlider(
    value=300,
    min=10, max=10_000, 
    step=10, 
    description='Iterations',
    continuous_update=False,
    style={'description_width': 'initial'}
)

LEARNING_RATE_FLOAT = widgets.BoundedFloatText(
    value=0.1,
    min=1e-6, max=10.0, 
    step=0.001, 
    description='Learning rate',
    continuous_update=False,
    style={'description_width': 'initial'}
)

SEED_ITERATION = widgets.IntSlider(
    value=69,
    min=0, max=2_000, 
    step=1, 
    description='Seed',
    continuous_update=False,
    style={'description_width': 'initial'}
)

poly_fit_reg_out = interactive_output(
    show,
    {
        'number_of_samples': NUMBER_OF_SAMPLES_SLIDER,
        'noise': NOISE_SLIDER,
        'iterations': ITERATIONS_SLIDER,
        'lr': LEARNING_RATE_FLOAT,
        'seed': SEED_ITERATION
    }
)

display(
    HBox([
        VBox([
            NUMBER_OF_SAMPLES_SLIDER,
            NOISE_SLIDER,
        ], layout=widgets.Layout(width='30%')),
        VBox([
            ITERATIONS_SLIDER,
            LEARNING_RATE_FLOAT,
            SEED_ITERATION
        ], layout=widgets.Layout(width='30%')),        
    ], layout=widgets.Layout(width='100%')),
    poly_fit_reg_out
)

HBox(children=(VBox(children=(IntSlider(value=300, continuous_update=False, description='Number of samples', m…

Output()