In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import yaml
import json
import math
import os
import torch
import numpy as np 
from torch.utils.data import Dataset

from text import text_to_sequence
from utils.tools import pad_1D, pad_2D


  return torch._C._cuda_getDeviceCount() > 0


In [3]:
class Args:
    train_config = "/work/tc046/tc046/lordzuko/work/SpeakingStyle/config/BC2013/train.yaml"
    model_config = "/work/tc046/tc046/lordzuko/work/SpeakingStyle/config/BC2013/model.yaml"
    preprocess_config = "/work/tc046/tc046/lordzuko/work/SpeakingStyle/config/BC2013/preprocess.yaml"
    restore_step = 0
args = Args()

In [4]:
preprocess_config = yaml.load(
        open(args.preprocess_config, "r"), Loader=yaml.FullLoader
    )
model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)

configs = preprocess_config, model_config, train_config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [5]:
## dataset.py

class Dataset(Dataset):
    def __init__(
        self, filename, preprocess_config, model_config, train_config, sort=False, drop_last=False
    ):
        self.dataset_name = preprocess_config["dataset"]
        self.preprocessed_path = preprocess_config["path"]["preprocessed_path"]
        self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"]
        self.batch_size = train_config["optimizer"]["batch_size"]
        self.load_spker_embed = model_config["multi_speaker"] \
            and preprocess_config["preprocessing"]["speaker_embedder"] != 'none'

        self.basename, self.speaker, self.text, self.raw_text = self.process_meta(
            filename
        )
        with open(os.path.join(self.preprocessed_path, "speakers.json")) as f:
            self.speaker_map = json.load(f)
        self.sort = sort
        self.drop_last = drop_last

    def __len__(self):
        return len(self.text)

    def __getitem__(self, idx):
        basename = self.basename[idx]
        speaker = self.speaker[idx]
        speaker_id = self.speaker_map[speaker]
        raw_text = self.raw_text[idx]
        phone = np.array(text_to_sequence(self.text[idx], self.cleaners))
        mel_path = os.path.join(
            self.preprocessed_path,
            "mel",
            "{}-mel-{}.npy".format(speaker, basename),
        )
        mel = np.load(mel_path)
        pitch_path = os.path.join(
            self.preprocessed_path,
            "pitch",
            "{}-pitch-{}.npy".format(speaker, basename),
        )
        pitch = np.load(pitch_path) # Phoneme Level
        ref_pitch_path = os.path.join(
            self.preprocessed_path,
            "pitch_frame",
            "{}-pitch-{}.npy".format(speaker, basename),
        )
        ref_pitch = np.load(ref_pitch_path) # Frame Level
        energy_path = os.path.join(
            self.preprocessed_path,
            "energy",
            "{}-energy-{}.npy".format(speaker, basename),
        )
        energy = np.load(energy_path) # Phoneme Level
        ref_energy_path = os.path.join(
            self.preprocessed_path,
            "energy_frame",
            "{}-energy-{}.npy".format(speaker, basename),
        )
        ref_energy = np.load(ref_energy_path) # Frame Level
        duration_path = os.path.join(
            self.preprocessed_path,
            "duration",
            "{}-duration-{}.npy".format(speaker, basename),
        )
        duration = np.load(duration_path)
        spker_embed = np.load(os.path.join(
            self.preprocessed_path,
            "spker_embed",
            "{}-spker_embed.npy".format(speaker),
        )) if self.load_spker_embed else None

        sample = {
            "id": basename,
            "speaker": speaker_id,
            "text": phone,
            "raw_text": raw_text,
            "mel": mel,
            "pitch": pitch,
            "energy": energy,
            "duration": duration,
            "spker_embed": spker_embed,
            "ref_pitch": ref_pitch,
            "ref_energy": ref_energy,
        }

        return sample

    def process_meta(self, filename):
        with open(
            os.path.join(self.preprocessed_path, filename), "r", encoding="utf-8"
        ) as f:
            name = []
            speaker = []
            text = []
            raw_text = []
            for line in f.readlines():
                n, s, t, r = line.strip("\n").split("|")
                name.append(n)
                speaker.append(s)
                text.append(t)
                raw_text.append(r)
            return name, speaker, text, raw_text

    def reprocess(self, data, idxs):
        ids = [data[idx]["id"] for idx in idxs]
        speakers = [data[idx]["speaker"] for idx in idxs]
        texts = [data[idx]["text"] for idx in idxs]
        raw_texts = [data[idx]["raw_text"] for idx in idxs]
        mels = [data[idx]["mel"] for idx in idxs]
        pitches = [data[idx]["pitch"] for idx in idxs]
        ref_pitches = [data[idx]["ref_pitch"] for idx in idxs]
        energies = [data[idx]["energy"] for idx in idxs]
        ref_energies = [data[idx]["ref_energy"] for idx in idxs]
        durations = [data[idx]["duration"] for idx in idxs]
        spker_embeds = np.concatenate(np.array([data[idx]["spker_embed"] for idx in idxs]), axis=0) \
            if self.load_spker_embed else None

        text_lens = np.array([text.shape[0] for text in texts])
        mel_lens = np.array([mel.shape[0] for mel in mels])

        speakers = np.array(speakers)
        texts = pad_1D(texts)
        mels = pad_2D(mels)
        pitches = pad_1D(pitches)
        ref_pitches = pad_1D(ref_pitches)
        energies = pad_1D(energies)
        ref_energies = pad_1D(ref_energies)
        durations = pad_1D(durations)

        return (
            ids,
            raw_texts,
            speakers,
            texts,
            text_lens,
            max(text_lens),
            mels,
            mel_lens,
            max(mel_lens),
            pitches,
            energies,
            durations,
            spker_embeds,
            ref_pitches,
            ref_energies,
        )

    def collate_fn(self, data):
        data_size = len(data)

        if self.sort:
            len_arr = np.array([d["text"].shape[0] for d in data])
            idx_arr = np.argsort(-len_arr)
        else:
            idx_arr = np.arange(data_size)

        tail = idx_arr[len(idx_arr) - (len(idx_arr) % self.batch_size) :]
        idx_arr = idx_arr[: len(idx_arr) - (len(idx_arr) % self.batch_size)]
        idx_arr = idx_arr.reshape((-1, self.batch_size)).tolist()
        if not self.drop_last and len(tail) > 0:
            idx_arr += [tail.tolist()]

        output = list()
        for idx in idx_arr:
            output.append(self.reprocess(data, idx))

        return output


