In [1]:
# mount drive
from google.colab import drive
drive.mount('/content/drive/')

# cd to project root
%cd /content/drive/MyDrive/riffusion

# install environment
!curl -L https://repo.anaconda.com/miniconda/Miniconda3-py39_4.12.0-Linux-x86_64.sh -o miniconda.sh
!chmod +x miniconda.sh
!sh miniconda.sh -b -p /content/miniconda
!/content/miniconda/bin/pip install -r requirements.txt
!/content/miniconda/bin/pip install --upgrade ipython ipykernel

Mounted at /content/drive/
[Errno 2] No such file or directory: '/content/drive/MyDrive/riffusion'
/content
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 73.0M  100 73.0M    0     0   103M      0 --:--:-- --:--:-- --:--:--  103M
PREFIX=/content/miniconda
Unpacking payload ...
Collecting package metadata (current_repodata.json): - \ | done
Solving environment: - \ | / - \ done

## Package Plan ##

  environment location: /content/miniconda

  added / updated specs:
    - _libgcc_mutex==0.1=main
    - _openmp_mutex==4.5=1_gnu
    - brotlipy==0.7.0=py39h27cfd23_1003
    - ca-certificates==2022.3.29=h06a4308_1
    - certifi==2021.10.8=py39h06a4308_2
    - cffi==1.15.0=py39hd667e15_1
    - charset-normalizer==2.0.4=pyhd3eb1b0_0
    - colorama==0.4.4=pyhd3eb1b0_0
    - conda-content-trust==0.1.1=pyhd3eb1b0_0
    - conda-package-handling==1.8.1=py39h7f8727e_0
 

In [2]:
# install additional dependencies
! pip install dacite flask_cors flask_ngrok pyngrok

Collecting dacite
  Downloading dacite-1.9.2-py3-none-any.whl.metadata (17 kB)
Collecting flask_cors
  Downloading flask_cors-6.0.1-py3-none-any.whl.metadata (5.3 kB)
Collecting flask_ngrok
  Downloading flask_ngrok-0.0.25-py3-none-any.whl.metadata (1.8 kB)
Collecting pyngrok
  Downloading pyngrok-7.3.0-py3-none-any.whl.metadata (8.1 kB)
Downloading dacite-1.9.2-py3-none-any.whl (16 kB)
Downloading flask_cors-6.0.1-py3-none-any.whl (13 kB)
Downloading flask_ngrok-0.0.25-py3-none-any.whl (3.1 kB)
Downloading pyngrok-7.3.0-py3-none-any.whl (25 kB)
Installing collected packages: pyngrok, dacite, flask_ngrok, flask_cors
Successfully installed dacite-1.9.2 flask_cors-6.0.1 flask_ngrok-0.0.25 pyngrok-7.3.0


In [3]:
# where script saved under
%cd /content/drive/MyDrive/Training-Free-StyleID

/content/drive/MyDrive/Training-Free-StyleID


In [5]:
"""
Flask server that serves the riffusion model as an API.
"""

import dataclasses
import io
import json
import logging
import time
import typing as T
from pathlib import Path

import dacite
import flask
import PIL
import torch
from flask_cors import CORS
from pyngrok import ngrok

# Fix CUDA linear algebra backend to avoid cusolver errors
torch.backends.cuda.preferred_linalg_library('magma')

# NOTE original riffusion pipeline
from riffusion.riffusion_pipeline import RiffusionPipeline
from riffusion.datatypes import InferenceInput, InferenceOutput

from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams

# from flask_ngrok import run_with_ngrok
NGROK_AUTH_TOKEN = "32MmrpMI4sZN558sIugyRuhDgDg_5AdY64F9xihYgNZZfyHJL"

# Flask app with CORS
app = flask.Flask(__name__)
CORS(app)


# Create a logger object
logger = logging.getLogger("my_server")
logger.setLevel(logging.DEBUG)

# Log at the INFO level to both stdout and disk
logging.basicConfig(level=logging.INFO)
logging.getLogger().addHandler(logging.FileHandler("server.log"))

# Create a file handler to write logs to a file
file_handler = logging.FileHandler("server.log")
file_handler.setLevel(logging.DEBUG)

# set format
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s')

# initalise file handler
file_handler.setFormatter(formatter)

# initalise console handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)

# Add handlers to the logger
logger.addHandler(file_handler)

# Global variable for the model pipeline
PIPELINE: T.Optional[RiffusionPipeline] = None

# set auth token for free n_grok usage
ngrok.set_auth_token(NGROK_AUTH_TOKEN)



