### Imports

In [1]:
import argparse
import copy
import os
import logging
import secrets
import numpy
import copy
import gc
import math
from datetime import timedelta

from ipywidgets import IntProgress
from IPython.display import display
from IPython.display import Audio

import time

# PyTorch model and training necessities
import torch
import torch.nn as nn
import torch.nn.functional as nnF
import torch.optim as optim
from torch.utils.data import DataLoader, random_split

from complexPyTorch.complexFunctions import complex_relu

import auraloss

# Audio
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T

from torio.io import CodecConfig

# Image datasets and image manipulation
import torchvision
import torchvision.transforms as transforms

# Image display
import matplotlib.pyplot as plt
import numpy as np

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter

print(torch.__version__)
print(torchaudio.__version__)


2.3.0+cu121
2.3.0+cu121


### Configuration

In [2]:
base_dataset_directory = '/home/jacob/cv-corpus-17.0-2024-03-15/en'
noisy_dataset_directory = '/home/jacob/noisy-commonvoice/en'
models_dir = '/home/jacob/denoise-models'

### Load datasets

In [3]:
common_voice_dataset = torchaudio.datasets.COMMONVOICE(root=base_dataset_directory)
common_voice_noisy_dataset = torchaudio.datasets.COMMONVOICE(root=noisy_dataset_directory)

In [4]:
def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
    waveform = waveform.numpy()

    num_channels, _ = waveform.shape

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].specgram(waveform[c], Fs=sample_rate)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
        if xlim:
            axes[c].set_xlim(xlim)
    figure.suptitle(title)
    

### Load datasets and create train / test splits. The same seed is used for splitting noisy and clear datasets so the files match up.

In [5]:
device="cuda"

clear_loader = DataLoader(
    common_voice_dataset,
    batch_size=1)

noisy_loader = DataLoader(
    common_voice_noisy_dataset,
    batch_size=1)

split_generator_0 = torch.Generator().manual_seed(314)
noisy_train, noisy_test = random_split(noisy_loader.dataset, [0.9, 0.1], generator=split_generator_0)

split_generator_1 = torch.Generator().manual_seed(314)
clear_train, clear_test = random_split(clear_loader.dataset, [0.9, 0.1], generator=split_generator_1)

# noisy_1 = next(iter(noisy_train))
# clear_1 = next(iter(clear_train))

# Audio(noisy_1[0].squeeze(), rate=48000)




In [5]:
# Audio(clear_1[0].squeeze(), rate=48000)

In [49]:
### Create a model

sample_rate = 48000

sample_batch_ms = 50
hidden_size_ms = 200

samples_per_batch = int((sample_batch_ms / 1000) * sample_rate)
samples_per_hidden = int((hidden_size_ms / 1000) * sample_rate)

gc.collect()
torch.cuda.empty_cache()

class ComplexRelu(nn.Module):
    def __init__(self):
        super(ComplexRelu, self).__init__()
             
    def forward(self, x):
        x = complex_relu(x)
        return x


