In [1]:
import torch
import os
import os.path
import shutil
import numpy as np
import soundfile as sf

from pathlib import PurePath
from torch.utils.data import DataLoader, random_split
from asteroid.data import TimitDataset, TimitCleanDataset, RandomMixtureDataset
from tqdm import tqdm

from torch import optim
from pytorch_lightning import Trainer, seed_everything, loggers as pl_loggers
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from asteroid_filterbanks.transforms import mag
from asteroid.engine import System
from asteroid.losses import singlesrc_neg_sisdr

from asteroid import DCUNet, DCCRNet

%load_ext autoreload
%autoreload 2

## Constants and utils

In [2]:
BATCH_SIZE       = 64     # could be more on cluster, test if larger one work
SAMPLE_RATE      = 8000   # as agreed upon
#CROP_LEN         = 8192   # slightly more than a second, guaranteed to be less than the shortest clip in TIMIT
CROP_LEN         = 24000  # average track len in TIMIT
SEED             = 42     # magic number :)    

# directory to cache fixed mixtures
TIMIT_CACHE_DIR = '/import/vision-eddydata/dm005_tmp/mixed_wavs_asteroid2'
# directory with train noises (n116-n120)
DRONE_NOISE_DIR = '../../../datasets/noises-train-drones'
# fixed SNRs for validation set
TRAIN_SNRS = [-25, -20, -15, -10, -5]

def sisdr_loss_wrapper(est_target, target):
    return singlesrc_neg_sisdr(est_target.squeeze(1), target).mean()

def train_val_split(ds, val_fraction=0.1, random_seed=SEED):
    assert val_fraction > 0 and val_fraction < 0.5
    len_train = int(len(ds) * (1 - val_fraction))
    len_val = len(ds) - len_train
    return random_split(ds, [len_train, len_val], generator=torch.Generator().manual_seed(random_seed))

## Prepare the data

### Resample TIMIT dataset 

In [3]:
TIMIT_DIR = PurePath('../../../datasets/TIMIT')
TIMIT_DIR_8kHZ = PurePath('/import/vision-eddydata/dm005_tmp/TIMIT_8kHZ')

In [9]:
os.makedirs(TIMIT_DIR_8kHZ, exist_ok=True)
shutil.copyfile(TIMIT_DIR / 'train_data.csv', TIMIT_DIR_8kHZ / 'train_data.csv')
shutil.copyfile(TIMIT_DIR / 'test_data.csv', TIMIT_DIR_8kHZ / 'test_data.csv')

data_dir_in = TIMIT_DIR / 'data'
data_dir_out = TIMIT_DIR_8kHZ / 'data'

def resample(ds, dir_in, dir_out, message='Resampling'):
    dl = DataLoader(ds, num_workers=10)
    for wav, path in tqdm(dl, message):
        path = PurePath(path[0])
        out_path = dir_out / path.relative_to(dir_in)
        os.makedirs(out_path.parent, exist_ok=True)
        sf.write(file=out_path, data=wav[0].numpy(), samplerate=SAMPLE_RATE)

timit_train = TimitCleanDataset(TIMIT_DIR, subset='train', sample_rate=SAMPLE_RATE)
resample(timit_train, data_dir_in, data_dir_out, 'Resampling training data')

timit_test = TimitCleanDataset(TIMIT_DIR, subset='test', sample_rate=SAMPLE_RATE)
resample(timit_test, data_dir_in, data_dir_out, 'Resampling test data')

Resampling training data: 100%|██████████| 4620/4620 [00:23<00:00, 195.48it/s]
Resampling test data: 100%|██████████| 1680/1680 [00:08<00:00, 193.00it/s]


### Load and split the data

In [4]:
# I wanted to do this stuff with random training data, but unfortunately the start of the epoch seems to be very 
# slow with this approach. I don't yet understand why this happens. Let's stick to the fixed dataset for now

# Reproducibility - fix all random seeds
# seed_everything(SEED)

# # Load clean data and split it into train and val
# timit = TimitCleanDataset(TIMIT_DIR_8kHZ, subset='train', sample_rate=SAMPLE_RATE)
# timit_train, timit_val = train_val_split(timit, val_fraction=0.1, random_seed=SEED)

