## Prepare the Environment

In [None]:
import os
from pathlib import Path

BASE_DIR = Path.cwd()
MODEL_NAME = "facebook/musicgen-small"

if (Path("/") / "home" / "vsioros" / "data").is_dir():
    BASE_DIR = Path("/") / "home" / "vsioros" / "data"
    MODEL_NAME = "facebook/musicgen-large"

    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

## Load the Model

The pre-trained MusicGen small, medium and large checkpoints can be loaded from the [pre-trained weights](https://huggingface.co/models?search=facebook/musicgen-) on the Hugging Face Hub. Change the repo id with the checkpoint size you wish to load. We'll default to the small checkpoint, which is the fastest of the three but has the lowest audio quality:

In [None]:
from transformers import MusicgenForConditionalGeneration
from transformers import AutoProcessor
import torch
from typing import Any
from numpy.typing import NDArray
import numpy as np


class ModelProxy(object):
    def __init__(self, model_name: str, guidance_scale: float = 3.0):
        self.model_name = model_name
        self.guidance_scale = guidance_scale

        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"

        self.processor = AutoProcessor.from_pretrained(model_name)
        self.model = MusicgenForConditionalGeneration.from_pretrained(model_name).to(self.device)

    def generate(self, inputs: dict[str, Any], max_new_tokens: int = 512) -> NDArray[np.float_]:
        return (
            self.model.generate(
                **inputs.to(self.device),
                do_sample=True,
                guidance_scale=self.guidance_scale,
                max_new_tokens=max_new_tokens,
            )
            .cpu()
            .numpy()
            .squeeze()
        )

    def encode(self, prompts: list[str]) -> dict[str, Any]:
        return self.processor(text=prompts, padding=True, return_tensors="pt")

    def decode(self, encoded_token: NDArray[np.float_]) -> str:
        return self.processor.decode(encoded_token)

    @property
    def sampling_rate(self) -> float:
        return self.model.config.audio_encoder.sampling_rate

    @property
    def frame_rate(self) -> float:
        return self.model.config.audio_encoder.frame_rate

    @property
    def decoder_layers(self) -> list[torch.nn.Module]:
        return self.model.decoder.model.decoder.layers

In [None]:
model_proxy = ModelProxy(MODEL_NAME)

## Prompt-to-Prompt

In [None]:
from typing import Optional, Tuple

import torch
from torch import nn
from transformers.utils import (
    logging,
)

logger = logging.get_logger(__name__)


def register_attention_control(model_proxy, controller):
    def ca_forward(self, attention_type):
        def forward(
            hidden_states: torch.Tensor,
            key_value_states: Optional[torch.Tensor] = None,
            past_key_value: Optional[Tuple[torch.Tensor]] = None,
            attention_mask: Optional[torch.Tensor] = None,
            layer_head_mask: Optional[torch.Tensor] = None,
            output_attentions: bool = False,
        ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
            """Input shape: Batch x Time x Channel"""
            # if key_value_states are provided this layer is used as a cross-attention layer
            # for the decoder
            is_cross_attention = key_value_states is not None

            bsz, tgt_len, _ = hidden_states.size()

            # get query proj
            query_states = self.q_proj(hidden_states) * self.scaling
            # get key, value proj
            # `past_key_value[0].shape[2] == key_value_states.shape[1]`
            # is checking that the `sequence_length` of the `past_key_value` is the same as
            # the provided `key_value_states` to support prefix tuning
            if (
                is_cross_attention
                and past_key_value is not None
                and past_key_value[0].shape[2] == key_value_states.shape[1]
            ):
                # reuse k,v, cross_attentions
                key_states = past_key_value[0]
                value_states = past_key_value[1]
            elif is_cross_attention:
                # cross_attentions
                key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
                value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
            elif past_key_value is not None:
                # reuse k, v, self_attention
                key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
                value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
                key_states = torch.cat([past_key_value[0], key_states], dim=2)
                value_states = torch.cat([past_key_value[1], value_states], dim=2)
            else:
                # self_attention
                key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
                value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

            if self.is_decoder:
                # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
                # Further calls to cross_attention layer can then reuse all cross-attention
                # key/value_states (first "if" case)
                # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
                # all previous decoder key/value_states. Further calls to uni-directional self-attention
                # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
                # if encoder bi-directional self-attention `past_key_value` is always `None`
                past_key_value = (key_states, value_states)

            proj_shape = (bsz * self.num_heads, -1, self.head_dim)
            query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
            key_states = key_states.reshape(*proj_shape)
            value_states = value_states.reshape(*proj_shape)

            src_len = key_states.size(1)
            attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))

            if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
                raise ValueError(
                    f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
                    f" {attn_weights.size()}",
                )

            if attention_mask is not None:
                if attention_mask.size() != (bsz, 1, tgt_len, src_len):
                    raise ValueError(
                        f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}",
                    )
                attn_weights = (
                    attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
                )
                attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

            attn_weights = nn.functional.softmax(attn_weights, dim=-1)
            attn_weights = controller(attn_weights, is_cross_attention, attention_type)

            if layer_head_mask is not None:
                if layer_head_mask.size() != (self.num_heads,):
                    raise ValueError(
                        f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
                        f" {layer_head_mask.size()}",
                    )
                attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
                    bsz, self.num_heads, tgt_len, src_len
                )
                attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

            if output_attentions:
                # this operation is a bit awkward, but it's required to
                # make sure that attn_weights keeps its gradient.
                # In order to do so, attn_weights have to be reshaped
                # twice and have to be reused in the following
                attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
                attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
            else:
                attn_weights_reshaped = None

            attn_probs = nn.functional.dropout(
                attn_weights, p=self.dropout, training=self.training
            )

            attn_output = torch.bmm(attn_probs, value_states)

            if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
                raise ValueError(
                    f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
                    f" {attn_output.size()}",
                )

            attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
            attn_output = attn_output.transpose(1, 2)

            # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
            # partitioned across GPUs when using tensor-parallelism.
            attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

            attn_output = self.out_proj(attn_output)

            return attn_output, attn_weights_reshaped, past_key_value

        return forward

    def register_recr(net_, count, attention_type):
        if net_.__class__.__name__ == "MusicgenAttention":
            net_.forward = ca_forward(net_, attention_type)
            return count + 1

        return count

    cross_att_count = 0
    sub_nets = model_proxy.decoder_layers.named_children()
    for net in sub_nets:
        if net[1].__class__.__name__ != "MusicgenDecoderLayer":
            continue

        for subnet in net[1].named_children():
            attention_type = None
            if subnet[0] == "encoder_attn":
                attention_type = "cross"
            elif subnet[0] == "self_attn":
                attention_type = "self"

            if attention_type is not None:
                cross_att_count += register_recr(subnet[1], 0, attention_type)

    controller.num_att_layers = cross_att_count

