In [1]:
!nvidia-smi

Sun Apr  2 12:44:48 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   63C    P8    12W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [1]:
!git clone https://github.com/hmartiro/riffusion-inference
%cd riffusion-inference

fatal: destination path 'riffusion-inference' already exists and is not an empty directory.
/content/riffusion-inference


In [3]:
!pip install -r requirements.txt
!pip install gradio
!pip install Pillow==9.1.0

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting accelerate
  Downloading accelerate-0.18.0-py3-none-any.whl (215 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m215.3/215.3 KB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting argh
  Downloading argh-0.28.1-py3-none-any.whl (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.5/40.5 KB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting dacite
  Downloading dacite-1.8.0-py3-none-any.whl (14 kB)
Collecting demucs
  Downloading demucs-4.0.0.tar.gz (1.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m69.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting diffusers>=0.9.0
  Downloading diffusers-0.14.0-py3-none-any.whl (737 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m737.4/737.4 KB[0m [31m62.2 MB/s[0m eta [3

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting gradio
  Downloading gradio-3.24.1-py3-none-any.whl (15.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.7/15.7 MB[0m [31m58.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting aiohttp
  Downloading aiohttp-3.8.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m47.1 MB/s[0m eta [36m0:00:00[0m
Collecting semantic-version
  Downloading semantic_version-2.10.0-py2.py3-none-any.whl (15 kB)
Collecting python-multipart
  Downloading python_multipart-0.0.6-py3-none-any.whl (45 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.7/45.7 KB[0m [31m694.5 kB/s[0m eta [36m0:00:00[0m
[?25hCollecting ffmpy
  Downloading ffmpy-0.3.0.tar.gz (4.8 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting orjson
  Downloading orjson-3.8.9-

In [2]:
from pathlib import Path
from typing import Union, Optional
from os import getcwd, sep
import io
import sys

import PIL

sys.path.append(getcwd() + sep + "riffusion")

from riffusion.datatypes import InferenceInput, InferenceOutput
from riffusion.riffusion_pipeline import RiffusionPipeline
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.server import SEED_IMAGES_DIR



pipeline: RiffusionPipeline


def startup(
        checkpoint: str = "riffusion/riffusion-model-v1",
        no_traced_unet: bool = False,
        device: str = "cuda",
            ) -> None:
    """
    Initializes the pipeline.
    
    Parameters:
        checkpoint (str): a string of where to load a pretrained
        model.
        no_traced_unet (bool): whether to not use a traced unet
        for speedups.
        device (str): "cuda", "cpu", or "mps".

    Returns:
        None
    """
    global pipeline
    pipeline = RiffusionPipeline.load_checkpoint(
        checkpoint=checkpoint,
        use_traced_unet=not no_traced_unet,
        device=device,
    )

def compute(inputs: InferenceInput) -> InferenceOutput:
    """
    Function from the riffusion server :func:`compute_request`.
    
    Parameters:
        inputs (:py:class:`InferenceInput`): The inputs for the request.
    
    Returns:
        str: an :py:class:`InferenceOutput`.
    """
    # Load the seed image by ID
    init_image_path = Path(SEED_IMAGES_DIR, f"{inputs.seed_image_id}.png")

    if not init_image_path.is_file():
        return f"Invalid seed image: {inputs.seed_image_id}", 400
    init_image = PIL.Image.open(str(init_image_path)).convert("RGB")

    # Load the mask image by ID
    mask_image: Optional[PIL.Image.Image] = None
    if inputs.mask_image_id:
        mask_image_path = Path(SEED_IMAGES_DIR, f"{inputs.mask_image_id}.png")
        if not mask_image_path.is_file():
            return f"Invalid mask image: {inputs.mask_image_id}", 400
        mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB")

    # Execute the model to get the spectrogram image
    image = pipeline.riffuse(
        inputs,
        init_image=init_image,
        mask_image=mask_image,
    )

    # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
    params = SpectrogramParams(
        min_frequency=0,
        max_frequency=10000,
    )

    # Reconstruct audio from the image
    # TODO(hayk): It may help performance a bit to cache this object
    converter = SpectrogramImageConverter(params=params, device=str(pipeline.device))

    segment = converter.audio_from_spectrogram_image(
        image,
        apply_filters=True,
    )

    # Export audio to MP3 bytes
    mp3_bytes = io.BytesIO()
    segment.export(mp3_bytes, format="mp3")
    mp3_bytes.seek(0)

    return mp3_bytes.getbuffer(), segment.duration_seconds

In [None]:
startup(device="cuda")

In [12]:
import base64
import json
import gradio as gr
from riffusion.datatypes import InferenceInput, PromptInput, InferenceOutput

seed: int = 1

def muse(prompt_start_input, prompt_end_input, alpha_input, num_steps: int = 50):
    global seed
    transition: bool = prompt_start_input != prompt_end_input
    promptStart: PromptInput = PromptInput(
        prompt=prompt_start_input,
        seed=seed,
    )
    seed += 1 if transition else 0
    promptEnd: PromptInput = PromptInput(
        prompt=prompt_end_input,
        seed=seed,
    )
    inputTotal: InferenceInput = InferenceInput(
        start=promptStart,
        end=promptEnd,
        alpha=alpha_input,
        num_inference_steps=num_steps,
    )
    
    output, seconds = compute(inputs=inputTotal)

    with open("output.mp3", "wb") as f:
        f.write(output)
    return "output.mp3"


In [16]:
from IPython.display import Audio
from IPython.display import display

promptStart = "smooth jazz" # @param {type: "string"}
promptEnd = "piano" # @param {type: "string"}
alpha = 0.23 # @param {type:"slider", min:0.0, max:1, step:0.01}

fileName = muse(promptStart, promptEnd, alpha)
wn = Audio(fileName, autoplay=True)
display(wn)

  0%|          | 0/38 [00:00<?, ?it/s]