In [1]:
from docopt import docopt

import sys
import gc
import platform
from os.path import dirname, join
from tqdm import tqdm, trange
from datetime import datetime

# The deepvoice3 model
from deepvoice3_pytorch import frontend, builder
import audio
import lrschedule

import torch
from torch import nn
from torch import optim
import torch.backends.cudnn as cudnn
from torch.utils import data as data_utils
from torch.utils.data.sampler import Sampler
import numpy as np
from numba import jit

from nnmnkwii.datasets import FileSourceDataset, FileDataSource
from os.path import join, expanduser
import random

# import librosa.display
import matplotlib
# hopefully this solves the error
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import sys
import os
from tensorboardX import SummaryWriter
from matplotlib import cm
from warnings import warn
from hparams import hparams, hparams_debug_string
import cv2

global_step = 0
global_epoch = 0
use_cuda = torch.cuda.is_available()
if use_cuda:
    cudnn.benchmark = False

_frontend = None  # to be set later


def _pad(seq, max_len, constant_values=0):
    return np.pad(seq, (0, max_len - len(seq)),
                  mode='constant', constant_values=constant_values)


def _pad_2d(x, max_len, b_pad=0):
    x = np.pad(x, [(b_pad, max_len - len(x) - b_pad), (0, 0)],
               mode="constant", constant_values=0)
    return x

def _pad_4d(x, max_len):
    idx = x.shape[0]
    y = np.zeros((max_len,48,96,3))
    y[:idx] = x
    return y

def plot_alignment(alignment, path, title=None, info=None):
    fig, ax = plt.subplots()
    im = ax.imshow(
        alignment,
        aspect='auto',
        origin='lower',
        interpolation='none')
    fig.colorbar(im, ax=ax)
    xlabel = 'Decoder timestep'
    if info is not None:
        xlabel += '\n\n' + info
    plt.xlabel(xlabel)
    plt.ylabel('Encoder timestep')
    if title is not None:
        plt.title(title)
    plt.tight_layout()
    plt.savefig(path, format='png')
    plt.close()


In [2]:

class TextDataSource(FileDataSource):    
    def __init__(self, data_root, speaker_id=None):        
        self.data_root = data_root        
        self.speaker_ids = None
        self.multi_speaker = False
        # If not None, filter by speaker_id
        self.speaker_id = speaker_id

    def collect_files(self):
        meta = join(self.data_root, "train.txt")
        with open(meta, "rb") as f:
            lines = f.readlines()
        l = lines[0].decode("utf-8").split("|")
        assert len(l) == 5 or len(l) == 6
        self.multi_speaker = len(l) == 6
        texts = list(map(lambda l: l.decode("utf-8").split("|")[3], lines))
        if self.multi_speaker:
            speaker_ids = list(map(lambda l: int(l.decode("utf-8").split("|")[-1]), lines))
            # Filter by speaker_id
            # using multi-speaker dataset as a single speaker dataset
            if self.speaker_id is not None:
                indices = np.array(speaker_ids) == self.speaker_id
                texts = list(np.array(texts)[indices])
                self.multi_speaker = False
                return texts

            return texts, speaker_ids
        else:
            return texts

    def collect_features(self, *args):
        if self.multi_speaker:
            text, speaker_id = args
        else:
            text = args[0]
        global _frontend
        if _frontend is None:
            _frontend = getattr(frontend, hparams.frontend)
        seq = _frontend.text_to_sequence(text, p=hparams.replace_pronunciation_prob)

        if platform.system() == "Windows":
            if hasattr(hparams, 'gc_probability'):
                _frontend = None  # memory leaking prevention in Windows
                if np.random.rand() < hparams.gc_probability:
                    gc.collect()  # garbage collection enforced
                    print("GC done")

        if self.multi_speaker:
            return np.asarray(seq, dtype=np.int32), int(speaker_id)
        else:
            return np.asarray(seq, dtype=np.int32)


class _NPYDataSource(FileDataSource):
    def __init__(self, data_root, col, speaker_id=None):
        self.data_root = data_root
        self.col = col
        self.frame_lengths = []
        self.speaker_id = speaker_id

    def collect_files(self):
        meta = join(self.data_root, "train.txt")
        with open(meta, "rb") as f:
            lines = f.readlines()
        l = lines[0].decode("utf-8").split("|")
        assert len(l) == 5 or len(l) == 6
        multi_speaker = len(l) == 6
        self.frame_lengths = list(
            map(lambda l: int(l.decode("utf-8").split("|")[2]), lines))
 
        paths = list(map(lambda l: l.decode("utf-8").split("|")[self.col], lines))
        paths = list(map(lambda f: join(self.data_root, f), paths))
 
        if multi_speaker and self.speaker_id is not None:
            speaker_ids = list(map(lambda l: int(l.decode("utf-8").split("|")[-1]), lines))
            # Filter by speaker_id
            # using multi-speaker dataset as a single speaker dataset
            indices = np.array(speaker_ids) == self.speaker_id
            paths = list(np.array(paths)[indices])
            self.frame_lengths = list(np.array(self.frame_lengths)[indices])
            # aha, need to cast numpy.int64 to int
            self.frame_lengths = list(map(int, self.frame_lengths))
 
        return paths
 
    def collect_features(self, path):
        return np.load(path)

