## Imports

In [None]:
import os
import sys
import glob
from typing import Dict, Any, List, Union, Optional
from pathlib import Path
from tqdm import tqdm
import pickle

import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import librosa

from datasets import load_dataset
import torch
import torchaudio
from transformers import MusicgenForConditionalGeneration, AutoProcessor

In [None]:
sys.path.append(os.path.dirname(os.path.abspath("")))
sys.path.append(os.path.dirname(os.path.abspath(""))+"/ToolsISMIR2024/syntheory")

In [None]:
from ToolsISMIR2024.syntheory.embeddings.models import (
    Model, load_musicgen_model, load_audio,
    mfcc, melspectrogram, chroma_cqt, concat_features,
)
from ToolsISMIR2024.syntheory.embeddings.extract_embeddings import audio_file_to_embedding_np_array

## Data

In [None]:
dataset_id = "neerajaabhyankar/hindustani-raag-small"
hrs_full = load_dataset(dataset_id, revision="0dfb021e54e0e7489b90a47e23ef15f34fa740ec")
hrs = hrs_full["train"].train_test_split(seed=42, shuffle=True, train_size=0.8, test_size=0.2, stratify_by_column="label") # train-val split
del hrs_full
dataset_name = dataset_id.split("/")[-1]

In [None]:
SR = 48000

## Embedding Model

In [None]:
# model_config = {
#     "model_name": "MELSPEC",
#     "model_type": "MELSPEC",
#     "minimum_duration_in_sec": 4,
# }
# processor, model = None, None

In [None]:
model_config = {
    "model_name": "MUSICGEN_DECODER_LM_L",
    "model_type": "MUSICGEN_DECODER_LM_L",
    "minimum_duration_in_sec": 4,
    # "extract_from_layer": None,
}
processor, model = load_musicgen_model(Model.MUSICGEN_DECODER_LM_L)

In [None]:
model = model.to("mps")

## Preprocess

In [None]:
# # reload
# import importlib
# importlib.reload(sys.modules["ToolsISMIR2024.syntheory.embeddings.extract_embeddings"])
# importlib.reload(sys.modules["transformers"])
# from transformers import MusicgenForConditionalGeneration

In [None]:
def resample_audio(y: np.ndarray, orig_sr: int, target_sr: int, duration: Optional[float] = None) -> np.ndarray:
    # truncate
    if duration is not None:
        y = y[:int(duration * orig_sr)]
    
    # resample
    audio = librosa.resample(y, orig_sr=orig_sr, target_sr=target_sr)
    if audio.ndim == 1:
        audio = audio[np.newaxis]
    audio = audio.mean(axis=0)

    # normalize audio
    norm_factor = np.abs(audio).max()
    if norm_factor > 0:
        audio /= norm_factor

    return audio.flatten()

In [None]:
def extract_musicgen_emb(
    audio: np.ndarray,
    processor: AutoProcessor,
    model: Union[MusicgenForConditionalGeneration],
    model_config: Dict[str, Any],
):
    # parse config
    extract_from_layer=model_config.get("extract_from_layer", None)
    decoder_hidden_states=model_config.get("decoder_hidden_states", True)
    meanpool=model_config.get("meanpool", True)
    model_type=Model[model_config["model_type"]]
    
    # set up inputs
    sampling_rate = model.config.audio_encoder.sampling_rate  # MusicGen uses 32000 Hz
    audio = resample_audio(audio, SR, sampling_rate)
    print("audio resampled")
    
    # process inputs for model
    inputs = processor(
        audio=audio,
        text="",
        sampling_rate=sampling_rate,
        padding=True,
        return_tensors="pt",
    )
    if model.device.type == "mps":
        inputs = {k: v.to("mps") for k, v in inputs.items()}
        inputs["input_values"] = inputs["input_values"].to(torch.float32)
    print("inputs prepared")
    
    if model_type == Model.MUSICGEN_AUDIO_ENCODER:
        x = inputs["input_values"]
        # audio encoder
        audio_encoder = model.get_audio_encoder()
        # extract representations from audio encoder
        for layer in audio_encoder.encoder.layers:
            x = layer(x)
        if meanpool:
            return x.mean(axis=2).squeeze().detach().numpy()
        else:
            return x.squeeze().detach().numpy()
    elif model_type in [
        Model.MUSICGEN_DECODER_LM_S,
        Model.MUSICGEN_DECODER_LM_M,
        Model.MUSICGEN_DECODER_LM_L
    ]:
        # extract representations from decoder LM
        out = model(**inputs, output_attentions=True, output_hidden_states=True)
        # output decoder hidden states
        if decoder_hidden_states:
            if extract_from_layer is None:
                if meanpool:
                    return np.stack(tuple(l.mean(axis=1).squeeze().detach().cpu().numpy() for l in out.decoder_hidden_states))
                else:
                    return np.stack(tuple(l.squeeze().detach().cpu().numpy() for l in out.decoder_hidden_states))
            else:
                if meanpool:
                    return out.decoder_hidden_states[extract_from_layer].mean(axis=1).squeeze().detach().cpu().numpy()
                else:
                    return out.decoder_hidden_states[extract_from_layer].squeeze().detach().cpu().numpy()
        # output decoder attentions
        else:
            if extract_from_layer is None:
                if meanpool:
                    return np.stack(tuple(l.mean(axis=(2, 3)).squeeze().detach().cpu().numpy() for l in out.decoder_attentions))
                else:
                    return np.stack(tuple(l.squeeze().detach().cpu().numpy() for l in out.decoder_attentions))
            else:
                if meanpool:
                    return out.decoder_attentions[extract_from_layer].mean(axis=(2, 3)).squeeze().detach().cpu().numpy()
                else:
                    return out.decoder_attentions[extract_from_layer].squeeze().detach().cpu().numpy()
    else:
        raise ValueError(f"Invalid model: {model_type}")