class TextDataset(Dataset):
    def __init__(self, filepath, preprocess_config, model_config):
        self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"]
        self.preprocessed_path = preprocess_config["path"]["preprocessed_path"]
        self.load_spker_embed = model_config["multi_speaker"] \
            and preprocess_config["preprocessing"]["speaker_embedder"] != 'none'

        self.basename, self.speaker, self.text, self.raw_text = self.process_meta(
            filepath
        )
        with open(
            os.path.join(
                preprocess_config["path"]["preprocessed_path"], "speakers.json"
            )
        ) as f:
            self.speaker_map = json.load(f)

    def __len__(self):
        return len(self.text)

    def __getitem__(self, idx):
        basename = self.basename[idx]
        speaker = self.speaker[idx]
        speaker_id = self.speaker_map[speaker]
        raw_text = self.raw_text[idx]
        phone = np.array(text_to_sequence(self.text[idx], self.cleaners))
        mel_path = os.path.join(
            self.preprocessed_path,
            "mel",
            "{}-mel-{}.npy".format(speaker, basename),
        )
        mel = np.load(mel_path)
        ref_pitch_path = os.path.join(
            self.preprocessed_path,
            "pitch_frame",
            "{}-pitch-{}.npy".format(speaker, basename),
        )
        ref_pitch = np.load(ref_pitch_path) # Frame Level
        ref_energy_path = os.path.join(
            self.preprocessed_path,
            "energy_frame",
            "{}-energy-{}.npy".format(speaker, basename),
        )
        ref_energy = np.load(ref_energy_path) # Frame Level
        spker_embed = np.load(os.path.join(
            self.preprocessed_path,
            "spker_embed",
            "{}-spker_embed.npy".format(speaker),
        )) if self.load_spker_embed else None

        return (basename, speaker_id, phone, raw_text, mel, spker_embed, ref_pitch, ref_energy)

    def process_meta(self, filename):
        with open(filename, "r", encoding="utf-8") as f:
            name = []
            speaker = []
            text = []
            raw_text = []
            for line in f.readlines():
                n, s, t, r = line.strip("\n").split("|")
                name.append(n)
                speaker.append(s)
                text.append(t)
                raw_text.append(r)
            return name, speaker, text, raw_text

    def collate_fn(self, data):
        ids = [d[0] for d in data]
        speakers = np.array([d[1] for d in data])
        texts = [d[2] for d in data]
        raw_texts = [d[3] for d in data]
        mels = [d[4] for d in data]
        spker_embeds = np.concatenate(np.array([d[5] for d in data]), axis=0) \
            if self.load_spker_embed else None
        ref_pitches = [d[6] for d in data]
        ref_energies = [d[7] for d in data]

        text_lens = np.array([text.shape[0] for text in texts])
        mel_lens = np.array([mel.shape[0] for mel in mels])

        texts = pad_1D(texts)
        mels = pad_2D(mels)
        ref_pitches = pad_1D(ref_pitches)
        ref_energies = pad_1D(ref_energies)

        return (
            ids,
            raw_texts,
            speakers,
            texts,
            text_lens,
            max(text_lens),
            mels,
            mel_lens,
            max(mel_lens),
            spker_embeds,
            ref_pitches,
            ref_energies,
        )


