### Imports

In [18]:
import argparse
import copy
import os
import logging
import secrets
import numpy
import copy
from datetime import timedelta

from ipywidgets import IntProgress
from IPython.display import display
from IPython.display import Audio

import time

# PyTorch model and training necessities
import torch
import torch.nn as nn
import torch.nn.functional as nnF
import torch.optim as optim
from torch.utils.data import DataLoader, random_split

from complexPyTorch.complexFunctions import complex_relu

import auraloss

# Audio
import torchaudio
from torio.io import CodecConfig

# Image datasets and image manipulation
import torchvision
import torchvision.transforms as transforms

# Image display
import matplotlib.pyplot as plt
import numpy as np

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter

print(torch.__version__)
print(torchaudio.__version__)


2.3.0+cu121
2.3.0+cu121


### Configuration

In [2]:
base_dataset_directory = '/home/jacob/cv-corpus-17.0-2024-03-15/en'
noisy_dataset_directory = '/home/jacob/noisy-commonvoice/en'
models_dir = '/home/jacob/denoise-models'

### Load datasets

In [3]:
common_voice_dataset = torchaudio.datasets.COMMONVOICE(root=base_dataset_directory)
common_voice_noisy_dataset = torchaudio.datasets.COMMONVOICE(root=noisy_dataset_directory)

### Load datasets and create train / test splits. The same seed is used for splitting noisy and clear datasets so the files match up.

In [4]:
device="cuda"

clear_loader = DataLoader(
    common_voice_dataset,
    batch_size=1)

noisy_loader = DataLoader(
    common_voice_noisy_dataset,
    batch_size=1)

split_generator_0 = torch.Generator().manual_seed(314)
noisy_train, noisy_test = random_split(noisy_loader.dataset, [0.9, 0.1], generator=split_generator_0)

split_generator_1 = torch.Generator().manual_seed(314)
clear_train, clear_test = random_split(clear_loader.dataset, [0.9, 0.1], generator=split_generator_1)

# noisy_1 = next(iter(noisy_train))
# clear_1 = next(iter(clear_train))

# Audio(noisy_1[0].squeeze(), rate=48000)




In [5]:
# Audio(clear_1[0].squeeze(), rate=48000)

In [26]:
### Create a model

sample_rate = 48000

sample_batch_ms = 200

samples_per_batch = int((sample_batch_ms / 1000) * sample_rate)

print(samples_per_batch)


class ComplexRelu(nn.Module):
    def __init__(self):
        super(ComplexRelu, self).__init__()
             
    def forward(self, x):
        x = complex_relu(x)
        return x
        

with torch.cuda.device(0):
    torch.cuda.empty_cache()
    dtype=torch.float32
    
    sequence_model = nn.Sequential(
        nn.Linear(in_features=samples_per_batch, out_features=samples_per_batch, device=device, dtype=dtype),
        nn.Linear(in_features=samples_per_batch, out_features=int(samples_per_batch * 0.8), device=device, dtype=dtype),
        nn.Linear(in_features=int(samples_per_batch * 0.8), out_features=int(samples_per_batch * 0.6), device=device, dtype=dtype),
        nn.Linear(in_features=int(samples_per_batch * 0.6), out_features=int(samples_per_batch * 0.4), device=device, dtype=dtype),
        # ComplexRelu(),
        nn.ReLU(),
        nn.Linear(in_features=int(samples_per_batch * 0.4), out_features=int(samples_per_batch * 0.6), device=device, dtype=dtype),    
        nn.Linear(in_features=int(samples_per_batch * 0.6), out_features=int(samples_per_batch * 0.8), device=device, dtype=dtype),    
        nn.Linear(in_features=int(samples_per_batch * 0.8), out_features=samples_per_batch, device=device, dtype=dtype),
        nn.Linear(in_features=samples_per_batch, out_features=samples_per_batch, device=device, dtype=dtype),        
    )

    # loss_fn = nn.L1Loss()
    loss_fn = auraloss.time.SNRLoss()
    # loss_fn = auraloss.freq.SumAndDifferenceSTFTLoss()

    optimizer = torch.optim.SGD(params=sequence_model.parameters(), lr=0.001)

    print(sequence_model)
    


9600
Sequential(
  (0): Linear(in_features=9600, out_features=9600, bias=True)
  (1): Linear(in_features=9600, out_features=7680, bias=True)
  (2): Linear(in_features=7680, out_features=5760, bias=True)
  (3): Linear(in_features=5760, out_features=3840, bias=True)
  (4): ReLU()
  (5): Linear(in_features=3840, out_features=5760, bias=True)
  (6): Linear(in_features=5760, out_features=7680, bias=True)
  (7): Linear(in_features=7680, out_features=9600, bias=True)
  (8): Linear(in_features=9600, out_features=9600, bias=True)
)


In [27]:
%%time
### Train

files_processed = 0

