In [13]:
import torch
import torch.nn.functional as F
import os
import os.path
import shutil
import numpy as np
import soundfile as sf

from pathlib import PurePath
from torch import nn
from torch.utils.data import DataLoader, random_split
from asteroid.data import TimitDataset
from asteroid.data.utils import CachedWavSet, RandomMixtureSet, FixedMixtureSet
from tqdm import tqdm

from torch import optim
from pytorch_lightning import Trainer, seed_everything, loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from asteroid_filterbanks import make_enc_dec
from asteroid_filterbanks.transforms import mag
from asteroid.engine import System
from asteroid.losses import singlesrc_neg_sisdr

from asteroid import DCUNet, DCCRNet, DPRNNTasNet, ConvTasNet

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Constants and utils

In [4]:
BATCH_SIZE       = 32     # could be more on cluster, test if larger one work
SAMPLE_RATE      = 8000   # as agreed upon
CROP_LEN         = 24000  # average track len in TIMIT
SEED             = 42     # magic number :)    

# directory with train noises (n116-n120)
DRONE_NOISE_DIR = '../../../datasets/noises-train-drones-8khz'
# fixed SNRs for validation set
TRAIN_SNRS = [-25, -20, -15, -10, -5]

def sisdr_loss_wrapper(est_target, target):
    return singlesrc_neg_sisdr(est_target.squeeze(1), target).mean()

def train_val_split(ds, val_fraction=0.1, random_seed=SEED):
    assert val_fraction > 0 and val_fraction < 0.5
    len_train = int(len(ds) * (1 - val_fraction))
    len_val = len(ds) - len_train
    return random_split(ds, [len_train, len_val], generator=torch.Generator().manual_seed(random_seed))

## Prepare the data

### Resample TIMIT dataset 

In [5]:
TIMIT_DIR = PurePath('../../../datasets/TIMIT')
TIMIT_DIR_8kHZ = PurePath('/import/vision-eddydata/dm005_tmp/TIMIT_8kHZ')

In [5]:
os.makedirs(TIMIT_DIR_8kHZ, exist_ok=True)
shutil.copyfile(TIMIT_DIR / 'train_data.csv', TIMIT_DIR_8kHZ / 'train_data.csv')
shutil.copyfile(TIMIT_DIR / 'test_data.csv', TIMIT_DIR_8kHZ / 'test_data.csv')

data_dir_in = TIMIT_DIR / 'data'
data_dir_out = TIMIT_DIR_8kHZ / 'data'

def resample(ds, dir_in, dir_out, message='Resampling'):
    dl = DataLoader(ds, num_workers=10)
    for wav, path in tqdm(dl, message):
        path = PurePath(path[0])
        out_path = dir_out / path.relative_to(dir_in)
        os.makedirs(out_path.parent, exist_ok=True)
        sf.write(file=out_path, data=wav[0].numpy(), samplerate=SAMPLE_RATE)

timit_train = TimitDataset(TIMIT_DIR, subset='train', sample_rate=SAMPLE_RATE)
resample(timit_train, data_dir_in, data_dir_out, 'Resampling training data')

timit_test = TimitDataset(TIMIT_DIR, subset='test', sample_rate=SAMPLE_RATE)
resample(timit_test, data_dir_in, data_dir_out, 'Resampling test data')

Resampling training data: 100%|██████████| 4620/4620 [02:37<00:00, 29.33it/s]
Resampling test data: 100%|██████████| 1680/1680 [00:58<00:00, 28.86it/s]


### Load and split the data

In [4]:
# This is for training with on-the-fly random noise mixture with random SNR from the range.
# Is expected to produce a more robust result: checking this now on EECS server

# Reproducibility - fix all random seeds
# seed_everything(SEED)

# # Load noises, resample and save into the memory
# noises = CachedWavSet(DRONE_NOISE_DIR, sample_rate=SAMPLE_RATE, precache=True)

# # Load clean data and split it into train and val
# timit = TimitDataset(TIMIT_DIR_8kHZ, subset='train', sample_rate=SAMPLE_RATE, with_path=False)
# timit_train, timit_val = train_val_split(timit, val_fraction=0.1, random_seed=SEED)

# # Training data mixes crops randomly on the fly with random SNR in range (effectively infinite training data)
# # `repeat_factor=20` means that the dataset contains 20 copies of itself - it is the easiest way to make the epoch longer
# timit_train = RandomMixtureSet(timit_train, noises, random_seed=SEED, snr_range=(-25, -5),
#                                crop_length=CROP_LEN, repeat_factor=30)

# # Validation data is fixed (for stability): mix every clean clip with all the noises in the folder
# # Argument `mixtures_per_clean` regulates with how many different noise files each clean file will be mixed
# timit_val = FixedMixtureSet(timit_val, noises, snrs=TRAIN_SNRS, random_seed=SEED,
#                             mixtures_per_clean=5, crop_length=CROP_LEN)

Precaching audio: 100%|██████████| 5/5 [00:00<00:00, 108.51it/s]


In [6]:
# Reproducibility - fix all random seeds
seed_everything(SEED)

# Load noises, resample and save into the memory
noises = CachedWavSet(DRONE_NOISE_DIR, sample_rate=SAMPLE_RATE, precache=True)

timit = TimitDataset(TIMIT_DIR_8kHZ, subset='train', sample_rate=SAMPLE_RATE, with_path=False)
timit_mix = FixedMixtureSet(timit, noises, snrs=TRAIN_SNRS, random_seed=SEED,
                            mixtures_per_clean=5, crop_length=CROP_LEN)