In [6]:
## train.py

import argparse
import os

import torch
import yaml
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from utils.model import get_vocoder, get_param_num, get_named_param, get_model
from utils.tools import get_configs_of, to_device, log, synth_one_sample
# from model import FastSpeech2Loss

# Get dataset
dataset = Dataset(
    "train.txt", preprocess_config, model_config, train_config, sort=True, drop_last=True
)
batch_size = train_config["optimizer"]["batch_size"]
group_size = 4  # Set this larger than 1 to enable sorting in Dataset
assert batch_size * group_size < len(dataset)
loader = DataLoader(
    dataset,
    batch_size=batch_size * group_size,
    shuffle=True,
    collate_fn=dataset.collate_fn,
    num_workers=12,
)

In [7]:
# Prepare model
model, optimizer = get_model(
    args, configs, device, train=True, ignore_layers=train_config["ignore_layers"])
model = nn.DataParallel(model)
num_param = get_param_num(model)
# Loss = FastSpeech2Loss(preprocess_config, model_config).to(device)
print("Number of FastSpeech2 Parameters:", num_param)

Number of FastSpeech2 Parameters: 50477403


In [8]:
# Load vocoder
vocoder = get_vocoder(model_config, device)

Removing weight norm...


In [9]:
for batchs in loader:
    for batch in batchs:
        batch = to_device(batch, device)
        break

In [10]:
output = model(*(batch[2:]))

594


In [None]:
# for batchs in loader:
#     for batch in batchs:
#         batch = to_device(batch, device)

#         # Forward
#         output = model(*(batch[2:]))
#         print(output)
#         break

In [None]:
# class FastSpeech2Loss(nn.Module):
#     """ FastSpeech2 Loss """

#     def __init__(self, preprocess_config, model_config):
#         super(FastSpeech2Loss, self).__init__()
#         self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][
#             "feature"
#         ]
#         self.energy_feature_level = preprocess_config["preprocessing"]["energy"][
#             "feature"
#         ]
#         self.mse_loss = nn.MSELoss()
#         self.mae_loss = nn.L1Loss()

#     def forward(self, inputs, predictions):
#         (
#             mel_targets,
#             _,
#             _,
#             pitch_targets,
#             energy_targets,
#             duration_targets,
#         ) = inputs[6:]
#         (
#             mel_predictions,
#             postnet_mel_predictions,
#             pitch_predictions,
#             energy_predictions,
#             log_duration_predictions,
#             _,
#             src_masks,
#             mel_masks,
#             _,
#             _,
#         ) = predictions
#         src_masks = ~src_masks
#         mel_masks = ~mel_masks
#         log_duration_targets = torch.log(duration_targets.float() + 1)
#         mel_targets = mel_targets[:, : mel_masks.shape[1], :]
#         mel_masks = mel_masks[:, :mel_masks.shape[1]]

#         log_duration_targets.requires_grad = False
#         pitch_targets.requires_grad = False
#         energy_targets.requires_grad = False
#         mel_targets.requires_grad = False

#         if self.pitch_feature_level == "phoneme_level":
#             pitch_predictions = pitch_predictions.masked_select(src_masks)
#             pitch_targets = pitch_targets.masked_select(src_masks)
#         elif self.pitch_feature_level == "frame_level":
#             pitch_predictions = pitch_predictions.masked_select(mel_masks)
#             pitch_targets = pitch_targets.masked_select(mel_masks)

#         if self.energy_feature_level == "phoneme_level":
#             energy_predictions = energy_predictions.masked_select(src_masks)
#             energy_targets = energy_targets.masked_select(src_masks)
#         if self.energy_feature_level == "frame_level":
#             energy_predictions = energy_predictions.masked_select(mel_masks)
#             energy_targets = energy_targets.masked_select(mel_masks)

#         log_duration_predictions = log_duration_predictions.masked_select(src_masks)
#         log_duration_targets = log_duration_targets.masked_select(src_masks)

#         mel_predictions = mel_predictions.masked_select(mel_masks.unsqueeze(-1))
#         postnet_mel_predictions = postnet_mel_predictions.masked_select(
#             mel_masks.unsqueeze(-1)
#         )
#         mel_targets = mel_targets.masked_select(mel_masks.unsqueeze(-1))

#         mel_loss = self.mae_loss(mel_predictions, mel_targets)
#         postnet_mel_loss = self.mae_loss(postnet_mel_predictions, mel_targets)

#         pitch_loss = self.mse_loss(pitch_predictions, pitch_targets)
#         energy_loss = self.mse_loss(energy_predictions, energy_targets)
#         duration_loss = self.mse_loss(log_duration_predictions, log_duration_targets)

