In [1]:
import dataclasses
import io
import typing as T
from pathlib import Path
import sys
sys.path.append('code/riffusion')
import numpy as np
import pydub
import streamlit as st
from PIL import Image

from riffusion.datatypes import InferenceInput, PromptInput
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def run_interpolation(
    inputs: InferenceInput, init_image: Image.Image, device: str = "cuda"
) -> T.Tuple[Image.Image, io.BytesIO]:
    """
    Cached function for riffusion interpolation.
    """
    pipeline = streamlit_util.load_riffusion_checkpoint(
        device=device,
        # No trace so we can have variable width
        no_traced_unet=True,
    )

    image = pipeline.riffuse(
        inputs,
        init_image=init_image,
        mask_image=None,
    )

    # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
    params = SpectrogramParams(
        min_frequency=0,
        max_frequency=10000,
    )

    # Reconstruct from image to audio
    audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
        image=image,
        params=params,
        device=device,
        output_format="mp3",
    )

    return image, audio_bytes

In [8]:
prompt = 'basic prompt a'
seed = 42
denoising = 0.75
guidance = 7.0
prompt_input_a = PromptInput(
        prompt=prompt,
        seed=seed,
        denoising=denoising,
        guidance=guidance,
    )
prompt = 'basic prompt b'
seed = 42
denoising = 0.75
guidance = 7.0
prompt_input_b = PromptInput(
        prompt=prompt,
        seed=seed,
        denoising=denoising,
        guidance=guidance,
    )

In [7]:
num_interpolation_steps = 4
alphas = list(np.linspace(0, 1, num_interpolation_steps))
alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas])

In [5]:
init_image = Image.open('/opt/ml/input/code/riffusion/seed_images/og_beat.png').convert("RGB")
device = 'cuda'

In [9]:
num_inference_steps = 50
image_list: T.List[Image.Image] = []
audio_bytes_list: T.List[io.BytesIO] = []
for i, alpha in enumerate(alphas):
    inputs = InferenceInput(
        alpha=float(alpha),
        num_inference_steps=num_inference_steps,
        seed_image_id="og_beat",
        start=prompt_input_a,
        end=prompt_input_b,
    )
    image, audio_bytes = run_interpolation(
        inputs=inputs,
        init_image=init_image,
        device=device,
    )
    image_list.append(image)
    audio_bytes_list.append(audio_bytes)

2023-01-15 15:02:20.852 
  command:

    streamlit run /opt/conda/envs/riffusion/lib/python3.9/site-packages/ipykernel_launcher.py [ARGUMENTS]
Fetching 15 files: 100%|██████████| 15/15 [00:00<00:00, 15542.13it/s]
You have passed a non-standard module <function RiffusionPipeline.load_checkpoint.<locals>.<lambda> at 0x7f76cf476a60>. We cannot verify whether it has the correct type
100%|██████████| 38/38 [00:03<00:00, 12.07it/s]
Fetching 15 files: 100%|██████████| 15/15 [00:00<00:00, 13688.98it/s]
You have passed a non-standard module <function RiffusionPipeline.load_checkpoint.<locals>.<lambda> at 0x7f76310bc310>. We cannot verify whether it has the correct type
100%|██████████| 38/38 [00:02<00:00, 13.25it/s]
Fetching 15 files: 100%|██████████| 15/15 [00:00<00:00, 16202.57it/s]
You have passed a non-standard module <function RiffusionPipeline.load_checkpoint.<locals>.<lambda> at 0x7f76310bcdc0>. We cannot verify whether it has the correct type
100%|██████████| 38/38 [00:02<00:00, 13.77it

In [10]:
audio_segments = [pydub.AudioSegment.from_file(audio_bytes) for audio_bytes in audio_bytes_list]
concat_segment = audio_segments[0]
for segment in audio_segments[1:]:
    concat_segment = concat_segment.append(segment, crossfade=0)

audio_bytes = io.BytesIO()
concat_segment.export(audio_bytes, format="mp3")
audio_bytes.seek(0)

0

In [18]:
with open('/opt/ml/input/code/riffusion/test.wav', mode='bx') as f:
    f.write(audio_bytes.getvalue()) 