<div style="display: flex; justify-content: space-between; align-items: center;">
    <div style="text-align: left; flex: 4">
        <strong>Author:</strong> Amirhossein Heydari ‚Äî 
        üìß <a href="mailto:amirhosseinheydari78@gmail.com">amirhosseinheydari78@gmail.com</a> ‚Äî 
        üêô <a href="https://github.com/mr-pylin/pytorch-workshop" target="_blank" rel="noopener">github.com/mr-pylin</a>
    </div>
    <div style="text-align: right; flex: 1;">
        <a href="https://pytorch.org/" target="_blank" rel="noopener noreferrer">
            <img src="../assets/images/pytorch/logo/pytorch-logo-dark.svg" 
                 alt="PyTorch Logo"
                 style="max-height: 48px; width: auto; background-color: #ffffff; border-radius: 8px;">
        </a>
    </div>
</div>
<hr>


**Table of contents**<a id='toc0_'></a>    
- [Dependencies](#toc1_)    
- [Pre-Processin Dataset](#toc2_)    
  - [Create Artificial Dataset](#toc2_1_)    
  - [Split dataset into trainset & testset](#toc2_2_)    
  - [Normalization](#toc2_3_)    
  - [Dataset](#toc2_4_)    
  - [Dataloader](#toc2_5_)    
- [Radial Basis Function Networks](#toc3_)    
  - [Common Radial Basis Functions](#toc3_1_)    
    - [Gaussian](#toc3_1_1_)    
    - [Multiquadric](#toc3_1_2_)    
    - [Inverse Multiquadric](#toc3_1_3_)    
    - [Inverse Quadratic](#toc3_1_4_)    
    - [Thin-Plate Spline](#toc3_1_5_)    
    - [RBF Feature Mapping Visualization](#toc3_1_6_)    
  - [RBF Networks](#toc3_2_)    
    - [Single Layer Architecture](#toc3_2_1_)    
    - [Multi Layers Architecture](#toc3_2_2_)    
  - [Model Training Pipeline](#toc3_3_)    
    - [Set up model and Hyperparameters](#toc3_3_1_)    
    - [Train & Validation Loop](#toc3_3_2_)    
    - [Test Loop](#toc3_3_3_)    

<!-- vscode-jupyter-toc-config
	numbering=false
	anchor=true
	flat=false
	minLevel=1
	maxLevel=6
	/vscode-jupyter-toc-config -->
<!-- THIS CELL WILL BE REPLACED ON TOC UPDATE. DO NOT WRITE YOUR TEXT IN THIS CELL -->

# <a id='toc1_'></a>[Dependencies](#toc0_)


In [None]:
import shutil

import matplotlib.pyplot as plt
import torch
from IPython.display import HTML
from matplotlib.animation import FuncAnimation
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
from torchinfo import summary
from torchmetrics.classification import MulticlassAccuracy

In [None]:
# disable automatic figure display (plt.show() required)  
# this ensures consistency with .py scripts and gives full control over when plots appear
plt.ioff()

In [None]:
# choose animation display method
display_backend = FuncAnimation.to_html5_video if shutil.which("ffmpeg") else FuncAnimation.to_jshtml
display_backend

In [None]:
# set a seed for deterministic results
seed = 0
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
# check if cuda is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# log
device

# <a id='toc2_'></a>[Pre-Processin Dataset](#toc0_)


## <a id='toc2_1_'></a>[Create Artificial Dataset](#toc0_)


In [None]:
# generate a 2D classification dataset
n_samples = 250
n_classes = 3

X, y = make_classification(
    n_samples=n_samples,
    n_features=2,
    n_informative=2,
    n_redundant=0,
    n_classes=n_classes,
    n_clusters_per_class=1,
    random_state=seed,
)

## <a id='toc2_2_'></a>[Split dataset into trainset & testset](#toc0_)


In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=seed)

## <a id='toc2_3_'></a>[Normalization](#toc0_)


In [None]:
scaler = MinMaxScaler(feature_range=(-1, 1))
scaler.fit(X_train)

# log
print(f"min of trainset: {X_train.min(axis=0)}")
print(f"max of trainset: {X_train.max(axis=0)}")

In [None]:
X_train = scaler.transform(X_train)
X_test = scaler.transform(X_test)

In [None]:
# plot
plt.scatter(X_train[:, 0], X_train[:, 1], marker="o", c=y_train, s=25, edgecolor="k", label="trainset")
plt.scatter(X_test[:, 0], X_test[:, 1], marker=",", c=y_test, s=25, edgecolor="k", label="testset")
plt.legend()
plt.title(f"2D dataset with {n_samples} samples")
plt.xlabel("feature 1")
plt.ylabel("feature 2")
plt.show()

## <a id='toc2_4_'></a>[Dataset](#toc0_)


In [None]:
X_train = torch.tensor(X_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.int64)
y_test = torch.tensor(y_test, dtype=torch.int64)

trainset = TensorDataset(X_train, y_train)
testset = TensorDataset(X_test, y_test)

## <a id='toc2_5_'></a>[Dataloader](#toc0_)


In [None]:
batch_size = 4

trainloader = DataLoader(dataset=trainset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(dataset=testset, batch_size=batch_size, shuffle=False)

# <a id='toc3_'></a>[Radial Basis Function Networks](#toc0_)


## <a id='toc3_1_'></a>[Common Radial Basis Functions](#toc0_)


### <a id='toc3_1_1_'></a>[Gaussian](#toc0_)

$$
\phi(x) = e^{- \frac{\|x - c\|^2}{2\sigma^2}}
$$


In [None]:
def rbf_gaussian_multi(x: torch.Tensor, centers: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
    """
    x       : [num_samples, features]
    centers : [num_centers, features]
    sigma   : [num_centers] or scalar
    returns : [num_samples, num_centers] RBF activations
    """

    # compute pairwise distances [num_samples, num_centers]
    dists = torch.cdist(x, centers)  # shape: [num_samples, num_centers]

    # same as above code
    # x_exp = x.unsqueeze(1)                             # [num_samples, 1, features]
    # c_exp = centers.unsqueeze(0)                       # [1, num_centers, features]
    # dists = ((x_exp - c_exp) ** 2).sum(dim=-1).sqrt()  # [num_samples, num_centers]

    # make sigma broadcastable if it's a 1D tensor
    if sigma.ndim == 1:
        sigma = sigma.unsqueeze(0)  # shape [1, num_centers]

    return torch.exp(-(dists**2) / (2 * sigma**2))

### <a id='toc3_1_2_'></a>[Multiquadric](#toc0_)

$$
\phi(x) = \sqrt{1 + (\|x - c\|/\sigma)^2}
$$


In [None]:
def rbf_multiquadric_multi(x: torch.Tensor, centers: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
    """
    x       : [num_samples, features]
    centers : [num_centers, features]
    sigma   : [num_centers] or scalar
    returns : [num_samples, num_centers] RBF activations
    """
    dists = torch.cdist(x, centers)
    if sigma.ndim == 1:
        sigma = sigma.unsqueeze(0)
    return torch.sqrt(1 + (dists / sigma) ** 2)

### <a id='toc3_1_3_'></a>[Inverse Multiquadric](#toc0_)

$$
\phi(x) = \frac{1}{\sqrt{1 + (\|x - c\|/\sigma)^2}}
$$


In [None]:
def rbf_inverse_multiquadric_multi(x: torch.Tensor, centers: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
    """
    x       : [num_samples, features]
    centers : [num_centers, features]
    sigma   : [num_centers] or scalar
    returns : [num_samples, num_centers] RBF activations
    """
    dists = torch.cdist(x, centers)
    if sigma.ndim == 1:
        sigma = sigma.unsqueeze(0)
    return 1.0 / torch.sqrt(1 + (dists / sigma) ** 2)

### <a id='toc3_1_4_'></a>[Inverse Quadratic](#toc0_)

$$
\phi(x) = \frac{1}{1 + (\|x - c\|/\sigma)^2}
$$


In [None]:
def rbf_inverse_quadratic_multi(x: torch.Tensor, centers: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
    """
    x       : [num_samples, features]
    centers : [num_centers, features]
    sigma   : [num_centers] or scalar
    returns : [num_samples, num_centers] RBF activations
    """
    dists = torch.cdist(x, centers)
    if sigma.ndim == 1:
        sigma = sigma.unsqueeze(0)
    return 1.0 / (1 + (dists / sigma) ** 2)

### <a id='toc3_1_5_'></a>[Thin-Plate Spline](#toc0_)

$$
\phi(x) = \|x - c\|^2 \log(\|x - c\|)
$$


In [None]:
def rbf_thin_plate_spline_multi(x: torch.Tensor, centers: torch.Tensor) -> torch.Tensor:
    """
    x       : [num_samples, features]
    centers : [num_centers, features]
    returns : [num_samples, num_centers] RBF activations
    """
    dists = torch.cdist(x, centers)
    dists = torch.clamp(dists, min=1e-10)  # avoid log(0)
    return dists**2 * torch.log(dists)

### <a id='toc3_1_6_'></a>[RBF Feature Mapping Visualization](#toc0_)


In [None]:
# define 2 centers for visualization
centers = torch.tensor([[-0.5, 0.0], [0.5, 0.0]])
sigma = torch.tensor([0.3, 0.3])

In [None]:
# list of kernels and their corresponding multi-center functions
kernels = [
    ("Gaussian", rbf_gaussian_multi),
    ("Multiquadric", rbf_multiquadric_multi),
    ("Inverse Multiquadric", rbf_inverse_multiquadric_multi),
    ("Inverse Quadratic", rbf_inverse_quadratic_multi),
    ("Thin-Plate Spline", rbf_thin_plate_spline_multi),
]

n_cols = 1 + len(kernels)

In [None]:
# create figure with n_kernels columns
fig, axes = plt.subplots(1, n_cols, figsize=(5 * n_cols, 4))

# original input features
axes[0].scatter(X_train[:, 0], X_train[:, 1], marker="o", c=y_train, s=25, edgecolor="k")
axes[0].set_title("Original features")
axes[0].set_xlabel("x1")
axes[0].set_ylabel("x2")
axes[0].grid(alpha=0.3)

# RBF feature spaces
for i, (name, func) in enumerate(kernels, start=1):
    if name == "Thin-Plate Spline":
        phi = func(X_train, centers)
    else:
        phi = func(X_train, centers, sigma)
    phi1, phi2 = phi[:, 0], phi[:, 1]
    axes[i].scatter(phi1, phi2, marker="o", c=y_train, s=25, edgecolor="k")
    axes[i].set_title(name)
    axes[i].set_xlabel("phi1")
    axes[i].set_ylabel("phi2")
    axes[i].grid(alpha=0.3)

plt.show()

## <a id='toc3_2_'></a>[RBF Networks](#toc0_)

- Radial Basis Function (RBF) networks are a family of feed-forward models used for nonlinear function approximation.
- They transform inputs using localized, distance-based activation functions such as Gaussian kernels.
- The network output is formed by combining these transformed responses, enabling smooth interpolation and efficient learning.

**RBF Parameters**:
- The centers $\mu_j$ of the radial units are typically selected using clustering methods such as K-Means, sampled directly from training data, or defined as learnable parameters during training.
- The spread $\sigma_j$ of each unit controls the width of the radial basis function and may be fixed heuristically, derived from cluster statistics, or optimized during training.
- A common rule for assigning spreads uses the maximum pairwise distance among centers:
  $$
  \sigma = \frac{d_{\text{max}}}{\sqrt{2k}}
  $$
  where
  - $d_{\text{max}}$ is the largest Euclidean distance between any two centers,
  - $k$ is the total number of radial basis units.


In [None]:
class RBFLayer(nn.Module):
    def __init__(self, in_features: int, out_features: int, kernel: str = "gaussian"):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.centers = nn.Parameter(torch.Tensor(out_features, in_features))
        self.sigmas = nn.Parameter(torch.Tensor(out_features))
        self.kernel = kernel
        self.reset_parameters()

        # mapping kernel names to functions
        self.rbf_functions = {
            "gaussian": rbf_gaussian_multi,
            "multiquadric": rbf_multiquadric_multi,
            "inverse_multiquadric": rbf_inverse_multiquadric_multi,
            "inverse_quadratic": rbf_inverse_quadratic_multi,
            "thin_plate_spline": rbf_thin_plate_spline_multi,
        }

    def reset_parameters(self) -> None:
        nn.init.uniform_(self.centers, -1.0, 1.0)
        nn.init.uniform_(self.sigmas, 0.1, 1.0)  # avoid zero sigma

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        if self.kernel not in self.rbf_functions:
            raise ValueError(f"Unknown kernel type {self.kernel}")

        rbf_fn = self.rbf_functions[self.kernel]
        return rbf_fn(input, self.centers, self.sigmas)

### <a id='toc3_2_1_'></a>[Single Layer Architecture](#toc0_)

- An RBF network with a **single hidden layer** transforms inputs using **radial basis functions (RBFs)** centered at learned prototypes.
- Each hidden unit computes a **localized activation** based on the distance between the input and its center, most commonly via a Gaussian kernel.
- The output layer forms a **linear combination** of these localized responses, enabling efficient approximation of nonlinear mappings with simple optimization.

<div style="text-align: center; padding-top: 10px;">
    <img src="../assets/images/original/rbf/rbf-original.svg" alt="rbf-original.svg" style="min-width: 512px; width: 60%; height: auto; border-radius: 16px;">
    <p><em>Figure 1: Single-layer Radial Basis Function Network</em></p>
</div>

**Calculating the number of parameters**:

<table style="margin: 0 auto; text-align:center;">
  <thead>
    <tr>
      <th colspan="2">RBF layer parameters</th>
      <th colspan="2">Output layer parameters</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Centers (M)</td>
      <td>Widths (œÉ)</td>
      <td>Weights (W)</td>
      <td>Bias</td>
    </tr>
    <tr>
      <td>n √ó r</td>
      <td>r</td>
      <td>r √ó o</td>
      <td>o</td>
    </tr>
  </tbody>
  <tfoot>
    <tr>
      <td colspan="2">(n + 1) √ó r </td>
      <td colspan="2">(r + 1) √ó o</td>
    </tr>
  </tfoot>
</table>

- $n$: input dimensionality (number of features)  
- $r$: number of RBF kernels (number of basis functions)  
- $o$: number of output neurons  

---

**Input matrix**:

$$
X =
\begin{bmatrix}
x_{1}^{(1)} & x_{2}^{(1)} & \cdots & x_{n}^{(1)} \\
x_{1}^{(2)} & x_{2}^{(2)} & \cdots & x_{n}^{(2)} \\
\vdots & \vdots & \ddots & \vdots \\
x_{1}^{(m)} & x_{m}^{(m)} & \cdots & x_{n}^{(m)}
\end{bmatrix}_{m \times n}
$$

- $m$: number of samples  
- $n$: number of input features  

---

**Hidden layer (RBF activations)**:

$$
\boldsymbol{\phi}(\mathbf{X}) =
\begin{bmatrix}
1 & 1 & \cdots & 1 \\
\phi_1\!\left(\mathbf{x}^{(1)}\right) &
\phi_1\!\left(\mathbf{x}^{(2)}\right) &
\cdots &
\phi_1\!\left(\mathbf{x}^{(m)}\right) \\
\phi_2\!\left(\mathbf{x}^{(1)}\right) &
\phi_2\!\left(\mathbf{x}^{(2)}\right) &
\cdots &
\phi_2\!\left(\mathbf{x}^{(m)}\right) \\
\vdots & \vdots & \ddots & \vdots \\
\phi_r\!\left(\mathbf{x}^{(1)}\right) &
\phi_r\!\left(\mathbf{x}^{(2)}\right) &
\cdots &
\phi_r\!\left(\mathbf{x}^{(m)}\right)
\end{bmatrix}
\in \mathbb{R}^{(r+1) \times m}
$$

with Gaussian RBFs:

$$
\phi_j\!\left(\mathbf{x}^{(i)}\right) =
\exp\!\left(
-\frac{\left\lVert \mathbf{x}^{(i)} - \boldsymbol{\mu}_j \right\rVert^2}{2\sigma_j^2}
\right)
$$

- The first row corresponds to the **bias term**.
- Each RBF neuron responds strongly only to inputs **near its center**, enforcing locality.

---

**Output layer weights**:

$$
\mathbf{W} =
\begin{bmatrix}
w_{0}^{(1)} & w_{0}^{(2)} & \cdots & w_{0}^{(o)} \\
w_{1}^{(1)} & w_{1}^{(2)} & \cdots & w_{1}^{(o)} \\
\vdots & \vdots & \ddots & \vdots \\
w_{r}^{(1)} & w_{r}^{(2)} & \cdots & w_{r}^{(o)}
\end{bmatrix}
\in \mathbb{R}^{(r+1) \times o}
$$

- $o$: number of output neurons (classes or regression targets)

---

**Network output**:

$$
\mathbf{f}(\mathbf{X})
=
\boldsymbol{\phi}(\mathbf{X})^{\top}
\mathbf{W}
\in
\mathbb{R}^{m \times o}
$$

- The model is **linear in the output weights** and **nonlinear in the input space**.
- This structure allows efficient training of $W$ using closed-form or standard linear optimization methods.


In [None]:
class RBFNet(nn.Module):
    def __init__(self, in_features: int, hidden_features: int, out_features: int, kernel: str = "gaussian"):
        super().__init__()
        self.rbf = RBFLayer(in_features, hidden_features, kernel)
        self.linear = nn.Linear(hidden_features, out_features)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return self.linear(self.rbf(input))

### <a id='toc3_2_2_'></a>[Multi Layers Architecture](#toc0_)
 
 
- A multi-layer RBF network extends the classical single-hidden-layer design by **stacking multiple RBF-based hidden layers**.
- Each hidden layer performs a **nonlinear, localized transformation** of its input space, enabling hierarchical feature extraction.
- Deeper RBF architectures can model **more complex decision boundaries** with fewer kernels per layer compared to a shallow RBF network.

<div style="text-align: center; padding-top: 10px;">
    <img src="../assets/images/original/rbf/rbf-extended.svg" alt="rbf-extended.svg" style="min-width: 512px; width: 75%; height: auto; border-radius: 16px;">
    <p><em>Figure 2: Multi-layer Radial Basis Function Network</em></p>
</div>

**Key idea**:

- Instead of directly mapping inputs to outputs using a single RBF layer, intermediate RBF layers learn **progressively abstract representations**.
- Each layer applies RBF transformations to the activations of the previous layer, not directly to raw input data.


In [None]:
class MultiLayerRBFNet(nn.Module):
    def __init__(
        self,
        in_features: int,
        hidden_features: list[int],
        out_features: int,
        kernels: list[str] | None = None,
    ):
        """
        in_features     : number of input features
        hidden_features : list of hidden neurons per RBF layer, e.g., [64, 32]
        out_features    : number of output neurons
        kernels         : list of kernel names per layer, default all "gaussian"
        """
        super().__init__()
        self.num_layers = len(hidden_features)
        if kernels is None:
            kernels = ["gaussian"] * self.num_layers
        assert len(kernels) == self.num_layers, "kernels list must match hidden_features length"

        # create list of RBF layers
        layers = []
        prev_features = in_features
        for size, kernel in zip(hidden_features, kernels):
            layers.append(RBFLayer(prev_features, size, kernel))
            prev_features = size
        self.rbf_layers = nn.ModuleList(layers)

        # final linear layer
        self.linear = nn.Linear(prev_features, out_features)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.rbf_layers:
            x = layer(x)
        return self.linear(x)

## <a id='toc3_3_'></a>[Model Training Pipeline](#toc0_)

<div style="text-align: center; padding-top: 10px;">
    <img src="../assets/images/original/rbf/rbf-example.svg" alt="rbf-example.svg" style="min-width: 512px; width: 60%; height: auto; border-radius: 16px;">
    <p><em>Figure 3: A Simple Radial Basis Function Network with Three Gaussian Functions</em></p>
</div>


**Calculating the number of parameters**:

<table style="margin: 0 auto; text-align:center;">
  <thead>
    <tr>
      <th colspan="2">RBF layer parameters</th>
      <th colspan="2">Output layer parameters</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Centers (M)</td>
      <td>Widths (œÉ)</td>
      <td>Weights (W)</td>
      <td>Bias</td>
    </tr>
    <tr>
      <td>2 √ó 3</td>
      <td>3</td>
      <td>3 √ó 3</td>
      <td>3</td>
    </tr>
  </tbody>
  <tfoot>
    <tr>
      <td colspan="2">(2 + 1) √ó 3 </td>
      <td colspan="2">(3 + 1) √ó 3</td>
    </tr>
    <tr style="border-top: 2px solid; font-weight: bold;">
      <td colspan="6">Total Parameters: 9 + 12 = <strong>21</strong></td>
    </tr>
  </tfoot>
</table>


### <a id='toc3_3_1_'></a>[Set up model and Hyperparameters](#toc0_)


In [None]:
# input/output sizes
in_features = trainset[0][0].shape[0]  # automatically from dataset
hidden_features = 3                    # number of RBF kernels
out_features = n_classes               # number of output classes

In [None]:
# initialize the model
model = RBFNet(in_features, hidden_features, out_features).to(device)

# log
model

In [None]:
summary(model, input_size=(batch_size, in_features))

In [None]:
# training hyperparameters
lr = 0.01
num_epochs = 10

In [None]:
# loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = Adam(params=model.parameters(), lr=lr)

### <a id='toc3_3_2_'></a>[Train & Validation Loop](#toc0_)


In [None]:
# store accuracy and loss at each epoch
train_acc_per_epoch = []
train_loss_per_epoch = []
train_acc = MulticlassAccuracy(num_classes=n_classes, top_k=1).to(device)

In [None]:
# store RBF layer parameters at each epoch
centers_history = []
sigmas_history = []

In [None]:
for epoch in range(num_epochs):
    model.train()
    train_loss = 0

    for x, y in trainloader:
        x, y_true = x.to(device), y.to(device)

        # forward + loss
        y_pred = model(x)
        loss = criterion(y_pred, y_true)

        # backward
        loss.backward()

        # update parameters
        optimizer.step()
        optimizer.zero_grad()

        # store loss and accuracy per iteration
        train_loss += loss.item() * len(x)
        train_acc.update(y_pred, y_true)

    # store metrics per epoch
    train_loss_per_epoch.append(train_loss / len(trainset))
    train_acc_per_epoch.append(train_acc.compute().item())
    train_acc.reset()

    # store current centers and sigmas per epoch
    centers_history.append(model.rbf.centers.detach().cpu().clone())
    sigmas_history.append(model.rbf.sigmas.detach().cpu().clone())

    # log
    print(
        f"epoch {epoch+1:0{len(str(num_epochs))}}/{num_epochs} -> "
        f"train[loss: {train_loss_per_epoch[-1]:7.5f} - acc: {train_acc_per_epoch[-1]*100:5.2f}%]"
    )

In [None]:
def plot_frame(epoch_idx, centers_history, sigmas_history, X_train, y_train, X_test, y_test):
    plt.clf()
    centers = centers_history[epoch_idx].numpy()
    sigmas = sigmas_history[epoch_idx].numpy()

    plt.scatter(X_train[:, 0], X_train[:, 1], marker="o", c=y_train, s=25, edgecolor="k", label="trainset")
    plt.scatter(X_test[:, 0], X_test[:, 1], marker=",", c=y_test, s=25, edgecolor="k", label="testset")

    for center, radius in zip(centers, sigmas):
        circle = plt.Circle(center, radius, color="b", fill=False)
        plt.gca().add_artist(circle)

    plt.title(f"Epoch {epoch_idx + 1}")
    plt.xlabel("feature 1")
    plt.ylabel("feature 2")
    plt.legend()


In [None]:
fig = plt.figure(figsize=(6,6))

anim = FuncAnimation(
    fig,
    lambda i: plot_frame(i, centers_history, sigmas_history, X_train, y_train, X_test, y_test),
    frames=len(centers_history),
    interval=1000
)

# choose backend depending on ffmpeg availability
HTML(display_backend(anim))


### <a id='toc3_3_3_'></a>[Test Loop](#toc0_)


In [None]:
test_acc = MulticlassAccuracy(num_classes=n_classes, top_k=1).to(device)

In [None]:
# test loop
model.eval()
test_loss = 0
predictions = []
targets = []

with torch.no_grad():
    for x, y in testloader:

        # send data to GPU
        x, y_true = x.to(device), y.to(device)

        # forward
        y_pred = model(x)
        loss = criterion(y_pred, y_true)

        # store loss and accuracy per iteration
        test_loss += loss.item() * len(x)
        test_acc.update(y_pred, y_true)

        predictions.extend(y_pred.argmax(dim=1).cpu())
        targets.extend(y_true.cpu())

# log
print(f"test[loss: {test_loss / len(testset):.5f} - acc: {test_acc.compute().item()*100:5.2f}%]")