In [0]:
!pip install -q soundfile

In [0]:
!git clone -q https://github.com/chanil1218/DCUnet.pytorch

In [0]:
!wget -q "HIDDEN/einschlafen-chunks-10s/einschlafen2-chunks-10000.tar" -O - | tar xf -

In [0]:
!wget "HIDDEN/einschlafen-chunks-10s/AIR_1_4.tar" -O - | tar xf -

In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
import os
import sys
if 'DCUnet.pytorch' not in sys.path: sys.path.append('DCUnet.pytorch')

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR

from scipy.io import wavfile
import librosa
import tqdm

import glob
import utils
from models.unet import Unet
from models.layers.istft import ISTFT
#from se_dataset import AudioDataset
from torch.utils.data import DataLoader


"""
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', default='experiments/base_model', help="Directory containing params.json")
parser.add_argument('--restore_file', default=None, help="Optional, name of the file in --model_dir containing weights to reload before training")  # 'best' or 'train'
parser.add_argument('--batch_size', default=32, type=int, help='train batch size')
parser.add_argument('--num_epochs', default=100, type=int, help='train epochs number')
args = parser.parse_args()
"""

#n_fft, hop_length = 400, 160
n_fft, hop_length = 1024, 256
window = torch.hann_window(n_fft).cuda()
stft = lambda x: torch.stft(x, n_fft, hop_length, window=window)
istft = ISTFT(n_fft, hop_length, window='hanning').cuda()

def wSDRLoss(mixed, clean, clean_est, eps=2e-7):
    # Used on signal level(time-domain). Backprop-able istft should be used.
    # Batched audio inputs shape (N x T) required.
    bsum = lambda x: torch.sum(x, dim=1) # Batch preserving sum for convenience.
    def mSDRLoss(orig, est):
        # Modified SDR loss, <x, x`> / (||x|| * ||x`||) : L2 Norm.
        # Original SDR Loss: <x, x`>**2 / <x`, x`> (== ||x`||**2)
        #  > Maximize Correlation while producing minimum energy output.
        correlation = bsum(orig * est)
        energies = torch.norm(orig, p=2, dim=1) * torch.norm(est, p=2, dim=1)
        return -(correlation / (energies + eps))

    noise = mixed - clean
    noise_est = mixed - clean_est

    a = bsum(clean**2) / (bsum(clean**2) + bsum(noise**2) + eps)
    wSDR = a * mSDRLoss(clean, clean_est) + (1 - a) * mSDRLoss(noise, noise_est)
    return torch.mean(wSDR)


In [0]:
import glob
import tarfile
import soundfile as sf
import random
import pickle

try:
    TARS, CLEAN = pickle.load(open('tars-clean.pk', 'rb'))
except IOError:
    TARS = glob.glob("/content/drive/My Drive/einschlafen-10s-AIRs/**/*.tar", recursive=True)
    TARS.sort()
    CLEAN = glob.glob("einschlafen2/**/*.wav", recursive=True)
    CLEAN.sort()
    pickle.dump((TARS, CLEAN), open('tars-clean.pk', 'wb'))

SAMPLE_LEN = 3
FIRST_FRAME = 16_000
LAST_FRAME = SAMPLE_LEN * 16_000 + FIRST_FRAME

TARS = TARS[:500]

class MyAudioDataset(torch.utils.data.IterableDataset):
    def __init__(self, tars, clean_files):
        self.tars = tars
        self.preloaded = []
    
    def __iter__(self):
        tar_iter = iter(self.tars)
        while True:
            if not self.preloaded:
                for _ in range(3):
                    try:
                        tar = next(tar_iter)
                    except StopIteration:
                        yield from self.preloaded
                        return
                    else:
                        try:
                            self.refill(tar, self.preloaded)
                        except Exception as e:
                            print("Error reading", tar, e)
                if not self.preloaded:
                    raise RuntimeError("Could not read any tars")
            yield self.preloaded.pop()

    def refill(self, tar, res):                
        episode, chunk = tar.split("/")[-1].split("-", 1)
        clean = f"einschlafen2/{episode}/{episode}-{chunk[:-4]}"
        clean_data, clean_sr = sf.read(clean)
        assert clean_sr == 16_000
        clean_data = torch.from_numpy(clean_data[FIRST_FRAME:LAST_FRAME]).type(torch.FloatTensor)
        with tarfile.open(tar) as tarf:
            members = tarf.getmembers()
            random.shuffle(members)
            for member in members:
                cur_f = tarf.extractfile(member)
                noisy_data, noisy_sr = sf.read(cur_f)
                assert noisy_sr == 16_000
                noisy_data = torch.from_numpy(noisy_data[FIRST_FRAME:LAST_FRAME]).type(torch.FloatTensor)
                res.append((noisy_data, clean_data))

In [0]:
if 0:
    params = utils.Params("DCUnet.pytorch/exp/unet16.json")
    net = Unet(params.model).cuda()
