In [None]:
import os


In [None]:
# !pip install matplotlib

In [1]:
import torch
import numpy as np
import pandas as pd
import librosa as lr
import matplotlib.pyplot as plt

from librosa import display as lrd
import IPython.display as ipd

from torch.utils.data import DataLoader, ConcatDataset, random_split
from asteroid.data import TimitDataset
from tqdm import tqdm

from asteroid.data.utils import find_audio_files, cut_or_pad

%load_ext autoreload
%autoreload 2

In [2]:
def show_wav(wav, sr=16000):
    if type(wav) == str:
        wav, sr = lr.load(wav)
        
    lrd.waveplot(wav, sr=sr)
    plt.show()
    ipd.display(ipd.Audio(wav, rate=sr))

In [7]:
# TIMIT_CACHE_DIR = '/import/vision-eddydata/dm005_tmp/mixed_wavs_asteroid'
TIMIT_CACHE_DIR = '/jmain01/home/JAD007/txk02/aaa18-txk02/Datasets/mixed_wavs_asteroid'
TIMIT_TRAIN_DIR = '/jmain01/home/JAD007/txk02/aaa18-txk02/Datasets/TIMIT'
ENV_NOISE_DIR = '/jmain01/home/JAD007/txk02/aaa18-txk02/Datasets/noises-train'
DRONE_NOISE_DIR = '/jmain01/home/JAD007/txk02/aaa18-txk02/Datasets/noises-test-drones'

In [4]:
train_snrs = [-25, -20, -15, -10, -5, 0, 5, 10, 15]
test_snrs = [-30, -25, -20, -15, -10, -5, 0, 5]

In [5]:
timit_train_misc = TimitDataset.load_with_cache(
    TIMIT_TRAIN_DIR,
    ENV_NOISE_DIR,
    cache_dir=TIMIT_CACHE_DIR, snrs=train_snrs, root_seed=42, prefetch_mixtures=False,
    dset_name='train-misc', subset='train', track_duration=48000)

Preparing datasets: 100%|██████████| 9/9 [04:26<00:00, 29.60s/it]


In [None]:
# timit_train_drones = TimitDataset.load_with_cache(
#     '../../../datasets/TIMIT', '../../../datasets/noises-train-drones',
#     cache_dir=TIMIT_CACHE_DIR, snrs=train_snrs, root_seed=42, prefetch_mixtures=False,
#     mixtures_per_clean=5, dset_name='train-drones',
#     subset='train', track_duration=48000)

In [None]:
timit_test_drones = TimitDataset.load_with_cache(
    TIMIT_TRAIN_DIR,
    DRONE_NOISE_DIR,
    cache_dir=TIMIT_CACHE_DIR, snrs=test_snrs, dset_name='test-drones',
    subset='test', root_seed=68)

In [None]:
def train_val_split(ds, val_fraction=0.1, random_seed=42):
    assert val_fraction > 0 and val_fraction < 0.5
    len_train = int(len(ds) * (1 - val_fraction))
    len_val = len(ds) - len_train
    return random_split(ds, [len_train, len_val], generator=torch.Generator().manual_seed(random_seed))

In [None]:
train_set, val_set = train_val_split(timit_train_misc)

In [None]:
BATCH_SIZE = 32
NUM_WORKERS = 10

train_loader = DataLoader(
    train_set,
    shuffle=True,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    drop_last=True,
)

val_loader = DataLoader(
    val_set,
    shuffle=False,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    drop_last=True,
)

#train_loader_test = DataLoader(train_set, shuffle=True, batch_size=8, num_workers=NUM_WORKERS, drop_last=True)

In [None]:
from torch import optim
from torch_lr_finder import LRFinder
from pytorch_lightning import Trainer, loggers as pl_loggers
from asteroid_filterbanks.transforms import mag
from asteroid.engine import System
from asteroid.losses import singlesrc_neg_sisdr

from asteroid import DCUNet, DCCRNet

def sisdr_loss_wrapper(est_target, target):
    return singlesrc_neg_sisdr(est_target.squeeze(1), target).mean()


In [None]:
dcunet20 = DCUNet("DCUNet-20", fix_length_mode="trim")
dcunet20_opt = optim.Adam(dcunet20.parameters(), lr=1e-7, weight_decay=1e-6)

In [None]:
#lr_finder = LRFinder(dcunet20, dcunet20_opt, sisdr_loss_wrapper, device="cuda")
#lr_finder.range_test(train_loader_test, end_lr=10, num_iter=100)
#lr_finder.plot()
#lr_finder.reset()

In [None]:
dcunet20_sched = optim.lr_scheduler.OneCycleLR(dcunet20_opt, 0.03, epochs=10, steps_per_epoch=len(train_loader))
scheduler = {'scheduler': dcunet20_sched, 'interval': 'step'}

