In [309]:
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import os
import argparse
import random
import torch
import torchaudio
import torchvision.io
import pyloudnorm as pyln
import glob
from matplotlib import pyplot as plt
from IPython.display import Video
from IPython.display import Audio
from torchvision.transforms import v2
from torch.utils.data import Dataset
from torch import nn

In [460]:
class LipsyncDataset(Dataset):
    """Audio to animated lip viseme dataset"""

    viseme_labels = ['Ah', 'D', 'Ee', 'F', 'L', 'M', 'Neutral', 'Oh', 'R', 'S', 'Uh', 'Woo']
    
    visemes = {
        'Ah': [0, 1],
        'D': [0, 3],
        'Ee': [0, 2],
        'F': [1, 3],
        'L': [1, 1],
        'M': [2, 0],
        'Neutral': [1, 0],
        'Oh': [1, 2],
        'R': [2, 3],
        'S': [2, 1],
        'Uh': [2, 2],
        'Woo': [0, 0],
    }

    def __init__(self, parquet_file, transform=None, samplerate=16000, table=None):
        if table is not None:
            self.table = table
        else:
            self.table = pq.read_table(parquet_file).to_pandas()
        self.transform = transform
        self.rate = samplerate

    def __len__(self):
        return self.table.shape[0]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        if idx < 0 or idx >= len(self):
            raise IndexError()
        a = torch.Tensor(self.table['audio'][idx].copy())
        v = torch.Tensor(self.table['visemes'][idx].copy())
        sample = {
            'audio': a,
            'visemes': v,
        }
        if self.transform:
            sample = self.transform(sample)
        return sample

    def display_audio(self, idx):
        """Show audio output (only for untransformed audio)"""
        sample = self[idx]
        return Audio(sample['audio'], rate=self.rate)

    def make_video(self, audio, vis_indexes):
        """Make video given visemes (only for untransformed audio)"""
        sprites = v2.Resize((256 * 4, 256 * 4))(torchvision.io.decode_image('../images/demolipssheet_bg.png', 'RGB')).permute([1, 2, 0])
        frames = vis_indexes.shape[0]
        # Make copy of audio in stereo contiguous format for writing to video
        a = np.ascontiguousarray(torch.Tensor(audio.copy()).reshape(1, -1).expand(2, -1).numpy())
        v = torch.zeros(frames, 256, 256, 3)
        for i in range(frames):
            vi = self.viseme_labels[vis_indexes[i]]
            pos = self.visemes[vi]
            v[i, :, :, :] = sprites[pos[0] * 256 : pos[0] * 256 + 256, pos[1] * 256 : pos[1] * 256 + 256, :]
        with tempfile.NamedTemporaryFile(delete_on_close=False, suffix='.mp4', dir='') as f:
            torchvision.io.write_video('out.mp4', v, fps=30, audio_array=a, audio_fps=self.rate, audio_codec='aac')
            return 'out.mp4'

    def display_video(self, idx):
        sample = self[idx]
        fname = self.make_video(sample['audio'], sample['visemes'])
        return Video(fname)

class AudioMFCC(nn.Module):
    '''Analyze audio to MFCC'''
    def __init__(self, audio_rate=16000, num_mels=13):
        super().__init__()
        window_time = 25e-3 # seconds
        window_length = round(window_time * audio_rate)
        hop_time = 10e-3 # seconds
        hop_length = round(hop_time * audio_rate)
        melkwargs = {
            "n_fft": window_length,
            "win_length": window_length,
            "hop_length": hop_length,
        }
        self.mfcc = torchaudio.transforms.MFCC(sample_rate=audio_rate, n_mfcc=num_mels, melkwargs=melkwargs)

    def __call__(self, sample):
        waveform = sample['audio']
        a = self.mfcc(waveform)
        vols = []
        for i in range(a.shape[1]):
            w = waveform[i * hop_length:i * hop_length + window_length].numpy()
            vols.append(np.log(1e-10 + np.sqrt(np.mean(w ** 2))))
        tv = torch.tensor(vols).reshape(1,-1)
        v = sample['visemes']
        return {
            'audio': torch.cat((tv, a)),
            'visemes': v,
        }

