<a href="https://colab.research.google.com/github/Harmonai-org/sample-generator/blob/main/Finetune_Dance_Diffusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dance Diffusion finetune

Licensed under the MIT License

Copyright (c) 2022 Zach Evans

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.


# Set Up

In [1]:
#@title Check GPU Status
import subprocess
simple_nvidia_smi_display = True#@param {type:"boolean"}
if simple_nvidia_smi_display:
    #!nvidia-smi
    nvidiasmi_output = subprocess.run(['nvidia-smi', '-L'], stdout=subprocess.PIPE).stdout.decode('utf-8')
    print(nvidiasmi_output)
else:
    #!nvidia-smi -i 0 -e 0
    nvidiasmi_output = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE).stdout.decode('utf-8')
    print(nvidiasmi_output)
    nvidiasmi_ecc_note = subprocess.run(['nvidia-smi', '-i', '0', '-e', '0'], stdout=subprocess.PIPE).stdout.decode('utf-8')
    print(nvidiasmi_ecc_note)

GPU 0: Tesla T4 (UUID: GPU-00ddb01c-2415-da34-a7a6-f166e6f38a0c)



In [2]:
#@title Prepare folders
import subprocess, os, sys, ipykernel

def gitclone(url, targetdir=None):
    if targetdir:
        res = subprocess.run(['git', 'clone', url, targetdir], stdout=subprocess.PIPE).stdout.decode('utf-8')
    else:
        res = subprocess.run(['git', 'clone', url], stdout=subprocess.PIPE).stdout.decode('utf-8')
    print(res)

def pipi(modulestr):
    res = subprocess.run(['pip', 'install', modulestr], stdout=subprocess.PIPE).stdout.decode('utf-8')
    print(res)

def pipie(modulestr):
    res = subprocess.run(['git', 'install', '-e', modulestr], stdout=subprocess.PIPE).stdout.decode('utf-8')
    print(res)

def wget(url, outputdir):
    res = subprocess.run(['wget', url, '-P', f'{outputdir}'], stdout=subprocess.PIPE).stdout.decode('utf-8')
    print(res)

try:
    from google.colab import drive
    print("Google Colab detected. Using Google Drive.")
    is_colab = True
    google_drive = True #@param {type:"boolean"}
    #@markdown Click here if you'd like to save the diffusion model checkpoint file to (and/or load from) your Google Drive:
    save_models_to_google_drive = True #@param {type:"boolean"}
except:
    is_colab = False
    google_drive = False
    save_models_to_google_drive = False
    print("Google Colab not detected.")

if is_colab:
    if google_drive is True:
        drive.mount('/content/drive')
        root_path = '/content/drive/MyDrive/AI/Bass_Diffusion'
    else:
        root_path = '/content'
else:
    root_path = os.getcwd()

import os
def createPath(filepath):
    os.makedirs(filepath, exist_ok=True)

initDirPath = f'{root_path}/init_audio'
createPath(initDirPath)
outDirPath = f'{root_path}/audio_out'
createPath(outDirPath)

if is_colab:
    if google_drive and not save_models_to_google_drive or not google_drive:
        model_path = '/content/models'
        createPath(model_path)
    if google_drive and save_models_to_google_drive:
        model_path = f'{root_path}/models'
        createPath(model_path)
else:
    model_path = f'{root_path}/models'
    createPath(model_path)

#@markdown Click here if you'd like to save the diffusion model checkpoint file to your [Weights & Biases](www.wandb.ai/site) account:
save_models_to_wandb = True #@param {type:"boolean"}
save_wandb_str = '--save-wandb all' if save_models_to_wandb else ''
if save_models_to_wandb == True:
    print('Saving model checkpoints in wandb')

Google Colab detected. Using Google Drive.
Mounted at /content/drive


In [3]:
#@title Install dependencies
!git clone https://github.com/harmonai-org/sample-generator
!pip install /content/sample-generator

