## Prepare the Environment

In [None]:
import os
from pathlib import Path

BASE_DIR = Path.cwd()
MODEL_NAME = "facebook/musicgen-small"

if (Path("/") / "home" / "vsioros" / "data").is_dir():
    BASE_DIR = Path("/") / "home" / "vsioros" / "data"
    MODEL_NAME = "facebook/musicgen-large"

    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

## Load the Model

The pre-trained MusicGen small, medium and large checkpoints can be loaded from the [pre-trained weights](https://huggingface.co/models?search=facebook/musicgen-) on the Hugging Face Hub. Change the repo id with the checkpoint size you wish to load. We'll default to the small checkpoint, which is the fastest of the three but has the lowest audio quality:

In [None]:
from typing import Any

import numpy as np
import torch
from numpy.typing import NDArray
from transformers import AutoProcessor, MusicgenForConditionalGeneration


class ModelProxy:
    def __init__(self, model_name: str, guidance_scale: float = 3.0):
        self.model_name = model_name
        self.guidance_scale = guidance_scale

        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"

        self.processor = AutoProcessor.from_pretrained(model_name)
        self.model = MusicgenForConditionalGeneration.from_pretrained(model_name).to(self.device)

    def generate(self, inputs: dict[str, Any], max_new_tokens: int = 512) -> NDArray[np.float64]:
        return (
            self.model.generate(
                **inputs.to(self.device),
                do_sample=True,
                guidance_scale=self.guidance_scale,
                max_new_tokens=max_new_tokens,
            )
            .cpu()
            .numpy()
            .squeeze()
        )

    def encode(self, prompts: list[str]) -> dict[str, Any]:
        return self.processor(text=prompts, padding=True, return_tensors="pt")

    def decode(self, encoded_token: NDArray[np.float64]) -> str:
        return self.processor.decode(encoded_token)

    @property
    def sampling_rate(self) -> float:
        return self.model.config.audio_encoder.sampling_rate

    @property
    def frame_rate(self) -> float:
        return self.model.config.audio_encoder.frame_rate

    @property
    def decoder_layers(self) -> list[torch.nn.Module]:
        return self.model.decoder.model.decoder.layers

In [None]:
model_proxy = ModelProxy(MODEL_NAME)

## Prompt-to-Prompt

In [None]:
from typing import Optional, Tuple

import torch
from torch import nn
from transformers.utils import (
    logging,
)

logger = logging.get_logger(__name__)


def register_attention_control(model_proxy, controller):
    def ca_forward(self, attention_type):
        def forward(
            hidden_states: torch.Tensor,
            key_value_states: Optional[torch.Tensor] = None,
            past_key_value: Optional[Tuple[torch.Tensor]] = None,
            attention_mask: Optional[torch.Tensor] = None,
            layer_head_mask: Optional[torch.Tensor] = None,
            output_attentions: bool = False,
        ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
            """Input shape: Batch x Time x Channel"""
            # if key_value_states are provided this layer is used as a cross-attention layer
            # for the decoder
            is_cross_attention = key_value_states is not None

            bsz, tgt_len, _ = hidden_states.size()

            # get query proj
            query_states = self.q_proj(hidden_states) * self.scaling
            # get key, value proj
            # `past_key_value[0].shape[2] == key_value_states.shape[1]`
            # is checking that the `sequence_length` of the `past_key_value` is the same as
            # the provided `key_value_states` to support prefix tuning
            if (
                is_cross_attention
                and past_key_value is not None
                and past_key_value[0].shape[2] == key_value_states.shape[1]
            ):
                # reuse k,v, cross_attentions
                key_states = past_key_value[0]
                value_states = past_key_value[1]
            elif is_cross_attention:
                # cross_attentions
                key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
                value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
            elif past_key_value is not None:
                # reuse k, v, self_attention
                key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
                value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
                key_states = torch.cat([past_key_value[0], key_states], dim=2)
                value_states = torch.cat([past_key_value[1], value_states], dim=2)
            else:
                # self_attention
                key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
                value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

            if self.is_decoder:
                # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
                # Further calls to cross_attention layer can then reuse all cross-attention
                # key/value_states (first "if" case)
                # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
                # all previous decoder key/value_states. Further calls to uni-directional self-attention
                # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
                # if encoder bi-directional self-attention `past_key_value` is always `None`
                past_key_value = (key_states, value_states)

            proj_shape = (bsz * self.num_heads, -1, self.head_dim)
            query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
            key_states = key_states.reshape(*proj_shape)
            value_states = value_states.reshape(*proj_shape)

            src_len = key_states.size(1)
            attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))

            if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
                raise ValueError(
                    f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
                    f" {attn_weights.size()}",
                )

            if attention_mask is not None:
                if attention_mask.size() != (bsz, 1, tgt_len, src_len):
                    raise ValueError(
                        f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}",
                    )
                attn_weights = (
                    attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
                )
                attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

            attn_weights = nn.functional.softmax(attn_weights, dim=-1)
            attn_weights = controller(attn_weights, is_cross_attention, attention_type)

            if layer_head_mask is not None:
                if layer_head_mask.size() != (self.num_heads,):
                    raise ValueError(
                        f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
                        f" {layer_head_mask.size()}",
                    )
                attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
                    bsz,
                    self.num_heads,
                    tgt_len,
                    src_len,
                )
                attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

            if output_attentions:
                # this operation is a bit awkward, but it's required to
                # make sure that attn_weights keeps its gradient.
                # In order to do so, attn_weights have to be reshaped
                # twice and have to be reused in the following
                attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
                attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
            else:
                attn_weights_reshaped = None

            attn_probs = nn.functional.dropout(
                attn_weights,
                p=self.dropout,
                training=self.training,
            )

            attn_output = torch.bmm(attn_probs, value_states)

            if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
                raise ValueError(
                    f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
                    f" {attn_output.size()}",
                )

            attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
            attn_output = attn_output.transpose(1, 2)

            # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
            # partitioned across GPUs when using tensor-parallelism.
            attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

            attn_output = self.out_proj(attn_output)

            return attn_output, attn_weights_reshaped, past_key_value

        return forward

    def register_recr(net_, count, attention_type):
        if net_.__class__.__name__ == "MusicgenAttention":
            net_.forward = ca_forward(net_, attention_type)
            return count + 1

        return count

    cross_att_count = 0
    sub_nets = model_proxy.decoder_layers.named_children()
    for net in sub_nets:
        if net[1].__class__.__name__ != "MusicgenDecoderLayer":
            continue

        for subnet in net[1].named_children():
            attention_type = None
            if subnet[0] == "encoder_attn":
                attention_type = "cross"
            elif subnet[0] == "self_attn":
                attention_type = "self"

            if attention_type is not None:
                cross_att_count += register_recr(subnet[1], 0, attention_type)

    controller.num_att_layers = cross_att_count

