In [1]:
!nvidia-smi

Fri Apr 14 18:27:47 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   60C    P8    12W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

## Please restart runtime if [PIL] is already installed
  - Should only have to do it once

In [2]:
# Riffusion dependencies
!pip install pillow==9.1.0
!pip install accelerate
!pip install argh
!pip install dacite
!pip install demucs
!pip install diffusers>=0.9.0
!pip install flask
!pip install flask_cors
!pip install numpy
!pip install plotly
!pip install pydub
!pip install pysoundfile
!pip install scipy
!pip install soundfile
!pip install sox
!pip install streamlit>=1.10.0
!pip install torch
!pip install torchaudio
!pip install torchvision
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in

In [3]:
!git clone https://github.com/hmartiro/riffusion-inference
%cd riffusion-inference

Cloning into 'riffusion-inference'...
remote: Enumerating objects: 793, done.[K
remote: Counting objects: 100% (432/432), done.[K
remote: Compressing objects: 100% (133/133), done.[K
remote: Total 793 (delta 343), reused 304 (delta 298), pack-reused 361[K
Receiving objects: 100% (793/793), 8.29 MiB | 28.48 MiB/s, done.
Resolving deltas: 100% (489/489), done.
/content/riffusion-inference


In [4]:
from pathlib import Path
from typing import Union, Optional
from os import getcwd, sep
import io
import sys

import PIL

sys.path.append(getcwd() + sep + "riffusion")

from riffusion.datatypes import InferenceInput, InferenceOutput
from riffusion.riffusion_pipeline import RiffusionPipeline
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.server import SEED_IMAGES_DIR



pipeline: RiffusionPipeline


def startup(
        checkpoint: str = "riffusion/riffusion-model-v1",
        no_traced_unet: bool = False,
        device: str = "cuda",
            ) -> None:
    """
    Initializes the pipeline.
    
    Parameters:
        checkpoint (str): a string of where to load a pretrained
        model.
        no_traced_unet (bool): whether to not use a traced unet
        for speedups.
        device (str): "cuda", "cpu", or "mps".

    Returns:
        None
    """
    global pipeline
    pipeline = RiffusionPipeline.load_checkpoint(
        checkpoint=checkpoint,
        use_traced_unet=not no_traced_unet,
        device=device,
    )

def compute(inputs: InferenceInput) -> InferenceOutput:
    """
    Function from the riffusion server :func:`compute_request`.
    
    Parameters:
        inputs (:py:class:`InferenceInput`): The inputs for the request.
    
    Returns:
        str: an :py:class:`InferenceOutput`.
    """
    # Load the seed image by ID
    init_image_path = Path(SEED_IMAGES_DIR, f"{inputs.seed_image_id}.png")

    if not init_image_path.is_file():
        return f"Invalid seed image: {inputs.seed_image_id}", 400
    init_image = PIL.Image.open(str(init_image_path)).convert("RGB")

    # Load the mask image by ID
    mask_image: Optional[PIL.Image.Image] = None
    if inputs.mask_image_id:
        mask_image_path = Path(SEED_IMAGES_DIR, f"{inputs.mask_image_id}.png")
        if not mask_image_path.is_file():
            return f"Invalid mask image: {inputs.mask_image_id}", 400
        mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB")

    # Execute the model to get the spectrogram image
    image = pipeline.riffuse(
        inputs,
        init_image=init_image,
        mask_image=mask_image,
    )

    # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
    params = SpectrogramParams(
        min_frequency=0,
        max_frequency=10000,
    )

    # Reconstruct audio from the image
    # TODO(hayk): It may help performance a bit to cache this object
    converter = SpectrogramImageConverter(params=params, device=str(pipeline.device))

    segment = converter.audio_from_spectrogram_image(
        image,
        apply_filters=True,
    )
    return segment

In [5]:
startup(device="cuda")

Downloading (…)ain/model_index.json:   0%|          | 0.00/543 [00:00<?, ?B/s]

Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]

Downloading (…)_encoder/config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

Downloading (…)tokenizer/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

Downloading (…)tokenizer/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

Downloading (…)cheduler_config.json:   0%|          | 0.00/284 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/806 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/492M [00:00<?, ?B/s]

Downloading (…)7ae/unet/config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]

Downloading (…)on_pytorch_model.bin:   0%|          | 0.00/335M [00:00<?, ?B/s]

Downloading (…)on_pytorch_model.bin:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

Downloading (…)87ae/vae/config.json:   0%|          | 0.00/547 [00:00<?, ?B/s]

You have passed a non-standard module <function RiffusionPipeline.load_checkpoint.<locals>.<lambda> at 0x7fb9cdefb160>. We cannot verify whether it has the correct type
  deprecate(


Downloading unet_traced.pt:   0%|          | 0.00/1.72G [00:00<?, ?B/s]

In [6]:
import base64
import json
from riffusion.datatypes import InferenceInput, PromptInput, InferenceOutput

seed: int = 1

def muse(prompt_start_input, prompt_end_input, alpha_input, num_steps: int = 50):    
    segment = museHelp(prompt_start_input, prompt_end_input, alpha_input, num_steps)

    # Export audio to MP3 bytes
    wav_bytes = io.BytesIO()
    segment.export(wav_bytes, format="wav")
    wav_bytes.seek(0)

    with open("output.mp3", "wb") as f:
        f.write(wav_bytes)
    return "output.mp3"

def museHelp(prompt_start_input, prompt_end_input, alpha_input, num_steps: int = 50):
    global seed
    transition: bool = prompt_start_input != prompt_end_input
    promptStart: PromptInput = PromptInput(
        prompt=prompt_start_input,
        seed=seed,
    )
    seed += 1 if transition else 0
    promptEnd: PromptInput = PromptInput(
        prompt=prompt_end_input,
        seed=seed,
    )
    inputTotal: InferenceInput = InferenceInput(
        start=promptStart,
        end=promptEnd,
        alpha=alpha_input,
        num_inference_steps=num_steps,
    )
    
    segment = compute(inputs=inputTotal)
    return segment


In [None]:
!pip install sounddevice

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sounddevice
  Downloading sounddevice-0.4.6-py3-none-any.whl (31 kB)
Installing collected packages: sounddevice
Successfully installed sounddevice-0.4.6


In [None]:
import sounddevice as sd

In [7]:
from IPython.display import Audio
from IPython.display import display

promptStart = "classical piano" # @param {type: "string"}
promptEnd = "heavy metal" # @param {type: "string"}
alpha = 0 # @param {type:"slider", min:0.0, max:1, step:0.01}

segment  = museHelp(promptStart, promptEnd, alpha)
#filename = muse(promptStart, promptEnd, alpha)
wn = Audio(data=segment.get_array_of_samples(), rate=segment.frame_rate, autoplay=True)
display(wn)

  0%|          | 0/38 [00:00<?, ?it/s]

In [15]:
def convert(d: dict)-> bytes:
  return bytes(str(d), "utf-8")

In [27]:
import socket, os
import pydub

server_ip: str = '3.13.191.225' # @param {type: "string"}
port: int = 18567 # @param {type: "number"}
CHUNK: int = 1024
sample_rate: int = 44100
client_socket = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
socket_address = (server_ip,port)
client_socket.connect(socket_address)
client: dict = {"client": "music", "threads": "1", "sessionNum": "0"}
client_socket.sendall(bytes(str(client), "utf-8"))

In [None]:
def deconvert(message: bytes) -> Union[str, dict]:
   strMessage: str = message.decode("utf-8")
   try:
      dictMessage: dict = ast.literal_eval(strMessage)
      return dictMessage
   except ValueError:
      return strMessage
   return strMessage

In [None]:
client_socket.recv

In [1]:
def sendAudio(segment: pydub.AudioSegment):
  data = segment.get_array_of_samples()
  for i in range(0, len(data) + CHUNK, CHUNK):
    mini = data[i:min(i + CHUNK, len(data) - 1)]
    message: dict = {"data": mini, "sample_rate": segment.frame_rate, "end": False}
    client_socket.sendall(convert(message))
  client_socket.sendall({"end": True})

NameError: ignored

In [None]:
def receive():
  while True:
    messages = client_socket.recv(50)
    inputs = deconvert(messages)
    if type(inputs) == str:
      print("AAH")
      skip = False
    else:
      music = museHelp(inputs["promptStart"], inputs["promptEnd"], inputs["alpha"])
      sendAudio(music)