In [1]:
#!/usr/bin/env python

import os
import json
import time
import argparse
from tqdm import tqdm
from collections import defaultdict
import math
import argparse
import importlib

# torchim:
import torch
from torch import nn
from torch.utils.data import DataLoader, Subset, ConcatDataset
# from tensorboardX import SummaryWriter
import numpy as np
import pytorch_warmup as warmup

# data:
import data
from data.collate import collate_fn, gpu_collate, no_pad_collate
from data.transforms import (
    Compose, AddLengths, AudioSqueeze, TextPreprocess,
    MaskSpectrogram, ToNumpy, BPEtexts, MelSpectrogram,
    ToGpu, Pad, NormalizedMelSpectrogram
)
import youtokentome as yttm

import torchaudio
from audiomentations import (
    TimeStretch, PitchShift, AddGaussianNoise
)
from functools import partial

# model:
from model import configs as quartznet_configs
from model.quartznet import QuartzNet

# utils:
import yaml
from easydict import EasyDict as edict
from utils import fix_seeds, remove_from_dict, prepare_bpe
import wandb
from decoder import GreedyDecoder, BeamCTCDecoder

import youtokentome as yttm


# TODO: wrap to trainer class


def train(config):
    fix_seeds(seed=config.train.get('seed', 42))
    dataset_module = importlib.import_module(
        f'.{config.dataset.name}', data.__name__)
    bpe = prepare_bpe(config)

    transforms_train = Compose([
        TextPreprocess(),
        ToNumpy(),
        BPEtexts(bpe=bpe, dropout_prob=config.bpe.get('dropout_prob', 0.05)),
        AudioSqueeze(),
        AddGaussianNoise(
            min_amplitude=0.001,
            max_amplitude=0.015,
            p=0.5
        ),
        TimeStretch(
            min_rate=0.8,
            max_rate=1.25,
            p=0.5
        ),
        PitchShift(
            min_semitones=-4,
            max_semitones=4,
            p=0.5
        )
        # AddLengths()
    ])

    batch_transforms_train = Compose([
        ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
        NormalizedMelSpectrogram(
            sample_rate=config.dataset.get('sample_rate', 16000),
            n_mels=config.model.feat_in,
            normalize=config.dataset.get('normalize', None)
        ).to('cuda' if torch.cuda.is_available() else 'cpu'),
        MaskSpectrogram(
            probability=0.5,
            time_mask_max_percentage=0.05,
            frequency_mask_max_percentage=0.15
        ),
        AddLengths(),
        Pad()
    ])

    transforms_val = Compose([
        TextPreprocess(),
        ToNumpy(),
        BPEtexts(bpe=bpe),
        AudioSqueeze()
    ])

    batch_transforms_val = Compose([
        ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
        NormalizedMelSpectrogram(
            sample_rate=config.dataset.get(
                'sample_rate', 16000),  # for LJspeech
            n_mels=config.model.feat_in,
            normalize=config.dataset.get('normalize', None)
        ).to('cuda' if torch.cuda.is_available() else 'cpu'),
        AddLengths(),
        Pad()
    ])

    # load datasets
    train_dataset = dataset_module.get_dataset(
        config, transforms=transforms_train, part='train')
    val_dataset = dataset_module.get_dataset(
        config, transforms=transforms_val, part='val')
    print("!!!", config.train.get('num_workers', 4))
    train_dataloader = DataLoader(train_dataset, 
                                  batch_size=config.train.get('batch_size', 1), collate_fn=no_pad_collate)

    val_dataloader = DataLoader(val_dataset, 
                                batch_size=1, collate_fn=no_pad_collate)

    model = QuartzNet(
        model_config=getattr(
            quartznet_configs, config.model.name, '_quartznet5x5_config'),
        **remove_from_dict(config.model, ['name'])
    )

    print(model)
    optimizer = torch.optim.Adam(
        model.parameters(), **config.train.get('optimizer', {}))
    num_steps = len(train_dataloader) * config.train.get('epochs', 10)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_steps)
    # warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)

    if config.train.get('from_checkpoint', None) is not None:
        model.load_weights(config.train.from_checkpoint)

    if torch.cuda.is_available():
        model = model.cuda()

    criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
    # criterion = nn.CTCLoss(blank=config.model.vocab_size)
    decoder = GreedyDecoder(bpe=bpe)

    prev_wer = 1000
    wandb.init(project=config.wandb.project, config=config)
    wandb.watch(model, log="all", log_freq=config.wandb.get(
        'log_interval', 5000))
    for epoch_idx in tqdm(range(config.train.get('epochs', 10))):
        # train:
        model.train()
        for batch_idx, batch in enumerate(train_dataloader):
            print(batch)
            batch = batch_transforms_train(batch)
            print(batch)
            return

            optimizer.zero_grad()
            logits = model(batch['audio'])
            output_length = torch.ceil(
                batch['input_lengths'].float() / model.stride).int()
            loss = criterion(logits.permute(2, 0, 1).log_softmax(
                dim=2), batch['text'], output_length, batch['target_lengths'])
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                model.parameters(), config.train.get('clip_grad_norm', 15))
            optimizer.step()
            lr_scheduler.step()
            # warmup_scheduler.dampen()

            if batch_idx % config.wandb.get('log_interval', 5000) == 0:
                target_strings = decoder.convert_to_strings(batch['text'])
                decoded_output = decoder.decode(
                    logits.permute(0, 2, 1).softmax(dim=2))
                wer = np.mean([decoder.wer(true, pred)
                              for true, pred in zip(target_strings, decoded_output)])
                cer = np.mean([decoder.cer(true, pred)
                              for true, pred in zip(target_strings, decoded_output)])
                step = epoch_idx * \
                    len(train_dataloader) * train_dataloader.batch_size + \
                    batch_idx * train_dataloader.batch_size
                wandb.log({
                    "train_loss": loss.item(),
                    "train_wer": wer,
                    "train_cer": cer,
                    "train_samples": wandb.Table(
                        columns=['gt_text', 'pred_text'],
                        data=zip(target_strings, decoded_output)
                    )
                }, step=step)

        # validate:
        model.eval()
        val_stats = defaultdict(list)
        for batch_idx, batch in enumerate(val_dataloader):
            batch = batch_transforms_val(batch)
            with torch.no_grad():
                logits = model(batch['audio'])
                output_length = torch.ceil(
                    batch['input_lengths'].float() / model.stride).int()
                loss = criterion(logits.permute(2, 0, 1).log_softmax(
                    dim=2), batch['text'], output_length, batch['target_lengths'])

            target_strings = decoder.convert_to_strings(batch['text'])
            decoded_output = decoder.decode(
                logits.permute(0, 2, 1).softmax(dim=2))
            wer = np.mean([decoder.wer(true, pred)
                          for true, pred in zip(target_strings, decoded_output)])
            cer = np.mean([decoder.cer(true, pred)
                          for true, pred in zip(target_strings, decoded_output)])
            val_stats['val_loss'].append(loss.item())
            val_stats['wer'].append(wer)
            val_stats['cer'].append(cer)
        for k, v in val_stats.items():
            val_stats[k] = np.mean(v)
        val_stats['val_samples'] = wandb.Table(
            columns=['gt_text', 'pred_text'], data=zip(target_strings, decoded_output))
        wandb.log(val_stats, step=step)

        # save model, TODO: save optimizer:
        if val_stats['wer'] < prev_wer:
            os.makedirs(config.train.get(
                'checkpoint_path', 'checkpoints'), exist_ok=True)
            prev_wer = val_stats['wer']
            torch.save(
                model.state_dict(),
                os.path.join(config.train.get(
                    'checkpoint_path', 'checkpoints'), f'model_{epoch_idx}_{prev_wer}.pth')
            )
            wandb.save(os.path.join(config.train.get(
                'checkpoint_path', 'checkpoints'), f'model_{epoch_idx}_{prev_wer}.pth'))

In [2]:

parser = argparse.ArgumentParser(description='Training model.')
parser.add_argument('--config', default='configs/train_LJSpeech.yaml',
                    help='path to config file')
args = parser.parse_args("")
with open(args.config, 'r') as f:
    config = edict(yaml.safe_load(f))

In [3]:
dir(args)

['__class__',
 '__contains__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_get_args',
 '_get_kwargs',
 'config']

In [4]:
args.config

'configs/train_LJSpeech.yaml'

In [5]:
fix_seeds(seed=config.train.get('seed', 42))
dataset_module = importlib.import_module(
    f'.{config.dataset.name}', data.__name__)
bpe = prepare_bpe(config)

transforms_train = Compose([
    TextPreprocess(),
    ToNumpy(),
    BPEtexts(bpe=bpe, dropout_prob=config.bpe.get('dropout_prob', 0.05)),
    AudioSqueeze(),
    AddGaussianNoise(
        min_amplitude=0.001,
        max_amplitude=0.015,
        p=0.5
    ),
    TimeStretch(
        min_rate=0.8,
        max_rate=1.25,
        p=0.5
    ),
    PitchShift(
        min_semitones=-4,
        max_semitones=4,
        p=0.5
    )
    # AddLengths()
])

batch_transforms_train = Compose([
    ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
    NormalizedMelSpectrogram(
        sample_rate=config.dataset.get('sample_rate', 16000),
        n_mels=config.model.feat_in,
        normalize=config.dataset.get('normalize', None)
    ).to('cuda' if torch.cuda.is_available() else 'cpu'),
    MaskSpectrogram(
        probability=0.5,
        time_mask_max_percentage=0.05,
        frequency_mask_max_percentage=0.15
    ),
    AddLengths(),
    Pad()
])

transforms_val = Compose([
    TextPreprocess(),
    ToNumpy(),
    BPEtexts(bpe=bpe),
    AudioSqueeze()
])

batch_transforms_val = Compose([
    ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
    NormalizedMelSpectrogram(
        sample_rate=config.dataset.get(
            'sample_rate', 16000),  # for LJspeech
        n_mels=config.model.feat_in,
        normalize=config.dataset.get('normalize', None)
    ).to('cuda' if torch.cuda.is_available() else 'cpu'),
    AddLengths(),
    Pad()
])

# load datasets
train_dataset = dataset_module.get_dataset(
    config, transforms=transforms_train, part='train')
val_dataset = dataset_module.get_dataset(
    config, transforms=transforms_val, part='val')
print("!!!", config.train.get('num_workers', 4))
train_dataloader = DataLoader(train_dataset, 
                                batch_size=config.train.get('batch_size', 1), collate_fn=no_pad_collate)

val_dataloader = DataLoader(val_dataset, 
                            batch_size=1, collate_fn=no_pad_collate)

model = QuartzNet(
    model_config=getattr(
        quartznet_configs, config.model.name, '_quartznet5x5_config'),
    **remove_from_dict(config.model, ['name'])
)

print(model)
optimizer = torch.optim.Adam(
    model.parameters(), **config.train.get('optimizer', {}))
num_steps = len(train_dataloader) * config.train.get('epochs', 10)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=num_steps)
# warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)

if config.train.get('from_checkpoint', None) is not None:
    model.load_weights(config.train.from_checkpoint)

if torch.cuda.is_available():
    model = model.cuda()

criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
# criterion = nn.CTCLoss(blank=config.model.vocab_size)
decoder = GreedyDecoder(bpe=bpe)

prev_wer = 1000
wandb.init(project=config.wandb.project, config=config)
wandb.watch(model, log="all", log_freq=config.wandb.get(
    'log_interval', 5000))
for epoch_idx in tqdm(range(config.train.get('epochs', 10))):
    # train:
    model.train()
    for batch_idx, batch in enumerate(train_dataloader):
        print(batch)
        batch = batch_transforms_train(batch)
        print(batch)
        return

        optimizer.zero_grad()
        logits = model(batch['audio'])
        output_length = torch.ceil(
            batch['input_lengths'].float() / model.stride).int()
        loss = criterion(logits.permute(2, 0, 1).log_softmax(
            dim=2), batch['text'], output_length, batch['target_lengths'])
        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(), config.train.get('clip_grad_norm', 15))
        optimizer.step()
        lr_scheduler.step()
        # warmup_scheduler.dampen()

        if batch_idx % config.wandb.get('log_interval', 5000) == 0:
            target_strings = decoder.convert_to_strings(batch['text'])
            decoded_output = decoder.decode(
                logits.permute(0, 2, 1).softmax(dim=2))
            wer = np.mean([decoder.wer(true, pred)
                            for true, pred in zip(target_strings, decoded_output)])
            cer = np.mean([decoder.cer(true, pred)
                            for true, pred in zip(target_strings, decoded_output)])
            step = epoch_idx * \
                len(train_dataloader) * train_dataloader.batch_size + \
                batch_idx * train_dataloader.batch_size
            wandb.log({
                "train_loss": loss.item(),
                "train_wer": wer,
                "train_cer": cer,
                "train_samples": wandb.Table(
                    columns=['gt_text', 'pred_text'],
                    data=zip(target_strings, decoded_output)
                )
            }, step=step)

    # validate:
    model.eval()
    val_stats = defaultdict(list)
    for batch_idx, batch in enumerate(val_dataloader):
        batch = batch_transforms_val(batch)
        with torch.no_grad():
            logits = model(batch['audio'])
            output_length = torch.ceil(
                batch['input_lengths'].float() / model.stride).int()
            loss = criterion(logits.permute(2, 0, 1).log_softmax(
                dim=2), batch['text'], output_length, batch['target_lengths'])

        target_strings = decoder.convert_to_strings(batch['text'])
        decoded_output = decoder.decode(
            logits.permute(0, 2, 1).softmax(dim=2))
        wer = np.mean([decoder.wer(true, pred)
                        for true, pred in zip(target_strings, decoded_output)])
        cer = np.mean([decoder.cer(true, pred)
                        for true, pred in zip(target_strings, decoded_output)])
        val_stats['val_loss'].append(loss.item())
        val_stats['wer'].append(wer)
        val_stats['cer'].append(cer)
    for k, v in val_stats.items():
        val_stats[k] = np.mean(v)
    val_stats['val_samples'] = wandb.Table(
        columns=['gt_text', 'pred_text'], data=zip(target_strings, decoded_output))
    wandb.log(val_stats, step=step)

    # save model, TODO: save optimizer:
    if val_stats['wer'] < prev_wer:
        os.makedirs(config.train.get(
            'checkpoint_path', 'checkpoints'), exist_ok=True)
        prev_wer = val_stats['wer']
        torch.save(
            model.state_dict(),
            os.path.join(config.train.get(
                'checkpoint_path', 'checkpoints'), f'model_{epoch_idx}_{prev_wer}.pth')
        )
        wandb.save(os.path.join(config.train.get(
            'checkpoint_path', 'checkpoints'), f'model_{epoch_idx}_{prev_wer}.pth'))

