In [2]:
import os
import tensorflow as tf
import soundfile as sf
import numpy as np
from backbones.unet import UNet
from config import Config
from dataset.preprocess_utils import set_loudness, create_rir_conds
from IPython.display import Audio, display
from pydub import AudioSegment

# Activate CUDA if GPU is present
physical_devices = tf.config.list_physical_devices("GPU")
print("Physical_devices:", physical_devices)

if len(physical_devices) > 0:
    os.environ["CUDA_VISIBLE_DEVICES"] = str(0)
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
else:
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

# Load config parameters
params = Config()
model_params = params.model
pre_params = params.data
train_params = params.train

# Initialize RI UNet
ri_unet = UNet(model_params)

#load checkpoints
ckpt = tf.train.Checkpoint(model=ri_unet)
ckpt_path = '/media/datadisk/dimos/drums_dereverb/saved_models/CDiff_RI_gmd_pre_5e-5/checkpoints'
ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_path, max_to_keep=1)
# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial()
    print('Model weights restored!')
else:
    print('Check path for model weights')


file_path = '/media/datadisk/dimos/drums_dereverb/data/clean_drums/ANiMAL-Rockshow.wav'      
        
# Load the audio file
try:
    audio_ex, sr = sf.read(file_path)
    # Convert to mono if it's multi-channel
    if len(audio_ex.shape) > 1:  # If the audio has more than one channel
        audio_ex = np.mean(audio_ex, axis=1)  # Average across channels to make mono

    # Convert audio to np.float32
    audio_ex = audio_ex.astype(np.float32)
        
    if sr != pre_params.sr:
        raise ValueError(f"Sample rate mismatch. Expected {pre_params.sr}, got {sr}.")
    
    print('Audio file loaded succesfully')
except Exception as e:
    print(f"Error reading file {file_path}: {e}")

Physical_devices: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Model weights restored!
Audio file loaded succesfully


In [3]:
# Get first 30 Seconds
audio_ex = audio_ex[:30 * pre_params.sr]

# Sample T60 and room dimensions
t60 = np.random.uniform(pre_params.t60_r[0], pre_params.t60_r[1])
room_dim = np.array(
        [
        np.random.uniform(
        pre_params.room_dim_r[2 * n],
        pre_params.room_dim_r[2 * n + 1])
        for n in range(3)
        ])
                    
# Create clean and reverberant files
lossy_ex, dry_ex = create_rir_conds(t60, room_dim, pre_params.min_distance_to_wall, pre_params.sr, audio_ex)

# Normalize
lossy_ex = set_loudness(lossy_ex, pre_params.sr, LUFS=pre_params.lufs)
lossy_ex = lossy_ex.astype(np.float32)
dry_ex = set_loudness(dry_ex, pre_params.sr, LUFS=pre_params.lufs)
dry_ex = dry_ex.astype(np.float32)

dry_ex_ = (dry_ex * 32767).astype(np.int16)  # Scale and convert to int16
lossy_ex_ = (lossy_ex * 32767).astype(np.int16)

# Directory to save audio files
output_dir = "selected_examples/notebooks"
os.makedirs(output_dir, exist_ok=True)

dry_audio = AudioSegment(
    dry_ex_.tobytes(), 
    frame_rate=pre_params.sr, 
    sample_width=dry_ex_.dtype.itemsize, 
    channels=1
)

dry_file_path = os.path.join(output_dir, "dry_input.mp3")
dry_audio.export(dry_file_path, format="mp3", bitrate="128k")

print('Input Audio (Ground Truth) normalized -28 LUFS')
display(Audio(dry_file_path, embed=True))

reverb_audio = AudioSegment(
    lossy_ex_.tobytes(), 
    frame_rate=pre_params.sr, 
    sample_width=lossy_ex_.dtype.itemsize, 
    channels=1
)

reverb_file_path = os.path.join(output_dir, "reverberant_audio.mp3")
reverb_audio.export(reverb_file_path, format="mp3", bitrate="128k")

print('\nReverberant Audio normalized -28 LUFS')
display(Audio(reverb_file_path, embed=True))

Input Audio (Ground Truth) normalized -28 LUFS



Reverberant Audio normalized -28 LUFS


In [4]:
def segment_audio(audio, overlap=0.5):
    """Segment audio for diffusion"""
    segment_samples = pre_params.dur * pre_params.sr
    step_size = int((1 - overlap) * segment_samples)
    segments = []
    for start in range(0, len(audio) - segment_samples + 1, step_size):
        segment = audio[start:start + segment_samples]
        segments.append(segment)
    return np.array(segments)

