# Metric learning training set preparation

This notebook prepares a training set from the Mixing Secrets collection for training the timbre metric. Run `../download.ipynb` first.

Copyright 2020 InterDigital R&D and Télécom Paris.  
Author: Ondřej Cífka

In [1]:
import collections
import concurrent.futures as cf
import glob
import hashlib
import json
import os
import regex as re
import shutil
import sys
import traceback

import essentia
import essentia.standard as estd
import numpy as np
import soundfile as sf
from tqdm.auto import tqdm

In [66]:
INPUT_DIR = '../download'
OUTPUT_DIR = 'wav_16kHz'
SR_IN = 44100  # Essentia default
SR = 16000
MAX_FILES_PER_SONG = 12
MAX_SEGMENTS_PER_FILE = 6

In [67]:
TRACK_BLACKLIST = [
    r'over\b', 'overhead', 'room'
]
TRACK_BLACKLIST_RE = re.compile('(' + '|'.join(x for x in TRACK_BLACKLIST) + ')')

In [68]:
def normalize_track_name(name):
    name = re.sub(r'[^\p{L}]', ' ', name)
    name = re.sub(r'(\p{Ll})(\p{Lu})', r'\1 \2', name)
    name = re.sub(r'\s+', ' ', name)
    name = name.strip().lower()
    return name

In [69]:
audio_paths = []
excluded_names = set()
included_names = set()

for song_path in glob.glob(os.path.join(INPUT_DIR, '*')):
    # Use song name as seed
    seed = os.path.basename(song_path).encode()
    seed = int.from_bytes(hashlib.sha512(seed).digest(), 'big')
    rng = np.random.default_rng(seed=seed)

    song_audio_paths = []
    for path in glob.glob(os.path.join(song_path, '**', '*.*'), recursive=True):
        if not os.path.splitext(path)[1] in ['.wav', '.flac']:
            continue
        name, _ = os.path.splitext(os.path.basename(path))
        name = normalize_track_name(name)
        if TRACK_BLACKLIST_RE.search(name) or len(name) == 0:
            excluded_names.add(name)
            continue
        included_names.add(name)
        song_audio_paths.append(path)
    
    song_audio_paths.sort()
    rng.shuffle(song_audio_paths)
    audio_paths.extend(song_audio_paths[:MAX_FILES_PER_SONG])
audio_paths.sort()

In [70]:
def process_file(path):
    # Use filename as seed
    seed = os.path.basename(path).encode()
    seed = int.from_bytes(hashlib.sha512(seed).digest(), 'big')
    rng = np.random.default_rng(seed=seed)

    audio = estd.EasyLoader(filename=path)()
    frame_size = SR_IN // 10
    silence_alg = estd.SilenceRate(thresholds=[essentia.db2lin(-60 / 2)])
    silence = np.array([silence_alg(frame) for frame in 
                        estd.FrameGenerator(audio, frameSize=frame_size, hopSize=frame_size)]).reshape(-1)
    
    # Find 8-second segments with < 50% silence and starting with a non-silent frame
    n = 8 * SR_IN // frame_size
    silence_cumsum = np.pad(np.cumsum(silence), (1, 0))
    silence_sums = silence_cumsum[n + 1:-1] - silence_cumsum[1:-n - 1]
    [candidates] = np.where((silence_sums < n // 2) & (silence[:-n - 1] == 0))
    candidates *= frame_size

    rng.shuffle(candidates)
    if len(candidates) < 2:
        return []
    if len(candidates) % 2 == 1:
        candidates = candidates[:-1]
    candidates = candidates[:MAX_SEGMENTS_PER_FILE]

    meta = []
    i_len = len(str(MAX_SEGMENTS_PER_FILE - 1))
    for i in candidates:
        out_path = os.path.join(OUTPUT_DIR,
                                os.path.relpath(path, INPUT_DIR).replace(os.path.sep, '_'))
        out_path = os.path.splitext(out_path)[0]
        out_path = f'{out_path}.{str(i).zfill(i_len)}.wav'
        segment_audio = audio[i:i + 8 * SR_IN]
        sf.write(out_path,
                 estd.Resample(inputSampleRate=SR_IN, outputSampleRate=SR)(segment_audio),
                 samplerate=SR)

        meta.append({
            'path': os.path.relpath(out_path, OUTPUT_DIR),
            'track_name': os.path.relpath(path, INPUT_DIR).split(os.path.sep)[0],
            'src_path': os.path.relpath(path, INPUT_DIR),
            'src_offset': i / SR_IN
        })
    return meta

In [72]:
os.makedirs(OUTPUT_DIR)
with cf.ProcessPoolExecutor(16) as pool:
    results = [item
               for items in tqdm(pool.map(process_file, audio_paths, chunksize=4), total=len(audio_paths))
               for item in items]

HBox(children=(FloatProgress(value=0.0, max=4874.0), HTML(value='')))




In [73]:
with open('metadata_single.json', 'w') as f:
    json.dump(results, f)

In [74]:
results_by_file = collections.defaultdict(list)
for item in results:
    results_by_file[item['src_path']].append(item)
results_by_song = collections.defaultdict(list)
for item in results:
    results_by_song[item['track_name']].append(item)

In [75]:
# Generate triplets: anchor, positive example, negative example
# Positive examples are from the same file
# Negative examples are from different songs

rng = np.random.default_rng(seed=0)

with open('triplets', 'w') as f:
    for src_path, items in sorted(results_by_file.items()):
        name, _ = os.path.splitext(os.path.basename(src_path))
        name = normalize_track_name(name)
        
        items = list(items)
        rng.shuffle(items)

        for anchor, positive in zip(items[::2], items[1::2]):
            negative_track_name = rng.choice([tn for tn in results_by_song if tn != anchor['track_name']])
            negative = rng.choice(results_by_song[negative_track_name])
            print(*(os.path.join(OUTPUT_DIR, x['path']) for x in [anchor, positive, negative]),
                  sep='\t', file=f)

In [76]:
!shuf <triplets >triplets_shuf
!head -n -400 triplets_shuf >triplets_train
!tail -n 200 triplets_shuf >triplets_test
!tail -n 400 triplets_shuf | head -n 200 >triplets_val

In [77]:
!wc -l triplets*

   7781 triplets
   7781 triplets_shuf
    200 triplets_test
   7381 triplets_train
    200 triplets_val
  23343 total
