## Dependencies

In [None]:
import torch
from torch import nn

import pytorch_lightning as pl

import torchaudio.transforms as T
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

from IPython.display import Audio, display
from data.datasets import TIMITDataset, PhonemeLabeler
from utils.utils import provide_reproducibility

import wandb

## CUDA device

In [None]:
provide_reproducibility(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

## Datasets

In [None]:
dir_name = r'C:\Data\TIMIT\data'
# dir_name = '/media/maxim/Programming/voice_datasets/timit/TIMIT_2/data' # ubuntu

In [None]:
vowel_labels = ['IY', 'IH', 'EH', 'EY', 'AE', 'AA', 'AW', 'AY', 'AH', 'AO', 'OY', 'OW', 'UH', 'UW', 'UX', 'ER', 'AX',
                'IX', 'AXR', 'AH-H']
consonant_labels = ['B', 'D', 'G', 'P', 'T', 'K', 'DX', 'Q', 'JH', 'CH', 'S', 'SH', 'Z', 'ZH', 'F', 'TH', 'V', 'M', 'N',
                    'NG', 'EM', 'EN', 'ENG', 'NX']

phoneme_classes = {
    'vowels': vowel_labels,
    'consonants': consonant_labels
}
phone_labels = vowel_labels + consonant_labels

timit_dataset_train = TIMITDataset(usage='train', root_dir=dir_name,
                                   phone_codes=phone_labels, padding=16000,
                                   phoneme_labeler=PhonemeLabeler(phoneme_classes, '.'),
                                   description_file_path='../../data/timit_description.csv')

timit_dataset_test = TIMITDataset(usage='test', root_dir=dir_name,
                                  phone_codes=phone_labels, padding=16000,
                                  phoneme_labeler=PhonemeLabeler(phoneme_classes, '.'),
                                  description_file_path='../../data/timit_description.csv')

timit_framerate = timit_dataset_train[0].frame_rate

In [None]:
labels = ['consonants', 'vowels']
num_of_classes = 2

## Transform

In [None]:
transform = T.Resample(orig_freq=timit_framerate, new_freq=8000)
transform_cpu = T.Resample(orig_freq=timit_framerate, new_freq=8000)

In [None]:
def label_to_index(phone):
    if phone == 'consonants':
        return torch.tensor([1, 0])
    else:
        return torch.tensor([0, 1])


def index_to_label(index):
    return labels[index]


def pad_sequence(batch):
    batch = [item.t() for item in batch]
    batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0.)
    return batch.permute(0, 2, 1)


def collate_fn(batch):
    tensors, targets = [], []

    for waveform, label, *_ in batch:
        tensors += [waveform]
        targets += [label_to_index(label)]

    tensors = pad_sequence(tensors)
    targets = torch.stack(targets)

    return tensors, targets


batch_size = 256

if device == "cuda":
    num_workers = 0
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False

train_loader = torch.utils.data.DataLoader(
    timit_dataset_train,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=pin_memory,
)
test_loader = torch.utils.data.DataLoader(
    timit_dataset_test,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=pin_memory,
)

train_features, train_labels = next(iter(train_loader))

In [None]:
from models.m_models import M3
from models.phoneme_recognizer import PhonemeRecognizer, AudioPreprocessorCallback

model_name = ''

# model params
n_input = 1,
n_output = num_of_classes,
stride = 4,
n_channel = 256,
optimizer = 'adadelta'
lr = 3e-1

model = PhonemeRecognizer(
    num_classes=num_of_classes,
    acoustic_model=M3,
    model_params=dict(
        n_input=n_input,
        n_output=num_of_classes,
        stride=stride,
        n_channel=n_channel
    ),
    loss_criterion=nn.BCELoss(),
    lr=3e-1
)
config_params = dict(
    n_input=n_input,
    n_output=num_of_classes,
    stride=stride,
    n_channel=n_channel,
    optimizer=optimizer,
    lr=lr
)
preprocessor_callback = AudioPreprocessorCallback(transform=transform, device=device)

In [None]:
from pytorch_lightning.loggers import WandbLogger

logger = WandbLogger(save_dir='../../logs', name=model_name, log_graph=True)

with wandb.init(project='phoneme_recognizer', config=config_params):
    wandb.run.name = f'({wandb.run.id}) {model_name}'
    wandb.watch(model, log='all', log_graph=True)