class MelSpecDataSource(_NPYDataSource):
    def __init__(self, data_root, speaker_id=None):
        super(MelSpecDataSource, self).__init__(data_root, 1, speaker_id)


class LinearSpecDataSource(_NPYDataSource):
    def __init__(self, data_root, speaker_id=None):
        super(LinearSpecDataSource, self).__init__(data_root, 0, speaker_id)

class VisualDataSource(FileDataSource):
    def __init__(self, data_root, speaker_id=None):
        self.data_root = data_root 
        self.speaker_ids = None
        self.multi_speaker = False
        # If not None, filter by speaker_id
        self.speaker_id = speaker_id

    def collect_files(self):
        meta = join(self.data_root, "train.txt")
        with open(meta, "rb") as f:
            lines = f.readlines()
        l = lines[0].decode("utf-8").split("|")
        assert len(l) == 5 or len(l) == 6
        self.multi_speaker = len(l) == 6
        video_file = list(map(lambda l: l.decode("utf-8").split("|")[4].strip(), lines))
        if self.multi_speaker:
            speaker_ids = list(map(lambda l: int(l.decode("utf-8").split("|")[-1]), lines))
            # Filter by speaker_id
            # using multi-speaker dataset as a single speaker dataset
            if self.speaker_id is not None:
                indices = np.array(speaker_ids) == self.speaker_id
                texts = list(np.array(texts)[indices])
                self.multi_speaker = False
                return texts

            return video_file, speaker_ids
        else:
            return video_file

    def collect_features(self, path):
#             add wavs to path and then read the video files
#         path = path = './wavs/'
        path = os.path.join(self.data_root, path)
        cap = cv2.VideoCapture(path)
        frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        # frameWidth = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        # frameHeight = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

        # video = np.empty((frameCount, frameHeight, frameWidth, 3), np.dtype('uint8'))

        ret = True
        video = []
        fc = 0
        while fc < frameCount and ret:
            ret, frame = cap.read()
            video.append(frame)
            fc += 1
        cap.release()
        video.append(np.ones(frame.shape, dtype=frame.dtype)*255) # stop frame, all white
        x = np.asarray(video) / 255.
        return x
        
        