def compute_ri_stft(signal):
    """Compute RI STFTs for UNet RI, DCCRN, DCUNet."""
    signal_stft = tf.signal.stft(
        signal,
        frame_length=pre_params.win,
        frame_step=pre_params.hop,
        fft_length=pre_params.fft,
        window_fn=pre_params.window_fn(),
    )

    signal_stft_real = tf.math.real(signal_stft)
    signal_stft_imag = tf.math.imag(signal_stft)

    # create a new dimension also for UNet
    signal_stft_ri = tf.stack([signal_stft_real, signal_stft_imag], axis=-1)

    return tf.cast(signal_stft_ri, tf.float32)


def reverse_diffusion(inp_ri, step_stop=0):
    """Reverse Cold Diffusion Method. Returns all steps"""
    inp_ri = inp_ri[tf.newaxis,...]
        
    base = tf.ones([1], dtype=tf.int32)
    #step_stop=0 is full reverse diffusion
        
    #store all steps
    all_diff_steps = []
        
    for t in range(train_params.diffusions_steps, step_stop, -1):
        inp_ri = ri_unet([inp_ri, base * t])
        all_diff_steps.append(tf.squeeze(inp_ri).numpy())
            
    return all_diff_steps


#slice to buffers
lossy_buffers = segment_audio(lossy_ex)
noisy_ri_stfts = compute_ri_stft(lossy_buffers)
#reverse diffusion for every noisy stft buffer
all_diffed_buffers = []
for s in range(0, len(noisy_ri_stfts)):
    buffer = noisy_ri_stfts[s,...]
    all_buffer_steps = reverse_diffusion(buffer)
    all_diffed_buffers.append(all_buffer_steps)


print('Completed')


2025-01-13 13:40:28.596466: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8907


Completed


In [5]:
def compute_signal_from_RI_stft(ri_stft, frame_length=pre_params.win, frame_step=pre_params.hop):
    
    polar_spec = tf.complex(ri_stft[...,0], ri_stft[...,1])
        
    inversed_stft = tf.signal.inverse_stft(polar_spec, frame_length=frame_length, frame_step=frame_step, 
                fft_length=frame_length, window_fn=tf.signal.inverse_stft_window_fn(
                     frame_step,forward_window_fn=tf.signal.hann_window))
    return inversed_stft

# Define window application function
def apply_window(segments, window_fn=tf.signal.hann_window):
    """ Apply a window to each segment post ISTFT. """
    window = window_fn(segments.shape[1])
    windowed_segments = segments * window[None, :]  # Apply window along each segment
    return windowed_segments


# Define the Overlap-Add function
def overlap_add(windowed_segments, step_size):
    """ Perform the overlap-add process to reconstruct the signal. """
    output_length = step_size * (len(windowed_segments) - 1) + len(windowed_segments[0])
    reconstructed_signal = np.zeros(output_length)
    for i, segment in enumerate(windowed_segments):
        start_index = i * step_size
        reconstructed_signal[start_index:start_index + len(segment)] += segment
    return reconstructed_signal


def reconstruct_from_segments(stft_segments, overlap=0.5):
    
    # ISTFT to convert back to time domain and apply window
    istft_segments = np.array([compute_signal_from_RI_stft(stft).numpy() for stft in stft_segments])
    windowed_segments = apply_window(istft_segments)
    
    segment_samples = int(pre_params.dur * pre_params.sr)
    step_size = int(segment_samples * overlap)  # 50% overlap
    
    # Overlap-add to reconstruct the full signal
    reconstructed_audio = overlap_add(windowed_segments, step_size)
    
    return reconstructed_audio


#reconstruct for each diffusion step
#for every diffusion step
for t in range(0, train_params.diffusions_steps):
    diff_step = []
    #get from every sample the prediction
    for p in range(0, len(all_diffed_buffers)):
        diff_step.append(all_diffed_buffers[p][t])
    diff_step_ri =tf.stack(diff_step)
    #reconstruct
    diff_wav = reconstruct_from_segments(diff_step_ri)
    diff_wav = (diff_wav * 32767).astype(np.int16)  # Scale and convert to int16
    diff_audio = AudioSegment(
        diff_wav.tobytes(), 
        frame_rate=pre_params.sr, 
        sample_width=diff_wav.dtype.itemsize, 
        channels=1
    )
    #save it 
    diff_file_path = os.path.join(output_dir, "diffused_"+str(t)+'.mp3')
    diff_audio.export(diff_file_path, format="mp3", bitrate="128k")
    print('Audio Diffused step', t)
    display(Audio(diff_file_path, embed=True))

print('Input Audio (Ground Truth)')
display(Audio(dry_file_path, embed=True))

            

Audio Diffused step 0


Audio Diffused step 1


Audio Diffused step 2


Audio Diffused step 3


Audio Diffused step 4


Audio Diffused step 5


Audio Diffused step 6


Audio Diffused step 7


Audio Diffused step 8


Audio Diffused step 9


Input Audio (Ground Truth)