In [None]:
from abc import ABC, abstractmethod
from collections import defaultdict


class BaseController(ABC):
    def reset(self):
        self.num_att_layers = 0
        self.batch_size = -1
        self.max_new_tokens = -1

        self.cur_att_layer = 0
        self.cur_step = 0

    def __call__(self, attn_weights, is_cross, attention_type) -> None:
        self.cur_att_layer = self.cur_att_layer + 1
        if self.cur_att_layer == self.num_att_layers:
            self.cur_att_layer = 0
            self.cur_step = self.cur_step + 1

        # Exclude unconditional inputs
        h1 = attn_weights.shape[0] // 2
        attn = attn_weights[:h1]

        # Reshape according to batch size
        h2 = attn.shape[0] // (self.batch_size)

        attn = attn.reshape(self.batch_size, h2, *attn.shape[1:])

        if is_cross:
            attn = self.replace_cross_attention(attn, is_cross, attention_type)
            # attn /= torch.sum(attn, dim=-1, keepdim=True)
        else:
            attn = self.replace_self_attention(attn)

        attn = attn.reshape(self.batch_size * h2, *attn.shape[2:])

        attn_weights[:h1] = attn

        return attn_weights

    @abstractmethod
    def replace_self_attention(self, attn):
        raise NotImplementedError

    @abstractmethod
    def replace_cross_attention(self, attn_weights, is_cross, attention_type) -> None:
        raise NotImplementedError


class EmptyController(BaseController):
    def replace_self_attention(self, attn):
        return attn

    def replace_cross_attention(self, attn_weights, is_cross, attention_type):
        return attn_weights


class AttentionStore(BaseController):
    def reset(self):
        super().reset()

        self.features = defaultdict(list)

    def get_self_attention(self):
        tensors = self.features["self"]
        tensors = [tensor.mean(dim=-1) for tensor in tensors]
        tensors = torch.stack(tensors)
        tensors = tensors.view(self.max_new_tokens, -1, *tensors.shape[1:])

        return tensors

    def get_cross_attention(self):
        tensors = self.features["cross"]
        tensors = torch.stack(tensors)
        tensors = tensors.view(self.max_new_tokens, -1, *tensors[0].shape)

        return tensors

    def get_self_attention_importance(self):
        aggregate_cross_attention = self.get_self_attention()
        aggregate_cross_attention = aggregate_cross_attention[:, :, 1:, :, :]
        aggregate_cross_attention = aggregate_cross_attention.mean(dim=(0, 3, 4))

        # Min-Max scaling to normalize values between 0 and 1 for each column (sample)
        min_values = aggregate_cross_attention.min(dim=0).values
        max_values = aggregate_cross_attention.max(dim=0).values

        normalized_scores = (aggregate_cross_attention - min_values) / (max_values - min_values)

        # Get indices that would sort the layers based on their mean scores
        sorted_indices = torch.argsort(normalized_scores, descending=True, dim=0)
        sorted_indices = sorted_indices.view(sorted_indices.shape[1], -1)

        # self-attention layers are called first and thusly hold indices 1, 3, 5 etc.
        sorted_indices = 2 * sorted_indices + 1

        return sorted_indices

    def get_cross_attention_importance(self, word_piece_index):
        aggregate_cross_attention = self.get_cross_attention()
        aggregate_cross_attention = aggregate_cross_attention[:, :, 1:, :, :, word_piece_index]
        aggregate_cross_attention = aggregate_cross_attention.mean(dim=(0, 3, 4))

        # Min-Max scaling to normalize values between 0 and 1 for each column (sample)
        min_values = aggregate_cross_attention.min(dim=0).values
        max_values = aggregate_cross_attention.max(dim=0).values

        normalized_scores = (aggregate_cross_attention - min_values) / (max_values - min_values)

        # Get indices that would sort the layers based on their mean scores
        sorted_indices = torch.argsort(normalized_scores, descending=True, dim=0)
        sorted_indices = sorted_indices.view(sorted_indices.shape[1], -1)

        # cross-attention layers are called second and thusly hold indices 2, 4, 6 etc.
        sorted_indices = 2 * (sorted_indices + 1)

        return sorted_indices

    def get_aggregate_cross_attention(self):
        return torch.mean(torch.stack(self.features["cross"]), axis=0)

    def replace_self_attention(self, attn) -> None:
        self.features["self"].append(attn)

        return attn

    def replace_cross_attention(self, attn_weights, is_cross, attention_type) -> None:
        self.features["cross"].append(attn_weights)

        return attn_weights