In [None]:
from abc import ABC, abstractmethod
from collections import defaultdict


class BaseController(ABC):
    def reset(self):
        self.num_att_layers = 0
        self.batch_size = -1
        self.max_new_tokens = -1

        self.cur_att_layer = 0
        self.cur_step = 0

    def __call__(self, attn_weights, is_cross, attention_type) -> None:
        self.cur_att_layer = self.cur_att_layer + 1
        if self.cur_att_layer == self.num_att_layers:
            self.cur_att_layer = 0
            self.cur_step = self.cur_step + 1

        # Exclude unconditional inputs
        h1 = attn_weights.shape[0] // 2
        attn = attn_weights[:h1]

        # Reshape according to batch size
        h2 = attn.shape[0] // (self.batch_size)

        attn = attn.reshape(self.batch_size, h2, *attn.shape[1:])

        if is_cross:
            attn = self.replace_cross_attention(attn, is_cross, attention_type)
            # attn /= torch.sum(attn, dim=-1, keepdim=True)
        else:
            attn = self.replace_self_attention(attn)

        attn = attn.reshape(self.batch_size * h2, *attn.shape[2:])

        attn_weights[:h1] = attn

        return attn_weights

    @abstractmethod
    def replace_self_attention(self, attn):
        raise NotImplementedError

    @abstractmethod
    def replace_cross_attention(self, attn_weights, is_cross, attention_type) -> None:
        raise NotImplementedError


class EmptyController(BaseController):
    def replace_self_attention(self, attn):
        return attn

    def replace_cross_attention(self, attn_weights, is_cross, attention_type):
        return attn_weights


class AttentionStore(BaseController):
    def reset(self):
        super().reset()

        self.features = defaultdict(lambda: [])

    def get_self_attention(self):
        tensors = self.features["self"]
        tensors = [tensor.mean(dim=-1) for tensor in tensors]
        tensors = torch.stack(tensors)
        tensors = tensors.view(self.max_new_tokens, -1, *tensors.shape[1:])

        return tensors

    def get_cross_attention(self):
        tensors = self.features["cross"]
        tensors = torch.stack(tensors)
        tensors = tensors.view(self.max_new_tokens, -1, *tensors[0].shape)

        return tensors

    def get_self_attention_importance(self):
        aggregate_cross_attention = self.get_self_attention()
        aggregate_cross_attention = aggregate_cross_attention[:, :, 1:, :, :]
        aggregate_cross_attention = aggregate_cross_attention.mean(dim=(0, 3, 4))

        # Min-Max scaling to normalize values between 0 and 1 for each column (sample)
        min_values = aggregate_cross_attention.min(dim=0).values
        max_values = aggregate_cross_attention.max(dim=0).values

        normalized_scores = (aggregate_cross_attention - min_values) / (max_values - min_values)

        # Get indices that would sort the layers based on their mean scores
        sorted_indices = torch.argsort(normalized_scores, descending=True, dim=0)
        sorted_indices = sorted_indices.view(sorted_indices.shape[1], -1)

        # self-attention layers are called first and thusly hold indices 1, 3, 5 etc.
        sorted_indices = 2 * sorted_indices + 1

        return sorted_indices

    def get_cross_attention_importance(self, word_piece_index):
        aggregate_cross_attention = self.get_cross_attention()
        aggregate_cross_attention = aggregate_cross_attention[:, :, 1:, :, :, word_piece_index]
        aggregate_cross_attention = aggregate_cross_attention.mean(dim=(0, 3, 4))

        # Min-Max scaling to normalize values between 0 and 1 for each column (sample)
        min_values = aggregate_cross_attention.min(dim=0).values
        max_values = aggregate_cross_attention.max(dim=0).values

        normalized_scores = (aggregate_cross_attention - min_values) / (max_values - min_values)

        # Get indices that would sort the layers based on their mean scores
        sorted_indices = torch.argsort(normalized_scores, descending=True, dim=0)
        sorted_indices = sorted_indices.view(sorted_indices.shape[1], -1)

        # cross-attention layers are called second and thusly hold indices 2, 4, 6 etc.
        sorted_indices = 2 * (sorted_indices + 1)

        return sorted_indices

    def get_aggregate_cross_attention(self):
        return torch.mean(torch.stack(self.features["cross"]), axis=0)

    def replace_self_attention(self, attn) -> None:
        self.features["self"].append(attn)

        return attn

    def replace_cross_attention(self, attn_weights, is_cross, attention_type) -> None:
        self.features["cross"].append(attn_weights)

        return attn_weights