class PartialyRandomizedSimilarTimeLengthSampler(Sampler):
    """Partially randmoized sampler

    1. Sort by lengths
    2. Pick a small patch and randomize it
    3. Permutate mini-batchs
    """

    def __init__(self, lengths, batch_size=16, batch_group_size=None,
                 permutate=True):
        self.lengths, self.sorted_indices = torch.sort(torch.LongTensor(lengths))
        self.batch_size = batch_size
        if batch_group_size is None:
            batch_group_size = min(batch_size * 32, len(self.lengths))
            if batch_group_size % batch_size != 0:
                batch_group_size -= batch_group_size % batch_size

        self.batch_group_size = batch_group_size
        assert batch_group_size % batch_size == 0
        self.permutate = permutate

    def __iter__(self):
        indices = self.sorted_indices.clone()
        batch_group_size = self.batch_group_size
        s, e = 0, 0
        for i in range(len(indices) // batch_group_size):
            s = i * batch_group_size
            e = s + batch_group_size
            random.shuffle(indices[s:e])

        # Permutate batches
        if self.permutate:
            perm = np.arange(len(indices[:e]) // self.batch_size)
            random.shuffle(perm)
            indices[:e] = indices[:e].view(-1, self.batch_size)[perm, :].view(-1)

        # Handle last elements
        s += batch_group_size
        if s < len(indices):
            random.shuffle(indices[s:])

        return iter(indices)

    def __len__(self):
        return len(self.sorted_indices)


class PyTorchDataset(object):
    def __init__(self, X, Mel, Y, V):
        self.X = X
        self.Mel = Mel
        self.Y = Y
        self.V = V
        # alias
        self.multi_speaker = X.file_data_source.multi_speaker
 
    def __getitem__(self, idx):
        if self.multi_speaker:
            text, speaker_id = self.X[idx]
            return text, self.Mel[idx], self.Y[idx], speaker_id
        else:
            return self.X[idx], self.Mel[idx], self.Y[idx], self.V[idx]
 
    def __len__(self):
        return len(self.X)


def sequence_mask(sequence_length, max_len=None):
    if max_len is None:
        max_len = sequence_length.data.max()
    batch_size = sequence_length.size(0)
    seq_range = torch.arange(0, max_len).long()
    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
    if sequence_length.is_cuda:
        seq_range_expand = seq_range_expand.cuda()
    seq_length_expand = sequence_length.unsqueeze(1) \
        .expand_as(seq_range_expand)
    return (seq_range_expand < seq_length_expand).float()


class MaskedL1Loss(nn.Module):
    def __init__(self):
        super(MaskedL1Loss, self).__init__()
        self.criterion = nn.L1Loss(reduction="sum")

    def forward(self, input, target, lengths=None, mask=None, max_len=None):
        if lengths is None and mask is None:
            raise RuntimeError("Should provide either lengths or mask")

        # (B, T, 1)
        if mask is None:
            mask = sequence_mask(lengths, max_len).unsqueeze(-1)

        # (B, T, D)
        mask_ = mask.expand_as(input)
        loss = self.criterion(input * mask_, target * mask_)
        return loss / mask_.sum()


def collate_fn(batch):
    """Create batch"""
    r = hparams.outputs_per_step
    downsample_step = hparams.downsample_step
    multi_speaker = len(batch[0]) == 5

    # Lengths
    input_lengths = [len(x[0]) for x in batch]
    max_input_len = max(input_lengths)

    target_lengths = [len(x[1]) for x in batch]

    max_target_len = max(target_lengths)
    if max_target_len % r != 0:
        max_target_len += r - max_target_len % r
        assert max_target_len % r == 0
    if max_target_len % downsample_step != 0:
        max_target_len += downsample_step - max_target_len % downsample_step
        assert max_target_len % downsample_step == 0
  
      # Set 0 for zero beginning padding
      # imitates initial decoder states
    b_pad = r
    max_target_len += b_pad * downsample_step

    a = np.array([_pad(x[0], max_input_len) for x in batch], dtype=np.int)
    x_batch = torch.LongTensor(a)

    input_lengths = torch.LongTensor(input_lengths)
    target_lengths = torch.LongTensor(target_lengths)

    b = np.array([_pad_2d(x[1], max_target_len, b_pad=b_pad) for x in batch],
              dtype=np.float32)
    mel_batch = torch.FloatTensor(b)

    c = np.array([_pad_2d(x[2], max_target_len, b_pad=b_pad) for x in batch],
              dtype=np.float32)
    y_batch = torch.FloatTensor(c)
    
    ## VIDEO ADDITION ##
    vid_input_lengths = [x[3].shape[0] for x in batch]
    vid_input_lengths = torch.LongTensor(vid_input_lengths)
    max_vid_len = max(vid_input_lengths)
    d = np.array([_pad_4d(x[3], max_vid_len) for x in batch],
              dtype=np.float32)
    v_batch = torch.FloatTensor(d)
    vid_positions = np.array([_pad(np.arange(1, x[3].shape[0] + 1), max_vid_len)
                            for x in batch], dtype=np.int)
    vid_positions = torch.LongTensor(vid_positions)
    
    ####################
    # text positions
    text_positions = np.array([_pad(np.arange(1, len(x[0]) + 1), max_input_len)
                            for x in batch], dtype=np.int)
    text_positions = torch.LongTensor(text_positions)
  
    max_decoder_target_len = max_target_len // r // downsample_step

    # frame positions
    s, e = 1, max_decoder_target_len + 1
    # if b_pad > 0:
    #    s, e = s - 1, e - 1
    # NOTE: needs clone to supress RuntimeError in dataloarder...
    # ref: https://github.com/pytorch/pytorch/issues/10756
    frame_positions = torch.arange(s, e).long().unsqueeze(0).expand(
     len(batch), max_decoder_target_len).clone()

    # done flags
    done = np.array([_pad(np.zeros(len(x[1]) // r // downsample_step - 1),
                       max_decoder_target_len, constant_values=1)
                  for x in batch])
    done = torch.FloatTensor(done).unsqueeze(-1)
  
    if multi_speaker:
        speaker_ids = torch.LongTensor([x[3] for x in batch])
    else:
        speaker_ids = None
  
    return x_batch, (input_lengths, vid_input_lengths), mel_batch, y_batch, v_batch, \
        (text_positions, frame_positions, vid_positions), done, target_lengths, speaker_ids


def time_string():
    return datetime.now().strftime('%Y-%m-%d %H:%M')


def save_alignment(path, attn, title):
    plot_alignment(attn.T, path, title, info="{}, {}, step={}".format(
        hparams.builder, time_string(), global_step))


def prepare_spec_image(spectrogram):
    # [0, 1]
    spectrogram = (spectrogram - np.min(spectrogram)) / (np.max(spectrogram) - np.min(spectrogram))
    spectrogram = np.flip(spectrogram, axis=1)  # flip against freq axis
    return np.uint8(cm.magma(spectrogram.T) * 255)



In [3]:
def eval_model(global_step, writer, device, model, checkpoint_dir, ismultispeaker):
    # harded coded
    texts = [
        "Scientists at the CERN laboratory say they have discovered a new particle.",
        "There's a way to measure the acute emotional intelligence that has never gone out of style.",
        "President Trump met with other leaders at the Group of 20 conference.",
        "Generative adversarial network or variational auto-encoder.",
        "Please call Stella.",
        "Some have accepted this as a miracle without any physical explanation.",
    ]
    import synthesis
    synthesis._frontend = _frontend

    eval_output_dir = join(checkpoint_dir, "eval")
    os.makedirs(eval_output_dir, exist_ok=True)

    # Prepare model for evaluation
    model_eval = build_model().to(device)
    model_eval.load_state_dict(model.state_dict())

    # hard coded
    speaker_ids = [0, 1, 10] if ismultispeaker else [None]
    for speaker_id in speaker_ids:
        speaker_str = "multispeaker{}".format(speaker_id) if speaker_id is not None else "single"

        for idx, text in enumerate(texts):
            signal, alignment, _, mel = synthesis.tts(
                model_eval, text, p=0, speaker_id=speaker_id, fast=True)
            signal /= np.max(np.abs(signal))

            # Alignment
            path = join(eval_output_dir, "step{:09d}_text{}_{}_alignment.png".format(
                global_step, idx, speaker_str))
            save_alignment(path, alignment)
            tag = "eval_averaged_alignment_{}_{}".format(idx, speaker_str)
            # try:
            #     writer.add_image(tag, np.uint8(cm.viridis(np.flip(alignment, 1).T) * 255), global_step)
            # except Exception as e:
            #     warn(str(e))

            # # Mel
            # try:
            #     writer.add_image("(Eval) Predicted mel spectrogram text{}_{}".format(idx, speaker_str),
            #                      prepare_spec_image(mel), global_step)
            # except Exception as e:
            #     warn(str(e))

            # Audio
            path = join(eval_output_dir, "step{:09d}_text{}_{}_predicted.wav".format(
                global_step, idx, speaker_str))
            audio.save_wav(signal, path)

            # try:
            #     writer.add_audio("(Eval) Predicted audio signal {}_{}".format(idx, speaker_str),
            #                      signal, global_step, sample_rate=hparams.sample_rate)
            # except Exception as e:
            #     warn(str(e))
            #     pass


def save_states(global_step, writer, mel_outputs, linear_outputs, attn, attn_f, mel, y,
                input_lengths, checkpoint_dir=None):
    print("Save intermediate states at step {}".format(global_step))

    # idx = np.random.randint(0, len(input_lengths))
    idx = min(1, len(input_lengths) - 1)
    input_length = input_lengths[idx]

    # Alignment
    # Multi-hop attention
    if attn is not None and attn.dim() == 4:
        for i, alignment in enumerate(attn):
            alignment = alignment[idx].cpu().data.numpy()
            tag = "alignment_layer{}".format(i + 1)
            try:
                # writer.add_image(tag, np.uint8(cm.viridis(
                #     np.flip(alignment, 1).T) * 255), global_step)
                # save files as well for now
                alignment_dir = join(
                    checkpoint_dir, "alignment_layer{}".format(i + 1))
                os.makedirs(alignment_dir, exist_ok=True)
                path = join(alignment_dir, "step{:09d}_layer_{}_alignment.png".format(
                    global_step, i + 1))
                save_alignment(path, alignment, "Text")
            except Exception as e:
                warn(str(e))

        # Save averaged alignment
        alignment_dir = join(checkpoint_dir, "alignment_ave")
        os.makedirs(alignment_dir, exist_ok=True)
        path = join(alignment_dir, "step{:09d}_layer_alignment.png".format(global_step))
        alignment = attn.mean(0)[idx].cpu().data.numpy()
        save_alignment(path, alignment, "Text")
        tag = "averaged_alignment"

        # try:
        #     writer.add_image(tag, np.uint8(cm.viridis(
        #         np.flip(alignment, 1).T) * 255), global_step)
        # except Exception as e:
        #     warn(str(e))

    #### Video alignment
    if attn_f is not None and attn_f.dim() == 4:
        for i, alignment in enumerate(attn_f):
            alignment = alignment[idx].cpu().data.numpy()
            tag = "alignment_layer_v{}".format(i + 1)
            try:
                # writer.add_image(tag, np.uint8(cm.viridis(
                #     np.flip(alignment, 1).T) * 255), global_step)
                # save files as well for now
                alignment_dir = join(
                    checkpoint_dir, "alignment_layer{}".format(i + 1))
                os.makedirs(alignment_dir, exist_ok=True)
                path = join(alignment_dir, "step{:09d}_layer_{}_alignment_v.png".format(
                    global_step, i + 1))
                save_alignment(path, alignment, "Video")
            except Exception as e:
                warn(str(e))

        # Save averaged alignment
        alignment_dir = join(checkpoint_dir, "alignment_ave")
        os.makedirs(alignment_dir, exist_ok=True)
        path = join(alignment_dir, "step{:09d}_layer_alignment_v.png".format(global_step))
        alignment = attn_f.mean(0)[idx].cpu().data.numpy()
        save_alignment(path, alignment, "Video")
        tag = "averaged_alignment_v"

        # try:
        #     writer.add_image(tag, np.uint8(cm.viridis(
        #         np.flip(alignment, 1).T) * 255), global_step)
        # except Exception as e:
        #     warn(str(e))

    # Predicted mel spectrogram
    if mel_outputs is not None:
        mel_output = mel_outputs[idx].cpu().data.numpy()
        mel_output = prepare_spec_image(audio._denormalize(mel_output))
        # try:
        #     writer.add_image("Predicted mel spectrogram",
        #                      mel_output, global_step)
        # except Exception as e:
        #     warn(str(e))
        #     pass

    # Predicted spectrogram
    if linear_outputs is not None:
        linear_output = linear_outputs[idx].cpu().data.numpy()
        spectrogram = prepare_spec_image(audio._denormalize(linear_output))
        # try:
        #     writer.add_image("Predicted linear spectrogram",
        #                      spectrogram, global_step)
        # except Exception as e:
        #     warn(str(e))
        #     pass

        # Predicted audio signal
        signal = audio.inv_spectrogram(linear_output.T)
        signal /= np.max(np.abs(signal))
        path = join(checkpoint_dir, "step{:09d}_predicted.wav".format(
            global_step))
        # try:
        #     writer.add_audio("Predicted audio signal", signal,
        #                      global_step, sample_rate=hparams.sample_rate)
        # except Exception as e:
        #     warn(str(e))
        #     pass
        audio.save_wav(signal, path)

    # Target mel spectrogram
    if mel_outputs is not None:
        mel_output = mel[idx].cpu().data.numpy()
        mel_output = prepare_spec_image(audio._denormalize(mel_output))
        # try:
        #     writer.add_image("Target mel spectrogram", mel_output, global_step)
        # except Exception as e:
        #     warn(str(e))
        #     pass

    # Target spectrogram
    if linear_outputs is not None:
        linear_output = y[idx].cpu().data.numpy()
        spectrogram = prepare_spec_image(audio._denormalize(linear_output))
        # try:
        #     writer.add_image("Target linear spectrogram",
        #                      spectrogram, global_step)
        # except Exception as e:
        #     warn(str(e))
        #     pass


def logit(x, eps=1e-8):
    return torch.log(x + eps) - torch.log(1 - x + eps)


def masked_mean(y, mask):
    # (B, T, D)
    mask_ = mask.expand_as(y)
    return (y * mask_).sum() / mask_.sum()


def spec_loss(y_hat, y, mask, priority_bin=None, priority_w=0):
    masked_l1 = MaskedL1Loss()
    l1 = nn.L1Loss()

    w = hparams.masked_loss_weight

    # L1 loss
    if w > 0:
        assert mask is not None
        l1_loss = w * masked_l1(y_hat, y, mask=mask) + (1 - w) * l1(y_hat, y)
    else:
        assert mask is None
        l1_loss = l1(y_hat, y)

    # Priority L1 loss
    if priority_bin is not None and priority_w > 0:
        if w > 0:
            priority_loss = w * masked_l1(
                y_hat[:, :, :priority_bin], y[:, :, :priority_bin], mask=mask) \
                + (1 - w) * l1(y_hat[:, :, :priority_bin], y[:, :, :priority_bin])
        else:
            priority_loss = l1(y_hat[:, :, :priority_bin], y[:, :, :priority_bin])
        l1_loss = (1 - priority_w) * l1_loss + priority_w * priority_loss

    # Binary divergence loss
    if hparams.binary_divergence_weight <= 0:
        binary_div = y.data.new(1).zero_()
    else:
        y_hat_logits = logit(y_hat)
        z = -y * y_hat_logits + torch.log1p(torch.exp(y_hat_logits))
        if w > 0:
            binary_div = w * masked_mean(z, mask) + (1 - w) * z.mean()
        else:
            binary_div = z.mean()

    return l1_loss, binary_div


@jit(nopython=True)
def guided_attention(N, max_N, T, max_T, g):
    W = np.zeros((max_N, max_T), dtype=np.float32)
    for n in range(N):
        for t in range(T):
            W[n, t] = 1 - np.exp(-(n / N - t / T)**2 / (2 * g * g))
    return W


def guided_attentions(input_lengths, target_lengths, max_target_len, g=0.2):
    B = len(input_lengths)
    max_input_len = input_lengths.max()
    W = np.zeros((B, max_target_len, max_input_len), dtype=np.float32)
    for b in range(B):
        W[b] = guided_attention(input_lengths[b], max_input_len,
                                target_lengths[b], max_target_len, g).T
    return W



In [4]:
def train(device, model, data_loader, optimizer, writer,
          init_lr=0.002,
          checkpoint_dir=None, checkpoint_interval=None, nepochs=None,
          clip_thresh=1.0,
          train_seq2seq=True, train_postnet=True, batch_size=4):
    linear_dim = model.linear_dim
    r = hparams.outputs_per_step
    downsample_step = hparams.downsample_step
    current_lr = init_lr
    print(device)
    binary_criterion = nn.BCELoss()

    assert train_seq2seq or train_postnet

    global global_step, global_epoch
    while global_epoch < nepochs:
        running_loss = 0.
        for step, (x, input_audvid_lengths, mel, y, v, positions, done, target_lengths,
                   speaker_ids) \
                in tqdm(enumerate(data_loader)):
            model.train()
            ismultispeaker = speaker_ids is not None
            # Learning rate schedule
            if hparams.lr_schedule is not None:
                lr_schedule_f = getattr(lrschedule, hparams.lr_schedule)
                current_lr = lr_schedule_f(
                    init_lr, global_step, **hparams.lr_schedule_kwargs)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = current_lr
            optimizer.zero_grad()

            # Used for Position encoding
            text_positions, frame_positions, vid_positions = positions
            input_lengths, video_length = input_audvid_lengths 
            # Downsample mel spectrogram
            if downsample_step > 1:
                mel = mel[:, 0::downsample_step, :].contiguous()

            # Lengths
            input_lengths = input_lengths.long().numpy()
            video_length = video_length.long().numpy()
            decoder_lengths = target_lengths.long().numpy() // r // downsample_step

            max_seq_len = max(input_lengths.max(), decoder_lengths.max())
            if max_seq_len >= hparams.max_positions:
                raise RuntimeError(
                    """max_seq_len ({}) >= max_posision ({})
Input text or decoder targget length exceeded the maximum length.
Please set a larger value for ``max_position`` in hyper parameters.""".format(
                        max_seq_len, hparams.max_positions))

            # Transform data to CUDA device
            if train_seq2seq:
                x = x.to(device)
                text_positions = text_positions.to(device)
                frame_positions = frame_positions.to(device)
            if train_postnet:
                y = y.to(device)
#             video_length = video_length.to(device)
            v = v.to(device)
            # print(v.shape)
            vid_positions = vid_positions.to(device)
            # print(video_length)
            mel, done = mel.to(device), done.to(device)
            target_lengths = target_lengths.to(device)
            speaker_ids = speaker_ids.to(device) if ismultispeaker else None

            # Create mask if we use masked loss
            if hparams.masked_loss_weight > 0:
                # decoder output domain mask
                decoder_target_mask = sequence_mask(
                    target_lengths // (r * downsample_step),
                    max_len=mel.size(1)).unsqueeze(-1)
                if downsample_step > 1:
                    # spectrogram-domain mask
                    target_mask = sequence_mask(
                        target_lengths, max_len=y.size(1)).unsqueeze(-1)
                else:
                    target_mask = decoder_target_mask
                # shift mask
                decoder_target_mask = decoder_target_mask[:, r:, :]
                target_mask = target_mask[:, r:, :]
            else:
                decoder_target_mask, target_mask = None, None

            # Apply model
            if train_seq2seq and train_postnet:
#                 print('l')
                mel_outputs, linear_outputs, attn, done_hat, attn_f = model(
                    x, mel, speaker_ids=speaker_ids,
                    text_positions=text_positions, frame_positions=frame_positions, vid_positions=vid_positions,
                    input_lengths=input_lengths, faces=v, video_lengths=video_length, batch_size=batch_size)
            elif train_seq2seq:
                assert speaker_ids is None
                mel_outputs, attn, done_hat, _ = model.seq2seq(
                    x, mel,
                    text_positions=text_positions, frame_positions=frame_positions,
                    input_lengths=input_lengths)
                # reshape
                mel_outputs = mel_outputs.view(len(mel), -1, mel.size(-1))
                linear_outputs = None
            elif train_postnet:
                assert speaker_ids is None
                linear_outputs = model.postnet(mel)
                mel_outputs, attn, done_hat = None, None, None

            # Losses
            w = hparams.binary_divergence_weight

            # mel:
            if train_seq2seq:
                mel_l1_loss, mel_binary_div = spec_loss(
                    mel_outputs[:, :-r, :], mel[:, r:, :], decoder_target_mask)
                mel_loss = (1 - w) * mel_l1_loss + w * mel_binary_div

            # done:
            if train_seq2seq:
                done_loss = binary_criterion(done_hat, done)

            # linear:
            if train_postnet:
                n_priority_freq = int(hparams.priority_freq / (hparams.sample_rate * 0.5) * linear_dim)
                linear_l1_loss, linear_binary_div = spec_loss(
                    linear_outputs[:, :-r, :], y[:, r:, :], target_mask,
                    priority_bin=n_priority_freq,
                    priority_w=hparams.priority_freq_weight)
                linear_loss = (1 - w) * linear_l1_loss + w * linear_binary_div

            # Combine losses
            if train_seq2seq and train_postnet:
                loss = mel_loss + linear_loss + done_loss
            elif train_seq2seq:
                loss = mel_loss + done_loss
            elif train_postnet:
                loss = linear_loss

            # attention
            # print(input_lengths.shape)
            # print(video_length.shape)
            if train_seq2seq and hparams.use_guided_attention:
                soft_mask = guided_attentions(input_lengths, decoder_lengths,
                                              attn.size(-2),
                                              g=hparams.guided_attention_sigma)
                soft_mask = torch.from_numpy(soft_mask).to(device)
                attn_loss = (attn * soft_mask).mean()
                loss += attn_loss

                soft_mask_f = guided_attentions(video_length, decoder_lengths,
                                              attn_f.size(-2),
                                              g=hparams.guided_attention_sigma)
                # print('error check')
                # print(decoder_lengths)
                # print(attn_f.shape,soft_mask_f.shape)
                soft_mask_f = torch.from_numpy(soft_mask_f).to(device)
                attn_loss_f = (attn_f * soft_mask_f).mean()
                loss += attn_loss_f


            if global_step > 0 and global_step % checkpoint_interval == 0:
                save_states(
                    global_step, writer, mel_outputs, linear_outputs, attn, attn_f,
                    mel, y, input_lengths, checkpoint_dir)
                save_checkpoint(
                    model, optimizer, global_step, checkpoint_dir, global_epoch,
                    train_seq2seq, train_postnet)

            # if global_step > 0 and global_step % hparams.eval_interval == 0:
            #     eval_model(global_step, writer, device, model,
            #                checkpoint_dir, ismultispeaker)

            # Update
            loss.backward()
            if clip_thresh > 0:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.get_trainable_parameters(), clip_thresh)
            optimizer.step()

            # Logs
            # writer.add_scalar("loss", float(loss.item()), global_step)
            # if train_seq2seq:
            #     writer.add_scalar("done_loss", float(done_loss.item()), global_step)
            #     writer.add_scalar("mel loss", float(mel_loss.item()), global_step)
            #     writer.add_scalar("mel_l1_loss", float(mel_l1_loss.item()), global_step)
            #     writer.add_scalar("mel_binary_div_loss", float(mel_binary_div.item()), global_step)
            # if train_postnet:
            #     writer.add_scalar("linear_loss", float(linear_loss.item()), global_step)
            #     writer.add_scalar("linear_l1_loss", float(linear_l1_loss.item()), global_step)
            #     writer.add_scalar("linear_binary_div_loss", float(linear_binary_div.item()), global_step)
            # if train_seq2seq and hparams.use_guided_attention:
            #     writer.add_scalar("attn_loss", float(attn_loss.item()), global_step)
            # if clip_thresh > 0:
            #     writer.add_scalar("gradient norm", grad_norm, global_step)
            # writer.add_scalar("learning rate", current_lr, global_step)

            global_step += 1
            running_loss += loss.item()

        averaged_loss = running_loss / (len(data_loader))
        # writer.add_scalar("loss (per epoch)", averaged_loss, global_epoch)
        print("Loss: {}".format(running_loss / (len(data_loader))))

        global_epoch += 1




In [5]:
def build_model():
    model = getattr(builder, hparams.builder)(
        n_speakers=hparams.n_speakers,
        speaker_embed_dim=hparams.speaker_embed_dim,
        n_vocab=_frontend.n_vocab,
        embed_dim=hparams.text_embed_dim,
        mel_dim=hparams.num_mels,
        linear_dim=hparams.fft_size // 2 + 1,
        r=hparams.outputs_per_step,
        downsample_step=hparams.downsample_step,
        padding_idx=hparams.padding_idx,
        dropout=hparams.dropout,
        kernel_size=hparams.kernel_size,
        encoder_channels=hparams.encoder_channels,
        decoder_channels=hparams.decoder_channels,
        converter_channels=hparams.converter_channels,
        use_memory_mask=hparams.use_memory_mask,
        trainable_positional_encodings=hparams.trainable_positional_encodings,
        force_monotonic_attention=hparams.force_monotonic_attention,
        use_decoder_state_for_postnet_input=hparams.use_decoder_state_for_postnet_input,
        max_positions=hparams.max_positions,
        speaker_embedding_weight_std=hparams.speaker_embedding_weight_std,
        freeze_embedding=hparams.freeze_embedding,
        window_ahead=hparams.window_ahead,
        window_backward=hparams.window_backward,
        key_projection=hparams.key_projection,
        value_projection=hparams.value_projection,
    )
    return model



In [6]:

def _load_embedding(path, model):
    state = _load(path)["state_dict"]
    key = "seq2seq.encoder.embed_tokens.weight"
    model.seq2seq.encoder.embed_tokens.weight.data = state[key]

# https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/3


def restore_parts(path, model):
    print("Restore part of the model from: {}".format(path))
    state = _load(path)["state_dict"]
    model_dict = model.state_dict()
    valid_state_dict = {k: v for k, v in state.items() if k in model_dict}

    try:
        model_dict.update(valid_state_dict)
        model.load_state_dict(model_dict)
    except RuntimeError as e:
        # there should be invalid size of weight(s), so load them per parameter
        print(str(e))
        model_dict = model.state_dict()
        for k, v in valid_state_dict.items():
            model_dict[k] = v
            try:
                model.load_state_dict(model_dict)
            except RuntimeError as e:
                print(str(e))
                warn("{}: may contain invalid size of weight. skipping...".format(k))



In [7]:
_frontend = getattr(frontend, hparams.frontend)


In [8]:
data_root= '/ssd_scratch/cvit/anchit/training/'

In [9]:
speaker_id=None

In [10]:
X = FileSourceDataset(TextDataSource(data_root, speaker_id))

In [11]:
Mel = FileSourceDataset(MelSpecDataSource(data_root, speaker_id))

In [12]:
Y = FileSourceDataset(LinearSpecDataSource(data_root, speaker_id))

In [13]:
V = FileSourceDataset(VisualDataSource(data_root, speaker_id))

In [14]:
frame_lengths = Mel.file_data_source.frame_lengths

In [15]:
sampler = PartialyRandomizedSimilarTimeLengthSampler(
         frame_lengths, batch_size=hparams.batch_size)

In [16]:
dataset = PyTorchDataset(X, Mel, Y, V)

In [17]:
data_loader = data_utils.DataLoader(
         dataset, batch_size=hparams.batch_size,
         num_workers=hparams.num_workers, sampler=sampler,
         collate_fn=collate_fn, pin_memory=hparams.pin_memory, drop_last=True)
 

In [18]:
device = torch.device("cuda" if use_cuda else "cpu")


In [19]:
model = build_model()
# print(model)
model = model.to(device)

In [20]:
 optimizer = optim.Adam(model.get_trainable_parameters(),
                           lr=hparams.initial_learning_rate, betas=(
        hparams.adam_beta1, hparams.adam_beta2),
        eps=hparams.adam_eps, weight_decay=hparams.weight_decay,
        amsgrad=hparams.amsgrad)

In [21]:
log_event_path = './'
writer = SummaryWriter(log_event_path)

In [22]:
linear_dim = model.linear_dim
r = hparams.outputs_per_step
downsample_step = hparams.downsample_step
current_lr = 0.002
checkpoint_interval=None
checkpoint_dir=None
nepochs=None
clip_thresh=1.0
train_seq2seq=True
train_postnet=True
binary_criterion = nn.BCELoss()

In [23]:
    train_seq2seq=True
    train_postnet=True
    train(device, model, data_loader, optimizer, writer,
              init_lr=hparams.initial_learning_rate,
              checkpoint_dir=checkpoint_dir,
              checkpoint_interval=hparams.checkpoint_interval,
              nepochs=hparams.nepochs,
              clip_thresh=hparams.clip_thresh,
              train_seq2seq=train_seq2seq, train_postnet=train_postnet)

cuda


0it [00:01, ?it/s]

deepvoice3 mai
torch.Size([4, 118, 256])
torch.Size([4, 256, 131])
torch.Size([4, 118, 256])
latest edits
torch.Size([4, 118, 512])
yaha valueka projection
yaha keys projection
deepvoice3 mai
torch.Size([4, 118, 512])
torch.Size([4, 512, 98])
torch.Size([4, 118, 512])





RuntimeError: The size of tensor a (256) must match the size of tensor b (512) at non-singleton dimension 2