In [None]:
class BaseEditController(BaseController):
    def replace_self_attention(self, attn):
        attn_base, att_replace = attn[0], attn[1:]

        return attn_base.unsqueeze(0).expand(att_replace.shape[0] + 1, *attn_base.shape)


class RandomController(BaseEditController):
    def replace_cross_attention(self, attn_weights, is_cross, attention_type):
        attn_weights[1:] = torch.randn_like(attn_weights[0])

        return attn_weights


class IgnoreWordController(BaseEditController):
    def __init__(self, indices: list[int]):
        super().__init__()

        self.indices = indices

    def replace_cross_attention(self, attn_weights, is_cross, attention_type):
        attn_weights[1:, :, :, self.indices] = 0

        return attn_weights


class ReplaceWordController(BaseEditController):
    def __init__(self, indices: list[list[int]], blend: float = 0.5):
        super().__init__()

        self.source_indices, self.target_indices = indices
        self.blend = blend

    def replace_cross_attention(self, attn_weights, is_cross, attention_type):
        source_attn = attn_weights[0, :, :, self.source_indices]
        averaged_attn = source_attn.mean(dim=-1, keepdims=True)

        # Repeat averaged attention values to match dimensions of the target attention
        averaged_attn_repeated = averaged_attn.expand(-1, -1, len(self.target_indices))

        attn_weights[1:, :, :, self.target_indices] = (1 - self.blend) * attn_weights[
            1:,
            :,
            :,
            self.target_indices,
        ] + self.blend * averaged_attn_repeated

        return attn_weights


class RefineController(BaseEditController):
    def __init__(self, indices: list[list[int]], blend: float = 0.5):
        super().__init__()

        self.source_indices, self.target_indices = indices
        self.blend = blend

    def replace_cross_attention(self, attn_weights, is_cross, attention_type):
        attn_weights[1:, :, :, self.target_indices] = (1 - self.blend) * attn_weights[
            1:,
            :,
            :,
            self.target_indices,
        ] + self.blend * attn_weights[0, :, :, self.source_indices]

        return attn_weights


class ReweightWordController(BaseEditController):
    def __init__(self, indices: list[int], weight: float = 5):
        super().__init__()

        self.indices = indices
        self.weight = weight

    def replace_cross_attention(self, attn_weights, is_cross, attention_type):
        attn_weights[1:] = attn_weights[0]

        non_target_indices = [i for i in range(attn_weights.shape[-1]) if i not in self.indices]
        attn_weights[1:, :, :, non_target_indices] /= self.weight

        return attn_weights


class ReplaceController(BaseEditController):
    def __init__(self, blend: float = 0.5):
        super().__init__()

        self.blend = blend

    def replace_cross_attention(self, attn_weights, is_cross, attention_type):
        attn_weights[1:] = (1 - self.blend) * attn_weights[1:] + self.blend * attn_weights[0]

        return attn_weights

In [None]:
from typing import Any


class ControllerModifier(BaseController):
    def __init__(self, controller: BaseController) -> None:
        super().__init__()

        self.controller = controller

    def __getattr__(self, name: str):
        if name == "controller":
            return super().__getattr__(name)

        return getattr(self.controller, name)

    def __setattr__(self, name: str, value: Any):
        if name == "controller":
            return super().__setattr__(name, value)

        return setattr(self.controller, name, value)


class OffsetControllerModifier(ControllerModifier):
    def __init__(self, controller: BaseController, offset: float = 0.0) -> None:
        super().__init__(controller)

        assert 0.0 <= offset <= 1.0

        self.offset = offset

    def replace_self_attention(self, attn):
        if self.cur_step < round(self.offset * self.max_new_tokens):
            return attn

        return self.controller.replace_self_attention(attn)

    def replace_cross_attention(self, attn_weights, is_cross, attention_type):
        if self.cur_step < round(self.offset * self.max_new_tokens):
            return attn_weights

        return self.controller.replace_cross_attention(attn_weights, is_cross, attention_type)


