In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import os
import argparse
import soundfile as sf
from IPython.display import Audio
import time
from pathlib import Path 

In [None]:
# Local imports from your project structure
from model.gru_audio_model import RNN, GRUAudioConfig
from utils.utils import multi_linspace, steps, plot_condition_tensor

from inference import run_inference

In [None]:
# Point run dir to the folder that Train.ipynb created for its run.

# good one ... run_directory = str(Path('./output/20250811_135640')) # 'Path to the directory of the saved run.'
run_directory = str(Path('./output/20250821_135117_pistons_1024.100_4.48')) # 'Path to the directory of the saved run.'

top_n = 10 #'Sample from the top N most likely outputs.'
temperature = 1 #'Controls the randomness of predictions.'
length_seconds =4.0 #'Length of the audio to generate in seconds.'

sample_rate = 16000
generation_length = int(length_seconds * sample_rate)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = 'cpu'

In [None]:
# -------     Load model     -----------#

config_path = os.path.join(run_directory, "config.pt")
checkpoint_path = os.path.join(run_directory, "checkpoints", "last_checkpoint.pt") #  # "last_checkpoint.pt") # "checkpoint_40.pt") # 

assert os.path.exists(run_directory), f"Run directory not found: {run_directory}"
assert os.path.exists(config_path), f"Config file not found: {config_path}"
assert os.path.exists(checkpoint_path), f"Checkpoint file not found: {checkpoint_path}"

saved_configs = torch.load(config_path, weights_only=False)
model_config = saved_configs["model_config"]

model = RNN(model_config).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

print("Model successfully loaded from checkpoint.")
print(f"Using device = {device}")

In [None]:

num_cond_params = model_config.cond_size
cond_seq = torch.zeros(generation_length, num_cond_params)

# instID
# cond_seq[:, 0] = torch.FloatTensor(steps(np.array([0,1,0,0,0,0,0,0,0]), generation_length))
# cond_seq[:, 1] = torch.FloatTensor(steps(np.array([0,0,1,0,0,0,0,0,0]), generation_length))
# cond_seq[:, 2] = torch.FloatTensor(steps(np.array([0,0,0,1,0,0,0,0,0]), generation_length))
# cond_seq[:, 3] = torch.FloatTensor(steps(np.array([0,0,0,0,1,0,0,0,0]), generation_length))
# cond_seq[:, 4] = torch.FloatTensor(steps(np.array([0,0,0,0,0,1,0,0,0]), generation_length))
# cond_seq[:, 5] = torch.FloatTensor(steps(np.array([0,0,0,0,0,0,1,0,0]), generation_length))
# cond_seq[:, 6] = torch.FloatTensor(steps(np.array([0,0,0,0,0,0,0,1,0]), generation_length))

# param_1
cond_seq[:, 0] = .9


In [None]:
#For nsynth.64.76_sm, the parameter 
#  Param1 - instID
#  Param2 - a (amplitude)
#  Param3 - p (pitch in [0,1], representing midi [64, 76]

plot_condition_tensor(cond_seq, 16000)

In [None]:
warmup_len = 256
t = torch.linspace(0., 1., warmup_len)
warmup_sequence = torch.sin(2 * np.pi * 220.0 * t)*.2    
noise = (torch.rand_like(warmup_sequence) - 0.5) * 2 * .1
warmup_sequence = warmup_sequence+noise    


start_time = time.monotonic()
generated_audio = run_inference(
    model=model,
    cond_seq=cond_seq,
    warmup_sequence=warmup_sequence,
    top_n=top_n,
    temperature=temperature
)
elapsed_time = time.monotonic() - start_time
print(f"Time to generate: {elapsed_time:.2f}")

In [None]:
#print(f"Saving waveform plot to {args.output_plot_path}")
plt.figure(figsize=(20, 5))
plt.plot(generated_audio)
plt.title("Generated Audio Waveform")
plt.xlabel("Sample")
plt.ylabel("Amplitude")
plt.grid()
#plt.savefig(args.output_plot_path)
#plt.close()

plt.show()

In [None]:
Audio(generated_audio, rate=16000)

In [None]:
#---------------------------------------------------  pitch glide  --------------------#

In [None]:
cond_seq = torch.zeros(generation_length, num_cond_params)

num_cond_params = model_config.cond_size
cond_seq = torch.zeros(generation_length, num_cond_params)

# instID
# cond_seq[:, 0] = 0
# cond_seq[:, 1] = 0
# cond_seq[:, 2] = 0   #bugs
# cond_seq[:, 3] = 0
# cond_seq[:, 4] = 0   #pistons
# cond_seq[:, 5] = 0
# cond_seq[:, 6] = 1

# param_1
cond_seq[:, 0] = torch.FloatTensor(multi_linspace([(0, 1),(.5,.5), (1,1)], generation_length))

plot_condition_tensor(cond_seq, 16000)

generated_audio_glide = run_inference(
    model=model,
    cond_seq=cond_seq,
    warmup_sequence=warmup_sequence,
    top_n=top_n,
    temperature=temperature
)


plt.figure(figsize=(20, 5))
plt.plot(generated_audio_glide)
plt.title("Generated Audio Waveform")
plt.xlabel("Sample")
plt.ylabel("Amplitude")
plt.grid()

plt.show()
Audio(generated_audio_glide, rate=16000)