In [1]:
import base64
import io
import os
import time
import cv2
import numpy as np
from PIL import Image

import torch

The Encoder 
============
This part simply encodes each of the frame to a 512*1 vector

In [2]:
import torchvision

resnetbase = torchvision.models.resnet50(pretrained=True, progress=True)

In [3]:
#The paper uses conv4_block3 which should be layer3, bottleneck 2 layer for us
import torch.nn as nn
import torch.nn.functional as F

class ResNet50Bottom(nn.Module):
    def __init__(self, original_model):
        super(ResNet50Bottom, self).__init__()
        self.rnet=nn.Sequential(*list(original_model.children())[:-4])
        self.left=nn.Sequential(*list(original_model.children())[-4][:3])
        
    def forward(self, x):
        x = self.rnet(x)
        x = self.left(x)
        return x

resencode = ResNet50Bottom(resnetbase)
src = torch.randn(10, 3, 100, 100)
out = resencode(src)
print(out.shape)

torch.Size([10, 1024, 7, 7])


In [4]:
import math
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)



In [5]:
def pairwise_l2_distance(a, b):
  """Computes pairwise distances between all rows of a and b."""
  norm_a = torch.sum(torch.square(a), 1)
  norm_a = torch.reshape(norm_a, (-1, 1))
  norm_b = torch.sum(torch.square(b), 1)
  norm_b = torch.reshape(norm_b, (1, -1))
  dist = torch.maximum(norm_a
                      - 2.0 * torch.matmul(a,torch.transpose(b,0,1))                          + norm_b, torch.tensor(0.0))
  return dist

#returns 1*num_frame*num_frame
def _get_sims(embs):
    """Calculates self-similarity between sequence of embeddings."""
    dist = pairwise_l2_distance(embs, embs)
    sims = -1.0 * dist
    sims = sims.unsqueeze(0)
    return sims

def get_sims(embs, temperature = 13.544):
    batch_size = embs.shape[0]
    seq_len = embs.shape[1]
    embs = torch.reshape(embs, (batch_size, seq_len, -1))

    simsarr=[]
    for i in range(batch_size):
        simsarr.append(_get_sims(embs[i,:,:]).unsqueeze(0))
    
    sims = torch.vstack(simsarr)
    sims /= temperature
    sims = F.softmax(sims, dim=-1)
    return sims
        

In [6]:
encoder = nn.TransformerEncoderLayer(d_model=512, nhead=8)
src = torch.rand(10,100,512)
pos_encoder = PositionalEncoding(512, 0.1)
out = pos_encoder(src)
print(out.shape)
out = encoder(src)
print(out.shape)

torch.Size([10, 100, 512])
torch.Size([10, 100, 512])


In [7]:
#after this we will have an output with 1026 channels * height * width * frames, this is a 4D output. each 3-D object represents one frames encodings. Now we will use a 3-D convolution to encode relations between adjacent frames.

class RepNet(nn.Module):
    def __init__(self, num_frames):
        super(RepNet, self).__init__()
        self.num_frames = num_frames
        self.resnetBase = resencode

        self.Conv3D = nn.Conv3d(in_channels = 1024,
                                out_channels = 512,
                                kernel_size = 3,
                                padding = 3,
                                dilation = 3)
        self.bn1 = nn.BatchNorm3d(512)
        #get_sims
        
        self.conv3x3_1 = nn.Conv2d(in_channels = 1,
                                 out_channels = 32,
                                 kernel_size = 3,
                                 padding = 1)
        #relu after this, now we have batch * num_frames * num_frames*32
        #flatten with dim batch * 32*num_frames * num_frames

        self.conv3x3_2 = nn.Conv2d(in_channels = 32,
                                 out_channels = 512,
                                 kernel_size = 3,
                                 padding = 1)
        self.pos_encoder = PositionalEncoding(512, 0.1)
        trans_encoder_layer = nn.TransformerEncoderLayer(d_model = 512,
                                                nhead = 4,
                                                dim_feedforward = 512,
                                                dropout = 0.1,
                                                activation = 'relu')
        self.trans_encoder=nn.TransformerEncoder(trans_encoder_layer,5)
        #what is the output shape of transformer
        #same as the input, exactly
        #fc layers

        #period prediction
        self.fc1_1 = nn.Linear(self.num_frames*512, 512)
        self.fc1_2 = nn.Linear(512, 1)

        
        #within period module
        self.fc2_1 = nn.Linear(self.num_frames*512, 512)
        self.fc2_2 = nn.Linear(512, 1)

    def forward(self, x):
        batch_size = x.shape[0]
        x = torch.reshape(x, (-1, 3, x.shape[3], x.shape[4]))
        x = self.resnetBase(x)
        x = torch.reshape(x, 
                    (batch_size,-1,x.shape[1],x.shape[2],x.shape[3]))
        x = torch.transpose(x, 1, 2)
        #batch*channel*frame*height*width
        x = self.Conv3D(x)
        x = self.bn1(x)
        x,_ = torch.max(x, 4)
        x,_ = torch.max(x, 3)
        
        final_embs = x
        
        x = torch.transpose(x, 1, 2)
        x = get_sims(x)
        x = F.relu(self.conv3x3_1(x))
        x = F.relu(self.conv3x3_2(x))
        x = torch.reshape(x, (batch_size, 512, -1))
        x = torch.transpose(x, 1, 2)
        x = self.pos_encoder(x)
        x = self.trans_encoder(x)
        x = torch.reshape(x, (x.shape[0],self.num_frames, -1))
        
        print(x.shape)

        y1 = F.relu(self.fc1_1(x))
        y1 = F.relu(self.fc1_2(y1))

        y2 = F.relu(self.fc2_1(x))
        y2 = F.relu(self.fc2_2(y2))
        return y1, y2, final_embs

print("done")

done


In [8]:
#batch_size*num_frame*c*h*w

src = torch.randn(2,64, 3, 112, 112)    
model = RepNet(64)
o1, o2, e = model(src)
print(o1.shape)
print(o2.shape)
print(e.shape)

torch.Size([2, 64, 32768])
torch.Size([2, 64, 1])
torch.Size([2, 64, 1])
torch.Size([2, 512, 64])