# # Training data mixes crops randomly on the fly with random SNR in range (effectively infinite training data)
# timit_train = RandomMixtureDataset(timit_train, DRONE_NOISE_DIR, random_seed=SEED, snr_range=(-25, -5),
#                                    crop_length=CROP_LEN, sample_rate=SAMPLE_RATE)

# # Validation data is fixed (for stability): mix every clean clip with all the noises in the folder
# # You can add the argument `prefetch_mixtures=False` to cancel iterating over the whole dataset to save
# # the mixtures in the cache folder in advance
# # Argument `mixtures_per_clean` regulates with how many different noise files each clean file will be mixed
# timit_val = TimitDataset.load_with_cache(
#      timit_val, DRONE_NOISE_DIR, cache_dir=TIMIT_CACHE_DIR, snrs=TRAIN_SNRS, root_seed=SEED,
#      mixtures_per_clean=3, dset_name='valid-drones', sample_rate=SAMPLE_RATE,
#      subset='train', crop_length=CROP_LEN)

In [5]:
# Reproducibility - fix all random seeds
seed_everything(SEED)

timit_train_drones = TimitDataset.load_with_cache(
    TIMIT_DIR_8kHZ, DRONE_NOISE_DIR,
    cache_dir=TIMIT_CACHE_DIR, snrs=TRAIN_SNRS, root_seed=SEED,
    mixtures_per_clean=5, dset_name='train-drones', sample_rate=SAMPLE_RATE,
    subset='train', crop_length=CROP_LEN)

timit_train, timit_val = train_val_split(timit_train_drones, val_fraction=0.1, random_seed=SEED)

Preparing datasets: 100%|██████████| 5/5 [06:22<00:00, 76.50s/it]
Load samples: 100%|██████████| 115500/115500 [05:53<00:00, 326.67it/s] 

Track lengths stats: total 2772000000, mean 24000.0, median 24000.0, min 24000, max 24000
Tracks in total: 115500
Total audio duration: 48:7:30





In [6]:
NUM_WORKERS = 5
train_loader = DataLoader(timit_train, shuffle=True, batch_size=BATCH_SIZE,
                          num_workers=NUM_WORKERS, drop_last=True)
val_loader = DataLoader(timit_val, batch_size=BATCH_SIZE,
                        num_workers=NUM_WORKERS, drop_last=True)

## Set up the model, optimizer and scheduler

In [7]:
# some random parameters, does it look sensible?
LR = 1e-3
REDUCE_LR_PATIENCE = 3
EARLY_STOP_PATIENCE = 10
MAX_EPOCHS = 500

# the model here should be constructed in the script accordingly to the passed config (including the model type)
# most of the models accept `sample_rate` parameter for encoders, which is important (default is 16000, override)
model = DCUNet("DCUNet-20", fix_length_mode="trim", sample_rate=SAMPLE_RATE)
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=REDUCE_LR_PATIENCE)
early_stopping = EarlyStopping(monitor='val_loss', patience=EARLY_STOP_PATIENCE)

# Probably we also need to subclass `System`, in order to log the target metrics on the validation set (PESQ/STOI)
system = System(model, optimizer, sisdr_loss_wrapper, train_loader, val_loader, scheduler)

In [10]:
# log dir and model name are also part of the config, of course
LOG_DIR = 'logs'
logger = pl_loggers.TensorBoardLogger(LOG_DIR, name='TIMIT-drones-DCUNET-20-proper', version=1)

# choose the proper accelerator for JADE, probably `ddp` (also, `auto_select_gpus=True` might be useful)
trainer = Trainer(max_epochs=MAX_EPOCHS, gpus=-1, accelerator='dp',
                  logger=logger, callbacks=[early_stopping], deterministic=True)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


## Train

In [None]:
trainer.fit(system)


  | Name  | Type   | Params
---------------------------------
0 | model | DCUNet | 3.5 M 
---------------------------------
3.5 M     Trainable params
0         Non-trainable params
3.5 M     Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

## Save the models fully with all the parameters

In [None]:
torch.save(model.serialize(), 'dcunet_20_proper_v1.pt')