class AddDerivatives(nn.Module):
    """Extend audio features with derivatives"""

class Upsample(nn.Module):
    '''Upsample visemes to new framerate'''
    def __init__(self, old_fps=30, new_fps=100):
        super().__init__()
        ratio = new_fps / old_fps
        self.transform_viseme = nn.Upsample(scale_factor=ratio, mode='nearest-exact')

    def __call__(self, sample):
        a = sample['audio']
        # Visemes needs to have batch etc. stuff in front, then also be float to work
        v = self.transform_viseme(sample['visemes'].reshape((1, 1, -1)).to(dtype=torch.float)).reshape((-1,))
        return {
            'audio': a,
            'visemes': v,
        }

class PadVisemes(nn.Module):
    '''Pad visemes by a frame if we need it to match audio size'''
    def __call__(self, sample):
        a = sample['audio']
        v = sample['visemes']
        if v.shape[-1] < a.shape[-1]:
            vv = torch.Tensor(v.shape[0] + 1)
            vv[:-1] = v[:]
            vv[-1] = v[-1]
            v = vv
        return {
            'audio': a,
            'visemes': v,
        }

class RandomChunk(nn.Module):
    '''Extract fixed size block from random position'''
    def __init__(self, size=100, seed=1234):
        super().__init__()
        self.size = size
        self.rng = np.random.default_rng(seed)

    def __call__(self, sample):
        # If sample is too small, play it again
        a = sample['audio']
        v = sample['visemes']
        if a.shape[-1] < self.size:
            aa = torch.zeros(a.shape[0], self.size)
            # v[-1] should be neutral viseme
            vv = torch.ones(self.size) * v[-1]
            offset = self.rng.integers(0, self.size - a.shape[-1])
            aa[:, offset:offset + a.shape[1]] = a[:, :]            
            vv[offset:offset + v.shape[0]] = v[:]
        else:
            offset = self.rng.integers(0, a.shape[-1] - self.size)
            aa = a[:, offset:offset + self.size]
            vv = v[offset:offset + self.size]
        return {
            'audio': aa,
            'visemes': vv,
        }


In [465]:
transform = nn.Sequential(Upsample(), AudioMFCC(), PadVisemes(), AddPower(), RandomChunk(size=100, seed=1))
dataset = LipsyncDataset('../data/lipsync.parquet', table=table, transform=None)
s = dataset[0]
AudioMFCC()(s)['audio'][-1,:]

