In [1]:
import os
import torchaudio 
import torch
from tqdm import tqdm
import numpy as np
import sudo_rm_rf.groupcomm_sudormrf_v2 as sudormrf_gc_v2
import sudo_rm_rf.causal_improved_sudormrf_v3 as causal_improved_sudormrf

In [2]:
fs_model = 16000
fs_recordings = 48000
resampler = torchaudio.transforms.Resample(fs_recordings, fs_model).cuda()

In [3]:
# Declare the model
def process_recordings(model_path):
    print('Processing: '+str(model_path))
    modelname = os.path.basename(model_path)
    if modelname.split('_')[-1] == 'causal':
        model = causal_improved_sudormrf.CausalSuDORMRF(
                in_audio_channels=2,
                out_channels=512,
                in_channels=256,
                num_blocks=16,
                upsampling_depth=5,
                enc_kernel_size=21,
                enc_num_basis=512,
                num_sources=1)
        causal = True
    else:
        model = sudormrf_gc_v2.GroupCommSudoRmRf(
                in_audio_channels=2,
                out_channels=512,
                in_channels=256,
                num_blocks=16,
                upsampling_depth=5,
                enc_kernel_size=21,
                enc_num_basis=512,
                num_sources=1)
        causal = False

    # Load checkpoint
    checkpoint_name = os.listdir(model_path)
    checkpoint_name.sort()
    model_path = os.path.join(model_path, checkpoint_name[0])
    model.load_state_dict(torch.load(model_path))
    model = model.cuda()
    model.eval();

    main_path = '/home/ubuntu/Data/ha_listening_situations'
    
    recordings_path = os.path.join(os.path.join(main_path, 'recordings'), 'ku_recordings')
    processed_path = os.path.join(os.path.join(main_path, 'processed_' + modelname), 'ku_processed')

    dirs = [ name for name in os.listdir(recordings_path) if os.path.isdir(os.path.join(recordings_path, name)) ]

    dirs.sort()

    with torch.inference_mode(True):
        for dir in tqdm(dirs):
            dirpath = os.path.join(recordings_path, dir)
            files = os.listdir(dirpath)
            if not os.path.exists(os.path.join(processed_path, dir)):
                os.makedirs(os.path.join(processed_path, dir))
            for file in files:
                mixture, fs = torchaudio.load(os.path.join(dirpath, file))
                mixture = mixture.cuda()
                mixture = resampler(mixture)
                mixture = mixture.cuda()
                ini_nrg = torch.sum(mixture ** 2)
                mixture = (mixture - torch.mean(mixture)) / torch.std(mixture)
                processed = model(mixture.unsqueeze(0))
                processed /= torch.sqrt(torch.sum(processed ** 2) / ini_nrg)
                torchaudio.save(os.path.join(os.path.join(processed_path, dir), file), processed[0].cpu(), sample_rate=fs_model)
    print('Done processing '+str(model_path))
    print(' . ')

In [4]:

process_recordings('pretrained_models/m1_alldata_normal')
process_recordings('pretrained_models/m4_alldata_normal_causal')

Processing: /home/ubuntu/Data/enric_models/m1_alldata_normal


100%|███████████████████████████████████████████| 32/32 [01:47<00:00,  3.36s/it]


Done processing /home/ubuntu/Data/enric_models/m1_alldata_normal/gc_sudo_epoch_25
 . 
Processing: /home/ubuntu/Data/enric_models/m3_alldata_mild


100%|███████████████████████████████████████████| 32/32 [01:55<00:00,  3.61s/it]


Done processing /home/ubuntu/Data/enric_models/m3_alldata_mild/gc_sudo_epoch_25
 . 
Processing: /home/ubuntu/Data/enric_models/m4_alldata_normal_causal


100%|███████████████████████████████████████████| 32/32 [00:44<00:00,  1.39s/it]


Done processing /home/ubuntu/Data/enric_models/m4_alldata_normal_causal/gc_sudo_epoch_25
 . 
Processing: /home/ubuntu/Data/enric_models/m5_alldata_mild_causal


100%|███████████████████████████████████████████| 32/32 [00:43<00:00,  1.36s/it]

Done processing /home/ubuntu/Data/enric_models/m5_alldata_mild_causal/gc_sudo_epoch_24
 . 