In [6]:
#!/usr/bin/env python

import os
import json
import time
import argparse
from tqdm import tqdm
from collections import defaultdict
import math
import argparse
import importlib

# torchim:
import torch
from torch import nn
from torch.utils.data import DataLoader, Subset, ConcatDataset
# from tensorboardX import SummaryWriter
import numpy as np
import pytorch_warmup as warmup

# data:
import data
from data.collate import collate_fn, gpu_collate, no_pad_collate
from data.transforms import (
    Compose, AddLengths, AudioSqueeze, TextPreprocess,
    MaskSpectrogram, ToNumpy, BPEtexts, MelSpectrogram,
    ToGpu, Pad, NormalizedMelSpectrogram
)
import youtokentome as yttm

import torchaudio
from audiomentations import (
    TimeStretch, PitchShift, AddGaussianNoise
)
from functools import partial

# model:
from model import configs as quartznet_configs
from model.quartznet import QuartzNet

# utils:
import yaml
from easydict import EasyDict as edict
from utils import fix_seeds, remove_from_dict, prepare_bpe
import wandb
from decoder import GreedyDecoder, BeamCTCDecoder

import youtokentome as yttm


# TODO: wrap to trainer class

In [7]:

parser = argparse.ArgumentParser(description='Training model.')
parser.add_argument('--config', default='configs/train_LJSpeech.yaml',
                    help='path to config file')
args = parser.parse_args("")
with open(args.config, 'r') as f:
    config = edict(yaml.safe_load(f))

In [8]:
config

{'dataset': {'root': 'DB/LJspeech',
  'train_part': 0.95,
  'name': 'ljspeech',
  'sample_rate': 22050},
 'bpe': {'train': True, 'model_path': 'yttm.bpe'},
 'train': {'seed': 42,
  'num_workers': 1,
  'batch_size': 32,
  'clip_grad_norm': 15,
  'epochs': 42,
  'optimizer': {'lr': 0.0005, 'weight_decay': 0.0001}},
 'wandb': {'project': 'quartznet_ljspeech', 'log_interval': 20},
 'model': {'name': '_quartznet5x5_config', 'vocab_size': 120, 'feat_in': 64}}

In [9]:
fix_seeds(seed=config.train.get('seed', 42))
dataset_module = importlib.import_module(
    f'.{config.dataset.name}', data.__name__)
bpe = prepare_bpe(config)

transforms_train = Compose([
    TextPreprocess(),
    ToNumpy(),
    BPEtexts(bpe=bpe, dropout_prob=config.bpe.get('dropout_prob', 0.05)),
    AudioSqueeze(),
    AddGaussianNoise(
        min_amplitude=0.001,
        max_amplitude=0.015,
        p=0.5
    ),
    TimeStretch(
        min_rate=0.8,
        max_rate=1.25,
        p=0.5
    ),
    PitchShift(
        min_semitones=-4,
        max_semitones=4,
        p=0.5
    )
    # AddLengths()
])

batch_transforms_train = Compose([
    ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
    NormalizedMelSpectrogram(
        sample_rate=config.dataset.get('sample_rate', 16000),
        n_mels=config.model.feat_in,
        normalize=config.dataset.get('normalize', None)
    ).to('cuda' if torch.cuda.is_available() else 'cpu'),
    MaskSpectrogram(
        probability=0.5,
        time_mask_max_percentage=0.05,
        frequency_mask_max_percentage=0.15
    ),
    AddLengths(),
    Pad()
])

transforms_val = Compose([
    TextPreprocess(),
    ToNumpy(),
    BPEtexts(bpe=bpe),
    AudioSqueeze()
])

batch_transforms_val = Compose([
    ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
    NormalizedMelSpectrogram(
        sample_rate=config.dataset.get(
            'sample_rate', 16000),  # for LJspeech
        n_mels=config.model.feat_in,
        normalize=config.dataset.get('normalize', None)
    ).to('cuda' if torch.cuda.is_available() else 'cpu'),
    AddLengths(),
    Pad()
])

# load datasets
train_dataset = dataset_module.get_dataset(
    config, transforms=transforms_train, part='train')
val_dataset = dataset_module.get_dataset(
    config, transforms=transforms_val, part='val')
# print("!!!", config.train.get('num_workers', 4))
train_dataloader = DataLoader(train_dataset, 
                                batch_size=config.train.get('batch_size', 1), collate_fn=no_pad_collate)

val_dataloader = DataLoader(val_dataset, 
                            batch_size=1, collate_fn=no_pad_collate)

model = QuartzNet(
    model_config=getattr(
        quartznet_configs, config.model.name, '_quartznet5x5_config'),
    **remove_from_dict(config.model, ['name'])
)

print(model)
optimizer = torch.optim.Adam(
    model.parameters(), **config.train.get('optimizer', {}))
num_steps = len(train_dataloader) * config.train.get('epochs', 10)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=num_steps)
# warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)

if config.train.get('from_checkpoint', None) is not None:
    model.load_weights(config.train.from_checkpoint)

if torch.cuda.is_available():
    model = model.cuda()

criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
# criterion = nn.CTCLoss(blank=config.model.vocab_size)
decoder = GreedyDecoder(bpe=bpe)

prev_wer = 1000
wandb.init(project=config.wandb.project, config=config)
wandb.watch(model, log="all", log_freq=config.wandb.get(
    'log_interval', 5000))

[]

In [10]:
fix_seeds(seed=config.train.get('seed', 42))
dataset_module = importlib.import_module(
    f'.{config.dataset.name}', data.__name__)
bpe = prepare_bpe(config)

transforms_train = Compose([
    TextPreprocess(),
    ToNumpy(),
    BPEtexts(bpe=bpe, dropout_prob=config.bpe.get('dropout_prob', 0.05)),
    AudioSqueeze(),
    AddGaussianNoise(
        min_amplitude=0.001,
        max_amplitude=0.015,
        p=0.5
    ),
    TimeStretch(
        min_rate=0.8,
        max_rate=1.25,
        p=0.5
    ),
    PitchShift(
        min_semitones=-4,
        max_semitones=4,
        p=0.5
    )
    # AddLengths()
])

batch_transforms_train = Compose([
    ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
    NormalizedMelSpectrogram(
        sample_rate=config.dataset.get('sample_rate', 16000),
        n_mels=config.model.feat_in,
        normalize=config.dataset.get('normalize', None)
    ).to('cuda' if torch.cuda.is_available() else 'cpu'),
    MaskSpectrogram(
        probability=0.5,
        time_mask_max_percentage=0.05,
        frequency_mask_max_percentage=0.15
    ),
    AddLengths(),
    Pad()
])

transforms_val = Compose([
    TextPreprocess(),
    ToNumpy(),
    BPEtexts(bpe=bpe),
    AudioSqueeze()
])

batch_transforms_val = Compose([
    ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
    NormalizedMelSpectrogram(
        sample_rate=config.dataset.get(
            'sample_rate', 16000),  # for LJspeech
        n_mels=config.model.feat_in,
        normalize=config.dataset.get('normalize', None)
    ).to('cuda' if torch.cuda.is_available() else 'cpu'),
    AddLengths(),
    Pad()
])

# load datasets
train_dataset = dataset_module.get_dataset(
    config, transforms=transforms_train, part='train')
val_dataset = dataset_module.get_dataset(
    config, transforms=transforms_val, part='val')
# print("!!!", config.train.get('num_workers', 4))
train_dataloader = DataLoader(train_dataset, 
                                batch_size=config.train.get('batch_size', 1), collate_fn=no_pad_collate)

val_dataloader = DataLoader(val_dataset, 
                            batch_size=1, collate_fn=no_pad_collate)

model = QuartzNet(
    model_config=getattr(
        quartznet_configs, config.model.name, '_quartznet5x5_config'),
    **remove_from_dict(config.model, ['name'])
)

print(model)
write_to_file(model,'model_structure.txt')

optimizer = torch.optim.Adam(
    model.parameters(), **config.train.get('optimizer', {}))
num_steps = len(train_dataloader) * config.train.get('epochs', 10)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=num_steps)
# warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)

if config.train.get('from_checkpoint', None) is not None:
    model.load_weights(config.train.from_checkpoint)

if torch.cuda.is_available():
    model = model.cuda()

criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
# criterion = nn.CTCLoss(blank=config.model.vocab_size)
decoder = GreedyDecoder(bpe=bpe)

prev_wer = 1000
wandb.init(project=config.wandb.project, config=config)
wandb.watch(model, log="all", log_freq=config.wandb.get(
    'log_interval', 5000))

In [11]:

parser = argparse.ArgumentParser(description='Training model.')
parser.add_argument('--config', default='configs/train_LJSpeech.yaml',
                    help='path to config file')
args = parser.parse_args("")
with open(args.config, 'r') as f:
    config = edict(yaml.safe_load(f))

In [12]:
def write_to_file(str_w, file_name = 'sth.txt', mode = 'w'):
    with open(file_name) as f:
        f.write(str_w)

In [13]:
config

{'dataset': {'root': 'DB/LJspeech',
  'train_part': 0.95,
  'name': 'ljspeech',
  'sample_rate': 22050},
 'bpe': {'train': True, 'model_path': 'yttm.bpe'},
 'train': {'seed': 42,
  'num_workers': 1,
  'batch_size': 32,
  'clip_grad_norm': 15,
  'epochs': 42,
  'optimizer': {'lr': 0.0005, 'weight_decay': 0.0001}},
 'wandb': {'project': 'quartznet_ljspeech', 'log_interval': 20},
 'model': {'name': '_quartznet5x5_config', 'vocab_size': 120, 'feat_in': 64}}

In [14]:
fix_seeds(seed=config.train.get('seed', 42))
dataset_module = importlib.import_module(
    f'.{config.dataset.name}', data.__name__)
bpe = prepare_bpe(config)

transforms_train = Compose([
    TextPreprocess(),
    ToNumpy(),
    BPEtexts(bpe=bpe, dropout_prob=config.bpe.get('dropout_prob', 0.05)),
    AudioSqueeze(),
    AddGaussianNoise(
        min_amplitude=0.001,
        max_amplitude=0.015,
        p=0.5
    ),
    TimeStretch(
        min_rate=0.8,
        max_rate=1.25,
        p=0.5
    ),
    PitchShift(
        min_semitones=-4,
        max_semitones=4,
        p=0.5
    )
    # AddLengths()
])

batch_transforms_train = Compose([
    ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
    NormalizedMelSpectrogram(
        sample_rate=config.dataset.get('sample_rate', 16000),
        n_mels=config.model.feat_in,
        normalize=config.dataset.get('normalize', None)
    ).to('cuda' if torch.cuda.is_available() else 'cpu'),
    MaskSpectrogram(
        probability=0.5,
        time_mask_max_percentage=0.05,
        frequency_mask_max_percentage=0.15
    ),
    AddLengths(),
    Pad()
])

transforms_val = Compose([
    TextPreprocess(),
    ToNumpy(),
    BPEtexts(bpe=bpe),
    AudioSqueeze()
])

batch_transforms_val = Compose([
    ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
    NormalizedMelSpectrogram(
        sample_rate=config.dataset.get(
            'sample_rate', 16000),  # for LJspeech
        n_mels=config.model.feat_in,
        normalize=config.dataset.get('normalize', None)
    ).to('cuda' if torch.cuda.is_available() else 'cpu'),
    AddLengths(),
    Pad()
])

# load datasets
train_dataset = dataset_module.get_dataset(
    config, transforms=transforms_train, part='train')
val_dataset = dataset_module.get_dataset(
    config, transforms=transforms_val, part='val')
# print("!!!", config.train.get('num_workers', 4))
train_dataloader = DataLoader(train_dataset, 
                                batch_size=config.train.get('batch_size', 1), collate_fn=no_pad_collate)

val_dataloader = DataLoader(val_dataset, 
                            batch_size=1, collate_fn=no_pad_collate)

model = QuartzNet(
    model_config=getattr(
        quartznet_configs, config.model.name, '_quartznet5x5_config'),
    **remove_from_dict(config.model, ['name'])
)

print(model)
write_to_file(model,'model_structure.txt')

optimizer = torch.optim.Adam(
    model.parameters(), **config.train.get('optimizer', {}))
num_steps = len(train_dataloader) * config.train.get('epochs', 10)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=num_steps)
# warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)

