# About

This notebook demonstrates the identical results of vidoe swin transformer, imported from `keras-cv` and `torch-vision` libraries. The `keras-cv` version of video swin is implemented in `keras 3`, makes it able to run in multiple backend, i.e. `tensorflow`, `torch`, and `jax`.

In [1]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import warnings

In [2]:
os.environ["KERAS_BACKEND"] = "torch" # 'torch', 'tensorflow', 'jax'

warnings.simplefilter(action="ignore")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

In [3]:
!git clone --branch video_swin https://github.com/innat/keras-cv.git
%cd keras-cv
!pip install -q -e .

Cloning into 'keras-cv'...
remote: Enumerating objects: 13735, done.[K
remote: Counting objects: 100% (1872/1872), done.[K
remote: Compressing objects: 100% (752/752), done.[K
remote: Total 13735 (delta 1297), reused 1587 (delta 1104), pack-reused 11863[K
Receiving objects: 100% (13735/13735), 25.64 MiB | 31.71 MiB/s, done.
Resolving deltas: 100% (9742/9742), done.
/kaggle/working/keras-cv


# KerasCV: Video Swin : Pretrained: ImageNet 1K

In [4]:
import keras
from keras import ops
from keras_cv.models import VideoSwinBackbone
from keras_cv.models import VideoClassifier

keras.__version__

'3.0.5'

In [5]:
def vswin_tiny():
    backbone=VideoSwinBackbone(
        input_shape=(32, 224, 224, 3), 
        embed_dim=96,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        include_rescaling=False, 
    )
    keras_model = VideoClassifier(
        backbone=backbone,
        num_classes=400,
        activation=None,
        pooling='avg',
    )
    keras_model.load_weights(
        '/kaggle/input/videoswin/keras/tiny/1/videoswin_tiny_kinetics400_classifier.weights.h5'
    )
    return keras_model

In [6]:
def vswin_small():
    backbone=VideoSwinBackbone(
        input_shape=(32, 224, 224, 3), 
        embed_dim=96,
        depths=[2, 2, 18, 2],
        num_heads=[3, 6, 12, 24],
        include_rescaling=False, 
    )
    keras_model = VideoClassifier(
        backbone=backbone,
        num_classes=400,
        activation=None,
        pooling='avg',
    )
    keras_model.load_weights(
        '/kaggle/input/videoswin/keras/small/1/videoswin_small_kinetics400_classifier.weights.h5'
    )
    return keras_model

In [7]:
def vswin_base():
    backbone=VideoSwinBackbone(
        input_shape=(32, 224, 224, 3), 
        embed_dim=128,
        depths=[2, 2, 18, 2],
        num_heads=[4, 8, 16, 32],
        include_rescaling=False, 
    )
    keras_model = VideoClassifier(
        backbone=backbone,
        num_classes=400,
        activation=None,
        pooling='avg',
    )
    keras_model.load_weights(
        '/kaggle/input/videoswin/keras/base/1/videoswin_base_kinetics400_classifier.weights.h5'
    )
    return keras_model

In [8]:
keras_models = [vswin_tiny(), vswin_small(), vswin_base()]
keras_models[0].summary()

# TorchVision: Video Swin : Pretrained: ImageNet 1K

In [9]:
import torch
import torchvision
from torchinfo import summary
from torchvision.models.video import Swin3D_T_Weights, Swin3D_S_Weights, Swin3D_B_Weights

def torch_vswin_tiny():
    torch_model = torchvision.models.video.swin3d_t(
        weights=Swin3D_T_Weights.KINETICS400_V1
    ).eval()
    return torch_model

def torch_vswin_small():
    torch_model = torchvision.models.video.swin3d_s(
        weights=Swin3D_S_Weights.KINETICS400_V1
    ).eval()
    return torch_model

def torch_vswin_base():
    torch_model = torchvision.models.video.swin3d_b(
        weights=Swin3D_B_Weights.KINETICS400_V1
    ).eval()
    return torch_model

In [10]:
torch_models = [torch_vswin_tiny(), torch_vswin_small(), torch_vswin_base()]
summary(
    torch_models[0], input_size=(1, 3, 32, 224, 224)
)

Downloading: "https://download.pytorch.org/models/swin3d_t-7615ae03.pth" to /root/.cache/torch/hub/checkpoints/swin3d_t-7615ae03.pth
100%|██████████| 122M/122M [00:02<00:00, 54.0MB/s]
Downloading: "https://download.pytorch.org/models/swin3d_s-da41c237.pth" to /root/.cache/torch/hub/checkpoints/swin3d_s-da41c237.pth
100%|██████████| 218M/218M [00:04<00:00, 55.4MB/s]
Downloading: "https://download.pytorch.org/models/swin3d_b_1k-24f7c7c6.pth" to /root/.cache/torch/hub/checkpoints/swin3d_b_1k-24f7c7c6.pth
100%|██████████| 364M/364M [00:06<00:00, 57.0MB/s]


Layer (type:depth-idx)                                  Output Shape              Param #
SwinTransformer3d                                       [1, 400]                  --
├─PatchEmbed3d: 1-1                                     [1, 16, 56, 56, 96]       --
│    └─Conv3d: 2-1                                      [1, 96, 16, 56, 56]       9,312
│    └─LayerNorm: 2-2                                   [1, 16, 56, 56, 96]       192
├─Dropout: 1-2                                          [1, 16, 56, 56, 96]       --
├─Sequential: 1-3                                       [1, 16, 7, 7, 768]        --
│    └─Sequential: 2-3                                  [1, 16, 56, 56, 96]       --
│    │    └─SwinTransformerBlock: 3-1                   [1, 16, 56, 56, 96]       119,445
│    │    └─SwinTransformerBlock: 3-2                   [1, 16, 56, 56, 96]       119,445
│    └─PatchMerging: 2-4                                [1, 16, 28, 28, 192]      --
│    │    └─LayerNorm: 3-3                    

# Inference

In [11]:
common_input = np.random.normal(0, 1, (1, 32, 224, 224, 3)).astype('float32')
keras_input = ops.array(common_input)
torch_input = torch.from_numpy(common_input.transpose(0, 4, 1, 2, 3))
print(keras_input.shape, torch_input.shape)

torch.Size([1, 32, 224, 224, 3]) torch.Size([1, 3, 32, 224, 224])


In [12]:
def logit_checking(keras_model, torch_model):
    # forward pass
    keras_predict = keras_model(keras_input)
    torch_predict = torch_model(torch_input)
    print(keras_predict.shape, torch_predict.shape)
    print('keras logits: ', keras_predict[0, :5])
    print('torch logits: ', torch_predict[0, :5], end='\n')
    np.testing.assert_allclose(
        keras_predict.detach().numpy(),
        torch_predict.detach().numpy(),
        1e-5, 1e-5
    )
    del keras_model 
    del torch_model

In [13]:
for km, tm in zip(keras_models, torch_models):
    logit_checking(
        km, tm
    )

torch.Size([1, 400]) torch.Size([1, 400])
keras logits:  tensor([-0.0906,  1.2267,  1.1639, -0.3530, -1.5449], grad_fn=<SliceBackward0>)
torch logits:  tensor([-0.0906,  1.2267,  1.1639, -0.3530, -1.5449], grad_fn=<SliceBackward0>)
torch.Size([1, 400]) torch.Size([1, 400])
keras logits:  tensor([ 0.6399,  1.2136,  0.9395, -0.4962, -1.9626], grad_fn=<SliceBackward0>)
torch logits:  tensor([ 0.6399,  1.2136,  0.9395, -0.4962, -1.9626], grad_fn=<SliceBackward0>)
torch.Size([1, 400]) torch.Size([1, 400])
keras logits:  tensor([ 1.1572,  0.0092,  0.0929, -1.8786, -2.8799], grad_fn=<SliceBackward0>)
torch logits:  tensor([ 1.1572,  0.0092,  0.0929, -1.8786, -2.8799], grad_fn=<SliceBackward0>)


In [14]:
import gc
gc.collect()

27

# Keras: Video Swin Base - Pretrained: ImageNet 22K

In [15]:
def vswin_base():
    backbone=VideoSwinBackbone(
        input_shape=(32, 224, 224, 3), 
        embed_dim=128,
        depths=[2, 2, 18, 2],
        num_heads=[4, 8, 16, 32],
        include_rescaling=False, 
    )
    keras_model = VideoClassifier(
        backbone=backbone,
        num_classes=400,
        activation=None,
        pooling='avg',
    )
    keras_model.load_weights(
        '/kaggle/input/videoswin/keras/base/1/videoswin_base_kinetics400_imagenet22k_classifier.weights.h5'
    )
    return keras_model

In [16]:
keras_models = vswin_base()

In [17]:
import torchvision
from torchvision.models.video import Swin3D_B_Weights

torch_model = torchvision.models.video.swin3d_b(
    weights=Swin3D_B_Weights.KINETICS400_IMAGENET22K_V1
).eval()

Downloading: "https://download.pytorch.org/models/swin3d_b_22k-7c6ae6fa.pth" to /root/.cache/torch/hub/checkpoints/swin3d_b_22k-7c6ae6fa.pth
100%|██████████| 364M/364M [00:07<00:00, 51.8MB/s]


In [18]:
logit_checking(
    keras_models, torch_model
)

torch.Size([1, 400]) torch.Size([1, 400])
keras logits:  tensor([ 0.2773,  0.8488,  1.4034, -1.0703, -1.4610], grad_fn=<SliceBackward0>)
torch logits:  tensor([ 0.2773,  0.8488,  1.4034, -1.0703, -1.4610], grad_fn=<SliceBackward0>)
