In [None]:
import os
import random
from typing import List, Dict, Tuple

import librosa
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import pyrubberband as pyrb
import soundfile as sf
from scipy.stats import lognorm, uniform

import steme.audio as audio
import steme.dataset as dataset
import steme.utils as utils

In [None]:
import IPython.display as ipd

In [None]:
DATASET_PATH = "/home/gigibs/Documents/datasets/gtzan_augmented_log"

In [None]:
gtzan, tracks, tempi = dataset.gtzan_data()
giant_steps, gs_tracks, gs_tempi = dataset.giant_steps_data()
ballroom, b_tracks, b_tempi = dataset.ballroom_data()

dist_low = dataset.lognormal70()

theta = dataset.variables_non_linear(25, 40, 190)
log_bins = theta[(theta > 30) & (theta < 370)][::2]
# linear_bins = np.arange(30, 350, 10)

In [None]:
bins = log_bins

In [None]:
print(f"gtzan size: {len(tracks)}")
print(f"giant_steps size: {len(gs_tracks)}")
print(f"ballroom size: {len(b_tracks)}")
print(f"lognorm @ 70 size: {len(dist_low)}")

In [None]:
fig, ax = plt.subplots(1,3, figsize=(10,5))

ax[0].hist(tempi, bins=bins,  color="orange", alpha=0.7, label="gtzan")
ax[0].title.set_text("GTZAN (999 tracks)")
ax[0].set_xlabel("BPM")
ax[0].set_ylabel("# tracks")
ax[1].hist(gs_tempi, bins=bins, color="red", alpha=0.7, label="giant_steps")
ax[1].title.set_text("Giant Steps (659 tracks)")
ax[1].set_xlabel("BPM")
ax[1].set_ylabel("# tracks")
ax[2].hist(b_tempi, bins=bins, color="blue", alpha=0.7, label="ballroom")
ax[2].title.set_text("Ballroom (698 tracks)")
ax[2].set_xlabel("BPM")
ax[2].set_ylabel("# tracks")

#ax.hist(dist_low, bins=bins,  color="green", alpha=0.7, label="lognorm @ 70")

plt.tight_layout()
# plt.savefig("datasets_tempo_distribution.svg")

## Augmenting GTZAN

# Approach 1: time streching only GTZAN

In [None]:
finer_bins = log_bins[::2]

In [None]:
plt.hist([tempi, dist_low], bins=finer_bins, color=["red", "orange"], label=["gtzan", "lognorm@70"])
plt.legend()

In [None]:
dist_low_hist = np.histogram(dist_low, bins=finer_bins)
gtzan_dist = np.histogram(tempi, bins=finer_bins)

In [None]:
diff_tempi = dist_low_hist[0] - gtzan_dist[0]

In [None]:
fig, ax = plt.subplots(2,1)
ax[0].hist([tempi, dist_low], finer_bins, alpha=0.7, label=["gtzan", "lognorm@70"], color=["red", "orange"], 
         stacked=False)
ax[0].legend()
#plt.hist(dist_low, finer_bins, alpha=0.7, label="lognorm@70", color="orange")
ax[1].bar(finer_bins[1:], diff_tempi, 2.5, alpha=0.5, label="diff", color="blue")

ax[1].legend()

In [None]:
# gtzan_info = {bin: qtd de faixas no bin}
# transformation_dict = {bin: transformação}
# se transformation_dict[bin] <= 0 e gtzan_info >= 0, faz a transformação pra faixa necessária
# se transformation_dict[bin] >= 0 e gtzan_info >= 0, pula pro próximo bin

In [None]:
def create_transformation_dict(verbose=True):
    removals = 0
    additions = 0

    transformation_dict = {}

    for idx, value in enumerate(diff_tempi):    
        transformation_dict[f"{finer_bins[idx]}, {finer_bins[idx+1]}"] = value

        if value < 0:
            message = f"remove {value} samples"
            removals += np.abs(value)
        elif value > 0:
            message = f"add {value} samples"
            additions += value
        else:
            message = "do nothing"
            
        if verbose:
            print(f"{finer_bins[idx]} - {finer_bins[idx+1]}: {message}")
        
    if verbose:
        print(f"total removals = {removals}, total additions = {additions}")
    return transformation_dict

def reset_transformation_dict():
    return create_transformation_dict(verbose=False)

In [None]:
def create_helper_dict(bins):
    # criar um dicionário com intervalo: {track_ids}
    # [30,40]: ["classical.0000", "blues.0010"]
    helper_dict = {}
    for idx in range(len(bins)-1):
        helper_dict[f"{bins[idx]}, {bins[idx+1]}"] = []
    return helper_dict

In [None]:
helper_dict = create_helper_dict(finer_bins)
gtzan_mapping = {}