if config.train.get('from_checkpoint', None) is not None:
    model.load_weights(config.train.from_checkpoint)

if torch.cuda.is_available():
    model = model.cuda()

criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
# criterion = nn.CTCLoss(blank=config.model.vocab_size)
decoder = GreedyDecoder(bpe=bpe)

prev_wer = 1000
wandb.init(project=config.wandb.project, config=config)
wandb.watch(model, log="all", log_freq=config.wandb.get(
    'log_interval', 5000))

In [15]:
def write_to_file(str_w, file_name = 'sth.txt', mode = 'w'):
    with open(file_name,mode) as f:
        f.write(str_w)

In [16]:
config

{'dataset': {'root': 'DB/LJspeech',
  'train_part': 0.95,
  'name': 'ljspeech',
  'sample_rate': 22050},
 'bpe': {'train': True, 'model_path': 'yttm.bpe'},
 'train': {'seed': 42,
  'num_workers': 1,
  'batch_size': 32,
  'clip_grad_norm': 15,
  'epochs': 42,
  'optimizer': {'lr': 0.0005, 'weight_decay': 0.0001}},
 'wandb': {'project': 'quartznet_ljspeech', 'log_interval': 20},
 'model': {'vocab_size': 120, 'feat_in': 64}}

In [17]:
fix_seeds(seed=config.train.get('seed', 42))
dataset_module = importlib.import_module(
    f'.{config.dataset.name}', data.__name__)
bpe = prepare_bpe(config)

transforms_train = Compose([
    TextPreprocess(),
    ToNumpy(),
    BPEtexts(bpe=bpe, dropout_prob=config.bpe.get('dropout_prob', 0.05)),
    AudioSqueeze(),
    AddGaussianNoise(
        min_amplitude=0.001,
        max_amplitude=0.015,
        p=0.5
    ),
    TimeStretch(
        min_rate=0.8,
        max_rate=1.25,
        p=0.5
    ),
    PitchShift(
        min_semitones=-4,
        max_semitones=4,
        p=0.5
    )
    # AddLengths()
])

batch_transforms_train = Compose([
    ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
    NormalizedMelSpectrogram(
        sample_rate=config.dataset.get('sample_rate', 16000),
        n_mels=config.model.feat_in,
        normalize=config.dataset.get('normalize', None)
    ).to('cuda' if torch.cuda.is_available() else 'cpu'),
    MaskSpectrogram(
        probability=0.5,
        time_mask_max_percentage=0.05,
        frequency_mask_max_percentage=0.15
    ),
    AddLengths(),
    Pad()
])

transforms_val = Compose([
    TextPreprocess(),
    ToNumpy(),
    BPEtexts(bpe=bpe),
    AudioSqueeze()
])

batch_transforms_val = Compose([
    ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
    NormalizedMelSpectrogram(
        sample_rate=config.dataset.get(
            'sample_rate', 16000),  # for LJspeech
        n_mels=config.model.feat_in,
        normalize=config.dataset.get('normalize', None)
    ).to('cuda' if torch.cuda.is_available() else 'cpu'),
    AddLengths(),
    Pad()
])

# load datasets
train_dataset = dataset_module.get_dataset(
    config, transforms=transforms_train, part='train')
val_dataset = dataset_module.get_dataset(
    config, transforms=transforms_val, part='val')
# print("!!!", config.train.get('num_workers', 4))
train_dataloader = DataLoader(train_dataset, 
                                batch_size=config.train.get('batch_size', 1), collate_fn=no_pad_collate)

val_dataloader = DataLoader(val_dataset, 
                            batch_size=1, collate_fn=no_pad_collate)

model = QuartzNet(
    model_config=getattr(
        quartznet_configs, config.model.name, '_quartznet5x5_config'),
    **remove_from_dict(config.model, ['name'])
)

print(model)
write_to_file(model,'model_structure.txt')

optimizer = torch.optim.Adam(
    model.parameters(), **config.train.get('optimizer', {}))
num_steps = len(train_dataloader) * config.train.get('epochs', 10)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=num_steps)
# warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)

if config.train.get('from_checkpoint', None) is not None:
    model.load_weights(config.train.from_checkpoint)

if torch.cuda.is_available():
    model = model.cuda()

criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
# criterion = nn.CTCLoss(blank=config.model.vocab_size)
decoder = GreedyDecoder(bpe=bpe)

prev_wer = 1000
wandb.init(project=config.wandb.project, config=config)
wandb.watch(model, log="all", log_freq=config.wandb.get(
    'log_interval', 5000))

In [18]:
#!/usr/bin/env python

import os
import json
import time
import argparse
from tqdm import tqdm
from collections import defaultdict
import math
import argparse
import importlib

# torchim:
import torch
from torch import nn
from torch.utils.data import DataLoader, Subset, ConcatDataset
# from tensorboardX import SummaryWriter
import numpy as np
import pytorch_warmup as warmup

# data:
import data
from data.collate import collate_fn, gpu_collate, no_pad_collate
from data.transforms import (
    Compose, AddLengths, AudioSqueeze, TextPreprocess,
    MaskSpectrogram, ToNumpy, BPEtexts, MelSpectrogram,
    ToGpu, Pad, NormalizedMelSpectrogram
)
import youtokentome as yttm

import torchaudio
from audiomentations import (
    TimeStretch, PitchShift, AddGaussianNoise
)
from functools import partial

# model:
from model import configs as quartznet_configs
from model.quartznet import QuartzNet

# utils:
import yaml
from easydict import EasyDict as edict
from utils import fix_seeds, remove_from_dict, prepare_bpe
import wandb
from decoder import GreedyDecoder, BeamCTCDecoder

import youtokentome as yttm


# TODO: wrap to trainer class

In [19]:

parser = argparse.ArgumentParser(description='Training model.')
parser.add_argument('--config', default='configs/train_LJSpeech.yaml',
                    help='path to config file')
args = parser.parse_args("")
with open(args.config, 'r') as f:
    config = edict(yaml.safe_load(f))

In [20]:
def write_to_file(str_w, file_name = 'sth.txt', mode = 'w'):
    with open(file_name,mode) as f:
        f.write(str_w)

In [21]:
config

{'dataset': {'root': 'DB/LJspeech',
  'train_part': 0.95,
  'name': 'ljspeech',
  'sample_rate': 22050},
 'bpe': {'train': True, 'model_path': 'yttm.bpe'},
 'train': {'seed': 42,
  'num_workers': 1,
  'batch_size': 32,
  'clip_grad_norm': 15,
  'epochs': 42,
  'optimizer': {'lr': 0.0005, 'weight_decay': 0.0001}},
 'wandb': {'project': 'quartznet_ljspeech', 'log_interval': 20},
 'model': {'name': '_quartznet5x5_config', 'vocab_size': 120, 'feat_in': 64}}

In [22]:
fix_seeds(seed=config.train.get('seed', 42))
dataset_module = importlib.import_module(
    f'.{config.dataset.name}', data.__name__)
bpe = prepare_bpe(config)

transforms_train = Compose([
    TextPreprocess(),
    ToNumpy(),
    BPEtexts(bpe=bpe, dropout_prob=config.bpe.get('dropout_prob', 0.05)),
    AudioSqueeze(),
    AddGaussianNoise(
        min_amplitude=0.001,
        max_amplitude=0.015,
        p=0.5
    ),
    TimeStretch(
        min_rate=0.8,
        max_rate=1.25,
        p=0.5
    ),
    PitchShift(
        min_semitones=-4,
        max_semitones=4,
        p=0.5
    )
    # AddLengths()
])

batch_transforms_train = Compose([
    ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
    NormalizedMelSpectrogram(
        sample_rate=config.dataset.get('sample_rate', 16000),
        n_mels=config.model.feat_in,
        normalize=config.dataset.get('normalize', None)
    ).to('cuda' if torch.cuda.is_available() else 'cpu'),
    MaskSpectrogram(
        probability=0.5,
        time_mask_max_percentage=0.05,
        frequency_mask_max_percentage=0.15
    ),
    AddLengths(),
    Pad()
])

transforms_val = Compose([
    TextPreprocess(),
    ToNumpy(),
    BPEtexts(bpe=bpe),
    AudioSqueeze()
])

batch_transforms_val = Compose([
    ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
    NormalizedMelSpectrogram(
        sample_rate=config.dataset.get(
            'sample_rate', 16000),  # for LJspeech
        n_mels=config.model.feat_in,
        normalize=config.dataset.get('normalize', None)
    ).to('cuda' if torch.cuda.is_available() else 'cpu'),
    AddLengths(),
    Pad()
])

# load datasets
train_dataset = dataset_module.get_dataset(
    config, transforms=transforms_train, part='train')
val_dataset = dataset_module.get_dataset(
    config, transforms=transforms_val, part='val')
# print("!!!", config.train.get('num_workers', 4))
train_dataloader = DataLoader(train_dataset, 
                                batch_size=config.train.get('batch_size', 1), collate_fn=no_pad_collate)

val_dataloader = DataLoader(val_dataset, 
                            batch_size=1, collate_fn=no_pad_collate)

model = QuartzNet(
    model_config=getattr(
        quartznet_configs, config.model.name, '_quartznet5x5_config'),
    **remove_from_dict(config.model, ['name'])
)

print(model)
write_to_file(model,'model_structure.txt')

optimizer = torch.optim.Adam(
    model.parameters(), **config.train.get('optimizer', {}))
num_steps = len(train_dataloader) * config.train.get('epochs', 10)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=num_steps)
# warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)

if config.train.get('from_checkpoint', None) is not None:
    model.load_weights(config.train.from_checkpoint)

if torch.cuda.is_available():
    model = model.cuda()

criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
# criterion = nn.CTCLoss(blank=config.model.vocab_size)
decoder = GreedyDecoder(bpe=bpe)

prev_wer = 1000
wandb.init(project=config.wandb.project, config=config)
wandb.watch(model, log="all", log_freq=config.wandb.get(
    'log_interval', 5000))

In [23]:
def write_to_file(str_w, file_name = 'sth.txt', mode = 'w'):
    with open(file_name,mode) as f:
        f.write(str(str_w))

In [24]:
fix_seeds(seed=config.train.get('seed', 42))
dataset_module = importlib.import_module(
    f'.{config.dataset.name}', data.__name__)
bpe = prepare_bpe(config)

transforms_train = Compose([
    TextPreprocess(),
    ToNumpy(),
    BPEtexts(bpe=bpe, dropout_prob=config.bpe.get('dropout_prob', 0.05)),
    AudioSqueeze(),
    AddGaussianNoise(
        min_amplitude=0.001,
        max_amplitude=0.015,
        p=0.5
    ),
    TimeStretch(
        min_rate=0.8,
        max_rate=1.25,
        p=0.5
    ),
    PitchShift(
        min_semitones=-4,
        max_semitones=4,
        p=0.5
    )
    # AddLengths()
])

batch_transforms_train = Compose([
    ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
    NormalizedMelSpectrogram(
        sample_rate=config.dataset.get('sample_rate', 16000),
        n_mels=config.model.feat_in,
        normalize=config.dataset.get('normalize', None)
    ).to('cuda' if torch.cuda.is_available() else 'cpu'),
    MaskSpectrogram(
        probability=0.5,
        time_mask_max_percentage=0.05,
        frequency_mask_max_percentage=0.15
    ),
    AddLengths(),
    Pad()
])

transforms_val = Compose([
    TextPreprocess(),
    ToNumpy(),
    BPEtexts(bpe=bpe),
    AudioSqueeze()
])

batch_transforms_val = Compose([
    ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
    NormalizedMelSpectrogram(
        sample_rate=config.dataset.get(
            'sample_rate', 16000),  # for LJspeech
        n_mels=config.model.feat_in,
        normalize=config.dataset.get('normalize', None)
    ).to('cuda' if torch.cuda.is_available() else 'cpu'),
    AddLengths(),
    Pad()
])

# load datasets
train_dataset = dataset_module.get_dataset(
    config, transforms=transforms_train, part='train')
val_dataset = dataset_module.get_dataset(
    config, transforms=transforms_val, part='val')
# print("!!!", config.train.get('num_workers', 4))
train_dataloader = DataLoader(train_dataset, 
                                batch_size=config.train.get('batch_size', 1), collate_fn=no_pad_collate)

val_dataloader = DataLoader(val_dataset, 
                            batch_size=1, collate_fn=no_pad_collate)

model = QuartzNet(
    model_config=getattr(
        quartznet_configs, config.model.name, '_quartznet5x5_config'),
    **remove_from_dict(config.model, ['name'])
)

print(model)
write_to_file(model,'model_structure.txt')

optimizer = torch.optim.Adam(
    model.parameters(), **config.train.get('optimizer', {}))
num_steps = len(train_dataloader) * config.train.get('epochs', 10)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=num_steps)
# warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)

if config.train.get('from_checkpoint', None) is not None:
    model.load_weights(config.train.from_checkpoint)

if torch.cuda.is_available():
    model = model.cuda()

criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
# criterion = nn.CTCLoss(blank=config.model.vocab_size)
decoder = GreedyDecoder(bpe=bpe)

prev_wer = 1000
wandb.init(project=config.wandb.project, config=config)
wandb.watch(model, log="all", log_freq=config.wandb.get(
    'log_interval', 5000))

In [25]:
#!/usr/bin/env python

import os
import json
import time
import argparse
from tqdm import tqdm
from collections import defaultdict
import math
import argparse
import importlib

# torchim:
import torch
from torch import nn
from torch.utils.data import DataLoader, Subset, ConcatDataset
# from tensorboardX import SummaryWriter
import numpy as np
import pytorch_warmup as warmup

