In [None]:
import os
from pathlib import Path

BASE_DIR = Path.cwd()

if (Path("/") / "home" / "vsioros" / "data").is_dir():
    BASE_DIR = Path("/") / "home" / "vsioros" / "data"

    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import sys
sys.path.append("..")

import torch
import os
import numpy as np
from huggingface_hub import snapshot_download
from auffusion.prompt2prompt.pipeline_prompt2prompt import Prompt2PromptPipeline
from auffusion.prompt2prompt.ptp_utils import (
    AttentionControlEdit,
    AttentionReplace,
    AttentionRefine,
    AttentionReweight,
    get_equalizer,
)
from auffusion.converter import denormalize_spectrogram, Generator

In [None]:
pretrained_model_name_or_path = "auffusion/auffusion-full-no-adapter"
dtype = torch.float16
device = "cuda"
sampling_rate = 16000

In [None]:
if not os.path.isdir(pretrained_model_name_or_path):
    pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path)

In [None]:
vocoder = Generator.from_pretrained(pretrained_model_name_or_path, subfolder="vocoder")
vocoder = vocoder.to(device=device, dtype=dtype)

In [None]:
pipe = Prompt2PromptPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=dtype, low_cpu_mem_usage=False)
pipe = pipe.to(device)

In [None]:
from scipy.io.wavfile import write
from typing import Optional


def run_and_display(
    prompts: list[str],
    controller: AttentionControlEdit,
    num_inference_steps: int,
    seed: Optional[int] = None,
):
    g_cpu = None
    if seed is not None:
        g_cpu = torch.Generator().manual_seed(seed)

    outputs = pipe(
        prompt=prompts,
        height=256,
        width=1024,
        num_inference_steps=num_inference_steps,
        controller=controller,
        generator=g_cpu,
        output_type="pt",
    )

    audio_values_list = []
    for i in range(len(prompts)):
        spec = torch.from_numpy(outputs.images.transpose(0, 3, 1, 2)[i]).to(device, dtype)
        denorm_spec = denormalize_spectrogram(spec)
        denorm_spec_audio = vocoder.inference(denorm_spec)

        audio_values_list.append(denorm_spec_audio.squeeze())

    return np.stack(audio_values_list)

