## Install

In [None]:
!pip install mimikit

## Imports and Source Code for Redundance Rate

In [4]:
from mimikit.freqnet import FreqNet
from mimikit.data import Database
from mimikit.utils import audio, signal
from mimikit import NeptuneConnector
import torch
import numpy as np
from random import randint
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (20, 6)


# functions we need to compute the redundance rate


def cosine_similarity(X, Y, eps=1e-10):
    """
    safely computes the cosine similarity between matrices X and Y.

    Shapes:
    -------
    X : (*, N, D)
    Y : (*, M, D)
    D_xy : (*, N, M)

    Notes:
    ------
    The need for this function arises from the fact that torch.nn.CosineSimilarity only computes the 
    diagonal of D_xy, as in cosine_sim(output, target) 
    """
    if not isinstance(eps, torch.Tensor):
        eps = torch.tensor(eps).to(X)
        
    dot_prod = torch.matmul(X, Y.transpose(-2, -1))
    norms = torch.norm(X, p=2, dim=-1).unsqueeze_(-1) * torch.norm(Y, p=2, dim=-1).unsqueeze_(-2)
    cos_theta = dot_prod / torch.maximum(norms, eps)    
    return cos_theta

def angular_distance(X, Y, eps=1e-10):
    """
    angular distance is a valid distance metric based on the cosine similarity
    see https://en.wikipedia.org/wiki/Cosine_similarity#Angular_distance_and_similarity
    
    Shapes:
    -------
    X : (*, N, D)
    Y : (*, M, D)
    D_xy : (*, N, M)
    """
    if not isinstance(eps, torch.Tensor):
        eps = torch.tensor(eps).to(X)
        
    def safe_acos(x):
        # torch.acos returns nan near -1 and 1... see https://github.com/pytorch/pytorch/issues/8069
        return torch.acos(torch.clamp(x, min=-1+eps/2, max=1-eps/2))

    have_negatives = torch.any(X < 0) or torch.any(Y < 0)
    cos_theta = cosine_similarity(X, Y, eps)
    
    pi = torch.acos(torch.zeros(1)).item() * 2
    D_xy = (1 + int(not have_negatives)) * safe_acos(cos_theta) / pi
    
    return D_xy


def nearest_neighbor(X, Y):
    """
    computes nearest neighbor by angular distance
    """
    D_xy = angular_distance(X, Y)
    dists, nn = torch.min(D_xy, dim=-1)
    return dists, nn


def torch_frame(x, frame_size, hop_length):
    """
    helper to reshape an array into frames
    """
    N = x.size(-1)
    org_size = x.size()[:-1]
    tmp_0 = np.prod(tuple(org_size))
    new_dims = (1 + int((N - frame_size) / hop_length), frame_size)
    framed = torch.as_strided(x.reshape(-1, N), (tmp_0, *new_dims), (N, hop_length, 1))
    return framed.reshape(*org_size, *new_dims)


def repeat_rate(x, frame_size, hop_length):
    """
    frames x and compute repeat-rate per frame
    """
    framed = torch_frame(x, frame_size, hop_length)
    uniques = torch.tensor([torch.unique(row).size(0) for row in framed.reshape(-1, framed.size(-1))])
    return (1 - (uniques-1) / (frame_size-1)).reshape(framed.size()[:-1], -1)


![title](imgs/redundance-rate.png)

## Setup DB & model

In [None]:
import os

project_name = "experiment-2/"

# function to filter experiment (must return True or False)
# `exp_row` is a dict with the column names as keys and the values are the exp's values (no .y!)
def is_match(exp_row):
    return "Lachenmann" in exp_row["DB"] or "Schweifel" in exp_row["DB"]


nc = NeptuneConnector(user="k-tonal", setup={"project": project_name})
prj = nc.get_project("project")
if "exps" not in globals():
    exps = [(exp, exp.get_logs()) for exp in prj.get_experiments()
            if is_match({k: v.y for k, v in exp.get_logs().items()})]
dbs, mdls = {}, {}
for exp, log in exps:
    nc.setup[exp.id + "-db"] = log["db-id"].y
    nc.setup[exp.id] = project_name + exp.id
    if not os.path.exists(log["db-name"].y):
        nc.download_database(exp.id + "-db", log["db-name"].y)
    if not os.path.exists(exp.id + "/states"):
        nc.download_experiment(exp.id, artifacts="states/")

    dbs[exp.id] = Database(log["db-name"].y)
    iloc = log["db-iloc"].y
    def get_model(id, epoch, iloc=iloc):
        if iloc:
            data = dbs[id].fft.get(dbs[id].metadata.iloc[eval(iloc)])
        else:
            data = dbs[id].fft
        return FreqNet.load_from_checkpoint(id + "/states/epoch=%i.ckpt" % epoch,
                                     data_object=data)
    mdls[exp.id] = get_model

mdls

### Load a Model

In [None]:
exp_id = "EX2-200"
epoch = 99

db, model = dbs[exp_id], mdls[exp_id](exp_id, epoch)
db.fft.attrs, model.data.shape, model.hparams

## Generate single output

In [None]:
prompt_length = 64
n_steps = 2048


# prompt index :

i = randint(0, model.data.shape[0] - prompt_length)


output = model.generate(model.data[i:i+prompt_length], time_domain=False, n_steps=n_steps).squeeze(0)
wrt = torch.from_numpy(model.data[i+prompt_length:i+prompt_length+n_steps]).to(output).unsqueeze(0)

audio(output.squeeze().numpy().T, hop_length=db.fft.attrs["hop_length"])

## Compute RR over time at mutiple levels for a single output

In [None]:
# compute nearest neighbors:

with torch.no_grad():
    _, neighbs = nearest_neighbor(output[:, prompt_length:], wrt)


