In [None]:
import pathlib

import cached_conv as cc
import gin
import torch

from msprior.attention import Prior
from msprior.dataset import SequenceDataset

cc.use_cached_conv(True)
torch.set_grad_enabled(False)

run = "runs/encoder_decoder_continuous/"
db_path = "/data/antoine/rave2vec_jax/preprocessed"

config = pathlib.Path(run).rglob("*.gin")
config = str(list(config)[0])

gin.clear_config()
gin.parse_config_file(config)

model = Prior().eval()

try:
    ckpt = pathlib.Path(run).rglob("*best*.ckpt")
    ckpt = str(list(ckpt)[0])
    print(f"loading {ckpt}")
    ckpt = torch.load(ckpt, map_location="cpu")["state_dict"]
    model.load_state_dict(ckpt)
    print("checkpoint restored")
except:
    print("could not restore checkpoint")

gin.parse_config("SEQ_LEN=512")
dataset = SequenceDataset(db_path)

rave = pathlib.Path(db_path).glob("*.ts")
rave = str(list(rave)[0])
rave = torch.jit.load(rave).eval()

In [None]:
from IPython.display import Audio, display

data_idx = 18282
batch = dataset[data_idx]
z = torch.from_numpy(batch["decoder_inputs"].T[None])

audio = rave.decode(z)

display(Audio(audio.reshape(-1).numpy(), rate=rave.sr))

In [None]:
from tqdm import tqdm, trange

# RESET
for n, m in model.named_modules():
    if n[-9:] != "attention":
        continue
    m.reset()

samples = torch.from_numpy(batch["decoder_inputs"][None, :256])
try:
    semantic = torch.from_numpy(batch["encoder_inputs"][None, :64])
    semantic_next = torch.from_numpy(batch["encoder_inputs"][None, 64:])
except:
    pass

if "only" in run:
    semantic = None

temperature = .85

current_sample = model.sample(samples, semantic, temperature)[0][:, -1:]
samples = torch.cat([samples, current_sample], 1)

for t in trange(255):
    s_index = t // 4
    if semantic is not None:
        current_semantic = semantic_next[:, s_index:s_index + 1]
    else:
        current_semantic = None

    current_sample = model.sample(current_sample, current_semantic,
                                  temperature)[0]
    samples = torch.cat([samples, current_sample], 1)

audio_gen = rave.decode(samples.permute(0, 2, 1)).reshape(-1).numpy()
display(Audio(audio_gen, rate=rave.sr))

In [None]:
import librosa as li
import matplotlib.pyplot as plt
import numpy as np


def get_melspec(x):
    x = np.asarray(x).reshape(-1)
    x = li.feature.melspectrogram(y=x, sr=rave.sr)
    x = np.log1p(abs(x))
    return x


fig, axes = plt.subplots(2, 1, figsize=(8, 8))

axes[0].matshow(get_melspec(audio), aspect="auto", origin="lower")
axes[1].matshow(get_melspec(audio_gen), aspect="auto", origin="lower")

In [None]:
audio_np = np.asarray(audio).reshape(-1)

audio_stack = np.stack([audio_np, audio_gen], 0)

display(Audio(audio_stack, rate=rave.sr))

In [None]:
import pickle

from udls import AudioExample

with dataset.env.begin() as txn:
    ae = AudioExample(txn.get(dataset.keys[data_idx]))
ae = ae.as_dict()

with open("/data/antoine/rave2vec_jax/preprocessed/kmeans", "rb") as kmeans:
    kmeans = pickle.load(kmeans)

clusters = kmeans.cluster_centers_

tokens = ae["semantic_indices"]
features = ae["semantic_features"]
tokenized_features = np.take(clusters, tokens, 0)

fig, axes = plt.subplots(2, 1, figsize=(8, 8))
axes[0].matshow(features, aspect="auto")
axes[1].matshow(tokenized_features, aspect="auto")

In [None]:
plt.matshow(abs(features.T - tokenized_features.T), origin="lower")
plt.colorbar()