In [None]:
class BaseEditController(BaseController):
    def replace_self_attention(self, attn):
        attn_base, att_replace = attn[0], attn[1:]

        return attn_base.unsqueeze(0).expand(att_replace.shape[0] + 1, *attn_base.shape)


class RandomController(BaseEditController):
    def replace_cross_attention(self, attn_weights, is_cross, attention_type):
        attn_weights[1:] = torch.randn_like(attn_weights[0])

        return attn_weights


class IgnoreWordController(BaseEditController):
    def __init__(self, indices: list[int]):
        super().__init__()

        self.indices = indices

    def replace_cross_attention(self, attn_weights, is_cross, attention_type):
        attn_weights[1:, :, :, self.indices] = 0

        return attn_weights


class ReplaceWordController(BaseEditController):
    def __init__(self, indices: list[list[int]], blend: float = 0.5):
        super().__init__()

        self.source_indices, self.target_indices = indices
        self.blend = blend

    def replace_cross_attention(self, attn_weights, is_cross, attention_type):
        source_attn = attn_weights[0, :, :, self.source_indices]
        averaged_attn = source_attn.mean(dim=-1, keepdims=True)

        # Repeat averaged attention values to match dimensions of the target attention
        averaged_attn_repeated = averaged_attn.expand(-1, -1, len(self.target_indices))

        attn_weights[1:, :, :, self.target_indices] = (1 - self.blend) * attn_weights[
            1:, :, :, self.target_indices
        ] + self.blend * averaged_attn_repeated

        return attn_weights


class RefineController(BaseEditController):
    def __init__(self, indices: list[list[int]], blend: float = 0.5):
        super().__init__()

        self.source_indices, self.target_indices = indices
        self.blend = blend

    def replace_cross_attention(self, attn_weights, is_cross, attention_type):
        attn_weights[1:, :, :, self.target_indices] = (1 - self.blend) * attn_weights[
            1:, :, :, self.target_indices
        ] + self.blend * attn_weights[0, :, :, self.source_indices]

        return attn_weights


class ReweightWordController(BaseEditController):
    def __init__(self, indices: list[int], weight: float = 5):
        super().__init__()

        self.indices = indices
        self.weight = weight

    def replace_cross_attention(self, attn_weights, is_cross, attention_type):
        attn_weights[1:] = attn_weights[0]

        non_target_indices = [i for i in range(attn_weights.shape[-1]) if i not in self.indices]
        attn_weights[1:, :, :, non_target_indices] /= self.weight

        return attn_weights


class ReplaceController(BaseEditController):
    def __init__(self, blend: float = 0.5):
        super().__init__()

        self.blend = blend

    def replace_cross_attention(self, attn_weights, is_cross, attention_type):
        attn_weights[1:] = (1 - self.blend) * attn_weights[1:] + self.blend * attn_weights[0]

        return attn_weights

In [None]:
from typing import Any


class ControllerModifier(BaseController):
    def __init__(self, controller: BaseController) -> None:
        super().__init__()

        self.controller = controller

    def __getattr__(self, name: str):
        if name == "controller":
            return super().__getattr__(name)

        return getattr(self.controller, name)

    def __setattr__(self, name: str, value: Any):
        if name == "controller":
            return super().__setattr__(name, value)

        return setattr(self.controller, name, value)


