In [1]:
import torch.nn as nn
import torchvision.models.video as models
import albumentations as alb
import torch
import numpy as np

#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')


class VideoEncoder(nn.Module):
    def __init__(self, encoding_size: int):
        super(VideoEncoder, self).__init__()
        self.base_network = models.r3d_18(pretrained = True)
        self.base_network.fc = nn.Linear(self.base_network.fc.in_features, encoding_size)
        self.bn = nn.BatchNorm1d(encoding_size, momentum=0.01)
        self.init_weights()

    def init_weights(self):
        
        self.base_network.fc.weight.data.normal_(0.0, 0.02)
        self.base_network.fc.bias.data.fill_(0)
        
        
    def forward(self, vidFile):

        if (vidFile.isOpened() == False):
            print('Error while trying to read video. Please check path again')
            return
        
        clips = []
        clip_len = 25
        
        transform = alb.Compose([
                    alb.Resize(128, 171, always_apply=True),
                    alb.CenterCrop(112, 112, always_apply=True),
                    alb.Normalize(mean = [0.43216, 0.394666, 0.37645],
                                std = [0.22803, 0.22145, 0.216989], 
                                always_apply=True)
                ])

        
        while(vidFile.isOpened()):
            # capture each frame of the video
            ret, frame = vidFile.read()
            if ret == True:
                image = frame.copy()
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = transform(image=frame)['image']
                clips.append(frame)
                if len(clips) == clip_len:
                    with torch.no_grad(): 
                        input_frames = np.array(clips)       
                        input_frames = np.expand_dims(input_frames, axis=0)
                        input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
                        input_frames = torch.tensor(input_frames, dtype=torch.float32)
                        input_frames = input_frames.to(device)
                        x = self.base_network(input_frames)
                        
                        x = torch.flatten(x, 1)
                        
                    featureMap = self.base_network.fc(x)
                    featureMap = self.bn(featureMap)
                    return featureMap