class DenoisingModel(nn.Module):
    def __init__(self, device, dtype):
        super(DenoisingModel, self).__init__()
        
        n_fft = 400
        self.spectrogram = T.Spectrogram(n_fft=n_fft, power=2, wkwargs={'device': device})
        self.griffin = T.GriffinLim(n_fft=n_fft, wkwargs={'device': device})

        # spectrogram_size = n_fft / 2 + 1
        
        # self.inverseSpectrogram = T.InverseSpectrogram(wkwargs = {'device': device})
        # self.conv1_out_channels = 16

        # self.c1 = nn.Conv2d(in_channels=1, out_channels=self.conv1_out_channels, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype)

        # self.c1 = nn.Conv2d(in_channels=1, out_channels=self.conv1_out_channels, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype)
        # self.c2 = nn.Conv2d(in_channels=self.conv1_out_channels, out_channels=32, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype)
        # self.c3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=3, padding=1, device=device, dtype=dtype)
        # self.c4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=3, padding=1, device=device, dtype=dtype)

        # self.pool = nn.MaxPool2d(kernel_size=3, stride=3)
        self.dropout = nn.Dropout(p=0.3)

        print(f"samples_per_batch: {samples_per_batch}, samples_per_hidden: {samples_per_hidden}")

        self.lstm = nn.LSTM(input_size=201, hidden_size=201 * 5, num_layers=3, dropout=0.3, device=device, dtype=dtype)

        # dense_size = 3000
        
        self.fc1 = nn.Linear(in_features = 5 * 2613, out_features = 6000, device=device, dtype=dtype)
        self.fc2 = nn.Linear(in_features = 6000, out_features = 5000, device=device, dtype=dtype)
        self.fc3 = nn.Linear(in_features = 5000, out_features = 4000, device=device, dtype=dtype)
        self.fc4 = nn.Linear(in_features = 4000, out_features = 201 * 13, device=device, dtype=dtype)
        
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax()
        
        # self.normy = nn.BatchNorm1d(65, device=device)

        # self.fc1 = nn.Linear(in_features = 2048, out_features = 6000, device=device, dtype=dtype)
        # self.fc2 = nn.Linear(in_features = 6000, out_features = 8000, device=device, dtype=dtype)
        # self.fc3 = nn.Linear(in_features = 8000, out_features = 151 * 65, device=device, dtype=dtype)
        # self.fc4 = nn.Linear(in_features = 151 * 65, out_features = 151 * 65, device=device, dtype=dtype)

        
        # self.transformer = nn.Transformer(
        #     d_model=input_dim,
        #     nhead=nhead,
        #     num_encoder_layers=num_encoder_layers,
        #     num_decoder_layers=num_decoder_layers,
        #     dim_feedforward=dim_feedforward,
        #     batch_first=True,
        #     device=device,
        #     dtype=dtype,
        # )
        
        

    def forward(self, x):
        with torch.enable_grad():
            x.requires_grad = True
            x = x.to(device)

            # print(f"000 {x}")

            # print(f"1 size of x: {x.size()} dtype: {x.dtype} has_nans: {torch.isnan(x).any()}")        
            x = self.spectrogram(x).unsqueeze(0)
            x = x.permute(2, 0, 1)

            # print(f"x.max(): {x.max()}")            
            # print(f"100 {x}")
            # print(f"X MAX: {x.max()}")
            # x = x / 2048.0
            # print(f"X MAX after: {x.max()}")

            # print(f"2 size of x: {x.size()} dtype: {x.dtype} has_nans: {torch.isnan(x).any()}")
            # x = x.unsqueeze(0)
            # print(f"3 size of x: {x.size()} dtype: {x.dtype} has_nans: {torch.isnan(x).any()}")
            
            # x = self.c1(x)
            # print(f"4 size of x: {x.size()} dtype: {x.dtype} has_nans: {torch.isnan(x).any()}")
            # 
            # x = torch.relu(x)
            # 
            # print(f"5 size of x: {x.size()} dtype: {x.dtype} has_nans: {torch.isnan(x).any()}")
            
            # print(f"6 size of x: {x.size()} dtype: {x.dtype} has_nans: {torch.isnan(x).any()}")
            # x = self.c2(x)
            # x = self.dropout(x)

            # x = self.c3(x)
            # x = self.dropout(x)
            
            # x = self.c4(x)
            
            x, _ = self.lstm(x)
            
            x = self.dropout(x)

            # x = self.pool(x)

            # print(f"200 {x}")

            # print(f"7 size of x: {x.size()} dtype: {x.dtype} has_nans: {torch.isnan(x).any()}")
            
            x = x.view(-1, 201 * 65)

            # print(x)
            
            # print(f"8 size of x: {x.size()} dtype: {x.dtype} has_nans: {torch.isnan(x).any()}")
            x = self.fc1(x)
            x = self.relu(x)
            x = self.dropout(x)

            # print(f"9 size of x: {x.size()} dtype: {x.dtype} has_nans: {torch.isnan(x).any()}")
            x = self.fc2(x)
            x = self.softmax(x)
            x = self.dropout(x)

            # print(f"10 size of x: {x.size()} dtype: {x.dtype} has_nans: {torch.isnan(x).any()}")
            x = self.fc3(x)
            x = self.dropout(x)

            # print(f"11 size of x: {x.size()} dtype: {x.dtype} has_nans: {torch.isnan(x).any()}")
            x = self.fc4(x)
            x = self.dropout(x)

            # x = self.dropout(x)

            # print(f"300 {x}")
            
            # print(f"12 size of x: {x.size()} dtype: {x.dtype} has_nans: {torch.isnan(x).any()}")
            x = x.view(201, 13)

            # x = self.normy(x)

            # print(f"400 {x}")

            # print(f"13 size of x: {x.size()} dtype: {x.dtype} has_nans: {torch.isnan(x).any()}")

            x = x - x.min() + 0.001
            x = x / 2.0
            # x = torch.clamp(x, min=1e-10)
            # x = x * 100.0

            # print(f"x.max(): {x.max()}")
            # print(f"x min: {x.min()}")
            # print(f"500 {x}") 
            # print(f"x min: {x.min()}")
            x = self.griffin(x)

            # print(f"600 {x}")

            # print(f"9 size of x: {x.size()} dtype: {x.dtype} has_nans: {torch.isnan(x).any()}")

        
            # x = self.inverseSpectrogram(x)
            # print(x)
            # print(f"9 size of x: {x.size()} dtype: {x.dtype} has_nans: {torch.isnan(x).any()}")
            # x = x.to(device)
            # print(f"10 size of x: {x.size()} dtype: {x.dtype} has_nans: {torch.isnan(x).any()}")
        return x        