In [None]:
def get_tokens(prompts: list[str]):
    tokenizer = pipe.tokenizer

    tokens = []
    for text in prompts:
        tokens.append(
            [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
        )

    return tokens


def get_replacement_indices(prompts: list[str], word_a: str, word_b: str) -> list[list[int]]:
    token_groups = get_tokens(prompts)

    if len(prompts[0].split()) != len(prompts[1].split()):
        raise NotImplementedError(f"Different prompt lengths ({prompts})")

    indices = []
    for word, tokens in zip([word_a, word_b], token_groups):
        prompt_indices, substring = [], []
        for i in range(len(tokens)):
            if word.startswith("".join([*substring, tokens[i]])):
                substring.append(tokens[i])
                prompt_indices.append(i)
        indices.append(prompt_indices)

    return indices


def get_reweight_word_indices(prompts: list[str], word: str) -> tuple[list[str], list[int]]:
    return get_replacement_indices(prompts, word, word)[0]


def get_refine_word_indices(prompts: list[str]) -> tuple[list[str], list[int]]:
    token_groups = get_tokens(prompts)

    indices = []
    for i, token_i in enumerate(token_groups[0]):
        for j, token_j in enumerate(token_groups[1]):
            if token_i.startswith("<") or token_j.startswith("<"):
                continue

            if token_i == token_j:
                indices.append((i, j))

    return list(zip(*indices))

### ISMIR Evaluation

We incorporate different audio editing axes:

- Instrument Replacement: In these prompts, one instrument or sound source is replaced with another instrument or sound. For example, replacing blues ensemble with guitar and drums with country ensemble with guitar and drums, or replacing acoustic guitar solo with electric guitar solo.

- Mood/Tonal Change: These prompts involve changing the mood or tonality of the music. For instance, transforming a happy violin solo into a sad violin solo, or converting a major chord pop song into a minor chord pop song.

- Genre Shift: These prompts involve shifting the genre or style of the music. For example, transitioning from a rock riff on electric guitar to a metal riff on electric guitar, or changing a jazz beat with saxophone into a hip-hop beat with saxophone.

- Melodic Transformation: These edits involve altering the melodic content of the music. This can include changes in melodic contour, intervals, motifs, and melodies. For example, transforming a melodic line from ascending to descending or changing the melodic intervals to create a different melodic feel.

- Harmonic Modification: These edits involve modifying the harmonic structure of the music. This can include changes in chord progressions, harmonic rhythm, harmonic density, and harmonic tension. For instance, altering the chord progression from a standard I-IV-V to a more complex progression or introducing chromaticism to the harmony.

- Form/Structure Variation: These edits involve variations in the overall form or structure of the music. This can include changes in sectional arrangement, repetitions, transitions, and developmental processes. For example, restructuring a piece by adding or removing sections, or altering the order of musical events to create a different narrative flow.

We generate a plethora of prompt pairs utilizing ChatGPT and the accompanying text-prompt:

#### Replace

```txt
Generate a list of quadruples containing:
1. The edit category,
2. The original word in the prompt,
3. The replacement word,
4. The prompt pair consisting of the source and edited prompt. 

Ensure to include at least one entry for each of the following edit categories and output only a single Python list.

Example entry:
(
   "Genre Shift",
   "blues",
   "country",
   ("blues ensemble with guitar and drums", "country ensemble with guitar and drums"),
),

Edit Categories:

- Instrument Replacement: Replace one instrument or sound source with another.
- Mood/Tonal Change: Change the mood or tonality of the music.
- Genre Shift: Shift the genre or style of the music.
- Melodic Transformation: Alter the melodic content of the music.
- Harmonic Modification: Modify the harmonic structure of the music.
- Form/Structure Variation: Vary the overall form or structure of the music.
```

#### Refine

```txt
Generate a list of tuples containing:
1. The edit category,
2. The prompt pair consisting of the source and edited prompt. 

The source prompt should strictly be a substring of the edited prompt. The edited prompt should only add details.

Ensure to include at least one entry for each of the following edit categories and output only a single Python list.

Example entry:
("piano melody", "jazz piano melody with improvisation"),

Edit Categories:

- Instrument Enhancement: Enhancing the sound by adding effects or layering additional sounds without replacing the instrument.
- Mood/Tonal Enhancement: Modifying tonality or mood through adjustments like reverb or EQ settings.
- Genre Fusion: Combining elements from different genres while preserving the original essence.
- Melodic Embellishment: Adding ornamentations or variations to enrich the melody's expressiveness.
- Harmonic Enrichment: Enriching the harmonic structure by adding chords or layers for a fuller sound.
- Form/Structure Expansion: Elaborating on the form or structure by adding new sections or transitions for complexity.
```

### Reweight

```txt
Generate a list of tuples containing:
1. The edit category,
2. The target word
2. The prompt pair consisting of the source and edited prompt. 

The source prompt and the edited prompt should be identical. The target word must be included in both prompts, indicating which aspect of the described audio should be intensified or diminished.

Ensure to include at least one entry for each of the following edit categories and output only a single Python list.

Example entry:
("happy", ("happy acoustic guitar solo", "happy acoustic guitar solo"))
Edit Categories:

Instrument Reinforcement: Enhancing the sound of specific instruments or adding additional layers to enrich the overall texture without replacing them entirely.
Mood/Tonal Brightening: Adjusting tonality or mood through changes like increasing brightness or introducing uplifting effects to evoke a happier atmosphere.
Genre Blend: Mixing elements from various genres while maintaining the song's core identity to create a fusion that embodies a happier vibe.
Melodic Flourish: Introducing embellishments or variations in the melody to make it more lively, optimistic, and joyful.
Harmonic Enlivening: Enriching the harmonic structure by adding chords or harmonic layers that convey a sense of positivity and energy.
Form/Structure Expansion: Expanding the song's structure or adding new sections to build anticipation, create contrasts, and enhance the overall uplifting mood.
```

In [None]:
replace_group = [
    (
        "Instrument Replacement",
        "violin",
        "cello",
        ("beautiful violin solo", "beautiful cello solo"),
    ),
    (
        "Mood/Tonal Change",
        "happy",
        "sad",
        ("uplifting piano melody", "melancholic piano melody"),
    ),
    (
        "Genre Shift",
        "rock",
        "metal",
        ("rock riff on electric guitar", "metal riff on electric guitar"),
    ),
    (
        "Melodic Transformation",
        "ascending",
        "descending",
        ("ascending melodic line", "descending melodic line"),
    ),
    (
        "Harmonic Modification",
        "I-IV-V",
        "ii-V-I",
        ("standard chord progression", "jazzier chord progression"),
    ),
    (
        "Form/Structure Variation",
        "adding",
        "removing",
        (
            "restructuring a piece by adding sections",
            "restructuring a piece by removing sections",
        ),
    ),
    (
        "Instrument Replacement",
        "guitar",
        "piano",
        ("guitar solo", "piano solo"),
    ),
    (
        "Mood/Tonal Change",
        "major",
        "minor",
        ("major chord pop song", "minor chord pop song"),
    ),
    (
        "Genre Shift",
        "jazz",
        "hip-hop",
        ("jazz beat with saxophone", "hip-hop beat with saxophone"),
    ),
    (
        "Melodic Transformation",
        "motif",
        "variation",
        ("repeating motif", "varied motif"),
    ),
    (
        "Harmonic Modification",
        "standard",
        "chromatic",
        ("standard chord progression", "chromatic chord progression"),
    ),
    (
        "Form/Structure Variation",
        "repetitions",
        "transitions",
        ("repeating sections", "transitional sections"),
    ),
    (
        "Instrument Replacement",
        "drums",
        "synthesizer",
        ("drum solo", "synthesizer solo"),
    ),
    (
        "Mood/Tonal Change",
        "dark",
        "ethereal",
        ("dark ambient track", "ethereal ambient track"),
    ),
    (
        "Genre Shift",
        "pop",
        "reggae",
        ("pop song with catchy hooks", "reggae song with catchy hooks"),
    ),
    (
        "Melodic Transformation",
        "intervals",
        "sequences",
        ("melodic intervals", "melodic sequences"),
    ),
    (
        "Harmonic Modification",
        "I-vi-IV-V",
        "ii-V-I",
        ("typical chord progression", "jazzier chord progression"),
    ),
    (
        "Form/Structure Variation",
        "intro",
        "outro",
        ("introductory section", "concluding section"),
    ),
    (
        "Instrument Replacement",
        "trumpet",
        "flute",
        ("trumpet solo", "flute solo"),
    ),
    (
        "Mood/Tonal Change",
        "uplifting",
        "haunting",
        ("uplifting guitar melody", "haunting guitar melody"),
    ),
]

In [None]:
refine_group = [
    (
        "Instrument Enhancement",
        ("piano melody", "jazz piano melody with added chorus effect for depth and warmth"),
    ),
    (
        "Mood/Tonal Enhancement",
        ("guitar riff", "ethereal guitar riff with shimmering reverb and atmospheric delay"),
    ),
    (
        "Genre Fusion",
        (
            "hip-hop beat",
            "trap-infused hip-hop beat with electronic synth arpeggios and 808 bass",
        ),
    ),
    (
        "Melodic Embellishment",
        (
            "vocal line",
            "soulful vocal line with intricate melismatic runs and emotive vibrato",
        ),
    ),
    (
        "Harmonic Enrichment",
        (
            "chord progression",
            "lush chord progression with added ninth and eleventh extensions for richness",
        ),
    ),
    (
        "Form/Structure Expansion",
        (
            "bridge section",
            "extended bridge section with modulating key centers and layered counterpoint",
        ),
    ),
    (
        "Instrument Enhancement",
        (
            "drum groove",
            "dynamic drum groove with layered percussion and enhanced stereo imaging",
        ),
    ),
    (
        "Mood/Tonal Enhancement",
        (
            "ambient pad",
            "serene ambient pad with subtle modulated filters and soft side-chain compression",
        ),
    ),
    (
        "Genre Fusion",
        (
            "jazz saxophone solo",
            "fusion jazz saxophone solo with electronic glitch effects and syncopated beats",
        ),
    ),
    (
        "Melodic Embellishment",
        (
            "flute melody",
            "flute melody with cascading runs and delicate trills for added expressiveness",
        ),
    ),
    (
        "Harmonic Enrichment",
        (
            "bassline",
            "deep bassline with walking chromatic lines and extended harmonic sequences",
        ),
    ),
    (
        "Form/Structure Expansion",
        ("chorus", "expanded chorus with layered harmonies and intricate rhythmic variations"),
    ),
    (
        "Instrument Enhancement",
        (
            "synth lead",
            "bright synth lead with added modulation effects and stereo widening for depth",
        ),
    ),
    (
        "Mood/Tonal Enhancement",
        (
            "piano chords",
            "soothing piano chords with gentle reverb and subtle tape saturation for warmth",
        ),
    ),
    (
        "Genre Fusion",
        (
            "reggae rhythm",
            "reggae rhythm with dubstep-inspired bass drops and electronic glitches",
        ),
    ),
    (
        "Melodic Embellishment",
        (
            "violin solo",
            "expressive violin solo with emotive slides and delicate pizzicato accents",
        ),
    ),
    (
        "Harmonic Enrichment",
        (
            "guitar strumming",
            "dynamic guitar strumming with extended chord voicings and added suspended notes",
        ),
    ),
    (
        "Form/Structure Expansion",
        (
            "pre-chorus",
            "extended pre-chorus with building tension and additional instrumental layers",
        ),
    ),
    (
        "Instrument Enhancement",
        (
            "drum fill",
            "energetic drum fill with layered percussion and added room reverb for spaciousness",
        ),
    ),
    (
        "Mood/Tonal Enhancement",
        ("synth pad", "dreamy synth pad with evolving filter sweeps and atmospheric delays"),
    ),
]

In [None]:
reweight_group = [
    (
        "Instrument Reinforcement",
        "drums",
        ("dynamic drums in the chorus", "dynamic drums in the chorus"),
    ),
    (
        "Mood/Tonal Brightening",
        "bright",
        ("bright piano melody", "bright piano melody"),
    ),
    (
        "Genre Blend",
        "pop",
        ("pop rock guitar riff", "pop rock guitar riff"),
    ),
    (
        "Melodic Flourish",
        "optimistic",
        ("optimistic flute melody", "optimistic flute melody"),
    ),
    (
        "Harmonic Enlivening",
        "major",
        ("uplifting major chord progression", "uplifting major chord progression"),
    ),
    (
        "Form/Structure Expansion",
        "chorus",
        ("extended chorus with layered vocals", "extended chorus with layered vocals"),
    ),
    (
        "Instrument Reinforcement",
        "bass",
        ("thumping bassline", "thumping bassline"),
    ),
    (
        "Mood/Tonal Brightening",
        "cheerful",
        ("cheerful brass section", "cheerful brass section"),
    ),
    (
        "Genre Blend",
        "funk",
        ("funk-infused guitar riff", "funk-infused guitar riff"),
    ),
    (
        "Melodic Flourish",
        "joyful",
        ("joyful synth melody", "joyful synth melody"),
    ),
    (
        "Harmonic Enlivening",
        "seventh",
        ("vibrant seventh chord progression", "vibrant seventh chord progression"),
    ),
    (
        "Form/Structure Expansion",
        "bridge",
        (
            "extended bridge section with energetic build-up",
            "extended bridge section with energetic build-up",
        ),
    ),
    (
        "Instrument Reinforcement",
        "guitar",
        ("powerful guitar solo", "powerful guitar solo"),
    ),
    (
        "Mood/Tonal Brightening",
        "uplifting",
        ("uplifting strings arrangement", "uplifting strings arrangement"),
    ),
    (
        "Genre Blend",
        "reggae",
        ("reggae-inspired drum groove", "reggae-inspired drum groove"),
    ),
    (
        "Melodic Flourish",
        "hopeful",
        ("hopeful vocal melody", "hopeful vocal melody"),
    ),
    (
        "Harmonic Enlivening",
        "major",
        ("bright major chord progression", "bright major chord progression"),
    ),
    (
        "Form/Structure Expansion",
        "pre-chorus",
        (
            "extended pre-chorus with dynamic instrumentation",
            "extended pre-chorus with dynamic instrumentation",
        ),
    ),
    (
        "Instrument Reinforcement",
        "synth",
        ("lush synth pads", "lush synth pads"),
    ),
    (
        "Mood/Tonal Brightening",
        "vibrant",
        ("vibrant brass section", "vibrant brass section"),
    ),
]

In [None]:
from typing import Any, Iterable, Iterator


def transform_samples(
    edit: str,
    samples: Iterable[tuple[str, Any, tuple[str]]],
) -> Iterator[dict[str, Any]]:
    for edit_category, *aditional, prompts in samples:
        yield {
            "Edit": edit.title(),
            "Category": edit_category,
            "Source Prompt": prompts[0],
            "Edited Prompt": prompts[1],
            "Additional": aditional,
        }


class Dataset:
    def __init__(self, **groups: dict[str, dict[str, Any]]) -> None:
        self.groups = groups

    def __iter__(self) -> Iterator[tuple[str, tuple[str, str], AttentionControlEdit]]:
        for edit, group in self.groups.items():
            for sample in transform_samples(edit, group):
                edit = sample["Edit"]
                category = sample["Category"]
                source_prompt = sample["Source Prompt"]
                edited_prompt = sample["Edited Prompt"]
                prompts = [source_prompt, edited_prompt]
                additional = sample["Additional"]

                if edit == "Replace":

                    def get_controller(prompts, additional):
                        cross_replace_steps = 0.1
                        self_replace_steps = 0.2
                        NUM_DIFFUSION_STEPS = 50
                        controller = AttentionReplace(
                            prompts,
                            NUM_DIFFUSION_STEPS,
                            cross_replace_steps=cross_replace_steps,
                            self_replace_steps=self_replace_steps,
                            tokenizer=pipe.tokenizer,
                            device=pipe.device,
                            dtype=dtype,
                        )

                        return controller
                elif edit == "Refine":

                    def get_controller(prompts, additional):
                        cross_replace_steps = 0.8
                        self_replace_steps = 0.4
                        NUM_DIFFUSION_STEPS = 50
                        controller = AttentionRefine(
                            prompts,
                            NUM_DIFFUSION_STEPS,
                            cross_replace_steps=cross_replace_steps,
                            self_replace_steps=self_replace_steps,
                            tokenizer=pipe.tokenizer,
                            device=pipe.device,
                            dtype=dtype,
                        )

                        return controller
                elif edit == "Reweight":

                    def get_controller(prompts, additional):
                        weight_list = (4,)
                        cross_replace_steps = 0.8
                        self_replace_steps = 0.8
                        NUM_DIFFUSION_STEPS = 50
                        equalizer = get_equalizer(
                            prompts[1], additional, weight_list, tokenizer=pipe.tokenizer
                        )

                        controller = AttentionReweight(
                            prompts,
                            NUM_DIFFUSION_STEPS,
                            cross_replace_steps=cross_replace_steps,
                            self_replace_steps=self_replace_steps,
                            tokenizer=pipe.tokenizer,
                            device=pipe.device,
                            equalizer=equalizer,
                        )

                        return controller
                else:
                    raise NotImplementedError

                for seed in range(5):
                    yield (
                        edit,
                        category,
                        source_prompt,
                        edited_prompt,
                        seed,
                        get_controller(prompts, additional),
                    )

    def __len__(self):
        count = 0
        for group in self.groups.values():
            count += 5 * len(group)

        return count

In [None]:
dataset = Dataset(
    replace=replace_group,
    refine=refine_group,
    reweight=reweight_group,
)

In [None]:
import warnings

import pandas as pd
from tqdm.auto import tqdm

RESULTS_DIR = BASE_DIR / "results" / "auffusion"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

EVAL_DIR = RESULTS_DIR / "Evaluation"
RESULTS_PATH = EVAL_DIR / "results.pkl"

# Load existing results if available
df = pd.DataFrame()
if RESULTS_PATH.is_file():
    df = pd.read_pickle(RESULTS_PATH)

results = df.to_dict(orient="records")
existing_paths = set(df.get("Source Path", [])).union(set(df.get("Edited Path", [])))

for edit, category, source_prompt, edited_prompt, seed, controller in tqdm(dataset):
    torch.cuda.empty_cache()

    output_folder = EVAL_DIR / edit / f"{source_prompt} - {edited_prompt}" / f"{seed:02d}"

    source_filepath = output_folder / f"00 - {source_prompt}.wav"
    edited_filepath = output_folder / f"01 - {edited_prompt}.wav"

    if source_filepath in existing_paths and edited_filepath in existing_paths:
        continue

    output_folder.mkdir(parents=True, exist_ok=True)

    try:
        audio_values = run_and_display(
            [source_prompt, edited_prompt],
            controller,
            50,
            seed=seed,
        )

        write(source_filepath, rate=sampling_rate, data=audio_values[0])
        write(edited_filepath, rate=sampling_rate, data=audio_values[1])
    except KeyboardInterrupt:
        raise
    except Exception as e:
        warnings.warn(f"{edit} '{source_prompt}' -> '{edited_prompt}' [{e}]")
        continue

    results.append(
        {
            "Source Path": source_filepath,
            "Edited Path": edited_filepath,
            "Edit": edit,
            "Category": category,
            "Source Prompt": source_prompt,
            "Edited Prompt": edited_prompt,
        },
    )

    pd.DataFrame(results).to_pickle(RESULTS_PATH)