In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as utils
import datetime
import os
from tqdm.notebook import tqdm
import random
import numpy as np
from librosa.util import find_files
from librosa.core import load,stft,resample,istft

from IPython.display import display, Audio

import parameter as C
import myutils as ut
import network

fftpath=C.PATH_FFT
model_path = "./model/model/epoch80.pt"

PATH_FFT = "./stft_data"
SPEECH_PATH = "./speech_data"
NOISE_PATH = "./noise_data"

speechlist = find_files(SPEECH_PATH, ext="npy")
noiselist = find_files(NOISE_PATH, ext="npy")

random.shuffle(speechlist)
random.shuffle(noiselist)

if not os.path.exists(PATH_FFT):
    os.mkdir(PATH_FFT)
noise_num = len(noiselist)

i=0    
#speech data
spec = stft(np.load(speechlist[i]), n_fft=C.FFT_SIZE, hop_length=C.H, win_length=C.FFT_SIZE)
fulllen=spec.shape[1]

while fulllen<C.PATCH_LENGTH * (C.BATCH_SIZE+1) :
    i+=1
    conc = stft(np.load(speechlist[i]), n_fft=C.FFT_SIZE, hop_length=C.H, win_length=C.FFT_SIZE)
    spec = np.concatenate((spec,conc),1)
    fulllen = spec.shape[1]
speech_spec = spec[:C.PATCH_LENGTH * (C.BATCH_SIZE+1) ]

#noise data
spec = stft(np.load(noiselist[random.randint(0, noise_num-1)]), n_fft=C.FFT_SIZE, hop_length=C.H, win_length=C.FFT_SIZE)
fulllen=spec.shape[1]

while fulllen<C.PATCH_LENGTH * (C.BATCH_SIZE+1) :
    i+=1
    conc = stft(np.load(noiselist[random.randint(0, noise_num-1)]), n_fft=C.FFT_SIZE, hop_length=C.H, win_length=C.FFT_SIZE)
    spec = np.concatenate((spec,conc),1)
    space = np.zeros([spec.shape[0],random.randint(1,120)])
    spec = np.concatenate((spec,space),1)
    fulllen = spec.shape[1]
noise_spec = spec[:C.PATCH_LENGTH * (C.BATCH_SIZE+1)]

#data mixer
speech_spec=speech_spec[:,: C.PATCH_LENGTH * C.BATCH_SIZE]
noise_spec=noise_spec[:,: C.PATCH_LENGTH * C.BATCH_SIZE]
mix_spec=speech_spec+noise_spec

print("処理前ミックス")
listen_mix=istft(mix_spec,hop_length=C.H, win_length=C.FFT_SIZE)
display(Audio(listen_mix, rate=16000))
print("処理前話し声")
listen_speech=istft(speech_spec,hop_length=C.H, win_length=C.FFT_SIZE)
display(Audio(listen_speech, rate=16000))
print("処理前打鍵音")
listen_noise=istft(noise_spec,hop_length=C.H, win_length=C.FFT_SIZE)
display(Audio(listen_noise, rate=16000))

#ここから振幅と位相に分けて処理していきます
mix_mag = np.abs(mix_spec)
mix_mag /= np.max(mix_mag)
mix_phase = np.exp(1.j*np.angle(mix_spec))

listen_list=[]
for iterate in tqdm(range(C.BATCH_SIZE)):
    data = mix_mag[:,C.PATCH_LENGTH * iterate:C.PATCH_LENGTH * (iterate+1)]
    data=torch.from_numpy(data.astype(np.float32)).clone()
    listen_list.append(data)
    
tensor_data = torch.stack(listen_list)

model = network.UnetConv2()
model.load_state_dict(torch.load(model_path))
mask=model(tensor_data)
#%%%%
mask[mask < 0.90]=0 #hard mask
#%%%
h=tensor_data * mask
h = h.to('cpu').detach().numpy().copy()

q = tensor_data * (1-mask)
q = q.to('cpu').detach().numpy().copy()

output = h[0,:,:]
for i in tqdm(range(1,C.BATCH_SIZE)):
    output = np.concatenate([output, h[i,:,:]], 1)
print("処理後音声")
denoise=istft(output*mix_phase,hop_length=C.H, win_length=C.FFT_SIZE)
display(Audio(denoise, rate=16000))

output = q[0,:,:]
for i in tqdm(range(1,C.BATCH_SIZE)):
    output = np.concatenate([output, q[i,:,:]], 1)
print("処理後音声(打鍵)")
denoise=istft(output*mix_phase,hop_length=C.H, win_length=C.FFT_SIZE)
display(Audio(denoise, rate=16000))