class AttentionHeadControllerModifier(ControllerModifier):
    def __init__(self, controller: BaseController, attention_head_indices: list[int]) -> None:
        super().__init__(controller)

        self.attention_head_indices = attention_head_indices

    def replace_self_attention(self, attn):
        return self.controller.replace_self_attention(attn)

    def replace_cross_attention(self, attn_weights, is_cross, attention_type):
        attn_weights_slice = attn_weights[:, self.attention_head_indices, :, :]
        attn_weights_slice = self.controller.replace_cross_attention(
            attn_weights_slice,
            is_cross,
            attention_type,
        )

        attn_weights[:, self.attention_head_indices, :, :] = attn_weights_slice

        return attn_weights


class SelfAttentionLerpControllerModifier(ControllerModifier):
    def __init__(self, controller: BaseController) -> None:
        super().__init__(controller)

    def replace_self_attention(self, attn):
        blend = self.cur_att_layer / self.num_att_layers

        attn[1:] = (1 - blend) * attn[1:] + blend * attn[0]

        return attn

    def replace_cross_attention(self, attn_weights, is_cross, attention_type) -> None:
        return self.controller.replace_cross_attention(attn_weights, is_cross, attention_type)


class SelfAttentionCutoffControllerModifier(ControllerModifier):
    def __init__(self, controller: BaseController, threshold: float = 0.75) -> None:
        super().__init__(controller)

        assert 0.0 <= threshold <= 1.0

        self.threshold = threshold

    def replace_self_attention(self, attn):
        if self.cur_att_layer <= np.floor(self.threshold * self.num_att_layers):
            return self.controller.replace_self_attention(attn)

        return attn

    def replace_cross_attention(self, attn_weights, is_cross, attention_type) -> None:
        return self.controller.replace_cross_attention(attn_weights, is_cross, attention_type)


class AttentionLerpControllerModifier(ControllerModifier):
    def __init__(self, controller: BaseController) -> None:
        super().__init__(controller)

    def replace_self_attention(self, attn):
        blend = self.cur_att_layer / self.num_att_layers

        attn = (1 - blend) * attn + blend * self.controller.replace_self_attention(attn)

        return attn

    def replace_cross_attention(self, attn_weights, is_cross, attention_type) -> None:
        blend = self.cur_att_layer / self.num_att_layers

        attn_weights = (
            1 - blend
        ) * attn_weights + blend * self.controller.replace_cross_attention(
            attn_weights.clone(),
            is_cross,
            attention_type,
        )

        return attn_weights


class AttentionCutoffControllerModifier(ControllerModifier):
    def __init__(self, controller: BaseController, threshold: float = 0.75) -> None:
        super().__init__(controller)

        assert 0.0 <= threshold <= 1.0

        self.threshold = threshold

    def replace_self_attention(self, attn):
        if self.cur_att_layer <= np.floor(self.threshold * self.num_att_layers):
            return self.controller.replace_self_attention(attn)

        return attn

    def replace_cross_attention(self, attn_weights, is_cross, attention_type) -> None:
        if self.cur_att_layer <= np.floor(self.threshold * self.num_att_layers):
            return self.controller.replace_cross_attention(attn_weights, is_cross, attention_type)

        return attn_weights


class DecoderLayerControllerModifier(ControllerModifier):
    def __init__(self, controller: BaseController, decoder_layer_indices: set[int]) -> None:
        super().__init__(controller)

        self.decoder_layer_indices = decoder_layer_indices

    def replace_self_attention(self, attn):
        if self.cur_att_layer in self.decoder_layer_indices:
            return self.controller.replace_self_attention(attn)

        return attn

    def replace_cross_attention(self, attn_weights, is_cross, attention_type):
        if self.cur_att_layer in self.decoder_layer_indices:
            return self.controller.replace_cross_attention(attn_weights, is_cross, attention_type)

        return attn_weights


class StrengthControllerModifier(ControllerModifier):
    def __init__(self, controller: BaseController, strength: float = 1.0) -> None:
        super().__init__(controller)

        self.strength = strength

    def replace_self_attention(self, attn):
        attn = (1 - self.strength) * attn + self.strength * self.controller.replace_self_attention(
            attn.clone(),
        )

        return attn

    def replace_cross_attention(self, attn_weights, is_cross, attention_type) -> None:
        attn_weights = (
            1 - self.strength
        ) * attn_weights + self.strength * self.controller.replace_cross_attention(
            attn_weights.clone(),
            is_cross,
            attention_type,
        )

        return attn_weights

In [None]:
import re
from datetime import datetime

import numpy as np
from scipy.io.wavfile import write

CONTROLLER_REG = re.compile(r"Controller.*")
SNAKE_CASE_REG = re.compile(r"(?<!^)(?=[A-Z])")

RESULTS_DIR = BASE_DIR / "results" / "musicgen"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)


def _get_controller_subpath(controller: BaseController) -> Path:
    controller_name = CONTROLLER_REG.sub("", controller.__class__.__name__)
    controller_name = SNAKE_CASE_REG.sub("_", controller_name).lower()

    return Path(controller_name)


