<a href="https://colab.research.google.com/github/buganart/descriptor-transformer/blob/main/descriptor_model_predict.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@markdown Before starting please save the notebook in your drive by clicking on `File -> Save a copy in drive`

In [None]:
#@markdown Check GPU, should be a Tesla V100
!nvidia-smi -L
import os
print(f"We have {os.cpu_count()} CPU cores.")

In [None]:
#@markdown Mount google drive
from google.colab import drive
from google.colab import output
drive.mount('/content/drive')

from pathlib import Path
if not Path("/content/drive/My Drive/IRCMS_GAN_collaborative_database").exists():
    raise RuntimeError(
        "Shortcut to our shared drive folder doesn't exits.\n\n"
        "\t1. Go to the google drive web UI\n"
        "\t2. Right click shared folder IRCMS_GAN_collaborative_database and click \"Add shortcut to Drive\""
    )

def clear_on_success(msg="Ok!"):
    if _exit_code == 0:
        output.clear()
        print(msg)

In [None]:
#@markdown Install wandb and log in
%pip install wandb
output.clear()
import wandb
from pathlib import Path
wandb_drive_netrc_path = Path("drive/My Drive/colab/.netrc")
wandb_local_netrc_path = Path("/root/.netrc")
if wandb_drive_netrc_path.exists():
    import shutil

    print("Wandb .netrc file found, will use that to log in.")
    shutil.copy(wandb_drive_netrc_path, wandb_local_netrc_path)
else:
    print(
        f"Wandb config not found at {wandb_drive_netrc_path}.\n"
        f"Using manual login.\n\n"
        f"To use auto login in the future, finish the manual login first and then run:\n\n"
        f"\t!mkdir -p '{wandb_drive_netrc_path.parent}'\n"
        f"\t!cp {wandb_local_netrc_path} '{wandb_drive_netrc_path}'\n\n"
        f"Then that file will be used to login next time.\n"
    )

!wandb login
output.clear()
print("ok!")

# Description