# data:
import data
from data.collate import collate_fn, gpu_collate, no_pad_collate
from data.transforms import (
    Compose, AddLengths, AudioSqueeze, TextPreprocess,
    MaskSpectrogram, ToNumpy, BPEtexts, MelSpectrogram,
    ToGpu, Pad, NormalizedMelSpectrogram
)
import youtokentome as yttm

import torchaudio
from audiomentations import (
    TimeStretch, PitchShift, AddGaussianNoise
)
from functools import partial

# model:
from model import configs as quartznet_configs
from model.quartznet import QuartzNet

# utils:
import yaml
from easydict import EasyDict as edict
from utils import fix_seeds, remove_from_dict, prepare_bpe
import wandb
from decoder import GreedyDecoder, BeamCTCDecoder

import youtokentome as yttm


# TODO: wrap to trainer class

parser = argparse.ArgumentParser(description='Training model.')
parser.add_argument('--config', default='configs/train_LJSpeech.yaml',
                    help='path to config file')
args = parser.parse_args("")
with open(args.config, 'r') as f:
    config = edict(yaml.safe_load(f))

def write_to_file(str_w, file_name = 'sth.txt', mode = 'w'):
    with open(file_name,mode) as f:
        f.write(str(str_w))

In [26]:
config

{'dataset': {'root': 'DB/LJspeech',
  'train_part': 0.95,
  'name': 'ljspeech',
  'sample_rate': 22050},
 'bpe': {'train': True, 'model_path': 'yttm.bpe'},
 'train': {'seed': 42,
  'num_workers': 1,
  'batch_size': 32,
  'clip_grad_norm': 15,
  'epochs': 42,
  'optimizer': {'lr': 0.0005, 'weight_decay': 0.0001}},
 'wandb': {'project': 'quartznet_ljspeech', 'log_interval': 20},
 'model': {'name': '_quartznet5x5_config', 'vocab_size': 120, 'feat_in': 64}}

In [27]:
fix_seeds(seed=config.train.get('seed', 42))
dataset_module = importlib.import_module(
    f'.{config.dataset.name}', data.__name__)
bpe = prepare_bpe(config)

transforms_train = Compose([
    TextPreprocess(),
    ToNumpy(),
    BPEtexts(bpe=bpe, dropout_prob=config.bpe.get('dropout_prob', 0.05)),
    AudioSqueeze(),
    AddGaussianNoise(
        min_amplitude=0.001,
        max_amplitude=0.015,
        p=0.5
    ),
    TimeStretch(
        min_rate=0.8,
        max_rate=1.25,
        p=0.5
    ),
    PitchShift(
        min_semitones=-4,
        max_semitones=4,
        p=0.5
    )
    # AddLengths()
])

batch_transforms_train = Compose([
    ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
    NormalizedMelSpectrogram(
        sample_rate=config.dataset.get('sample_rate', 16000),
        n_mels=config.model.feat_in,
        normalize=config.dataset.get('normalize', None)
    ).to('cuda' if torch.cuda.is_available() else 'cpu'),
    MaskSpectrogram(
        probability=0.5,
        time_mask_max_percentage=0.05,
        frequency_mask_max_percentage=0.15
    ),
    AddLengths(),
    Pad()
])

transforms_val = Compose([
    TextPreprocess(),
    ToNumpy(),
    BPEtexts(bpe=bpe),
    AudioSqueeze()
])

batch_transforms_val = Compose([
    ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
    NormalizedMelSpectrogram(
        sample_rate=config.dataset.get(
            'sample_rate', 16000),  # for LJspeech
        n_mels=config.model.feat_in,
        normalize=config.dataset.get('normalize', None)
    ).to('cuda' if torch.cuda.is_available() else 'cpu'),
    AddLengths(),
    Pad()
])

# load datasets
train_dataset = dataset_module.get_dataset(
    config, transforms=transforms_train, part='train')
val_dataset = dataset_module.get_dataset(
    config, transforms=transforms_val, part='val')
# print("!!!", config.train.get('num_workers', 4))
train_dataloader = DataLoader(train_dataset, 
                                batch_size=config.train.get('batch_size', 1), collate_fn=no_pad_collate)

val_dataloader = DataLoader(val_dataset, 
                            batch_size=1, collate_fn=no_pad_collate)

model = QuartzNet(
    model_config=getattr(
        quartznet_configs, config.model.name, '_quartznet5x5_config'),
    **remove_from_dict(config.model, ['name'])
)

print(model)
write_to_file(model,'model_structure.txt')

optimizer = torch.optim.Adam(
    model.parameters(), **config.train.get('optimizer', {}))
num_steps = len(train_dataloader) * config.train.get('epochs', 10)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=num_steps)
# warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)

if config.train.get('from_checkpoint', None) is not None:
    model.load_weights(config.train.from_checkpoint)

if torch.cuda.is_available():
    model = model.cuda()

criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
# criterion = nn.CTCLoss(blank=config.model.vocab_size)
decoder = GreedyDecoder(bpe=bpe)

prev_wer = 1000
wandb.init(project=config.wandb.project, config=config)
wandb.watch(model, log="all", log_freq=config.wandb.get(
    'log_interval', 5000))

[]

In [28]:
#!/usr/bin/env python

import os
import json
import time
import argparse
from tqdm import tqdm
from collections import defaultdict
import math
import argparse
import importlib

# torchim:
import torch
from torch import nn
from torch.utils.data import DataLoader, Subset, ConcatDataset
# from tensorboardX import SummaryWriter
import numpy as np
import pytorch_warmup as warmup

# data:
import data
from data.collate import collate_fn, gpu_collate, no_pad_collate
from data.transforms import (
    Compose, AddLengths, AudioSqueeze, TextPreprocess,
    MaskSpectrogram, ToNumpy, BPEtexts, MelSpectrogram,
    ToGpu, Pad, NormalizedMelSpectrogram
)
import youtokentome as yttm

import torchaudio
from audiomentations import (
    TimeStretch, PitchShift, AddGaussianNoise
)
from functools import partial

# model:
from model import configs as quartznet_configs
from model.quartznet import QuartzNet

# utils:
import yaml
from easydict import EasyDict as edict
from utils import fix_seeds, remove_from_dict, prepare_bpe
import wandb
from decoder import GreedyDecoder, BeamCTCDecoder

import youtokentome as yttm


# TODO: wrap to trainer class

parser = argparse.ArgumentParser(description='Training model.')
parser.add_argument('--config', default='configs/train_LJSpeech.yaml',
                    help='path to config file')
args = parser.parse_args("")
with open(args.config, 'r') as f:
    config = edict(yaml.safe_load(f))

def write_to_file(str_w, file_name = 'sth.txt', mode = 'w'):
    with open(file_name,mode) as f:
        f.write(str(str_w))

In [29]:
config

{'dataset': {'root': 'DB/LJspeech',
  'train_part': 0.95,
  'name': 'ljspeech',
  'sample_rate': 22050},
 'bpe': {'train': True, 'model_path': 'yttm.bpe'},
 'train': {'seed': 42,
  'num_workers': 1,
  'batch_size': 32,
  'clip_grad_norm': 15,
  'epochs': 42,
  'optimizer': {'lr': 0.0005, 'weight_decay': 0.0001}},
 'wandb': {'project': 'quartznet_ljspeech', 'log_interval': 20},
 'model': {'name': '_quartznet5x5_config', 'vocab_size': 120, 'feat_in': 64}}

In [30]:
fix_seeds(seed=config.train.get('seed', 42))
dataset_module = importlib.import_module(
    f'.{config.dataset.name}', data.__name__)
bpe = prepare_bpe(config)

transforms_train = Compose([
    TextPreprocess(),
    ToNumpy(),
    BPEtexts(bpe=bpe, dropout_prob=config.bpe.get('dropout_prob', 0.05)),
    AudioSqueeze(),
    AddGaussianNoise(
        min_amplitude=0.001,
        max_amplitude=0.015,
        p=0.5
    ),
    TimeStretch(
        min_rate=0.8,
        max_rate=1.25,
        p=0.5
    ),
    PitchShift(
        min_semitones=-4,
        max_semitones=4,
        p=0.5
    )
    # AddLengths()
])

batch_transforms_train = Compose([
    ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
    NormalizedMelSpectrogram(
        sample_rate=config.dataset.get('sample_rate', 16000),
        n_mels=config.model.feat_in,
        normalize=config.dataset.get('normalize', None)
    ).to('cuda' if torch.cuda.is_available() else 'cpu'),
    MaskSpectrogram(
        probability=0.5,
        time_mask_max_percentage=0.05,
        frequency_mask_max_percentage=0.15
    ),
    AddLengths(),
    Pad()
])

transforms_val = Compose([
    TextPreprocess(),
    ToNumpy(),
    BPEtexts(bpe=bpe),
    AudioSqueeze()
])

batch_transforms_val = Compose([
    ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
    NormalizedMelSpectrogram(
        sample_rate=config.dataset.get(
            'sample_rate', 16000),  # for LJspeech
        n_mels=config.model.feat_in,
        normalize=config.dataset.get('normalize', None)
    ).to('cuda' if torch.cuda.is_available() else 'cpu'),
    AddLengths(),
    Pad()
])

# load datasets
train_dataset = dataset_module.get_dataset(
    config, transforms=transforms_train, part='train')
val_dataset = dataset_module.get_dataset(
    config, transforms=transforms_val, part='val')
# print("!!!", config.train.get('num_workers', 4))
train_dataloader = DataLoader(train_dataset, 
                                batch_size=config.train.get('batch_size', 1), collate_fn=no_pad_collate)

val_dataloader = DataLoader(val_dataset, 
                            batch_size=1, collate_fn=no_pad_collate)

model = QuartzNet(
    model_config=getattr(
        quartznet_configs, config.model.name, '_quartznet5x5_config'),
    **remove_from_dict(config.model, ['name'])
)

print(model)
write_to_file(model,'model_structure.txt')

optimizer = torch.optim.Adam(
    model.parameters(), **config.train.get('optimizer', {}))
num_steps = len(train_dataloader) * config.train.get('epochs', 10)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=num_steps)
# warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)

if config.train.get('from_checkpoint', None) is not None:
    model.load_weights(config.train.from_checkpoint)

if torch.cuda.is_available():
    model = model.cuda()

criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
# criterion = nn.CTCLoss(blank=config.model.vocab_size)
decoder = GreedyDecoder(bpe=bpe)

prev_wer = 1000
wandb.init(project=config.wandb.project, config=config)
wandb.watch(model, log="all", log_freq=config.wandb.get(
    'log_interval', 5000))

[]

In [31]:
getattr(quartznet_configs, config.model.name, '_quartznet5x5_config')

In [32]:
#!/usr/bin/env python

import os
import json
import time
import argparse
from tqdm import tqdm
from collections import defaultdict
import math
import argparse
import importlib

# torchim:
import torch
from torch import nn
from torch.utils.data import DataLoader, Subset, ConcatDataset
# from tensorboardX import SummaryWriter
import numpy as np
import pytorch_warmup as warmup

# data:
import data
from data.collate import collate_fn, gpu_collate, no_pad_collate
from data.transforms import (
    Compose, AddLengths, AudioSqueeze, TextPreprocess,
    MaskSpectrogram, ToNumpy, BPEtexts, MelSpectrogram,
    ToGpu, Pad, NormalizedMelSpectrogram
)
import youtokentome as yttm

import torchaudio
from audiomentations import (
    TimeStretch, PitchShift, AddGaussianNoise
)
from functools import partial

# model:
from model import configs as quartznet_configs
from model.quartznet import QuartzNet

# utils:
import yaml
from easydict import EasyDict as edict
from utils import fix_seeds, remove_from_dict, prepare_bpe
import wandb
from decoder import GreedyDecoder, BeamCTCDecoder

import youtokentome as yttm


# TODO: wrap to trainer class

parser = argparse.ArgumentParser(description='Training model.')
parser.add_argument('--config', default='configs/train_LJSpeech.yaml',
                    help='path to config file')
args = parser.parse_args("")
with open(args.config, 'r') as f:
    config = edict(yaml.safe_load(f))

def write_to_file(str_w, file_name = 'sth.txt', mode = 'w'):
    with open(file_name,mode) as f:
        f.write(str(str_w))

In [33]:
config

{'dataset': {'root': 'DB/LJspeech',
  'train_part': 0.95,
  'name': 'ljspeech',
  'sample_rate': 22050},
 'bpe': {'train': True, 'model_path': 'yttm.bpe'},
 'train': {'seed': 42,
  'num_workers': 1,
  'batch_size': 32,
  'clip_grad_norm': 15,
  'epochs': 42,
  'optimizer': {'lr': 0.0005, 'weight_decay': 0.0001}},
 'wandb': {'project': 'quartznet_ljspeech', 'log_interval': 20},
 'model': {'name': '_quartznet5x5_config', 'vocab_size': 120, 'feat_in': 64}}

In [34]:
fix_seeds(seed=config.train.get('seed', 42))
dataset_module = importlib.import_module(
    f'.{config.dataset.name}', data.__name__)
bpe = prepare_bpe(config)

transforms_train = Compose([
    TextPreprocess(),
    ToNumpy(),
    BPEtexts(bpe=bpe, dropout_prob=config.bpe.get('dropout_prob', 0.05)),
    AudioSqueeze(),
    AddGaussianNoise(
        min_amplitude=0.001,
        max_amplitude=0.015,
        p=0.5
    ),
    TimeStretch(
        min_rate=0.8,
        max_rate=1.25,
        p=0.5
    ),
    PitchShift(
        min_semitones=-4,
        max_semitones=4,
        p=0.5
    )
    # AddLengths()
])