class OffsetControllerModifier(ControllerModifier):
    def __init__(self, controller: BaseController, offset: float = 0.0) -> None:
        super().__init__(controller)

        assert 0.0 <= offset <= 1.0

        self.offset = offset

    def replace_self_attention(self, attn):
        if self.cur_step < round(self.offset * self.max_new_tokens):
            return attn

        return self.controller.replace_self_attention(attn)

    def replace_cross_attention(self, attn_weights, is_cross, attention_type):
        if self.cur_step < round(self.offset * self.max_new_tokens):
            return attn_weights

        return self.controller.replace_cross_attention(attn_weights, is_cross, attention_type)


class AttentionHeadControllerModifier(ControllerModifier):
    def __init__(self, controller: BaseController, attention_head_indices: list[int]) -> None:
        super().__init__(controller)

        self.attention_head_indices = attention_head_indices

    def replace_self_attention(self, attn):
        return self.controller.replace_self_attention(attn)

    def replace_cross_attention(self, attn_weights, is_cross, attention_type):
        attn_weights_slice = attn_weights[:, self.attention_head_indices, :, :]
        attn_weights_slice = self.controller.replace_cross_attention(
            attn_weights_slice, is_cross, attention_type
        )

        attn_weights[:, self.attention_head_indices, :, :] = attn_weights_slice

        return attn_weights


class SelfAttentionLerpControllerModifier(ControllerModifier):
    def __init__(self, controller: BaseController) -> None:
        super().__init__(controller)

    def replace_self_attention(self, attn):
        blend = self.cur_att_layer / self.num_att_layers

        attn[1:] = (1 - blend) * attn[1:] + blend * attn[0]

        return attn

    def replace_cross_attention(self, attn_weights, is_cross, attention_type) -> None:
        return self.controller.replace_cross_attention(attn_weights, is_cross, attention_type)


class SelfAttentionCutoffControllerModifier(ControllerModifier):
    def __init__(self, controller: BaseController, threshold: float = 0.75) -> None:
        super().__init__(controller)

        assert 0.0 <= threshold <= 1.0

        self.threshold = threshold

    def replace_self_attention(self, attn):
        if self.cur_att_layer <= np.floor(self.threshold * self.num_att_layers):
            return self.controller.replace_self_attention(attn)

        return attn

    def replace_cross_attention(self, attn_weights, is_cross, attention_type) -> None:
        return self.controller.replace_cross_attention(attn_weights, is_cross, attention_type)


class AttentionLerpControllerModifier(ControllerModifier):
    def __init__(self, controller: BaseController) -> None:
        super().__init__(controller)

    def replace_self_attention(self, attn):
        blend = self.cur_att_layer / self.num_att_layers

        attn = blend * attn + (1 - blend) * self.controller.replace_self_attention(attn)

        return attn

    def replace_cross_attention(self, attn_weights, is_cross, attention_type) -> None:
        blend = self.cur_att_layer / self.num_att_layers

        attn_weights = blend * attn_weights + (1 - blend) * self.controller.replace_cross_attention(attn_weights.clone(), is_cross, attention_type)

        return attn_weights


class AttentionCutoffControllerModifier(ControllerModifier):
    def __init__(self, controller: BaseController, threshold: float = 0.75) -> None:
        super().__init__(controller)

        assert 0.0 <= threshold <= 1.0

        self.threshold = threshold

    def replace_self_attention(self, attn):
        if self.cur_att_layer <= np.floor(self.threshold * self.num_att_layers):
            return self.controller.replace_self_attention(attn)

        return attn

    def replace_cross_attention(self, attn_weights, is_cross, attention_type) -> None:
        if self.cur_att_layer <= np.floor(self.threshold * self.num_att_layers):
            return self.controller.replace_cross_attention(attn_weights, is_cross, attention_type)

        return attn_weights


class DecoderLayerControllerModifier(ControllerModifier):
    def __init__(self, controller: BaseController, decoder_layer_indices: set[int]) -> None:
        super().__init__(controller)

        self.decoder_layer_indices = decoder_layer_indices

    def replace_self_attention(self, attn):
        if self.cur_att_layer in self.decoder_layer_indices:
            return self.controller.replace_self_attention(attn)

        return attn

    def replace_cross_attention(self, attn_weights, is_cross, attention_type):
        if self.cur_att_layer in self.decoder_layer_indices:
            return self.controller.replace_cross_attention(attn_weights, is_cross, attention_type)

        return attn_weights

In [None]:
from datetime import datetime
from scipy.io.wavfile import write
import pandas as pd
import re
import numpy as np

CONTROLLER_REG = re.compile(r"Controller.*")
SNAKE_CASE_REG = re.compile(r"(?<!^)(?=[A-Z])")

RESULTS_DIR = BASE_DIR / "results" / "musicgen"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)


def _get_controller_subpath(controller: BaseController) -> Path:
    controller_name = CONTROLLER_REG.sub("", controller.__class__.__name__)
    controller_name = SNAKE_CASE_REG.sub("_", controller_name).lower()

    return Path(controller_name)