In [6]:
def compute_request(
    inputs: InferenceInput,
    pipeline: RiffusionPipeline,
) -> T.Union[str, T.Tuple[str, int]]:
    """
    Does all the heavy lifting of the request.

    Args:
        inputs: The input dataclass
        pipeline: The riffusion model pipeline
    """

    # Load the seed image by ID
    init_image_path = Path(f"{inputs.seed_image_path}.png")

    print("######################### input image path: ", init_image_path)

    if not init_image_path.is_file():
        return f"Invalid seed image: {inputs.seed_image_path}", 400
    init_image = PIL.Image.open(str(init_image_path)).convert("RGB")

    # Load the mask image by ID
    mask_image: T.Optional[PIL.Image.Image] = None

    # NOTE pass mask image here
    # mask_image = PIL.Image.open("...png").convert("RGB")
    if inputs.mask_image_path:
        mask_image_path = Path(f"{inputs.mask_image_path}.png")
        if not mask_image_path.is_file():
            return f"Invalid mask image: {inputs.mask_image_path}", 400
        mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB")

    # Execute the model to get the spectrogram image
    image = pipeline.riffuse(
        inputs,
        init_image=init_image,
        mask_image=mask_image,
    )

    # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
    params = SpectrogramParams(
        min_frequency=0,
        max_frequency=10000,
    )

    # Reconstruct audio from the image
    # TODO(hayk): It may help performance a bit to cache this object
    # Use CPU for audio processing to avoid CUDA solver issues
    converter = SpectrogramImageConverter(params=params, device="cpu")

    # NOTE 轉回 audio signal
    segment = converter.audio_from_spectrogram_image(
        image,
        apply_filters=True,
    )

    # Export audio to MP3 bytes
    mp3_bytes = io.BytesIO()
    segment.export(mp3_bytes, format="mp3")
    mp3_bytes.seek(0)

    # Export image to JPEG bytes
    image_bytes = io.BytesIO()
    image.save(image_bytes, exif=image.getexif(), format="JPEG")
    image_bytes.seek(0)

    # Assemble the output dataclass
    output = InferenceOutput(
        image="data:image/jpeg;base64," + base64_util.encode(image_bytes),
        audio="data:audio/mpeg;base64," + base64_util.encode(mp3_bytes),
        duration_s=segment.duration_seconds,
    )

    # release memory
    import gc
    del image, mask_image, init_image  # delete big tensors
    gc.collect()
    torch.cuda.empty_cache()  # free cached memory
    torch.cuda.ipc_collect()  # (optional) reclaim inter-process memory

    output_name = f"{''.join(inputs.seed_image_path.split('/')[-2:])}_to_{''.join(inputs.mask_image_path.split('/')[-2:])}"

    with open(f"{inputs.output_path}/{output_name}.json", "w") as f:
        json.dump(dataclasses.asdict(output), f, indent=2, ensure_ascii=False)

    return output

In [7]:
# Where built-in seed images are stored
# import traceback
# def run_app_background(*args, **kwargs):
#     try:
#         # Your existing Flask + ngrok code
#         global PIPELINE

#         import logging, sys
#         logging.basicConfig(
#             level=logging.DEBUG,
#             format="%(asctime)s [%(levelname)s] %(message)s",
#             handlers=[logging.StreamHandler(sys.stdout)]
#         )
#         app.logger.setLevel(logging.DEBUG)

#         app.logger.info("Loading RiffusionPipeline...")
#         PIPELINE = RiffusionPipeline.load_checkpoint(
#             checkpoint=kwargs.get("checkpoint", "riffusion/riffusion-model-v1"),
#             use_traced_unet=not kwargs.get("no_traced_unet", False),
#             device=kwargs.get("device", "cuda")
#         )
#         app.logger.info("Pipeline loaded successfully!")

#         public_url = ngrok.connect(kwargs.get("port", 5000))
#         print(f" * ngrok tunnel URL: {public_url}", flush=True)

#         app.logger.info(f"Starting Flask server on port {kwargs.get('port', 5000)}...")
#         app.run(port=kwargs.get("port", 5000), debug=kwargs.get("debug", True), use_reloader=False)

#     except Exception:
#         print("Exception in background thread:", flush=True)
#         traceback.print_exc()

def run_app(
    *,
    checkpoint: str = "riffusion/riffusion-model-v1",
    no_traced_unet: bool = False,
    device: str = "cuda",
    port: int = 5000,
    debug: bool = False,
):
    """
    Run a Flask API that serves the given riffusion model checkpoint
    and exposes it via ngrok.
    """
    global PIPELINE

    # Initialize the model
    PIPELINE = RiffusionPipeline.load_checkpoint(
        checkpoint=checkpoint,
        use_traced_unet=not no_traced_unet,
        device=device,
    )

    # Set debug mode
    app.debug = debug

    # Start ngrok tunnel
    public_url = ngrok.connect(port)
    print(f" * ngrok tunnel URL: {public_url}", flush=True)

    # Start Flask server
    app.run(port=port)


