In [18]:
import os
import sys
import torch
import torchaudio
from torchaudio import transforms
import numpy as np
import pandas as pd
import random
from torch.utils.data import DataLoader, Dataset
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from pydub import AudioSegment
torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False

np.random.seed(seed:=69)
torch.manual_seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
current_device = torch.cuda.get_device_name(0) if torch.cuda.device_count() > 0 else 'cpu'
print(f'using device: {current_device}')

using device: GeForce GTX 1060 3GB


In [19]:
input_path = './input/cleaned/'
output_path = './input/cleaned/output/'
sample_rate = 44100

In [20]:
input_li = pd.Series([input_path + 'x/' + i for i in os.listdir(input_path + 'x/')], dtype=str)
# targets have the exact same name, but are in the y/ folder, so to ensure that everything is in the 
# proper order, list through the input directory, but link to the y/ directory
target_li = pd.Series([input_path + 'y/' + i for i in os.listdir(input_path + 'x/')], dtype=str)
# table = pd.DataFrame(data=[X, y], columns=['input', 'target'])
# table.head()
df = pd.DataFrame(data={'input': input_li, 'target': target_li})
print(f'missing values: {df.isna().any().sum()}')
df.head()

missing values: 0


Unnamed: 0,input,target
0,./input/cleaned/x/12 Comics You Need to See - ...,./input/cleaned/y/12 Comics You Need to See - ...
1,./input/cleaned/x/12 Comics You Need to See - ...,./input/cleaned/y/12 Comics You Need to See - ...
2,./input/cleaned/x/12 Comics You Need to See - ...,./input/cleaned/y/12 Comics You Need to See - ...
3,./input/cleaned/x/12 Comics You Need to See - ...,./input/cleaned/y/12 Comics You Need to See - ...
4,./input/cleaned/x/12 Comics You Need to See - ...,./input/cleaned/y/12 Comics You Need to See - ...


In [21]:
class AudioDataset(Dataset):
    
    def __init__(self, input_path, transform=None):
        self.input_path = input_path
        # use a root path that branches into x/ and y/ directories
        input_li = pd.Series([input_path + 'x/' + i for i in os.listdir(input_path + 'x/')], dtype=str)
        # targets have the exact same name, but are in the y/ folder, so to ensure that everything is in the 
        # proper order, list through the input directory, but link to the y/ directory
        target_li = pd.Series([input_path + 'y/' + i for i in os.listdir(input_path + 'x/')], dtype=str)
        df = pd.DataFrame(data={'input': input_li, 'target': target_li})
        self.df = df.sample(frac=1, random_state=seed)
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        x = self.df['input'].iloc[idx]
        x, _ = torchaudio.load(x)
        y = self.df['target'].iloc[idx]
        y, _ = torchaudio.load(y)
        if self.transform:
            x = self.transform(x)
            y = self.transform(y)
        return x, y
    
    def random(self):
        idx = random.randint(0, len(self.df))
        return self.__getitem__(idx)

In [22]:
data = AudioDataset(input_path=input_path)
print(f'num samples: {len(data)}')
x, y = data[0]
input_shape = x.shape
flat_feats = input_shape[0] * input_shape[1] 
print(input_shape[0], input_shape[1])
x, y

num samples: 30984
2 44100


(tensor([[-0.1303, -0.1345, -0.1401,  ...,  0.0357,  0.0464,  0.0559],
         [ 0.0567,  0.0561,  0.0545,  ...,  0.1372,  0.1462,  0.1501]]),
 tensor([[-0.0445, -0.0443, -0.0428,  ...,  0.0082,  0.0010, -0.0049],
         [ 0.0732,  0.0727,  0.0716,  ...,  0.0344,  0.0327,  0.0284]]))

In [23]:
# loss function
def mse(output:torch.Tensor, label:torch.Tensor):
    return torch.mean((output - label) ** 2)

In [24]:
mse(x, y)

tensor(0.0049)

In [25]:
# build another loss function, which is based on:
# https://openaccess.thecvf.com/content_CVPR_2019/papers/Yuan_Signal-To-Noise_Ratio_A_Robust_Distance_Metric_for_Deep_Metric_Learning_CVPR_2019_paper.pdf
# tinyurl:
# https://tinyurl.com/yclop5na
# this is our implementation
class PSNR:
    """Peak Signal to Noise Ratio
    output and target have range [0, 255]"""

    def __init__(self):
        self.name = "PSNR"

    @staticmethod
    def __call__(output, target):
        mse = torch.mean((output - target) ** 2)
        return 20 * torch.log10(255.0 / torch.sqrt(mse))

class audio_PSNR:
    """Peak Signal to Noise Ratio
    output and target have range [-1, 1]"""

    def __init__(self):
        self.name = "audio PSNR"

    @staticmethod
    def __call__(output, target):
        mse = torch.mean((output - target) ** 2)
        return 20 * torch.log10(1.0 / torch.sqrt(mse))

In [26]:
class FC_Net(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.input_shape = self.num_flat_features(x)
        super(FC_Net, self).__init__()
#         self.conv2d = nn.Conv2d()
    
    def reshape(self, x):
        return torch.unsqueeze(x, 0)
    
    def fc1(self, x):
        x = nn.Linear(self.input_shape, self.input_shape)
        x = f.relu(x)

    def melspectrogram(self, x):
        melspectrogram_transform = transforms.MelSpectrogram(
        sample_rate=44100, n_mels=128
        )
        melspectrogram_db_transform = transforms.AmplitudeToDB()
    
    def num_flat_features(self, x) -> int:
        size = x.size()[1:] # all dims except batch dim
        num_features = 1
        for s in size:
            num_features *= s
        return int(num_features)
    
    def forward(self, x):
        x = self.reshape(x)
        x = f.relu(self.fc1(x))
#         x = f.relu(self.fc3(x))
        x = torch.reshape(x, self.input_shape)
        x = nn.Linear(self.input_shape, self.input_shape)
        return x

In [27]:
net = FC_Net()

In [28]:
# melspectogram_transform = \
#     torchaudio.transforms.MelSpectrogram(
#     sample_rate=sample_rate, n_mels=128)
# melspectogram_db_transform = torchaudio.transforms.AmplitudeToDB()

# melspectogram = melspectogram_transform(audio)
# plt.figure()
# plt.imshow(melspectogram.squeeze().numpy(), cmap='hot')
    
# melspectogram_db=melspectogram_db_transform(melspectogram)
# plt.figure()
# plt.imshow(melspectogram_db.squeeze().numpy(), cmap='hot')