In [1]:
import torch
from torch.utils import data
import glob
import os
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

In [2]:
# To read the images in numerical order
import re
numbers = re.compile(r'(\d+)')
def numericalSort(value):
    parts = numbers.split(value)
    parts[1::2] = map(int, parts[1::2])
    return parts

In [3]:
def get_video_frames(path):
        videogen = skvideo.io.vreader(path)
        frames = np.array([frame for frame in videogen])
        return frames

In [4]:
lips_filelist = sorted(glob.glob('/Users/manideepkolla/Downloads/test_fold/output/*/*.mp4'), key=numericalSort)
masks_filelist = sorted(glob.glob('/Users/manideepkolla/Downloads/test_fold/output/*/*.npy'), key=numericalSort)
spects_filelist = sorted(glob.glob('/Users/manideepkolla/Downloads/test_fold/output/*/*.png'), key=numericalSort)

In [5]:
folders_list = sorted(glob.glob('/Users/manideepkolla/Downloads/test_fold/output/*'), key=numericalSort)

In [6]:
folders_list

['/Users/manideepkolla/Downloads/test_fold/output/5549779787693549159_00001_0',
 '/Users/manideepkolla/Downloads/test_fold/output/5549779787693549159_00001_1',
 '/Users/manideepkolla/Downloads/test_fold/output/5549779787693549159_00001_2',
 '/Users/manideepkolla/Downloads/test_fold/output/5549779787693549159_00001_3',
 '/Users/manideepkolla/Downloads/test_fold/output/5549779787693549159_00001_4']

In [7]:
# DataLoader class

class Dataset(data.Dataset):
    'Characterizes a dataset for PyTorch'
    def __init__(self, lips_filelist, masks_filelist, spects_filelist):
        'Initialization'
        self.lips_filelist = lips_filelist
        self.masks_filelist = masks_filelist
        self.spects_filelist = spects_filelist

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.lips_filelist)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        lips_filename = self.lips_filelist[index]
        mask_filename = self.masks_filelist[index]
        spect_filename = self.spects_filelist[index]
        
        # Read the lips.mp4 file
        lips = get_video_frames(lip_filename)
        
        # Read mask
        mask = np.load(mask_filename)
        
        # Read mixed spectrogram
        spect = io.imread(spect_filename)

        return lips, spect, mask

In [9]:
# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
#cudnn.benchmark = True

# Parameters
params = {'batch_size': 64,
          'shuffle': True,
          'num_workers': 6}
max_epochs = 100

In [10]:
# Generators
training_set = Dataset(lips_filelist, masks_filelist, spects_filelist)
training_generator = data.DataLoader(training_set, **params)

In [None]:
# Loop over epochs
for epoch in range(max_epochs):
    # Training
    for local_batch, local_labels in training_generator:
        # Transfer to GPU
        local_batch, local_labels = local_batch.to(device), local_labels.to(device)

        # Model computations
        [...]

    # Validation
    with torch.set_grad_enabled(False):
        for local_batch, local_labels in validation_generator:
            # Transfer to GPU
            local_batch, local_labels = local_batch.to(device), local_labels.to(device)

            # Model computations
            [...]

In [None]:
class LipNet(nn.Module):
    def __init__(self, opt, vocab_size):
        super(LipNet, self).__init__()
        self.opt = opt
        self.conv = nn.Sequential(
            nn.Conv3d(3, 32, kernel_size=(3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2)),
            nn.ReLU(True),
            nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)),
            nn.Dropout3d(opt.dropout),
            nn.Conv3d(32, 64, kernel_size=(3, 5, 5), stride=(1, 1, 1), padding=(1, 2, 2)),
            nn.ReLU(True),
            nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)),
            nn.Dropout3d(opt.dropout),
            nn.Conv3d(64, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)),
            nn.ReLU(True),
            nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)),
            nn.Dropout3d(opt.dropout)
        )
        # T B C*H*W
        self.gru1 = nn.GRU(96 * 3 * 6, opt.rnn_size, 1, bidirectional=True)
        self.drp1 = nn.Dropout(opt.dropout)
        # T B F
        self.gru2 = nn.GRU(opt.rnn_size * 2, opt.rnn_size, 1, bidirectional=True)
        self.drp2 = nn.Dropout(opt.dropout)
        # T B V
        self.pred = nn.Linear(opt.rnn_size * 2, vocab_size + 1)
        
        # initialisations
        for m in self.conv.modules():
            if isinstance(m, nn.Conv3d):
                init.kaiming_normal_(m.weight, nonlinearity='relu')
                init.constant_(m.bias, 0)

        init.kaiming_normal_(self.pred.weight, nonlinearity='sigmoid')
        init.constant_(self.pred.bias, 0)

        for m in (self.gru1, self.gru2):
            stdv = math.sqrt(2 / (96 * 3 * 6 + opt.rnn_size))
            for i in range(0, opt.rnn_size * 3, opt.rnn_size):
                init.uniform_(m.weight_ih_l0[i: i + opt.rnn_size],
                            -math.sqrt(3) * stdv, math.sqrt(3) * stdv)
                init.orthogonal_(m.weight_hh_l0[i: i + opt.rnn_size])
                init.constant_(m.bias_ih_l0[i: i + opt.rnn_size], 0)
                init.uniform_(m.weight_ih_l0_reverse[i: i + opt.rnn_size],
                            -math.sqrt(3) * stdv, math.sqrt(3) * stdv)
                init.orthogonal_(m.weight_hh_l0_reverse[i: i + opt.rnn_size])
                init.constant_(m.bias_ih_l0_reverse[i: i + opt.rnn_size], 0)
    
    def forward(self, x):
        x = self.conv(x) # B C T H W
        x = x.permute(2, 0, 1, 3, 4).contiguous() # T B C H W
        x = x.view(x.size(0), x.size(1), -1)
        x, _ = self.gru1(x)
        x = self.drp1(x)
        x, _ = self.gru2(x)
        x = self.drp2(x)
        x = self.pred(x)
        
        return x