In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import os
import argparse
import soundfile as sf
from IPython.display import Audio
import time

In [None]:
# Local imports from your project structure
from model.gru_audio_model import RNN, GRUAudioConfig
from audioDataLoader.mulaw import mu_law_encode, mu_law_decode

from utils.utils import multi_linspace, steps, plot_condition_tensor

In [None]:

run_directory = "./output/20250805_162729" #'Path to the directory of the saved run.'
top_n = 5 #'Sample from the top N most likely outputs.'
temperature =1.0 #'Controls the randomness of predictions.'
length_seconds =2.0 #'Length of the audio to generate in seconds.'

sample_rate = 16000
generation_length = int(length_seconds * sample_rate)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#parser.add_argument('--output_wav_path', type=str, default='generated_audio.wav', help='Path to save the output WAV file.')
#parser.add_argument('--output_plot_path', type=str, default='generated_waveform.png', help='Path to save the output plot.')


In [None]:
# -------     Load model     -----------#

config_path = os.path.join(run_directory, "config.pt")
checkpoint_path = os.path.join(run_directory, "checkpoints", "last_checkpoint.pt")

assert os.path.exists(run_directory), f"Run directory not found: {run_directory}"
assert os.path.exists(config_path), f"Config file not found: {config_path}"
assert os.path.exists(checkpoint_path), f"Checkpoint file not found: {checkpoint_path}"

saved_configs = torch.load(config_path, weights_only=False)
model_config = saved_configs["model_config"]

model = RNN(model_config).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

print("Model successfully loaded from checkpoint.")


In [None]:
def run_inference(model, cond_seq, warmup_sequence, top_n=3, temperature=1.0):
    """
    Generates audio sequence based on a conditioning sequence.

    Args:
        model: The trained RNN model.
        cond_seq (torch.Tensor): The sequence of conditioning parameters. Shape: (seq_len, num_cond_params).
        warmup_sequence (torch.Tensor): A raw audio sequence to warm up the model's hidden state. Shape: (warmup_len,).
        top_n (int): The number of top predictions to sample from.
        temperature (float): Controls the randomness of predictions. Higher is more random.

    Returns:
        np.array: The generated audio waveform.
    """
    device = next(model.parameters()).device
    print("Starting inference...")

    # --- 1. Warm-up Phase --- 
    print("Warming up model hidden state...")
    warmup_encoded = mu_law_encode(warmup_sequence, quantization_channels=256)
    warmup_input_audio = (warmup_encoded.float() / 255.0).to(device)

    first_cond_vec = cond_seq[0].unsqueeze(0).repeat(len(warmup_input_audio), 1).to(device)
    warmup_full_input = torch.cat([warmup_input_audio.unsqueeze(-1), first_cond_vec], dim=-1)

    hidden = model.init_hidden(batch_size=1)
    for i in range(len(warmup_full_input)):
        _, hidden = model(warmup_full_input[i].unsqueeze(0), hidden, batch_size=1)

    next_input_audio = warmup_input_audio[-1].unsqueeze(0)

    # --- 2. Generation Phase --- 
    print(f"Generating {len(cond_seq)} audio samples...")
    generated_samples = []
    with torch.no_grad():
        for i in range(len(cond_seq)):
            current_cond_vec = cond_seq[i].unsqueeze(0).to(device)
            next_input_full = torch.cat([next_input_audio.unsqueeze(-1), current_cond_vec], dim=-1)

            logits, hidden = model(next_input_full, hidden, batch_size=1)

            logits = logits.div(temperature).squeeze()
            top_n_logits, top_n_indices = torch.topk(logits, top_n)
            top_n_probs = F.softmax(top_n_logits, dim=-1)
            sampled_relative_idx = torch.multinomial(top_n_probs, 1).squeeze()
            sampled_mu_law_index = top_n_indices[sampled_relative_idx]

            new_audio_sample = mu_law_decode(sampled_mu_law_index, quantization_channels=256)
            generated_samples.append(new_audio_sample.item())

            next_input_audio = (mu_law_encode(new_audio_sample.unsqueeze(0), 256).float() / 255.0).to(device)

    print("Inference complete.")
    return np.array(generated_samples)

In [None]:

num_cond_params = model_config.cond_size
cond_seq = torch.zeros(generation_length, num_cond_params)
cond_seq[:, 0] = 1.0
cond_seq[:, 1] = torch.FloatTensor(multi_linspace([(0,.3),(.5,1), (1,.3)], generation_length))
#cond_seq[:, 2] = torch.linspace(0, 1, generation_length)
cond_seq[:, 2] = torch.FloatTensor(steps(np.array([0,2,4,5,7,9,11,12])/12., generation_length))


In [None]:
plot_condition_tensor(cond_seq, 16000)

In [None]:
warmup_len = 32
t = torch.linspace(0., 1., warmup_len)
warmup_sequence = torch.sin(2 * np.pi * 220.0 * t)

start_time = time.monotonic()
generated_audio = run_inference(
    model=model,
    cond_seq=cond_seq,
    warmup_sequence=warmup_sequence,
    top_n=top_n,
    temperature=temperature
)
elapsed_time = time.monotonic() - start_time

In [None]:
elapsed_time

In [None]:
#print(f"Saving waveform plot to {args.output_plot_path}")
plt.figure(figsize=(20, 5))
plt.plot(generated_audio)
plt.title("Generated Audio Waveform")
plt.xlabel("Sample")
plt.ylabel("Amplitude")
plt.grid()
#plt.savefig(args.output_plot_path)
#plt.close()

plt.show()

In [None]:
Audio(generated_audio, rate=16000)