tensor([-1.5946e-05, -1.5946e-05, -1.5946e-05, -1.5946e-05, -1.5946e-05,
        -1.5946e-05, -1.5946e-05, -1.5946e-05, -1.5946e-05, -1.5946e-05,
        -1.5946e-05, -1.5946e-05, -1.5946e-05, -1.5946e-05, -1.5946e-05,
        -1.5946e-05, -1.5946e-05, -1.5946e-05, -1.5946e-05, -1.5946e-05,
        -1.5946e-05, -1.5946e-05, -1.5946e-05, -1.5946e-05, -1.5946e-05,
        -1.5946e-05, -1.5946e-05, -1.5946e-05, -1.5946e-05, -1.5946e-05,
        -1.5946e-05, -1.5946e-05, -1.5946e-05, -6.3605e+00,  1.3083e+01,
         1.5476e+01,  4.1104e+00,  9.8713e+00,  9.5293e+00,  2.1981e+01,
         1.7255e+01,  1.1071e+01,  3.8165e-01,  1.0367e+01,  9.2018e+00,
        -1.2079e+00,  2.9037e-02,  1.4811e+01,  9.8525e+00, -3.1102e+00,
        -1.4362e+01, -4.6245e+00,  1.8234e+01,  3.8173e+01,  2.6664e+01,
         3.4285e+01,  6.4059e+00, -1.8324e+01, -3.8552e+00, -1.4689e+01,
        -1.8086e+01, -1.6705e+01, -1.0913e+01, -1.5653e+01, -2.1476e+01,
        -2.3182e+01, -2.6040e+01, -3.2820e+01, -3.5

In [232]:
# Resample from 30fps to 100fps
# Needs to be float, nearest-exact
v100 = torch.nn.Upsample(scale_factor=100/30, mode='nearest-exact')(vv)
v100

tensor([[[1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5., 5.,
          6., 6., 6., 7., 7., 7.]]])

In [505]:
dataset = LipsyncDataset('../data/lipsync.parquet', table=table, )#transform=Upsample())
len(dataset)

664

TypeError: Dataset() takes no arguments

In [509]:
nn.Sequential(Upsample(), AudioMFCC(), PadVisemes(), AddPower(), RandomChunk(size=100, seed=1))(dataset[1])

{'audio': tensor([[  -2.4561,   -2.5161,   -2.7308,  ...,   -1.8477,   -1.9166,
            -2.0721],
         [-261.3439, -245.6447, -234.3624,  ..., -185.3834, -191.8687,
          -192.9063],
         [ 142.4426,  109.1230,  111.5967,  ...,   82.6618,   73.3223,
            77.1488],
         ...,
         [  -3.7173,  -11.5020,   -7.3655,  ...,  -31.2948,  -29.1820,
           -23.5247],
         [ -22.9988,  -13.0663,   -5.8015,  ...,   -9.7735,   -5.0283,
            -5.7599],
         [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
             0.0000]]),
 'visemes': tensor([8., 8., 8., 9., 9., 9., 9., 9., 9., 9., 9., 9., 9., 9., 9., 9., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 9., 9., 9.,
         9., 9., 9., 9., 9., 9., 9., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
 

In [236]:
dataset.display_audio(0)

  v = torch.Tensor(self.table['visemes'][idx]).reshape((1, 1, -1)).to(dtype=torch.float)


In [136]:
dataset.display_video(0)

In [495]:
x = np.array([[0]*3 + [1]*10 + [2]*10 + [0]*3, [1]*3 + [0]*10 + [2]*10 + [1]*3])
print(x.shape)
f = np.array([1.0, 1.0, -1.0, -1.0])
f /= 2
#f = f / np.sum(f)
y = np.apply_along_axis(np.convolve, axis=1, arr=x, v=f, mode='same')
y, x.shape, y.shape

(2, 26)


(array([[ 0. ,  0. ,  0.5,  1. ,  0.5,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,
          0. ,  0.5,  1. ,  0.5,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,
         -1. , -2. , -1. ,  0. ],
        [ 1. ,  0.5, -0.5, -1. , -0.5,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,
          0. ,  1. ,  2. ,  1. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,
         -0.5, -1. , -0.5, -0.5]]),
 (2, 26),
 (2, 26))

In [504]:
generator1 = torch.Generator().manual_seed(42)
list(torch.utils.data.random_split(range(10), [3, 7], generator=generator1)[1])

[8, 4, 5, 0, 9, 3, 7]

In [514]:
x = torch.randn(2, 3, 4)

In [520]:
x.to(torch.long)

tensor([[[ 0, -1,  1,  0],
         [ 0,  0,  0,  0],
         [ 0,  0,  0,  0]],

        [[ 0,  0, -1,  0],
         [ 0,  0,  0, -2],
         [ 0,  0,  1,  0]]])

In [516]:
x.reshape(2, 12)

tensor([[-0.5458, -1.8158,  1.8652, -0.2104,  0.5011, -0.7595,  0.9068,  0.1185,
         -0.0562,  0.4933,  0.0621, -0.2162],
        [-0.4685,  0.5074, -1.1797,  0.1185,  0.8213, -0.9577, -0.4403, -2.6315,
         -0.0661,  0.6884,  1.5279,  0.6679]])

In [517]:
x.to(torch.float)

tensor([[[-0.5458, -1.8158,  1.8652, -0.2104],
         [ 0.5011, -0.7595,  0.9068,  0.1185],
         [-0.0562,  0.4933,  0.0621, -0.2162]],

        [[-0.4685,  0.5074, -1.1797,  0.1185],
         [ 0.8213, -0.9577, -0.4403, -2.6315],
         [-0.0661,  0.6884,  1.5279,  0.6679]]])