<a href="https://colab.research.google.com/github/jasper-zheng/streamable-stable-audio-open/blob/main/demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Streaming Stable Audio Open 1.0's Autoencoder  

Streaming pre-trained [Stable Audio Open 1.0](https://huggingface.co/stabilityai/stable-audio-open-1.0)'s autoencoder with cached convolution, for realtime continuous inference. And scripting it to TorchScript to be used with [nn~](https://github.com/acids-ircam/nn_tilde) in MaxMSP/PureData.  

Author: Jasper Shuoyang Zheng

## Installation (Only do this once)

In [None]:
!git clone https://github.com/jasper-zheng/streamable-stable-audio-open.git
%cd streamable-stable-audio-open
!pip install -r requirements.txt

## Import libraries

In [None]:
import sys

base_dir = 'streamable-stable-audio-open'
sys.path.append(f'{base_dir}')

import torch
from models import get_pretrained_pretransform
from export import remove_parametrizations

import librosa, time
from IPython.display import Audio, display

import cached_conv as cc

cc.use_cached_conv(True)

device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"

## Download and load pre-trained model

Before proceed to model downloading, you need to:  
1. use `hf auth login` in terminal to login to your HuggingFace account,
2. go to [stable-audio-open-1.0](https://huggingface.co/stabilityai/stable-audio-open-1.0) and agree to Stability AI's License Agreement to get access to the model weights.

In [None]:
## Load the autoencoder from stable-audio-open-1.0

autoencoder, model_config = get_pretrained_pretransform("stabilityai/stable-audio-open-1.0",
                                                         model_half=True,
                                                         skip_bottleneck=True,
                                                         device=device)

print(f"sample_rate: {model_config.get('sample_rate', 'unknown')}")
print(f"latent_dim: {model_config['model']['pretransform']['config'].get('latent_dim', 'unknown')}")
print(f"downsampling_ratio: {model_config['model']['pretransform']['config'].get('downsampling_ratio', 'unknown')}")
print(f"io_channels: {model_config['model']['pretransform']['config'].get('io_channels', 'unknown')}")

autoencoder = autoencoder.to(device)
autoencoder.eval()

remove_parametrizations(autoencoder)

## Prepare audio chunks

In [None]:
# Load an example audio file and split into chunks

buffer_size = 4096

audio_path = librosa.example('fishin', hq=True)
wv, sr = librosa.load(audio_path, sr=44100, mono=False)
wv = torch.tensor(wv, device=device)[:,buffer_size*50:buffer_size*150].unsqueeze(0)  # make stereo, limit length for test
wv_chunks = [wv[:, :, i*buffer_size:(i+1)*buffer_size] for i in range(100)]

print(f'waveform shape: {wv.shape}')
print(f'number of chunks: {len(wv_chunks)}')
print(f'chunk shape: {wv_chunks[0].shape}')

## Forward pass the encoder and decoder

In [None]:
print(f'Running encoder, device: {device}')
## Run audio chunks to the encoder

latent_chunks = []
with torch.no_grad():
    torch.cuda.synchronize() if device == "cuda" else torch.mps.synchronize()
    start_time = time.perf_counter()
    for i, w in enumerate(wv_chunks):
        latent = autoencoder.encode(w)
        latent_chunks.append(latent)
    torch.cuda.synchronize() if device == "cuda" else torch.mps.synchronize()
    end_time = time.perf_counter()
    print(f'Encoder execution time: {end_time - start_time:.2f} seconds')


print(f'Running decoder, device: {device}')
## Run audio chunks to the decoder
wv_recons = []
with torch.no_grad():
    torch.cuda.synchronize() if device == "cuda" else torch.mps.synchronize()
    start_time = time.perf_counter()
    for i, latent in enumerate(latent_chunks):
        wv_recon = autoencoder.decode(latent)
        wv_recons.append(wv_recon)
    torch.cuda.synchronize() if device == "cuda" else torch.mps.synchronize()
    end_time = time.perf_counter()
    print(f'Decoder execution time: {end_time - start_time:.2f} seconds')

wv_recon = torch.cat(wv_recons, dim=-1)
print(f'reconstructed waveform shape: {wv_recon.shape}')

In [None]:
"Original:"
display(Audio(wv.cpu().numpy()[0], rate=sr))
"Reconstructed:"
display(Audio(wv_recon.cpu().numpy()[0], rate=sr))

## Export to TorchScript  

If you have [nn~](https://github.com/acids-ircam/nn_tilde) in MaxMSP/PureData:

In [None]:
!python streamable-stable-audio-open/export.py --output exported/stable-vae.ts --streaming