batch_transforms_train = Compose([
    ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
    NormalizedMelSpectrogram(
        sample_rate=config.dataset.get('sample_rate', 16000),
        n_mels=config.model.feat_in,
        normalize=config.dataset.get('normalize', None)
    ).to('cuda' if torch.cuda.is_available() else 'cpu'),
    MaskSpectrogram(
        probability=0.5,
        time_mask_max_percentage=0.05,
        frequency_mask_max_percentage=0.15
    ),
    AddLengths(),
    Pad()
])

transforms_val = Compose([
    TextPreprocess(),
    ToNumpy(),
    BPEtexts(bpe=bpe),
    AudioSqueeze()
])

batch_transforms_val = Compose([
    ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
    NormalizedMelSpectrogram(
        sample_rate=config.dataset.get(
            'sample_rate', 16000),  # for LJspeech
        n_mels=config.model.feat_in,
        normalize=config.dataset.get('normalize', None)
    ).to('cuda' if torch.cuda.is_available() else 'cpu'),
    AddLengths(),
    Pad()
])

# load datasets
train_dataset = dataset_module.get_dataset(
    config, transforms=transforms_train, part='train')
val_dataset = dataset_module.get_dataset(
    config, transforms=transforms_val, part='val')
# print("!!!", config.train.get('num_workers', 4))
train_dataloader = DataLoader(train_dataset, 
                                batch_size=config.train.get('batch_size', 1), collate_fn=no_pad_collate)

val_dataloader = DataLoader(val_dataset, 
                            batch_size=1, collate_fn=no_pad_collate)

print(getattr(quartznet_configs, config.model.name, '_quartznet5x5_config'))

model = QuartzNet(
    model_config=getattr(
        quartznet_configs, config.model.name, '_quartznet5x5_config'),
    **remove_from_dict(config.model, ['name'])
)

# print(model)
write_to_file(model,'model_structure.txt')

optimizer = torch.optim.Adam(
    model.parameters(), **config.train.get('optimizer', {}))
num_steps = len(train_dataloader) * config.train.get('epochs', 10)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=num_steps)
# warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)

if config.train.get('from_checkpoint', None) is not None:
    model.load_weights(config.train.from_checkpoint)

if torch.cuda.is_available():
    model = model.cuda()

criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
# criterion = nn.CTCLoss(blank=config.model.vocab_size)
decoder = GreedyDecoder(bpe=bpe)

prev_wer = 1000
wandb.init(project=config.wandb.project, config=config)
wandb.watch(model, log="all", log_freq=config.wandb.get(
    'log_interval', 5000))

[]

