# Imports and Constants

In [3]:
import jukemirlib
from pydub import AudioSegment
from tqdm import tqdm
import numpy as np
import os
import IPython.display as ipd

In [None]:
JUKEBOX_SR = 44100
CTX_WINDOW_LENGTH = 1048576

# Functions

In [None]:
def extract_jukebox_samples(file_path, segment_length_s=10, cutoff='pad', overlap=0.1):

    audio = AudioSegment.from_file(file_path)
    audio = audio.set_frame_rate(JUKEBOX_SR) # convert to JUKEBOX_SR    
    
    segment_length_ms = segment_length_s * 1000
    step = int((1-overlap)*segment_length_ms)
    sample_len = int(JUKEBOX_SR*segment_length_s)
    
    segment_array = []
    for i in range(0, len(audio), step):
        # create segment
        start_time = i 
        end_time = i + segment_length_ms
        segment = audio[start_time:end_time]
        
        # to np array
        segment = np.array(segment.get_array_of_samples(), dtype=np.float16).reshape((-1, segment.channels))
        segment = segment.T
        
        if segment.ndim == 1:
            segment = segment[np.newaxis]
        segment = segment.mean(axis=0)

        # normalize audio
        norm_factor = np.abs(segment).max()
        if norm_factor > 0:
            segment /= norm_factor
        segment = segment.flatten()
        # pad or crop end-of-file samples
        if len(segment) < sample_len:
            if cutoff=='pad':
                #pad at most 50% of the signal, dispose of the rest
                if len(segment) < 0.5*sample_len:
                    break
                # Pad Length
                pad_len = sample_len - len(segment)
                # Pad with 0s
                pad = np.zeros(pad_len)
                segment = np.concatenate((segment,pad), axis=0)
            elif cutoff=='leave':
                segment = audio[start_time:len(audio)]
            elif cutoff=='crop':
                break
        segment_array.append(segment)
    return segment_array

In [None]:
def batch_extract_jukebox(audio_samples, meanpool=False, mult_factor = 100):
    assert mult_factor <= 1722
    embs = jukemirlib.extract(audio_samples, layers=[36], meanpool=meanpool)[36]
    # print(f"init_shape = {embs.shape}")
    split_embeddings = np.array_split(embs, mult_factor, axis=1)
    mean_splits = [np.mean(arr, axis=1) for arr in split_embeddings]
    final_embs = np.vstack(mean_splits)
    # print(f"final_shape = {final_embs.shape}")
    return final_embs

In [None]:
# this function extracts jukebox ready samples from an audio file.

def extract_jukebox_embeddings(file_path, segment_length_s=10, cutoff='pad', overlap=0.1, meanpool_bool=True):
    audio = AudioSegment.from_file(file_path)
    audio = audio.set_frame_rate(JUKEBOX_SR) # convert to JUKEBOX_SR    
    
    segment_length_ms = segment_length_s * 1000
    step = int((1-overlap)*segment_length_ms)
    sample_len = int(JUKEBOX_SR*segment_length_s)
    
    emb_array = []
    for i in range(0, len(audio), step):
        # create segment
        start_time = i 
        end_time = i + segment_length_ms
        segment = audio[start_time:end_time]
        
        # to np array
        segment = np.array(segment.get_array_of_samples(), dtype=np.float32).reshape((-1, segment.channels))
        segment = segment.T
        
        if segment.ndim == 1:
            segment = segment[np.newaxis]
        segment = segment.mean(axis=0)

        # normalize audio
        norm_factor = np.abs(segment).max()
        if norm_factor > 0:
            segment /= norm_factor
        segment = segment.flatten()
        # pad or crop end-of-file samples
        if len(segment) < sample_len:
            if cutoff=='pad':
                #pad at most 50% of the signal, dispose of the rest
                if len(segment) < 0.5*sample_len:
                    break
                # Pad Length
                pad_len = sample_len - len(segment)
                # Pad with 0s
                pad = np.zeros(pad_len)
                segment = np.concatenate((segment,pad), axis=0)
            elif cutoff=='leave':
                segment = audio[start_time:len(audio)]
            elif cutoff=='crop':
                break

        emb_array.append(jukemirlib.extract(audio=segment.flatten(), layers=[36], meanpool=meanpool_bool)[36])
    return emb_array

# Get Files

In [None]:
TRAIN_PATH = "/home/cvillela/dataland/data/hanwha/training/source/"
VAL_PATH = "/home/cvillela/dataland/data/hanwha/validation/source/"

file_paths = []
for dirpath, dirnames, filenames in os.walk(TRAIN_PATH):
        for file_name in filenames:
            if file_name.endswith('.wav'):
                file_paths.append(os.path.join(dirpath,file_name))
for dirpath, dirnames, filenames in os.walk(VAL_PATH):
        for file_name in filenames:
            if file_name.endswith('.wav'):
                file_paths.append(os.path.join(dirpath,file_name))
print(f"Listed {len(file_paths)} files")           

# Extract Embeddings

In [None]:
sample_duration = 5
overlap = 0.1
meanpool_bool = False
cutoff = 'pad'
batch_size = 4
mult_factor = 100

## Batched

In [None]:
i = 0
emb_list = []
sample_list = []

for f in tqdm(file_paths):
    sample_list = sample_list + extract_jukebox_samples(f, segment_length_s=sample_duration, overlap=overlap, cutoff='pad')
    if len(sample_list) >= batch_size:
        while len(sample_list) >= batch_size:
            curr_batch = []
            for _ in range(batch_size):
                curr_batch.append(sample_list.pop())
            emb_list.append(batch_extract_jukebox(curr_batch, meanpool=meanpool_bool, mult_factor = mult_factor))
    if len(emb_list) > 1000:
        i+=1
        emb_list = np.vstack(emb_list)
        np.save(f"/home/cvillela/dataland/data/hanwha/embeddings/jukebox_m{mult_factor}_{i}.npy", emb_list)
        emb_list = []
        break
    
if len(sample_list)>0:
    i+=1
    emb_list.append(batch_extract_jukebox(curr_batch, meanpool=meanpool_bool, mult_factor = mult_factor))
    emb_list = np.vstack(emb_list)
    np.save(f"/home/cvillela/dataland/data/hanwha/embeddings/jukebox_m{mult_factor}_{i}.npy", emb_list)
    emb_list = []

## One by one

In [None]:
emb_list = []
for f in tqdm(file_paths):
    emb_list = emb_list + extract_jukebox_embeddings(f, segment_length_s=sample_duration, overlap=overlap, cutoff='crop', meanpool_bool=meanpool_bool)

In [None]:
embs = np.vstack(emb_list)
embs.shape

In [None]:
embs = np.vstack(emb_list)
np.save("/home/cvillela/dataland/data/hanwha/embeddings/",embs)

In [None]:
csv_dir = "/home/cvillela/dataland/umapper/data/hanwha/" 
file_paths = [
        # File paths...
       "/home/cvillela/dataland/umapper/data/hanwha/search_1.npy",
       "/home/cvillela/dataland/umapper/data/hanwha/search_2.npy",
       "/home/cvillela/dataland/umapper/data/hanwha/search_3.npy",
       "/home/cvillela/dataland/umapper/data/hanwha/search_4.npy",
       "/home/cvillela/dataland/umapper/data/hanwha/search_5.npy",
       "/home/cvillela/dataland/umapper/data/hanwha/search_6.npy",
       
    ]

In [None]:
import numpy as np

In [None]:
for f in file_paths:
    arr = np.load(f)
    np.savetxt(f.split('.')[0]+'.csv', arr, delimiter=",")