# Setup

In [None]:
# @title Check Runtime Compatibility
import torch

# Check for NVIDIA GPU
def check_nvidia_gpu():
    cuda_available = torch.cuda.is_available()
    if cuda_available:
        gpu_name = torch.cuda.get_device_name(0)
        print(f"NVIDIA GPU detected: {gpu_name}")
        return True
    else:
        print("No NVIDIA GPU detected.")
        return False

# Set a global variable for runtime compatibility
correct_runtime = check_nvidia_gpu()

if not correct_runtime:
    print("Use a compatible runtime with an NVIDIA GPU.")


Free Colab (or T4 GPU) Step 1:

In [None]:
#@markdown Install Stable Audio Tools
!pip install stable-audio-tools

GPU >= L4 Step 1:

In [None]:
#@markdown Install Stable Audio Tools for an L4 GPU Runtime on colab or greater - installs flas_attn
!pip install flash_attn
!pip install stable-audio-tools

Setup Step 3:

In [None]:
# @markdown after restarting from the previous step, run this cell to install LoRAW and then it will require one more reestart.
!git clone https://github.com/NeuralNotW0rk/LoRAW.git
%cd LoRAW
!pip install .
%cd ..

In [None]:
# @title Connect to Google Drive
from google.colab import drive
drive.mount('/content/drive')
print("Google Drive connected.")


# Download Model and Configure settings

In [None]:
# @title Download Model from Hugging Face (Optionaly you can use your own uploaded model or one from your google drive)
import os
import shutil

# Hugging Face token
# @markdown **IMPORTANT!** Make sure your Hugging Face token has the Permission to "Read access to contents of all public gated repos you can access" enabled in the Access Token Permissions if you are calling directly from the official Stable Audio base model repo.
hf_token_text = "" # @param {type:"string"}
download_dir = "/content/base_model" # @param {type:"string"}
repo_name = "stabilityai/stable-audio-open-1.0" # @param {type:"string"}

# Function to handle downloading the model from Hugging Face
def handle_model_loading(hf_token, download_dir, repo_name):
    if not hf_token:
        print("Please provide your Hugging Face API token.")
        return

    from huggingface_hub import hf_hub_download, HfApi

    # Check if the token has the required permissions
    try:
        api = HfApi()
        api.whoami(token=hf_token)
    except Exception as e:
        print(f"Invalid or insufficient Hugging Face API token: {e}")
        return

    # Create a directory for saving the downloaded files
    os.makedirs(download_dir, exist_ok=True)

    # Download the model checkpoint
    try:
        model_file = hf_hub_download(
            repo_id=repo_name,
            filename="model.ckpt",
            use_auth_token=hf_token
        )
        print(f"Downloaded model checkpoint to {model_file}")

        # Download the config file
        config_file = hf_hub_download(
            repo_id=repo_name,
            filename="model_config.json",
            use_auth_token=hf_token
        )
        print(f"Downloaded config file to {config_file}")
        print(f"Copying to {download_dir}")
        # Copy the files to the desired directory
        shutil.copy(model_file, os.path.join(download_dir, "model.ckpt"))
        shutil.copy(config_file, os.path.join(download_dir, "model_config.json"))

        print(f"Copied model and config files to {download_dir}.")
    except Exception as e:
        print(f"Error downloading or copying files from Hugging Face: {e}")

# Call the function to handle the model loading
handle_model_loading(hf_token_text, download_dir, repo_name)

# @markdown When you are done here, the "download_dir" will be where your model.ckpt and model_config.ckpt are located


In [None]:
# @title Set Training Paths and Options
# Hugging Face repository name for a Stable Audio Tools model
# Will prioritize model.safetensors over model.ckpt in the repo
# Optional, used in place of model-config and ckpt-path when using pre-trained model checkpoints on Hugging Face
#@markdown Training Config

#@markdown note: if you downloaded the model using the downloader with the default settings then these default paths will work for ckpt_path and model_config

# Path to the model config file for a local model

ckpt_path = "/content/base_model/model.ckpt" # @param {type:"string"}
model_config = "/content/base_model/model_config.json" # @param {type:"string"}

# Path to unwrapped model checkpoint file for a local model
# pretrained_name = "" # @param {type:"string"}

# Used in various model types such as latent diffusion models to load a pre-trained autoencoder.
# Requires an unwrapped model checkpoint.
pretransform_ckpt_path = "" # @param {type:"string"}

# The directory in which to save the model checkpoints
save_dir = "/trained" # @param {type:"string"}

# The number of steps between saved checkpoints.
# Default: 10000
checkpoint_every = 30 # @param {type:"integer"}

# Number of samples per-GPU during training. Should be set as large as your GPU VRAM will allow.
# Default: 8
batch_size = 8 # @param {type:"integer"}

# Enables and sets the number of batches for gradient batch accumulation. Useful for increasing effective batch size when training on smaller GPUs.
accum_batches = 2 # @param {type:"integer"}

# learning rate
learning_rate = 1e-4 # @param {type:"number"}

# Set to true to enable lora usage
use_lora = True # @param {type:"boolean"}

# A pre-trained lora continue from
lora_ckpt_path = "" # @param {type:"string"}

# Enables ReLoRA training if set, the number of steps between full-rank updates
relora_every = 1000 # @param {type:"integer"}
# @markdown Quantize is currently broken
# Set to true to enable 4-bit quantization of base model for QLoRA training
quantize = False # @param {type:"boolean"}

# @markdown Advanced
# Number of GPUs per-node to use for training
# Default: 1
# num_gpus = 1 # @param {type:"integer"}