#         total_loss = (
#             mel_loss + postnet_mel_loss + duration_loss + pitch_loss + energy_loss
#         )

#         return (
#             total_loss,
#             mel_loss,
#             postnet_mel_loss,
#             pitch_loss,
#             energy_loss,
#             duration_loss,
#         )


In [13]:
# # Init logger
# for p in train_config["path"].values():
#     os.makedirs(p, exist_ok=True)
# train_log_path = os.path.join(train_config["path"]["log_path"], "train")
# val_log_path = os.path.join(train_config["path"]["log_path"], "val")
# os.makedirs(train_log_path, exist_ok=True)
# os.makedirs(val_log_path, exist_ok=True)
# train_logger = SummaryWriter(train_log_path)
# val_logger = SummaryWriter(val_log_path)

In [14]:
# # Training
# named_param = ['s_gamma', 's_beta']
# step = args.restore_step + 1
# epoch = 1
# grad_acc_step = train_config["optimizer"]["grad_acc_step"]
# grad_clip_thresh = train_config["optimizer"]["grad_clip_thresh"]
# total_step = train_config["step"]["total_step"]
# log_step = train_config["step"]["log_step"]
# save_step = train_config["step"]["save_step"]
# synth_step = train_config["step"]["synth_step"]
# val_step = train_config["step"]["val_step"]

# outer_bar = tqdm(total=total_step, desc="Training", position=0)
# outer_bar.n = args.restore_step
# outer_bar.update()


Training:   0%|          | 0/90000 [00:00<?, ?it/s]

In [15]:
# while True:
#     inner_bar = tqdm(total=len(loader), desc="Epoch {}".format(epoch), position=1)
#     for batchs in loader:
#         for batch in batchs:
#             batch = to_device(batch, device)

#             # Forward
#             output = model(*(batch[2:]))

#             # Cal Loss
#             losses = Loss(batch, output, step, get_named_param(model, named_param))
#             losses, lambdas = losses[:-2], losses[-2:]
#             total_loss = losses[0]

#             # Backward
#             total_loss = total_loss / grad_acc_step
#             total_loss.backward()
#             if step % grad_acc_step == 0:
#                 # Clipping gradients to avoid gradient explosion
#                 nn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh)

#                 # Update weights
#                 lr = optimizer.step_and_update_lr()
#                 optimizer.zero_grad()

#             if step % log_step == 0:
#                 losses = [l.item() for l in losses]
#                 message1 = "Step {}/{}, ".format(step, total_step)
#                 message2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Adv Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}".format(
#                     *losses
#                 )

#                 with open(os.path.join(train_log_path, "log.txt"), "a") as f:
#                     f.write(message1 + message2 + "\n")

#                 outer_bar.write(message1 + message2)

#                 log(train_logger, step, losses=losses, lr=lr, lambdas=lambdas)

#             if step % synth_step == 0:
#                 fig, wav_reconstruction, wav_prediction, tag = synth_one_sample(
#                     batch,
#                     output,
#                     vocoder,
#                     model_config,
#                     preprocess_config,
#                 )
#                 log(
#                     train_logger,
#                     fig=fig,
#                     tag="Training/step_{}_{}".format(step, tag),
#                 )
#                 sampling_rate = preprocess_config["preprocessing"]["audio"][
#                     "sampling_rate"
#                 ]
#                 log(
#                     train_logger,
#                     audio=wav_reconstruction,
#                     sampling_rate=sampling_rate,
#                     tag="Training/step_{}_{}_reconstructed".format(step, tag),
#                 )
#                 log(
#                     train_logger,
#                     audio=wav_prediction,
#                     sampling_rate=sampling_rate,
#                     tag="Training/step_{}_{}_synthesized".format(step, tag),
#                 )

#             if step % val_step == 0:
#                 model.eval()
#                 message = evaluate(model, step, configs, val_logger, vocoder, len(losses), named_param)
#                 with open(os.path.join(val_log_path, "log.txt"), "a") as f:
#                     f.write(message + "\n")
#                 outer_bar.write(message)

#                 model.train()

#             if step % save_step == 0:
#                 torch.save(
#                     {
#                         "model": model.module.state_dict(),
#                         "optimizer": optimizer._optimizer.state_dict(),
#                     },
#                     os.path.join(
#                         train_config["path"]["ckpt_path"],
#                         "{}.pth.tar".format(step),
#                     ),
#                 )

#             if step == total_step:
#                 quit()
#             step += 1
#             outer_bar.update(1)

#         inner_bar.update(1)
#     epoch += 1


Epoch 1:   0%|          | 0/4931 [00:00<?, ?it/s][A

RuntimeError: stack expects a non-empty TensorList