else:
    MODEL = {
        "leaky_slope" : 0.1,
        "ratio_mask" : "BDT",
        "encoders" : [
            [1, 32, [7, 5], [2, 2], [3, 2]],
            [32, 64, [7, 5], [2, 2], [3, 2]],
            [64, 64, [5, 3], [2, 2], [2, 1]],
            [64, 64, [5, 3], [2, 2], [2, 1]],
            [64, 64, [5, 3], [2, 1], [2, 1]]
        ],
        "decoders" : [
            [64, 64, [5, 3], [2, 1], [2, 1]],
            [128, 64, [5, 3], [2, 2], [2, 1]],
            [128, 64, [5, 3], [2, 2], [2, 1]],
            [128, 32, [7, 5], [2, 2], [3, 2]],
            [64, 1, [7, 5], [2, 2], [3, 2]]
        ],
        "__coder_keys" : [
            "in_channels", "out_channels", "kernel_size", "stride", "padding"
        ]
    }
    net = Unet(MODEL).cuda()

print("Model has", sum(p.numel() for p in net.parameters() if p.requires_grad)/1e6, "M params")

# TODO - check exists
START_EPOCH = 28
N_EPOCHS = 50
NET_NAME = 'net-2-'
if START_EPOCH > 0:
    checkpoint = torch.load(f'/content/drive/My Drive/{NET_NAME}{START_EPOCH-1}.pth.tar')
    net.load_state_dict(checkpoint)

BATCH_SIZE = 32

train_dataset = MyAudioDataset(TARS, CLEAN)
train_data_loader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    ) #collate_fn=train_dataset.collate,#, num_workers=1)