for i in tracks:
    tempo = gtzan.track(i).tempo
    
    boundaries = np.digitize(tempo, finer_bins)
    gtzan_mapping[i] = (tempo, f"{finer_bins[boundaries-1]}, {finer_bins[boundaries]}")
    helper_dict[f"{finer_bins[boundaries-1]}, {finer_bins[boundaries]}"].append(i)

In [None]:
def check_missing_tracks(transformation_dict):
    for k, v in list(transformation_dict.items())[::-1]:
#     for k, v in transformation_dict.items():
        if v > 0:
            return k, v

In [None]:
def key_boundaries(key):
    return [float(i) for i in key.split(", ")]

In [None]:
transformation_dict = reset_transformation_dict()
augmented_dict = transformation_dict.copy()

In [None]:
to_remove = []
j = 0
for key, val in list(transformation_dict.items())[::-1]:
# for key, val in transformation_dict.items():
    if val < 0:
        print(f"augmenting tracks from {key}")
        for track_id in helper_dict[key]:
            print(track_id)
            original_tempo = gtzan.track(track_id).tempo
            original_boundaries = gtzan_mapping[track_id][1]

            str_boundaries = check_missing_tracks(transformation_dict)

            if str_boundaries is None:
    #             print(transformation_dict)
                # we're done then!
                break 

            new_tempo_boundaries = key_boundaries(str_boundaries[0])
            
            if key == str_boundaries[0]:
                print(f"we will not transform {key} into {str_boundaries[0]}")
#                 transformation_dict[str_boundaries[0]] -= 1
                break
            
            new_tempo = random.uniform(float(new_tempo_boundaries[0]), float(new_tempo_boundaries[1]))

#             print(f"transforming tracks from {key} to {new_tempo_boundaries}")

            tempo_rate = new_tempo/original_tempo

            x, fs = gtzan.track(track_id).audio
            to_remove.append(track_id)

#             print(f"original_tempo {original_tempo}, new_tempo {new_tempo}, tempo_rate {tempo_rate}")

            # pyrubberband parameters
            rbags = {"-2": ""} # choose finer algorithms to have a better quali
            x_stretch = pyrb.time_stretch(x, fs, tempo_rate)

    #         print(f"augmented one track from {original_boundaries} to {str_boundaries[0]}")
            transformation_dict[str_boundaries[0]] -= 1
            transformation_dict[original_boundaries] += 1
            augmented_dict[str_boundaries[0]] -= 1
            augmented_dict[original_boundaries] += 1
            
            # save audio
            sf.write(os.path.join(DATASET_PATH, f"audio/{track_id}_augmented.wav"), x_stretch, fs, subtype="PCM_24")
            # save tempo        
            with open(os.path.join(DATASET_PATH, f"annotations/tempo/{track_id}_augmented.bpm"), "w") as f:
                f.write(str(new_tempo))

In [None]:
for track_id in to_remove:
    try:
#         print(f"removing {track_id}")
        os.remove(os.path.join(DATASET_PATH, f"audio/{track_id}.wav"))
        os.remove(os.path.join(DATASET_PATH, f"annotations/tempo/{track_id}.bpm"))
    except:
#         print("already removed")
        continue

In [None]:
import steme.loader as loader

In [None]:
gtzan_augmented = loader.custom_dataset_loader(
    path=DATASET_PATH,
    dataset_name="",
    folder="",
)

In [None]:
gtzan_augmented_tracks = gtzan_augmented.track_ids
gtzan_augmented_tracks.remove("reggae.00086")
gtzan_augmented_tempi = [gtzan_augmented.track(track_id).tempo for track_id in gtzan_augmented_tracks]

In [None]:
plt.hist(gtzan_augmented_tempi, bins=finer_bins, color="red", label="gtzan_augmented")
plt.legend()

In [None]:
plt.hist(
    [gtzan_augmented_tempi, dist_low], 
    bins=np.arange(30,200,10), 
    color=["blue", "orange"], 
    label=["gtzan_augmented", "lognorm@70"]
)
plt.legend()

In [None]:
len(gtzan_augmented_tracks)

# Quality comparison

In [None]:
# load original track
orig_x, orig_fs = gtzan.track("blues.00002").audio

In [None]:
ipd.Audio(orig_x, rate=orig_fs)

In [None]:
aug_x, aug_fs = gtzan_augmented.track("blues.00002_augmented").audio

In [None]:
ipd.Audio(aug_x, rate=aug_fs)

In [None]:
orig_nov, _ = audio.spectral_flux(orig_x, orig_fs, n_fft=2048, hop_length=512)
orig_frame_time = librosa.frames_to_time(np.arange(len(orig_nov)),
                                    sr=orig_fs,
                                    hop_length=512)

aug_nov, _ = audio.spectral_flux(aug_x[:30*aug_fs], aug_fs, n_fft=2048, hop_length=512)
aug_frame_time = librosa.frames_to_time(np.arange(len(aug_nov)),
                                    sr=aug_fs,
                                    hop_length=512)