In [None]:
system = System(dcunet20, dcunet20_opt, sisdr_loss_wrapper, train_loader, val_loader, scheduler)

In [None]:
logger = pl_loggers.TensorBoardLogger('logs', name='TIMIT-drones-DCUNet-20-onecycle', version='v3')
trainer = Trainer(max_epochs=10, gpus=-1, accelerator='dp', logger=logger)
trainer.fit(system)

In [None]:
dcunet20_serialized = dcunet20.serialize()
torch.save(dcunet20_serialized, 'dcunet_20_onecycle_v3.pt')

## Training UNetGAN

In [None]:
import pytorch_lightning as pl
import torch.nn.functional as F
from collections import OrderedDict

from asteroid.masknn import UNetGANGenerator, UNetGANDiscriminator

def _unsqueeze_to_3d(x):
    """Normalize shape of `x` to [batch, n_chan, time]."""
    if x.ndim == 1:
        return x.reshape(1, 1, -1)
    elif x.ndim == 2:
        return x.unsqueeze(1)
    else:
        return x

class UNetGAN(pl.LightningModule):

    def __init__(
        self,
        mse_weight: float = 20,
        lr_g: float = 1e-3,
        lr_d: float = 1e-3,
        **kwargs
    ):
        super().__init__()
        self.save_hyperparameters()

        # networks
        self.generator = UNetGANGenerator()
        self.discriminator = UNetGANDiscriminator()

    def forward(self, z):
        return self.generator(_unsqueeze_to_3d(z))

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def training_step(self, batch, batch_idx, optimizer_idx):
        mix, clean = batch
        mix = _unsqueeze_to_3d(mix)
        clean = _unsqueeze_to_3d(clean)

        # train generator
        if optimizer_idx == 0:

            enh = self.generator(mix)
            disc_vals = self.discriminator(mix, enh)
            
            fake = torch.zeros(*disc_vals.size())
            fake = fake.type_as(disc_vals)
            
            mse_loss = F.mse_loss(clean, enh)
            adv_loss = -self.adversarial_loss(disc_vals, fake)
            
            lm = self.hparams.mse_weight
            g_loss = adv_loss + lm * mse_loss
        
            tqdm_dict = {'g_loss': g_loss}
            output = OrderedDict({
                'loss': g_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output

        # train discriminator
        if optimizer_idx == 1:
            # Measure discriminator's ability to classify real from generated samples

            enh = self.generator(mix).detach()
            clean_disc_vals = self.discriminator(mix, clean)
            enh_disc_vals = self.discriminator(mix, enh)
            
            # how well can it label as real?
            valid = torch.ones(*clean_disc_vals.size())
            valid = valid.type_as(clean_disc_vals)
            
            fake = torch.zeros(*enh_disc_vals.size())
            fake = fake.type_as(enh_disc_vals)

            real_loss = self.adversarial_loss(clean_disc_vals, valid)
            fake_loss = self.adversarial_loss(enh_disc_vals, fake)
        
            d_loss = real_loss + fake_loss
            tqdm_dict = {'d_loss': d_loss}
            output = OrderedDict({
                'loss': d_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output

    def configure_optimizers(self):
        lr_g = self.hparams.lr_g
        lr_d = self.hparams.lr_d

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr_g)
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr_d)
        return [opt_g, opt_d], []

#     def on_epoch_end(self):
#         z = self.validation_z.type_as(self.generator.model[0].weight)

#         # log sampled images
#         sample_imgs = self(z)
#         grid = torchvision.utils.make_grid(sample_imgs)
#         self.logger.experiment.add_image('generated_images', grid, self.current_epoch)

In [None]:
unetgan_module = UNetGAN(lr_g=2e-4, lr_d=2e-4)

In [None]:
trainer = Trainer(max_epochs=100, gpus=[1,3], accelerator='ddp',
                  resume_from_checkpoint='logs/UNetGAN-misc-continue/v1/checkpoints/epoch=68-step=322712.ckpt')
trainer.fit(unetgan_module, train_loader, val_loader)

In [None]:
logger = pl_loggers.TensorBoardLogger('logs', name='TIMIT-misc-DCCRN', version='v2')
model2 = DCCRNet("DCCRN-CL")
optimizer2 = optim.Adam(model2.parameters(), lr=1e-3)
system2 = System(model2, optimizer2, sisdr_loss_wrapper, train_loader, val_loader)

In [None]:
trainer = Trainer(max_epochs=30, gpus=1, logger=logger)
trainer.fit(system2)

In [None]:
dccrn_serialized = model2.serialize()

In [None]:
torch.save(dccrn_serialized, 'dccrn_v2.pt')

In [None]:
from asteroid.masknn.wavenet import UNetGANGenerator

In [None]:
gen = UNetGANGenerator()

In [None]:
sum(p.numel() for p in gen.parameters())