In [0]:
def train():
    torch.set_printoptions(precision=10, profile="full")

    # Optimizer
    optimizer = optim.Adam(net.parameters(), lr=1e-3)
    # Learning rate scheduler
    scheduler = ExponentialLR(optimizer, 0.95)

    #mse_loss = torch.nn.MSELoss()

    for epoch in range(START_EPOCH, START_EPOCH+N_EPOCHS):
        train_bar = tqdm.notebook.tqdm(train_data_loader, total=int(len(TARS) * 214 / BATCH_SIZE))
        ct = 0
        for train_mixed_cpu, train_clean_cpu in train_bar:
            ct += 1
            train_mixed = train_mixed_cpu.cuda()

            mixed_spec = stft(train_mixed).unsqueeze(dim=1)
            mixed_real, mixed_imag = mixed_spec[..., 0], mixed_spec[..., 1]

            out_real, out_imag = net(mixed_real, mixed_imag)
            train_clean = train_clean_cpu.cuda()

            #clean_spec = stft(train_clean).unsqueeze(dim=1)
            #clean_real, clean_imag = clean_spec[..., 0], clean_spec[..., 1]
            #out_spec = torch.cat([torch.unsqueeze(out_real, 4), torch.unsqueeze(out_imag, 4)], dim=4)
            #loss = mse_loss(clean_spec, out_spec)

            out_real, out_imag = torch.squeeze(out_real, 1), torch.squeeze(out_imag, 1)
            out_audio = istft(out_real, out_imag, train_mixed.size(1))
            out_audio = torch.squeeze(out_audio, dim=1)
            #for i, l in enumerate(seq_len):
            #    out_audio[i, l:] = 0
            #librosa.output.write_wav('mixed.wav', train_mixed[0].cpu().data.numpy()[:seq_len[0].cpu().data.numpy()], 16000)
            #librosa.output.write_wav('clean.wav', train_clean[0].cpu().data.numpy()[:seq_len[0].cpu().data.numpy()], 16000)
            #librosa.output.write_wav('out.wav', out_audio[0].cpu().data.numpy()[:seq_len[0].cpu().data.numpy()], 16000)
            loss = wSDRLoss(train_mixed, train_clean, out_audio)

            if ct % 30 == 0:
                print(epoch, loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        scheduler.step()
        torch.save(net.state_dict(), f'/content/drive/My Drive/{NET_NAME}{epoch}.pth.tar')
    #torch.save(net.state_dict(), './final.pth.tar')

In [0]:
def eval_samples(mixed):
    mixed_spec = stft(mixed.cuda()).unsqueeze(dim=1)
    mixed_real, mixed_imag = mixed_spec[..., 0], mixed_spec[..., 1]
    out_real, out_imag = net(mixed_real, mixed_imag)
    out_real, out_imag = torch.squeeze(out_real, 1), torch.squeeze(out_imag, 1)
    out_audio = istft(out_real, out_imag, mixed.size(1))
    out_audio = torch.squeeze(out_audio, dim=1)
    return out_audio

def eval_():
    train_data_loader_it = iter(train_data_loader)
    for _ in range(random.randint(0,5)*BATCH_SIZE):
        next(train_data_loader_it)
    train_mixed_cpu, train_clean_cpu = next(train_data_loader_it)

    randidx = random.randint(0, train_mixed_cpu.shape[0]-1)
    mixed = train_mixed_cpu[randidx:randidx+1]
    out_audio = eval_samples(mixed)
    train_clean = train_clean_cpu[randidx:randidx+1].cuda()
    librosa.output.write_wav('mixed.wav', mixed[0].cpu().data.numpy(), 16000)
    librosa.output.write_wav('clean.wav', train_clean[0].cpu().data.numpy(), 16000)
    librosa.output.write_wav('out.wav', out_audio[0].cpu().data.numpy(), 16000)

#eval_()

In [0]:
episodes = set(c.split('/')[1] for c in CLEAN)

conseq = []
for episode in episodes:
    chunks = [c for c in CLEAN if episode in c]
    chunks = sorted(chunks, key=lambda c: (int(c.split('-')[-3]), int(c.split('-')[-2])))
    conseq.append([chunks[0]])
    for chunk in chunks[1:]:
        if not conseq:
            conseq.append([chunk])
        else:
            prev_rir = conseq[-1][-1].split('-')[-3]
            prev_end = conseq[-1][-1].split('-')[-1][:-4]
            cur_start = chunk.split('-')[-2]
            cur_rir = chunk.split('-')[-3]
            if prev_rir != cur_rir or prev_end != cur_start:
                conseq.append([chunk])
            else:
                conseq[-1].append(chunk)

conseq = [c for c in conseq if len(c) > 2]

In [0]:
import scipy.signal

def test_sample(clean):
    mixed = scipy.signal.convolve(
        clean,
        librosa.core.load(random.choice(glob.glob("AIR_1_4/*wav")), sr=None)[0][:16000]
    )[16000:-16000]

    nsplit = len(mixed) // (SAMPLE_LEN*16000)
    audiolen = nsplit * (SAMPLE_LEN*16000)
    return clean[16000:-16000], mixed, eval_sample(
        torch.from_numpy(mixed[:audiolen].reshape((nsplit, -1))
                        ).type(torch.FloatTensor)
    ).reshape(audiolen).cpu().data.numpy()

In [0]:
conseq_sample_clean, conseq_sample_mixed, conseq_sample_out = test_sample(
    np.concatenate([librosa.core.load(wav,sr=None)[0] for wav in conseq[1]])
)

In [0]:
librosa.output.write_wav("conseq-clean.wav", conseq_sample_clean*0.1,sr=16000)
librosa.output.write_wav("conseq-mixed.wav", conseq_sample_mixed*0.1,sr=16000)
librosa.output.write_wav("conseq-out.wav", conseq_sample_out*0.1,sr=16000)

In [0]:
Id.Audio(conseq_sample_mixed,rate=16000)

In [0]:
Id.Audio(conseq_sample_out,rate=16000)

In [0]:
tagesschau_clean, tagesschau_mixed, tagesschau_out = test_sample(
    librosa.core.load("tagesschau---orig.wav",sr=16000, duration=20)[0]
)

In [0]:
Id.Audio(tagesschau_mixed,rate=16000)

In [0]:
Id.Audio(tagesschau_out,rate=16000)

In [0]:
librosa.output.write_wav("tagesschau-clean.wav", tagesschau_clean*0.1,sr=16000)
librosa.output.write_wav("tagesschau-mixed.wav", tagesschau_mixed*0.1,sr=16000)
librosa.output.write_wav("tagesschau-out.wav", tagesschau_out*0.1,sr=16000)

In [0]:
import IPython.display as Id

Id.Audio("clean.wav", rate=16000)

In [0]:
Id.Audio("mixed.wav", rate=16000)

In [0]:
Id.Audio("out.wav", rate=16000)

In [0]:
torch.cuda.empty_cache()
import gc; gc.collect()

def pretty_size(size):
	"""Pretty prints a torch.Size object"""
	assert(isinstance(size, torch.Size))
	return " × ".join(map(str, size))

def dump_tensors(gpu_only=True):
	"""Prints a list of the Tensors being tracked by the garbage collector."""
	import gc
	total_size = 0
	for obj in gc.get_objects():
		try:
			if torch.is_tensor(obj):
				if not gpu_only or obj.is_cuda:
					print("%s:%s%s %s" % (type(obj).__name__, 
										  " GPU" if obj.is_cuda else "",
										  " pinned" if obj.is_pinned else "",
										  pretty_size(obj.size())))
					total_size += obj.numel()
			elif hasattr(obj, "data") and torch.is_tensor(obj.data):
				if not gpu_only or obj.is_cuda:
					print("%s → %s:%s%s%s%s %s" % (type(obj).__name__, 
												   type(obj.data).__name__, 
												   " GPU" if obj.is_cuda else "",
												   " pinned" if obj.data.is_pinned else "",
												   " grad" if obj.requires_grad else "", 
												   " volatile" if obj.volatile else "",
												   pretty_size(obj.data.size())))
					total_size += obj.data.numel()
		except Exception as e:
			pass        
	print("Total size:", total_size)
 
dump_tensors()