# About

This notebook demonstrates the identical results of vidoe swin transformer, imported from `keras-cv` and `torch-vision` libraries. The `keras-cv` version of video swin is implemented in `keras 3`, makes it able to run in multiple backend, i.e. `tensorflow`, `torch`, and `jax`.

In [1]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import warnings

In [2]:
os.environ["KERAS_BACKEND"] = "torch" # 'torch', 'tensorflow', 'jax'
warnings.simplefilter(action="ignore")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

In [3]:
!git clone --branch video_swin https://github.com/innat/keras-cv.git
%cd keras-cv
!pip install -q -e .

Cloning into 'keras-cv'...
remote: Enumerating objects: 13766, done.[K
remote: Counting objects: 100% (1903/1903), done.[K
remote: Compressing objects: 100% (760/760), done.[K
remote: Total 13766 (delta 1325), reused 1617 (delta 1127), pack-reused 11863[K
Receiving objects: 100% (13766/13766), 25.64 MiB | 27.10 MiB/s, done.
Resolving deltas: 100% (9776/9776), done.
/kaggle/working/keras-cv


# KerasCV: Video Swin : Pretrained: ImageNet 1K

In [4]:
import keras
from keras import ops
from keras_cv.models import VideoSwinBackbone
from keras_cv.models import VideoClassifier

keras.__version__

'3.0.5'

In [5]:
!wget https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_tiny_kinetics400.weights.h5

def vswin_tiny():
    backbone=VideoSwinBackbone(
        input_shape=(32, 224, 224, 3), 
        embed_dim=96,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        include_rescaling=False, 
    )
    backbone.load_weights(
        'videoswin_tiny_kinetics400.weights.h5'
    )
    return backbone

--2024-03-31 16:57:37--  https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_tiny_kinetics400.weights.h5
Resolving github.com (github.com)... 140.82.121.3
Connecting to github.com (github.com)|140.82.121.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/697696973/5153e756-236b-41e7-a602-ab854a57034f?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T165737Z&X-Amz-Expires=300&X-Amz-Signature=6188cd48f4cffee2ddbc4c5a3c8e4701e15588c12e3446ed5b8a52a002072164&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=697696973&response-content-disposition=attachment%3B%20filename%3Dvideoswin_tiny_kinetics400.weights.h5&response-content-type=application%2Foctet-stream [following]
--2024-03-31 16:57:37--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/697696973

In [6]:
!wget https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_small_kinetics400.weights.h5

def vswin_small():
    backbone=VideoSwinBackbone(
        input_shape=(32, 224, 224, 3), 
        embed_dim=96,
        depths=[2, 2, 18, 2],
        num_heads=[3, 6, 12, 24],
        include_rescaling=False, 
    )
    backbone.load_weights(
        'videoswin_small_kinetics400.weights.h5'
    )
    return backbone

--2024-03-31 16:57:39--  https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_small_kinetics400.weights.h5
Resolving github.com (github.com)... 140.82.121.4
Connecting to github.com (github.com)|140.82.121.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/697696973/898b24c6-f517-4b01-872b-8f19acd2c54d?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T165739Z&X-Amz-Expires=300&X-Amz-Signature=ebd091cee3f64c57654966b81170827c3c667d8bab72a6b4969a987561926f0f&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=697696973&response-content-disposition=attachment%3B%20filename%3Dvideoswin_small_kinetics400.weights.h5&response-content-type=application%2Foctet-stream [following]
--2024-03-31 16:57:39--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/6976969

In [7]:
!wget https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_kinetics400.weights.h5

def vswin_base():
    backbone=VideoSwinBackbone(
        input_shape=(32, 224, 224, 3), 
        embed_dim=128,
        depths=[2, 2, 18, 2],
        num_heads=[4, 8, 16, 32],
        include_rescaling=False, 
    )
    backbone.load_weights(
        'videoswin_base_kinetics400.weights.h5'
    )
    return backbone

--2024-03-31 16:57:41--  https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_kinetics400.weights.h5
Resolving github.com (github.com)... 140.82.121.4
Connecting to github.com (github.com)|140.82.121.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/697696973/d5a7b9f0-78b7-4151-b1d3-ddba5c66c7c1?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T165742Z&X-Amz-Expires=300&X-Amz-Signature=4dbdc130edd48a081675524290e07e70aa8d48bcd5862bb80d8f15b72903abdd&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=697696973&response-content-disposition=attachment%3B%20filename%3Dvideoswin_base_kinetics400.weights.h5&response-content-type=application%2Foctet-stream [following]
--2024-03-31 16:57:42--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/697696973

In [8]:
keras_models = [vswin_tiny(), vswin_small(), vswin_base()]
keras_models[0].summary()

# TorchVision: Video Swin : Pretrained: ImageNet 1K

In [9]:
import torch
import torchvision
from torchinfo import summary
from torchvision.models.video import Swin3D_T_Weights, Swin3D_S_Weights, Swin3D_B_Weights

In [10]:
def exclude_top(model):
    backbone = torch.nn.Sequential(
        *(list(model.children())[:-2])
    )
    backbone.eval()
    return backbone

def torch_vswin_tiny():
    torch_model = torchvision.models.video.swin3d_t(
        weights=Swin3D_T_Weights.KINETICS400_V1
    ).eval()
    backbone = exclude_top(torch_model)
    return backbone

def torch_vswin_small():
    torch_model = torchvision.models.video.swin3d_s(
        weights=Swin3D_S_Weights.KINETICS400_V1
    ).eval()
    backbone = exclude_top(torch_model)
    return backbone

def torch_vswin_base():
    torch_model = torchvision.models.video.swin3d_b(
        weights=Swin3D_B_Weights.KINETICS400_V1
    ).eval()
    backbone = exclude_top(torch_model)
    return backbone

In [11]:
torch_models = [torch_vswin_tiny(), torch_vswin_small(), torch_vswin_base()]
summary(
    torch_models[0], input_size=(1, 3, 32, 224, 224)
)

Downloading: "https://download.pytorch.org/models/swin3d_t-7615ae03.pth" to /root/.cache/torch/hub/checkpoints/swin3d_t-7615ae03.pth
100%|██████████| 122M/122M [00:00<00:00, 131MB/s]
Downloading: "https://download.pytorch.org/models/swin3d_s-da41c237.pth" to /root/.cache/torch/hub/checkpoints/swin3d_s-da41c237.pth
100%|██████████| 218M/218M [00:06<00:00, 37.7MB/s]
Downloading: "https://download.pytorch.org/models/swin3d_b_1k-24f7c7c6.pth" to /root/.cache/torch/hub/checkpoints/swin3d_b_1k-24f7c7c6.pth
100%|██████████| 364M/364M [00:02<00:00, 135MB/s]


Layer (type:depth-idx)                                  Output Shape              Param #
Sequential                                              [1, 16, 7, 7, 768]        --
├─PatchEmbed3d: 1-1                                     [1, 16, 56, 56, 96]       --
│    └─Conv3d: 2-1                                      [1, 96, 16, 56, 56]       9,312
│    └─LayerNorm: 2-2                                   [1, 16, 56, 56, 96]       192
├─Dropout: 1-2                                          [1, 16, 56, 56, 96]       --
├─Sequential: 1-3                                       [1, 16, 7, 7, 768]        --
│    └─Sequential: 2-3                                  [1, 16, 56, 56, 96]       --
│    │    └─SwinTransformerBlock: 3-1                   [1, 16, 56, 56, 96]       119,445
│    │    └─SwinTransformerBlock: 3-2                   [1, 16, 56, 56, 96]       119,445
│    └─PatchMerging: 2-4                                [1, 16, 28, 28, 192]      --
│    │    └─LayerNorm: 3-3                    

# Inference

In [12]:
common_input = np.random.normal(0, 1, (1, 32, 224, 224, 3)).astype('float32')
keras_input = ops.array(common_input)
torch_input = torch.from_numpy(common_input.transpose(0, 4, 1, 2, 3))
print(keras_input.shape, torch_input.shape)

torch.Size([1, 32, 224, 224, 3]) torch.Size([1, 3, 32, 224, 224])


In [13]:
def logit_checking(keras_model, torch_model):
    # forward pass
    keras_predict = keras_model(keras_input)
    torch_predict = torch_model(torch_input) 
    print(keras_predict.shape, torch_predict.shape)
    np.testing.assert_allclose(
        keras_predict.detach().numpy(),
        torch_predict.detach().numpy(),
        1e-4, 1e-4
    )
    del keras_model 
    del torch_model

In [14]:
for km, tm in zip(keras_models, torch_models):
    logit_checking(
        km, tm
    )

torch.Size([1, 16, 7, 7, 768]) torch.Size([1, 16, 7, 7, 768])
torch.Size([1, 16, 7, 7, 768]) torch.Size([1, 16, 7, 7, 768])
torch.Size([1, 16, 7, 7, 1024]) torch.Size([1, 16, 7, 7, 1024])


In [15]:
import gc
gc.collect()

40

# Keras: Video Swin Base - Pretrained: ImageNet 22K

In [16]:
!wget https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_kinetics400_imagenet22k.weights.h5

def vswin_base():
    backbone=VideoSwinBackbone(
        input_shape=(32, 224, 224, 3), 
        embed_dim=128,
        depths=[2, 2, 18, 2],
        num_heads=[4, 8, 16, 32],
        include_rescaling=False, 
    )
    backbone.load_weights(
        'videoswin_base_kinetics400_imagenet22k.weights.h5'
    )
    return backbone

--2024-03-31 16:59:25--  https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_kinetics400_imagenet22k.weights.h5
Resolving github.com (github.com)... 140.82.121.4
Connecting to github.com (github.com)|140.82.121.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/697696973/75b53567-f9ae-4739-87c1-0d5d9d423f25?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240331%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240331T165925Z&X-Amz-Expires=300&X-Amz-Signature=b3146437a5644138b963a8e376f8cc066e3ff2c0b4bb7a05e1f705e930095453&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=697696973&response-content-disposition=attachment%3B%20filename%3Dvideoswin_base_kinetics400_imagenet22k.weights.h5&response-content-type=application%2Foctet-stream [following]
--2024-03-31 16:59:25--  https://objects.githubusercontent.com/github-production-releas

In [17]:
keras_models = vswin_base()

In [18]:
import torchvision
from torchvision.models.video import Swin3D_B_Weights

torch_model = torchvision.models.video.swin3d_b(
    weights=Swin3D_B_Weights.KINETICS400_IMAGENET22K_V1
).eval()
torch_model = exclude_top(torch_model)

Downloading: "https://download.pytorch.org/models/swin3d_b_22k-7c6ae6fa.pth" to /root/.cache/torch/hub/checkpoints/swin3d_b_22k-7c6ae6fa.pth
100%|██████████| 364M/364M [00:19<00:00, 19.7MB/s]


In [19]:
logit_checking(
    keras_models, torch_model
)

torch.Size([1, 16, 7, 7, 1024]) torch.Size([1, 16, 7, 7, 1024])