def get_controller_subpath(controller: BaseController) -> Path:
    if isinstance(controller, DecoderLayerControllerModifier):
        return (
            _get_controller_subpath(controller)
            / f"{''.join(map(lambda x: f'{x:02d}', controller.decoder_layer_indices))}"
            / get_controller_subpath(controller.controller)
        )
    elif isinstance(controller, ReplaceController):
        return _get_controller_subpath(controller) / f"{controller.blend:.2f}"
    elif isinstance(controller, ReplaceWordController):
        return (
            _get_controller_subpath(controller)
            / f"{controller.index:02d}"
            / f"{controller.blend:.2f}"
        )
    else:
        raise NotImplementedError(controller.__class__.__name__)


def construct_output_folder_path(controller: BaseController) -> Path:
    return RESULTS_DIR / get_controller_subpath(controller) / datetime.now().strftime("%Y_%m_%d")


def get_tokens(prompts: list[str]):
    inputs = model_proxy.encode(prompts)

    tokens = []
    for index in range(len(prompts)):
        tokens.append([model_proxy.decode(item) for item in inputs["input_ids"][index]])

    return tokens


def get_replacement_indices(prompts: list[str], word_a: str, word_b: str) -> list[list[int]]:
    token_groups = get_tokens(prompts)

    if len(prompts[0].split()) != len(prompts[1].split()):
        raise NotImplementedError(f"Different prompt lengths ({prompts})")

    indices = []
    for word, tokens in zip([word_a, word_b], token_groups):
        prompt_indices, substring = [], []
        for i in range(len(tokens)):
            if word.startswith("".join([*substring, tokens[i]])):
                substring.append(tokens[i])
                prompt_indices.append(i)
        indices.append(prompt_indices)

    return indices


def get_reweight_word_indices(prompts: list[str], word: str) -> tuple[list[str], list[int]]:
    return get_replacement_indices(prompts, word, word)[0]


def get_refine_word_indices(prompts: list[str]) -> tuple[list[str], list[int]]:
    token_groups = get_tokens(prompts)

    indices = []
    for i, token_i in enumerate(token_groups[0]):
        for j, token_j in enumerate(token_groups[1]):
            if token_i.startswith("<") or token_j.startswith("<"):
                continue

            if token_i == token_j:
                indices.append((i, j))

    return list(zip(*indices))


def get_ignore_indices(prompts: list[str]) -> tuple[list[str], list[int]]:
    tokens_a, tokens_b = prompts[0].split(), prompts[1].split()
    if len(tokens_a) != len(tokens_b):
        raise NotImplementedError("Different prompt lengths")

    index = tokens_b.index("<IGNORE>")
    word = tokens_a[index]

    return [prompts[0], prompts[0]], get_replacement_indices([prompts[0], prompts[0]], word, word)[
        0
    ]


def run_and_display(prompts, controller=None, seed=0, audio_length=10.0, save_results=False):
    max_new_tokens = 2 ** round(np.log2(audio_length * model_proxy.frame_rate))

    if seed is not None:
        torch.manual_seed(seed)

    if controller is None:
        controller = EmptyController()

    controller.reset()
    controller.batch_size = len(prompts)
    controller.max_new_tokens = max_new_tokens

    register_attention_control(model_proxy, controller)

    inputs = model_proxy.encode(prompts)
    audio_values = model_proxy.generate(inputs, max_new_tokens=max_new_tokens)

    if save_results:
        output_folder = construct_output_folder_path(controller)
        output_folder.mkdir(parents=True, exist_ok=True)

        for prompt, audio_value in zip(prompts, audio_values):
            filename = f"{prompt}.wav"
            filepath = output_folder / filename
            write(filepath, rate=model_proxy.sampling_rate, data=audio_value)

    return audio_values

### Comparing self-attention layers

In [None]:
from transformers import ClapModel

clap_model = ClapModel.from_pretrained("laion/clap-htsat-unfused")
clap_processor = AutoProcessor.from_pretrained("laion/clap-htsat-unfused")

In [None]:
import librosa
import numpy as np
import torch.nn.functional as F


def cosine_similarity(prompt, audios, sr=48000):
    # Resample audios
    audios = np.stack(
        [
            librosa.resample(audio, orig_sr=model_proxy.sampling_rate, target_sr=sr)
            for audio in audios
        ]
    )

    inputs = clap_processor(
        text=prompt, audios=audios, return_tensors="pt", sampling_rate=sr, padding=True
    )

    # Process prompt and audios
    prompt_features = clap_model.get_text_features(
        input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]
    )
    audio_features = clap_model.get_audio_features(
        input_features=inputs["input_features"], attention_mask=inputs["attention_mask"]
    )

    # Calculate cosine similarity between audios
    audio_audio_similarity = F.cosine_similarity(audio_features[0], audio_features[1], dim=0)

    # Calculate cosine similarity between prompt at index 1 and audio at index 1
    text_audio_similarity = F.cosine_similarity(prompt_features[0], audio_features[1], dim=0)

    return audio_audio_similarity.item(), text_audio_similarity.item()

In [None]:
from sklearn.metrics import accuracy_score