with torch.cuda.device(0):
    torch.cuda.empty_cache()
    dtype=torch.float32
    
    # sequence_model = nn.Sequential(
    #     nn.Linear(in_features=samples_per_batch, out_features=samples_per_batch, device=device, dtype=dtype),
    #     nn.Linear(in_features=samples_per_batch, out_features=int(samples_per_batch * 0.8), device=device, dtype=dtype),
    #     nn.Linear(in_features=int(samples_per_batch * 0.8), out_features=int(samples_per_batch * 0.6), device=device, dtype=dtype),
    #     nn.Linear(in_features=int(samples_per_batch * 0.6), out_features=int(samples_per_batch * 0.4), device=device, dtype=dtype),
    #     # ComplexRelu(),
    #     # nn.ReLU(),
    #     nn.Linear(in_features=int(samples_per_batch * 0.4), out_features=int(samples_per_batch * 0.6), device=device, dtype=dtype),    
    #     nn.Linear(in_features=int(samples_per_batch * 0.6), out_features=int(samples_per_batch * 0.8), device=device, dtype=dtype),    
    #     nn.Linear(in_features=int(samples_per_batch * 0.8), out_features=samples_per_batch, device=device, dtype=dtype),
    #     nn.Linear(in_features=samples_per_batch, out_features=samples_per_batch, device=device, dtype=dtype),        
    # )

    # sequence_model = nn.Sequential(
        # T.MuLawEncoding(quantization_channels=256),
        # DenoisingModel(input_dim=1, nhead=1, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, device=device, dtype=dtype),
        # T.MuLawDecoding(quantization_channels=256),
        
        # nn.Linear(in_features=samples_per_batch, out_features=samples_per_batch, device=device, dtype=dtype),
        # nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
        
        
        # nn.Linear(in_features=samples_per_batch, out_features=samples_per_batch, device=device, dtype=dtype),
        # nn.Transformer(nhead=16, num_encoder_layers=12, num_decoder_layers=12, dim_feedforward=samples_per_batch, batch_first=True, device=device, dtype=dtype),
        # nn.TransformerEncoderLayer(d_model=512, nhead=16, dim_feedforward=2048, dropout=0.1, device=device, dtype=dtype),
        # nn.Linear(in_features=samples_per_batch, out_features=samples_per_batch, device=device, dtype=dtype),        
    # )
    
    sequence_model = DenoisingModel(device=device, dtype=dtype)
    
    loss_fn = nn.L1Loss()
    # loss_fn = auraloss.time.SNRLoss()
    # loss_fn = auraloss.freq.SumAndDifferenceSTFTLoss()

    optimizer = torch.optim.SGD(params=sequence_model.parameters(), lr=0.01)

    print(sequence_model)
    


samples_per_batch: 2400, samples_per_hidden: 9600
DenoisingModel(
  (spectrogram): Spectrogram()
  (griffin): GriffinLim()
  (dropout): Dropout(p=0.3, inplace=False)
  (lstm): LSTM(201, 1005, num_layers=3, dropout=0.3)
  (fc1): Linear(in_features=13065, out_features=6000, bias=True)
  (fc2): Linear(in_features=6000, out_features=5000, bias=True)
  (fc3): Linear(in_features=5000, out_features=4000, bias=True)
  (fc4): Linear(in_features=4000, out_features=2613, bias=True)
  (relu): ReLU()
  (softmax): Softmax(dim=None)
)


In [50]:
%%time
### Train

files_processed = 0

t0 = time.time()

resample_rate = 48000