def get_controller_subpath(controller: BaseController) -> Path:
    if isinstance(controller, DecoderLayerControllerModifier):
        return (
            _get_controller_subpath(controller)
            / f"{''.join(map(lambda x: f'{x:02d}', controller.decoder_layer_indices))}"
            / get_controller_subpath(controller.controller)
        )
    elif isinstance(controller, ReplaceController):
        return _get_controller_subpath(controller) / f"{controller.blend:.2f}"
    elif isinstance(controller, ReplaceWordController):
        return (
            _get_controller_subpath(controller)
            / f"{controller.index:02d}"
            / f"{controller.blend:.2f}"
        )
    else:
        raise NotImplementedError(controller.__class__.__name__)


def construct_output_folder_path(controller: BaseController) -> Path:
    return RESULTS_DIR / get_controller_subpath(controller) / datetime.now().strftime("%Y_%m_%d")


def get_tokens(prompts: list[str]):
    inputs = model_proxy.encode(prompts)

    tokens = []
    for index in range(len(prompts)):
        tokens.append([model_proxy.decode(item) for item in inputs["input_ids"][index]])

    return tokens


def get_replacement_indices(prompts: list[str], word_a: str, word_b: str) -> list[list[int]]:
    token_groups = get_tokens(prompts)

    if len(prompts[0].split()) != len(prompts[1].split()):
        raise NotImplementedError(f"Different prompt lengths ({prompts})")

    indices = []
    for word, tokens in zip([word_a, word_b], token_groups):
        prompt_indices, substring = [], []
        for i in range(len(tokens)):
            if word.startswith("".join([*substring, tokens[i]])):
                substring.append(tokens[i])
                prompt_indices.append(i)
        indices.append(prompt_indices)

    return indices


def get_reweight_word_indices(prompts: list[str], word: str) -> tuple[list[str], list[int]]:
    return get_replacement_indices(prompts, word, word)[0]


def get_refine_word_indices(prompts: list[str]) -> tuple[list[str], list[int]]:
    token_groups = get_tokens(prompts)

    indices = []
    for i, token_i in enumerate(token_groups[0]):
        for j, token_j in enumerate(token_groups[1]):
            if token_i.startswith("<") or token_j.startswith("<"):
                continue

            if token_i == token_j:
                indices.append((i, j))

    return list(zip(*indices))


def get_ignore_indices(prompts: list[str]) -> tuple[list[str], list[int]]:
    tokens_a, tokens_b = prompts[0].split(), prompts[1].split()
    if len(tokens_a) != len(tokens_b):
        raise NotImplementedError("Different prompt lengths")

    index = tokens_b.index("<IGNORE>")
    word = tokens_a[index]

    return [prompts[0], prompts[0]], get_replacement_indices([prompts[0], prompts[0]], word, word)[
        0
    ]


def run_and_display(prompts, output_folder, controller=None, seed=0, audio_length=10.0):
    max_new_tokens = 2 ** round(np.log2(audio_length * model_proxy.frame_rate))

    if seed is not None:
        torch.manual_seed(seed)

    if controller is None:
        controller = EmptyController()

    controller.reset()
    controller.batch_size = len(prompts)
    controller.max_new_tokens = max_new_tokens

    register_attention_control(model_proxy, controller)

    inputs = model_proxy.encode(prompts)

    return model_proxy.generate(inputs, max_new_tokens=max_new_tokens)

### ISMIR Evaluation

We incorporate different audio editing axes:

- Instrument Replacement: In these prompts, one instrument or sound source is replaced with another instrument or sound. For example, replacing blues ensemble with guitar and drums with country ensemble with guitar and drums, or replacing acoustic guitar solo with electric guitar solo.

- Mood/Tonal Change: These prompts involve changing the mood or tonality of the music. For instance, transforming a happy violin solo into a sad violin solo, or converting a major chord pop song into a minor chord pop song.

- Genre Shift: These prompts involve shifting the genre or style of the music. For example, transitioning from a rock riff on electric guitar to a metal riff on electric guitar, or changing a jazz beat with saxophone into a hip-hop beat with saxophone.

- Melodic Transformation: These edits involve altering the melodic content of the music. This can include changes in melodic contour, intervals, motifs, and melodies. For example, transforming a melodic line from ascending to descending or changing the melodic intervals to create a different melodic feel.

- Harmonic Modification: These edits involve modifying the harmonic structure of the music. This can include changes in chord progressions, harmonic rhythm, harmonic density, and harmonic tension. For instance, altering the chord progression from a standard I-IV-V to a more complex progression or introducing chromaticism to the harmony.

- Form/Structure Variation: These edits involve variations in the overall form or structure of the music. This can include changes in sectional arrangement, repetitions, transitions, and developmental processes. For example, restructuring a piece by adding or removing sections, or altering the order of musical events to create a different narrative flow.

We generate a plethora of prompt pairs utilizing ChatGPT and the accompanying text-prompt:

#### Replace

```txt
Generate a list of quadruples containing:
1. The edit category,
2. The original word in the prompt,
3. The replacement word,
4. The prompt pair consisting of the source and edited prompt. 

Ensure to include at least one entry for each of the following edit categories and output only a single Python list.

Example entry:
(
   "Genre Shift",
   "blues",
   "country",
   ("blues ensemble with guitar and drums", "country ensemble with guitar and drums"),
),

Edit Categories:

- Instrument Replacement: Replace one instrument or sound source with another.
- Mood/Tonal Change: Change the mood or tonality of the music.
- Genre Shift: Shift the genre or style of the music.
- Melodic Transformation: Alter the melodic content of the music.
- Harmonic Modification: Modify the harmonic structure of the music.
- Form/Structure Variation: Vary the overall form or structure of the music.
```

