<a href="https://colab.research.google.com/github/olaviinha/NeuralTextToMusic/blob/main/mubert_txt2music.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#<font face="Trebuchet MS" size="6">Mubert <font color="#999" size="4">&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;</font><font color="#999" size="4">text2music</font><font color="#999" size="4">&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;</font><a href="https://github.com/olaviinha/NeuralTextToMusic" target="_blank"><font color="#999" size="4">Github</font></a>

This is a modified version of Mubert [Text-to-Music notebook](https://colab.research.google.com/github/ferluht/Mubert-Text-to-Music/blob/main/Mubert_Text_to_Music.ipynb). Original resides at [MubertAI Github](https://github.com/MubertAI/Mubert-Text-to-Music). It uses [Mubert](https://mubert.com/) [API](https://mubert2.docs.apiary.io/#) to generate tracks based on text prompt.

### Tips
- At the time of writing this, Mubert API endpoint used by this notebook does not appear extremely reliable; one minute it may work like a dream, the next minute every generation attempt will result in an error.

- If you fill `use_tags` field, prompt(s) will be ignored and a single track with provided tags is generated instead. I.e. use **either** prompt(s) or tag(s), they cannot be used simultaneously.

- You may generate multiple tracks with multiple prompts in one go by using semicolon (`;`) as a prompt separator.

- When tags are used, a minimum is 2 tags is required. Separate tags by comma (`,`) You will see a list of available tags after running this cell.

- If you have drive mounted and enter an `output_dir`, the notebook will attempt to auto-save all files there (prone to failures). Enter path relative to your Google Drive root, e.g. `music/mubert` if your Drive contains a directory called _music_, with a subdirectory called _mubert_.

In [None]:
#@title #Setup
#@markdown This cell needs to be run only once. It will mount your Google Drive and setup prerequisites.<br>
#@markdown <small>Mounting Drive will enable this notebook to save outputs directly to your Drive. Otherwise you will need to copy/download them manually from this notebook.</small>

force_setup = False
repositories = []
pip_packages = ''
apt_packages = ''
mount_drive = True #@param {type:"boolean"}
skip_setup = False #@ param {type:"boolean"}

# Download the repo from Github
import os
from google.colab import output
import warnings
warnings.filterwarnings('ignore')
%cd /content/

# inhagcutils
if not os.path.isfile('/content/inhagcutils.ipynb') and force_setup == False:
  !pip -q install import-ipynb {pip_packages}
  if apt_packages != '':
    !apt-get update && apt-get install {apt_packages}
  !curl -s -O https://raw.githubusercontent.com/olaviinha/inhagcutils/master/inhagcutils.ipynb
import import_ipynb
from inhagcutils import *

# Mount Drive
if mount_drive is True:
  if not os.path.isdir('/content/drive'):
    from google.colab import drive
    drive.mount('/content/drive')
    drive_root = '/content/drive/My Drive'
  if not os.path.isdir('/content/mydrive'):
    os.symlink('/content/drive/My Drive', '/content/mydrive')
    drive_root = '/content/mydrive/'
  drive_root_set = True
else:
  create_dirs(['/content/faux_drive'])
  drive_root = '/content/faux_drive/'

if len(repositories) > 0 and skip_setup == False:
  for repo in repositories:
    %cd /content/
    install_dir = fix_path('/content/'+path_leaf(repo).replace('.git', ''))
    repo = repo if '.git' in repo else repo+'.git'
    !git clone {repo}
    if os.path.isfile(install_dir+'setup.py') or os.path.isfile(install_dir+'setup.cfg'):
      !pip install -e ./{install_dir}
    if os.path.isfile(install_dir+'requirements.txt'):
      !pip install -r {install_dir}/requirements.txt

if len(repositories) == 1:
  %cd {install_dir}

dir_tmp = '/content/tmp/'
create_dirs([dir_tmp])

import time, sys
from datetime import timedelta





import subprocess, time
print("Setting up environment...")
start_time = time.time()
all_process = [
    ['pip', 'install', 'torch==1.12.1+cu113', 'torchvision==0.13.1+cu113', '--extra-index-url', 'https://download.pytorch.org/whl/cu113'],
    ['pip', 'install', '-U', 'sentence-transformers'],
    ['pip', 'install', 'httpx'],
]
for process in all_process:
    running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')

end_time = time.time()
print(f"Environment set up in {end_time-start_time:.0f} seconds")






import numpy as np
from sentence_transformers import SentenceTransformer
minilm = SentenceTransformer('all-MiniLM-L6-v2')

mubert_tags_string = 'tribal,action,kids,neo-classic,run 130,pumped,jazz / funk,ethnic,dubtechno,reggae,acid jazz,liquidfunk,funk,witch house,tech house,underground,artists,mystical,disco,sensorium,r&b,agender,psychedelic trance / psytrance,peaceful,run 140,piano,run 160,setting,meditation,christmas,ambient,horror,cinematic,electro house,idm,bass,minimal,underscore,drums,glitchy,beautiful,technology,tribal house,country pop,jazz & funk,documentary,space,classical,valentines,chillstep,experimental,trap,new jack swing,drama,post-rock,tense,corporate,neutral,happy,analog,funky,spiritual,sberzvuk special,chill hop,dramatic,catchy,holidays,fitness 90,optimistic,orchestra,acid techno,energizing,romantic,minimal house,breaks,hyper pop,warm up,dreamy,dark,urban,microfunk,dub,nu disco,vogue,keys,hardcore,aggressive,indie,electro funk,beauty,relaxing,trance,pop,hiphop,soft,acoustic,chillrave / ethno-house,deep techno,angry,dance,fun,dubstep,tropical,latin pop,heroic,world music,inspirational,uplifting,atmosphere,art,epic,advertising,chillout,scary,spooky,slow ballad,saxophone,summer,erotic,jazzy,energy 100,kara mar,xmas,atmospheric,indie pop,hip-hop,yoga,reggaeton,lounge,travel,running,folk,chillrave & ethno-house,detective,darkambient,chill,fantasy,minimal techno,special,night,tropical house,downtempo,lullaby,meditative,upbeat,glitch hop,fitness,neurofunk,sexual,indie rock,future pop,jazz,cyberpunk,melancholic,happy hardcore,family / kids,synths,electric guitar,comedy,psychedelic trance & psytrance,edm,psychedelic rock,calm,zen,bells,podcast,melodic house,ethnic percussion,nature,heavy,bassline,indie dance,techno,drumnbass,synth pop,vaporwave,sad,8-bit,chillgressive,deep,orchestral,futuristic,hardtechno,nostalgic,big room,sci-fi,tutorial,joyful,pads,minimal 170,drill,ethnic 108,amusing,sleepy ambient,psychill,italo disco,lofi,house,acoustic guitar,bassline house,rock,k-pop,synthwave,deep house,electronica,gabber,nightlife,sport & fitness,road trip,celebration,electro,disco house,electronic'
mubert_tags = np.array(mubert_tags_string.split(','))
mubert_tags_embeddings = minilm.encode(mubert_tags)

from IPython.display import Audio, display
import httpx
import json

def get_track_by_prompt(prompt, pat, duration, maxit=20, autoplay=False, loop=False, intensity="High", format="MP3"):
  job_status = 1
  if loop:
    mode = "loop"
  else:
    mode = "track"
  r = httpx.post('https://api-b2b.mubert.com/v2/TTMRecordTrack', 
      json={
          "method":"RecordTrackTTM",
          "params": {
              "pat": pat, 
              "duration": duration,
              "format": format,
              "bitrate": 192,
              "intensity": intensity,
              "text": prompt,
              "mode": mode,
          }
      })

  rdata = json.loads(r.text)
  assert rdata['status'] == 1, rdata['error']['text']
  trackurl = rdata['data']['tasks'][0]['download_link']
        
  for i in range(maxit):
    r = httpx.post('https://api-b2b.mubert.com/v2/TrackStatus', 
        json={
          "method":"TrackStatus",
          "params":
            {
                "pat": pat
            }
        })
    rdata = json.loads(r.text)
    job_status = rdata['data']['tasks'][0]['task_status_code']
    # print( rdata )
    # print( job_status )
    if job_status == 2:
      time.sleep(5)
      display(Audio(trackurl, autoplay=autoplay))
      break
    time.sleep(1)

def get_track_by_tags(tags, pat, duration, maxit=20, autoplay=False, loop=False, intensity="High", format="MP3"):
  job_status = 1
  if loop:
    mode = "loop"
  else:
    mode = "track"
  r = httpx.post('https://api-b2b.mubert.com/v2/RecordTrackTTM', 
      json={
          "method":"RecordTrackTTM",
          "params": {
              "pat": pat, 
              "duration": duration,
              "format": format,
              "bitrate": 192,
              "intensity": intensity,
              "tags": tags,
              "mode": mode,
          }
      })

  rdata = json.loads(r.text)
  assert rdata['status'] == 1, rdata['error']['text']
  trackurl = rdata['data']['tasks'][0]['download_link']
        
  for i in range(maxit):
    r = httpx.post('https://api-b2b.mubert.com/v2/TrackStatus', 
        json={
          "method":"TrackStatus",
          "params":
            {
                "pat": pat
            }
        })
    rdata = json.loads(r.text)
    job_status = rdata['data']['tasks'][0]['task_status_code']
    # print( rdata )
    # print( job_status )
    if job_status == 2:
      time.sleep(5)
      display(Audio(trackurl, autoplay=autoplay))
      break
    time.sleep(1)
  
  # for i in range(maxit):
  #     r = httpx.get(trackurl)
  #     if r.status_code == 200:
  #         time.sleep(2)
  #         display(Audio(trackurl, autoplay=autoplay))
  #         break
  #     time.sleep(1)
  return job_status, trackurl

def find_similar(em, embeddings, method='cosine'):
    scores = []
    for ref in embeddings:
        if method == 'cosine': 
            scores.append(1 - np.dot(ref, em)/(np.linalg.norm(ref)*np.linalg.norm(em)))
        if method == 'norm': 
            scores.append(np.linalg.norm(ref - em))
    return np.array(scores), np.argsort(scores)

def get_tags_for_prompts(prompts, top_n=3, debug=False):
    prompts_embeddings = minilm.encode(prompts)
    ret = []
    for i, pe in enumerate(prompts_embeddings):
        scores, idxs = find_similar(pe, mubert_tags_embeddings)
        top_tags = mubert_tags[idxs[:top_n]]
        top_prob = 1 - scores[idxs[:top_n]]
        if debug:
            print(f"Prompt: {prompts[i]}\nTags: {', '.join(top_tags)}\nScores: {top_prob}\n\n\n")
        ret.append((prompts[i], list(top_tags)))
    return ret













# email = "email@domain.com" #@param {type:"string"}
email = gen_id()+'@'+gen_id()+'.com'

r = httpx.post('https://api-b2b.mubert.com/v2/GetServiceAccess', 
    json={
        "method":"GetServiceAccess",
        "params": {
            "email": email,
            "license":"ttmmubertlicense#f0acYBenRcfeFpNT4wpYGaTQIyDI4mJGv5MfIhBFz97NXDwDNFHmMRsBSzmGsJwbTpP1A6i07AXcIeAHo5",
            "token":"4951f6428e83172a4f39de05d5b3ab10d58560b8",
            "mode": "loop"
        }
    })

rdata = json.loads(r.text)
assert rdata['status'] == 1, "probably incorrect e-mail"
pat = rdata['data']['pat']
print(f'Got token: {pat}')



































output.clear()
# !nvidia-smi
op(c.ok, 'Setup finished.')

def divide_chunks(l, n):
    for i in range(0, len(l), n):
        yield l[i:i + n]

available_tags = mubert_tags_string.split(',')
available_tags.sort() 
available_tag_chunks = list(divide_chunks(available_tags, 8))
print()
op(c.title, 'Available tags (for use_tags field):')
for chunk in available_tag_chunks:
  print( ', '.join(chunk) )

In [None]:
#@title # Generate tracks
prompt = "grandma's blueberries; good ol country folk ballad; enter the void" #@param {type:"string"}

#@markdown <small>If you use tags, prompt(s) will be ignored. At least two tags are required.</small>
use_tags = "" #@param {type:"string"}

duration = "1:00" #@param {type:"string"}
music_intensity = "high" #@param ["low", "medium", "high"]
format = "mp3" #@param ["mp3", "flac", "wav"]
output_dir = "" #@param {type:"string"}



end_session_when_done = False #@ param {type: "boolean"}


uniq_id = gen_id()
tags_used = False
intensity = music_intensity

if use_tags == '':
  if ';' in prompt:
    prompts = [x.strip() for x in prompt.split(';')]
  else:
    prompts = [prompt]
else:
  tags_used = True
  prompts = ['tags-only']

if ':' in duration:
  m, s = duration.split(':')
  duration = int(m) * 60 + int(s)
elif 'm' in duration:
  m, s = duration.split('m')
  if s == '': s = 0
  duration = (int(m) * 60) + int(s.replace('s', ''))
else:
  duration = int(duration.replace('s', ''))

# Output
if output_dir == '':
  dir_out = dir_tmp
else:
  if not os.path.isdir(drive_root+output_dir):
    os.mkdir(drive_root+output_dir)
  dir_out = drive_root+fix_path(output_dir)
  
timer_start = time.time()
total = len(prompts)

# -- DO THINGS --
if tags_used is True:
  if ',' in use_tags: 
    use_tags = [x.strip() for x in use_tags.split(',')]
  else:
    use_tags = [use_tags]
  tags = [[prompts[0], use_tags]]
else:
  tags = get_tags_for_prompts(prompts)

retry_urls = []

for i, tag in enumerate(tags, 1):
  
  used_prompt = tag[0]
  used_tags = tag[1]
  
  if tags_used is True:
    op(c.title, str(i)+'/'+str(total)+' Tags: '+', '.join(used_tags), time=True)
    thing = ', '.join(used_tags)
  else:
    op(c.title, str(i)+'/'+str(total)+' Prompt: '+used_prompt, time=True)
    op(c.okb, 'Tags: '+', '.join(used_tags), time=True)
    thing = used_prompt
  print()
  try:
    
    status, url = get_track_by_tags(used_tags, pat, duration, maxit=30, autoplay=False, intensity=intensity, format=format)
    if output_dir != '':
      time.sleep(1)
      if tags_used is True:
        file_out = dir_out+slug('_'.join(used_tags))+'__'+uniq_id+'-'+str(i)+'.'+format
      else:
        file_out = dir_out+slug(used_prompt)+'__'+uniq_id+'-'+str(i)+'.'+format
      print()
      if status == 2:
        !wget -q -O {file_out} {url}
        if os.path.isfile(file_out) and os.path.getsize(file_out) > 0:
          op(c.ok, 'Saved as', file_out.replace(drive_root, ''), time=True)
        elif os.path.isfile(file_out) and os.path.getsize(file_out) == 0:
          os.remove(file_out)
          op(c.fail, 'Received a file of 0 bytes, skipping auto-save for now.', time=True)
          # op(c.fail, 'Received a file of 0 bytes: file not auto-saved.')
          if total == 1: op(c.fail, 'You can try to download the file manually from this link:', url)
          retry_urls.append([thing, file_out, url])
        else:
          op(c.fail, 'ERROR saving file', file_out.replace(drive_root, ''), time=True)
          # op(c.fail, 'You can try to download the file manually from this link:', url)
      else:
        op(c.fail, 'Generation is taking too long, skipping auto-save for now.', time=True)
        if total==1: op(c.fail, 'You can try to download the file manually from this link:', url)
        retry_urls.append([thing, file_out, url])
    print()
  except Exception as e:
    print(str(e))
  print()


if len(retry_urls) > 0 and total > 1:
  op(c.warn, '- - - - - - - - - - - - - - - - - - - - - -')
  op(c.warn, 'Checking if skipped tracks are ready yet...')
  print()
  time.sleep(3)
  for retry_url in retry_urls:
    thing, file_out, url = retry_url
    op(c.title, thing, time=True)
    print()
    !wget -q -O {file_out} {url}
    if os.path.isfile(file_out) and os.path.getsize(file_out) > 0:
      display(Audio(file_out, autoplay=False))
      op(c.ok, 'Saved as', file_out.replace(drive_root, ''), time=True)
    elif os.path.isfile(file_out) and os.path.getsize(file_out) == 0:
      os.remove(file_out)
      op(c.fail, 'Received a file of 0 bytes again, file still not saved.', time=True)
      op(c.fail, 'You can try to download the file manually from this link:', url, time=True)
    else:
      op(c.fail, 'ERROR saving file', file_out.replace(drive_root, ''), time=True)
      op(c.fail, 'You can try to download the file manually from this link:', url, time=True)
    print()

# -- END THINGS --

timer_end = time.time()

print()
op(c.okb, 'Elapsed', timedelta(seconds=timer_end-timer_start), time=True)
op(c.ok, 'FIN.')

if end_session_when_done is True: end_session()