In [None]:
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
import os
import tqdm
import scipy.signal as sig
import time
import numpy as np
import soundfile as sf

In [None]:
class WavFolderDataset(Dataset,):
    def __init__(self, main_path, mode, sample_rate=48000):
        self.maxlen = 60000 #maximum IR lenght in samples
        self.main_path = main_path
        self.mode = mode
        if mode == 'b1':
            folders = ['b1_gp']
        elif mode == 'd1':
            folders = ['b1_gp', 'd1_original']
        elif mode == 'd2':
            folders = ['b1_gp', 'd1_original', 'd2_capsules']
        elif mode == 'd3':
            folders = ['b1_gp', 'd1_original', 'd2_capsules', 'd3_beamforming']
        elif mode == 'd4':
            folders = ['b1_gp', 'd1_original', 'd2_capsules', 'd3_beamforming', 'd4_permute']
        elif mode == 'ob1':
            folders = ['b1_gp']
        elif mode == 'od1':
            folders = ['d1_original']
        elif mode == 'od2':
            folders = ['d2_capsules']
        elif mode == 'od3':
            folders = ['d3_beamforming']
        elif mode == 'od4':
            folders = ['d4_permute']
        else:
            print('unspecified RIR dataset!')
        
        wav_files = []
        for folder in folders:
            files = os.listdir(os.path.join(main_path, folder))
            for f in files:
                if '.wav' in f:
                    wav_files.append(os.path.join(os.path.join(main_path, folder), f))
        self.wav_files = wav_files
        self.sample_rate = sample_rate

    def __len__(self):
        return len(self.wav_files)

    def __getitem__(self, idx):
        path = self.wav_files[idx]
        waveform, sr = torchaudio.load(path)
        if self.sample_rate and sr != self.sample_rate:
            print('resampling RIR.')
            waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)
            sr = self.sample_rate
            
        zeros_to_pad = self.maxlen - waveform.shape[1]
        if zeros_to_pad > 0:
            return torch.hstack((waveform, torch.zeros((2, zeros_to_pad))))
        else:
            return waveform[:, :self.maxlen]


In [None]:
main_path = '/media/diskA/enric/parirset'

In [None]:
def power(signal):
    return torch.mean(signal**2)
    
def conv_torch(test_track, x):
    #minibatch, channels, iW
    batch_size = x.shape[0]
    out = []
    for i in range(batch_size):
        left = torchaudio.functional.fftconvolve(
            test_track[i, 0, :],
            x[i, 0, :], 'same')
        right = torchaudio.functional.fftconvolve(
            test_track[i, 1, :],
            x[i, 1, :], 'same')
        
        mix = torch.stack([left, right])
        mix *= torch.sqrt(power(test_track) / power(mix)) #keep power the same as input audio
        out.append(mix)
    return torch.stack(out)

In [None]:
def conv_scipy(test_track, x):
    #if test_track.device != 'cpu':
    #    test_track = test_track.detach().cpu().numpy()
    #    x = x.detach().cpu().numpy()
    #else:
    #    test_track = test_track.numpy()
    #    x = x.numpy()
    batch_size = x.shape[0]
    out = []
    for i in range(batch_size):
        left = sig.fftconvolve(test_track[i, 0, :], x[i, 0,:])
        right = sig.fftconvolve(test_track[i, 1, :], x[i, 1,:])

    return np.array(out) #torch.from_numpy(np.array(out)).to("cuda:1")
# TRY SCIPY CONV

test_track, _ = torchaudio.load('test_track.wav')
BS = 5
dataset = WavFolderDataset(main_path, 'd1')
loader = DataLoader(dataset, batch_size=BS, shuffle=False, drop_last=True)

t1 = time.time()
#test_track = test_track.to("cuda:1")
test_track = test_track.repeat(BS,1,1)
test_track = test_track.numpy()

for x in tqdm.tqdm(loader):
    #x = x.to("cuda:1")
    _ = conv_scipy(test_track, x)
t2 = time.time()

In [None]:
# MUCH FASTER PYTORCH CONV
test_track, _ = torchaudio.load('test_track.wav')

t1 = time.time()
test_track = test_track.to("cuda:1")
test_track = test_track.repeat(BS,1,1)

for x in tqdm.tqdm(loader):
    x = x.to("cuda:1")
    out2 = conv_torch(test_track, x)
t2 = time.time()

In [None]:
'''
test_track, _ = torchaudio.load('test_track.wav')

t1 = time.time()
test_track = test_track.to("cuda:1")
test_track = test_track.repeat(BS,1,1)

for x in tqdm.tqdm(loader):
    x = x.to("cuda:1")
    out3 = conv_torch_batch(test_track, x)
t2 = time.time()

In [None]:
from IPython.display import Audio 

In [None]:
Audio(out2[0].cpu().numpy(), rate=48000)

In [None]:
Audio(out2[3].cpu().numpy(), rate=48000)

In [None]:
Audio(out2[4].cpu().numpy(), rate=48000)

In [None]:
Audio(out2[0].cpu().numpy(), rate=48000)