<a href="https://colab.research.google.com/github/rmnrnm/TNS-Pull-Request-Practice/blob/master/demo_riffusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# riffusion colab demo

Run [riffusion](https://www.riffusion.com/about) in a gradio demo with a colab host

Riffusion project by [Seth Forsgren](https://twitter.com/sethforsgren) and [Hayk Martiros](https://github.com/hmartiro), colab notebook by [Jasper Gilley](https://twitter.com/0xjasper)

Feel free to DM Jasper on Twitter if you have any problems with the notebook

Some cool prompt ideas can be found at https://ai-art-wiki.com/wiki/Riffusion#Prompts

In [None]:
!nvidia-smi

In [None]:
#@title Clone the inference repo
!git clone https://github.com/riffusion/riffusion.git
%cd riffusion

In [None]:
%pip install -e .

In [None]:
from riffusion.streamlit import util

In [None]:
#@title Install requirements (you may need to restart the kernel after this)
!pip install -r requirements.txt


In [None]:
import os 
import io
import numpy as np
import dataclasses
import IPython.display as ipd
from PIL import Image
import pydub
from random import randint

from riffusion.streamlit import util
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.datatypes import InferenceInput, PromptInput

In [None]:
pipeline = util.load_riffusion_checkpoint(
    device="cuda",
    checkpoint=util.DEFAULT_CHECKPOINT,
    # No trace so we can have variable width
    no_traced_unet=True,
)


In [None]:
def make_audio_from_spectrogram(image):
    params = SpectrogramParams(
        min_frequency=0,
        max_frequency=10000,
    )
    segment = util.audio_segment_from_spectrogram_image(
        image=image,
        params=params,
        device="cuda",
    )
    return segment

def create_spectrogram_from_prompt(prompt, negative_prompt, guidance=7, seed=42, width=512):
    image = util.run_txt2img(
        prompt=prompt,
        num_inference_steps=30,
        guidance=7,
        negative_prompt=negative_prompt,
        seed=seed,
        width=width,
        height=512,
        checkpoint="riffusion/riffusion-model-v1",
        device="cuda",
        #scheduler=scheduler,
    )

    ipd.display(image)
    
    return image


  
def render(prompt_a, prompt_b, num_interpolation_steps=5, seed_a=None, seeb_b=None, negative_prompt_a=False, negative_prompt_b=False, denoising=.75):
    
    """
    Interpolate between prompts in the latent space.
    """

    """
    This tool allows specifying two endpoints and generating a long-form interpolation
    between them that traverses the latent space. The interpolation is generated by
    the method described at https://www.riffusion.com/about. A seed image is used to
    set the beat and tempo of the generated audio, and can be set in the sidebar.
    Usually the seed is changed or the prompt, but not both at once. You can browse
    infinite variations of the same prompt by changing the seed.
    For example, try going from "church bells" to "jazz" with 10 steps and 0.75 denoising.
    This will generate a 50 second clip at 5 seconds per step. Then play with the seeds
    or denoising to get different variations.
    """

    #device = "cuda"
    extension = "mp3"

    num_inference_steps = 30

    guidance = 7 # How much the model listens to the text prompt

    init_image_name = "og_beat"
    

    alpha_power = 1


    alphas = np.linspace(0, 1, num_interpolation_steps)

    # Apply power scaling to alphas to customize the interpolation curve
    alphas_shifted = alphas * 2 - 1
    alphas_shifted = (np.abs(alphas_shifted) ** alpha_power * np.sign(alphas_shifted) + 1) / 2
    alphas = alphas_shifted

    if seed_a is None:
        seed_a = randint(1,9999)
        
    if seeb_b is None:
        seeb_b = randint(1,9999)

    # Prompt inputs A and B in two columns

    prompt_input_a = PromptInput(
        guidance=guidance, **get_prompt_inputs(prompt_a, negative_prompt_a, seed=seed_a, denoising_default=denoising)
    )

    prompt_input_b = PromptInput(
        guidance=guidance, **get_prompt_inputs(prompt_b, negative_prompt_b, seed=seeb_b, denoising_default=denoising)
    )


    init_image_path = os.path.join("seed_images", f"{init_image_name}.png")
    init_image = Image.open(str(init_image_path)).convert("RGB")

    # TODO(hayk): Move this code into a shared place and add to riffusion.cli
    image_list: T.List[Image.Image] = []
    audio_bytes_list: T.List[io.BytesIO] = []
    for i, alpha in enumerate(alphas):
        inputs = InferenceInput(
            alpha=float(alpha),
            num_inference_steps=num_inference_steps,
            seed_image_id="og_beat",
            start=prompt_input_a,
            end=prompt_input_b,
        )

        if i == 0:
            print("Example input JSON")
            print(dataclasses.asdict(inputs))

        image, audio_bytes = run_interpolation(
            pipeline,
            inputs=inputs,
            init_image=init_image,

        )


        image_list.append(image)
        audio_bytes_list.append(audio_bytes)

    # TODO(hayk): Concatenate with overlap and better blending like in audio to audio
    audio_segments = [pydub.AudioSegment.from_file(audio_bytes) for audio_bytes in audio_bytes_list]
    concat_segment = audio_segments[0]
    for segment in audio_segments[1:]:
        concat_segment = concat_segment.append(segment, crossfade=0)

    audio_bytes = io.BytesIO()
    concat_segment.export(audio_bytes, format=extension)
    audio_bytes.seek(0)

    print(f"Duration: {concat_segment.duration_seconds:.3f} seconds")
    #ipd.display(ipd.Audio(audio_bytes))

    output_name = (
        f"{prompt_input_a.prompt.replace(' ', '_')}_"
        f"{prompt_input_b.prompt.replace(' ', '_')}.{extension}"
    )
    return audio_bytes, concat_segment
    

def get_prompt_inputs(
    prompt,
    negative_prompt,
    seed,
    denoising_default: float = 0.5,
):
    """
    Compute prompt inputs from widgets.
    """
    p = {}

    p["prompt"] = prompt

    if negative_prompt:
        p["negative_prompt"] = negative_prompt

    p["seed"] = seed

    p["denoising"] = denoising_default

    return p


def run_interpolation(
    pipeline,
    inputs: InferenceInput,
    init_image: Image.Image,
    extension: str = "mp3",
):
    """
    Cached function for riffusion interpolation.
    """


    image = pipeline.riffuse(
        inputs,
        init_image=init_image,
        mask_image=None,
    )

    # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
    params = SpectrogramParams(
        min_frequency=0,
        max_frequency=10000,
    )

    # Reconstruct from image to audio
    audio_bytes = util.audio_bytes_from_spectrogram_image(
        image=image,
        params=params,
        device="cuda",
        output_format=extension,
    )

    return image, audio_bytes

In [None]:
#@title Run with Colab interface

prompt= 'dog barking'#@param {type:"string"}
negative_prompt = ""#@param {type:"string"}


image = create_spectrogram_from_prompt(prompt, negative_prompt, guidance=20)

make_audio_from_spectrogram(image)

In [None]:
audio_bytes, concat_segment = render('piano', 'oboe', num_interpolation_steps=5)

In [None]:
concat_segment