@app.route("/run_inference/", methods=["POST"])
def run_inference():
    """
    Execute the riffusion model as an API.

    Inputs:
        Serialized JSON of the InferenceInput dataclass

    Returns:
        Serialized JSON of the InferenceOutput dataclass
    """
    start_time = time.time()

    # Parse the payload as JSON
    json_data = json.loads(flask.request.data)

    # Log the request
    logging.info(json_data)

    # Parse an InferenceInput dataclass from the payload
    try:
        inputs = dacite.from_dict(InferenceInput, json_data)
    except dacite.exceptions.WrongTypeError as exception:
        logging.info(json_data)
        return str(exception), 400
    except dacite.exceptions.MissingValueError as exception:
        logging.info(json_data)
        return str(exception), 400

    # NOTE
    response = compute_request(
        inputs=inputs,
        pipeline=PIPELINE,
    )

    # Log the total time
    logging.info(f"Request took {time.time() - start_time:.2f} s")

    return response


# @app.route("/run_inference/", methods=["POST"])
# def run_inference():
#     """
#     Execute the riffusion model as an API.

#     Inputs:
#         Serialized JSON of the InferenceInput dataclass

#     Returns:
#         Serialized JSON of the InferenceOutput dataclass
#     """
#     start_time = time.time()

#     # Parse the payload as JSON
#     json_data = json.loads(flask.request.data)

#     # Log the request
#     logging.info(json_data)

#     # Parse an InferenceInput dataclass from the payload
#     try:
#         inputs = dacite.from_dict(InferenceInput, json_data)
#     except dacite.exceptions.WrongTypeError as exception:
#         logging.info(json_data)
#         return str(exception), 400
#     except dacite.exceptions.MissingValueError as exception:
#         logging.info(json_data)
#         return str(exception), 400

#     # NOTE
#     response = compute_request(
#         inputs=inputs,
#         pipeline=PIPELINE,
#     )

#     # Log the total time
#     logging.info(f"Request took {time.time() - start_time:.2f} s")

#     return response

def start_server():
  run_app()