timit_train, timit_val = train_val_split(timit_mix, val_fraction=0.1, random_seed=SEED)

Precaching audio: 100%|██████████| 5/5 [00:00<00:00, 17.93it/s]


In [7]:
NUM_WORKERS = 5
train_loader = DataLoader(timit_train, shuffle=True, batch_size=BATCH_SIZE,
                          num_workers=NUM_WORKERS, drop_last=True)
val_loader = DataLoader(timit_val, batch_size=BATCH_SIZE,
                        num_workers=NUM_WORKERS, drop_last=True)

In [8]:
batch = next(iter(val_loader))

In [11]:
mix, clean = batch

In [16]:
mix = mix.unsqueeze(1)
mix.shape

torch.Size([32, 1, 24000])

In [32]:
encoder, decoder = make_enc_dec('stft', n_filters=256, kernel_size=256, stride=128, sample_rate=8000)

In [33]:
mix_stft = encoder(mix).unsqueeze(1)

In [34]:
mix_stft.shape

torch.Size([32, 1, 258, 186])

In [35]:
mix_stft = mag(mix_stft)
mix_stft.shape

torch.Size([32, 1, 129, 186])

In [51]:
F.pad(mix_stft, (3,3,0,0), mode='replicate').shape

torch.Size([32, 1, 129, 192])

In [52]:
F.unfold(F.pad(mix_stft, (3, 3, 0, 0), mode='replicate'), (129, 7)).shape

torch.Size([32, 903, 186])

## Set up the model, optimizer and scheduler

In [10]:
# some random parameters, does it look sensible?
LR = 1e-3
REDUCE_LR_PATIENCE = 3
EARLY_STOP_PATIENCE = 5
MAX_EPOCHS = 100

# the model here should be constructed in the script accordingly to the passed config (including the model type)
# most of the models accept `sample_rate` parameter for encoders, which is important (default is 16000, override)
#model = DCUNet("DCUNet-20", fix_length_mode="trim", sample_rate=SAMPLE_RATE)
model = DCCRNet("DCCRN-CL", sample_rate=SAMPLE_RATE)

# Glorot initialization of the model's parameters (important, mentioned in the paper)
lrelu_gain = nn.init.calculate_gain('leaky_relu', 0.01)
def init_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_normal_(m.weight, gain=lrelu_gain)
        if m.bias is not None:
            m.bias.data.fill_(0.01)
    
model = model.apply(init_weights)

In [11]:
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=REDUCE_LR_PATIENCE)
early_stopping = EarlyStopping(monitor='val_loss', patience=EARLY_STOP_PATIENCE)
checkpoint = ModelCheckpoint(
    filename='{epoch:02d}-{val_loss:.2f}',
    monitor="val_loss",
    mode="min",
    save_top_k=5,
    verbose=True
)

# Probably we also need to subclass `System`, in order to log the target metrics on the validation set (PESQ/STOI)
system = System(model, optimizer, sisdr_loss_wrapper, train_loader, val_loader, scheduler)

In [12]:
# log dir and model name are also part of the config, of course
LOG_DIR = 'logs'
logger = pl_loggers.TensorBoardLogger(LOG_DIR, name='TIMIT-drones-DCCRN-proper', version=1)

# choose the proper accelerator for JADE, probably `ddp` (also, `auto_select_gpus=True` might be useful)
trainer = Trainer(max_epochs=MAX_EPOCHS, gpus=-1, accelerator='dp',
                  logger=logger, callbacks=[early_stopping, checkpoint], deterministic=True)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


## Train

In [None]:
trainer.fit(system)


  | Name  | Type    | Params
----------------------------------
0 | model | DCCRNet | 3.7 M 
----------------------------------
3.7 M     Trainable params
0         Non-trainable params
3.7 M     Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…



HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 0, global step 3247: val_loss reached -3.99637 (best -3.99637), saving model to "logs/TIMIT-drones-DCCRN-proper/version_1/checkpoints/epoch=00-val_loss=-4.00.ckpt" as top 5


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 1, global step 6495: val_loss reached -4.92332 (best -4.92332), saving model to "logs/TIMIT-drones-DCCRN-proper/version_1/checkpoints/epoch=01-val_loss=-4.92.ckpt" as top 5


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 2, global step 9743: val_loss reached -5.12564 (best -5.12564), saving model to "logs/TIMIT-drones-DCCRN-proper/version_1/checkpoints/epoch=02-val_loss=-5.13.ckpt" as top 5


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 3, global step 12991: val_loss reached -5.83977 (best -5.83977), saving model to "logs/TIMIT-drones-DCCRN-proper/version_1/checkpoints/epoch=03-val_loss=-5.84.ckpt" as top 5


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 4, global step 16239: val_loss reached -6.24204 (best -6.24204), saving model to "logs/TIMIT-drones-DCCRN-proper/version_1/checkpoints/epoch=04-val_loss=-6.24.ckpt" as top 5


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 5, global step 19487: val_loss reached -6.44075 (best -6.44075), saving model to "logs/TIMIT-drones-DCCRN-proper/version_1/checkpoints/epoch=05-val_loss=-6.44.ckpt" as top 5
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Epoch 6, global step 22735: val_loss reached -6.67306 (best -6.67306), saving model to "logs/TIMIT-drones-DCCRN-proper/version_1/checkpoints/epoch=06-val_loss=-6.67.ckpt" as top 5
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.

In [1]:
torch.save(model.serialize(), 'dccrn_proper_v1.pt')

NameError: name 'torch' is not defined

## Save the models fully with all the parameters

In [12]:
torch.save(model.serialize(), 'dccrn_random_v1.pt')