In [None]:
plt.plot(orig_frame_time, orig_nov, color="red", label="original audio")
plt.plot(aug_frame_time, aug_nov, color="blue", label="augmented audio")
plt.legend()

In [None]:
linear_theta = np.arange(30,350,1)

In [None]:
orig_T, orig_fT, orig_times = audio.tempogram(orig_x, orig_fs, 10, "fourier", linear_theta)
aug_T, aug_fT, aug_times = audio.tempogram(aug_x, aug_fs, 10, "fourier", linear_theta)

In [None]:
def plot_comparison(T, t, freqs, ttypes, subplot_titles, fig_title=None):
    """
    helper function to plot tempograms side-by-side.
    """
    figsize = (15, 5)
    num_tempograms = len(T)
    fig, ax = plt.subplots(1, num_tempograms, figsize=figsize)

    for idx in range(num_tempograms):
        kwargs = utils._tempogram_kwargs(t[idx], freqs[idx])

        ax[idx].imshow(T[idx], **kwargs)

        xlim = (t[idx][0], t[idx][-1])
        ylim = (freqs[idx][0], freqs[idx][-1])

        #plt.setp(ax, xlim=xlim, ylim=ylim)
        
        if ttypes[idx] == "log":
            labels = [item.get_text() for item in ax[0].get_yticklabels()]
            new_labels = np.rint(log_axis[::20]).astype(int)
            ax[idx].set_yticklabels(new_labels)

        if fig_title is not None:
            fig.suptitle(fig_title, fontsize=16)

        ax[idx].set_xlabel("Time (s)")
        ax[idx].set_ylabel("Tempo (BPM)")
        ax[idx].title.set_text(subplot_titles[idx])
    return fig, ax

In [None]:
plot_comparison([orig_T, aug_T], [orig_fT, aug_fT], [orig_times, aug_times], subplot_titles=[f"orig {orig_bpm}", f"aug {aug_bpm}"], ttypes="linear")

In [None]:
orig_bpm = gtzan.track("blues.00002").tempo
#utils.plot_tempogram(orig_T, orig_fT, orig_times, title=f"Original audio ({orig_bpm} BPM))")

In [None]:
aug_bpm = gtzan_augmented.track("blues.00002_augmented").tempo
utils.plot_tempogram(aug_T, aug_fT, aug_times, title=f"Augmented ({aug_bpm} BPM)")

## Approach 2: Augmentation in the tempogram domain

In [None]:
# first looking at the linear scenario

In [None]:
orig_T, orig_fT, orig_times = audio.tempogram(orig_x, orig_fs, 10, "fourier", linear_theta)

In [None]:
#utils.plot_tempogram(orig_T, orig_fT, orig_times, title="Original audio")

In [None]:
larger_orig_T, larger_orig_fT, larger_orig_times = audio.tempogram(orig_x, orig_fs, 10, "fourier", np.arange(30, 670))

In [None]:
#utils.plot_tempogram(larger_orig_T, larger_orig_fT, larger_orig_times, title="Original audio")

In [None]:
aug_T.shape, larger_orig_T.shape

In [None]:
larger_orig_T.shape[1]

In [None]:
raw_aug = np.zeros(orig_T.shape)
## dumb way of doing it
# average every 2 lines, copy the result to the new array
large_idx = 0
idx = 0
while idx < 320:
    # we have to use +2 because np slicing is [start, end), instead of [start, end]
    avg_lines = np.mean(larger_orig_T[large_idx:large_idx+2, :], axis=0)
    
    raw_aug[idx,:] = avg_lines
    
    idx += 1
    large_idx += 2
    

In [None]:
larger_orig_T[::2, :].shape

In [None]:
raw_aug = larger_orig_T[::2, :].copy()

In [None]:
plot_comparison([raw_aug, aug_T], [orig_fT, aug_fT], [orig_times, aug_times], subplot_titles=["tempogram_aug", "audio_aug"], ttypes="linear")

In [None]:
utils.plot_tempogram(aug_T, aug_fT, aug_times, title=f"Augmented ({aug_bpm} BPM)")

In [None]:
utils.plot_tempogram(orig_T, orig_fT, orig_times, title=f"Original audio ({orig_bpm} BPM)")

In [None]:
idx = 0
tmp = np.mean(larger_orig_T[idx:idx+2, :], axis=0)

In [None]:
tmp.shape

In [None]:
tempi_array = np.asarray(tempi)

In [None]:
plt.hist(tempi_array, bins=50)

In [None]:
plt.hist(
    [np.append(tempi_array,tempi_array/2), dist_low], 
    bins=50, 
    color=["red", "orange"], 
    label=["gtzan + gtzan/2", "lognorm@70"]
)
plt.legend()

In [None]:
np.append(tempi_array,tempi_array/2)