# VideoTransformer

TimeSformer(https://arxiv.org/abs/2102.05095), ViViT(https://arxiv.org/abs/2103.15691)

Welcome to the demo notebook for VideoTransformer. We'll showcase the prediction result by the above pre-trained models.

## Preliminaries

This section contains initial setup. Run it first.

In [None]:
!pip3 install --user torch
!pip3 install --user torchvision
!pip3 install --user matplotlib
!pip3 install --user decord
!pip3 install --user einops
!pip3 install --user scikit-image
!pip3 install --user pytorch-lightning

In [1]:
import torch
import torch.nn as nn
import numpy as np

from einops import rearrange, reduce, repeat
from IPython.display import display


In [2]:
%cd C:\Users\kenny\Desktop\Poly\22_23_sem2\project\proposed\vivit_testing\VideoTransformer-pytorch

C:\Users\kenny\Desktop\Poly\22_23_sem2\project\proposed\vivit_testing\VideoTransformer-pytorch


In [25]:
import data_transform as T
from dataset import DecordInit, load_annotation_data
from transformer import PatchEmbed, TransformerContainer, ClassificationHead
from video_transformer import ViViT

### Note
Please firstly dowload the weights and move to the current path `./VideoTransformer-pytorch/`
1. TimeSformer-B pre-trained on K400 https://drive.google.com/file/d/1jLkS24jkpmakPi3e5J8KH3FOPv370zvo/view?usp=sharing
2. ViViT-B pre-trained on K400 from https://drive.google.com/file/d/1-JVhSN3QHKUOLkXLWXWn5drdvKn0gPll/view?usp=sharing

## Video Transformer Model

We here load the pretrained weights of the transformer model TimeSformer-B or ViViT-B.

In [4]:
…

In [5]:
def replace_state_dict(state_dict):
	for old_key in list(state_dict.keys()):
		if old_key.startswith('model'):
			new_key = old_key[6:]
			state_dict[new_key] = state_dict.pop(old_key)
		else:
			new_key = old_key[9:]
			state_dict[new_key] = state_dict.pop(old_key)

In [6]:
def init_from_pretrain_(module, pretrained, init_module):
    if torch.cuda.is_available():
        state_dict = torch.load(pretrained)
    else:
        state_dict = torch.load(pretrained, map_location=torch.device('cpu'))
    if init_module == 'transformer':
        replace_state_dict(state_dict)
    elif init_module == 'cls_head':
        replace_state_dict(state_dict)
    else:
        raise TypeError(f'pretrained weights do not include the {init_module} module')
    msg = module.load_state_dict(state_dict, strict=False)
    return msg

In [26]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 
num_frames = 8
frame_interval = 32
num_class = 400
arch = 'vivit' # turn to vivit for initializing vivit model

if arch == 'timesformer':
    pretrain_pth = './timesformer_k400.pth'
    model = TimeSformer(num_frames=num_frames,
                        img_size=224,
                        patch_size=16,
                        embed_dims=768,
                        in_channels=3,
                        attention_type='divided_space_time',
                        return_cls_token=True)
elif arch == 'vivit':
    pretrain_pth = './vivit_model.pth'
    num_frames = num_frames * 2
    frame_interval = frame_interval // 2
    model = ViViT(num_frames=num_frames,
                  img_size=224,
                  patch_size=16,
                  embed_dims=768,
                  in_channels=3,
                  attention_type='fact_encoder',
                  return_cls_token=True,
                    weights_from="kinetics",
                     pretrain_pth=pretrain_pth)
else:
    raise TypeError(f'not supported arch type {arch}, chosen in (timesformer, vivit)')

cls_head = ClassificationHead(num_classes=num_class, in_channels=768)
#msg_trans = init_from_pretrain_(model, pretrain_pth, init_module='transformer')
msg_cls = init_from_pretrain_(cls_head, pretrain_pth, init_module='cls_head')

model.eval()
cls_head.eval()
model = model.to(device)
cls_head = cls_head.to(device)
print(f'load model finished, the missing key of cls is:{msg_cls[0]}')

_IncompatibleKeys(missing_keys=[], unexpected_keys=['cls_head.weight', 'cls_head.bias'])
load model finished, the missing key of cls is:[]


## Data preprocess

Here we show the video demo and transform the video input for the model processing.

In [30]:
from IPython.display import display, HTML

video_path = './demo/YABnJL_bDzw.mp4'
html_str = '''
<video controls width=\"480\" height=\"480\" src=\"{}\">animation</video>
'''.format(video_path)
display(HTML(html_str))

In [31]:
# Prepare data preprocess
mean, std = (0.45, 0.45, 0.45), (0.225, 0.225, 0.225)
data_transform = T.Compose([
        T.Resize(scale_range=(-1, 256)),
        T.ThreeCrop(size=224),
        T.ToTensor(),
        T.Normalize(mean, std)
        ])
temporal_sample = T.TemporalRandomCrop(num_frames*frame_interval)

# Sampling video frames
video_decoder = DecordInit()
v_reader = video_decoder(video_path)
total_frames = len(v_reader)
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
if end_frame_ind-start_frame_ind < num_frames:
    raise ValueError(f'the total frames of the video {video_path} is less than {num_frames}')
frame_indice = np.linspace(0, end_frame_ind-start_frame_ind-1, num_frames, dtype=int)
video = v_reader.get_batch(frame_indice).asnumpy()
del v_reader

display(video.shape)
video = torch.from_numpy(video).permute(0,3,1,2) # Video transform: T C H W
data_transform.randomize_parameters()
video = data_transform(video)
display(video.shape)
video = video.to(device)

#the 3 additional "batch" are from the ThreeCrop Transformation

(16, 256, 454, 3)

torch.Size([3, 16, 3, 224, 224])

## Video Classification

Here we use the pre-trained video transformer to classify the input video.

In [32]:
# Predict class label
with torch.no_grad():
    logits = model(video)
    display(logits.shape)
    output = cls_head(logits)
    display(output.shape)
    output = output.view(3, 400).mean(0)
    
    cls_pred = output.argmax().item()

class_map = './k400_classmap.json'
class_map = load_annotation_data(class_map)
for key, value in class_map.items():
    if int(value) == int(cls_pred):
        print(f'the shape of ouptut: {output.shape}, and the prediction is: {key}')
        break

torch.Size([3, 768])

torch.Size([3, 400])

the shape of ouptut: torch.Size([400]), and the prediction is: laughing