#### Refine

```txt
Generate a list of tuples containing:
1. The edit category,
2. The prompt pair consisting of the source and edited prompt. 

The source prompt should strictly be a substring of the edited prompt. The edited prompt should only add details.

Ensure to include at least one entry for each of the following edit categories and output only a single Python list.

Example entry:
("piano melody", "jazz piano melody with improvisation"),

Edit Categories:

- Instrument Enhancement: Enhancing the sound by adding effects or layering additional sounds without replacing the instrument.
- Mood/Tonal Enhancement: Modifying tonality or mood through adjustments like reverb or EQ settings.
- Genre Fusion: Combining elements from different genres while preserving the original essence.
- Melodic Embellishment: Adding ornamentations or variations to enrich the melody's expressiveness.
- Harmonic Enrichment: Enriching the harmonic structure by adding chords or layers for a fuller sound.
- Form/Structure Expansion: Elaborating on the form or structure by adding new sections or transitions for complexity.
```

### Reweight

```txt
Generate a list of tuples containing:
1. The edit category,
2. The target word
2. The prompt pair consisting of the source and edited prompt. 

The source prompt and the edited prompt should be identical. The target word must be included in both prompts, indicating which aspect of the described audio should be intensified or diminished.

Ensure to include at least one entry for each of the following edit categories and output only a single Python list.

Example entry:
("happy", ("happy acoustic guitar solo", "happy acoustic guitar solo"))
Edit Categories:

Instrument Reinforcement: Enhancing the sound of specific instruments or adding additional layers to enrich the overall texture without replacing them entirely.
Mood/Tonal Brightening: Adjusting tonality or mood through changes like increasing brightness or introducing uplifting effects to evoke a happier atmosphere.
Genre Blend: Mixing elements from various genres while maintaining the song's core identity to create a fusion that embodies a happier vibe.
Melodic Flourish: Introducing embellishments or variations in the melody to make it more lively, optimistic, and joyful.
Harmonic Enlivening: Enriching the harmonic structure by adding chords or harmonic layers that convey a sense of positivity and energy.
Form/Structure Expansion: Expanding the song's structure or adding new sections to build anticipation, create contrasts, and enhance the overall uplifting mood.
```

In [None]:
replace_group = [
    (
        "Instrument Replacement",
        "violin",
        "cello",
        ("beautiful violin solo", "beautiful cello solo"),
    ),
    (
        "Mood/Tonal Change",
        "happy",
        "sad",
        ("uplifting piano melody", "melancholic piano melody"),
    ),
    (
        "Genre Shift",
        "rock",
        "metal",
        ("rock riff on electric guitar", "metal riff on electric guitar"),
    ),
    (
        "Melodic Transformation",
        "ascending",
        "descending",
        ("ascending melodic line", "descending melodic line"),
    ),
    (
        "Harmonic Modification",
        "I-IV-V",
        "ii-V-I",
        ("standard chord progression", "jazzier chord progression"),
    ),
    (
        "Form/Structure Variation",
        "adding",
        "removing",
        (
            "restructuring a piece by adding sections",
            "restructuring a piece by removing sections",
        ),
    ),
    (
        "Instrument Replacement",
        "guitar",
        "piano",
        ("guitar solo", "piano solo"),
    ),
    (
        "Mood/Tonal Change",
        "major",
        "minor",
        ("major chord pop song", "minor chord pop song"),
    ),
    (
        "Genre Shift",
        "jazz",
        "hip-hop",
        ("jazz beat with saxophone", "hip-hop beat with saxophone"),
    ),
    (
        "Melodic Transformation",
        "motif",
        "variation",
        ("repeating motif", "varied motif"),
    ),
    (
        "Harmonic Modification",
        "standard",
        "chromatic",
        ("standard chord progression", "chromatic chord progression"),
    ),
    (
        "Form/Structure Variation",
        "repetitions",
        "transitions",
        ("repeating sections", "transitional sections"),
    ),
    (
        "Instrument Replacement",
        "drums",
        "synthesizer",
        ("drum solo", "synthesizer solo"),
    ),
    (
        "Mood/Tonal Change",
        "dark",
        "ethereal",
        ("dark ambient track", "ethereal ambient track"),
    ),
    (
        "Genre Shift",
        "pop",
        "reggae",
        ("pop song with catchy hooks", "reggae song with catchy hooks"),
    ),
    (
        "Melodic Transformation",
        "intervals",
        "sequences",
        ("melodic intervals", "melodic sequences"),
    ),
    (
        "Harmonic Modification",
        "I-vi-IV-V",
        "ii-V-I",
        ("typical chord progression", "jazzier chord progression"),
    ),
    (
        "Form/Structure Variation",
        "intro",
        "outro",
        ("introductory section", "concluding section"),
    ),
    (
        "Instrument Replacement",
        "trumpet",
        "flute",
        ("trumpet solo", "flute solo"),
    ),
    (
        "Mood/Tonal Change",
        "uplifting",
        "haunting",
        ("uplifting guitar melody", "haunting guitar melody"),
    ),
]