def extract_pitch_classes(audio, sr=model_proxy.sampling_rate, hop_length=512):
    # Extract pitch using librosa's piptrack function
    _, magnitudes = librosa.core.piptrack(y=audio, sr=sr, hop_length=hop_length)

    # Get the pitch with the maximum magnitude for each frame
    pitch_classes = np.argmax(magnitudes, axis=0)

    return pitch_classes


def calculate_melody_accuracy(input_melody, generated_melody):
    # Extract pitch classes from both melodies
    input_pitch_classes = extract_pitch_classes(input_melody)
    generated_pitch_classes = extract_pitch_classes(generated_melody)

    # Calculate melody accuracy
    accuracy = accuracy_score(input_pitch_classes, generated_pitch_classes)

    return accuracy

In [None]:
def calculate_beat_consistency_score(audio, sr=model_proxy.sampling_rate):
    # Beat detection
    _, beat_frames = librosa.beat.beat_track(y=audio, sr=sr)

    # Calculate Inter-Beat Intervals (IBIs)
    ibis = np.diff(librosa.frames_to_time(beat_frames, sr=sr))

    # Calculate mean and standard deviation of IBIs
    mean_ibi = np.mean(ibis)
    std_ibi = np.std(ibis)

    # Calculate Beat Consistency Score (coefficient of variation)
    beat_consistency_score = std_ibi / mean_ibi

    return beat_consistency_score

In [None]:
def calculate_snr(y):
    # Calculate the power of the signal
    signal_power = np.sum(y**2)

    # Estimate the noise using spectral flatness
    flatness = librosa.feature.spectral_flatness(y=y)

    # Calculate the noise power
    noise_power = np.sum(flatness)

    # Calculate SNR in decibels (dB)
    snr = 10 * np.log10(signal_power / noise_power)

    return snr

In [None]:
from skimage.metrics import structural_similarity as ssim


def calculate_ssi(audios):
    # Calculate spectrograms
    specgram_1 = librosa.amplitude_to_db(np.abs(librosa.stft(audios[0])), ref=np.max)
    specgram_2 = librosa.amplitude_to_db(np.abs(librosa.stft(audios[1])), ref=np.max)

    # Normalize the spectrograms to [0, 1]
    specgram_1 = (specgram_1 - np.min(specgram_1)) / (np.max(specgram_1) - np.min(specgram_1))
    specgram_2 = (specgram_2 - np.min(specgram_2)) / (np.max(specgram_2) - np.min(specgram_2))

    # Calculate Structural Similarity Index (SSI)
    ssi_index, _ = ssim(specgram_1, specgram_2, data_range=1.0, full=True)

    return ssi_index

In [None]:
from typing import Iterable


class Dataset(object):
    def __init__(
        self,
        samples: list[tuple[str, tuple[str, str]]],
        soft_blending: bool = False,
        seeds: Optional[list[int]] = None,
    ) -> None:
        if seeds is None:
            seeds = [0, 1]

        self.entries = []
        for edit_type, prompts in samples:
            controllers = []
            if "Ignore" in edit_type:
                prompts, indices = get_ignore_indices(prompts)
                controllers = [IgnoreWordController(indices)]
            elif "Replace" in edit_type:
                words_a, words_b = prompts[0].split(), prompts[1].split()
                index = next(
                    (
                        i
                        for i, (word_a, word_b) in enumerate(zip(words_a, words_b))
                        if word_a != word_b
                    ),
                    None,
                )
                indices = get_replacement_indices(prompts, words_a[index], words_b[index])
                controllers = [
                    ReplaceWordController(indices, blend) for blend in np.arange(0.3, 0.8, 0.2)
                ]
            else:
                raise NotImplementedError(f"{edit_type} is not supported")

            if soft_blending:
                controllers = map(SelfAttentionLerpControllerModifier, controllers)

            for controller in controllers:
                for seed in seeds:
                    self.entries.append((edit_type, (prompts[0], prompts[1]), controller, seed))

    def __iter__(self) -> Iterable[tuple[str, tuple[str, str], BaseController]]:
        yield from self.entries

    def __len__(self):
        return len(self.entries)

In [None]:
import pickle


class CheckpointManager(object):
    def __init__(self, filepath: Path) -> None:
        self.filepath = filepath

    def load(self) -> dict[str, Any]:
        if self.filepath.is_file():
            with self.filepath.open("rb") as file:
                return pickle.load(file)

        return {}

    def dump(self, **data: dict[str, Any]) -> None:
        with self.filepath.open("wb") as file:
            return pickle.dump(data, file)

