# HMDB10 action classification
Simplified task by selecting 10 classes from [HMDB51](https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/). (Kinetics-400 is way too large)

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "5"

In [4]:
import sys

sys.path.append("..")

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import Subset
from torch.utils.data import DataLoader
from einops import rearrange
import torch.nn.functional as F
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt

from ndswin import SwinClassifier

In [5]:
import warnings

warnings.simplefilter("ignore", UserWarning)

In [6]:
def accuracy(pred, y):
    predicted_labels = torch.argmax(pred, dim=1)
    correct_predictions = predicted_labels == y
    accuracy = correct_predictions.sum().item() / correct_predictions.numel()
    return accuracy * 100

### Download HMDB
Download and extract the `.rar` dataset from [serre-lab.clps.brown.edu](https://serre-lab.clps.brown.edu).

In [7]:
# !mkdir -p hmdb51/videos
# !wget --no-check-certificate -O hmdb51/videos/hmdb51_org.rar http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar
# !wget --no-check-certificate -O hmdb51/test_train_splits.rar https://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar
# !unrar x hmdb51/videos/hmdb51_org.rar hmdb51/ -o+
# !find hmdb51/ -name "*.rar" -exec unrar x {} hmdb51/videos \;
# !unrar x hmdb51/test_train_splits.rar hmdb51/ -o+

### Dataset (torchvision)
Using the `torchvision.datasets` implementation of HMDB51, with 16 frame clips and some data augmentation. Setting `frame_rate=5` is important. Then, only some classes are kept to simplify the classification problem.

In [8]:
def resize(x):
    x = rearrange(x, "t h w c -> c t h w")
    return F.interpolate(x, (80, 80))


def rescale(x):
    return x / 255.0


def noise_over_time(std: float = 0.05):
    def _transform(x):
        for i in range(x.shape[1]):
            x[:, i] = torch.clip(x[:, i] + torch.randn_like(x[:, i]) * std, 0, 1)
        return x

    return _transform


def minmax(x):
    return (x - x.min()) / (x.max() - x.min())


train_transform = transforms.Compose(
    [
        rescale,
        resize,
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop((64, 64)),
        # noise_over_time(),
    ]
)
val_transform = transforms.Compose([rescale, resize, transforms.CenterCrop((64, 64))])

# traindata_full = datasets.HMDB51(
#     "hmdb51/videos",
#     annotation_path="hmdb51/testTrainMulti_7030_splits",
#     frames_per_clip=16,
#     frame_rate=5,
#     step_between_clips=2,
#     train=True,
#     transform=train_transform,
# )
# valdata_full = datasets.HMDB51(
#     "hmdb51/videos",
#     annotation_path="hmdb51/testTrainMulti_7030_splits",
#     frames_per_clip=16,
#     frame_rate=5,
#     step_between_clips=2,
#     train=False,
#     transform=val_transform,
# )

In [25]:
# filter class labels
hmdb10_classes = {
    46: "talk",
    11: "eat",
    23: "laugh",
    4: "clap",
    27: "punch",
    30: "ride_bike",
    31: "ride_horse",
    32: "run",
    49: "walk",
    45: "sword_exercise",
}
hmdb10_class_map = {c: i for i, c in enumerate(hmdb10_classes.keys())}
hmdb10_class_map_inv = {i: c for i, c in enumerate(hmdb10_classes.keys())}


def remap_hmdb10_labels(batch_labels: torch.Tensor) -> torch.Tensor:
    return torch.tensor(
        [hmdb10_class_map[int(label)] for label in batch_labels], dtype=torch.long
    )


def invert_hmdb10_labels(batch_labels: torch.Tensor) -> torch.Tensor:
    return torch.tensor(
        [hmdb10_class_map_inv[int(label)] for label in batch_labels], dtype=torch.long
    )


# traindata = Subset(traindata_full, [i for i in tqdm(range(len(traindata_full))) if traindata_full[i][2] in hmdb10_classes])
# valdata = Subset(valdata_full, [i for i in tqdm(range(len(valdata_full))) if valdata_full[i][2] in hmdb10_classes])

print(f"Subsampled HMDB51 from {len(traindata_full)} to {len(traindata)}")

Subsampled HMDB51 from 7520 to 1976


In [10]:
import pickle

traindata_full = pickle.load(open("traindata_full.pickle", "rb"))
valdata_full = pickle.load(open("valdata_full.pickle", "rb"))
traindata = pickle.load(open("traindata.pickle", "rb"))
valdata = pickle.load(open("valdata.pickle", "rb"))

In [12]:
from matplotlib import animation
from IPython.display import HTML


idxs = np.random.randint(0, len(traindata), 5)
videos = np.stack([traindata[i][0].numpy() for i in idxs], 0)
titles = [traindata[i][2] for i in idxs]

fig, ax = plt.subplots(1, 5, figsize=(15, 4), layout="tight")
fig.subplots_adjust(wspace=0, hspace=0)

ims = []
for i in range(5):
    image_ = np.transpose(videos[i, :, 0], (1, 2, 0))
    im = ax[i].imshow(image_)
    ax[i].axis("off")
    ax[i].set_title(hmdb10_classes[titles[i]])
    ims.append(im)


def update(frame):
    for i in range(5):
        image_ = np.transpose(videos[i, :, frame], (1, 2, 0))
        ims[i].set_array(image_)
    return ims


ani = animation.FuncAnimation(
    fig, update, frames=videos.shape[2], interval=150, blit=True
)

plt.close(fig)

HTML(ani.to_html5_video())

## Video Swin Transformer
3D Swin classifier with 10 classes. It uses a similar arhchitecture to the [CIFAR notebook](./cifar.ipynb), but with an additional dimension for time (`space=3`). Patch merging is not performed on the time dimension as in the original Video Swin implementation.

In [None]:
device = "cuda"
PRETRAIN_N_EPOCHS = 20
N_EPOCHS = 20

model = SwinClassifier(
    space=3,
    dim=192,
    resolution=(16, 64, 64),
    num_classes=10,
    patch_size=(1, 4, 4),
    window_size=(16, 8, 8),
    depth=(2, 4),
    num_heads=(6, 12),
    use_abs_pe=False,
    use_conv=True,
    merge_mask=[False, True, True],  # no merging on time
    head_drop_p=0.5,
)

print(f"Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")

Parameters: 9.5M


## Masked video autoencoder
To make video classification easier the swin encoder is first trained for video reconstruction in an (naive) MAE fashion, on the entire HMDB dataset. Random frames are masked out in the input. The `Autoencoder` class uses `SwinClassifier` as encoder, and a stack of `SwinLayers` + `ConvTranspose` for the decoder part.

In [18]:
def mask_spacetime_patch(x, prob: float = 0.4, patch_size=(1, 64, 64)):
    B, C, T, H, W = x.shape
    pt, ph, pw = patch_size

    mask = (torch.rand(B, T // pt, H // ph, W // pw, device=x.device) < prob).float()
    mask = mask.repeat_interleave(pt, dim=1)
    mask = mask.repeat_interleave(ph, dim=2)
    mask = mask.repeat_interleave(pw, dim=3)
    mask = mask[:, :T, :H, :W]
    mask = mask.unsqueeze(1).repeat(1, C, 1, 1, 1).bool()

    x[mask] = 0.0
    return x, mask


MASK_FN = mask_spacetime_patch

In [19]:
from ndswin.layers import SwinLayer
from einops.layers.torch import Rearrange


class Autoencoder(nn.Module):
    def __init__(
        self,
        encoder: SwinClassifier,
        dim,
        patch_size,
        window_size,
        space=2,
        in_channels=3,
        merge_mask=None,
    ):
        super().__init__()
        assert space in [2, 3]
        self.space = space

        self.encoder = encoder

        self.bn_res = self.encoder.grid_sizes[-1]
        grid_size = self.bn_res
        Conv = nn.Conv2d if space == 2 else nn.Conv3d
        ConvTranspose = nn.ConvTranspose2d if space == 2 else nn.ConvTranspose3d
        # no upconv on time
        up_kernel = [2 if merge_mask[i] else 1 for i in range(space)]
        decoder = []
        dim = dim * 2**self.encoder.num_layers
        for _ in range(self.encoder.num_layers):
            dim = dim // 2
            decoder.append(
                nn.Sequential(
                    Conv(2 * dim, dim, kernel_size=[1] * space),
                    Rearrange("b c ... -> b ... c"),
                    SwinLayer(
                        space=space,
                        dim=dim,
                        depth=2,
                        num_heads=8,
                        grid_size=grid_size,
                        window_size=window_size,
                    ),
                    Rearrange("b ... c -> b c ..."),
                    nn.GELU(),
                    ConvTranspose(dim, dim, kernel_size=up_kernel, stride=up_kernel),
                )
            )
            grid_size = [grid_size[i] * up_kernel[i] for i in range(space)]
        self.decoder = nn.Sequential(*decoder)
        self.unpatch = ConvTranspose(
            dim, in_channels, kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        z = self.encoder.forward_features(x)
        z = rearrange(z, "b ... c -> b c ...")
        x = self.decoder(z)
        return self.unpatch(x)


ae = Autoencoder(
    encoder=model,
    dim=192,
    space=3,
    in_channels=3,
    patch_size=(1, 4, 4),
    window_size=(16, 8, 8),
    merge_mask=[False, True, True],
)

In [20]:
pretrain_opt = torch.optim.Adam(ae.parameters(), lr=1e-3, weight_decay=1e-8)
pretrain_sched = torch.optim.lr_scheduler.CosineAnnealingLR(
    pretrain_opt, PRETRAIN_N_EPOCHS, 1e-3
)

In [21]:
pretrain_trainloader = DataLoader(
    traindata_full, 32, num_workers=4, shuffle=True, pin_memory=True
)
pretrain_valloader = DataLoader(
    valdata_full, 32, num_workers=4, shuffle=False, pin_memory=True
)

In [22]:
ae = ae.to(device)

results = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}

for e in range(1, PRETRAIN_N_EPOCHS + 1):
    # train
    train_loss = 0.0
    ae.train()
    for sample in tqdm(pretrain_trainloader, "training"):
        x = sample[0].to(device)
        masked_x, mask = MASK_FN(x.clone())
        # loss only on masked reconstruction
        loss = F.mse_loss(ae(masked_x)[mask], x[mask])
        pretrain_opt.zero_grad()
        loss.backward()
        pretrain_opt.step()
        train_loss += loss.item()

    pretrain_sched.step()

    # eval
    val_str = ""
    if (e % 5) == 0 or e == 1:
        val_loss = 0.0
        ae.eval()
        with torch.no_grad():
            for sample in tqdm(pretrain_valloader, "evaluate"):
                x = sample[0].to(device)
                masked_x, mask = MASK_FN(x.clone())
                val_loss += F.mse_loss(ae(masked_x)[mask], x[mask])
            results["val_loss"].append(val_loss / len(pretrain_valloader))
            val_str = f", val/loss: {results['val_loss'][-1]:.3f}, "

    results["train_loss"].append(train_loss / len(pretrain_trainloader))
    e_s = str(e).zfill(len(str(N_EPOCHS)))
    print(f"[{e_s}] train/loss: {results['train_loss'][-1]:.3f}{val_str}")

training: 100%|██████████| 235/235 [02:47<00:00,  1.40it/s]
evaluate: 100%|██████████| 93/93 [00:29<00:00,  3.19it/s]


[01] train/loss: 0.059, val/loss: 0.033, 


training: 100%|██████████| 235/235 [02:44<00:00,  1.43it/s]


[02] train/loss: 0.030


training: 100%|██████████| 235/235 [02:44<00:00,  1.43it/s]


[03] train/loss: 0.024


training: 100%|██████████| 235/235 [02:44<00:00,  1.43it/s]


[04] train/loss: 0.019


training: 100%|██████████| 235/235 [02:44<00:00,  1.43it/s]
evaluate: 100%|██████████| 93/93 [00:28<00:00,  3.24it/s]


[05] train/loss: 0.018, val/loss: 0.017, 


training: 100%|██████████| 235/235 [02:44<00:00,  1.43it/s]


[06] train/loss: 0.017


training: 100%|██████████| 235/235 [02:44<00:00,  1.43it/s]


[07] train/loss: 0.016


training: 100%|██████████| 235/235 [02:44<00:00,  1.43it/s]


[08] train/loss: 0.021


training: 100%|██████████| 235/235 [02:44<00:00,  1.43it/s]


[09] train/loss: 0.016


training: 100%|██████████| 235/235 [02:44<00:00,  1.43it/s]
evaluate: 100%|██████████| 93/93 [00:28<00:00,  3.21it/s]


[10] train/loss: 0.015, val/loss: 0.016, 


training: 100%|██████████| 235/235 [02:44<00:00,  1.43it/s]


[11] train/loss: 0.015


training: 100%|██████████| 235/235 [02:44<00:00,  1.42it/s]


[12] train/loss: 0.016


training: 100%|██████████| 235/235 [02:44<00:00,  1.43it/s]


[13] train/loss: 0.014


training: 100%|██████████| 235/235 [02:44<00:00,  1.43it/s]


[14] train/loss: 0.013


training: 100%|██████████| 235/235 [02:44<00:00,  1.43it/s]
evaluate: 100%|██████████| 93/93 [00:29<00:00,  3.19it/s]


[15] train/loss: 0.013, val/loss: 0.013, 


training: 100%|██████████| 235/235 [02:44<00:00,  1.43it/s]


[16] train/loss: 0.013


training: 100%|██████████| 235/235 [02:44<00:00,  1.43it/s]


[17] train/loss: 0.012


training: 100%|██████████| 235/235 [02:44<00:00,  1.43it/s]


[18] train/loss: 0.013


training: 100%|██████████| 235/235 [02:45<00:00,  1.42it/s]


[19] train/loss: 0.012


training: 100%|██████████| 235/235 [02:45<00:00,  1.42it/s]
evaluate: 100%|██████████| 93/93 [00:31<00:00,  2.96it/s]

[20] train/loss: 0.012, val/loss: 0.013, 





### Reconstruction
Visualizing the bottleneck latent and the reconstructed video. The videos come from the __validation set__, and are reconstructed in a single forward pass.

In [23]:
from matplotlib import animation
from IPython.display import HTML

ae = ae.cpu()

idxs = np.random.randint(0, len(valdata_full), 5)
with torch.no_grad():
    gt_videos = torch.stack([valdata_full[i][0] for i in idxs], 0)
    masked_videos, mask = MASK_FN(gt_videos.clone())
    # 3D latent space (meanpool)
    lat_videos = ae.encoder.forward_features(masked_videos).mean(-1).numpy()
    # reconstruction
    ae_videos = ae(masked_videos).numpy()
    ae_videos[~mask] = masked_videos[~mask]

fig, ax = plt.subplots(3, 5, figsize=(15, 9), layout="tight")
fig.subplots_adjust(wspace=0, hspace=0)

ax[0, 0].set_ylabel("Ground truth", fontsize=12, labelpad=10)
ax[1, 0].set_ylabel("Latent", fontsize=12, labelpad=10)
ax[2, 0].set_ylabel("Reconstructed", fontsize=12, labelpad=10)

ims = []
for i in range(5):
    image_ = np.transpose(gt_videos[i, :, 0], (1, 2, 0))
    image_lat_ = lat_videos[i, 0]
    image_p_ = np.transpose(ae_videos[i, :, 0], (1, 2, 0))
    image_ = minmax(image_)
    image_lat_ = minmax(image_lat_)
    image_p_ = minmax(image_p_)
    im = ax[0, i].imshow(image_)
    imlat = ax[1, i].imshow(image_lat_)
    imp = ax[2, i].imshow(image_p_)
    for j in [0, 1, 2]:
        ax[j, i].get_xaxis().set_ticks([])
        ax[j, i].get_yaxis().set_ticks([])
    ims.append(im)
    ims.append(imlat)
    ims.append(imp)


def update(frame):
    for i in range(5):
        image_ = np.transpose(gt_videos[i, :, frame], (1, 2, 0))
        image_lat_ = lat_videos[i, frame // 2]
        image_p_ = np.transpose(ae_videos[i, :, frame], (1, 2, 0))
        image_ = minmax(image_)
        image_lat_ = minmax(image_lat_)
        image_p_ = minmax(image_p_)
        ims[3 * i].set_array(image_)
        ims[3 * i + 1].set_array(image_lat_)
        ims[3 * i + 2].set_array(image_p_)
    return ims


ani = animation.FuncAnimation(
    fig, update, frames=gt_videos.shape[2], interval=150, blit=True
)

plt.close(fig)

HTML(ani.to_html5_video())

## Training
Training with batch 64 of 16x64x64 videos (THW) takes around 1 hour on a single A100 and ~30GB of memory.

In [26]:
opt = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-3)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, N_EPOCHS, 1e-4)

In [27]:
trainloader = DataLoader(traindata, 64, num_workers=4, shuffle=True, pin_memory=True)
valloader = DataLoader(valdata, 64, num_workers=4, shuffle=False, pin_memory=True)

In [28]:
model = model.to(device)

results = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}

for e in range(1, N_EPOCHS + 1):
    # train
    train_loss = 0.0
    train_acc = 0.0
    model.train()
    for sample in tqdm(trainloader, "training"):
        x = sample[0].to(device)
        y = remap_hmdb10_labels(sample[2]).to(device)
        pred = F.log_softmax(model(x), dim=1)
        loss = F.nll_loss(pred, y)
        opt.zero_grad()
        loss.backward()
        opt.step()
        train_loss += loss.item()
        train_acc += accuracy(pred, y)

    sched.step()

    # eval
    val_str = ""
    if (e % 2) == 0 or e == 1:
        val_loss = 0.0
        val_acc = 0.0
        model.eval()
        with torch.no_grad():
            for sample in tqdm(valloader, "evaluate"):
                x = sample[0].to(device)
                y = remap_hmdb10_labels(sample[2]).to(device)
                pred = F.log_softmax(model(x), dim=1)
                val_loss += F.nll_loss(pred, y).item()
                val_acc += accuracy(pred, y)
            results["val_loss"].append(val_loss / len(valloader))
            results["val_acc"].append(val_acc / len(valloader))
            val_str = (
                f", val/loss: {results['val_loss'][-1]:.2f}, "
                f"val/acc: {results['val_acc'][-1]:.2f}%"
            )

    results["train_loss"].append(train_loss / len(trainloader))
    results["train_acc"].append(train_acc / len(trainloader))
    e_s = str(e).zfill(len(str(N_EPOCHS)))
    print(
        f"[{e_s}] train/loss: {results['train_loss'][-1]:.2f}, "
        f"train/acc: {results['train_acc'][-1]:.2f}%",
        val_str,
    )

training: 100%|██████████| 31/31 [00:39<00:00,  1.29s/it]
evaluate: 100%|██████████| 13/13 [00:11<00:00,  1.14it/s]


[01] train/loss: 1.85, train/acc: 34.21% , val/loss: 2.04, val/acc: 32.69%


training: 100%|██████████| 31/31 [00:38<00:00,  1.25s/it]
evaluate: 100%|██████████| 13/13 [00:12<00:00,  1.06it/s]


[02] train/loss: 1.62, train/acc: 44.34% , val/loss: 1.98, val/acc: 36.90%


training: 100%|██████████| 31/31 [00:38<00:00,  1.24s/it]


[03] train/loss: 1.53, train/acc: 47.26% 


training: 100%|██████████| 31/31 [00:38<00:00,  1.24s/it]
evaluate: 100%|██████████| 13/13 [00:11<00:00,  1.18it/s]


[04] train/loss: 1.40, train/acc: 54.79% , val/loss: 1.95, val/acc: 30.41%


training: 100%|██████████| 31/31 [00:38<00:00,  1.24s/it]


[05] train/loss: 1.22, train/acc: 60.85% 


training: 100%|██████████| 31/31 [00:38<00:00,  1.24s/it]
evaluate: 100%|██████████| 13/13 [00:11<00:00,  1.12it/s]


[06] train/loss: 1.06, train/acc: 65.65% , val/loss: 1.87, val/acc: 44.85%


training: 100%|██████████| 31/31 [00:38<00:00,  1.24s/it]


[07] train/loss: 0.91, train/acc: 72.00% 


training: 100%|██████████| 31/31 [00:38<00:00,  1.24s/it]
evaluate: 100%|██████████| 13/13 [00:13<00:00,  1.01s/it]


[08] train/loss: 0.79, train/acc: 75.24% , val/loss: 2.00, val/acc: 39.94%


training: 100%|██████████| 31/31 [00:38<00:00,  1.24s/it]


[09] train/loss: 0.71, train/acc: 77.15% 


training: 100%|██████████| 31/31 [00:38<00:00,  1.24s/it]
evaluate: 100%|██████████| 13/13 [00:13<00:00,  1.02s/it]


[10] train/loss: 0.66, train/acc: 79.69% , val/loss: 2.13, val/acc: 41.50%


training: 100%|██████████| 31/31 [00:38<00:00,  1.23s/it]


[11] train/loss: 0.58, train/acc: 81.31% 


training: 100%|██████████| 31/31 [00:38<00:00,  1.24s/it]
evaluate: 100%|██████████| 13/13 [00:11<00:00,  1.12it/s]


[12] train/loss: 0.53, train/acc: 84.22% , val/loss: 2.33, val/acc: 44.15%


training: 100%|██████████| 31/31 [00:38<00:00,  1.23s/it]


[13] train/loss: 0.50, train/acc: 84.48% 


training: 100%|██████████| 31/31 [00:38<00:00,  1.24s/it]
evaluate: 100%|██████████| 13/13 [00:13<00:00,  1.02s/it]


[14] train/loss: 0.48, train/acc: 84.84% , val/loss: 2.61, val/acc: 39.81%


training: 100%|██████████| 31/31 [00:38<00:00,  1.24s/it]


[15] train/loss: 0.42, train/acc: 86.42% 


training: 100%|██████████| 31/31 [00:38<00:00,  1.24s/it]
evaluate: 100%|██████████| 13/13 [00:10<00:00,  1.20it/s]


[16] train/loss: 0.40, train/acc: 88.28% , val/loss: 2.70, val/acc: 39.94%


training: 100%|██████████| 31/31 [00:38<00:00,  1.25s/it]


[17] train/loss: 0.36, train/acc: 89.01% 


training: 100%|██████████| 31/31 [00:38<00:00,  1.25s/it]
evaluate: 100%|██████████| 13/13 [00:12<00:00,  1.06it/s]


[18] train/loss: 0.36, train/acc: 88.41% , val/loss: 2.69, val/acc: 44.15%


training: 100%|██████████| 31/31 [00:38<00:00,  1.23s/it]


[19] train/loss: 0.36, train/acc: 88.93% 


training: 100%|██████████| 31/31 [00:38<00:00,  1.25s/it]
evaluate: 100%|██████████| 13/13 [00:12<00:00,  1.06it/s]

[20] train/loss: 0.31, train/acc: 90.42% , val/loss: 2.66, val/acc: 38.15%





## Qualitative model predictions

In [32]:
from matplotlib import animation
from IPython.display import HTML


idxs = np.random.randint(0, len(valdata), 5)
videos = torch.stack([valdata[i][0] for i in idxs], 0)
pred_c = model(videos.to(device, dtype=torch.float32)).detach().cpu().argmax(-1)
pred_c = invert_hmdb10_labels(pred_c)

fig, ax = plt.subplots(1, 5, figsize=(15, 4), layout="tight")
fig.subplots_adjust(wspace=0, hspace=0)

ims = []
for i in range(5):
    image_ = np.transpose(videos[i, :, 0], (1, 2, 0))
    im = ax[i].imshow(image_)
    ax[i].axis("off")
    ax[i].set_title(hmdb10_classes[pred_c[i].item()])
    ims.append(im)


def update(frame):
    for i in range(5):
        image_ = np.transpose(videos[i, :, frame], (1, 2, 0))
        ims[i].set_array(image_)
    return ims


ani = animation.FuncAnimation(
    fig, update, frames=videos.shape[2], interval=100, blit=True
)

plt.close(fig)

HTML(ani.to_html5_video())