Cloning into 'sample-generator'...
remote: Enumerating objects: 413, done.[K
remote: Counting objects: 100% (24/24), done.[K
remote: Compressing objects: 100% (17/17), done.[K
remote: Total 413 (delta 11), reused 18 (delta 7), pack-reused 389 (from 1)[K
Receiving objects: 100% (413/413), 59.57 MiB | 33.44 MiB/s, done.
Resolving deltas: 100% (236/236), done.
Processing ./sample-generator
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting prefigure (from sample-generator==1.0.0)
  Downloading prefigure-0.0.10-py3-none-any.whl.metadata (8.3 kB)
Collecting pytorch_lightning (from sample-generator==1.0.0)
  Downloading pytorch_lightning-2.5.0.post0-py3-none-any.whl.metadata (21 kB)
Collecting argparse (from prefigure->sample-generator==1.0.0)
  Downloading argparse-1.4.0-py2.py3-none-any.whl.metadata (2.8 kB)
Collecting configparser (from prefigure->sample-generator==1.0.0)
  Downloading configparser-7.1.0-py3-none-any.whl.metadata (5.4 kB)
Collecting gradio (from prefigure

# Train

In [1]:
#@markdown Log in to [Weights & Biases](https://wandb.ai/) for run tracking
!wandb login

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mzcheukhei86[0m ([33mzcheukhei86-the-university-of-hong-kong[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
#@markdown Name for the finetune project, used as the W&B project name, as well as the directory for the saved checkpoints
NAME="dd-basses" #@param {type:"string"}

#@markdown Path to the directory of audio data to use for fine-tuning
TRAINING_DIR="/content/drive/MyDrive/AItrainingdata_hiphop" #@param {type:"string"}

#@markdown Path to the checkpoint to fine-tune
CKPT_PATH="/content/drive/MyDrive/AI/models/gwf-440k.ckpt" #@param {type:"string"}

#@markdown Directory path for saving the fine-tuned outputs
OUTPUT_DIR="/content/drive/MyDrive/AI/models/DanceDiffusion/finetune" #@param {type:"string"}

#@markdown Number of training steps between demos
DEMO_EVERY=250 #@param {type:"number"}

#@markdown Number of training steps between saving model checkpoints
CHECKPOINT_EVERY=500 #@param {type:"number"}

#@markdown Sample rate to train at
SAMPLE_RATE=44100 #@param {type:"number"}

#@markdown Number of audio samples per training sample
SAMPLE_SIZE=65536 #@param {type:"number"}

#@markdown If true, the audio samples provided will be randomly cropped to SAMPLE_SIZE samples
#@markdown
#@markdown Turn off if you want to ensure the training data always starts at the beginning of the audio files (good for things like drum one-shots)
RANDOM_CROP=False #@param {type:"boolean"}

#@markdown Batch size to fine-tune (make it as high as it can go for your GPU)
BATCH_SIZE=2 #@param {type:"number"}

#@markdown Accumulate gradients over n batches, useful for training on one GPU.
#@markdown
#@markdown Effective batch size is BATCH_SIZE * ACCUM_BATCHES.
#@markdown
#@markdown Also increases the time between demos and saved checkpoints
ACCUM_BATCHES=4 #@param {type:"number"}

random_crop_str = f"--random-crop True" if RANDOM_CROP else ""

# Escape spaces in paths
CKPT_PATH = CKPT_PATH.replace(f" ", f"\ ")
OUTPUT_DIR = f"{OUTPUT_DIR}/{NAME}".replace(f" ", f"\ ")

%cd /content/sample-generator/

ckpt_path_str = f"--ckpt-path {CKPT_PATH}" if not CKPT_PATH =="" else ""

!python3 /content/sample-generator/train_uncond.py $ckpt_path_str \
                                                          --name "dd-basses" \
                                                          --training-dir "/content/drive/MyDrive/AI/models/DanceDiffusion/finetune" \
                                                          --sample-size 65536 \
                                                          --accum-batches 4 \
                                                          --sample-rate 44100 \
                                                          --batch-size 2 \
                                                          --demo-every 250 \
                                                          --checkpoint-every 500 \
                                                          --num-workers 2 \
                                                          --num-gpus 1 \
                                                          $random_crop_str \
                                                          --save-path "/content/drive/MyDrive/AI/models/DanceDiffusion/finetune"

/content/sample-generator
Using device: cuda
Random crop: False
Traceback (most recent call last):
  File "/content/sample-generator/train_uncond.py", line 228, in <module>
    main()
  File "/content/sample-generator/train_uncond.py", line 199, in main
    train_dl = data.DataLoader(train_set, args.batch_size, shuffle=True,
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 376, in __init__
    sampler = RandomSampler(dataset, generator=generator)  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/sampler.py", line 164, in __init__
    raise ValueError(
ValueError: num_samples should be a positive integer value, but got num_samples=0


In [None]:
#@markdown Name for the finetune project, used as the W&B project name, as well as the directory for the saved checkpoints
NAME="dd-basses" #@param {type:"string"}

#@markdown Path to the directory of audio data to use for fine-tuning
TRAINING_DIR="/content/drive/MyDrive/AItrainingdata_hiphop" #@param {type:"string"}

#@markdown Path to the checkpoint to fine-tune
CKPT_PATH="/content/drive/MyDrive/AI/models/gwf-440k.ckpt" #@param {type:"string"}

#@markdown Directory path for saving the fine-tuned outputs
OUTPUT_DIR="/content/drive/MyDrive/AI/models/DanceDiffusion/finetune" #@param {type:"string"}

#@markdown Number of training steps between demos
DEMO_EVERY=250 #@param {type:"number"}

#@markdown Number of training steps between saving model checkpoints
CHECKPOINT_EVERY=500 #@param {type:"number"}

#@markdown Sample rate to train at
SAMPLE_RATE=44100 #@param {type:"number"}

#@markdown Number of audio samples per training sample
SAMPLE_SIZE=65536 #@param {type:"number"}

#@markdown If true, the audio samples provided will be randomly cropped to SAMPLE_SIZE samples
#@markdown
#@markdown Turn off if you want to ensure the training data always starts at the beginning of the audio files (good for things like drum one-shots)
RANDOM_CROP=False #@param {type:"boolean"}

#@markdown Batch size to fine-tune (make it as high as it can go for your GPU)
BATCH_SIZE=2 #@param {type:"number"}

#@markdown Accumulate gradients over n batches, useful for training on one GPU.
#@markdown
#@markdown Effective batch size is BATCH_SIZE * ACCUM_BATCHES.
#@markdown
#@markdown Also increases the time between demos and saved checkpoints
ACCUM_BATCHES=4 #@param {type:"number"}

random_crop_str = f"--random-crop True" if RANDOM_CROP else ""

# Escape spaces in paths
CKPT_PATH = CKPT_PATH.replace(f" ", f"\ ")
OUTPUT_DIR = f"{OUTPUT_DIR}/{NAME}".replace(f" ", f"\ ")

%cd /content/sample-generator/

ckpt_path_str = f"--ckpt-path {CKPT_PATH}" if not CKPT_PATH =="" else ""

!python3 /content/sample-generator/train_uncond.py $ckpt_path_str \
                                                          --name "dd-basses" \
                                                          --training-dir "/content/drive/MyDrive/AItrainingdata_hiphop" \
                                                          --sample-size 65536 \
                                                          --accum-batches 4 \
                                                          --sample-rate 44100 \
                                                          --batch-size 2 \
                                                          --demo-every 250 \
                                                          --checkpoint-every 500 \
                                                          --num-workers 2 \
                                                          --num-gpus 1 \
                                                          $random_crop_str \
                                                          --save-path "/content/drive/MyDrive/AI/models/DanceDiffusion/finetune"

/content/sample-generator
Using device: cuda
Random crop: False
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mzcheukhei86[0m ([33mzcheukhei86-the-university-of-hong-kong[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Tracking run with wandb version 0.19.7
[34m[1mwandb[0m: Run data is saved locally in [35m[1m./wandb/run-20250307_054949-lhmg00o2[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mrestful-waterfall-5[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/zcheukhei86-the-university-of-hong-kong/dd-basses[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/zcheukhei86-the-university-of-hong-kong/dd-basses/runs/lhmg00o2[0m
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=F