In [39]:
import torch
import torch.utils.data as utils
from librosa.core import istft,load,stft

from IPython.display import display,Audio
from tqdm.notebook import tqdm
import glob
import numpy as np
import os
import random
import re

import model as mm
import utils as ut
from parameter import Parameter

In [33]:
p=Parameter()
datasets_save_dir = p.datasets_path
split = p.datasets_split #test/val/train
batch_size = p.batch_size
sample_rate = p.sample_rate
num_layer = p.num_layer
model_dir_path = p.model_path

clean_speech_dir = p.target_path
noise_dir = p.noise_path

audio_len = p.audio_len



fft_size = p.fft_size
hop_length = p.hop_length

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("CUDA is available:", torch.cuda.is_available())

CUDA is available: True


In [27]:
def length_fitting(data,audio_len):
    if len(data) > audio_len:
        data = data[:audio_len]
    else: 
        while len(data) < audio_len:
            data = np.concatenate((data,data),0)[:audio_len]
    return data

c_files = ut.take_path(clean_speech_dir)
n_files = ut.take_path(noise_dir)

random.shuffle(c_files)
random.shuffle(n_files)

num_c_files = len(c_files)
num_n_files = len(n_files)

In [37]:
model_path = 'model_layer5_20201219_163036.pt'

In [54]:
i=4

c_data, sr_c = load(c_files[i], sr=None)
n_data, sr_n = load(n_files[i], sr=None)

if sr_c != sample_rate:
    c_data, _ = load(c_files[i], sr=sample_rate)
    
if sr_n != sample_rate:
    n_data, _ = load(n_files[i], sr=sample_rate)
    
n_data = length_fitting(n_data,audio_len)

if len(c_data) < audio_len:
    print("音声データが短すぎます。")

else:
    #modelのimport
    c_p = c_data[:audio_len]
    c_p_stft=stft(c_p, n_fft=fft_size, hop_length=hop_length)
    f = c_p_stft.shape[0]
    t = c_p_stft.shape[1]
    num_layer = int(re.sub("\\D", "", model_path)[0])
    model = mm.Net(t,f, num_layer)
    model.load_state_dict(torch.load(model_dir_path+model_path))
    
    # processing
    step = len(c_data) // audio_len
    
    ret = torch.zeros([batch_size,f,t])
    
    for i in tqdm(range(step),leave=False,desc='[AUDIO Process..]'):
        c_p = c_data[i*audio_len : (i+1)*audio_len]
        n_p = n_data[i*audio_len : (i+1)*audio_len]
        
        c_p_stft=stft(c_p, n_fft=fft_size, hop_length=hop_length)
        n_p_stft=stft(n_p, n_fft=fft_size, hop_length=hop_length)
        
        addnoise_stft=c_p_stft+n_p_stft
        
        # modelに通す
        addnoise_tensor=np.abs(addnoise_stft).astype(np.float32)
        addnoise_tensor=torch.from_numpy(addnoise_tensor.astype(np.float32)).clone()
        
        print(addnoise_tensor.shape)
        print(ret.shape)
        
        for bat in range(batch_size):
            ret[i,:,:] = addnoise_tensor
        
        print(ret.shape)
        model.eval()
        mask = model(ret.float())[0,:,:]
        mask = mask.to('cpu').detach().numpy().copy()
        print(mask.shape)
        
        audio = addnoise_stft * mask
        print(audio.shape)
        
        audio =istft(audio, hop_length=hop_length)
        addnoise =istft(addnoise_stft, hop_length=hop_length)
        
        display(Audio(audio,rate = sample_rate))
        display(Audio(addnoise,rate = sample_rate))

HBox(children=(HTML(value='[AUDIO Process..]'), FloatProgress(value=0.0, max=1.0), HTML(value='')))

torch.Size([257, 129])
torch.Size([30, 257, 129])
torch.Size([30, 257, 129])
(257, 129)
(257, 129)


In [3]:
_datasets_path = glob.glob(datasets_save_dir+"/*.npz")
speech_list = []
addnoise_list = []

for file in tqdm(_datasets_path):
    d = np.load(file)    
    speech=torch.from_numpy(d["speech"].astype(np.float32)).clone()
    addnoise=torch.from_numpy(d["addnoise"].astype(np.float32)).clone()
    
    speech_list.append(speech)
    addnoise_list.append(addnoise)
    
num_data = len(speech_list)
a = round(num_data, -2)
if a > num_data:  
    num_usedata = round(num_data-100, -2)
else:
    num_usedata=a
    
tensor_speech = torch.stack(speech_list[:num_usedata])
tensor_addnoise = torch.stack(addnoise_list[:num_usedata])

print("Available data :", num_data)
print("Use data :", num_usedata)

mydataset = utils.TensorDataset(tensor_speech,tensor_addnoise)
data_num = tensor_speech.shape[0]
data_split = [int(data_num * split[0]),
              int(data_num * split[1]),
              int(data_num * split[2])]
test_dataset,val_dataset,train_dataset = utils.random_split(mydataset,data_split)

train_loader = utils.DataLoader(train_dataset,batch_size=batch_size,num_workers=os.cpu_count(),pin_memory=True,shuffle=True)
val_loader = utils.DataLoader(val_dataset,batch_size=batch_size,num_workers=os.cpu_count(),pin_memory=True,shuffle=True)
test_loader = utils.DataLoader(test_dataset,batch_size=batch_size,num_workers=os.cpu_count(),pin_memory=True,shuffle=True)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=52912.0), HTML(value='')))


Available data : 52912
Use data : 52900


In [4]:
print(glob.glob(model_dir_path+"/*.pt"))

['D:/yamamoto/modeldir\\model_layer5_20201219_163036.pt', 'D:/yamamoto/modeldir\\model_layer5_20201219_163036_Epoch10.pt', 'D:/yamamoto/modeldir\\model_layer5_20201219_163036_Epoch20.pt']


In [8]:
model_path = 'model_layer5_20201219_163036.pt'

# model
feat = tensor_addnoise.shape[1]
sequence = tensor_addnoise.shape[2]
model = mm.Net(sequence, feat, num_layer)
model.load_state_dict(torch.load(model_dir_path+model_path))

it = iter(test_loader)
speech,addnoise = next(it)

model.eval()

mask = model(addnoise.float())
h_hat = mask * addnoise

In [17]:
h_hat.shape

torch.Size([30, 257, 129])

In [19]:
for i in range(h_hat.shape[0]):
    print("Index",i+1)
    audio=istft(h_hat.to('cpu').detach().numpy().copy()[i,:,:], hop_length=hop_length)
    addnoiseaudio=istft(addnoise.to('cpu').detach().numpy().copy()[i,:,:], hop_length=hop_length)
    
    display(Audio(addnoiseaudio,rate = sample_rate))
    display(Audio(audio,rate = sample_rate))

Index 1


Index 2


Index 3


Index 4


Index 5


Index 6


Index 7


Index 8


Index 9


Index 10


Index 11


Index 12


Index 13


Index 14


Index 15


Index 16


Index 17


Index 18


Index 19


Index 20


Index 21


Index 22


Index 23


Index 24


Index 25


Index 26


Index 27


Index 28


Index 29


Index 30
