In [1]:
%load_ext autoreload
%autoreload 2

# change directory to the root of the project
import os

os.chdir("..")

In [None]:
from pathlib import Path
import dotenv
from omegaconf import OmegaConf
from vqpiano.data.dataset import FullSongPianorollDataset
from vqpiano.models.factory import model_factory

from vqpiano.models.ae import EncoderDecoder
from vqpiano.models.utils import load_ckpt_state_dict

model_config = OmegaConf.load("config/latent_diff/model_token.yaml")
dataset_config = OmegaConf.load("config/latent_diff/dataset_pop80k_k_fullsong.yaml")

model_ = model_factory(model_config.model)
assert isinstance(model_, EncoderDecoder)
model = model_

model.load_state_dict(
    load_ckpt_state_dict(
        # Path("wandb/run-20250328_110212-x8rqu2qc/files/checkpoints/epoch=2-step=530000.safetensors"),
        Path("wandb/run-20250404_013005-i41ffa2m/files/checkpoints/epoch=4-step=1000000.ckpt"),
        unwrap_lightning=True,
    )
)
ds = FullSongPianorollDataset(Path(dataset_config.path), props=["pianoroll"])
_ = model.eval()

In [None]:
ds[6]["pianoroll"].to_midi("ignore/output/gt.mid")

In [None]:
pr

In [None]:
import torch
from vqpiano.models.representation import SymbolicRepresentation
import matplotlib.pyplot as plt


# song, bar_idx1, bar_idx2 = '@Animenzzz/5Ggnzs2hP3s/0_365', 28, 63
song, bar_idx1, bar_idx2 = "@Animenzzz/3KpFbty0t_8/0_347", 29, 65
# pr = ds.ds.get_song(song).read_pianoroll("pianoroll", frames_per_beat=8)
pr = ds[0]["pianoroll"]

gt = SymbolicRepresentation.from_pianorolls(list(pr.iter_over_bars_pr()))

with torch.no_grad():
    model: EncoderDecoder
    latent = model.encode(gt)

# plot similarity matrix
latent_normalized = latent / latent.norm(dim=1, keepdim=True)
cosine_similarity = latent_normalized @ latent_normalized.T
plt.imshow(cosine_similarity, vmin=0.1, vmax=0.3)
plt.colorbar()
plt.show()


In [None]:
latent = latent.clone()
k = 4
rate = 0.1
for _ in range(5):
    diff = latent[:, None, :] - latent[None, :, :]
    euclidean_distance = diff.norm(dim=2, p=2)
    old_latent = latent.clone()
    for i in range(latent.shape[0]):
        nearest_k_idx = torch.argsort(euclidean_distance[i])[1 : k + 1]
        latent[i] = old_latent[i] + rate * (old_latent[nearest_k_idx].mean(dim=0) - old_latent[i])

    latent_normalized = latent / latent.norm(dim=1, keepdim=True)
    cosine_similarity = latent_normalized @ latent_normalized.T

plt.imshow(cosine_similarity, vmin=0.1, vmax=0.3)
plt.colorbar()
plt.show()


In [None]:
model.decode_autoregressive(latent[:10], 4).to_midi(21, "ignore/output/reconst2.mid")

In [None]:
# euclidean distance
diff = latent[:, None, :] - latent[None, :, :]
euclidean_distance = diff.norm(dim=2, p=2)
plt.imshow(-euclidean_distance)
plt.colorbar()
plt.show()


In [None]:
z1 = latent[bar_idx1]
z2 = latent[bar_idx2]
lerp_n = 10
z = torch.stack([z1.lerp(z2, i / (lerp_n - 1)) for i in range(lerp_n)])  # (10,dim)
prompt = SymbolicRepresentation.from_pianorolls([pr.slice((bar_idx1 - 4) * 32, (bar_idx1) * 32)])

results = []
for z_i in z:
    z_i = z_i.unsqueeze(0)
    with torch.no_grad():
        generated = model.decoder.sample(duration=prompt.duration + model.target_duration, prompt=prompt, condition=z_i)
        results.append(generated.to_pianoroll(min_pitch=21).slice(4 * 32, 5 * 32))

result = results[0]
for i in range(1, len(results)):
    result |= results[i]

plt.imshow(result.to_img_tensor())
plt.show()

result.to_midi("ignore/output/lerp.mid")


In [None]:
result.show()

In [None]:
pr.slice(0, 30 * 32).to_midi("ignore/output/gt.mid")

In [None]:
def reconstruct_autoregressive(
    model: EncoderDecoder,
    latents: torch.Tensor,
    n_prompt_bars: int,
    given_prompt_bars: list[SymbolicRepresentation] | None = None,
):
    """
    if given_prompt_bars is None, the first iterations the model will receive empty bars as prompts. It will feel
    generating the beginning of the piece.

    To make the model generate bars from the middle of the piece, pass the previous bars as given_prompt_bars.
    """
    bars = []

    if given_prompt_bars is None:
        for i in range(n_prompt_bars):
            bar = SymbolicRepresentation(device=latents.device)
            for _ in range(32):
                bar.add_frame()
            bars.append(bar)
    else:
        assert len(given_prompt_bars) == n_prompt_bars, f"{len(given_prompt_bars)} != {n_prompt_bars}"
        bars = given_prompt_bars.copy()

    for i in range(len(latents)):
        prompt = SymbolicRepresentation.cat(bars[i : i + n_prompt_bars])

        prediction = model.decoder.sample(
            duration=prompt.duration + model.target_duration, prompt=prompt, condition=latents[i].unsqueeze(0)
        )
        assert prediction.duration == prompt.duration + model.target_duration
        bars.append(prediction[:, prompt.length :])

    if given_prompt_bars is None:
        # remove the padding bars
        return SymbolicRepresentation.cat(bars[n_prompt_bars:])
    else:
        return SymbolicRepresentation.cat(bars)


pr = ds[3715].slice(0, 30 * 2)

pr.to_midi("ignore/output/gt.mid")

gt = SymbolicRepresentation.from_pianorolls(list(pr.iter_over_bars_pr()))


with torch.no_grad():
    latent = model.encode(gt)

latent_bias = torch.zeros_like(latent)
# latent_bias[:,14] = 1
res = reconstruct_autoregressive(model, latent + latent_bias, 4)
res.to_midi(21, "ignore/output/reconst1.mid")