In [None]:
refine_group = [
    (
        "Instrument Enhancement",
        ("piano melody", "jazz piano melody with added chorus effect for depth and warmth"),
    ),
    (
        "Mood/Tonal Enhancement",
        ("guitar riff", "ethereal guitar riff with shimmering reverb and atmospheric delay"),
    ),
    (
        "Genre Fusion",
        (
            "hip-hop beat",
            "trap-infused hip-hop beat with electronic synth arpeggios and 808 bass",
        ),
    ),
    (
        "Melodic Embellishment",
        (
            "vocal line",
            "soulful vocal line with intricate melismatic runs and emotive vibrato",
        ),
    ),
    (
        "Harmonic Enrichment",
        (
            "chord progression",
            "lush chord progression with added ninth and eleventh extensions for richness",
        ),
    ),
    (
        "Form/Structure Expansion",
        (
            "bridge section",
            "extended bridge section with modulating key centers and layered counterpoint",
        ),
    ),
    (
        "Instrument Enhancement",
        (
            "drum groove",
            "dynamic drum groove with layered percussion and enhanced stereo imaging",
        ),
    ),
    (
        "Mood/Tonal Enhancement",
        (
            "ambient pad",
            "serene ambient pad with subtle modulated filters and soft side-chain compression",
        ),
    ),
    (
        "Genre Fusion",
        (
            "jazz saxophone solo",
            "fusion jazz saxophone solo with electronic glitch effects and syncopated beats",
        ),
    ),
    (
        "Melodic Embellishment",
        (
            "flute melody",
            "flute melody with cascading runs and delicate trills for added expressiveness",
        ),
    ),
    (
        "Harmonic Enrichment",
        (
            "bassline",
            "deep bassline with walking chromatic lines and extended harmonic sequences",
        ),
    ),
    (
        "Form/Structure Expansion",
        ("chorus", "expanded chorus with layered harmonies and intricate rhythmic variations"),
    ),
    (
        "Instrument Enhancement",
        (
            "synth lead",
            "bright synth lead with added modulation effects and stereo widening for depth",
        ),
    ),
    (
        "Mood/Tonal Enhancement",
        (
            "piano chords",
            "soothing piano chords with gentle reverb and subtle tape saturation for warmth",
        ),
    ),
    (
        "Genre Fusion",
        (
            "reggae rhythm",
            "reggae rhythm with dubstep-inspired bass drops and electronic glitches",
        ),
    ),
    (
        "Melodic Embellishment",
        (
            "violin solo",
            "expressive violin solo with emotive slides and delicate pizzicato accents",
        ),
    ),
    (
        "Harmonic Enrichment",
        (
            "guitar strumming",
            "dynamic guitar strumming with extended chord voicings and added suspended notes",
        ),
    ),
    (
        "Form/Structure Expansion",
        (
            "pre-chorus",
            "extended pre-chorus with building tension and additional instrumental layers",
        ),
    ),
    (
        "Instrument Enhancement",
        (
            "drum fill",
            "energetic drum fill with layered percussion and added room reverb for spaciousness",
        ),
    ),
    (
        "Mood/Tonal Enhancement",
        ("synth pad", "dreamy synth pad with evolving filter sweeps and atmospheric delays"),
    ),
]

In [None]:
reweight_group = [
    (
        "Instrument Reinforcement",
        "drums",
        ("dynamic drums in the chorus", "dynamic drums in the chorus"),
    ),
    (
        "Mood/Tonal Brightening",
        "bright",
        ("bright piano melody", "bright piano melody"),
    ),
    (
        "Genre Blend",
        "pop",
        ("pop rock guitar riff", "pop rock guitar riff"),
    ),
    (
        "Melodic Flourish",
        "optimistic",
        ("optimistic flute melody", "optimistic flute melody"),
    ),
    (
        "Harmonic Enlivening",
        "major",
        ("uplifting major chord progression", "uplifting major chord progression"),
    ),
    (
        "Form/Structure Expansion",
        "chorus",
        ("extended chorus with layered vocals", "extended chorus with layered vocals"),
    ),
    (
        "Instrument Reinforcement",
        "bass",
        ("thumping bassline", "thumping bassline"),
    ),
    (
        "Mood/Tonal Brightening",
        "cheerful",
        ("cheerful brass section", "cheerful brass section"),
    ),
    (
        "Genre Blend",
        "funk",
        ("funk-infused guitar riff", "funk-infused guitar riff"),
    ),
    (
        "Melodic Flourish",
        "joyful",
        ("joyful synth melody", "joyful synth melody"),
    ),
    (
        "Harmonic Enlivening",
        "seventh",
        ("vibrant seventh chord progression", "vibrant seventh chord progression"),
    ),
    (
        "Form/Structure Expansion",
        "bridge",
        (
            "extended bridge section with energetic build-up",
            "extended bridge section with energetic build-up",
        ),
    ),
    (
        "Instrument Reinforcement",
        "guitar",
        ("powerful guitar solo", "powerful guitar solo"),
    ),
    (
        "Mood/Tonal Brightening",
        "uplifting",
        ("uplifting strings arrangement", "uplifting strings arrangement"),
    ),
    (
        "Genre Blend",
        "reggae",
        ("reggae-inspired drum groove", "reggae-inspired drum groove"),
    ),
    (
        "Melodic Flourish",
        "hopeful",
        ("hopeful vocal melody", "hopeful vocal melody"),
    ),
    (
        "Harmonic Enlivening",
        "major",
        ("bright major chord progression", "bright major chord progression"),
    ),
    (
        "Form/Structure Expansion",
        "pre-chorus",
        (
            "extended pre-chorus with dynamic instrumentation",
            "extended pre-chorus with dynamic instrumentation",
        ),
    ),
    (
        "Instrument Reinforcement",
        "synth",
        ("lush synth pads", "lush synth pads"),
    ),
    (
        "Mood/Tonal Brightening",
        "vibrant",
        ("vibrant brass section", "vibrant brass section"),
    ),
]