In [9]:
# set to background thread
import threading
threading.Thread(target=start_server, daemon=True).start()



  PIPELINE = RiffusionPipeline.load_checkpoint(


In [10]:
# run inference
CUDA_DEVICE=1
START_SEED=42
END_SEED=123
DENOISING=0.2
GUIDANCE=0
ALPHA=0
STEPS=50
OUTPUT_PATH = "/content/drive/MyDrive/riffusion/Training-Free-StyleID/results/audio"
SEED_IMAGE_PATH="/content/drive/MyDrive/riffusion/results/riffusion_seed_mask_images/accordian123/1"
MASK_IMAGE_PATH="/content/drive/MyDrive/riffusion/results/riffusion_seed_mask_images/violin123/1"

# Run curl command
!CUDA_VISIBLE_DEVICES="$CUDA_DEVICE" curl -X POST http://127.0.0.1:5000/run_inference/ -H "Content-Type: application/json" -d '{"start":{"prompt":"","seed":'"$START_SEED"',"denoising":'"$DENOISING"',"guidance":'"$GUIDANCE"'},"num_inference_steps":'"$STEPS"',"seed_image_path":"'"$SEED_IMAGE_PATH"'","mask_image_path":"'"$MASK_IMAGE_PATH"'","alpha":'"$ALPHA"',"end":{"prompt":"","seed":'"$END_SEED"',"denoising":'"$DENOISING"',"guidance":'"$GUIDANCE"', "output_path": '"$OUTPUT_PATH"'}}'

curl: (7) Failed to connect to 127.0.0.1 port 5000 after 0 ms: Connection refused


In [15]:
import requests
import json
import os

data = {
    "start": {"prompt": "", "seed": START_SEED, "denoising": DENOISING, "guidance": GUIDANCE},
    "num_inference_steps": STEPS,
    "seed_image_path": SEED_IMAGE_PATH,
    "mask_image_path": MASK_IMAGE_PATH,
    "alpha": ALPHA,
    "end": {"prompt": "", "seed": END_SEED, "denoising": DENOISING, "guidance": GUIDANCE, "output_path": OUTPUT_PATH}
}

response = requests.post("http://127.0.0.1:5000/run_inference/", json=data)
try:
    response = requests.post("http://127.0.0.1:5000/run_inference/", json=data)
    logger.info(f"Response status code: {response.status_code}")
    logger.info(f"Response text: {response.text[:500]}")  # limit output to first 500 chars
except Exception as e:
    logger.error(f"Request failed: {e}")

INFO:werkzeug:127.0.0.1 - - [08/Sep/2025 03:48:53] "[31m[1mPOST /run_inference/ HTTP/1.1[0m" 400 -
INFO:werkzeug:127.0.0.1 - - [08/Sep/2025 03:48:53] "[31m[1mPOST /run_inference/ HTTP/1.1[0m" 400 -
INFO:my_server:Response status code: 400
INFO:my_server:Response text: Invalid seed image: /content/drive/MyDrive/riffusion/results/riffusion_seed_mask_images/accordian123/1


######################### input image path:  /content/drive/MyDrive/riffusion/results/riffusion_seed_mask_images/accordian123/1.png
######################### input image path:  /content/drive/MyDrive/riffusion/results/riffusion_seed_mask_images/accordian123/1.png


try pyngrok

In [8]:
! pip install pyngrok

Collecting pyngrok
  Downloading pyngrok-7.3.0-py3-none-any.whl.metadata (8.1 kB)
Downloading pyngrok-7.3.0-py3-none-any.whl (25 kB)
Installing collected packages: pyngrok
Successfully installed pyngrok-7.3.0


In [15]:
import typing as T
import logging
import time
import json
from flask import Flask, request, jsonify
from pyngrok import ngrok
import dacite

app = Flask(__name__)
PIPELINE = None  # will be initialized in run_app

def run_app(
    *,
    checkpoint: str = "riffusion/riffusion-model-v1",
    no_traced_unet: bool = False,
    device: str = "cuda",
    port: int = 5000,
    debug: bool = False,
):
    """
    Run a Flask API that serves the given riffusion model checkpoint
    and exposes it via ngrok.
    """
    global PIPELINE

    # Initialize the model
    PIPELINE = RiffusionPipeline.load_checkpoint(
        checkpoint=checkpoint,
        use_traced_unet=not no_traced_unet,
        device=device,
    )

    # Set debug mode
    app.debug = debug

    # Start ngrok tunnel
    public_url = ngrok.connect(port)
    print(f" * ngrok tunnel URL: {public_url}")

    # Start Flask server
    app.run(port=port)

def start_server():
  run_app()

NameError: name 'testing' is not defined

In [16]:
# set to background thread
import threading
threading.Thread(target=start_server, daemon=True).start()



  PIPELINE = RiffusionPipeline.load_checkpoint(


test run_app() with n_grok setting

In [None]:
# test app server
# from flask import Flask, request, jsonify
# from flask_ngrok import run_with_ngrok
# import argh

# def run_app():
#   app = Flask(__name__)
#   run_with_ngrok(app)  # starts ngrok when app.run() is called

#   @app.route("/hello", methods=["GET"])
#   def hello():
#       return jsonify({"msg": "Hello from Flask in Colab!"})

#   @app.route("/echo", methods=["POST"])
#   def echo():
#       data = request.json
#       return jsonify({"you_sent": data})

#   app.run()

# this line is the problem
# argh.dispatch_command(run_app)

In [8]:
# def run_app(
#     *,
#     checkpoint: str = "riffusion/riffusion-model-v1",
#     no_traced_unet: bool = False,
#     device: str = "cuda",
#     host: str = "127.0.0.1",
#     port: int = 8001,
#     debug: bool = False,
#     ssl_certificate: T.Optional[str] = None,
#     ssl_key: T.Optional[str] = None,
# ):
    """
    Run a flask API that serves the given riffusion model checkpoint.
    """
    # Initialize the model
    # global PIPELINE

    # PIPELINE = RiffusionPipeline.load_checkpoint(
    #     checkpoint=checkpoint,
    #     use_traced_unet=True,
    #     device=device,
    # )

    # TypeError: run_with_ngrok.<locals>.new_run() got an unexpected keyword argument 'debug'
    # args = dict(
    #     # debug=debug,
    #     # threaded=False,
    #     host=host,
    #     port=port,
    # )

    # if ssl_certificate:
    #     assert ssl_key is not None
    #     args["ssl_context"] = (ssl_certificate, ssl_key)

    # app.run(**args)  # type: ignore

In [9]:
run_app() # app.run(**args)

TypeError: run_with_ngrok.<locals>.new_run() got an unexpected keyword argument 'host'

In [10]:
from flask_ngrok import run_with_ngrok

# Flask app with CORS
app = flask.Flask(__name__)
CORS(app)

# run background thread (with daemon)
run_with_ngrok(app)