This notebook generates music (.wav) based on runs from the wandb project "demiurge/descriptor_model". You may access the training models through [train.ipynb](https://github.com/buganart/descriptor-transformer/blob/main/descriptor_model_train.ipynb). The user will need to specify a **test_data_path** for a sound file folder (.wav), the notebook will generate descriptors (.json) for each sound file and convert those descriptors back into the same (.wav) format. The generated sound files will be the prediction of potential subsequent sounds for the input files.

To generate such predictive sound files, this notebook will first 


1.   process input music files in **test_data_path** and music descriptor database specified in **audio_dir** to descriptors if the files in the folder is not being processed into descriptors. The processed descriptors will be saved in the same path in the "processed_descriptors" folder. If they have already been processed, this step will be skipped. Note that **hop length** and **sampling rate(sr)** are parameters for processing music to descriptors. 
2.   load trained descriptor model from wandb project "demiurge/descriptor_model". Set **resume_run_id** directly, then the saved checkpoint of the run will be downloaded. The model loaded from the checkpoint will predict the subsequent descriptors based on **prediction_length**.
3.   query the predicted descriptors to the music descriptor database specified in **audio_dir**. The predicted descriptors will be replaced by the descriptors in the music descriptor database based on distance function such as euclidean, cosine, minkowski
4.   process the descriptor in the descriptor database and match them back to the music segment where it is extracted. Then, those music segments will be merged together into the generated music file. Note that **crossfade** is a parameter in the merging process. The generated music files will be saved in the **output_dir**.





In [None]:
#@title CONFIGURATION

#@markdown Directories can be found via file explorer on the left menu by navigating to `drive`  and then to the desired folders. 
#@markdown Then right-click and `Copy path`.

#@markdown ### #descriptor model input

#@markdown The descriptor will extract a .json file containing *spectral centroid/spectral flatness/fundamental frequency/spectral rolloff/RMS* data from the test_data_path .wavs below. The model will predict **prediction_length** descriptors to follow the test descriptor files.
#@markdown - This is the **Prediction DB** containing data for the model to generate next descriptors.
#@markdown - The model will predict next **prediction_length** descriptors given **window_size**(specified in the model) descriptors
#@markdown - if test_data_path is a path to a music directory, descriptors will be extracted from **test_data_path** and saved in **output_dir**
# test_data_path = "/content/drive/My Drive/AUDIO DATABASE/MUSIC TRANSFORMER/Transformer Corpus/" #@param {type:"string"}
# test_data_path = "/content/drive/My Drive/AUDIO DATABASE/MUSIC TRANSFORMER/sample_descriptor_files" #@param {type:"string"}
test_data_path = "/content/drive/My Drive/AUDIO DATABASE/TESTING/" #@param {type:"string"}

#@markdown ### #descriptor database

#@markdown - the path to the wav. file database to generate the descriptor database
#@markdown - This is the **RAW generated audio DB** which is only for the query and playback engine.
#@markdown - The descriptors predicted from the model need to be converted back to music. The files in this dataset will create a database with descriptor-sound mapping, which is used for converting descriptors back to music.
audio_dir = "/content/drive/My Drive/AUDIO DATABASE/TESTING/" #@param {type:"string"}

#@markdown - descriptors will be extracted from the **audio_dir** above but if your provide a input_db_filename that path will be used instead
# input_db_filename = f"/content/drive/My Drive/Descriptor Model/robertos_output.json" #@param {type:"string"}
# input_db_filename = "/content/drive/My Drive/AUDIO DATABASE/TESTING/output_descriptor_database.json" #@param {type:"string"}
input_db_filename = "" #@param {type:"string"}

#@markdown ### #resumption of previous runs
#@markdown Optional resumption arguments below, leaving both empty will start a new run from scratch.
#@markdown - The IDs can be found on Wanda. It is 8 characters long and may contain a-z letters and digits (for example **1t212ycn**)

#@markdown Resume a previous run 
resume_run_id = "lny7atep" #@param {type:"string"}

#@markdown ### #descriptors / sound parameter
#@markdown - the number of predicted descriptors after the **test_data**
prediction_length =  40#@param {type:"integer"}

#@markdown - wav parameters (hop length, sampling rate, crossfade)
hop_length = 1024 #@param {type:"integer"}
sr = 44100 #@param {type:"integer"}
crossfade = 22 #@param {type:"integer"}

#@markdown ### #save location
#@markdown - the path to save all generated files
output_dir = f"/content/drive/My Drive/Descriptor Model/OUTPUTS/{resume_run_id}" #@param {type:"string"}
# #@markdown name of generated files
# #@markdown - the file storing generated descriptors from the model
# generated_descriptor_filename = "AUDIOS_output.json" #@param {type:"string"}
# #@markdown - the file storing closest match query descriptors based on generated descriptors
# query_descriptor_filename = "query_output.json" #@param {type:"string"}
# #@markdown - the final wav file from combining music source represented by the query descriptors
# final_wav_filename = "output.wav" #@param {type:"string"}


hop_length = int(hop_length)
sr = int(sr)
crossfade = int(crossfade)

import re
from pathlib import Path
from argparse import Namespace

def check_wandb_id(run_id):
    if run_id and not re.match(r"^[\da-z]{8}$", run_id):
        raise RuntimeError(
            "Run ID needs to be 8 characters long and contain only letters a-z and digits.\n"
            f"Got \"{run_id}\""
        )

check_wandb_id(resume_run_id)

output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

#remove existing files
output_dir_files = output_dir.rglob("*.*")
for i in output_dir_files:
    i.unlink()


colab_config = {
    "resume_run_id": resume_run_id,
    "test_data_path": test_data_path,
    "prediction_length": prediction_length,
    "output_dir": output_dir,
}

for k, v in colab_config.items():
    print(f"=> {k:20}: {v}")

config = Namespace(**colab_config)
config.seed = 1234

In [None]:
#@markdown Install dependency and functions
%pip install --upgrade git+https://github.com/buganart/descriptor-transformer.git#egg=desc
from desc.train_function import get_resume_run_config, init_wandb_run, setup_model, setup_datamodule
from desc.helper_function import save_descriptor_as_json, dir2descriptor, save_json, get_dataframe_from_json

%pip install --upgrade librosa
import librosa

import numpy as np
import json
import os, os.path
from IPython.display import HTML, display
import time
import shutil

import pandas as pd
from numba import jit, cuda 
from scipy.spatial.distance import cosine, minkowski, euclidean
import torch


%pip install pydub
%pip install ffmpeg

from pydub import AudioSegment
from pydub.playback import play


def progress(value, max=100):
    return HTML("""
        <progress
            value='{value}'
            max='{max}',
            style='width: 100%'
        >
            {value}
        </progress>
    """.format(value=value, max=max))

clear_on_success()

## WAV TO DESCRIPTOR







In [None]:
#process input descriptor database if needed
if not input_db_filename:
    save_path = output_dir
    db_descriptors = dir2descriptor(audio_dir, hop=hop_length, sr=sr)

    #combine descriptors from multiple files
    data_dict = {}
    for filename, descriptor in db_descriptors:
        for element in descriptor:
            if element in data_dict:
                data_dict[element] = data_dict[element] + descriptor[element]
            else:
                data_dict[element] = descriptor[element]
    
    #replace empty input_db_filename by savefile name
    input_db_filename = Path(save_path) / "AUDIOS_database.json"
    save_json(input_db_filename, data_dict)

## DESCRIPTOR MODEL GENERATOR


In [None]:
config = get_resume_run_config(resume_run_id)
config.resume_run_id = resume_run_id
config.audio_db_dir = test_data_path
# please check window_size (if window_size is too large, 0 descriptor samples will be extracted.)
#print(config.window_size)

run = init_wandb_run(config, run_dir="./", mode="offline")
model,_ = setup_model(config, run)
model.eval()
#construct test_data
testdatamodule = setup_datamodule(config, run, isTrain=False)
test_dataloader = testdatamodule.test_dataloader()
test_data, fileindex = next(iter(test_dataloader))

prediction = model.predict(test_data, prediction_length)

#un normalize output
prediction = prediction * testdatamodule.dataset_std + testdatamodule.dataset_mean

generated_dir = output_dir / "generated_descriptors"
generated_dir.mkdir(parents=True, exist_ok=True)

save_descriptor_as_json(generated_dir, prediction, fileindex, testdatamodule, resume_run_id)
print("ok!")

## QUERY FUNCTION


In [None]:
query_dir = output_dir / "query_descriptors"
query_dir.mkdir(parents=True, exist_ok=True)
print("query_dir:", query_dir)

# import df1 (UnaGAN output)
input_db_filename = Path(input_db_filename)
df1 = get_dataframe_from_json(input_db_filename)

# import df2 (Descriptor GAN output)
generated_file_list = generated_dir.rglob("*.*")
generated_dataframe_list = []
for filepath in generated_file_list:
    df2 = get_dataframe_from_json(filepath)
    generated_dataframe_list.append((filepath, df2))

In [None]:
#####   modified (batch)
for filepath, df2 in generated_dataframe_list:
    #record runtime
    current_time = time.time()
    dict_key1 = list(df2.columns)[0]
    input_len = len(df2[dict_key1])
    column_list = list(df2.columns)
    input_array = torch.tensor(df2.loc[:, column_list].to_numpy(dtype=np.float32)).cuda()
    db = torch.tensor(df1.loc[:, column_list].to_numpy(dtype=np.float32)).cuda()


    # not enough RAM for array of shape (input_len, db_len)
    batch_size = 4096
    results_all = []
    for i in range(int(input_len/batch_size)+1):
        x = i * batch_size
        x_ = (i+1) * batch_size
        if x_ > input_len:
            x_ = input_len
        input = input_array[x:x_]
        dist = torch.cdist(input, db, p=2)
        results = torch.argmin(dist, axis=1).cpu().numpy()
        results_all.append(results)

    results_all = np.concatenate(results_all).flatten()

    id_array = df1["_id"][results_all]
    sample_array = df1["_sample"][results_all]

    data={
        "_id": id_array.tolist(), 
        "_sample": sample_array.tolist()
    }
    print("finished - saving as JSON now")

    
    savefile = query_dir / (str(filepath.stem) + ".json")
    with open(savefile, 'w') as outfile:
        json.dump(data, outfile, indent=2)

    print("descriptors are replaced by query descriptors in database. save file path: ", savefile)

    #record runtime
    step_time = time.time() - current_time
    print("time used:", step_time)

## PLAYBACK ENGINE



In [None]:
wav_dir = output_dir / "wav_output"
wav_dir.mkdir(parents=True, exist_ok=True)
print("wav_dir:", wav_dir)

query_file_list = query_dir.rglob("*.*")
query_dataframe_list = []
for filepath in query_file_list:
    to_play = get_dataframe_from_json(filepath)
    query_dataframe_list.append((filepath, to_play))

In [None]:
for filepath, to_play in query_dataframe_list:
    output_filename = wav_dir / (str(filepath.stem) + ".wav")
    
    # output_filename = output_dir / final_wav_filename

    if os.path.exists(output_filename):
        os.remove(output_filename)

    no_samples = len(to_play["_sample"])
    out = display(progress(0, no_samples), display_id = True)

    concat = AudioSegment.from_wav(to_play["_id"][0])
    hop = (hop_length / sr) * 1000
    startpos = int((float(to_play["_sample"][0]) / hop_length) * hop)

    concat = concat[startpos:startpos + hop]

    for x in range(1, no_samples):
        print(to_play["_id"][x])
        to_concat = AudioSegment.from_wav(to_play["_id"][x])
        startpos = int((float(to_play["_sample"][x]) / hop_length) * hop)
        if (startpos < crossfade): 
            thiscrossfade = 0
        else: 
            to_concat = to_concat[startpos - (crossfade / 2):startpos + hop]
            thiscrossfade = crossfade
        out.update(progress(x + 1, no_samples))

        concat = concat.append(to_concat, crossfade = thiscrossfade)

    concat.export(output_filename, format = "wav")