In [None]:
samples = [
    ("Ignore", ("pop song with guitar and drums", "pop song with <IGNORE> and drums")),
    ("Replace", ("pop song with guitar and drums", "pop song with synth and drums")),
    ("Replace (Sentiment)", ("happy pop song", "sad pop song")),
    ("Replace (Chord)", ("a major chord pop song", "a minor chord pop song")),
    ("Ignore", ("rock ballad with piano", "rock ballad with <IGNORE>")),
    (
        "Replace",
        ("jazz ensemble with trumpet and saxophone", "jazz ensemble with piano and saxophone"),
    ),
    ("Replace (Sentiment)", ("energetic electronic dance track", "calm electronic dance track")),
    (
        "Replace (Chord)",
        ("blues riff in E on electric guitar", "blues riff in G on electric guitar"),
    ),
    (
        "Ignore",
        (
            "acoustic folk song with banjo and harmonica",
            "acoustic folk song with <IGNORE> and harmonica",
        ),
    ),
    (
        "Replace",
        (
            "classical symphony with violins and cellos",
            "classical symphony with flutes and cellos",
        ),
    ),
    ("Replace (Sentiment)", ("upbeat indie pop anthem", "melancholic indie pop anthem")),
    ("Replace (Chord)", ("piano sonata in C minor", "piano sonata in A minor")),
    ("Ignore", ("funky bassline with slap technique", "funky <IGNORE> with slap technique")),
    (
        "Replace",
        ("latin jazz fusion with congas and bongos", "latin jazz fusion with timbales and bongos"),
    ),
    (
        "Replace (Sentiment)",
        ("motivational corporate background music", "relaxing corporate background music"),
    ),
    (
        "Replace (Chord)",
        (
            "gospel choir with dominant seventh chords",
            "gospel choir with diminished seventh chords",
        ),
    ),
    (
        "Ignore",
        (
            "ambient electronic soundscape with synthesizers",
            "ambient electronic <IGNORE> with synthesizers",
        ),
    ),
    (
        "Replace",
        (
            "orchestral film score with strings and brass",
            "orchestral film score with woodwinds and brass",
        ),
    ),
    ("Replace (Sentiment)", ("uplifting reggae vibes", "heartbreaking reggae vibes")),
    (
        "Replace (Chord)",
        ("punk rock anthem with power chords", "punk rock anthem with barre chords"),
    ),
]

dataset = Dataset(samples)

In [None]:
from tqdm.auto import tqdm


def run_greedy_ablation_study(checkpoint_path: Optional[Path] = None):
    if checkpoint_path is None:
        checkpoint_path = RESULTS_DIR / "greedy_checkpoint.pkl"

    checkpoint_manager = CheckpointManager(checkpoint_path)

    columns = [
        "Edit",
        "Layers",
        "Source Prompt",
        "Editted Prompt",
        "Source Audio",
        "Editted Audio",
        "Text-Audio Cosine Similarity",
        "Audio-Audio Cosine Similarity",
    ]

    prompts = [samples[0][1][0], samples[0][1][1]]
    controller = AttentionStore()
    audio_values = run_and_display(prompts, controller)

    # !This is a heuristic
    sorted_indices = controller.get_self_attention_importance()[0].tolist()

    cross_attention_layer_indices = [2 * (i + 1) for i in range(len(model_proxy.decoder_layers))]
    error_threshold = 0.1

    checkpoint = checkpoint_manager.load()
    black_listed_indices = checkpoint.get("black_listed_indices", [])
    visited_indices = checkpoint.get("visited_indices", [])
    df_list = checkpoint.get("df_list", [])

    all_indices = [
        i for i in sorted_indices if i not in visited_indices and i not in black_listed_indices
    ]
    progress_bar_a = tqdm(all_indices, position=0)
    for iteration in progress_bar_a:
        indices = [
            i for i in all_indices if i not in visited_indices and i not in black_listed_indices
        ]

        try:
            previous_max_score = df_list[-1]["Score"].item()
        except IndexError:
            previous_max_score = 0

        progress_bar_b, scores = tqdm(indices, position=1, leave=False), []
        for index in progress_bar_b:
            self_attention_layer_indices = [2 * i + 1 for i in [index, *visited_indices]]

            progress_bar_c, df_list_ablation = tqdm(dataset, position=2, leave=False), []
            for edit, prompts, controller, seed in progress_bar_c:
                layers = ",".join(f"{x:02d}" for x in self_attention_layer_indices).strip()

                progress_bar_c.set_postfix({"layers": layers})

                attention_layer_indices = [
                    *cross_attention_layer_indices,
                    *self_attention_layer_indices,
                ]
                controller = DecoderLayerControllerModifier(
                    controller, set(attention_layer_indices)
                )
                audio_values = run_and_display(prompts, controller, seed=seed)

                audio_audio_similarity, text_audio_similarity = cosine_similarity(
                    prompts[1], audio_values
                )

                row = [edit, layers, prompts[0], prompts[1], audio_values[0], audio_values[1]]
                row.append(text_audio_similarity)
                row.append(audio_audio_similarity)

                df_list_ablation.append(pd.DataFrame([row], columns=columns))

            df = pd.concat(df_list_ablation, ignore_index=True)

            metrics = df[["Text-Audio Cosine Similarity", "Audio-Audio Cosine Similarity"]]
            score = metrics.mean(axis=None)
            scores.append(score)

            error = abs(score - previous_max_score)
            if score < previous_max_score and error > error_threshold:
                black_listed_indices.append(index)

            progress_bar_b.set_postfix({"index": index, "score": f"{score:.3f}"})

        max_score = max(scores)

        error = abs(max_score - previous_max_score)
        if max_score < previous_max_score and error > error_threshold:
            break

        max_score_index = indices[scores.index(max_score)]
        visited_indices.append(max_score_index)
        df_list.append(
            pd.DataFrame([[max_score, visited_indices.copy()]], columns=["Score", "Indices"])
        )

        checkpoint_manager.dump(
            visited_indices=visited_indices,
            black_listed_indices=black_listed_indices,
            df_list=df_list,
        )

        progress_bar_a.set_postfix(
            {
                "last_checkpoint": f"{iteration:02d}",
                "current": f"{max_score:.3f}",
                "previous": f"{previous_max_score:.3f}",
                "error": f"{error * 100:.2f}%",
            }
        )

    df = pd.concat(df_list, ignore_index=True)
    df.to_pickle(RESULTS_DIR / "greedy.pkl")

