In [1]:
# Base
import itertools
from glob import glob
from tqdm import tqdm
import math
import textgrid
import random

# ML
import torch
import torch.nn.functional as F
from torch.utils.data import DistributedSampler, DataLoader
from torch.utils.tensorboard import SummaryWriter

# Local
from utils.misc import dict_to_object, plot_specgram, plot_waveform
from supervoice.audio import spectogram, load_mono_audio
from supervoice.model_duration import DurationPredictor
from supervoice.tokenizer import Tokenizer
from train_config import config

In [2]:
# Load text grid files
files = glob("datasets/vctk-aligned/**/*.TextGrid")
print(len(files))
files = files[0:10]
files = [textgrid.TextGrid.fromFile(f) for f in files]

88145


In [3]:
# Tokenizer
tokenizer = Tokenizer(config)

# Data extractor
def extract_data(src):

    # Prepare
    token_duration = 0.01
    tokens = src[1]
    time = 0
    output_tokens = []
    output_durations = []

    # Iterate over tokens
    for t in tokens:

        # Resolve durations
        ends = t.maxTime
        duration = math.floor((ends - time) / token_duration)
        time = ends

        # Resolve token
        tok = t.mark
        if tok == '':
            tok = tokenizer.silence_token

        # Apply
        output_tokens.append(tok)
        output_durations.append(duration)

    # Trim start silence
    while(output_tokens[0] == 'SIL'):
        output_durations
    if output_tokens[0] == 'SIL' and output_durations[0] > 1:
        output_durations[0] = 1
    if output_tokens[len(output_tokens) - 1] == 'SIL' and output_durations[len(output_durations) - 1] > 1:
        output_durations[len(output_durations) - 1] = 1

    # Outputs
    return output_tokens, output_durations
    
class TextGridDataset(torch.utils.data.Dataset):
    def __init__(self, files):
        self.files = files
    def __len__(self):
        return len(self.files)        
    def __getitem__(self, index):
        tg = self.files[index]

        # Load tokens/durations
        tokens, durations = extract_data(tg)
        tokens = tokenizer(tokens)
        durations = torch.Tensor(durations)

        # Calculate mask        
        mask_len = random.uniform(0.3, 0.7)
        mask_offset = random.uniform(0, 1 - mask_len)
        mask = torch.zeros(len(durations))
        mask_start = math.floor(mask_offset * len(durations))
        mask_end = math.floor((mask_offset + mask_len) * len(durations))
        mask[mask_start : mask_end] = 1
        mask = mask.bool()

        # Result
        return tokens, durations, mask


In [5]:

# Dataset, model, optimizer
device = "cpu"
dataset = TextGridDataset(files)
dataloader = DataLoader(dataset, batch_size = 1)
model = DurationPredictor(config)
model = model.to(device)
optim = torch.optim.AdamW(model.parameters(), 0.0002, betas=[0.8, 0.99])

checkpoint = torch.load(f'./output/duration_pre.pt', map_location="cpu")
model.load_state_dict(checkpoint['model'])


<All keys matched successfully>

In [6]:
for i in range(10):
    # Predict
    tokens, durations, mask = dataset[i]
    predicted, loss = model(
        tokens = tokens.unsqueeze(0).to(device), 
        durations = durations.unsqueeze(0).to(device), 
        mask = mask.unsqueeze(0).to(device), 
        target = durations.unsqueeze(0).to(device)
    )
    predicted = predicted.squeeze()

    # Log
    print(f'Loss: {loss.item()}')
    print(''.join(f"{tokenizer.tokens[num]:>8}" for num in tokens.tolist()))
    print(''.join(f"{num:8}" for num in predicted.tolist()))
    print(''.join(f"{int(num):8}" for num in durations.tolist()))
    print(''.join(f"{int(num):8}" for num in mask.tolist()))

Loss: 0.09282242506742477
   <SIL>      d̪      ej       m       ɐ       s       t       p       l      ej       f       ɒ       ɹ       i      tʃ       ɐ       ð       ɚ   <SIL>
       0       5       9       7       5       9       4       6       4      10      12       5       6       6      14       6       6       8       0
      90       9       4       8       6       9       5       7       5      10      13       3       5       9      12       8       8      12      29
       0       0       0       0       0       0       0       0       0       1       1       1       1       1       1       1       1       1       0
Loss: 0.19001419842243195
   <SIL>       w      iː       w       ə       ɫ       ɲ      iː       d       t       ə       s       t       ɐ      dʲ       i      d̪       ə       ɹ       ɪ      pʰ       ɒ       ɹ       t   <SIL>       b       ə       f       ɒ       ɹ       ɛ       ɲ       i      dʲ       ɪ       s       ɪ       ʒ       ə       n       ɪ       z