In [1]:
!pip install decord -q

In [2]:
!git clone --branch video_swin https://github.com/innat/keras-cv.git
%cd keras-cv
!pip install -q -e .

Cloning into 'keras-cv'...
remote: Enumerating objects: 13782, done.[K
remote: Counting objects: 100% (1919/1919), done.[K
remote: Compressing objects: 100% (769/769), done.[K
remote: Total 13782 (delta 1337), reused 1628 (delta 1134), pack-reused 11863[K
Receiving objects: 100% (13782/13782), 25.65 MiB | 27.53 MiB/s, done.
Resolving deltas: 100% (9788/9788), done.
/kaggle/working/keras-cv


In [3]:
import os, warnings
os.environ["KERAS_BACKEND"] = "torch" 
warnings.simplefilter(action="ignore")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

In [4]:
import numpy as np
import pandas as pd
import os, sys
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from decord import VideoReader
from decord import cpu, gpu
from torch.utils.data import Dataset, DataLoader

import keras
from keras import ops
from keras_cv.models import VideoSwinBackbone
from keras_cv.models import VideoClassifier

keras.__version__, torch.__version__

('3.0.5', '2.1.2')

In [5]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7d9cfdd7eb60>

# Data Set

In [6]:
!wget https://raw.githubusercontent.com/innat/VideoSwin/main/data/kinetics_400_labels.csv -q
labels = pd.read_csv('kinetics_400_labels.csv')
labels.head()

Unnamed: 0,id,name
0,0,abseiling
1,1,air drumming
2,2,answering questions
3,3,applauding
4,4,applying cream


In [7]:
id2label = dict(zip(labels.id.tolist(), labels.name.tolist()))
label2id = dict(zip(labels.name.tolist(), labels.id.tolist()))

In [8]:
def process_data(text_file_path, data_folder_path, n=10):
    video_paths = []
    labels = []
    string_labels = []

    # Get all video filenames from the data folder
    all_files_in_data_folder = [
        f for f in os.listdir(data_folder_path) 
        if os.path.isfile(os.path.join(data_folder_path, f))
    ]
    with open(text_file_path, 'r') as f:
        for line in tqdm(f):
            parts = line.strip().split()
            if len(parts) == 2:
                filename, label = parts
                search_string = filename[-n:]
                matching_file = next(
                    (
                        f for f in all_files_in_data_folder 
                        if f.endswith(search_string)
                    ), None
                )
                if matching_file:
                    abs_path = os.path.join(data_folder_path, matching_file)
                    video_paths.append(abs_path)
                    labels.append(int(label))
                    string_labels.append(id2label[int(label)])
                    
    df = pd.DataFrame({
        'video_path': video_paths,
        'label': labels,
        'string_label': string_labels
    })
    return df

# Example usage:
text_file_path = "/kaggle/input/k4testset/kinetics400_val_list_videos.txt"
data_folder_path = "/kaggle/input/k4testset/videos_val"
df = process_data(text_file_path, data_folder_path)
print(df.shape)
df.head()

19796it [00:30, 659.60it/s]

(19796, 3)





Unnamed: 0,video_path,label,string_label
0,/kaggle/input/k4testset/videos_val/jf7RDuUTrsQ...,325,somersaulting
1,/kaggle/input/k4testset/videos_val/JTlatknwOrY...,233,playing harmonica
2,/kaggle/input/k4testset/videos_val/8UxlDNur-Z0...,262,pushing cart
3,/kaggle/input/k4testset/videos_val/y9r115bgfNk...,320,sniffing
4,/kaggle/input/k4testset/videos_val/ZnIDviwA8CE...,244,playing saxophone


# Data Loader

To build the dataloader, we will be using [mmaction](https://mmaction2.readthedocs.io/en/latest/index.html) recipe. 

In [9]:
class VideoInit:
    def transform(self, results):
        container = VideoReader(results['filename'])
        results['total_frames'] = len(container)
        results['video_reader'] = container
        results['avg_fps'] = container.get_avg_fps()
        results['start_index'] = 0
        return results

In [10]:
class VideoSample:
    """Sample frames from the video.

    Required keys are "total_frames", "start_index" , added or modified keys
    are "frame_inds", "frame_interval" and "num_clips".

    Args:
        clip_len (int): Frames of each sampled output clip.
        frame_interval (int): Temporal interval of adjacent sampled frames.
            Default: 1.
        num_clips (int): Number of clips to be sampled. Default: 1.
        temporal_jitter (bool): Whether to apply temporal jittering.
            Default: False.
        twice_sample (bool): Whether to use twice sample when testing.
            If set to True, it will sample frames with and without fixed shift,
            which is commonly used for testing in TSM model. Default: False.
        out_of_bound_opt (str): The way to deal with out of bounds frame
            indexes. Available options are 'loop', 'repeat_last'.
            Default: 'loop'.
        test_mode (bool): Store True when building test or validation dataset.
            Default: False.
        start_index (None): This argument is deprecated and moved to dataset
            class (``BaseDataset``, ``VideoDatset``, ``RawframeDataset``, etc),
            see this: https://github.com/open-mmlab/mmaction2/pull/89.
    """

    def __init__(
        self,
        clip_len,
        frame_interval=1,
        num_clips=1,
        temporal_jitter=False,
        twice_sample=False,
        out_of_bound_opt='loop',
        test_mode=False,
        start_index=None,
        frame_uniform=False,
        multiview=1
    ):
        self.clip_len = clip_len
        self.frame_interval = frame_interval
        self.num_clips = num_clips
        self.temporal_jitter = temporal_jitter
        self.twice_sample = twice_sample
        self.out_of_bound_opt = out_of_bound_opt
        self.test_mode = test_mode
        self.frame_uniform = frame_uniform
        self.multiview=multiview
        assert self.out_of_bound_opt in ['loop', 'repeat_last']

        if start_index is not None:
            warnings.warn(
                'No longer support "start_index" in "SampleFrames", '
                'it should be set in dataset class, see this pr: '
                'https://github.com/open-mmlab/mmaction2/pull/89'
            )

    def _get_train_clips(self, num_frames):
        """Get clip offsets in train mode.

        It will calculate the average interval for selected frames,
        and randomly shift them within offsets between [0, avg_interval].
        If the total number of frames is smaller than clips num or origin
        frames length, it will return all zero indices.

        Args:
            num_frames (int): Total number of frame in the video.

        Returns:
            np.ndarray: Sampled frame indices in train mode.
        """
        ori_clip_len = self.clip_len * self.frame_interval
        avg_interval = (num_frames - ori_clip_len + 1) // self.num_clips

        if avg_interval > 0:
            base_offsets = np.arange(self.num_clips) * avg_interval
            clip_offsets = base_offsets + np.random.randint(
                avg_interval, size=self.num_clips)
        elif num_frames > max(self.num_clips, ori_clip_len):
            clip_offsets = np.sort(
                np.random.randint(
                    num_frames - ori_clip_len + 1, size=self.num_clips))
        elif avg_interval == 0:
            ratio = (num_frames - ori_clip_len + 1.0) / self.num_clips
            clip_offsets = np.around(np.arange(self.num_clips) * ratio)
        else:
            clip_offsets = np.zeros((self.num_clips, ), dtype=np.int32)

        return clip_offsets

    def _get_test_clips(self, num_frames):
        """Get clip offsets in test mode.

        Calculate the average interval for selected frames, and shift them
        fixedly by avg_interval/2. If set twice_sample True, it will sample
        frames together without fixed shift. If the total number of frames is
        not enough, it will return all zero indices.

        Args:
            num_frames (int): Total number of frame in the video.

        Returns:
            np.ndarray: Sampled frame indices in test mode.
        """
        ori_clip_len = self.clip_len * self.frame_interval
        avg_interval = (num_frames - ori_clip_len + 1) / float(self.num_clips)
        if num_frames > ori_clip_len - 1:
            base_offsets = np.arange(self.num_clips) * avg_interval
            clip_offsets = (base_offsets + avg_interval / 2.0).astype(np.int32)
            if self.twice_sample:
                clip_offsets = np.concatenate([clip_offsets, base_offsets])
        else:
            clip_offsets = np.zeros((self.num_clips, ), dtype=np.int32)
        return clip_offsets

    def _sample_clips(self, num_frames):
        """Choose clip offsets for the video in a given mode.

        Args:
            num_frames (int): Total number of frame in the video.

        Returns:
            np.ndarray: Sampled frame indices.
        """
        if self.test_mode:
            clip_offsets = self._get_test_clips(num_frames)
        else:
            if self.multiview == 1:
                clip_offsets = self._get_train_clips(num_frames)
            else:
                clip_offsets = np.concatenate(
                    [
                        self._get_train_clips(num_frames)  
                        for _ in range(self.multiview)
                    ]
                )
        return clip_offsets

    def get_seq_frames(self, num_frames):
        seg_size = float(num_frames - 1) / self.clip_len
        seq = []
        for i in range(self.clip_len):
            start = int(np.round(seg_size * i))
            end = int(np.round(seg_size * (i + 1)))
            if not self.test_mode:
                seq.append(random.randint(start, end))
            else:
                seq.append((start + end) // 2)

        return np.array(seq)

    def transform(self, results):
        """Perform the SampleFrames loading.

        Args:
            results (dict): The resulting dict to be modified and passed
                to the next transform in pipeline.
        """
        total_frames = results['total_frames']
        if self.frame_uniform:  # sthv2 sampling strategy
            assert results['start_index'] == 0
            frame_inds = self.get_seq_frames(total_frames)
        else:
            clip_offsets = self._sample_clips(total_frames)
            frame_inds = clip_offsets[:, None] + np.arange(
                self.clip_len)[None, :] * self.frame_interval
            frame_inds = np.concatenate(frame_inds)

            if self.temporal_jitter:
                perframe_offsets = np.random.randint(
                    self.frame_interval, size=len(frame_inds))
                frame_inds += perframe_offsets

            frame_inds = frame_inds.reshape((-1, self.clip_len))
            if self.out_of_bound_opt == 'loop':
                frame_inds = np.mod(frame_inds, total_frames)
            elif self.out_of_bound_opt == 'repeat_last':
                safe_inds = frame_inds < total_frames
                unsafe_inds = 1 - safe_inds
                last_ind = np.max(safe_inds * frame_inds, axis=1)
                new_inds = (safe_inds * frame_inds + (unsafe_inds.T * last_ind).T)
                frame_inds = new_inds
            else:
                raise ValueError('Illegal out_of_bound option.')

            start_index = results['start_index']
            frame_inds = np.concatenate(frame_inds) + start_index

        results['frame_inds'] = frame_inds.astype(np.int32)
        results['clip_len'] = self.clip_len
        results['frame_interval'] = self.frame_interval
        results['num_clips'] = self.num_clips
        return results

    def __repr__(self):
        repr_str = (f'{self.__class__.__name__}('
                    f'clip_len={self.clip_len}, '
                    f'frame_interval={self.frame_interval}, '
                    f'num_clips={self.num_clips}, '
                    f'temporal_jitter={self.temporal_jitter}, '
                    f'twice_sample={self.twice_sample}, '
                    f'out_of_bound_opt={self.out_of_bound_opt}, '
                    f'test_mode={self.test_mode})')
        return repr_str

In [11]:
class VideoDecode:
    def transform(self, results):
        frame_inds = results['frame_inds']
        container = results['video_reader']
        imgs = container.get_batch(frame_inds).asnumpy()
        imgs = list(imgs)
        results['video_reader'] = None
        del container
        results['imgs'] = imgs
        results['img_shape'] = imgs[0].shape[:2]
        return results

In [12]:
def _scale_size(
    size,
    scale,
):
    if isinstance(scale, (float, int)):
        scale = (scale, scale)
    w, h = size
    return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)

def rescale_size(
    old_size,
    scale,
    return_scale=False
):
    w, h = old_size
    if isinstance(scale, (float, int)):
        if scale <= 0:
            raise ValueError(f'Invalid scale {scale}, must be positive.')
        scale_factor = scale
    elif isinstance(scale, tuple):
        max_long_edge = max(scale)
        max_short_edge = min(scale)
        scale_factor = min(
            max_long_edge / max(h, w),
            max_short_edge / min(h, w)
        )
    else:
        raise TypeError(
            f'Scale must be a number or tuple of int, but got {type(scale)}'
        )

    new_size = _scale_size((w, h), scale_factor)

    if return_scale:
        return new_size, scale_factor
    else:
        return new_size

class VideoResize:
    def __init__(self, r_size):
        self.r_size = (np.inf, r_size)

    def transform(self, results):
        img_h, img_w = results['img_shape']
        new_w, new_h = rescale_size((img_w, img_h), self.r_size)

        imgs = [
            cv2.resize(img, (new_w, new_h))
            for img in results['imgs']
        ]
        results['imgs'] = imgs
        results['img_shape'] = imgs[0].shape[:2]
        return results

In [13]:
class VideoCrop:
    def __init__(self, c_size):
        self.c_size = c_size

    def transform(self, results):
        img_h, img_w = results['img_shape']
        center_x, center_y = img_w // 2, img_h // 2
        x1, x2 = center_x - self.c_size // 2, center_x + self.c_size // 2
        y1, y2 = center_y - self.c_size // 2, center_y + self.c_size // 2
        imgs = [img[y1:y2, x1:x2] for img in results['imgs']]
        results['imgs'] = imgs
        results['img_shape'] = imgs[0].shape[:2]
        return results

In [14]:
class VideoFormat:
    def transform(self, results):
        num_clips = results['num_clips']
        clip_len = results['clip_len']
        imgs = results['imgs']
        # [num_clips*clip_len, H, W, C]
        imgs = np.array(imgs)
        # [num_clips, clip_len, H, W, C]
        imgs = imgs.reshape((num_clips, clip_len) + imgs.shape[1:])
        results['imgs'] = imgs
        return results

In [15]:
item = dict()
item['filename'] = '/kaggle/input/k4testset/videos_val/--07WQ2iBlw.mp4'
v_init   = VideoInit().transform(item)
v_sample = VideoSample(clip_len=16, num_clips=1, test_mode=True).transform(v_init)
v_decode = VideoDecode().transform(v_sample)
v_resize = VideoResize(r_size=256).transform(v_decode)
v_crop   = VideoCrop(c_size=224).transform(v_resize)
v_format = VideoFormat().transform(v_crop)
print(v_format.keys())
v_format['imgs'].shape

dict_keys(['filename', 'total_frames', 'video_reader', 'avg_fps', 'start_index', 'frame_inds', 'clip_len', 'frame_interval', 'num_clips', 'imgs', 'img_shape'])


(1, 16, 224, 224, 3)

[inference-config-video-swin](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/db018fb8896251711791386bbd2127562fd8d6a6/configs/recognition/swin/swin_tiny_patch244_window877_kinetics400_1k.py#L45-L61)

```python
test_pipeline = [
    dict(type='DecordInit'),
    dict(
        type='SampleFrames',
        clip_len=32,
        frame_interval=2,
        num_clips=4,
        test_mode=True),
    dict(type='DecordDecode'),
    dict(type='Resize', scale=(-1, 224)),
    dict(type='ThreeCrop', crop_size=224),
    dict(type='Flip', flip_ratio=0),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='FormatShape', input_format='NCTHW'),
    dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
    dict(type='ToTensor', keys=['imgs'])
]
```

We will skip `ThreeCrop` and `Flip` at the moment.

In [16]:
num_classes=400
batch_size=16
num_clips=4
frame_rate=2 
input_frame=32
h_crop_size=w_crop_size=224

In [17]:
class VideoDataset(Dataset):
    def __init__(self, dataframe, clip_len=1, frame_sample_rate=8):
        self.dataframe = dataframe
        self.clip_len = clip_len
        self.frame_sample_rate = frame_sample_rate

    def __len__(self):
        return len(self.dataframe)

    def get_frames(self, video_path):
        item = dict()
        item['filename'] = video_path
        v_init   = VideoInit().transform(item)
        v_sample = VideoSample(
            clip_len=input_frame, 
            num_clips=num_clips, 
            frame_interval=frame_rate,
            test_mode=True
        ).transform(v_init)
        v_decode = VideoDecode().transform(v_sample)
        v_resize = VideoResize(r_size=256).transform(v_decode)
        v_crop   = VideoCrop(c_size=224).transform(v_resize)
        v_format = VideoFormat().transform(v_crop)
        frames = v_format['imgs']
        return frames

    def __getitem__(self, idx):
        video_path = self.dataframe.iloc[idx, 0]
        label = self.dataframe.iloc[idx, 1]
        video = self.get_frames(video_path)
        return torch.tensor(video).to(torch.float32), torch.tensor(label).to(torch.float32)

In [18]:
dataset = VideoDataset(
    dataframe=df,  
    clip_len=num_clips, frame_sample_rate=frame_rate
)
dataloader = DataLoader(
    dataset, 
    batch_size=batch_size, 
    shuffle=False, 
    pin_memory=True, 
)

In [19]:
for i, (videos, labels) in enumerate(dataloader):
    print(videos.shape, labels.shape)
    if i == 2 :
        break

torch.Size([16, 4, 32, 224, 224, 3]) torch.Size([16])
torch.Size([16, 4, 32, 224, 224, 3]) torch.Size([16])
torch.Size([16, 4, 32, 224, 224, 3]) torch.Size([16])


# Model

In [20]:
!wget https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_tiny_kinetics400_classifier.weights.h5 -q

def vswin_tiny():
    backbone=VideoSwinBackbone(
        input_shape=(32, 224, 224, 3), 
        embed_dim=96,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        include_rescaling=True, 
    )
    keras_model = VideoClassifier(
        backbone=backbone,
        num_classes=num_classes,
        activation=None,
        pooling='avg',
    )
    keras_model.load_weights(
        'videoswin_tiny_kinetics400_classifier.weights.h5'
    )
    return keras_model

In [21]:
model = vswin_tiny()
model.summary()

# Training API

In [22]:
class AverageMeter:
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [23]:
if torch.cuda.is_available():
    model.cuda().eval()
else:
    model.eval()

In [24]:
acc1_meter, acc5_meter = AverageMeter(), AverageMeter()
log_print_freq = 50

In [25]:
pbar = tqdm(enumerate(dataloader), total=len(dataloader), ncols=80, leave=True)

with torch.no_grad():
    for idx, (image, label) in pbar:
        label_id = label
        label_id = label_id.reshape(-1)
        
        b, n, t, h, w, c = image.size() # batch, clip, time-dim, channel, height, width
        tot_similarity = torch.zeros((b,num_classes)).cuda()
        
        for i in range(n):
            image_input = image[:, i, :, :, :, :] # [b,t,h,w,c]
            label_id = label_id.cuda(non_blocking=True)
            image_input = image_input.cuda(non_blocking=True)
            output = model(image_input)
            similarity = output.view(b, -1).softmax(dim=-1)
            tot_similarity += similarity
            
        values_1, indices_1 = tot_similarity.topk(1, dim=-1)
        values_5, indices_5 = tot_similarity.topk(5, dim=-1)
        acc1, acc5 = 0, 0
        
        for i in range(b):
            if indices_1[i] == label_id[i]:
                acc1 += 1
            if label_id[i] in indices_5[i]:
                acc5 += 1
                
        acc1_meter.update(float(acc1) / b * 100, b)
        acc5_meter.update(float(acc5) / b * 100, b)
        
        if idx % log_print_freq == 0:
            pbar.set_postfix(
                Acc1=f"{acc1_meter.avg:.3f}", Acc5=f"{acc5_meter.avg:.3f}"
            )

100%|███████████| 1238/1238 [3:59:55<00:00, 11.63s/it, Acc1=77.690, Acc5=93.297]