In [None]:
run_greedy_ablation_study()

In [None]:
def run_ablation_study(dataset, self_attention_layer_groups):
    columns = [
        "Edit",
        "Layers",
        "Source Prompt",
        "Editted Prompt",
        "Source Audio",
        "Editted Audio",
        "Text-Audio Cosine Similarity",
        "Audio-Audio Cosine Similarity",
        "Melody Accuracy",
        "Beat Consistency Score",
        "Signal to Noise Ratio",
        "Structural Similarity Index",
    ]

    df_list = []
    for edit, prompts, controller, seed in tqdm(dataset, position=0):
        for self_attention_layers in tqdm(self_attention_layer_groups, leave=False, position=1):
            controller = DecoderLayerControllerModifier(
                controller,
                set(
                    [2 * (i + 1) for i in range(len(model_proxy.decoder_layers))]
                    + self_attention_layers
                ),
            )
            audio_values = run_and_display(prompts, controller, seed=seed)

            layers = ",".join(f"{x:02d}" for x in self_attention_layers).strip()

            audio_audio_similarity, text_audio_similarity = cosine_similarity(
                prompts[1], audio_values
            )

            row = [edit, layers, prompts[0], prompts[1], audio_values[0], audio_values[1]]
            row.append(text_audio_similarity)
            row.append(audio_audio_similarity)
            row.append(calculate_melody_accuracy(audio_values[0], audio_values[1]))
            row.append(-calculate_beat_consistency_score(audio_values[1]))
            row.append(calculate_snr(audio_values[1]))
            row.append(calculate_ssi(audio_values))

            df_list.append(pd.DataFrame([row], columns=columns))

    return pd.concat(df_list, ignore_index=True)

#### Comparing individual self-attention layers

In [None]:
self_attention_layer_groups = [[2 * x + 1] for x in range(len(model_proxy.decoder_layers))]
df = run_ablation_study(dataset, self_attention_layer_groups)
df.to_pickle(RESULTS_DIR / "individual_hard.pkl")

#### Comparing `n - 1` self-attention layers

In [None]:
self_attention_layer_groups = [
    [2 * y + 1 for y in range(len(model_proxy.decoder_layers)) if x != y]
    for x in range(len(model_proxy.decoder_layers))
]
df = run_ablation_study(dataset, self_attention_layer_groups)
df.to_pickle(RESULTS_DIR / "leave_one_out_hard.pkl")

#### Comparing incremental groups of self-attention layers

In [None]:
self_attention_layer_groups = [
    [2 * y + 1 for y in range(0, x)] for x in range(1, len(model_proxy.decoder_layers) + 1)
]
df = run_ablation_study(dataset, self_attention_layer_groups)
df.to_pickle(RESULTS_DIR / "incremental_hard.pkl")

#### Comparing individual self-attention layers (Soft-blending self-attention)

In [None]:
dataset = Dataset(samples, soft_blending=True)

In [None]:
self_attention_layer_groups = [[2 * x + 1] for x in range(len(model_proxy.decoder_layers))]
df = run_ablation_study(dataset, self_attention_layer_groups)
df.to_pickle(RESULTS_DIR / "individual_soft.pkl")

#### Comparing `n - 1` self-attention layers (Soft-blending self-attention)

In [None]:
self_attention_layer_groups = [
    [2 * y + 1 for y in range(len(model_proxy.decoder_layers)) if x != y]
    for x in range(len(model_proxy.decoder_layers))
]
df = run_ablation_study(dataset, self_attention_layer_groups)
df.to_pickle(RESULTS_DIR / "leave_one_out_soft.pkl")

#### Comparing incremental groups of self-attention layers (Soft-blending self-attention)

In [None]:
self_attention_layer_groups = [[2 * y + 1 for y in range(0, x)] for x in range(1, 49)]
df = run_ablation_study(dataset, self_attention_layer_groups)
df.to_pickle(RESULTS_DIR / "incremental_soft.pkl")