In [None]:
import os
from pathlib import Path

BASE_DIR = Path.cwd()

if (Path("/") / "home" / "vsioros" / "data").is_dir():
    BASE_DIR = Path("/") / "home" / "vsioros" / "data"

    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

RESULTS_DIR = BASE_DIR / "results"

In [None]:
from enum import Enum
from typing import Any, Generator

import librosa
import numpy as np


class EditType(Enum):
    """An enumeration representing different types of edits.

    Attributes:
        REFINE (str): Represents a refinement edit.
        REPLACE (str): Represents a replacement edit.
        REWEIGHT (str): Represents a reweighting edit.
    """

    REFINE = "Refine"
    REPLACE = "Replace"
    REWEIGHT = "Reweight"


class ModelName(Enum):
    """An enumeration representing different model names.

    Attributes:
        MUSIC_GEN (str): Model name for music generation.
        AUFFUSION (str): Model name for auffusion.
    """

    MUSIC_GEN = "musicgen"
    AUFFUSION = "auffusion"


class DataLoader:
    """A class for loading audio samples corresponding to a specific edit type."""

    def __init__(
        self,
        edit_type: EditType,
        model_name: ModelName,
        root_dir: Path = RESULTS_DIR,
    ) -> None:
        """Initialize the DataLoader.

        Args:
            edit_type (EditType): The type of edits to load audio samples for.
            model_name (ModelName): The name of the model associated with the audio samples.
            root_dir (Path, optional): The root directory containing the audio samples. Defaults to RESULTS_DIR.
        """
        self.edit_type = edit_type
        self.data_dir = root_dir / model_name.value / "Evaluation" / edit_type.value
        self.count = 0
        for prompt_dir in filter(lambda _: _.is_dir(), self.data_dir.iterdir()):
            for _ in filter(lambda _: _.is_dir(), prompt_dir.iterdir()):
                self.count += 1

    def __len__(self) -> int:
        """Get the total number of seed directories.

        Returns:
            int: The number of seed directories.
        """
        return self.count

    def __iter__(self) -> Generator[tuple[str, int, dict[str, dict[str, Any]]], None, None]:
        """Load audio samples corresponding to the provided edit type.

        Yields:
            tuple[str, int, dict[str, Dict[str, Any]]]: A tuple containing the stem of the directory,
            the stem of the seed directory, and a dictionary of loaded audio files along with their sampling rates.
            The dictionary keys indicate whether the audio is the edited version ("edited") or the source version ("source").
        """
        for prompt_dir in filter(lambda _: _.is_dir(), self.data_dir.iterdir()):
            source_prompt, edited_prompt = prompt_dir.stem.split(" - ")
            for seed_dir in filter(lambda _: _.is_dir(), prompt_dir.iterdir()):
                audios = {}
                for audio_file in seed_dir.glob("*.wav"):
                    audio, sampling_rate = librosa.load(audio_file, sr=None)

                    if audio_file.stem.startswith("00"):
                        audios["source"] = {
                            "prompt": source_prompt,
                            "data": audio,
                            "sr": sampling_rate,
                        }
                    elif audio_file.stem.startswith("01"):
                        audios["edited"] = {
                            "prompt": edited_prompt,
                            "data": audio,
                            "sr": sampling_rate,
                        }

                yield int(seed_dir.stem), audios

## Melody Accuracy

In [None]:
from sklearn.metrics import accuracy_score


def extract_pitch_classes(audio: np.ndarray, sr: int, hop_length: int = 512) -> np.ndarray:
    """Extract pitch classes from an audio signal.

    Args:
        audio (np.ndarray): Input audio signal.
        sr (int): Sampling rate of the audio signal.
        hop_length (int): Hop length for computing the pitch.

    Returns:
        np.ndarray: Array of pitch classes.
    """
    # Extract pitch using librosa's piptrack function
    _, magnitudes = librosa.core.piptrack(y=audio, sr=sr, hop_length=hop_length)

    # Get the pitch with the maximum magnitude for each frame
    return np.argmax(magnitudes, axis=0)


def melody_accuracy(source_audio: np.ndarray, generated_audio: np.ndarray, sr: int) -> float:
    """Calculate the accuracy of generated melody compared to the input melody.

    Args:
        source_audio (np.ndarray): Input audio.
        generated_audio (np.ndarray): Generated audio.
        sr: (int): The sampling rate of the provided audio files.

    Returns:
        float: Melody accuracy.
    """
    # Extract pitch classes from both melodies
    input_pitch_classes = extract_pitch_classes(source_audio, sr)
    generated_pitch_classes = extract_pitch_classes(generated_audio, sr)

    # Calculate melody accuracy
    return accuracy_score(input_pitch_classes, generated_pitch_classes)

## Dynamics Correlation

In [None]:
import numpy as np
from scipy.stats import pearsonr