In [None]:
hrs

In [None]:
embedding_dir = "musicgen-embeddings-hindustani-raag-small"

In [None]:
# for ii, row in tqdm(enumerate(hrs["train"])):
#     path_key = row["audio"]["path"].split("/")[-1]
#     if os.path.exists(f"{embedding_dir}/{path_key}.pkl"):
#         continue
#     inputs = row["audio"]["array"]
#     if len(inputs) / SR < model_config["minimum_duration_in_sec"]:
#         continue
#     break

In [None]:
# # set up inputs
# sampling_rate = model.config.audio_encoder.sampling_rate  # MusicGen uses 32000 Hz
# audio = resample_audio(inputs, SR, sampling_rate)
# inputs2 = processor(
#     audio=audio,
#     text="",
#     sampling_rate=sampling_rate,
#     padding=True,
#     return_tensors="pt",
# )
# inputs2 = {k: v.to("mps") for k, v in inputs2.items()}
# inputs2["input_values"] = inputs2["input_values"].to(torch.float32)
# out = model(**inputs2, output_attentions=True, output_hidden_states=True)

In [None]:
# clear cache
# torch.mps.empty_cache()

In [None]:
# os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"

In [None]:
for ii, row in tqdm(enumerate(hrs["train"])):
    path_key = row["audio"]["path"].split("/")[-1]
    if os.path.exists(f"{embedding_dir}/{path_key}.pkl"):
        continue
    inputs = row["audio"]["array"]
    if len(inputs) / SR < model_config["minimum_duration_in_sec"]:
        continue
    emb = extract_musicgen_emb(inputs, processor, model, model_config)[-1]
    pickle.dump(emb, open(f"{embedding_dir}/{path_key}.pkl", "wb"))

In [None]:
for ii, row in tqdm(enumerate(hrs["test"])):
    path_key = row["audio"]["path"].split("/")[-1]
    if os.path.exists(f"{embedding_dir}/{path_key}.pkl"):
        continue
    inputs = row["audio"]["array"]
    emb = extract_musicgen_emb(inputs, processor, model, model_config)[-1]
    pickle.dump(emb, open(f"{embedding_dir}/{path_key}.pkl", "wb"))

In [None]:
# def get_embedding(examples):
#     input_arrays = [x["array"] for x in examples["audio"]]
#     embeddings = [get_embedding_from_model_using_config(inputs, model_config, processor, model)[-1] for inputs in input_arrays]
#     examples["embeddings"] = embeddings
#     return examples

# hrs = hrs.map(get_embedding, batched=True, batch_size=8)

## Training

#### dataloader

In [None]:
from torch.utils.data import DataLoader

In [None]:
train_loader = DataLoader(hrs["train"], batch_size=8, shuffle=True)
val_loader = DataLoader(hrs["test"], batch_size=8, shuffle=False)

#### train a simple MLP

In [None]:
import torch
from typing import List

In [None]:
# from probe.probes import SimpleMLP

In [None]:
class SimpleMLP(torch.nn.Module):
    def __init__(
        self,
        num_features: int,
        hidden_layer_sizes: List[int],
        num_outputs: int,
        dropout_p: float = 0.5,
    ) -> None:
        super().__init__()
        d = num_features

        self.num_layers = len(hidden_layer_sizes)
        for i, ld in enumerate(hidden_layer_sizes):
            setattr(self, f"hidden_{i}", torch.nn.Linear(d, ld))
            d = ld

        self.output = torch.nn.Linear(d, num_outputs)
        self.dropout = torch.nn.Dropout(p=dropout_p)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.dropout(x)
        for i in range(self.num_layers):
            x = getattr(self, f"hidden_{i}")(x)
            x = torch.functional.relu(x)
            x = self.dropout(x)
        return self.output(x)

In [None]:
# train the MLP
X = np.vstack([bhoop_train, yaman_train])
y = np.array([0] * len(bhoop_train) + [1] * len(yaman_train))


In [None]:
learning_rate = 0.001
num_epochs = 1000
batch_size = 2

mlp = SimpleMLP(num_features=X.shape[1], hidden_layer_sizes=[64, 32], num_outputs=2)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(mlp.parameters(), lr=learning_rate)

