In [1]:
################## IMPORT LIBRARIES ##################
import soundfile as sf
from IPython.display import Audio
import numpy as np
import sys
import importlib
import random 
import pandas as pd
pd.options.mode.copy_on_write = True
import time
from os.path import join as pjoin
from acoustics.bands import third
import scipy.signal as sig
from IPython.display import Audio
# from masp import shoebox_room_sim as srs
from scipy.io import wavfile
#import mat73
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (4, 3)
import torch 

In [2]:
################## IMPORT MY MODULES ##################
sys.path.append('../src')

import helpers as hlp
import evaluation
import dataset as ds
import trainer
import models

importlib.reload(evaluation)
importlib.reload(hlp)
importlib.reload(ds)
importlib.reload(trainer)
importlib.reload(models)

  from torchaudio.backend.common import AudioMetaData


<module 'models' from '/home/ubuntu/joanna/reverb-match-cond-u-net/notebooks/../src/models.py'>

In [16]:

def infer_long_audio(audio_tensor, emb1, emb2, model, config, frame_duration=98304, sample_rate=48000, overlap=0.5):

    device = config["device"]
    
    # Calculate the number of samples per frame and the amount of overlap
    frame_length = int(frame_duration)
    fade_length = int(0.2 * sample_rate)  # Fade length is equal to the overlap
    hop_length = int(frame_length * (1 - overlap))  # Distance between consecutive frames

    # Number of frames to process
    total_samples = audio_tensor.shape[1]
    num_frames = (total_samples + hop_length - 1) // hop_length

    # Apply a Hann window for the fade-in and fade-out
    hann_window = torch.hann_window(fade_length * 2).to(device)  # Full window (for fade in and out)
    fade_in = hann_window[:fade_length]  # First half of the Hann window
    fade_out = hann_window[fade_length:]  # Second half of the Hann window

    # Initialize output buffer and overlap counter
    output_audio = torch.zeros_like(audio_tensor).to(device)
    overlap_counter = torch.zeros_like(audio_tensor).to(device)  # To accumulate overlaps

    # Process each frame
    for i in range(num_frames):
        # Get the start and end indices of the current frame
        start = i * hop_length
        end = min(start + frame_length, total_samples)
        frame = audio_tensor[:, start:end].to(device)

        # Zero-pad the frame if it's shorter than the frame length (for the last frame)
        if frame.shape[1] < frame_length:
            frame = torch.nn.functional.pad(frame, (0, frame_length - frame.shape[1]))

        # Normalize the frame
        frame, sc = hlp.torch_normalize_max_abs(frame, out=True)
        
        # Apply model processing
        processed_frame = model.autoencoder(frame.unsqueeze(0).to(device), emb1, emb2).squeeze(0)
        processed_frame = processed_frame * sc

        # Apply fade-in and fade-out to avoid discontinuities
        if i > 0:
            processed_frame[:, :fade_length] *= fade_in
        if i < num_frames - 1:
            processed_frame[:, -fade_length:] *= fade_out
        
        # Add the processed frame back to the output audio tensor using overlap-add
        output_audio[:, start:end] += processed_frame[:, :end - start]
        overlap_counter[:, start:end] += 1  # Track how many frames contribute to each sample

    # Normalize the overlapped areas by dividing by the number of overlaps
    output_audio /= overlap_counter.clamp(min=1)  # Prevent division by zero

    return output_audio


In [17]:
################## LOAD TRAINING RESULTS AND CONFIG  ##################

datapath="/home/ubuntu/Data/RESULTS-reverb-match-cond-u-net/"
# exp_tag="runs-exp-28-03-2024"
# train_tag="02-04-2024--14-14_many-to-many_stft_1"


exp_tag="runs-exp-20-05-2024"
# train_tag="20-05-2024--22-48_c_wunet_logmel+wave_0.8_0.2"
train_tag="10-06-2024--15-02_c_wunet_stft+wave_0.8_0.2"

config ,train_results = trainer.load_train_results(datapath, exp_tag, train_tag,configtype="yaml")


# instantiate a test data set (to have an easier access to the RIR)
config["split"]="test"
config["df_metadata"]="/home/ubuntu/joanna/reverb-match-cond-u-net/dataset-metadata/nonoise_48khz_guestxr.csv"
config["p_noise"]=0
dataset=ds.DatasetReverbTransfer(config)


  train_results=torch.load(pjoin(datapath,exp_tag,train_tag,"checkpointbest.pt"),map_location=config["device"])