# for plotting multiple levels of locality, we have one hop_length for several frame_sizes
frame_size = (8, 64, 128)
hop_length = 2


# compute rr and plot

for fs in frame_size:
    with torch.no_grad():
        r = repeat_rate(neighbs, fs, hop_length)
    plt.plot(r.squeeze().cpu().numpy(), label="frame size = "+str(fs))
    
axes = plt.gca()
axes.set_ylim([-0.1, 1.1])
plt.legend()
plt.ylabel('Redundance Rate')
plt.xlabel('Time')
plt.title('Local Redundance Rate over Time')
None

## Initialize Outputs and scores

In [None]:
from tqdm import tqdm
from torch.utils.data import DataLoader

#######################################

# we take one prompt every "stride" index :
stride = 16

n_steps = 200
prompt_length = 64

batch_size = 300

# params for RR :

frame_sizes = (8, 64, 128)
hop_length = 2

#######################################

all_indices = np.arange(0, model.data.shape[0] - prompt_length, stride)
current_step = np.zeros_like(all_indices)

loader = DataLoader(all_indices, shuffle=False, batch_size=batch_size, drop_last=False)

scores = np.ones((all_indices.shape[0], len(frame_sizes)))
last_outputs = []
# compute

for indices in tqdm(loader):

    prompts = torch.from_numpy(model.data[indices[0]:indices[-1]+prompt_length])
    wrts = torch.from_numpy(model.data[min(indices[0]+prompt_length, model.data.shape[0]-stride*indices.shape[0]-n_steps):indices[-1]+prompt_length+n_steps])

    with torch.no_grad():
        D = prompts.size(-1)
        prompts = torch.as_strided(prompts, size=(indices.shape[0], prompt_length, D),
                                   stride=(stride * D, D, 1))
        wrts = torch.as_strided(wrts, size=(indices.shape[0], n_steps, D),
                                stride=(stride * D, D, 1))

    outputs = model.generate(prompts, time_domain=False, n_steps=n_steps).squeeze(0)
    last_outputs += [outputs[:, -model.receptive_field():].clone()]
    with torch.no_grad():
        _, neighbs = nearest_neighbor(outputs[:, prompt_length:], wrts)

    for i, fs in enumerate(frame_sizes):
        with torch.no_grad():
            r = repeat_rate(neighbs, fs, hop_length).mean(dim=-1)

    scores[indices // stride, i] = r.squeeze().numpy()
    current_step[indices // stride] = n_steps

with torch.no_grad():
    last_outputs = torch.cat(last_outputs)

### Plot

In [None]:
plt.figure(figsize=(32, 8))
for i, fs in enumerate(frame_sizes):
    plt.plot(all_indices[:], scores[:, i], label="frame size = "+str(fs))
    
plt.legend()
axes = plt.gca()
axes.set_ylim([-0.1, 1.1])
plt.ylabel('Mean Local Redundance Rate')
plt.xlabel('Prompt Index')
plt.title("Output's Scores")

scores.mean()

## Generate further and compute scores for the best indices

In [None]:
def get_wrts(model, indices, n_steps):
    wrts = np.stack([model.data[min(i, model.data.shape[0]-n_steps):i+n_steps] for i in indices])
    return torch.from_numpy(wrts)

def gen_and_score(model, prompts, indices, n_steps, frame_sizes=(8, 16, 32), hop_length=2):

    this_scores = np.ones((prompts.shape[0], len(frame_sizes)))
    outputs = model.generate(prompts, time_domain=False, n_steps=n_steps).squeeze(0)
    last_outputs = outputs[:, -model.receptive_field():].clone()
    wrts = get_wrts(model, indices, n_steps)
    with torch.no_grad():
        _, neighbs = nearest_neighbor(outputs[:, prompts.size(1):], wrts)

    for i, fs in enumerate(frame_sizes):
        with torch.no_grad():
            r = repeat_rate(neighbs, fs, hop_length).mean(dim=-1)

        this_scores[:, i] = r.squeeze().numpy()

    return this_scores, last_outputs

### PLAY WITH THESE : ######

n_steps = 500
n_times = 10
# index of the frame_size to use for threshold :
frame_sizes_level = 2
# only consider indices with scores below :
candidates_threshold = .75

############################

for _ in tqdm(range(n_times)):

    # threshold
    candidates = np.arange(scores.shape[0])[scores[:, frame_sizes_level] < candidates_threshold]

    if not np.any(candidates):
        break

    # n bests
    idx = np.argsort(scores[candidates, 0])[:128]
    idx = candidates[idx]

    new_scores, new_outs = gen_and_score(model, last_outputs[idx], all_indices[idx], n_steps)
    # update scores
    scores[idx] = (scores[idx] * current_step[idx][:, None] + (new_scores * n_steps)) / (current_step[idx][:, None] + n_steps)
    current_step[idx] += n_steps
    last_outputs[idx] = new_outs
    
"200 best indices are ", scores[:, frame_sizes_level].argsort()[:200]

### Plot and listen

In [None]:
plt.figure(figsize=(32, 8))
for i, fs in enumerate(frame_sizes):
    plt.plot(all_indices[:], scores[:, i], label="frame size = "+str(fs))
    
plt.legend()
axes = plt.gca()
axes.set_ylim([-0.1, 1.1])
plt.ylabel('Mean Local Redundance Rate')
plt.xlabel('Prompt Index')
plt.title("Output's Scores")

##### PICK AN INDEX #####
index = 123
########################

best_i = index * stride
print("prompt_index", best_i, "with scores", scores[index], "generated", current_step[index], "steps")

prompt = model.data[best_i:best_i+64]
out = model.generate(prompt, time_domain=True, n_steps=5000).squeeze().numpy()

audio(out, hop_length=db.fft.attrs["hop_length"])