# Number of GPU nodes being used for training
# Default: 1
# num_nodes = 1 # @param {type:"integer"}

# Multi-GPU strategy for distributed training. Setting to deepspeed will enable DeepSpeed ZeRO Stage 2.
# Default: ddp if --num_gpus > 1, else None
# strategy = "ddp" # @param {type:"string"}

# Floating-point precision to use during training
# Default: 16
precision = 16 # @param {type:"integer"}

# Number of CPU workers used by the data loader
# num_workers = 8 # @param {type:"integer"}

# RNG seed for PyTorch, helps with deterministic training
# seed = -1 # @param {type:"integer"}



In [None]:
# @title Update Parameters in model_config.json
# @markdown Defaults are fine here, though sample_size is roughly 6.2 for training on free Colabs
import json

# Default parameters for lora
sample_size = 276480 # @param {type:"number"}
component_whitelist = ["transformer"] # @param {type:"raw"}
multiplier = 1.0 # @param {type:"number"}
rank = 16 # @param {type:"integer"}
alpha = 16 # @param {type:"integer"}
dropout = 0.0 # @param {type:"number"}
module_dropout = 0.0 # @param {type:"number"}

# Load the existing config
with open(model_config, 'r') as f:
    config = json.load(f)

# Add or update lora parameters
lora_defaults = {
    "component_whitelist": component_whitelist,
    "multiplier": multiplier,
    "rank": rank,
    "alpha": alpha,
    "dropout": dropout,
    "module_dropout": module_dropout,
    "lr": learning_rate
}

config["lora"] = config.get("lora", {})
config["lora"].update(lora_defaults)

# Update sample size
config["sample_size"] = sample_size

# Save the updated config
with open(model_config, 'w') as f:
    json.dump(config, f, indent=2)

print(f"Updated lora parameters and sample size in {model_config}")


In [None]:
# @title Create Dataset Config files
#@markdown This will create the dataset.json and metadata.py files required for training
import json
import os

# Dataset folder param
dataset_folder = "/content/drive/MyDrive/YOUR_DATASET_PATH" # @param {type:"string"}

# Random crop param
random_crop = False # @param {type:"boolean"}
#@markdown advanced users can edit this cell to enhance the metadata collection process
dataset_configs_folder = "/content/dataset_configs"
os.makedirs(dataset_configs_folder, exist_ok=True)

# Create dataset config JSON
dataset_config = {
    "dataset_type": "audio_dir",
    "datasets": [
        {
            "id": "fine_tune",
            "path": dataset_folder,
            "custom_metadata_module": f"{dataset_configs_folder}/metadata.py"
        }
    ],
    "random_crop": random_crop
}

dataset_config_path = os.path.join(dataset_configs_folder, "dataset_config.json")

with open(dataset_config_path, "w") as f:
    json.dump(dataset_config, f, indent=4)

print(f"Created dataset config file at: {dataset_config_path}")

# Define the content for metadata.py
metadata_py_content = f"""
def get_custom_metadata(info, audio):
    # Remove the dataset folder path from the root of the path
    relative_path = info['path'].replace("{dataset_folder}", "")
    # Split the remaining path into components and join them with spaces, reversing the order
    prompt = ",".join(reversed(relative_path.strip("/").replace("-", " ").replace("_", " ").split("/")))
    return {{"prompt": prompt}}
"""

# Path to the metadata.py file
metadata_py_path = os.path.join(dataset_configs_folder, "metadata.py")

# Write the content to metadata.py
with open(metadata_py_path, 'w') as file:
    file.write(metadata_py_content)

print(f"Created metadata.py at {metadata_py_path}")


# Train

In [None]:
# @title Run Training Command
# Set the paths
loRAW_dir = "/content/LoRAW"
dataset_config_path = dataset_configs_folder+"/dataset_config.json"
model_config_path = model_config
pretrained_ckpt_path = ckpt_path

print(f"Dataset config path: {dataset_config_path}")
print(f"Model config path: {model_config_path}")
print(f"Pretrained ckpt path: {pretrained_ckpt_path}")
# Additional flags
use_lora = True  # @param {type:"boolean"}
%cd {loRAW_dir}
# Construct the training command
training_command = f"python train.py --dataset-config {dataset_config_path} --model-config {model_config_path} --pretrained-ckpt-path {pretrained_ckpt_path} --batch-size {batch_size} --checkpoint-every {checkpoint_every} --accum-batches {accum_batches}"

if use_lora:
    training_command += " --use-lora true"

# Print the command (for debugging purposes)
print(f"Training command: {training_command}")

# Run the training command
!{training_command}

# Use The Lora on Gradio

In [None]:

#@markdown Gradio Settings:
# If true, a publicly shareable link will be created for the Gradio demo
# share = False # @param {type:"boolean"}

# Used together to set a login for the Gradio demo
username = "guest" # @param {type:"string"}
password = "guest" # @param {type:"string"}
# /content/LoRAW/lightning_logs/YOUR_CHECKPOINTS_FOLDER
lora_checkpoint_path = "/content/LoRAW/lightning_logs/YOUR_FINE_TUNE_ID_HERE/checkpoints" # @param {type:"string"}
# If true, the model weights to half-precision
model_half = False # @param {type:"boolean"}

loRAW_dir = "/content/LoRAW"

%cd {loRAW_dir}
# Construct the training command
gradio_command = f"python run_gradio.py --model-config {model_config_path} --ckpt-path {pretrained_ckpt_path} --username {username} --password {password} --lora-dir {lora_checkpoint_path}"
print(gradio_command)
# Print the command (for debugging purposes)
!{gradio_command}