In [None]:
# Convert data to PyTorch tensors
X_train = torch.tensor(X, dtype=torch.float32)
y_train = torch.tensor(y, dtype=torch.long)
X_test = torch.tensor([bhoop_test, yaman_test], dtype=torch.float32)
y_test = torch.tensor([0, 1], dtype=torch.long)

# Training loop
for epoch in range(num_epochs):
    mlp.train()
    optimizer.zero_grad()
    
    # Forward pass
    outputs = mlp(X_train)
    loss = criterion(outputs, y_train)
    
    # Backward pass and optimization
    loss.backward()
    optimizer.step()
    
    if (epoch+1) % 10 == 0:
        outputs = mlp(X_test)
        loss_eval = criterion(outputs, y_test)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Eval Loss: {loss_eval.item():.4f}')


In [None]:
outputs = mlp(X_test)

In [None]:
outputs

## Probes from the Paper

In [None]:
from probe.probes import ProbeExperiment, ProbeExperimentConfig

In [None]:
PROBE_CONFIG = {
    "model_hash": None,
    "dataset": None,
    "dataset_embeddings_label_column_name": None,
    "data_standardization": True,
    "hidden_layer_sizes": [],
    "batch_size": 64,
    "learning_rate": 1e-3,
    "dropout_p": 0.5,
    "l2_weight_decay": None,
    "max_num_epochs": None,
    "early_stopping_metric": "primary",
    "early_stopping": True,
    "early_stopping_eval_frequency": 8,
    "early_stopping_boredom": 256,
    "seed": 0,
    "num_outputs": None,
    # if this is true, all the embedding files used in test/train are loaded into RAM
    # otherwise, we load only their location in a zarr file on disk and load as needed
    "load_embeddings_in_memory": False,
}

def _set_attr_if_exists(probe_config, hparams, attr_name, default=None):
    x = getattr(probe_config, attr_name, default)
    if x is not None:
        hparams[attr_name] = x

In [None]:
CONCEPT_LABELS = {
    "chord_progressions": [
        (19, "chord_progression"),
        (12, "key_note_name"),
    ],
    "chords": [(4, "chord_type"), (3, "inversion"), (12, "root_note_name")],
    "scales": [(7, "mode"), (12, "root_note_name")],
    "intervals": [(12, "interval"), (12, "root_note_name")],
    "notes": [(12, "root_note_pitch_class"), (9, "octave")],
    "time_signatures": [
        (8, "time_signature"),
        (6, "time_signature_beats"),
        (3, "time_signature_subdivision"),
    ],
    "tempos": [(161, "bpm")],
}

In [None]:
def start(
    use_wandb: bool = False, random_seed: int = 0, base_path_parent="data"
) -> ProbeExperiment:    
    # model type: [ JUKEBOX | MUSICGEN_DECODER | MUSICGEN_AUDIO_ENCODER | MFCC | CHROMA | MELSPEC | HANDCRAFT ]
    model_type = "MUSICGEN_DECODER"
    # model size: [S | M | L]
    model_size = "L"
    # model layer: [0, ... 71]
    model_layer = None
    # concept: [notes, tempos, time_signatures, etc. ] + a specific label
    concept = "scales"

    num_classes = getattr(PROBE_CONFIG, "num_classes", None)
    # set hyperparameters
    hparams = {}
    _set_attr_if_exists(PROBE_CONFIG, hparams, "data_standardization")
    _set_attr_if_exists(PROBE_CONFIG, hparams, "batch_size")
    _set_attr_if_exists(PROBE_CONFIG, hparams, "learning_rate")
    _set_attr_if_exists(PROBE_CONFIG, hparams, "dropout_p")
    _set_attr_if_exists(PROBE_CONFIG, hparams, "l2_weight_decay")
    _set_attr_if_exists(PROBE_CONFIG, hparams, "hidden_layer_sizes", [512])

    # get the concept label that is given by parent concept and the target we wish to probe
    dataset_settings = CONCEPT_LABELS[concept][0]
    _num_classes, label_column_name = dataset_settings

    # allow override of number of classes if given in config directly
    num_classes = num_classes or _num_classes

    is_regression = concept == "tempos"
    num_outputs = 1 if is_regression else num_classes
    output_type = "regression" if is_regression else "multiclass"

    cfg = ProbeExperimentConfig(
        dataset_embeddings_label_column_name=label_column_name,
        dataset=concept,
        num_outputs=num_outputs,
        model_hash=f"{model_type}-{model_size}-{model_layer}",
        max_num_epochs=100,
        **hparams,
        seed=random_seed,
        load_embeddings_in_memory=False,
    )

    exp = ProbeExperiment(
        cfg,
        summarize_frequency=100,
        use_wandb=use_wandb,
    )
    exp.obtain_data(
        model_type=model_type,
        model_size=model_size,
        output_type=output_type,
        model_layer=model_layer,
    )
    exp.train()

    return exp