t0 = time.time()

with torch.cuda.device(0):
    noisy_iter = iter(noisy_train)
    clear_iter = iter(clear_train)
    while(time.time() - t0 < 10 * 60 * 60):
        noisy_complete = next(noisy_iter, None)
        if noisy_complete == None:
            break
        
        noisy = noisy_complete[0].squeeze()
        clear = next(clear_iter)[0].squeeze()

        if noisy_complete[1] != 48000:
            continue        

        files_processed += 1
        
        noisy_split = torch.split(noisy, samples_per_batch)
        clear_split = torch.split(clear, samples_per_batch)
        
        for split_idx, noisy_batch in enumerate(noisy_split):
            noisy_batch = noisy_batch.to(device)
            
            noisy_pad = nn.ZeroPad1d((0, samples_per_batch - noisy_batch.size()[0]))
            noisy_batch = noisy_pad(noisy_batch)
            
            clear_batch = clear_split[split_idx].to(device)
            clear_pad = nn.ZeroPad1d((0, samples_per_batch - clear_batch.size()[0]))
            clear_batch = clear_pad(clear_batch)

            sequence_model.train()
            prediction = sequence_model(noisy_batch)
            # prediction = torch.fft.ifft(sequence_model(torch.fft.fft(noisy_batch)))
            loss = loss_fn(prediction, clear_batch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            sequence_model.eval()

        if files_processed % 10 == 0:
            elapsed_str = str(timedelta(seconds=time.time() - t0))
            print(f"Loss: {loss}, elapsed: {elapsed_str}")

    noisy_iter = iter(noisy_train)
    clear_iter = iter(clear_train)

    keep_going = True
    while keep_going:
        noisy_complete = next(noisy_iter)
        noisy = noisy_complete[0].squeeze()
        clear = next(clear_iter)[0].squeeze()
        if noisy_complete[1] == 48000:
            keep_going = False
        else:
            keep_going = True
        
    noisy_split = torch.split(noisy, samples_per_batch)
    clear_split = torch.split(clear, samples_per_batch)

    prediction_reconstructed = None
    
    for split_idx, noisy_batch in enumerate(noisy_split):
        noisy_batch = noisy_batch.to(device)
        
        noisy_pad = nn.ZeroPad1d((0, samples_per_batch - noisy_batch.size()[0]))
        noisy_batch = noisy_pad(noisy_batch)
        
        clear_batch = clear_split[split_idx].to(device)
        clear_pad = nn.ZeroPad1d((0, samples_per_batch - clear_batch.size()[0]))
        clear_batch = clear_pad(clear_batch)

        # prediction = torch.fft.ifft(sequence_model(torch.fft.fft(noisy_batch)))
        prediction = sequence_model(noisy_batch)
        if prediction_reconstructed is not None:
            prediction_reconstructed = torch.cat((prediction_reconstructed, prediction))
        else:
            prediction_reconstructed = prediction

    print(noisy.size())
    print(prediction_reconstructed.size())
    print(clear.size())

    torch.save(sequence_model, models_dir + f"/model-{time.strftime("%Y%m%d-%H%M%S")}")

    display(Audio(noisy.cpu().detach(), rate=48000))
    display(Audio(prediction_reconstructed.cpu().detach(), rate=48000))
    display(Audio(clear.cpu().detach(), rate=48000))



    

Loss: 58.677940368652344, elapsed: 0:00:04.057745
Loss: 35.885902404785156, elapsed: 0:00:08.517181
Loss: 6.755973815917969, elapsed: 0:00:12.757986
Loss: 3.9846506118774414, elapsed: 0:00:17.002586
Loss: 12.402949333190918, elapsed: 0:00:21.685704
Loss: 0.8028571605682373, elapsed: 0:00:26.378056
Loss: 21.938875198364258, elapsed: 0:00:30.743963
Loss: 6.201815605163574, elapsed: 0:00:35.227129
Loss: 1.8485157489776611, elapsed: 0:00:39.369679
Loss: 46.22247314453125, elapsed: 0:00:43.910063
Loss: 11.239066123962402, elapsed: 0:00:48.246627
Loss: 14.225729942321777, elapsed: 0:00:52.763279
Loss: 1.998288631439209, elapsed: 0:00:56.949294
Loss: 14.207096099853516, elapsed: 0:01:01.950680
Loss: 80.0, elapsed: 0:01:06.549718
Loss: 0.0038690867368131876, elapsed: 0:01:11.079494
Loss: 33.606468200683594, elapsed: 0:01:15.645805
Loss: 15.61754322052002, elapsed: 0:01:19.959587
Loss: 21.5706729888916, elapsed: 0:01:24.093006
Loss: 19.389652252197266, elapsed: 0:01:29.288851
Loss: 54.426227569

CPU times: user 9h 52min 59s, sys: 2min 15s, total: 9h 55min 14s
Wall time: 10h 1s
