In [1]:
# force re-import
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

import math

from collections import OrderedDict

# force re-import
%load_ext autoreload
%autoreload 2

from densetcn import *
from resnet import *
from preprocess import *
from feature import *

  from .autonotebook import tqdm as notebook_tqdm


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
class Lipreading(nn.Module):
    # configs/lrw_resnet18_dctcn_boundary.json
    def __init__(self, 
                 modality='video', 
                 hidden_dim=256, 
                 backbone_type='resnet', 
                 num_classes=500,
                 relu_type='swish', 
                 tcn_options={}, 
                 densetcn_options={
                 "block_config": [
                    3,
                    3,
                    3,
                    3
                ],
                "growth_rate_set": [
                    384,
                    384,
                    384,
                    384
                ],
                "kernel_size_set": [
                    3,
                    5,
                    7
                ],
                "dilation_size_set": [
                    1,
                    2,
                    5
                ],
                "reduced_size": 512,
                "squeeze_excitation": True,
                "dropout": 0.2,
                 }, 
                 width_mult=1.0,
                 use_boundary=True, 
                 extract_feats=True
                 ):
        super(Lipreading, self).__init__()
        self.extract_feats = extract_feats
        self.backbone_type = backbone_type
        self.modality = modality
        self.use_boundary = use_boundary


        self.frontend_nout = 64
        self.backend_out = 512
        self.trunk = ResNet(BasicBlock, [2, 2, 2, 2], relu_type=relu_type)

        # -- frontend3D
        if relu_type == 'relu':
            frontend_relu = nn.ReLU(True)
        elif relu_type == 'prelu':
            frontend_relu = nn.PReLU( self.frontend_nout )
        elif relu_type == 'swish':
            frontend_relu = Swish()

        self.frontend3D = nn.Sequential(
                    nn.Conv3d(1, self.frontend_nout, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False),
                    nn.BatchNorm3d(self.frontend_nout),
                    frontend_relu,
                    nn.MaxPool3d( kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)))
        """
        self.tcn =  DenseTCN( block_config=densetcn_options['block_config'],
                                growth_rate_set=densetcn_options['growth_rate_set'],
                                input_size=self.backend_out if not self.use_boundary else self.backend_out+1,
                                reduced_size=densetcn_options['reduced_size'],
                                num_classes=num_classes,
                                kernel_size_set=densetcn_options['kernel_size_set'],
                                dilation_size_set=densetcn_options['dilation_size_set'],
                                dropout=densetcn_options['dropout'],
                                relu_type=relu_type,
                                squeeze_excitation=densetcn_options['squeeze_excitation'],
                            )
        """
        # -- initialize
        self._initialize_weights_randomly()


    def forward(self, x, lengths, boundaries=None):
        B, C, T, H, W = x.size()
        x = self.frontend3D(x)
        Tnew = x.shape[2]    # outpu should be B x C2 x Tnew x H x W
        x = threeD_to_2D_tensor( x )
        x = self.trunk(x)

        if self.backbone_type == 'shufflenet':
            x = x.view(-1, self.stage_out_channels)
        x = x.view(B, Tnew, x.size(1))


        # -- duration
        if self.use_boundary:
            x = torch.cat([x, boundaries], dim=-1)
        return x
        #return x if self.extract_feats else self.tcn(x, lengths, B)
    
    def _initialize_weights_randomly(self):

        use_sqrt = True

        if use_sqrt:
            def f(n):
                return math.sqrt( 2.0/float(n) )
        else:
            def f(n):
                return 2.0/float(n)

        for m in self.modules():
            if isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
                n = np.prod( m.kernel_size ) * m.out_channels
                m.weight.data.normal_(0, f(n))
                if m.bias is not None:
                    m.bias.data.zero_()

            elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

            elif isinstance(m, nn.Linear):
                n = float(m.weight.data[0].nelement())
                m.weight.data = m.weight.data.normal_(0, f(n))

In [4]:
m = Lipreading()

In [7]:
pkl = torch.load("lrw_resnet18_dctcn_video_boundary.pth.tar")
state_dict = pkl["model_state_dict"]
m.load_state_dict(state_dict, strict = False)
torch.save(m.state_dict(), str('extractor_lrw_resnet18_dctcn_video_boundary.pt'))

<All keys matched successfully>