In [18]:
################## LOAD MODELS AND TRAINING WEIGHTS  ##################

model=trainer.load_chosen_model(config,config["modeltype"])
model.load_state_dict(train_results["model_state_dict"])

# # for older results (28-03-2024)
# model.autoencoder.load_state_dict(train_results["model_waveunet_state_dict"])
# model.conditioning_network.load_state_dict(train_results["model_reverbenc_state_dict"])

<All keys matched successfully>

In [23]:
################## LOAD FILES AND CONVOLVE WITH RIRS ##################

# load audios

# fs, s1=wavfile.read("/home/ubuntu/joanna/demo-wunet/v1p1.wav")
# fs, s2=wavfile.read("/home/ubuntu/joanna/demo-wunet/v2p2.wav")

fs, s1=wavfile.read("/home/ubuntu/joanna/demo-wunet/0/0_1_d0.wav")
fs, s2=wavfile.read("/home/ubuntu/joanna/demo-wunet/0/1_0_d0.wav")

# s1, fs=sf.read("/home/ubuntu/Data/VCTK/wav48_silence_trimmed/p232/p232_003_mic2.flac")
# s2, fs=sf.read("/home/ubuntu/Data/VCTK/wav48_silence_trimmed/p234/p234_004_mic2.flac")

# get float values
s1=s1.astype('float32')
s2=s2.astype('float32')

# resample
s1 = sig.resample_poly(s1, 48000, fs)
s2 = sig.resample_poly(s2, 48000, fs)

# pick impulse responses from a data set
idxs=dataset.get_idx_with_rt60diff(0.8,0.9)
df_r1_info=dataset.get_info(idxs[2],id="style")
df_r2_info=dataset.get_info(idxs[2],id="content")

# load impulse responses
fs, r1=wavfile.read(df_r1_info["ir_file_path"])
fs, r2=wavfile.read(df_r2_info["ir_file_path"])
# get float values
r1=r1.astype('float32')
r2=r2.astype('float32')

# create reverberant version of speech 
s1r1_np=sig.fftconvolve(s1,r1, 'full', 0)
s2r2_np=sig.fftconvolve(s2,r2, 'full', 0)

# playback
display("s1r1")
display(Audio(s1r1_np,rate=48e3))
display("s1r2")
display(Audio(s2r2_np,rate=48e3))


's1r1'

's1r2'

In [24]:
################## GET 2 SEC SAMPLE TO EXTRACT STYLE  ##################

# move to tensors of size  (1xL)
s1r1=torch.tensor(s1r1_np).unsqueeze(0)
s2r2=torch.tensor(s2r2_np).unsqueeze(0)

# get non-silent frame
s1_ref=hlp.get_nonsilent_frame(s1r1,dataset.sig_len)
s2_ref=hlp.get_nonsilent_frame(s2r2,dataset.sig_len)

# normalize
s1r1_ref=hlp.torch_normalize_max_abs(s1_ref)
s2r2_ref=hlp.torch_normalize_max_abs(s2_ref) 

# playback

display("s1r1_ref")
display(Audio(s1r1_ref,rate=48e3))

display("s2r2_ref")
display(Audio(s2_ref,rate=48e3))



's1r1_ref'

's2r2_ref'

In [25]:
# get reverb encodings from both rooms
emb_r1=model.conditioning_network(s1r1_ref.unsqueeze(0).to(config["device"]))
emb_r2=model.conditioning_network(s2r2_ref.unsqueeze(0).to(config["device"]))

pred_s1r2 = infer_long_audio(s1r1_ref, emb_r1, emb_r2, model, config)
pred_s2r1 = infer_long_audio(s2r2_ref, emb_r2, emb_r1, model, config)
# sPrediction=model.autoencoder(s1r1_ref.unsqueeze(0).to(config["device"]),emb_r1,emb_r2)
# sPrediction=model.(hlp.unsqueezeif2D(s2_ref),hlp.unsqueezeif2D(s1_ref))

display("pred_s1r2")
display(Audio(pred_s1r2.squeeze(0).detach().cpu(),rate=48e3))
display("pred_s2r1")
display(Audio(pred_s2r1.squeeze(0).detach().cpu(),rate=48e3)) 


'pred_s1r2'

'pred_s2r1'