In [None]:
import os
import numpy as np
from PIL import Image
from torch.utils import data
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
from tqdm import tqdm


class VideoEncoder(nn.Module):
    def __init__(self, encoding_size:int):
        """Load the pretrained ResNet-152 and replace top fc layer."""
        super(VideoEncoder, self).__init__()
        
        self.base_network = models.resnet152(pretrained = True)
        self.base_network.fc = nn.Linear(self.base_network.fc.in_features, encoding_size)
        self.bn = nn.BatchNorm1d(encoding_size, momentum=0.01)
        self.init_weights()

    def init_weights(self):
        
        self.base_network.fc.weight.data.normal_(0.0, 0.02)
        self.base_network.fc.bias.data.fill_(0)
        
    def forward(self, x_3d):
        cnn_embed_seq = []
        for t in range(x_3d.size(1)):
            # ResNet CNN
            with torch.no_grad():
                image = x_3d[:, t, :, :, :]
                x = self.base_network.conv1(image)  # ResNet
                x = self.base_network.bn1(x)
                x = self.base_network.relu(x)
                x = self.base_network.maxpool(x)

                x = self.base_network.layer1(x)
                x = self.base_network.layer2(x)
                x = self.base_network.layer3(x)
                x = self.base_network.layer4(x)

                x = self.base_network.avgpool(x)
                x = torch.flatten(x, 1) 

            featureMap = self.base_network.fc(x)
            featureMap = self.bn(featureMap)
            cnn_embed_seq.append(featureMap)

        # swap time and sample dim such that (sample dim, time dim, CNN latent dim)
        cnn_embed_seq = torch.stack(cnn_embed_seq, dim=0).transpose_(0, 1)

#         m = nn.MaxPool2d(2, stride=1)
#         cnn_embed_seq = m(cnn_embed_seq)
        
        return torch.flatten(cnn_embed_seq,1)
        #return cnn_embed_seq