with torch.cuda.device(0):
    noisy_iter = iter(noisy_train)
    clear_iter = iter(clear_train)
    while time.time() - t0 < 60 * 5:
        noisy_complete = next(noisy_iter, None)
        if noisy_complete is None:
            break
        
        noisy = noisy_complete[0].squeeze()
        clear = next(clear_iter)[0].squeeze()

        resampler = T.Resample(noisy_complete[1], resample_rate, dtype=torch.float32)
        noisy = resampler(noisy)
        clear = resampler(clear)

        files_processed += 1
        
        noisy_split = torch.split(noisy, samples_per_batch)
        clear_split = torch.split(clear, samples_per_batch)

        loss_sum = 0
        for split_idx, noisy_batch in enumerate(noisy_split):
            # time1 = time.perf_counter()

            noisy_batch = noisy_batch.to(device)
            
            noisy_pad = nn.ZeroPad1d((0, samples_per_batch - noisy_batch.size()[0]))
            noisy_batch = noisy_pad(noisy_batch)
            
            clear_batch = clear_split[split_idx].to(device)
            clear_pad = nn.ZeroPad1d((0, samples_per_batch - clear_batch.size()[0]))
            clear_batch = clear_pad(clear_batch)

            # time2 = time.perf_counter()
            
            sequence_model.train()

            # time3 = time.perf_counter()
            
            prediction = sequence_model(noisy_batch)

            # time4 = time.perf_counter()
            
            # prediction = torch.fft.ifft(sequence_model(torch.fft.fft(noisy_batch)))
            loss = loss_fn(prediction, clear_batch)

            if math.isnan(loss):
                nan_in_prediction = "Yes" if torch.isnan(prediction).any() else "No"
                print(f"ERROR: NaN loss. NaN in prediction? {nan_in_prediction}")
                raise KeyboardInterrupt

            # time5 = time.perf_counter()
            
            optimizer.zero_grad()

            # time6 = time.perf_counter()
            
            loss.backward()
            
            # time7 = time.perf_counter()
            # for name, param in sequence_model.named_parameters():
            #     if param.grad is not None:
            #         grad_mean = param.grad.mean().item()
            #         grad_std = param.grad.std().item()
            #         print(f"{name}: grad_mean = {grad_mean}, grad_std = {grad_std}")
            #     else:
            #         print(f"No gradient for {name}")
            
            optimizer.step()

            # time8 = time.perf_counter()
            
            sequence_model.eval()

            # time9 = time.perf_counter()

            loss_sum += loss
        
        if files_processed % 25 == 0:
            elapsed_str = str(timedelta(seconds=time.time() - t0))
            print(f"Loss: {(loss_sum / 25.0):.12f}\t elapsed: {elapsed_str}\tfiles_processed: {files_processed}")
            loss_sum = 0
            # 1-2: {(time2 - time1):.5f} 2-3: {(time3 - time2):.5f} 3-4: {(time4 - time3):.5f} 4-5: {(time5 - time4):.5f} 5-6: {(time6 - time5):.5f} 6-7: {(time7 - time6):.5f} 7-8: {(time8 - time7):.5f} 8-9: {(time9 - time8):.5f}

    noisy_iter = iter(noisy_train)
    clear_iter = iter(clear_train)

    keep_going = True
    while keep_going:
        noisy_complete = next(noisy_iter)
        noisy = noisy_complete[0].squeeze()
        clear = next(clear_iter)[0].squeeze()
        if noisy_complete[1] == 48000:
            keep_going = False
        else:
            keep_going = True
        
    noisy_split = torch.split(noisy, samples_per_batch)
    clear_split = torch.split(clear, samples_per_batch)

    prediction_reconstructed = None
    
    for split_idx, noisy_batch in enumerate(noisy_split):
        noisy_batch = noisy_batch.to(device)
        
        noisy_pad = nn.ZeroPad1d((0, samples_per_batch - noisy_batch.size()[0]))
        noisy_batch = noisy_pad(noisy_batch)
        
        clear_batch = clear_split[split_idx].to(device)
        clear_pad = nn.ZeroPad1d((0, samples_per_batch - clear_batch.size()[0]))
        clear_batch = clear_pad(clear_batch)

        # prediction = torch.fft.ifft(sequence_model(torch.fft.fft(noisy_batch)))
        prediction = sequence_model(noisy_batch)
        if prediction_reconstructed is not None:
            prediction_reconstructed = torch.cat((prediction_reconstructed, prediction))
        else:
            prediction_reconstructed = prediction

    print(noisy.size())
    print(prediction_reconstructed.size())
    print(clear.size())

    print(prediction_reconstructed)

    torch.save(sequence_model, models_dir + f"/model-{time.strftime("%Y%m%d-%H%M%S")}")

    display(Audio(noisy.cpu().detach(), rate=48000))
    display(Audio(prediction_reconstructed.cpu().detach(), rate=48000))
    display(Audio(clear.cpu().detach(), rate=48000))



    

Loss: 0.606571555138	 elapsed: 0:01:00.374659	files_processed: 25
Loss: 0.274048954248	 elapsed: 0:02:01.792607	files_processed: 50
Loss: 0.406927019358	 elapsed: 0:03:06.667241	files_processed: 75
Loss: 0.141191363335	 elapsed: 0:04:11.026467	files_processed: 100
torch.Size([206208])
torch.Size([206400])
torch.Size([206208])
tensor([-0.0067, -0.0124, -0.0142,  ...,  0.0427,  0.0518,  0.0810],
       device='cuda:0', grad_fn=<CatBackward0>)


CPU times: user 5min 11s, sys: 1.17 s, total: 5min 12s
Wall time: 5min 3s
