In [1]:
import torch

In [2]:
from model import LipNet

In [3]:
PATH = './weights/LipNet_unseen_loss_0.44562849402427673_wer_0.1332580699113564_cer_0.06796452465503355.pt'
lipnet = LipNet()
lipnet.load_state_dict(torch.load(PATH, map_location=torch.device('cpu')))

<All keys matched successfully>

In [4]:
torch.load(PATH, map_location=torch.device('cpu')).keys()

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'conv3.weight', 'conv3.bias', 'gru1.weight_ih_l0', 'gru1.weight_hh_l0', 'gru1.bias_ih_l0', 'gru1.bias_hh_l0', 'gru1.weight_ih_l0_reverse', 'gru1.weight_hh_l0_reverse', 'gru1.bias_ih_l0_reverse', 'gru1.bias_hh_l0_reverse', 'gru2.weight_ih_l0', 'gru2.weight_hh_l0', 'gru2.bias_ih_l0', 'gru2.bias_hh_l0', 'gru2.weight_ih_l0_reverse', 'gru2.weight_hh_l0_reverse', 'gru2.bias_ih_l0_reverse', 'gru2.bias_hh_l0_reverse', 'FC.weight', 'FC.bias'])

In [5]:
from torch.utils.data import Dataset
import os 
from preprocessing import TokenConv, get_frames_pkl, load_align, HorizontalFlip, padding
import cv2

class LipDataset(Dataset):
    def __init__(self, dataset_path, vid_pad=75, align_pad=40, phase="train") -> None:
        super().__init__()
        self.align_path = os.path.join(dataset_path, phase, "alignments")
        # self.vid_path = os.path.join(dataset_path, phase, "videos")
        self.frames_path = os.path.join(dataset_path, phase, "frames")
        self.vid_pad = vid_pad
        self.align_pad = align_pad
        self.phase = phase
        self.ctccoder = TokenConv()

        self.data = []
        for path, subdirs, files in os.walk(self.frames_path):
            if len(subdirs) != 0:  # if not in subdir, don't do anything
                continue

            spk = path.split(os.path.sep)[-1]  # only speaker name from path
            # print("Speaker: ", spk)

            for file in files:
                # if ".mpg" not in file:  # skip non-video files
                #     continue
                if ".pkl" not in file:  # skip non-pickle files
                    continue
                # print((spk, file.split(".")[0]))

                fname = file.split(".")[0]  # only name of the file without extention
                align_dir = os.path.join(self.align_path, spk, fname + ".align")
                if os.path.exists(align_dir):  # only add when the alignment also exists
                    self.data.append((spk, fname))  # speaker-name and name of the file
        print("Dataset loaded successfully!")
        return None

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        speaker, fname = self.data[index]
        frames_path = os.path.join(self.frames_path, speaker, fname + ".pkl")
        align_path = os.path.join(self.align_path, speaker, fname + ".align")

        vid = get_frames_pkl(frames_path)
        for i, v in enumerate(vid):
            vid[i] = cv2.resize(v, (128, 64))
        align = load_align(align_path)
        align = self.ctccoder.encode(align)

        if self.phase == "train":
            vid = HorizontalFlip(vid)

        vid_len = len(vid)
        align_len = len(align)
        vid = padding(vid, self.vid_pad)
        align = padding(align, self.align_pad)

        return (
            torch.Tensor(vid)/255.0, # normalization
            torch.Tensor(align),
            vid_len,
            align_len,
        )


In [6]:
dataset = LipDataset("./dataset", phase='test')

Dataset loaded successfully!


In [7]:
dataset.data

[('s1', 'pbac2p'),
 ('s1', 'pgwe6n'),
 ('s1', 'prap7a'),
 ('s1', 'bwwn6n'),
 ('s1', 'praj1s'),
 ('s1', 'lbakzn'),
 ('s1', 'prbd1s'),
 ('s1', 'pgbk9a')]

In [8]:
from torch.utils.data import DataLoader
from preprocessing import TokenConv

loader = DataLoader(dataset, batch_size=1, shuffle=False)

In [10]:
ctcdecoder = TokenConv()
lipnet.eval()
for vid, align, vid_len, align_len in loader:
    y = lipnet(vid).log_softmax(-1)
    y = torch.argmax(y, dim=-1)
    for tru, pre in zip(align.tolist(), y.tolist()):
        true_txt = ctcdecoder.ctc_decode(tru)
        pred_txt = ctcdecoder.ctc_decode(pre)
        print("True: ", true_txt)
        print("Pred: ", pred_txt)
        print('-'*10)

True:  place blue at c two please
Pred:  bin blue in d twro again
----------
True:  place gren with e six now
Pred:  bin brue in j six soon
----------
True:  place red at sp p seven again
Pred:  bin red with d nine now
----------
True:  bin white sp with n six now
Pred:  bin white with y six again
----------
True:  place red at j one son
Pred:  bin white by d two soon
----------
True:  lay blue at k zero now
Pred:  place re at t twro soon
----------
True:  place red by d one son
Pred:  bin red by b six soon
----------
True:  place gren by k nine again
Pred:  bin brue by p nine now
----------
