<a href="https://colab.research.google.com/github/karchkha/MusicLDM/blob/main/MusicLDM_pub.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<font face="Trebuchet MS" size="6">MusicLDM<font color="#999" size="4">&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;</font><font color="#999" size="4">Text-to-Music</font><font color="#999" size="4">&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;</font><a href="" target="_blank"><font color="#999" size="4">Github</font></a>

Generate audio from text-prompt using [MusicLDM](https://github.com/RetroCirce/MusicLDM).

This notebook has been optimized to use both [Diffusers MusicLDM pipeline](https://huggingface.co/ucsd-reach/musicldm) (for text-to-music generation) and native [MUsicLDM Python API](https://github.com/RetroCirce/MusicLDM) (for  text-to-music, Music-to-Music and Style Transfer). This makes first setup slow, but usage fast.


</font></a>
This colab is based on: https://colab.research.google.com/github/olaviinha/NeuralTextToAudio/blob/main/AudioLDM_pub.ipynb

## Instructions and tips



### <big>__Notebook usage__</big>
#### __General__
- `mount_drive` is optional but highly recommended, as it enables you to auto-save all generated WAV files as well as used checkpoints directly to your Google Drive (and thus sync to your computer in near real-time, if you have Google Drive installed). Should you opt not to mount Google Drive, directory _faux_drive_ (`/content/faux_drive`) found in the Files browser of the Colab runtime works as if it was your _My Drive_. You may use it to upload/download files via Colab's own Files browser pretending it's your Google Drive.
- All directory and file paths should be relative to your Google Drive root (My Drive). E.g. `output_dir` value should be `Music/AI-Generated-Sounds` if you have a directory called _Music_ in your Drive, containing a subdirectory called _AI-Generated-Sounds_. All paths are case-sensitive.
- MusicLDM has only one pulically available checkpoint that can be run on requires standard GPU.
- `local_models_dir` (optional) will save the used checkpoints in your Google Drive and/or use them from there if already available. Using this is a significant timesaver on Setup next times you use the notebook.
- `output_dir` is where the generated WAV files will be saved.
- `batch` will just repeat whatever you're generating that many times.
- If seed is set to `0` (zero), a random seed will be used.
- You may use a `;` (semicolon) in the prompt field as a separator, in which case a separate Music file will be generated for each semicolon-separated prompt in a single run.
<br><br>


#### __Text-to-Music generation__
- Enter a text to `prompt` field.
- Remove everything from `init_audio_file` field.
- Set `style_strength` to zero.
<br><br>

#### __Music-to-Music generation__
- Enter a file path to `init_audio_file` field.
- Remove all text from `prompt` field.
- Set `style_strength` to zero.
<br><br>

#### __Style Transfer__
- Enter a file path to `init_audio_file` field.
- Describe the style you want in `prompt` field.
- Set `style_strength` to greater than zero.
<br><br>


# Setup

# <font color="#FF0000"></a> After the first run of this cell, please use skip_setup=True for faster setup, unless you quit and re-connect runtime.</a>
</font>

In [None]:
##@title #Setup
#@markdown <small>Mounting Drive will enable this notebook to save outputs directly to your Drive. Otherwise you will need to copy/download them manually from this notebook.</small>

# Print colors
class c:
  title = '\033[96m'
  ok = '\033[92m'
  okb = '\033[94m'
  warn = '\033[93m'
  fail = '\033[31m'
  endc = '\033[0m'
  bold = '\033[1m'
  dark = '\33[90m'
  u = '\033[4m'

def op(typex, msg, value='', replaceable=False, time=False):
  if time == True:
    stamp = timestamp(human_readable=True)
    typex = c.dark+stamp+' '+typex
  if value != '':
    print(typex+msg+c.endc, end=' ')
    if replaceable == True:
      print(value, end='\r')
    else:
      print(value)
  else:
    if replaceable == True:
      print(typex+msg+c.endc, end='\r')
    else:
      print(typex+msg+c.endc)


quick_setup = False #@ param {type:"boolean"}

use_diffusers = False #@param {type:"boolean"}
use_github = True if quick_setup == False else False

force_setup = False
repositories = ["https://github.com/karchkha/MusicLDM.git"] #['https://github.com/RetroCirce/MusicLDM'] #  ['--single-branch --branch main https://github.com/RetroCirce/MusicLDM']
apt_packages = 'ffmpeg'
mount_drive = True #@param {type:"boolean"}
skip_setup = False #@param {type:"boolean"}
local_models_dir = "musicldm_model_local" #@param {type:"string"}
use_checkpoint = "musicldm-ckpt"
download_example_wav = True #@param {type:"boolean"}

if '-text-' in use_checkpoint:
  use_diffusers = False

pip_packages = 'transformers diffusers accelerate' if use_diffusers == True else ''

if quick_setup == True:
  op(c.title, 'Performing quick setup')

if use_github == False:
  repositories = []
  apt_packages = ''
  if local_models_dir != '':
    op(c.fail, '!! local_models_dir is ignored on quick setup')
  local_models_dir = ''

use_ckpt = use_checkpoint #+".ckpt"

import os
import yaml
from google.colab import output, files
import warnings
warnings.filterwarnings('ignore')
%cd /content

if skip_setup == False:
  if pip_packages != '':
    !pip -q install {pip_packages}
  if apt_packages != '':
    !apt-get update && apt-get install {apt_packages}

  # fix issue Unexpected key(s) in state_dict: "cond_stage_model.model.text_branch.embeddings.position_ids".
  # !pip install transformers==4.29.0

if use_diffusers == True:
  # pass
  import torch
  from diffusers import MusicLDMPipeline #AudioLDMPipeline
  from transformers import AutoProcessor, ClapModel

  # make Space compatible with CPU duplicates
  if torch.cuda.is_available():
      device = "cuda"
      torch_dtype = torch.float16
  else:
      device = "cpu"
      torch_dtype = torch.float32

  # load the diffusers pipeline
  repo_id = "ucsd-reach/musicldm"

  pipe = MusicLDMPipeline.from_pretrained(repo_id, torch_dtype=torch_dtype).to(device)
  # pipe.unet = torch.compile(pipe.unet)

  # CLAP model (only required for automatic scoring)
  clap_model = ClapModel.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full").to(device)
  processor = AutoProcessor.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full")
  generator = torch.Generator(device)

import sys, time, ntpath, string, random, librosa, librosa.display, IPython, shutil, math, psutil, datetime, requests, pytz
import numpy as np
import soundfile as sf
from datetime import timedelta

def gen_id(type='short'):
  id = ''
  if type == 'timestamp':
    id = timestamp()
  if type == 'short':
    id = requests.get('https://api.inha.asia/k/?type=short').text
  if type == 'long':
    id = requests.get('https://api.inha.asia/k').text
  return id

def timestamp(no_slash=False, human_readable=False, helsinki_time=True, date_only=False):
  if helsinki_time == True:
    dt = datetime.datetime.now(pytz.timezone('Europe/Helsinki'))
  else:
    dt = datetime.datetime.now()
  if no_slash == True:
    dt = dt.strftime("%Y%m%d%H%M%S")
  else:
    if human_readable == True:
      dt = dt.strftime("%Y-%m-%d %H:%M:%S")
    else:
      if date_only == True:
        dt = dt.strftime("%Y-%m-%d")
      else:
        dt = dt.strftime("%Y-%m-%d_%H%M%S")
  return dt;

def fix_path(path, add_slash=False):
  if path.endswith('/'):
    path = path #path[:-1]
  if not path.endswith('/'):
    path = path+"/"
  if path.startswith('/') and add_slash == True:
    path = path[1:]
  return path

def path_leaf(path):
  head, tail = ntpath.split(path)
  return tail or ntpath.basename(head)

def path_dir(path):
  return path.replace(path_leaf(path), '')

def path_ext(path, only_ext=False):
  filename, extension = os.path.splitext(path)
  if only_ext == True:
    extension = extension[1:]
  return extension

def basename(path):
  filename = os.path.basename(path).strip()#.replace(" ", "_")
  filebase = os.path.splitext(filename)[0]
  return filebase

def slug(s):
  valid_chars = "-_. %s%s" % (string.ascii_letters, string.digits)
  file = ''.join(c for c in s if c in valid_chars)
  file = file.replace(' ','_')
  return file

def audio_player(input, sr=44100, limit_duration=2):
  if type(input) != np.ndarray:
    input, sr = librosa.load(input, sr=None, mono=False)
  if limit_duration > 0:
    last_sample = math.floor(limit_duration*60*sr)
    if input.shape[-1] > last_sample:
      input = input[:last_sample, :last_sample]
      op(c.warn, 'WARN! Playback of below audio player is limited to first '+str(limit_duration)+' minutes to prevent Colab from crashing.\n')
  IPython.display.display(IPython.display.Audio(input, rate=sr))

# Mount Drive
if mount_drive == True:
  if os.path.isdir('/content/mydrive'):
    drive_root = '/content/mydrive/'
  else:
    if not os.path.isdir('/content/drive'):
      from google.colab import drive
      drive.mount('/content/drive')
      drive_root = '/content/drive/My Drive/'
    if not os.path.isdir('/content/mydrive'):
      os.symlink('/content/drive/My Drive', '/content/mydrive')
      drive_root = '/content/mydrive/'
    drive_root_set = True
else:
  drive_root = '/content/faux_drive/'
  if not os.path.isdir(drive_root):
    os.mkdir(drive_root)


if mount_drive == False:
  local_models_dir = ''

if len(repositories) > 0 and skip_setup == False:
  for repo in repositories:
    %cd /content/
    install_dir = fix_path('/content/'+path_leaf(repo).replace('.git', ''))
    repo = repo if '.git' in repo else repo+'.git'
    !git clone {repo}
    if os.path.isfile(install_dir+'setup.py') or os.path.isfile(install_dir+'setup.cfg'):
      !pip install -e {install_dir}
    if os.path.isfile(install_dir+'requirements.txt'):
      !pip install -r {install_dir}/requirements.txt
    elif os.path.isfile(install_dir+'colab_requirements.txt'):
      !pip uninstall torchdata torchtext  -y  # remove unsused librraries that don't work with needed ones
      !pip install -r {install_dir}/colab_requirements.txt
      !pip install --upgrade numba
    elif os.path.isfile(install_dir+'musicldm_environment.yml'):
      with open(f"{install_dir}/musicldm_environment.yml", 'r') as file:
          data = yaml.safe_load(file)

      dependencies = data.get('dependencies', [])
      pip_dependencies = []
      for dep in dependencies:
          if isinstance(dep, dict) and 'pip' in dep:
              pip_dependencies.extend(dep['pip'])

      with open(f'{install_dir}requirements.txt', 'w') as file:
          for dep in pip_dependencies:
              file.write(dep + '\n')

      !pip install -r {install_dir}/requirements.txt

      !rm {install_dir}/requirements.txt

else:
  install_dir = "/content/MusicLDM"

if len(repositories) == 1:
  %cd {install_dir}

dir_tmp = '/content/tmp/'
if not os.path.isdir(dir_tmp): os.mkdir(dir_tmp)

use_ckpt_path = os.path.expanduser('~')+'/.cache/musicldm/'

if not os.path.isdir(use_ckpt_path):
  os.makedirs(use_ckpt_path)

#### Model checkpoint is given part by part:
ckpt_urls = {
    "hifigan-ckpt.ckpt" : "https://zenodo.org/records/10643148/files/hifigan-ckpt.ckpt",
    "vae-ckpt.ckpt" : "https://zenodo.org/records/10643148/files/vae-ckpt.ckpt",
    "clap-ckpt.pt" : "https://zenodo.org/records/10643148/files/clap-ckpt.pt",
    "musicldm-ckpt.ckpt" : "https://zenodo.org/records/10643148/files/musicldm-ckpt.ckpt",
}

if all(os.path.isfile(use_ckpt_path + ckpt_name) for ckpt_name in ckpt_urls.keys()):
  op(c.ok, 'Checkpoint found:', use_ckpt_path)
else:
  if local_models_dir != '':
    models_dir = drive_root+fix_path(local_models_dir)
    if not os.path.isdir(models_dir):
      os.makedirs(models_dir)
    # for ckpt_url in ckpt_urls:
    #   use_ckpt = ckpt_url.split('files/')[1].split('?')[0]
    if all(os.path.isfile(models_dir + ckpt_name) for ckpt_name in ckpt_urls.keys()):
      op(c.title, 'Fetching local ckpt:', models_dir.replace(drive_root, '')+use_ckpt)
      for ckpt_name in ckpt_urls.keys():
          shutil.copy(os.path.join(models_dir, ckpt_name), use_ckpt_path)

      op(c.ok, 'Done.')
    else:
      op(c.warn, 'Downloading '+use_ckpt+' to ', models_dir.replace(drive_root, ''))

      %cd {models_dir}

      # !wget {ckpt_url} -O {models_dir}{use_ckpt}
      for name, ckpt_url in ckpt_urls.items():
        if not os.path.isfile(models_dir + name):
          !wget {ckpt_url}

        if not os.path.isfile(use_ckpt_path + name):
          shutil.copy(models_dir+name, use_ckpt_path)

      %cd {install_dir}

      op(c.ok, 'Done.')
  else:
    # for ckpt_url in ckpt_urls:
    #   use_ckpt = ckpt_url.split('files/')[1].split('?')[0]
    if quick_setup == False:
      models_dir = use_ckpt_path
      op(c.warn, 'Downloading', use_ckpt)
      # !wget {ckpt_url} -O {models_dir}{use_ckpt}

      for name, ckpt_url in ckpt_urls.items():
        if not os.path.isfile(models_dir + name):
          !wget {ckpt_url} -O {models_dir}{name}

      op(c.ok, 'Done.')
    else:
      op(c.warn, 'Skipping AudioLDM checkpoints...')

if download_example_wav and skip_setup==False:
  %cd {models_dir}
  if os.path.isfile("369920__mrthenoronha__cartoon-game-theme-loop.wav"):
    print("Example file is already in the google drive folder")
  else:
    !wget https://freesound.org/people/Mrthenoronha/sounds/369920/download/369920__mrthenoronha__cartoon-game-theme-loop.wav
  %cd {install_dir}
   
if use_github:
  _ckpt_path = use_ckpt_path+use_ckpt
  op(c.title, 'Build model', _ckpt_path)
  sys.path.append('/content/MusicLDM/interface/src')


  # from interface import text_to_audio, style_transfer, super_resolution_and_inpainting, build_model, latent_diffusion


##################################################### define functions here ######################################
%cd interface
from pytorch_lightning import seed_everything
import torch
from latent_diffusion.models.musicldm import MusicLDM, DDPM, DDIMSampler

from pipeline import text_to_audio, audio_to_audio, style_transfer, round_to_multiple


def build_model(
    ckpt_path=None,
    config=None,
    model_name="musicldm"
    ):
    print("Load AudioLDM: %s", model_name)

    # if(ckpt_path is None):
    #     ckpt_path = get_metadata()[model_name]["path"]

    # if(not os.path.exists(ckpt_path)):
    #     download_checkpoint(model_name)

    if torch.cuda.is_available():
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")

    if config is not None:
        assert type(config) is str
        config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
    else:
        # config = default_audioldm_config(model_name)
        pass

    # Use text as condition instead of using waveform during training
    # config["model"]["params"]["device"] = device
    config["model"]["params"]["cond_stage_key"] = "text"

    # Set coresponding models into config file:
    config["model"]["params"]["ckpt_path"] = ckpt_path+"musicldm-ckpt.ckpt"
    config['model']['params']['first_stage_config']['params']['reload_from_ckpt'] = ckpt_path+'vae-ckpt.ckpt'
    config['model']['params']['cond_stage_config']['params']['pretrained_path'] = ckpt_path+'clap-ckpt.pt'
    config['model']['params']['first_stage_config']['params']['ddconfig']['hifigan_ckpt'] = ckpt_path+'hifigan-ckpt.ckpt'


    # No normalization here
    # latent_diffusion = LatentDiffusion(**config["model"]["params"])
    latent_diffusion = MusicLDM(**config["model"]["params"])


    latent_diffusion.eval()
    latent_diffusion = latent_diffusion.to(device)

    latent_diffusion.cond_stage_model.embed_mode = "text"

    # Do not use unconditional_prob becasue we are in eval mode:
    latent_diffusion.cond_stage_model.unconditional_prob = 0.0

    return latent_diffusion

##################### Initilize model #####################


audioldm = build_model(ckpt_path=use_ckpt_path, config = "musicldm.yaml", model_name=use_checkpoint)  # _ckpt_path


##################### text to Audio with diffusers ########################

def text2audioDiffusers(text, negative_prompt, duration, guidance_scale, random_seed, n_candidates):
  waveforms = pipe(
      text,
      audio_length_in_s=duration,
      guidance_scale=guidance_scale,
      negative_prompt=negative_prompt,
      num_waveforms_per_prompt=n_candidates if n_candidates else 1,
      generator=generator.manual_seed(int(random_seed)),
  )["audios"]

  if waveforms.shape[0] > 1:
      waveform = score_waveforms(text, waveforms)
  else:
      waveform = waveforms[0]
  return waveform

def score_waveforms(text, waveforms):
  inputs = processor(text=text, audios=list(waveforms), return_tensors="pt", padding=True)
  inputs = {key: inputs[key].to(device) for key in inputs}
  with torch.no_grad():
      logits_per_text = clap_model(**inputs).logits_per_text  # this is the audio-text similarity score
      probs = logits_per_text.softmax(dim=-1)  # we can take the softmax to get the label probabilities
      most_probable = torch.argmax(probs)  # and now select the most likely audio waveform
  waveform = waveforms[most_probable]
  return waveform


prompt_list = []
used_durations = []

output.clear()
# !nvidia-smi
print()
op(c.title, 'Using:', use_ckpt, time=True)
op(c.ok, 'Setup finished.', time=True)
print()


In [5]:
# @title # Generate audio

prompt ="song with guitar solo playing; flying car"  #@param {type:"string"}
negative_prompt = "high quality" #@param {type:"string"}

output_dir = "output" #@param {type:"string"}
duration = 10 #@param {type:"slider", min:2.5, max:120, step:2.5}
guidance_scale = 2.5 #@param {type:"slider", min:2, max:5, step:0.5}
seed = 0 #@param {type:"integer"}
candidates = 3 #@param {type:"slider", min:2, max:5, step:1}
batch = 1 #@param {type:"integer"}
# use_diffusers = False


#@markdown <br>

#@markdown <br> Use "musicldm_model_local/508152__mrthenoronha__childish-theme-loop-3.wav in the init_audio_file if you downloaded an exmpale while setting up.

#@markdown #### <b>Style Transfer & Audio-to-Audio</b> settings – Ignore these settings if you just want to generate audio by text prompt.
init_audio_file = "" #@param {type:"string"}
style_strength = 0 #@param {type:"slider", min:0, max:1, step:0.05}

#@markdown <br>

display_initial_output = True
display_players = True
settings_in_filename = False


lowpass_cutoff = 200
hipass_cutoff = 11000
what_to_do = None
superresolution = False
trunc = 150

if what_to_do == 'Audio-to-audio-generation': action = 'audio2audio'
if what_to_do == 'Super-resolution': action = 'superres'
if what_to_do == 'Style Transfer': action = 'style'
if what_to_do == 'Inpaint': action = 'inpaint'

clean_every = 10 if duration <= 10 else 5
ddim_steps = 200
og_seed = seed
og_duration = duration
uniq_id = gen_id()
sr = 16000

# Prompt/input
if ';' in prompt:
  inputs = prompt.split(';')
elif prompt == 'prompt_list':
  inputs = prompt_list
  total_secs = batch * len(inputs) * duration
  total_per_cycle = clean_every * duration
  if total_per_cycle > 60:
    display_players = False
    print()
    op(c.warn, 'Audio players are disabled for this execution to keep Colab running smoothly.')
    print()
else:
  inputs = [prompt]

if isinstance(inputs[0], list):
  inputs = [x.strip() for x in inputs]

# Output
if output_dir == '':
  if mount_drive is True:
    dir_out = dir_tmp
  if mount_drive is False:
    dir_out = drive_root+'generated-audio/'
    if not os.path.isdir(dir_out):
      os.mkdir(dir_out)
else:
  if not os.path.isdir(drive_root+output_dir):
    os.makedirs(drive_root+output_dir)
  dir_out = drive_root+fix_path(output_dir)


og_dir_out = dir_out
if batch == 0: batch = 1
inputs = inputs * batch

timer_start = time.time()
total = len(inputs)
action = 'generate'
init_path = None

op(c.title, 'Run ID:', uniq_id, time=True)

if duration not in used_durations:
  op(c.warn, 'This generation takes a short while to start. Next generations will start instantly.', time=True)
  print()
  used_durations.append(duration)

for i, input in enumerate(inputs, 1):

  dir_out = og_dir_out

  if not i % clean_every+1:
    # output.clear()
    op(c.warn, 'Cell output is cleared every now and then  to keep Colab running smoothly.', time=True)
    op(c.warn, 'You can find all audio files from directory', dir_out.replace(drive_root, ''), time=True)
    print()
    op(c.title, 'Run ID:', uniq_id, time=True)

  prompt = input
  predefined_file_out = ''

  if isinstance(input, list):
    op(c.warn, 'Prompt-specific parameters found, ignoring form values.', time=True)
    prompt = input[0]
    seed = int(input[1])
    og_seed = seed
    guidance_scale = float(input[2])
    outd = input[3]
    if len(input) > 3:
      predefined_file_out = input[4]
    if outd != '':
      if not os.path.isdir(dir_out+outd):
        os.mkdir(dir_out+outd)
      dir_out = dir_out+outd+'/'

  ndx_info = str(i)+'/'+str(total)+' '
  print()

  if os.path.isfile(dir_out+predefined_file_out):
    op(c.warn, ndx_info+'Already exists, skipping', predefined_file_out)
    continue

  if init_audio_file != '':
    if os.path.isfile(drive_root+init_audio_file):
      init_path = drive_root+init_audio_file
      if superresolution is True:
        action = 'superres'
      elif style_strength > 0:
        init_filename = path_leaf(init_path)
        op(c.title, ndx_info+'Styling audio:', init_path.replace(drive_root, ''))
        op(c.title, 'With prompt:', prompt)
        action = 'style'
      else:
        op(c.title, ndx_info+'Audio-to-audio generation:', init_path.replace(drive_root, ''), time=True)
        # op(c.title, 'With prompt:', prompt, time=True)
        prompt = None
        action = 'audio2audio'
      # Trim duration if init duration is shorter than given duration
      init_y, init_sr = librosa.load(init_path, sr=None, mono=True)
      init_duration = librosa.get_duration(y = init_y, sr = init_sr)
      duration = round_to_multiple(init_duration, 2.5) if init_duration < og_duration else duration

    else:
      op(c.fail, ndx_info+'Init audio file not found!')
      sys.exit('Make sure init_audio_file is a valid audio file and a valid file path relative to your My Drive.')
  else:
    op(c.title, ndx_info+'Generating audio:', prompt)

    if isinstance(input, list):
      print('File:', path_leaf(predefined_file_out))
      print('Using seed:', seed)
      print('Using guidance scale:', guidance_scale)

  if og_seed == 0: seed = int(time.time()) - 1229904000 + random.randint(11111111, 99999999)


  addn = str(seed)+'_'+str(guidance_scale)+'_' if settings_in_filename == True else ''
  fo_head = dir_out+uniq_id+'_'+str(i).zfill(3)+'__'+addn

  if action == 'generate':
    if predefined_file_out != '':
      file_out = dir_out+predefined_file_out
    else:
      file_out = fo_head+slug(prompt)[:trunc]+'.wav'

    if use_diffusers == True:
      generated_audio = text2audioDiffusers(prompt, negative_prompt, duration, guidance_scale, seed, candidates)
      sf.write(file_out, generated_audio.T, sr, subtype='PCM_24')
      print("Saved file as:", file_out)
    else:
      # audioldm.logger_save_dir = dir_out
      # audioldm.logger_project = ""
      # audioldm.logger_version = ""
      file_out = uniq_id+'_'+str(i).zfill(3)+'__'+addn+slug(prompt)[:trunc]

      generated_audio = text_to_audio(audioldm,
                                      text = prompt,
                                      original_audio_file_path = None,
                                      seed = seed,
                                      duration=duration,
                                      guidance_scale=guidance_scale,
                                      ddim_steps=ddim_steps,
                                      n_candidate_gen_per_text=int(candidates),
                                      out_dir = dir_out,
                                      name = file_out
                                    )
      print("Saved file as:", generated_audio)

  ################################## audio2audio #############################################
  elif action == 'audio2audio':
    file_out = uniq_id+'_'+str(i).zfill(3)+'__'+addn+basename(init_path)+"_reconstucted" #+'.wav'

    generated_audio = audio_to_audio( audioldm,
                                      text = 'placeholder',
                                      original_audio_file_path = init_path,
                                      seed=seed,
                                      ddim_steps=ddim_steps,
                                      duration=duration,
                                      batchsize=1,
                                      guidance_scale=guidance_scale,
                                      n_candidate_gen_per_text=candidates,
                                      out_dir = dir_out,
                                      name = file_out)

    print("Saved file as:", generated_audio)
  # elif action == 'superres':
#     file_out = fo_head+basename(init_path)+'.wav'
#     y, sr = librosa.load(init_path, sr=None)
#     duration = librosa.get_duration(y = y, sr=sr)
#     if duration > 30: duration = 30
#     generated_audio = superres(None, duration, init_path, guidance_scale, seed, candidates, ddim_steps)

  elif action == 'style':
    file_out = uniq_id+'_'+str(i).zfill(3)+'__'+addn+basename(init_path)+"_style_transfered_as_"+slug(prompt)[:trunc]

    generated_audio = style_transfer( audioldm,
                                      text = prompt,
                                      original_audio_file_path = init_path,
                                      transfer_strength = style_strength,
                                      seed=seed,
                                      duration=duration,
                                      batchsize=1,
                                      guidance_scale=guidance_scale,
                                      ddim_steps=ddim_steps,
                                      config="musicldm.yaml",
                                      out_dir = dir_out,
                                      name = file_out

                                  )
    print("Saved file as:", generated_audio)

  else:
    op(c.fail, 'Something went wrong.')
    sys.exit()

  if display_players and display_initial_output:
    print()
    op(c.title, 'AudioLDM output:')
    print()
    audio_player(generated_audio, sr)
    print()



INFO:lightning_fabric.utilities.seed:Global seed set to 545731821


[90m2024-02-10 19:32:08 [96mRun ID:[0m rikgar

[96m1/2 Generating audio:[0m song with guitar solo playing
Waveform save path:  /content/faux_drive/output/
Generate audio using text song with guitar solo playing
Waveform save path:  /content/faux_drive/output/
Plotting: Switched to EMA weights
Use ddim sampler
Data shape for DDIM sampling is (3, 8, 256, 16), eta 1.0
Running DDIM Sampling with 200 timesteps


DDIM Sampler: 100%|██████████| 200/200 [01:04<00:00,  3.09it/s]


Similarity between generated audio and text tensor([0.4551, 0.5268, 0.4529], device='cuda:0')
Choose the following indexes: [1]
Plotting: Restored training weights
Saved file as: /content/faux_drive/output/rikgar_001__song_with_guitar_solo_playing.wav

[96mAudioLDM output:[0m



INFO:lightning_fabric.utilities.seed:Global seed set to 520971754




[96m2/2 Generating audio:[0m  flying car
Waveform save path:  /content/faux_drive/output/
Generate audio using text  flying car
Waveform save path:  /content/faux_drive/output/
Plotting: Switched to EMA weights
Use ddim sampler
Data shape for DDIM sampling is (3, 8, 256, 16), eta 1.0
Running DDIM Sampling with 200 timesteps


DDIM Sampler: 100%|██████████| 200/200 [01:04<00:00,  3.11it/s]


Similarity between generated audio and text tensor([0.2236, 0.3118, 0.1772], device='cuda:0')
Choose the following indexes: [1]
Plotting: Restored training weights
Saved file as: /content/faux_drive/output/rikgar_002___flying_car.wav

[96mAudioLDM output:[0m




