### Imports

In [20]:
import argparse
import copy
import os
import logging
import secrets
import numpy
import copy
import gc
import math
from datetime import timedelta

from ipywidgets import IntProgress
from IPython.display import display
from IPython.display import Audio

import time

# PyTorch model and training necessities
import torch
import torch.nn as nn
import torch.nn.functional as nnF
import torch.optim as optim
from torch.utils.data import DataLoader, random_split

from complexPyTorch.complexFunctions import complex_relu

import auraloss

# Audio
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T

from torio.io import CodecConfig

# Image datasets and image manipulation
import torchvision
import torchvision.transforms as transforms
import torchvision.models as TVM

# Image display
import matplotlib.pyplot as plt
import numpy as np

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter

print(torch.__version__)
print(torchaudio.__version__)


2.3.0+cu121
2.3.0+cu121


### Configuration

In [21]:
base_dataset_directory = '/home/jacob/cv-corpus-17.0-2024-03-15/en'
noisy_dataset_directory = '/home/jacob/noisy-commonvoice/en'
models_dir = '/home/jacob/denoise-models'

### Load datasets

In [22]:
common_voice_dataset = torchaudio.datasets.COMMONVOICE(root=base_dataset_directory)
common_voice_noisy_dataset = torchaudio.datasets.COMMONVOICE(root=noisy_dataset_directory)

In [23]:
def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
    waveform = waveform.numpy()

    num_channels, _ = waveform.shape

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].specgram(waveform[c], Fs=sample_rate)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
        if xlim:
            axes[c].set_xlim(xlim)
    figure.suptitle(title)
    

### Load datasets and create train / test splits. The same seed is used for splitting noisy and clear datasets so the files match up.

In [24]:
device="cuda"

clear_loader = DataLoader(
    common_voice_dataset,
    batch_size=1)

noisy_loader = DataLoader(
    common_voice_noisy_dataset,
    batch_size=1)

split_generator_0 = torch.Generator().manual_seed(314)
noisy_train, noisy_test = random_split(noisy_loader.dataset, [0.9, 0.1], generator=split_generator_0)

split_generator_1 = torch.Generator().manual_seed(314)
clear_train, clear_test = random_split(clear_loader.dataset, [0.9, 0.1], generator=split_generator_1)

# noisy_1 = next(iter(noisy_train))
# clear_1 = next(iter(clear_train))

# Audio(noisy_1[0].squeeze(), rate=48000)




In [5]:
# Audio(clear_1[0].squeeze(), rate=48000)

In [120]:
### Create a model

sample_rate = 48000

sample_batch_ms = 50
hidden_size_ms = 200

samples_per_batch = int((sample_batch_ms / 1000) * sample_rate)
samples_per_hidden = int((hidden_size_ms / 1000) * sample_rate)

gc.collect()
torch.cuda.empty_cache()

class ComplexRelu(nn.Module):
    def __init__(self):
        super(ComplexRelu, self).__init__()
             
    def forward(self, x):
        x = complex_relu(x)
        return x

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, device, dtype):
        super(ConvBlock, self).__init__()
        
        self.sequential = nn.Sequential(
            nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
            nn.BatchNorm1d(num_features=out_channels, device=device, dtype=dtype),
            nn.ReLU(inplace=True),
            nn.Conv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
            nn.BatchNorm1d(num_features=out_channels, device=device, dtype=dtype),
            nn.ReLU(inplace=True),
        )
        
    def forward(self, x):
        return self.sequential(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels, device, dtype):
        super(Down, self).__init__()
        
        self.sequential = nn.Sequential(
            nn.MaxPool1d(kernel_size=2, stride=2),
            ConvBlock(in_channels=in_channels, out_channels=out_channels, device=device, dtype=dtype)
        )
        
    def forward(self, x):
        return self.sequential(x)
    
