In [1]:
import numpy as np 
import pandas as pd
import os
import warnings

os.environ["KERAS_BACKEND"] = "torch"
warnings.simplefilter(action="ignore")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

In [2]:
!git clone https://github.com/innat/VideoSwin.git
%cd VideoSwin
!pip install -q -e . 

Cloning into 'VideoSwin'...
remote: Enumerating objects: 729, done.[K
remote: Counting objects: 100% (195/195), done.[K
remote: Compressing objects: 100% (76/76), done.[K
remote: Total 729 (delta 151), reused 150 (delta 119), pack-reused 534[K
Receiving objects: 100% (729/729), 3.53 MiB | 28.72 MiB/s, done.
Resolving deltas: 100% (433/433), done.
/kaggle/working/VideoSwin


In [3]:
import torch
import keras
from keras import ops

keras.__version__, torch.__version__

('3.0.5', '2.1.2+cpu')

# Utility

In [4]:
def logit_checking(models, inputs, name):
    keras_model, torch_model = models
    keras_input, torch_input = inputs

    # forward pass
    keras_predict = keras_model(keras_input)
    torch_predict = torch_model(torch_input)
    
    print('Model: ', name)
    print('Output shape: ', keras_predict.shape, torch_predict.shape)
    print('keras logits: ', keras_predict[0, :5])
    print('torch logits: ', torch_predict[0, :5], end='\n')
    
    np.testing.assert_allclose(
        keras_predict.detach().numpy(),
        torch_predict.detach().numpy(),
        1e-5, 1e-5
    )

    np.testing.assert_allclose(
        keras_predict.detach().numpy(),
        torch_predict.detach().numpy(),
        1e-6, 1e-6
    )

In [5]:
common_input = np.random.normal(0, 1, (1, 32, 224, 224, 3)).astype('float32')
keras_input = ops.array(common_input)
torch_input = torch.from_numpy(common_input.transpose(0, 4, 1, 2, 3))
print(keras_input.shape, torch_input.shape)

torch.Size([1, 32, 224, 224, 3]) torch.Size([1, 3, 32, 224, 224])


# Keras: Video Swin Tiny

In [6]:
!wget https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_tiny_kinetics400_classifier.weights.h5 -q

In [7]:
from videoswin.model import VideoSwinT

keras_model = VideoSwinT(
    input_shape=(32, 224, 224, 3),
    include_rescaling=False,
    num_classes=400,
    activation=None
)

keras_model.load_weights(
    'videoswin_tiny_kinetics400_classifier.weights.h5'
)

# TorchVision: Video Swin Tiny

In [8]:
import torchvision
from torchvision.models.video import Swin3D_T_Weights

torch_model = torchvision.models.video.swin3d_t(
    weights=Swin3D_T_Weights.KINETICS400_V1
).eval()

Downloading: "https://download.pytorch.org/models/swin3d_t-7615ae03.pth" to /root/.cache/torch/hub/checkpoints/swin3d_t-7615ae03.pth
100%|██████████| 122M/122M [00:03<00:00, 36.0MB/s]


In [9]:
logit_checking(
    [keras_model, torch_model], [keras_input, torch_input], name='VideoSwinTiny'
)

Model:  VideoSwinTiny
Output shape:  torch.Size([1, 400]) torch.Size([1, 400])
keras logits:  tensor([-0.1836,  1.2517,  1.0862, -0.3655, -1.4410], grad_fn=<SliceBackward0>)
torch logits:  tensor([-0.1836,  1.2517,  1.0862, -0.3655, -1.4410], grad_fn=<SliceBackward0>)


# Keras: Video Swin Small

In [10]:
!wget https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_small_kinetics400_classifier.weights.h5 -q

In [11]:
from videoswin.model import VideoSwinS

keras_model = VideoSwinS(
    input_shape=(32, 224, 224, 3),
    include_rescaling=False,
    num_classes=400,
    activation=None
)

keras_model.load_weights(
    'videoswin_small_kinetics400_classifier.weights.h5'
)

# TorchVision: Video Swin Small

In [12]:
import torchvision
from torchvision.models.video import Swin3D_S_Weights

torch_model = torchvision.models.video.swin3d_s(
    weights=Swin3D_S_Weights.KINETICS400_V1
).eval()

Downloading: "https://download.pytorch.org/models/swin3d_s-da41c237.pth" to /root/.cache/torch/hub/checkpoints/swin3d_s-da41c237.pth
100%|██████████| 218M/218M [00:05<00:00, 38.5MB/s]


In [13]:
logit_checking(
    [keras_model, torch_model], [keras_input, torch_input], name='VideoSwinSmall'
)

Model:  VideoSwinSmall
Output shape:  torch.Size([1, 400]) torch.Size([1, 400])
keras logits:  tensor([ 0.6722,  1.1854,  0.9514, -0.4893, -1.8892], grad_fn=<SliceBackward0>)
torch logits:  tensor([ 0.6722,  1.1854,  0.9514, -0.4893, -1.8892], grad_fn=<SliceBackward0>)


# Keras: Video Swin Base [ImageNet 22K]

In [14]:
!wget https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_kinetics400_imagenet22k_classifier.weights.h5 -q

In [15]:
from videoswin.model import VideoSwinB

keras_model = VideoSwinB(
    input_shape=(32, 224, 224, 3),
    include_rescaling=False,
    num_classes=400,
    activation=None
)

keras_model.load_weights(
    'videoswin_base_kinetics400_imagenet22k_classifier.weights.h5'
)

# TorchVision: Video Swin Base [ImageNet 22K]

In [16]:
import torchvision
from torchvision.models.video import Swin3D_B_Weights

torch_model = torchvision.models.video.swin3d_b(
    weights=Swin3D_B_Weights.KINETICS400_IMAGENET22K_V1
).eval()

Downloading: "https://download.pytorch.org/models/swin3d_b_22k-7c6ae6fa.pth" to /root/.cache/torch/hub/checkpoints/swin3d_b_22k-7c6ae6fa.pth
100%|██████████| 364M/364M [00:14<00:00, 25.5MB/s]


In [17]:
logit_checking(
    [keras_model, torch_model], [keras_input, torch_input], name='VideoSwinBaseImageNet22K'
)

Model:  VideoSwinBaseImageNet22K
Output shape:  torch.Size([1, 400]) torch.Size([1, 400])
keras logits:  tensor([ 0.3086,  0.7657,  1.4416, -1.0855, -1.4904], grad_fn=<SliceBackward0>)
torch logits:  tensor([ 0.3086,  0.7657,  1.4416, -1.0855, -1.4904], grad_fn=<SliceBackward0>)