In [None]:
from typing import Any, Iterable, Iterator


def transform_samples(
    edit: str,
    samples: Iterable[tuple[str, Any, tuple[str]]],
) -> Iterator[dict[str, Any]]:
    for edit_category, *aditional, prompts in samples:
        yield {
            "Edit": edit.title(),
            "Category": edit_category,
            "Source Prompt": prompts[0],
            "Edited Prompt": prompts[1],
            "Additional": aditional,
        }


class Dataset:
    def __init__(self, **groups: dict[str, dict[str, Any]]) -> None:
        self.groups = groups

    def __iter__(self) -> Iterator[tuple[str, tuple[str, str], BaseController]]:
        for edit, group in self.groups.items():
            for sample in transform_samples(edit, group):
                edit = sample["Edit"]
                category = sample["Category"]
                source_prompt = sample["Source Prompt"]
                edited_prompt = sample["Edited Prompt"]
                prompts = [source_prompt, edited_prompt]
                additional = sample["Additional"]

                if edit == "Replace":

                    def get_controller(prompts, additional):
                        indices = get_replacement_indices(prompts, *additional)
                        controller = SelfAttentionLerpControllerModifier(
                            ReplaceWordController(indices, 1),
                        )

                        return controller
                elif edit == "Refine":

                    def get_controller(prompts, additional):
                        indices = get_refine_word_indices(prompts)
                        controller = SelfAttentionLerpControllerModifier(
                            RefineController(indices, 1),
                        )

                        return controller
                elif edit == "Reweight":

                    def get_controller(prompts, additional):
                        indices = get_reweight_word_indices(prompts, *additional)
                        controller = SelfAttentionLerpControllerModifier(
                            ReweightWordController(indices, 2),
                        )

                        return controller
                else:
                    raise NotImplementedError

                for seed in range(5):
                    yield (
                        edit,
                        category,
                        source_prompt,
                        edited_prompt,
                        seed,
                        get_controller(prompts, additional),
                    )

    def __len__(self):
        count = 0
        for group in self.groups.values():
            count += 5 * len(group)

        return count

In [None]:
dataset = Dataset(
    replace=replace_group,
    refine=refine_group,
    reweight=reweight_group,
)

In [None]:
import warnings

import pandas as pd
from tqdm.auto import tqdm

EVAL_DIR = RESULTS_DIR / "Evaluation"
RESULTS_PATH = EVAL_DIR / "results.pkl"

# Load existing results if available
df = pd.DataFrame()
if RESULTS_PATH.is_file():
    df = pd.read_pickle(RESULTS_PATH)

results = df.to_dict(orient="records")
existing_paths = set(df.get("Source Path", [])).union(set(df.get("Edited Path", [])))

for edit, category, source_prompt, edited_prompt, seed, controller in tqdm(dataset):
    torch.cuda.empty_cache()

    output_folder = EVAL_DIR / edit / f"{source_prompt} - {edited_prompt}" / f"{seed:02d}"

    source_filepath = output_folder / f"00 - {source_prompt}.wav"
    edited_filepath = output_folder / f"01 - {edited_prompt}.wav"

    if source_filepath in existing_paths and edited_filepath in existing_paths:
        continue

    output_folder.mkdir(parents=True, exist_ok=True)

    try:
        audio_values = run_and_display(
            [source_prompt, edited_prompt],
            output_folder,
            controller,
            seed=seed,
        )

        write(source_filepath, rate=model_proxy.sampling_rate, data=audio_values[0])
        write(edited_filepath, rate=model_proxy.sampling_rate, data=audio_values[1])
    except KeyboardInterrupt:
        raise
    except Exception as e:
        warnings.warn(f"{edit} '{source_prompt}' -> '{edited_prompt}' [{e}]")
        continue

    results.append(
        {
            "Source Path": source_filepath,
            "Edited Path": edited_filepath,
            "Edit": edit,
            "Category": category,
            "Source Prompt": source_prompt,
            "Edited Prompt": edited_prompt,
        },
    )

    pd.DataFrame(results).to_pickle(RESULTS_PATH)

In [None]:
df = pd.read_pickle(RESULTS_PATH)

In [None]:
df.head()