class Up(nn.Module):
    def __init__(self, in_channels, out_channels, device, dtype):
        super(Up, self).__init__()
        
        self.transpose = nn.ConvTranspose1d(in_channels=in_channels, out_channels=in_channels // 2, kernel_size=2, stride=2, device=device, dtype=dtype)
        self.convBlock = ConvBlock(in_channels=in_channels, out_channels=out_channels, device=device, dtype=dtype)
        
    def forward(self, x1, x2):
        # print(f"x1 size: {x1.size()}")
        
        x1 = self.transpose(x1)
        
        diff = x2.size()[1] - x1.size()[1]  # Calculate difference correctly
        # print(f"Transposed x1: {x1.size()} x2: {x2.size()} diff: {diff}")
        
        # Pad x1 if necessary
        x1 = nnF.pad(x1, (diff // 2, diff - diff // 2))
        # print(f"Padded x1: {x1.size()}")
    
        # Concatenate along the channel dimension
        x = torch.cat([x2, x1], dim=1)
        # print(f"Concatenated x: {x.size()}")
        
        x = self.convBlock(x)
        
        # print(f"ConvBlock output: \n{x.size()}")
        
        return x
    
class OutLayer(nn.Module):
    def __init__(self, in_channels, out_channels, device, dtype):
        super(OutLayer, self).__init__()
        
        self.sequential = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='linear', align_corners=True),
            ConvBlock(in_channels=in_channels, out_channels=out_channels, device=device, dtype=dtype),            
        )
        
    def forward(self, x):
        return self.sequential(x)

class UNet1d(nn.Module):
    def __init__(self, in_channels, device, dtype):
        super(UNet1d, self).__init__()
        
        layer_sizes = [64, 128, 256, 512, 1024]
        
        self.first_layer = ConvBlock(in_channels=in_channels, out_channels=layer_sizes[0], device=device, dtype=dtype)
        
        self.down_layers = [
            Down(in_channels=layer_sizes[i], out_channels=layer_sizes[i+1], device=device, dtype=dtype)
            for i in range(len(layer_sizes) - 1)
        ]
        
        self.down_layers_module_list = nn.ModuleList(self.down_layers)
        
        self.up_layers = [
            Up(in_channels=layer_sizes[-(i+1)], out_channels=layer_sizes[-(i+2)], device=device, dtype=dtype)
            for i in range(len(layer_sizes) - 2)
        ]
        
        self.up_layers_module_list = nn.ModuleList(self.up_layers)
        
        self.last_layer = OutLayer(in_channels=layer_sizes[1], out_channels=in_channels, device=device, dtype=dtype)

    
    def forward(self, x):
        x = x.unsqueeze(0).unsqueeze(0)
        
        # print(f"Unsqueezed input: \n{x}")        
        
        x = self.first_layer(x)
        
        # print(f"First layer output: \n{x}")
        
        down_outputs = []
        for down_layer in self.down_layers:
            x = down_layer(x)
            # print(f"Down layer output: \n{x.size()}")
            down_outputs.append(x)
            
        down_outputs_reversed = list(reversed(down_outputs))
        
        up_outputs = []
        for (i, up_layer) in enumerate(self.up_layers):
            x = up_layer(x, down_outputs_reversed[i + 1])
            # print(f"Up layer output: \n{x.size()}")
            up_outputs.append(x)
        
        x = self.last_layer(x)
        
        # print(f"Last layer output: \n{x.size()}")
        
        x = x.squeeze(0).squeeze(0)
        
        # print(f"Squeezed output: \n{x.size()}")

        return x


with torch.cuda.device(0):
    torch.cuda.empty_cache()
    dtype=torch.float32
    
    sequence_model = UNet1d(in_channels=1, device=device, dtype=dtype)
    
    loss_fn = nn.L1Loss()
    # loss_fn = auraloss.time.SNRLoss()
    # loss_fn = auraloss.freq.SumAndDifferenceSTFTLoss()

    optimizer = torch.optim.SGD(params=sequence_model.parameters(), lr=0.01)

    print(sequence_model)
    


UNet1d(
  (first_layer): ConvBlock(
    (sequential): Sequential(
      (0): Conv1d(1, 64, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
      (4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down_layers_module_list): ModuleList(
    (0): Down(
      (sequential): Sequential(
        (0): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (1): ConvBlock(
          (sequential): Sequential(
            (0): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))
            (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            (3): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
            (4

In [None]:
%%time
### Train

files_processed = 0

t0 = time.time()

resample_rate = 48000

with torch.cuda.device(0):
    noisy_iter = iter(noisy_train)
    clear_iter = iter(clear_train)
    while time.time() - t0 < 60 * 5:
        noisy_complete = next(noisy_iter, None)
        if noisy_complete is None:
            break
        
        noisy = noisy_complete[0].squeeze()
        clear = next(clear_iter)[0].squeeze()

        resampler = T.Resample(noisy_complete[1], resample_rate, dtype=torch.float32)
        noisy = resampler(noisy)
        clear = resampler(clear)

        files_processed += 1
        
        noisy_split = torch.split(noisy, samples_per_batch)
        clear_split = torch.split(clear, samples_per_batch)

        loss_sum = 0
        for split_idx, noisy_batch in enumerate(noisy_split):
            # time1 = time.perf_counter()

            noisy_batch = noisy_batch.to(device)
            
            noisy_pad = nn.ZeroPad1d((0, samples_per_batch - noisy_batch.size()[0]))
            noisy_batch = noisy_pad(noisy_batch)
            
            clear_batch = clear_split[split_idx].to(device)
            clear_pad = nn.ZeroPad1d((0, samples_per_batch - clear_batch.size()[0]))
            clear_batch = clear_pad(clear_batch)

            # time2 = time.perf_counter()
            
            sequence_model.train()

            # time3 = time.perf_counter()
            
            prediction = sequence_model(noisy_batch)

            # time4 = time.perf_counter()
            
            # prediction = torch.fft.ifft(sequence_model(torch.fft.fft(noisy_batch)))
            loss = loss_fn(prediction, clear_batch)

            if math.isnan(loss):
                nan_in_prediction = "Yes" if torch.isnan(prediction).any() else "No"
                print(f"ERROR: NaN loss. NaN in prediction? {nan_in_prediction}")
                raise KeyboardInterrupt

            # time5 = time.perf_counter()
            
            optimizer.zero_grad()

            # time6 = time.perf_counter()
            
            loss.backward()
            
            # time7 = time.perf_counter()
            # for name, param in sequence_model.named_parameters():
            #     if param.grad is not None:
            #         grad_mean = param.grad.mean().item()
            #         grad_std = param.grad.std().item()
            #         print(f"{name}: grad_mean = {grad_mean}, grad_std = {grad_std}")
            #     else:
            #         print(f"No gradient for {name}")
            
            optimizer.step()

            # time8 = time.perf_counter()
            
            sequence_model.eval()

            # time9 = time.perf_counter()

            loss_sum += loss
        
        if files_processed % 25 == 0:
            elapsed_str = str(timedelta(seconds=time.time() - t0))
            print(f"Loss: {(loss_sum / 25.0):.12f}\t elapsed: {elapsed_str}\tfiles_processed: {files_processed}")
            loss_sum = 0
            # 1-2: {(time2 - time1):.5f} 2-3: {(time3 - time2):.5f} 3-4: {(time4 - time3):.5f} 4-5: {(time5 - time4):.5f} 5-6: {(time6 - time5):.5f} 6-7: {(time7 - time6):.5f} 7-8: {(time8 - time7):.5f} 8-9: {(time9 - time8):.5f}

    noisy_iter = iter(noisy_train)
    clear_iter = iter(clear_train)

    keep_going = True
    while keep_going:
        noisy_complete = next(noisy_iter)
        noisy = noisy_complete[0].squeeze()
        clear = next(clear_iter)[0].squeeze()
        if noisy_complete[1] == 48000:
            keep_going = False
        else:
            keep_going = True
        
    noisy_split = torch.split(noisy, samples_per_batch)
    clear_split = torch.split(clear, samples_per_batch)

    prediction_reconstructed = None
    
    for split_idx, noisy_batch in enumerate(noisy_split):
        noisy_batch = noisy_batch.to(device)
        
        noisy_pad = nn.ZeroPad1d((0, samples_per_batch - noisy_batch.size()[0]))
        noisy_batch = noisy_pad(noisy_batch)
        
        clear_batch = clear_split[split_idx].to(device)
        clear_pad = nn.ZeroPad1d((0, samples_per_batch - clear_batch.size()[0]))
        clear_batch = clear_pad(clear_batch)

        # prediction = torch.fft.ifft(sequence_model(torch.fft.fft(noisy_batch)))
        prediction = sequence_model(noisy_batch)
        if prediction_reconstructed is not None:
            prediction_reconstructed = torch.cat((prediction_reconstructed, prediction))
        else:
            prediction_reconstructed = prediction

    print(noisy.size())
    print(prediction_reconstructed.size())
    print(clear.size())

    print(prediction_reconstructed)

    torch.save(sequence_model, models_dir + f"/model-{time.strftime("%Y%m%d-%H%M%S")}")

    display(Audio(noisy.cpu().detach(), rate=48000))
    display(Audio(prediction_reconstructed.cpu().detach(), rate=48000))
    display(Audio(clear.cpu().detach(), rate=48000))



    

Loss: 0.378890365362	 elapsed: 0:00:14.178621	files_processed: 25
Loss: 0.222374126315	 elapsed: 0:00:26.683783	files_processed: 50
Loss: 0.349628299475	 elapsed: 0:00:40.644103	files_processed: 75
Loss: 0.073013946414	 elapsed: 0:00:53.380538	files_processed: 100
Loss: 0.333045899868	 elapsed: 0:01:06.072085	files_processed: 125
Loss: 0.273125410080	 elapsed: 0:01:18.135365	files_processed: 150
Loss: 0.274757921696	 elapsed: 0:01:30.210839	files_processed: 175
Loss: 0.293407440186	 elapsed: 0:01:42.848867	files_processed: 200