def extract_dynamics(audio: np.ndarray, sr: int) -> np.ndarray:
    """Extract dynamics from an audio file.

    Args:
        audio (np.ndarray): The audio file.
        sr: (int): The sampling rate of the provided audio file.

    Returns:
        np.ndarray: Array containing the dynamics of the audio.
    """
    onset_env = librosa.onset.onset_strength(y=audio, sr=sr)
    tempo, _ = librosa.beat.beat_track(onset_envelope=onset_env, sr=sr)
    return onset_env / tempo


def dynamics_correlation(
    source_audio: np.ndarray,
    edited_audio: np.ndarray,
    sr: int,
) -> float:
    """Compute macro dynamics correlation between two audio files.

    Args:
        source_audio (np.ndarray): The source audio file.
        edited_audio (np.ndarray): The edited audio file.
        sr (int): The sampling rate of the provided audio files.

    Returns:
        float: The macro dynamics correlation.
    """
    source_dynamics = extract_dynamics(source_audio, sr)
    edited_dynamics = extract_dynamics(edited_audio, sr)
    macro_correlation, _ = pearsonr(source_dynamics, edited_dynamics)

    return macro_correlation

In [None]:
# TODO: MICRO CORRELATIONS

## Rythm F1

In [None]:
from typing import List, Tuple
import numpy as np

# This is done to circumvent deprecation issues
np.float = float
np.int = int

from madmom.features.beats import DBNBeatTrackingProcessor, RNNBeatProcessor


def extract_beat_timestamps(activation: np.ndarray) -> np.ndarray:
    """Extract beat timestamps from the given activation.

    Args:
        activation (np.ndarray): Activation data for beat tracking.

    Returns:
        np.ndarray: Array containing beat timestamps.
    """
    return DBNBeatTrackingProcessor(fps=100)(activation)


def check_alignment(
    timestamps_1: List[float],
    timestamps_2: List[float],
    threshold: float = 0.07,
) -> List[Tuple[float, float]]:
    """Check alignment between two sets of timestamps.

    Args:
        timestamps_1 (List[float]): Timestamps from the first audio file.
        timestamps_2 (List[float]): Timestamps from the second audio file.
        threshold (float, optional): Maximum allowed difference for alignment. Defaults to 0.07.

    Returns:
        List[Tuple[float, float]]: List of aligned timestamps pairs.
    """
    aligned = []
    for ts1 in timestamps_1:
        for ts2 in timestamps_2:
            if abs(ts1 - ts2) < threshold:
                aligned.append((ts1, ts2))
                break
    return aligned


def calculate_f1_score(
    aligned_timestamps: List[Tuple[float, float]],
    timestamps_1: List[float],
    timestamps_2: List[float],
) -> float:
    """Calculate the F1 score based on alignment of timestamps.

    Args:
        aligned_timestamps (List[Tuple[float, float]]): List of aligned timestamp pairs.
        timestamps_1 (List[float]): Timestamps from the first audio file.
        timestamps_2 (List[float]): Timestamps from the second audio file.

    Returns:
        float: F1 score.
    """
    precision = len(aligned_timestamps) / len(timestamps_1)
    recall = len(aligned_timestamps) / len(timestamps_2)

    # Avoid division by zero
    if precision + recall == 0:
        return 0.0

    return 2 * (precision * recall) / (precision + recall)


def rhythm_f1(source_audio: np.ndarray, generated_audio: np.ndarray) -> float:
    """Calculate Rhythm F1 score between two audio samples.

    Read more [here](https://madmom.readthedocs.io/en/v0.15/modules/features/beats.html#madmom.features.beats.DBNBeatTrackingProcessor).

    Args:
        source_audio (np.ndarray): Array containing the first audio sample.
        generated_audio (np.ndarray): Array containing the second audio sample.

    Returns:
        float: Rhythm F1 score.
    """
    # Step 1: Calculate frame-wise beat probabilities
    beat_processor_1 = RNNBeatProcessor()
    beat_processor_2 = RNNBeatProcessor()

    # Step 2: Extract beat timestamps
    timestamps_1 = extract_beat_timestamps(beat_processor_1(source_audio))
    timestamps_2 = extract_beat_timestamps(beat_processor_2(generated_audio))

    # Step 3: Check alignment
    aligned_timestamps = check_alignment(timestamps_1, timestamps_2)

    # Step 4: Calculate F1 score
    return calculate_f1_score(aligned_timestamps, timestamps_1, timestamps_2)

## CLAP Score

In [None]:
from transformers import AutoProcessor, ClapModel

clap_model = ClapModel.from_pretrained("laion/clap-htsat-unfused")
clap_processor = AutoProcessor.from_pretrained("laion/clap-htsat-unfused")

In [None]:
import numpy as np
import torch.nn.functional as F


