In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1,2'

import warnings
warnings.simplefilter("ignore", UserWarning)

import torch
import torchaudio
import librosa
import time
import pickle
from pesq import pesq
import numpy as np
import numpy as npb
from tqdm import tqdm
from istft import ISTFT
from aia_trans import aia_complex_trans_mag, aia_complex_trans_ri, dual_aia_trans_merge_crm
import soundfile as sf
import multiprocessing
from collections import OrderedDict

In [2]:
class Enhance:
    def __init__(self, args):
        self.model = dual_aia_trans_merge_crm()
        checkpoint = torch.load(args['Model_path'])

        state_dict = checkpoint['model_state_dict']
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:] # remove `module.`
            new_state_dict[name] = v   
        self.model.load_state_dict(new_state_dict)

        self.model.eval()
        self.model.cuda()
        self.istft = ISTFT(filter_length=320, hop_length=160, window='hanning')
        self.fs = args['fs']

    def enhance(self, noise_file_path, clean_file_path):
        with torch.no_grad():
            feat_wav, _ = sf.read(noise_file_path)
            c = np.sqrt(len(feat_wav) / np.sum((feat_wav ** 2.0)))
            feat_wav = feat_wav * c
            wav_len = len(feat_wav)
            frame_num = int(np.ceil((wav_len - 320 + 320) / 160 + 1))
            fake_wav_len = (frame_num - 1) * 160 + 320 - 320
            left_sample = fake_wav_len - wav_len
            feat_wav = torch.FloatTensor(np.concatenate((feat_wav, np.zeros([left_sample])), axis=0))
            feat_x = torch.stft(feat_wav.unsqueeze(dim=0), n_fft=320, hop_length=160, win_length=320,
                                window=torch.hann_window(320)).permute(0, 3, 2, 1)
            noisy_phase = torch.atan2(feat_x[:, -1, :, :], feat_x[:, 0, :, :])
            feat_x_mag = (torch.norm(feat_x, dim=1)) ** 0.5
            feat_x = torch.stack((feat_x_mag * torch.cos(noisy_phase), feat_x_mag * torch.sin(noisy_phase)), dim=1)
            esti_x = self.model(feat_x.cuda())
            esti_mag, esti_phase = torch.norm(esti_x, dim=1), torch.atan2(esti_x[:, -1, :, :],
                                                                            esti_x[:, 0, :, :])
            esti_mag = esti_mag ** 2
            esti_com = torch.stack((esti_mag * torch.cos(esti_phase), esti_mag * torch.sin(esti_phase)), dim=1)
            esti_com = esti_com.cpu()
            esti_utt = self.istft(esti_com).squeeze().detach().numpy()
            esti_utt = esti_utt[:wav_len]
            esti_utt = esti_utt / c
            sf.write(clean_file_path, esti_utt, args['fs'])

In [3]:
###
with open('/workspace/SE_2022/train_noise_by_type.pkl', 'rb') as f:
    train_noise = pickle.load(f)

with open('/workspace/SE_2022/val_noise_by_type.pkl', 'rb') as f:
    val_noise = pickle.load(f)

with open('/workspace/SE_2022/train_map.pkl', 'rb') as f:
    noise_clean_map = pickle.load(f)

In [4]:

with open('/workspace/SE_2022/test_noise_by_type.pkl', 'rb') as f:
    test_noise = pickle.load(f)

In [5]:
# data_type = 'blower'  


# noise_path_list = []
# for path in (val_noise[data_type]):
#     noise_path = '/workspace/SE/data/train/' + path
#     noise_path_list.append(noise_path)

# args = {}
# args['Model_path'] = f'/workspace/model_new_best/{data_type}.pth.tar'
# args['fs'] = 16000

# enhance_model = Enhance(args)

In [6]:
output_dir = '/workspace/output_data/new_best/test'

for data_type in tqdm(list(test_noise.keys())):
    noise_path_list = []
    for path in (test_noise[data_type]):
        noise_path = '/workspace/SE/data/test/' + path
        noise_path_list.append(noise_path)
        args = {}
        if data_type == 'grinding':
            args['Model_path'] = '/workspace/grinding.pth.tar'          
        else:
            args['Model_path'] = f'/workspace/model_new_best/{data_type}.pth.tar'
        args['fs'] = 16000

        enhance_model = Enhance(args)        


    for noise_file_path in noise_path_list:
        if noise_file_path.endswith('.flac'):

            number = noise_file_path.split('_')[1]
            clean_file_name = 'vocal_'+ number + '.flac'
            clean_file_path = os.path.join(output_dir, clean_file_name)
            try:
                enhance_model.enhance(noise_file_path, clean_file_path)
            except:
                print('Error:', noise_file_path)        

100%|██████████| 21/21 [09:22<00:00, 26.80s/it]


### Prediction

### Prediction

In [42]:
output_dir = '/workspace/output_data/new_best/val'

