In [1]:
!nvidia-smi

Sun Apr 16 22:50:02 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P8     9W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Please restart runtime if [PIL] is already installed
  - Should only have to do it once

In [2]:
# Riffusion dependencies
!pip install pillow==9.1.0
!pip install accelerate
!pip install argh
!pip install dacite
!pip install demucs
!pip install diffusers>=0.9.0
!pip install flask
!pip install flask_cors
!pip install numpy
!pip install plotly
!pip install pydub
!pip install pysoundfile
!pip install scipy
!pip install soundfile
!pip install sox
!pip install streamlit>=1.10.0
!pip install torch
!pip install torchaudio
!pip install torchvision
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pillow==9.1.0
  Downloading Pillow-9.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.3/4.3 MB[0m [31m45.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pillow
  Attempting uninstall: pillow
    Found existing installation: Pillow 8.4.0
    Uninstalling Pillow-8.4.0:
      Successfully uninstalled Pillow-8.4.0
Successfully installed pillow-9.1.0


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting accelerate
  Downloading accelerate-0.18.0-py3-none-any.whl (215 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m215.3/215.3 kB[0m [31m20.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: accelerate
Successfully installed accelerate-0.18.0
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting argh
  Downloading argh-0.28.1-py3-none-any.whl (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.5/40.5 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: argh
Successfully installed argh-0.28.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting dacite
  Downloading dacite-1.8.0-py3-none-any.whl (14 kB)
Installing collected packages: dacite
Successfully installed dacite-1.8.0
Looking in 

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting flask_cors
  Downloading Flask_Cors-3.0.10-py2.py3-none-any.whl (14 kB)
Installing collected packages: flask_cors
Successfully installed flask_cors-3.0.10
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pydub
  Downloading pydub-0.25.1-py2.py3-none-any.whl (32 kB)
Installing collected packages: pydub
Successfully installed pydub-0.25.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pysoundfile
  Downloading PySoundFile-0.9.0.post1-py2.py3-none-any.whl (24 kB)
Inst

In [2]:
!git clone https://github.com/hmartiro/riffusion-inference
%cd riffusion-inference

Cloning into 'riffusion-inference'...
remote: Enumerating objects: 793, done.[K
remote: Counting objects: 100% (370/370), done.[K
remote: Compressing objects: 100% (112/112), done.[K
remote: Total 793 (delta 296), reused 263 (delta 257), pack-reused 423[K
Receiving objects: 100% (793/793), 8.29 MiB | 18.22 MiB/s, done.
Resolving deltas: 100% (491/491), done.
/content/riffusion-inference


In [3]:
from pathlib import Path
from typing import Union, Optional
from os import getcwd, sep
import io
import sys

import PIL

sys.path.append(getcwd() + sep + "riffusion")

from riffusion.datatypes import InferenceInput, InferenceOutput
from riffusion.riffusion_pipeline import RiffusionPipeline
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.server import SEED_IMAGES_DIR



pipeline: RiffusionPipeline


def startup(
        checkpoint: str = "riffusion/riffusion-model-v1",
        no_traced_unet: bool = False,
        device: str = "cuda",
            ) -> None:
    """
    Initializes the pipeline.
    
    Parameters:
        checkpoint (str): a string of where to load a pretrained
        model.
        no_traced_unet (bool): whether to not use a traced unet
        for speedups.
        device (str): "cuda", "cpu", or "mps".

    Returns:
        None
    """
    global pipeline
    pipeline = RiffusionPipeline.load_checkpoint(
        checkpoint=checkpoint,
        use_traced_unet=not no_traced_unet,
        device=device,
    )

def compute(inputs: InferenceInput) -> InferenceOutput:
    """
    Function from the riffusion server :func:`compute_request`.
    
    Parameters:
        inputs (:py:class:`InferenceInput`): The inputs for the request.
    
    Returns:
        str: an :py:class:`InferenceOutput`.
    """
    # Load the seed image by ID
    init_image_path = Path(SEED_IMAGES_DIR, f"{inputs.seed_image_id}.png")

    if not init_image_path.is_file():
        return f"Invalid seed image: {inputs.seed_image_id}", 400
    init_image = PIL.Image.open(str(init_image_path)).convert("RGB")

    # Load the mask image by ID
    mask_image: Optional[PIL.Image.Image] = None
    if inputs.mask_image_id:
        mask_image_path = Path(SEED_IMAGES_DIR, f"{inputs.mask_image_id}.png")
        if not mask_image_path.is_file():
            return f"Invalid mask image: {inputs.mask_image_id}", 400
        mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB")

    # Execute the model to get the spectrogram image
    image = pipeline.riffuse(
        inputs,
        init_image=init_image,
        mask_image=mask_image,
    )

    # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
    params = SpectrogramParams(
        min_frequency=0,
        max_frequency=10000,
    )

    # Reconstruct audio from the image
    # TODO(hayk): It may help performance a bit to cache this object
    converter = SpectrogramImageConverter(params=params, device=str(pipeline.device))

    segment = converter.audio_from_spectrogram_image(
        image,
        apply_filters=True,
    )
    return segment

In [4]:
startup(device="cuda")

Downloading (…)ain/model_index.json:   0%|          | 0.00/543 [00:00<?, ?B/s]

Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]

Downloading pytorch_model.bin:   0%|          | 0.00/492M [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/806 [00:00<?, ?B/s]

Downloading (…)_encoder/config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

Downloading (…)tokenizer/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

Downloading (…)tokenizer/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

Downloading (…)on_pytorch_model.bin:   0%|          | 0.00/335M [00:00<?, ?B/s]

Downloading (…)7ae/unet/config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]

Downloading (…)on_pytorch_model.bin:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

Downloading (…)cheduler_config.json:   0%|          | 0.00/284 [00:00<?, ?B/s]

Downloading (…)87ae/vae/config.json:   0%|          | 0.00/547 [00:00<?, ?B/s]

You have passed a non-standard module <function RiffusionPipeline.load_checkpoint.<locals>.<lambda> at 0x7fbbf7992c10>. We cannot verify whether it has the correct type
  deprecate(


Downloading unet_traced.pt:   0%|          | 0.00/1.72G [00:00<?, ?B/s]

In [5]:
import base64
import json
from riffusion.datatypes import InferenceInput, PromptInput, InferenceOutput

seed: int = 1

def muse(prompt_start_input, prompt_end_input, alpha_input, num_steps: int = 50):
    """
    Gets music and exports it to a file.

    Parameters:
      prompt_start_input (str): starting prompt
      prompt_end_input (str): ending prompt
      num_steps (int): number of inference steps

    Returns:
      The string of the output file path
    """    
    segment = museHelp(prompt_start_input, prompt_end_input, alpha_input, num_steps)

    # Export audio to MP3 bytes
    wav_bytes = io.BytesIO()
    segment.export(wav_bytes, format="wav")
    wav_bytes.seek(0)
    # Write to file
    with open("output.mp3", "wb") as f:
        f.write(wav_bytes)
    return "output.mp3"

def museHelp(prompt_start_input, prompt_end_input, num_steps: int = 50):
    """
    Gets music.

    Parameters:
      prompt_start_input (str): starting prompt
      prompt_end_input (str): ending prompt
      num_steps (int): number of inference steps

    Returns:
      The pydub.AudioSegment containing the music
    """    
    global seed
    segments = None # initial segment
    for alpha_input in range(0, 4): # for 4 iterations
      ai: float = float(alpha_input) / 4 # alpha_input in float
      transition: bool = prompt_start_input != prompt_end_input
      promptStart: PromptInput = PromptInput(
          prompt=prompt_start_input,
          seed=seed,
      ) # starting prompt
      seed += 1 if transition else 0
      promptEnd: PromptInput = PromptInput(
          prompt=prompt_end_input,
          seed=seed,
      ) # ending prompt
      inputTotal: InferenceInput = InferenceInput(
          start=promptStart,
          end=promptEnd,
          alpha=0.0,
          num_inference_steps=num_steps,
      ) # the total input

      segment = compute(inputs=inputTotal) # makes music
      if segments == None: # if initial, set it to first segment
        segments = segment
      else: # else, concatenate it
        segments += segment
      seed += 1
    return segments


In [6]:
from IPython.display import Audio
from IPython.display import display

# Example of music

promptStart = "classical piano" # @param {type: "string"}
promptEnd = "heavy metal" # @param {type: "string"}

segment  = museHelp(promptStart, promptEnd)
#filename = muse(promptStart, promptEnd, alpha)
wn = Audio(data=segment.get_array_of_samples(), rate=segment.frame_rate, autoplay=True)
display(wn)

  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/38 [00:00<?, ?it/s]

In [7]:
def convert(d: dict)-> bytes:
  """
   Converts the dictionary into bytes using utf-8 encoding.

   Parameters:
      d (dict): the dict to convert.

   Returns:
      bytes: The converted dictionary.
  """
  return bytes(str(d), "utf-8")

In [26]:
import socket, os
import pydub

server_ip: str = '3.142.167.54' # @param {type: "string"}
port: int = 12157 # @param {type: "number"}
CHUNK: int = 1024
sample_rate: int = 44100
client_socket = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
socket_address = (server_ip,port)
client_socket.connect(socket_address)
client: dict = {"client": "music", "threads": "1", "sessionNum": "0"}
client_socket.sendall(bytes(str(client), "utf-8"))

In [8]:
import ast

def deconvert(message: bytes) -> Union[str, dict]:
   """
   Converts the bytes into a str or dict using utf-8 encoding.

   Parameters:
      message (bytes): the bytes to convert.

   Returns:
      str | dict: The converted dictionary or string if failed.
   """
   strMessage: str = message.decode("utf-8") # decoded string
   if strMessage == "": # if empty, return string
      return strMessage
   try: # try to convert it to dict, return string if fails
      dictMessage: dict = ast.literal_eval(strMessage)
      return dictMessage
   except ValueError:
      return strMessage
   except SyntaxError:
      return strMessage
   return strMessage

In [23]:
def sendAudio(segment: pydub.AudioSegment):
  """
  Sends the audio segment in bytes to server.

  Parameters:
    segment (pydub.AudioSegment): the segment to send.
  
  Returns:
    None
  """
  data = io.BytesIO()
  segment.export(data, format="wav") # gets data to send
  sData: bytes = data.getvalue()
  client_socket.sendall(sData) # send data chunk

In [24]:
sendAudio(segment)

In [27]:
def receive():
  """
  Endless recieve function.
  """
  stop = False # don't stop
  try: # for any unforseen issues
    while not stop: # while not stopped
      messages = client_socket.recv(CHUNK)
      inputs = deconvert(messages) # deconvert message to dict
      if type(inputs) == str: # if failed, break
        print("AAH")
        stop = True
      else: # else, if inputs are ended, break, else make music
        if inputs["promptStart"] == "endstart" and inputs["promptEnd"] == "endend":
          stop = True
          continue
        music = museHelp(inputs["promptStart"], inputs["promptEnd"])
        sendAudio(music)
  except Exception as err: # any error
    print(err) # print it and close connection
    client_socket.close()

In [28]:
receive()

  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/38 [00:00<?, ?it/s]

  0%|          | 0/38 [00:00<?, ?it/s]

AAH


In [25]:
client_socket.close()