def clap(
    prompt: str,
    audios: np.ndarray,
    orig_sr: int,
    sr: int = 48000,
) -> tuple[float, float]:
    """Computes cosine similarity between a prompt and audio features.

    Args:
        prompt (str): The input prompt.
        audios (np.ndarray): Array containing audio samples.
        orig_sr (int): Original sampling rate of audio.
        sr (int, optional): Target sampling rate. Defaults to 48000.

    Returns:
        tuple[float, float]: Cosine similarity between audios.
    """
    # Resample audios
    if orig_sr is not None and orig_sr != sr:
        audios = np.stack(
            [librosa.resample(audio, orig_sr=orig_sr, target_sr=sr) for audio in audios],
        )

    inputs = clap_processor(
        text=prompt,
        audios=audios,
        return_tensors="pt",
        sampling_rate=sr,
        padding=True,
    )

    # Process prompt and audios
    prompt_features = clap_model.get_text_features(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
    )
    audio_features = clap_model.get_audio_features(
        input_features=inputs["input_features"],
        attention_mask=inputs["attention_mask"],
    )

    # Calculate cosine similarity between audios
    audio_audio_similarity = F.cosine_similarity(audio_features[0], audio_features[1], dim=0)

    # Calculate cosine similarity between prompt at index 1 and audio at index 1
    text_audio_similarity = F.cosine_similarity(prompt_features[0], audio_features[1], dim=0)

    return audio_audio_similarity.item(), text_audio_similarity.item()

In [None]:
import pandas as pd
from tqdm.auto import tqdm


def collect_metrics(
    prompt: str,
    source_audio: np.ndarray,
    generated_audio: np.ndarray,
    sr: int,
) -> tuple[float, float, float, float]:
    """Collect metrics for the provided source and generated audio.

    Args:
        prompt (str): The prompt associated with the audio samples.
        source_audio (np.ndarray): The source audio samples.
        generated_audio (np.ndarray): The generated audio samples.
        sr (int): The sampling rate of the audio samples.

    Returns:
        Tuple[float, float, float, float]: A tuple containing the computed metrics.
            - Melody accuracy score
            - Dynamics correlation score
            - Rhythm F1 score
            - Clap score
    """
    melody_accuracy_score = melody_accuracy(source_audio, generated_audio, sr)
    dynamics_correlation_score = dynamics_correlation(source_audio, generated_audio, sr)
    rhythm_f1_score = rhythm_f1(source_audio, generated_audio)
    clap_score = clap(prompt, [source_audio, generated_audio], sr)

    return melody_accuracy_score, dynamics_correlation_score, rhythm_f1_score, clap_score


def construct_dataframe(data_loader: DataLoader) -> pd.DataFrame:
    """Construct a DataFrame containing metrics for audio samples.

    This function constructs a DataFrame containing metrics for audio samples loaded
    using the provided DataLoader.

    Args:
        data_loader (DataLoader): An instance of DataLoader that loads audio samples.

    Returns:
        pd.DataFrame: A DataFrame containing metrics for audio samples.
    """
    results = []
    for seed, samples in tqdm(data_loader, position=2, leave=False, desc="Evaluating dataset"):
        source_prompt = samples.get("source", {}).get("prompt")
        edited_prompt = samples.get("edited", {}).get("prompt")
        source_audio = samples.get("source", {}).get("data")
        generated_audio = samples.get("edited", {}).get("data")
        sr = samples.get("source", {}).get("sr")  # Assuming source and edited have the same sr

        if source_audio is not None and generated_audio is not None:
            # Calculate metrics
            melody_accuracy_score, dynamics_correlation_score, rhythm_f1_score, clap_score = (
                collect_metrics(
                    samples.get("source", {}).get("prompt"),
                    source_audio,
                    generated_audio,
                    sr,
                )
            )

            # Append results to DataFrame
            results.append(
                {
                    "Source Prompt": source_prompt,
                    "Edited Prompt": edited_prompt,
                    "Seed": seed,
                    "Melody Accuracy": melody_accuracy_score,
                    "Dynamics Correlation": dynamics_correlation_score,
                    "Rhythm F1": rhythm_f1_score,
                    "A2A Similarity": clap_score[0],
                    "T2A Similarity": clap_score[1],
                },
            )

    return pd.DataFrame(results).sort_values(by=["Source Prompt", "Seed"])

In [None]:
for model_name in tqdm(ModelName):
    for edit_type in tqdm(EditType, position=1, leave=False):
        df = construct_dataframe(DataLoader(edit_type, model_name))
        DATAFRAME_PATH = (
            RESULTS_DIR / model_name.value / "Evaluation" / edit_type.value / "metrics.pkl"
        )
        df.to_pickle(DATAFRAME_PATH)

In [None]:
dfs = []
for model_name in ModelName:
    for edit_type in EditType:
        DATAFRAME_PATH = (
            RESULTS_DIR / model_name.value / "Evaluation" / edit_type.value / "metrics.pkl"
        )
        df = pd.read_pickle(DATAFRAME_PATH)
        df['Model'] = model_name.value
        df['Edit'] = edit_type.value
        dfs.append(df)

df = pd.concat(dfs)

In [None]:
metrics = [
    "Melody Accuracy",
    "Dynamics Correlation",
    "Rhythm F1",
    "A2A Similarity",
    "T2A Similarity",
]

df.groupby(by=["Model", "Edit"])[metrics].mean()