for noise_file_path in tqdm(noise_path_list):
    if noise_file_path.endswith('.flac'):
        if  noise_file_path == '/workspace/SE/data/train/mixed_02522_blower.flac':
            continue

        number = noise_file_path.split('_')[1]
        clean_file_name = 'vocal_'+ number + '.flac'
        clean_file_path = os.path.join(output_dir, clean_file_name)
        try:
            enhance_model.enhance(noise_file_path, clean_file_path)
        except:
            print('Error:', noise_file_path)

100%|██████████| 166/166 [00:50<00:00,  3.27it/s]


### Evaluation

In [43]:
import os
import time
from unittest import result
from tqdm import tqdm
# import tqdm
# import tqdm.asyncio
import numpy as np
import soundfile as sf
import pickle
import multiprocessing

import torch
import torchaudio

from pesq import pesq

def get_pesq(ref_file_path, deg_file_path):
    clean, _ = sf.read(ref_file_path)
    noise, rate = sf.read(deg_file_path)
    try:
        return pesq(16000, clean, noise, 'wb')
    except:
        return 0

In [44]:
result_dict = {}

output_dir = '/workspace/output_data/new_best/val'
true_dir = '/workspace/SE/data/train'

# for data_type in tqdm(list(val_noise.keys())):
noise_path_list = []
for path in (val_noise[data_type]):
    noise_path = '/workspace/SE/data/train/' + path
    noise_path_list.append(noise_path)

args = []
for noise_file_path in noise_path_list:
    if noise_file_path.endswith('.flac'):
        if  noise_file_path == '/workspace/SE/data/train/mixed_02522_blower.flac':
            continue

        number = noise_file_path.split('_')[1]
        clean_file_name = 'vocal_'+ number + '.flac'
        clean_file_path = os.path.join(output_dir, clean_file_name)
        true_file_path = os.path.join(true_dir, clean_file_name)
        args.append((true_file_path, clean_file_path))
pool = multiprocessing.Pool(processes=60)
results =pool.starmap_async(get_pesq, args)
score = np.array(results.get())
pool.close()
pool.join()
result_dict[data_type] = [score.sum(), len(noise_path_list)]

In [45]:
result_dict

{'blower': [474.3190757036209, 166]}

In [96]:
for name, value in result_dict.items():
    print(f'{name} average score: {value[0]/value[1]}')

air_conditioner average score: 2.882402257026295
blower average score: 2.7714439228356604
car_horn average score: 2.8317565077988704
children_playing average score: 3.0695928377613053
cleaner average score: 2.4125387690024462
dog_bark average score: 3.3749859715953017
drilling average score: 2.5962354975087303
engine_idling average score: 2.860425354747161
fan average score: 2.4821400877958286
garbage_truck average score: 3.1284913851554137
grinding average score: 2.2864224569099707
gun_shot average score: 2.7649630994763053
jackhammer average score: 2.360905850005007
market average score: 1.948142236973866
music average score: 2.380946582736391
rainy average score: 3.2969092337671153
siren average score: 3.543262479547969
street_music average score: 2.607016674535615
traffic average score: 2.535739555503383
train average score: 2.1933772362858415
silence average score: 4.598791903637825


In [92]:
result_dict

{'air_conditioner': [863.5438184738159, 283],
 'blower': [455.8149915933609, 166],
 'car_horn': [795.7235786914825, 281],
 'children_playing': [899.303745508194, 281],
 'cleaner': [389.9822087287903, 167],
 'dog_bark': [569.8560967445374, 165],
 'drilling': [726.9459393024445, 280],
 'engine_idling': [835.0588532686234, 281],
 'fan': [407.46159875392914, 167],
 'garbage_truck': [523.5132936239243, 166],
 'grinding': [356.7650239467621, 164],
 'gun_shot': [782.4845571517944, 283],
 'jackhammer': [392.4669420719147, 167],
 'market': [321.21466636657715, 166],
 'music': [404.53263568878174, 165],
 'rainy': [542.7403056621552, 167],
 'siren': [588.2848097085953, 167],
 'street_music': [773.07677090168, 280],
 'traffic': [423.7632557153702, 165],
 'train': [378.3397938013077, 166],
 'silence': [216.14321947097778, 47]}

In [None]:
args = []
for flac_name in os.listdir(deg_root):
    if flac_name.endswith('.flac'):
        if flac_name.split('_')[0] == 'mixed':
            number = flac_name.split('_')[1]
            clean_file_name = 'vocal_'+ number + '.flac'
            noise_file_path = os.path.join(deg_root, flac_name)
            clean_file_path = os.path.join(ref_root, clean_file_name)
            args.append((noise_file_path, clean_file_path))
pool = multiprocessing.Pool(processes=60)
results =pool.starmap_async(get_pesq, args)
score = results.get()
pool.close()
pool.join()        