In [1]:
import torch
import numpy
import matplotlib.pyplot

# source : https://github.com/mrdbourke/pytorch-deep-learning/blob/main/helper_functions.py

def plot_decision_boundary(model: torch.nn.Module, X: torch.Tensor, y: torch.Tensor, resolution = 1000):
    """Plots decision boundaries of model predicting on X in comparison to y.

    Source - https://madewithml.com/courses/foundations/neural-networks/ (with modifications)
    """
    # Put everything to CPU (works better with NumPy + Matplotlib)
    model.to("cpu")
    X, y = X.to("cpu"), y.to("cpu")

    # Setup prediction boundaries and grid
    x_min, x_max = X[:, 0].min() - 0.1, X[:, 0].max() + 0.1
    y_min, y_max = X[:, 1].min() - 0.1, X[:, 1].max() + 0.1

    # linspace(start, stop, num) -> Return evenly spaced numbers over a specified interval.
    # meshgrid(x, y) -> Return coordinate matrices from coordinate vectors.
    xx, yy = numpy.meshgrid(numpy.linspace(x_min, x_max, resolution + 1), numpy.linspace(y_min, y_max, resolution + 1))

    # Make features
    # ravel() -> Return a contiguous flattened array.
    X_to_pred_on = torch.from_numpy(numpy.column_stack((xx.ravel(), yy.ravel()))).float()

    # Make predictions
    model.eval()
    with torch.inference_mode():
        y_logits = model(X_to_pred_on)

    # Test for multi-class or binary and adjust logits to prediction labels
    if len(torch.unique(y)) > 2:
        y_pred = torch.softmax(y_logits, dim=1).argmax(dim=1)  # mutli-class
    else:
        y_pred = torch.round(torch.sigmoid(y_logits))  # binary

    # Reshape preds and plot
    y_pred = y_pred.reshape(xx.shape).detach().numpy()
    matplotlib.pyplot.contourf(xx, yy, y_pred, cmap=matplotlib.pyplot.cm.RdYlBu, alpha=0.7)
    matplotlib.pyplot.scatter(X[:, 0], X[:, 1], c=y, s=40, cmap=matplotlib.pyplot.cm.RdYlBu)
    matplotlib.pyplot.xlim(xx.min(), xx.max())
    matplotlib.pyplot.ylim(yy.min(), yy.max())