In [334]:
from setup import setup_project_root
setup_project_root()

In [335]:
import shutil
from pathlib import Path


import pandas as pd
import torch


from langaugedetection.data.length_df_tools import select_array
from langaugedetection.data.spectrogram import parse


def select_language_time(languages, window):
    """
    Given a list of languages, and their respective time
    window loads a .csv containing languages and audio files
    correlating to that window. It returns a dictionary with
    key : language
    item : list of audio files
    """
    lang_to_options = {}

    for lang in languages:
        path = f"/om2/user/moshepol/prosody/data/raw_audio/{lang}/custom/length.csv"
        df = pd.read_csv(path)

        lang_to_options[lang] = select_array(df, window)

        print(f"Finished reading: {lang}")

    return lang_to_options


def train_test_val_split(lang_dict):
    """
    Returns the indices for how the data should be
    split between training and testing - does this with
    80% train
    10% test
    10% validation
    """

    max_length = 1_000_000_000

    for lang, item in lang_dict.items():
        max_length = min(max_length, len(item))

    partition = (max_length // 1_000) * 1_000

    # Overide, Make sure to remove when submitting it
    partition = 3_000

    train = int(partition * 0.8)
    test = int(partition * 0.1)

    return train, train + test, train + test * 2


def check_path(base):
    """
    Clears the folder that we're going to
    laod the spectrograms into
    """
    folder = Path(base)

    # Checks if the folder exists, and then deletes folder and recreates it
    if folder.is_dir():
        shutil.rmtree(base)

    folder.mkdir(exist_ok=True, parents=True)


def split_dataset(files, train, test, val):
    """
    Splits the dataset into it's train, test, val
    """
    train_files = files[0:train]
    test_files = files[train:test]
    validation_files = files[test:val]

    return train_files, test_files, validation_files


def save_files(spectrograms, path, batch_size, i):
    """
    Code to save the files,
    """

    def file_name(length, num):
        return "batch_" + '0' * (length - len(str(num))) + str(num) + '_' + str(batch_size)

    torch.save(spectrograms, Path(path + "/" + file_name(5, i) + '.pt'))


In [336]:
import torch
from torch.utils.data import Dataset
import torchaudio

class AudioFileDataset(Dataset):
    def __init__(self, audio_dir, length=88_000, sr=16_000):
        self.files = list(audio_dir)
        self.target_sr = sr
        self.length = length
        self.samplers = {}

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        path = self.files[idx]

        waveform, sr = torchaudio.load(path)

        if sr != self.target_sr:

            if sr not in self.samplers.keys():
                self.samplers[sr] = torchaudio.transforms.Resample(sr, self.target_sr)
                
            new_wave = self.samplers[sr](waveform)
            return self.pad_waveform(new_wave)

        return self.pad_waveform(waveform)

    def pad_waveform(self, waveform):
        length = waveform.shape[-1]

        if length > self.length:
            raise TypeError

        else:
            pad = self.length - length
            return torch.nn.functional.pad(waveform, (0, pad))


In [337]:
languages = ["en", "it", "es", "de"]
window = "5.5 - 6.0"

lang_dict = select_language_time(languages, window)

Finished reading: en
Finished reading: it
Finished reading: es
Finished reading: de


In [338]:
# Need to Pad and then stack all of our tensors -- once created

def collate_fn(batch):
    '''
    Takes in a batch of [1 length] audio files and then stacks them
    all audio files must be the same length
    '''
    group = torch.stack(batch)

    return group

In [339]:
audio_files = lang_dict['en'][0:10_000]
dataset = AudioFileDataset(audio_files)

In [340]:
from torch.utils.data import DataLoader

loader = DataLoader(
    dataset, 
    batch_size=256, 
    collate_fn=collate_fn,
    num_workers=8
)

In [341]:
def compute_spectrogram_batch(batch, n_fft=2048, hop_length=512):
    batch = batch.to('cuda')

    specs = torch.stft(
        batch.squeeze(1),
        n_fft=n_fft,
        hop_length=hop_length,
        return_complex=True
    )

    power = specs.abs() ** 2
    db = 10 * torch.log10(torch.clamp(power, min=1e-10))

    return db.cpu()  # [B, freq_bins, time_frames]

In [342]:
lang = 'en'
use = placement = "range_" + window.replace(".", "_").replace(" ", "")
placement = 'train'

base = f"/om2/user/moshepol/prosody/data/raw_audio/{lang}/spect/{use}/{placement}/"
print(base)
check_path(base)

/om2/user/moshepol/prosody/data/raw_audio/en/spect/range_5_5-6_0/train/


In [343]:
i = 0

for batch_waveforms in loader:
    specs = compute_spectrogram_batch(batch_waveforms)  # [B, F, T]

    save_files(specs, base, len(specs), i)
    i += 1

    print(f'Completed Round: {i}')

Completed Round: 1
Completed Round: 2
Completed Round: 3
Completed Round: 4
Completed Round: 5
Completed Round: 6
Completed Round: 7
Completed Round: 8
Completed Round: 9
Completed Round: 10
Completed Round: 11
Completed Round: 12
Completed Round: 13
Completed Round: 14
Completed Round: 15
Completed Round: 16
Completed Round: 17
Completed Round: 18
Completed Round: 19
Completed Round: 20
Completed Round: 21
Completed Round: 22
Completed Round: 23
Completed Round: 24
Completed Round: 25
Completed Round: 26
Completed Round: 27
Completed Round: 28
Completed Round: 29
Completed Round: 30
Completed Round: 31
Completed Round: 32
Completed Round: 33
Completed Round: 34
Completed Round: 35
Completed Round: 36
Completed Round: 37
Completed Round: 38
Completed Round: 39
Completed Round: 40