VBox(children=(Label(value='0.042 MB of 0.042 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

In [35]:
for epoch_idx in tqdm(range(config.train.get('epochs', 10))):
    # train:
    model.train()
    for batch_idx, batch in enumerate(train_dataloader):
        batch1=batch
        batch = batch_transforms_train(batch)
        batch2=batch
        break

        optimizer.zero_grad()
        logits = model(batch['audio'])
        output_length = torch.ceil(
            batch['input_lengths'].float() / model.stride).int()
        loss = criterion(logits.permute(2, 0, 1).log_softmax(
            dim=2), batch['text'], output_length, batch['target_lengths'])
        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(), config.train.get('clip_grad_norm', 15))
        optimizer.step()
        lr_scheduler.step()
        # warmup_scheduler.dampen()

        if batch_idx % config.wandb.get('log_interval', 5000) == 0:
            target_strings = decoder.convert_to_strings(batch['text'])
            decoded_output = decoder.decode(
                logits.permute(0, 2, 1).softmax(dim=2))
            wer = np.mean([decoder.wer(true, pred)
                            for true, pred in zip(target_strings, decoded_output)])
            cer = np.mean([decoder.cer(true, pred)
                            for true, pred in zip(target_strings, decoded_output)])
            step = epoch_idx * \
                len(train_dataloader) * train_dataloader.batch_size + \
                batch_idx * train_dataloader.batch_size
            wandb.log({
                "train_loss": loss.item(),
                "train_wer": wer,
                "train_cer": cer,
                "train_samples": wandb.Table(
                    columns=['gt_text', 'pred_text'],
                    data=zip(target_strings, decoded_output)
                )
            }, step=step)
    # #!
    # # validate:
    # model.eval()
    # val_stats = defaultdict(list)
    # for batch_idx, batch in enumerate(val_dataloader):
    #     batch = batch_transforms_val(batch)
    #     with torch.no_grad():
    #         logits = model(batch['audio'])
    #         output_length = torch.ceil(
    #             batch['input_lengths'].float() / model.stride).int()
    #         loss = criterion(logits.permute(2, 0, 1).log_softmax(
    #             dim=2), batch['text'], output_length, batch['target_lengths'])

    #     target_strings = decoder.convert_to_strings(batch['text'])
    #     decoded_output = decoder.decode(
    #         logits.permute(0, 2, 1).softmax(dim=2))
    #     wer = np.mean([decoder.wer(true, pred)
    #                     for true, pred in zip(target_strings, decoded_output)])
    #     cer = np.mean([decoder.cer(true, pred)
    #                     for true, pred in zip(target_strings, decoded_output)])
    #     val_stats['val_loss'].append(loss.item())
    #     val_stats['wer'].append(wer)
    #     val_stats['cer'].append(cer)
    # for k, v in val_stats.items():
    #     val_stats[k] = np.mean(v)
    # val_stats['val_samples'] = wandb.Table(
    #     columns=['gt_text', 'pred_text'], data=zip(target_strings, decoded_output))
    # wandb.log(val_stats, step=step)

    # # save model, TODO: save optimizer:
    # if val_stats['wer'] < prev_wer:
    #     os.makedirs(config.train.get(
    #         'checkpoint_path', 'checkpoints'), exist_ok=True)
    #     prev_wer = val_stats['wer']
    #     torch.save(
    #         model.state_dict(),
    #         os.path.join(config.train.get(
    #             'checkpoint_path', 'checkpoints'), f'model_{epoch_idx}_{prev_wer}.pth')
    #     )
    #     wandb.save(os.path.join(config.train.get(
    #         'checkpoint_path', 'checkpoints'), f'model_{epoch_idx}_{prev_wer}.pth'))

In [36]:
for epoch_idx in tqdm(range(config.train.get('epochs', 10))):
    # train:
    model.train()
    for batch_idx, batch in enumerate(train_dataloader):
        batch1=batch
        batch = batch_transforms_train(batch)
        batch2=batch
        break

        optimizer.zero_grad()
        logits = model(batch['audio'])
        output_length = torch.ceil(
            batch['input_lengths'].float() / model.stride).int()
        loss = criterion(logits.permute(2, 0, 1).log_softmax(
            dim=2), batch['text'], output_length, batch['target_lengths'])
        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(), config.train.get('clip_grad_norm', 15))
        optimizer.step()
        lr_scheduler.step()
        # warmup_scheduler.dampen()

        if batch_idx % config.wandb.get('log_interval', 5000) == 0:
            target_strings = decoder.convert_to_strings(batch['text'])
            decoded_output = decoder.decode(
                logits.permute(0, 2, 1).softmax(dim=2))
            wer = np.mean([decoder.wer(true, pred)
                            for true, pred in zip(target_strings, decoded_output)])
            cer = np.mean([decoder.cer(true, pred)
                            for true, pred in zip(target_strings, decoded_output)])
            step = epoch_idx * \
                len(train_dataloader) * train_dataloader.batch_size + \
                batch_idx * train_dataloader.batch_size
            wandb.log({
                "train_loss": loss.item(),
                "train_wer": wer,
                "train_cer": cer,
                "train_samples": wandb.Table(
                    columns=['gt_text', 'pred_text'],
                    data=zip(target_strings, decoded_output)
                )
            }, step=step)
    # #!
    # # validate:
    # model.eval()
    # val_stats = defaultdict(list)
    # for batch_idx, batch in enumerate(val_dataloader):
    #     batch = batch_transforms_val(batch)
    #     with torch.no_grad():
    #         logits = model(batch['audio'])
    #         output_length = torch.ceil(
    #             batch['input_lengths'].float() / model.stride).int()
    #         loss = criterion(logits.permute(2, 0, 1).log_softmax(
    #             dim=2), batch['text'], output_length, batch['target_lengths'])

    #     target_strings = decoder.convert_to_strings(batch['text'])
    #     decoded_output = decoder.decode(
    #         logits.permute(0, 2, 1).softmax(dim=2))
    #     wer = np.mean([decoder.wer(true, pred)
    #                     for true, pred in zip(target_strings, decoded_output)])
    #     cer = np.mean([decoder.cer(true, pred)
    #                     for true, pred in zip(target_strings, decoded_output)])
    #     val_stats['val_loss'].append(loss.item())
    #     val_stats['wer'].append(wer)
    #     val_stats['cer'].append(cer)
    # for k, v in val_stats.items():
    #     val_stats[k] = np.mean(v)
    # val_stats['val_samples'] = wandb.Table(
    #     columns=['gt_text', 'pred_text'], data=zip(target_strings, decoded_output))
    # wandb.log(val_stats, step=step)

    # # save model, TODO: save optimizer:
    # if val_stats['wer'] < prev_wer:
    #     os.makedirs(config.train.get(
    #         'checkpoint_path', 'checkpoints'), exist_ok=True)
    #     prev_wer = val_stats['wer']
    #     torch.save(
    #         model.state_dict(),
    #         os.path.join(config.train.get(
    #             'checkpoint_path', 'checkpoints'), f'model_{epoch_idx}_{prev_wer}.pth')
    #     )
    #     wandb.save(os.path.join(config.train.get(
    #         'checkpoint_path', 'checkpoints'), f'model_{epoch_idx}_{prev_wer}.pth'))
    break

In [37]:
batch1

{'audio': [array([ 0.00064805, -0.00185332, -0.00244806, ..., -0.00170346,
          0.00287292,  0.00232188], dtype=float32),
  array([-0.00314305,  0.00103363,  0.00466579, ...,  0.03743864,
          0.04169904,  0.        ], dtype=float32),
  array([-2.7460288e-03, -4.2517125e-03, -6.8397014e-05, ...,
          0.0000000e+00,  0.0000000e+00,  0.0000000e+00], dtype=float32),
  array([ 1.6830985e-04, -2.7138460e-05, -2.2539313e-04, ...,
          2.5602366e-04,  1.0118349e-03,  0.0000000e+00], dtype=float32),
  array([3.5219718e-04, 3.9384593e-04, 3.4842314e-04, ..., 1.4193331e-04,
         1.3546018e-04, 8.1811209e-05], dtype=float32),
  array([ 7.7085420e-03, -2.8882974e-03, -2.8774177e-05, ...,
          0.0000000e+00,  0.0000000e+00,  0.0000000e+00], dtype=float32),
  array([ 3.8510363e-05,  2.8201991e-03,  5.8321138e-03, ...,
          4.3147802e-03,  1.1419075e-02, -6.4695268e-03], dtype=float32),
  array([ 2.2583008e-03,  1.7700195e-03, -9.1552734e-05, ...,
          1.8310547

In [38]:
batch1.keys()

dict_keys(['audio', 'text', 'sample_rate'])

In [39]:
batch2.keys()

dict_keys(['audio', 'text', 'sample_rate', 'input_lengths', 'target_lengths'])

In [40]:
batch2.input_lengths

In [41]:
batch2['input_lengths']

tensor([1065,  210, 1066,  567,  895,  627,  925,  197,  833,  973,  498,  909,
         285, 1097, 1019,  581,  774,  826,  708,  516,  950,  778,  932,  867,
         978,  672, 1064,  654,  588,  763,  867,  781], device='cuda:0')

In [42]:
batch2['text']

tensor([[ 56,  12,  44,  ...,   0,   0,   0],
        [ 65,  90,  75,  ...,   0,   0,   0],
        [106,  42,  15,  ...,  16,  71,  11],
        ...,
        [ 42,   4,  25,  ...,   0,   0,   0],
        [ 65,  57,  70,  ...,   0,   0,   0],
        [ 72, 110,  11,  ...,   0,   0,   0]], device='cuda:0',
       dtype=torch.int32)

In [43]:
batch2['text'].shape

torch.Size([32, 96])

In [44]:
batch1['text'].shape

In [45]:
batch1['text']

[array([ 56,  12,  44,   6,  75,  65,  43,  93,  88,  48,  49, 103,  46,
         62,  13,  99,  78,  13,  46,   5,  42,  47, 119,  56,  47,  11,
         79,  53,  50,  16,  51,   9,  52,  73,  10,  17,  17,  51,  11,
         57,  82,  19,  63,   8,  96,   4,  10,  17,   4,   9,  86,  57,
         82,  19,  42,  98,  43,  42,  12,   6,  11,  72,  53,  12,   7,
         17,   6,  11,  87,  21,  47,  11,  79,  52,  65,  43,  83,  27,
         13,  10,  24,  62,  80]),
 array([65, 90, 75, 53, 97, 21, 69, 54, 10, 95, 88, 63,  8, 14, 51,  9]),
 array([106,  42,  15,   6,  13,  70, 107,  43,  53,  13,  44,  71,   5,
         40,   8,   8,  26,   4,  94,  21,  47,  11,  11,  80,  11,  57,
         82,  19,  46,   8,   8,  14,  55,  15,   8,  16,  26,  11,   4,
         49,  22,  12,   7,  25,  52,  65,   4,  47,  15,  10,   5,  17,
        106,  53,  79,  18, 116,  71,  90,  17,   8,  47,  43,  46,   8,
          8,  14,  16, 104,   6,  51,  11,  58,  43,  77,   5,   6,  41,
         12,  1

In [46]:
batch1['text'][0]

array([ 56,  12,  44,   6,  75,  65,  43,  93,  88,  48,  49, 103,  46,
        62,  13,  99,  78,  13,  46,   5,  42,  47, 119,  56,  47,  11,
        79,  53,  50,  16,  51,   9,  52,  73,  10,  17,  17,  51,  11,
        57,  82,  19,  63,   8,  96,   4,  10,  17,   4,   9,  86,  57,
        82,  19,  42,  98,  43,  42,  12,   6,  11,  72,  53,  12,   7,
        17,   6,  11,  87,  21,  47,  11,  79,  52,  65,  43,  83,  27,
        13,  10,  24,  62,  80])

In [47]:
batch1['text'][0].len()

In [48]:
len(batch1['text'][0])

83

In [49]:
len(batch1['text'])

32

In [50]:
batch1['text'][0]

array([ 56,  12,  44,   6,  75,  65,  43,  93,  88,  48,  49, 103,  46,
        62,  13,  99,  78,  13,  46,   5,  42,  47, 119,  56,  47,  11,
        79,  53,  50,  16,  51,   9,  52,  73,  10,  17,  17,  51,  11,
        57,  82,  19,  63,   8,  96,   4,  10,  17,   4,   9,  86,  57,
        82,  19,  42,  98,  43,  42,  12,   6,  11,  72,  53,  12,   7,
        17,   6,  11,  87,  21,  47,  11,  79,  52,  65,  43,  83,  27,
        13,  10,  24,  62,  80])

In [51]:
train_dataset

<torch.utils.data.dataset.Subset at 0x1ee6cd9f280>

In [52]:
dir(train_dataset)

['__add__',
 '__annotations__',
 '__class__',
 '__class_getitem__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__orig_bases__',
 '__parameters__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__slots__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_is_protocol',
 'dataset',
 'indices']

In [53]:
train_dataset.data

In [54]:
train_dataset.dataset

<data.ljspeech.LJSpeechDataset at 0x1ee6cd5b2b0>

In [55]:
dir(train_dataset.dataset)

['__add__',
 '__class__',
 '__class_getitem__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__orig_bases__',
 '__parameters__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__slots__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_flist',
 '_is_protocol',
 '_metadata_path',
 '_parse_filesystem',
 '_path',
 'get_text',
 'transforms']

In [56]:
dir(train_dataset.dataset.getitem())

In [57]:
dir(train_dataset.dataset)

['__add__',
 '__class__',
 '__class_getitem__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__orig_bases__',
 '__parameters__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__slots__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_flist',
 '_is_protocol',
 '_metadata_path',
 '_parse_filesystem',
 '_path',
 'get_text',
 'transforms']

In [58]:
train_dataset.dataset__getitem__()

In [59]:
train_dataset.dataset.__getitem__()

In [60]:
train_dataset.dataset.__getitem__(0)

{'audio': array([ 0.00337451, -0.00338884, -0.00205553, ...,  0.00065728,
         0.00708335,  0.00721541], dtype=float32),
 'text': array([ 56,  12,  44,   6,  75,  65,  43,  93,  88,  48,  49, 103,  46,
         62,  13,  99,  78,  13,  46,   5,  42,  47, 119,  56,  47,  11,
         79,  53,  50,  16,  51,   9,  52,  73,  10,  17,  17,  51,  11,
         57,  82,  19,  63,   8,  96,   4,  10,  17,  77,  86,  57,  82,
         19,  42,  98,  43,  42,  12,   6,  11,  72,  53,  12,   7,  17,
          6,  11,  87,  21,  47,  11,  79,  52,  65,  43,  83,  27,  13,
         10,  24,  62,  80]),
 'sample_rate': 22050}

In [61]:
# load datasets
train_dataset_2 = dataset_module.get_dataset(
    config, part='train')

In [62]:
# load datasets
train_dataset_2 = dataset_module.get_dataset(
    config, part='train')
train_dataset_2.dataset.__getitem__(0)

{'audio': tensor([[-7.3242e-04, -7.6294e-04, -6.4087e-04,  ...,  7.3242e-04,
           2.1362e-04,  6.1035e-05]]),
 'text': 'Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition',
 'sample_rate': 22050}

In [63]:
# load datasets
train_dataset_2 = dataset_module.get_dataset(
    config, part='train')
train_dataset_2.dataset.__getitem__(0)

{'audio': tensor([[-7.3242e-04, -7.6294e-04, -6.4087e-04,  ...,  7.3242e-04,
           2.1362e-04,  6.1035e-05]]),
 'text': 'Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition',
 'sample_rate': 22050}

In [64]:
train_dataset.dataset.__getitem__(0)

{'audio': array([0.00087503, 0.00045779, 0.00011389, ..., 0.        , 0.        ,
        0.        ], dtype=float32),
 'text': array([ 56,  12,  44,   6,  75,  65,  43,  93,  88,  48,  49, 103,  46,
         62,  13,  46,  13,  78,  13,  46,   5,  42,  47, 119,  56,  47,
         11,  79,  53,  50,  16,  51,   9,  52,  73,  10,  17,  17,  51,
         11,  57,  82,  19,  63,   8,  96,   4,  10,  17,  77,  86,  57,
         82,  19,  42,  15,  15,  43,  42,  12,   6,  11,  72,  53,  12,
          7,  17,   6,  11,  87,  21,  47,  11,  79,  52,  65,  43,  83,
         27,  13,  10,  24,  62,  80]),
 'sample_rate': 22050}

In [65]:
# load datasets
train_dataset_2 = dataset_module.get_dataset(
    config, part='train')
train_dataset_2.dataset.__getitem__(0)

{'audio': tensor([[-7.3242e-04, -7.6294e-04, -6.4087e-04,  ...,  7.3242e-04,
           2.1362e-04,  6.1035e-05]]),
 'text': 'Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition',
 'sample_rate': 22050}

In [66]:
# load datasets
train_dataset_2 = dataset_module.get_dataset(
    config, part='train')
train_dataset_2.dataset.__getitem__(0)

{'audio': tensor([[-7.3242e-04, -7.6294e-04, -6.4087e-04,  ...,  7.3242e-04,
           2.1362e-04,  6.1035e-05]]),
 'text': 'Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition',
 'sample_rate': 22050}

In [67]:
train_dataset.dataset.__getitem__(0)

{'audio': array([-6.6338282e-04, -7.4983307e-04, -6.5849890e-04, ...,
        -1.4586677e-04, -9.8007549e-05, -7.8858320e-05], dtype=float32),
 'text': array([ 56,  12,  44,   6,  75,  65,  43,  93,  88,  48,  49, 103,  46,
         62,  13,  99,  78,  13,  46,   5,  42,  47, 119,  56,  47,  11,
         79,  53,  50,  16,  51,   9,  52,  73,  10,  17,  17,  51,  11,
         57,  82,  19,  63,   8,  96,   4,  10,  17,  77,   8,   6,  57,
         82,  19,  42,  98,  43,  42,  12,   6,  11,  72,  53,  12,   7,
         17,   6,  11,  87,  21,  47,  11,  79,  52,  65,  43,  83,  27,
         13,  10,  24,  62,  80]),
 'sample_rate': 22050}

In [68]:
train_dataset.dataset.__getitem__(0)

{'audio': array([-7.3242188e-04, -7.6293945e-04, -6.4086914e-04, ...,
         7.3242188e-04,  2.1362305e-04,  6.1035156e-05], dtype=float32),
 'text': array([ 56,  12,  44,   6,  75,  65,  43,   4,  50,  88,  48,  49, 103,
         46,  62,  13,  99,  78,  13,  46,   5,  42,  47, 119,  56,  47,
         11,  79,  53,  50,  16,  51,   9,  52,   4,  14,  10,  17,  17,
         51,  11,  57,  82,  19,  63,   8,  96,   4,  10,  17,  77,  86,
         57,  82,  19,  42,  15,  15,  43,  42,  12,   6,  11,  72,   4,
         16,  12,   7,  17,   6,  11,  87,  21,  47,  11,  79,  52,   4,
         10,   9,  43,  83,  27,  13,  10,  24,  62,  80]),
 'sample_rate': 22050}

In [69]:
train_dataset.dataset.__getitem__(0)

{'audio': array([ 3.0744888e-04,  8.4192281e-05, -2.1004396e-04, ...,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00], dtype=float32),
 'text': array([56, 12, 44,  6, 75,  4, 10,  9, 74,  5, 93, 88, 48, 49, 11,  5, 46,
        62, 13, 99, 78, 13, 46,  5, 42, 47,  4, 54, 56, 47, 11, 79, 53, 50,
        16, 51,  9, 52, 73, 10, 17, 17, 51, 11, 57, 82, 19, 63,  8, 96,  4,
        10, 17, 77, 86, 57, 82, 19, 42, 98, 43, 42, 12,  6, 11, 72, 53, 12,
         7, 17,  6, 11, 87, 21, 47, 11, 79, 52, 65, 43, 83, 27, 13, 10, 24,
        62, 80]),
 'sample_rate': 22050}

In [70]:
train_dataset.dataset.__getitem__(0)

{'audio': array([7.1615778e-04, 3.3713618e-04, 3.9294842e-05, ..., 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00], dtype=float32),
 'text': array([ 56,  12,  44,   6,  75,  65,  43,  93,  88,  48,  49, 103,  46,
         62,  13,  99,  78,  13,  46,   5,  42,  47, 119,  56,  47,  11,
         79,  53,  50,  16,  51,   9,  52,  73,  10,  17,  17,  51,  11,
         57,  82,  19,  63,   8,  96,   4,  10,  17,  77,  86,  57,  82,
         19,  42,  98,  43,  42,  12,   6,  11,  72,  53,  12,   7,  17,
          6,  11,  87,  21,  47,  11,  79,  52,  65,  43,  83,  27,  13,
         10,  24,  62,  80]),
 'sample_rate': 22050}

In [71]:
train_dataset.dataset.__getitem__(0)

{'audio': array([0.00112545, 0.00064798, 0.00023147, ..., 0.        , 0.        ,
        0.        ], dtype=float32),
 'text': array([ 56,  12,  44,   6,  75,  65,  43,  93,  88,  48,  49, 103,  46,
         62,  13,  99,  78,  13,  46,   5,  42,  47, 119,  56,  47,  11,
         79,  53,  50,  16,  51,   9,  52,  73,  10,  17,  17,  51,  11,
         57,  82,  19,  63,   8,  11,   6,   4,  10,  17,  77,  86,  57,
         82,  19,  42,  98,  43,  42,  12,   6,  11,  72,  53,  12,   7,
         17,   6,  11,  87,  21,  47,  11,  79,  52,  65,  43,  83,  27,
         13,  10,  24,  62,  80]),
 'sample_rate': 22050}

In [72]:
train_dataset.dataset.__getitem__(0)

{'audio': array([-0.00761147,  0.00097147,  0.0043958 , ...,  0.        ,
         0.        ,  0.        ], dtype=float32),
 'text': array([ 56,  12,  44,   6,  75,  65,  43,  93,  88,  48,  49,  11,   5,
         46,  62,  13,  46,  13,  78,  13,  46,   5,  42,  47, 119,  56,
         47,  11,  79,  53,  50,  16,  51,   9,  52,  73,  10,  17,  17,
         51,  11,  57,  82,  19,  63,   8,  96,   4,  10,  17,  77,  86,
         57,  82,  19,  42,  98,  43,  42,  12,   6,  11,  42,  59,  53,
         12,   7,  17,   6,  11,  87,  21,  47,  11,  79,  52,  65,  43,
         83,  27,  13,  10,  24,  62,  80]),
 'sample_rate': 22050}

In [73]:
train_dataset.dataset.__getitem__(0)

{'audio': array([-0.00902189,  0.00607888, -0.0015608 , ..., -0.00049304,
         0.00043237,  0.        ], dtype=float32),
 'text': array([ 56,  12,  44,   6,  75,   4,  10,   9,  43,  93,  88,  48,  49,
        103,  46,  62,  13,  99,  78,  13,  46,   5,  42,  47, 119,  56,
         47,  11,  79,  53,  50,  16,  51,   9,  52,  73,  10,  17,  17,
         51,  11,  57,  82,  19,  63,   8,  96,   4,  10,  17,  77,  86,
         57,  82,  19,  42,  98,  43,  42,  12,   6,  11,  72,  53,  12,
          7,  17,   6,  11,  87,  21,  47,  11,  79,  52,  65,  43,   4,
          5,  27,  13,  10,  24,  62,  80]),
 'sample_rate': 22050}

In [74]:
train_dataset.dataset.__getitem__(0)

{'audio': array([ 0.00728497, -0.00395672,  0.01365469, ..., -0.00145744,
         0.00241929,  0.        ], dtype=float32),
 'text': array([ 56,  12,  44,   6,  75,  65,  43,  93,  88,  48,  49, 103,  46,
         62,  13,  99,  78,  13,  46,   5,  42,  47, 119,  56,  47,  11,
         79,  53,  50,  16,  51,   9,  52,  73,  10,  17,  17,  51,  11,
         57,  82,  19,  63,   8,  96,   4,  10,  17,  77,  86,  57,  82,
         19,  42,  98,  43,  42,  12,   6,  11,  72,  53,  12,   7,  17,
          6,  11,  87,  21,  47,  11,  79,  52,  65,  43,  83,  27,  13,
         10,  24,  62,  80]),
 'sample_rate': 22050}

In [75]:
train_dataset.dataset.__getitem__(0)

{'audio': array([-0.00930794, -0.0199947 , -0.01293118, ..., -0.01280719,
         0.00850524,  0.        ], dtype=float32),
 'text': array([ 56,  12,  44,   6,  75,  65,  43,  93,  88,  48,  49, 103,  46,
         62,  13,  99,  78,  13,  46,   5,  42,  47, 119,  56,  47,  11,
         79,  53,  50,  16,  51,   9,  52,  73,  10,  17,  17,  51,  11,
         57,  82,  19,  63,   8,  96,   4,  10,  17,  77,  86,  57,  82,
         19,  42,  98,  43,  42,  12,   6,  11,  72,  53,  12,   7,  17,
          6,  11,  87,  21,  47,  11,  79,  52,  65,  43,  83,  27,  13,
         10,  24,  62,  80]),
 'sample_rate': 22050}

In [76]:
train_dataset.dataset.__getitem__(0)
len(train_dataset.dataset.__getitem__(0)['text'])

83

In [77]:
# load datasets
train_dataset_2 = dataset_module.get_dataset(
    config, part='train')
train_dataset_2.dataset.__getitem__(0)
len(train_dataset_2.dataset.__getitem__(0)['text'])

151

In [78]:
train_dataset.dataset.__getitem__(0)
len(train_dataset.dataset.__getitem__(0)['text'])

84

In [79]:
# load datasets
train_dataset_2 = dataset_module.get_dataset(
    config, part='train')
train_dataset_2.dataset.__getitem__(0)
len(train_dataset_2.dataset.__getitem__(0)['text'])

151

In [80]:
train_dataset.dataset.__getitem__(0)
len(train_dataset.dataset.__getitem__(0)['text'])

83

In [81]:
train_dataset.dataset.__getitem__(0)
len(train_dataset.dataset.__getitem__(0)['text'])

84

In [82]:
train_dataset.dataset.__getitem__(0)
len(train_dataset.dataset.__getitem__(0)['text'])

83

In [83]:
train_dataset.dataset.__getitem__(0)
len(train_dataset.dataset.__getitem__(0)['text'])

86

In [84]:
# load datasets
train_dataset_2 = dataset_module.get_dataset(
    config, part='train')
train_dataset_2.dataset.__getitem__(0)
# len(train_dataset_2.dataset.__getitem__(0)['text'])

{'audio': tensor([[-7.3242e-04, -7.6294e-04, -6.4087e-04,  ...,  7.3242e-04,
           2.1362e-04,  6.1035e-05]]),
 'text': 'Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition',
 'sample_rate': 22050}

In [85]:
train_dataset.dataset.__getitem__(0)
# len(train_dataset.dataset.__getitem__(0)['text'])

{'audio': array([ 3.74017225e-04,  7.72984131e-05, -1.21351484e-04, ...,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00], dtype=float32),
 'text': array([ 56,  12,  44,   6,  75,  65,  43,  93,  88,  48,  49, 103,  46,
         62,  13,  99,  78,  13,  46,   5,  42,  47, 119,  56,  47,  11,
         79,  53,  50,  16,  51,   9,  52,   4,  14,  10,  17,  17,  51,
         11,  57,  82,  19,  63,   8,  96,   4,  10,  17,  77,  86,  57,
         82,  19,  42,  98,  43,   4,  69,   6,  11,  72,   4,  16,  12,
          7,  17,   6,  11,  87,  21,  47,  11,  79,  52,  65,  43,  83,
         27,  13,  10,  24,  62,  80]),
 'sample_rate': 22050}

In [86]:
config.dataset

{'root': 'DB/LJspeech',
 'train_part': 0.95,
 'name': 'ljspeech',
 'sample_rate': 22050}

In [87]:
dir(config.dataset)

['__class__',
 '__class_getitem__',
 '__contains__',
 '__delattr__',
 '__delitem__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__ior__',
 '__iter__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__or__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__reversed__',
 '__ror__',
 '__setattr__',
 '__setitem__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'clear',
 'copy',
 'fromkeys',
 'get',
 'items',
 'keys',
 'name',
 'pop',
 'popitem',
 'root',
 'sample_rate',
 'setdefault',
 'train_part',
 'update',
 'values']

In [88]:
config.dataset.get('sample_rate', 16000)

22050

In [89]:
config.dataset.get('sample_rate')

22050

In [90]:
print(train_dataset.dataset.__getitem__(0))
# len(train_dataset.dataset.__getitem__(0)['text'])

In [91]:
a=train_dataset.dataset.__getitem__(0)
print(a)
# len(train_dataset.dataset.__getitem__(0)['text'])
print(a['audio'])

In [92]:
a=train_dataset.dataset.__getitem__(0)
print(a)
# len(train_dataset.dataset.__getitem__(0)['text'])
print(a['audio'].shape)

In [93]:
# load datasets
train_dataset_2 = dataset_module.get_dataset(
    config, part='train')
b = train_dataset_2.dataset.__getitem__(0)
# len(train_dataset_2.dataset.__getitem__(0)['text'])
print(b)

In [94]:
# load datasets
train_dataset_2 = dataset_module.get_dataset(
    config, part='train')
b = train_dataset_2.dataset.__getitem__(0)
# len(train_dataset_2.dataset.__getitem__(0)['text'])
print(b)
print(b['audio'].shape)

In [95]:
# load datasets
train_dataset_2 = dataset_module.get_dataset(
    config, part='train')
b = train_dataset_2.dataset.__getitem__(1)
# len(train_dataset_2.dataset.__getitem__(0)['text'])
print(b)
print(b['audio'].shape)

In [96]:
# load datasets
train_dataset_2 = dataset_module.get_dataset(
    config, part='train')
b = train_dataset_2.dataset.__getitem__(0)
# len(train_dataset_2.dataset.__getitem__(0)['text'])
print(b)
print(b['audio'].shape)

In [97]:
batch1['input_lengths']

In [98]:
batch1['text'][0]

array([ 56,  12,  44,   6,  75,  65,  43,  93,  88,  48,  49, 103,  46,
        62,  13,  99,  78,  13,  46,   5,  42,  47, 119,  56,  47,  11,
        79,  53,  50,  16,  51,   9,  52,  73,  10,  17,  17,  51,  11,
        57,  82,  19,  63,   8,  96,   4,  10,  17,   4,   9,  86,  57,
        82,  19,  42,  98,  43,  42,  12,   6,  11,  72,  53,  12,   7,
        17,   6,  11,  87,  21,  47,  11,  79,  52,  65,  43,  83,  27,
        13,  10,  24,  62,  80])

In [99]:
batch1['audio'][0]

array([ 0.00064805, -0.00185332, -0.00244806, ..., -0.00170346,
        0.00287292,  0.00232188], dtype=float32)

In [100]:
len(batch1['audio'][0])

212893

In [101]:
len(batch1['audio'][0])

212893

In [102]:
len(batch1['audio'][0])
len(batch2['audio'][0])

64

In [103]:
len(batch1['audio'][0])
len(batch2['audio'])

32

In [104]:
len(batch1['audio'][0])
len(batch2['audio'][0])

64

In [105]:
print(len(batch1['audio'][0]))
print(len(batch2['audio'][0]))

In [106]:
print(len(batch1['audio']))
print(len(batch1['audio'][0]))
print(len(batch2['audio'][0]))

In [107]:
batch1['audio']

[array([ 0.00064805, -0.00185332, -0.00244806, ..., -0.00170346,
         0.00287292,  0.00232188], dtype=float32),
 array([-0.00314305,  0.00103363,  0.00466579, ...,  0.03743864,
         0.04169904,  0.        ], dtype=float32),
 array([-2.7460288e-03, -4.2517125e-03, -6.8397014e-05, ...,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00], dtype=float32),
 array([ 1.6830985e-04, -2.7138460e-05, -2.2539313e-04, ...,
         2.5602366e-04,  1.0118349e-03,  0.0000000e+00], dtype=float32),
 array([3.5219718e-04, 3.9384593e-04, 3.4842314e-04, ..., 1.4193331e-04,
        1.3546018e-04, 8.1811209e-05], dtype=float32),
 array([ 7.7085420e-03, -2.8882974e-03, -2.8774177e-05, ...,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00], dtype=float32),
 array([ 3.8510363e-05,  2.8201991e-03,  5.8321138e-03, ...,
         4.3147802e-03,  1.1419075e-02, -6.4695268e-03], dtype=float32),
 array([ 2.2583008e-03,  1.7700195e-03, -9.1552734e-05, ...,
         1.8310547e-04, -1.8310547e-04, -3.

In [108]:
len(batch1['audio'][0])

212893

In [109]:
len(batch1['audio'][0])
len(batch2['audio'][0])

64

In [110]:
len(batch1['audio'][0])
len(batch2['audio'][0])
len(batch1['text'][0])
# len(batch2['audio'][0])

83

In [111]:
len(batch1['audio'][0])
len(batch2['audio'][0])
len(batch1['text'][0])
len(batch2['text'][0])

96

In [112]:
# #! 
batch_transforms_val = Compose([
    ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
    NormalizedMelSpectrogram(
        sample_rate=config.dataset.get(
            'sample_rate', 16000),  # for LJspeech
        n_mels=config.model.feat_in,
        normalize=config.dataset.get('normalize', None)
    ).to('cuda' if torch.cuda.is_available() else 'cpu'),
    AddLengths(),
    Pad()
])

for epoch_idx in tqdm(range(config.train.get('epochs', 10))):
    # train:
    model.train()
    for batch_idx, batch in enumerate(train_dataloader):
        batch1=batch
        print(len(batch1['audio'][0]))
        batch = batch_transforms_train(batch)
        batch2=batch
        print(len(batch2['audio'][0]))
        break

        optimizer.zero_grad()
        logits = model(batch['audio'])
        output_length = torch.ceil(
            batch['input_lengths'].float() / model.stride).int()
        loss = criterion(logits.permute(2, 0, 1).log_softmax(
            dim=2), batch['text'], output_length, batch['target_lengths'])
        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(), config.train.get('clip_grad_norm', 15))
        optimizer.step()
        lr_scheduler.step()
        # warmup_scheduler.dampen()

        if batch_idx % config.wandb.get('log_interval', 5000) == 0:
            target_strings = decoder.convert_to_strings(batch['text'])
            decoded_output = decoder.decode(
                logits.permute(0, 2, 1).softmax(dim=2))
            wer = np.mean([decoder.wer(true, pred)
                            for true, pred in zip(target_strings, decoded_output)])
            cer = np.mean([decoder.cer(true, pred)
                            for true, pred in zip(target_strings, decoded_output)])
            step = epoch_idx * \
                len(train_dataloader) * train_dataloader.batch_size + \
                batch_idx * train_dataloader.batch_size
            wandb.log({
                "train_loss": loss.item(),
                "train_wer": wer,
                "train_cer": cer,
                "train_samples": wandb.Table(
                    columns=['gt_text', 'pred_text'],
                    data=zip(target_strings, decoded_output)
                )
            }, step=step)
    # #!
    # # validate:
    # model.eval()
    # val_stats = defaultdict(list)
    # for batch_idx, batch in enumerate(val_dataloader):
    #     batch = batch_transforms_val(batch)
    #     with torch.no_grad():
    #         logits = model(batch['audio'])
    #         output_length = torch.ceil(
    #             batch['input_lengths'].float() / model.stride).int()
    #         loss = criterion(logits.permute(2, 0, 1).log_softmax(
    #             dim=2), batch['text'], output_length, batch['target_lengths'])

    #     target_strings = decoder.convert_to_strings(batch['text'])
    #     decoded_output = decoder.decode(
    #         logits.permute(0, 2, 1).softmax(dim=2))
    #     wer = np.mean([decoder.wer(true, pred)
    #                     for true, pred in zip(target_strings, decoded_output)])
    #     cer = np.mean([decoder.cer(true, pred)
    #                     for true, pred in zip(target_strings, decoded_output)])
    #     val_stats['val_loss'].append(loss.item())
    #     val_stats['wer'].append(wer)
    #     val_stats['cer'].append(cer)
    # for k, v in val_stats.items():
    #     val_stats[k] = np.mean(v)
    # val_stats['val_samples'] = wandb.Table(
    #     columns=['gt_text', 'pred_text'], data=zip(target_strings, decoded_output))
    # wandb.log(val_stats, step=step)

    # # save model, TODO: save optimizer:
    # if val_stats['wer'] < prev_wer:
    #     os.makedirs(config.train.get(
    #         'checkpoint_path', 'checkpoints'), exist_ok=True)
    #     prev_wer = val_stats['wer']
    #     torch.save(
    #         model.state_dict(),
    #         os.path.join(config.train.get(
    #             'checkpoint_path', 'checkpoints'), f'model_{epoch_idx}_{prev_wer}.pth')
    #     )
    #     wandb.save(os.path.join(config.train.get(
    #         'checkpoint_path', 'checkpoints'), f'model_{epoch_idx}_{prev_wer}.pth'))
    break

In [113]:
# #! 
batch_transforms_val = Compose([
    ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
    # NormalizedMelSpectrogram(
    #     sample_rate=config.dataset.get(
    #         'sample_rate', 16000),  # for LJspeech
    #     n_mels=config.model.feat_in,
    #     normalize=config.dataset.get('normalize', None)
    # ).to('cuda' if torch.cuda.is_available() else 'cpu'),
    AddLengths(),
    Pad()
])

for epoch_idx in tqdm(range(config.train.get('epochs', 10))):
    # train:
    model.train()
    for batch_idx, batch in enumerate(train_dataloader):
        batch1=batch
        print(len(batch1['audio'][0]))
        batch = batch_transforms_train(batch)
        batch2=batch
        print(len(batch2['audio'][0]))
        break

        optimizer.zero_grad()
        logits = model(batch['audio'])
        output_length = torch.ceil(
            batch['input_lengths'].float() / model.stride).int()
        loss = criterion(logits.permute(2, 0, 1).log_softmax(
            dim=2), batch['text'], output_length, batch['target_lengths'])
        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(), config.train.get('clip_grad_norm', 15))
        optimizer.step()
        lr_scheduler.step()
        # warmup_scheduler.dampen()

        if batch_idx % config.wandb.get('log_interval', 5000) == 0:
            target_strings = decoder.convert_to_strings(batch['text'])
            decoded_output = decoder.decode(
                logits.permute(0, 2, 1).softmax(dim=2))
            wer = np.mean([decoder.wer(true, pred)
                            for true, pred in zip(target_strings, decoded_output)])
            cer = np.mean([decoder.cer(true, pred)
                            for true, pred in zip(target_strings, decoded_output)])
            step = epoch_idx * \
                len(train_dataloader) * train_dataloader.batch_size + \
                batch_idx * train_dataloader.batch_size
            wandb.log({
                "train_loss": loss.item(),
                "train_wer": wer,
                "train_cer": cer,
                "train_samples": wandb.Table(
                    columns=['gt_text', 'pred_text'],
                    data=zip(target_strings, decoded_output)
                )
            }, step=step)
    # #!
    # # validate:
    # model.eval()
    # val_stats = defaultdict(list)
    # for batch_idx, batch in enumerate(val_dataloader):
    #     batch = batch_transforms_val(batch)
    #     with torch.no_grad():
    #         logits = model(batch['audio'])
    #         output_length = torch.ceil(
    #             batch['input_lengths'].float() / model.stride).int()
    #         loss = criterion(logits.permute(2, 0, 1).log_softmax(
    #             dim=2), batch['text'], output_length, batch['target_lengths'])

    #     target_strings = decoder.convert_to_strings(batch['text'])
    #     decoded_output = decoder.decode(
    #         logits.permute(0, 2, 1).softmax(dim=2))
    #     wer = np.mean([decoder.wer(true, pred)
    #                     for true, pred in zip(target_strings, decoded_output)])
    #     cer = np.mean([decoder.cer(true, pred)
    #                     for true, pred in zip(target_strings, decoded_output)])
    #     val_stats['val_loss'].append(loss.item())
    #     val_stats['wer'].append(wer)
    #     val_stats['cer'].append(cer)
    # for k, v in val_stats.items():
    #     val_stats[k] = np.mean(v)
    # val_stats['val_samples'] = wandb.Table(
    #     columns=['gt_text', 'pred_text'], data=zip(target_strings, decoded_output))
    # wandb.log(val_stats, step=step)

    # # save model, TODO: save optimizer:
    # if val_stats['wer'] < prev_wer:
    #     os.makedirs(config.train.get(
    #         'checkpoint_path', 'checkpoints'), exist_ok=True)
    #     prev_wer = val_stats['wer']
    #     torch.save(
    #         model.state_dict(),
    #         os.path.join(config.train.get(
    #             'checkpoint_path', 'checkpoints'), f'model_{epoch_idx}_{prev_wer}.pth')
    #     )
    #     wandb.save(os.path.join(config.train.get(
    #         'checkpoint_path', 'checkpoints'), f'model_{epoch_idx}_{prev_wer}.pth'))
    break

In [114]:
# #! 
batch_transforms_train = Compose([
    ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
    # NormalizedMelSpectrogram(
    #     sample_rate=config.dataset.get(
    #         'sample_rate', 16000),  # for LJspeech
    #     n_mels=config.model.feat_in,
    #     normalize=config.dataset.get('normalize', None)
    # ).to('cuda' if torch.cuda.is_available() else 'cpu'),
    AddLengths(),
    Pad()
])

for epoch_idx in tqdm(range(config.train.get('epochs', 10))):
    # train:
    model.train()
    for batch_idx, batch in enumerate(train_dataloader):
        batch1=batch
        print(len(batch1['audio'][0]))
        batch = batch_transforms_train(batch)
        batch2=batch
        print(len(batch2['audio'][0]))
        break

        optimizer.zero_grad()
        logits = model(batch['audio'])
        output_length = torch.ceil(
            batch['input_lengths'].float() / model.stride).int()
        loss = criterion(logits.permute(2, 0, 1).log_softmax(
            dim=2), batch['text'], output_length, batch['target_lengths'])
        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(), config.train.get('clip_grad_norm', 15))
        optimizer.step()
        lr_scheduler.step()
        # warmup_scheduler.dampen()

        if batch_idx % config.wandb.get('log_interval', 5000) == 0:
            target_strings = decoder.convert_to_strings(batch['text'])
            decoded_output = decoder.decode(
                logits.permute(0, 2, 1).softmax(dim=2))
            wer = np.mean([decoder.wer(true, pred)
                            for true, pred in zip(target_strings, decoded_output)])
            cer = np.mean([decoder.cer(true, pred)
                            for true, pred in zip(target_strings, decoded_output)])
            step = epoch_idx * \
                len(train_dataloader) * train_dataloader.batch_size + \
                batch_idx * train_dataloader.batch_size
            wandb.log({
                "train_loss": loss.item(),
                "train_wer": wer,
                "train_cer": cer,
                "train_samples": wandb.Table(
                    columns=['gt_text', 'pred_text'],
                    data=zip(target_strings, decoded_output)
                )
            }, step=step)
    # #!
    # # validate:
    # model.eval()
    # val_stats = defaultdict(list)
    # for batch_idx, batch in enumerate(val_dataloader):
    #     batch = batch_transforms_val(batch)
    #     with torch.no_grad():
    #         logits = model(batch['audio'])
    #         output_length = torch.ceil(
    #             batch['input_lengths'].float() / model.stride).int()
    #         loss = criterion(logits.permute(2, 0, 1).log_softmax(
    #             dim=2), batch['text'], output_length, batch['target_lengths'])

    #     target_strings = decoder.convert_to_strings(batch['text'])
    #     decoded_output = decoder.decode(
    #         logits.permute(0, 2, 1).softmax(dim=2))
    #     wer = np.mean([decoder.wer(true, pred)
    #                     for true, pred in zip(target_strings, decoded_output)])
    #     cer = np.mean([decoder.cer(true, pred)
    #                     for true, pred in zip(target_strings, decoded_output)])
    #     val_stats['val_loss'].append(loss.item())
    #     val_stats['wer'].append(wer)
    #     val_stats['cer'].append(cer)
    # for k, v in val_stats.items():
    #     val_stats[k] = np.mean(v)
    # val_stats['val_samples'] = wandb.Table(
    #     columns=['gt_text', 'pred_text'], data=zip(target_strings, decoded_output))
    # wandb.log(val_stats, step=step)

    # # save model, TODO: save optimizer:
    # if val_stats['wer'] < prev_wer:
    #     os.makedirs(config.train.get(
    #         'checkpoint_path', 'checkpoints'), exist_ok=True)
    #     prev_wer = val_stats['wer']
    #     torch.save(
    #         model.state_dict(),
    #         os.path.join(config.train.get(
    #             'checkpoint_path', 'checkpoints'), f'model_{epoch_idx}_{prev_wer}.pth')
    #     )
    #     wandb.save(os.path.join(config.train.get(
    #         'checkpoint_path', 'checkpoints'), f'model_{epoch_idx}_{prev_wer}.pth'))
    break

In [115]:
print(len(batch1['audio'])) # batchsize
print(len(batch1['audio'][0]))
print(len(batch2['audio'][0]))
print(max([len(x) for x in batch1['audio'] ]))

In [116]:
#!/usr/bin/env python

import os
import json
import time
import argparse
from tqdm import tqdm
from collections import defaultdict
import math
import argparse
import importlib

# torchim:
import torch
from torch import nn
from torch.utils.data import DataLoader, Subset, ConcatDataset
# from tensorboardX import SummaryWriter
import numpy as np
import pytorch_warmup as warmup

# data:
import data
from data.collate import collate_fn, gpu_collate, no_pad_collate
from data.transforms import (
    Compose, AddLengths, AudioSqueeze, TextPreprocess,
    MaskSpectrogram, ToNumpy, BPEtexts, MelSpectrogram,
    ToGpu, Pad, NormalizedMelSpectrogram
)
import youtokentome as yttm

import torchaudio
from audiomentations import (
    TimeStretch, PitchShift, AddGaussianNoise
)
from functools import partial

# model:
from model import configs as quartznet_configs
from model.quartznet import QuartzNet

# utils:
import yaml
from easydict import EasyDict as edict
from utils import fix_seeds, remove_from_dict, prepare_bpe
import wandb
from decoder import GreedyDecoder, BeamCTCDecoder

import youtokentome as yttm


# TODO: wrap to trainer class

parser = argparse.ArgumentParser(description='Training model.')
parser.add_argument('--config', default='configs/train_LJSpeech.yaml',
                    help='path to config file')
args = parser.parse_args("")
with open(args.config, 'r') as f:
    config = edict(yaml.safe_load(f))

def write_to_file(str_w, file_name = 'sth.txt', mode = 'w'):
    with open(file_name,mode) as f:
        f.write(str(str_w))

In [117]:
config

{'dataset': {'root': 'DB/LJspeech',
  'train_part': 0.95,
  'name': 'ljspeech',
  'sample_rate': 22050},
 'bpe': {'train': True, 'model_path': 'yttm.bpe'},
 'train': {'seed': 42,
  'num_workers': 1,
  'batch_size': 32,
  'clip_grad_norm': 15,
  'epochs': 42,
  'optimizer': {'lr': 0.0005, 'weight_decay': 0.0001}},
 'wandb': {'project': 'quartznet_ljspeech', 'log_interval': 20},
 'model': {'name': '_quartznet5x5_config', 'vocab_size': 120, 'feat_in': 64}}

In [118]:
config.dataset.get('sample_rate', 16000)

22050

In [119]:
fix_seeds(seed=config.train.get('seed', 42))
dataset_module = importlib.import_module(
    f'.{config.dataset.name}', data.__name__)
bpe = prepare_bpe(config)

transforms_train = Compose([
    TextPreprocess(),# removing punctuation in text - might not needed
    ToNumpy(), # convert audio to numpy
    BPEtexts(bpe=bpe, dropout_prob=config.bpe.get('dropout_prob', 0.05)),
    AudioSqueeze(), # remove 1st dimension if it is 1 [1,...]
    AddGaussianNoise(
        min_amplitude=0.001,
        max_amplitude=0.015,
        p=0.5
    ),
    TimeStretch(
        min_rate=0.8,
        max_rate=1.25,
        p=0.5
    ),
    PitchShift(
        min_semitones=-4,
        max_semitones=4,
        p=0.5
    )
    # AddLengths()
])

batch_transforms_train = Compose([
    ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
    NormalizedMelSpectrogram(
        sample_rate=config.dataset.get('sample_rate', 16000),
        n_mels=config.model.feat_in,
        normalize=config.dataset.get('normalize', None)
    ).to('cuda' if torch.cuda.is_available() else 'cpu'),
    MaskSpectrogram(
        probability=0.5,
        time_mask_max_percentage=0.05,
        frequency_mask_max_percentage=0.15
    ),
    AddLengths(),
    Pad()
])

transforms_val = Compose([
    TextPreprocess(),
    ToNumpy(),
    BPEtexts(bpe=bpe),
    AudioSqueeze()
])

batch_transforms_val = Compose([
    ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
    NormalizedMelSpectrogram(
        sample_rate=config.dataset.get(
            'sample_rate', 16000),  # for LJspeech
        n_mels=config.model.feat_in,
        normalize=config.dataset.get('normalize', None)
    ).to('cuda' if torch.cuda.is_available() else 'cpu'),
    AddLengths(),
    Pad()
])

# load datasets
train_dataset = dataset_module.get_dataset(
    config, transforms=transforms_train, part='train')
val_dataset = dataset_module.get_dataset(
    config, transforms=transforms_val, part='val')
# print("!!!", config.train.get('num_workers', 4))
train_dataloader = DataLoader(train_dataset, 
                                batch_size=config.train.get('batch_size', 1), collate_fn=no_pad_collate)

val_dataloader = DataLoader(val_dataset, 
                            batch_size=1, collate_fn=no_pad_collate)

print(getattr(quartznet_configs, config.model.name, '_quartznet5x5_config'))

model = QuartzNet(
    model_config=getattr(
        quartznet_configs, config.model.name, '_quartznet5x5_config'),
    **remove_from_dict(config.model, ['name'])
)

# print(model)
write_to_file(model,'model_structure.txt')

optimizer = torch.optim.Adam(
    model.parameters(), **config.train.get('optimizer', {}))
num_steps = len(train_dataloader) * config.train.get('epochs', 10)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=num_steps)
# warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)

if config.train.get('from_checkpoint', None) is not None:
    model.load_weights(config.train.from_checkpoint)

if torch.cuda.is_available():
    model = model.cuda()

criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
# criterion = nn.CTCLoss(blank=config.model.vocab_size)
decoder = GreedyDecoder(bpe=bpe)

prev_wer = 1000
wandb.init(project=config.wandb.project, config=config)
wandb.watch(model, log="all", log_freq=config.wandb.get(
    'log_interval', 5000))