In [8]:
import torch
import torchaudio
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio.transforms as T
import torchaudio.functional as F
from IPython.display import display, Audio
import librosa
import numpy as np
import pandas as pd
import random
import os
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms

In [23]:
aud_dir = '../../LibriVox_Kaggle/'
bg_dir = '../../LibriVox_Kaggle/BGnoise/'
rir_dir = '../../RIR/MIT_IR_Survey/Audio/'
train_csv_file = 'only_audioFname_train.csv'
test_csv_file = 'only_audioFname_test.csv'

bg_files = os.listdir(bg_dir)
rir_files = os.listdir(rir_dir)[1::]

SAMPLE_RATE = 16000

N_FFT = 1024
WIN_LEN = 1024
HOP_LEN = 256

spectrogram = T.Spectrogram(n_fft=N_FFT, 
                            win_length=WIN_LEN, 
                            hop_length=HOP_LEN, 
                            center=True, 
                            pad_mode="reflect", 
                            power = 2.0)

In [12]:
def resample_audio(audio, sr):
    resampled_audio = F.resample(audio, sr, SAMPLE_RATE)
    return resampled_audio

def stereo_to_mono(audio):
    new_audio = torch.mean(audio, dim=0).unsqueeze(0)
    return new_audio


def load_audio(aud_fname):
    
    raw_wav, sampleRate = torchaudio.load(aud_fname)
    if raw_wav.shape[0] == 2:
        raw_wav = stereo_to_mono(raw_wav)
    if sampleRate != SAMPLE_RATE:
        raw_wav = resample_audio(raw_wav, sampleRate)
    return raw_wav


def add_noise(audio, rir, noise_wav, snr):
    echo_audio = F.fftconvolve(audio, rir)[:,0:audio.shape[1]]
    noisy_audio = F.add_noise(echo_audio, noise_wav[:,0:audio.shape[1]], torch.Tensor([snr]))
    return noisy_audio

def random_second_choice(audio):
    duration = (int)(audio.shape[1]/SAMPLE_RATE)
    random_sec = random.choice([i for i in range(0, duration-1)])
    return random_sec

In [13]:
def get_data(filename):

    rir_fname = os.path.join(rir_dir,random.choice(rir_files))
    bg_fname = os.path.join(bg_dir, random.choice(bg_files))
    snr_choice = random.choice([5,10,20])

    wav = load_audio(filename)
    rir_ = load_audio(rir_fname)
    bg = load_audio(bg_fname)

    rand_wav_sec = random_second_choice(wav)
    rand_bg_sec = random_second_choice(bg)

    wav_sec = wav[:,rand_wav_sec*SAMPLE_RATE:(rand_wav_sec+1)*SAMPLE_RATE]
    bg_sec = bg[:,rand_bg_sec*SAMPLE_RATE:(rand_bg_sec+1)*SAMPLE_RATE]

    noisy_audio = add_noise(wav_sec, rir_, bg_sec, snr_choice)

    return noisy_audio, bg_sec


In [44]:
spec_inst = T.Spectrogram(
    n_fft=400,
    win_length=None,
    hop_length=100,
    power=None
)

inv_spec_inst = T.InverseSpectrogram(
    n_fft=400,
    win_length=None,
    hop_length=100
)

def get_spectrogram(audio):
    spec = T.Spectrogram(power=None)(audio)
    return spec

def get_audio_from_spectrogram(spec):
    audio = T.InverseSpectrogram()(spec)
    return audio

In [69]:
def get_spec_image(audio):
    
    spec = spectrogram(audio.squeeze()).numpy()
    fig, axs = plt.subplots()
    plt.figure(figsize=(10,4))
    img = axs.imshow(librosa.power_to_db(spec), interpolation="nearest", origin="lower", aspect="auto", cmap="viridis")
    axs.axis('off')
    fig.canvas.draw()

    data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    pil_image = Image.fromarray(data)
    plt.close(fig)
    return pil_image

In [86]:
sample_aud = '../../LibriVox_Kaggle/achtgesichterambiwasse/achtgesichterambiwasse_0009.wav'

noisy, noise = get_data(sample_aud)

noisy_spec = get_spectrogram(noisy)
noisy_aud = get_audio_from_spectrogram(noisy_spec)

dada = torch.view_as_real(noisy_spec)
reshpd = torch.reshape(dada, (1,2,201,81))

In [87]:
class audioDataset(Dataset):

    def __init__(self, audio_csvfile, aud_dir):
        self.audio_df = pd.read_csv(audio_csvfile)
        self.aud_dir = aud_dir

    def __len__(self):
        return len(self.audio_df)
    
    def __getitem__(self, index):
        audio_path = os.path.join(self.aud_dir, self.audio_df.iloc[index, 0])

        audio_in, label = get_data(audio_path)
        
        audio_spec = torch.view_as_real(get_spectrogram(audio_in))
        label_spec = torch.view_as_real(get_spectrogram(label))
        audio_spec = torch.reshape(audio_spec,(1,2,201,81))
        label_spec = torch.reshape(label_spec, (1,2,201,81))

        return audio_spec, label_spec

In [88]:
aud_dir = '../../LibriVox_Kaggle/'
train_csv_file = 'only_audioFname_train.csv'
test_csv_file = 'only_audioFname_test.csv'

train_dataset = audioDataset('only_audioFname_train.csv', aud_dir)
test_dataset = audioDataset('only_audioFname_test.csv', aud_dir)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [89]:
audio_spec, labs = next(iter(train_dataloader))
audio_spec.shape, labs.shape

(torch.Size([32, 1, 2, 201, 81]), torch.Size([32, 1, 2, 201, 81]))

In [91]:
labs.squeeze().shape

torch.Size([32, 2, 201, 81])

In [92]:
torch.manual_seed(13)
torch.cuda.manual_seed(13)

class speechRemoval00(nn.Module):
    
    def __init__(self):
        super(speechRemoval00, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(2, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU()
            )
        
        self.decoder = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 2, kernel_size=3, padding=1),
            nn.ReLU()
            )
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)

        return x


In [95]:
torch.manual_seed(13)
torch.cuda.manual_seed(13)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = speechRemoval00().to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [96]:
epochs = 10

for epoch in range(0,epochs):

    loss_ten = torch.Tensor([])
    for data in train_dataloader:
        #print('pass')
        model.train()
        inputs, labels = data
        inputs = inputs.squeeze().to(device)
        labels = labels.squeeze().to(device)
        
        # Forward pass
        outputs = model(inputs)
        
        # Compute loss
        loss = loss_fn(outputs, inputs)

        # BP and optim
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_ten = torch.cat((loss_ten,torch.Tensor([loss.item()])),0)
    
    print(f"Epoch [{epoch + 1}/{epochs}] Loss: {torch.mean(loss_ten)}")


IndexError: Cannot choose from an empty sequence