In [4]:
from mimikit.freqnet import FreqNet
from mimikit.data import Database
from mimikit.utils import audio, signal
from mimikit import NeptuneConnector
import torch
import numpy as np
from random import randint
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (20, 6)

def cosine_similarity(X, Y, eps=1e-10):
    """
    safely computes the cosine similarity between matrices X and Y.

    Shapes:
    -------
    X : (*, N, D)
    Y : (*, M, D)
    D_xy : (*, N, M)

    Notes:
    ------
    The need for this function arises from the fact that torch.nn.CosineSimilarity only computes the 
    diagonal of D_xy, as in cosine_sim(output, target) 
    """
    if not isinstance(eps, torch.Tensor):
        eps = torch.tensor(eps).to(X)
        
    dot_prod = torch.matmul(X, Y.transpose(-2, -1))
    norms = torch.norm(X, p=2, dim=-1).unsqueeze_(-1) * torch.norm(Y, p=2, dim=-1).unsqueeze_(-2)
    cos_theta = dot_prod / torch.maximum(norms, eps)    
    return cos_theta

def angular_distance(X, Y, eps=1e-10):
    """
    angular distance is a valid distance metric based on the cosine similarity
    see https://en.wikipedia.org/wiki/Cosine_similarity#Angular_distance_and_similarity
    
    Shapes:
    -------
    X : (*, N, D)
    Y : (*, M, D)
    D_xy : (*, N, M)
    """
    if not isinstance(eps, torch.Tensor):
        eps = torch.tensor(eps).to(X)
        
    def safe_acos(x):
        # torch.acos returns nan near -1 and 1... see https://github.com/pytorch/pytorch/issues/8069
        return torch.acos(torch.clamp(x, min=-1+eps/2, max=1-eps/2))

    have_negatives = torch.any(X < 0) or torch.any(Y < 0)
    cos_theta = cosine_similarity(X, Y, eps)
    
    pi = torch.acos(torch.zeros(1)).item() * 2
    D_xy = (1 + int(not have_negatives)) * safe_acos(cos_theta) / pi
    
    return D_xy


def nearest_neighbor(X, Y):
    D_xy = angular_distance(X, Y)
    dists, nn = torch.min(D_xy, dim=-1)
    return dists, nn


def torch_frame(x, frame_size, hop_length):
    N = x.size(-1)
    org_size = x.size()[:-1]
    tmp_0 = np.prod(tuple(org_size))
    new_dims = (1 + int((N - frame_size) / hop_length), frame_size)
    framed = torch.as_strided(x.reshape(-1, N), (tmp_0, *new_dims), (N, hop_length, 1))
    return framed.reshape(*org_size, *new_dims)


def repeat_rate(x, frame_size, hop_length):
    """
    frames x and compute repeat-rate per frame
    """
    framed = torch_frame(x, frame_size, hop_length)
    uniques = torch.tensor([torch.unique(row).size(0) for row in framed.reshape(-1, framed.size(-1))])
    return (1 - (uniques-1) / (frame_size-1)).reshape(framed.size()[:-1], -1)


def nn_repeat_rate(X, wrt, frame_size, hop_length):
    _, neighbs = nearest_neighbor(X, wrt)
    return repeat_rate(neighbs, frame_size, hop_length)

![title](imgs/redundance-rate.png)

## Setup DB & model

In [None]:
nep_con = NeptuneConnector(user="k-tonal",
                           setup=dict(db="data-and-base-notebooks/DAT-27",
                                      model="experiment-2/EX2-13"))

db_name = "genoel-mix.h5"

path_to_db = "./" + db_name
path_to_model = "./models"

nep_con.download_experiment("model", destination=path_to_model, artifacts="states/")

db = nep_con.download_database("db", db_name)

db.metadata

In [None]:
epoch = 99

path_to_ckpt = path_to_model + nep_con.setup["model"].split("/")[-1] + "/states/epoch=%i.ckpt" % epoch
model = FreqNet.load_from_checkpoint(path_to_ckpt, data_object=db.fft)

## Generate single output

In [None]:
prompt_length = 64
n_steps = 2048
i = random.randint(0, model.data.shape[0] - prompt_length)

output = model.generate(model.data[i:i+prompt_length], time_domain=False, n_steps=n_steps).squeeze(0)
wrt = torch.from_numpy(model.data[i+prompt_length:i+prompt_length+n_steps]).to(output).unsqueeze(0)

audio(output.squeeze().numpy().T, hop_length=db.fft.attrs["hop_length"])

## Compute RR over time at mutiple levels for a single output

In [None]:
# compute nearest neighbors:

with torch.no_grad():
    _, neighbs = nearest_neighbor(output[:, prompt_length:], wrt)


# multiple levels of locality :

frame_size = (2, 8, 32)
hop_length = (1, 1, 1)


# compute rr and plot

for fs, hop in zip(frame_size, hop_length):
    with torch.no_grad():
        r = repeat_rate(neighbs, frame_size, hop_length)
    plt.plot(r.squeeze().cpu().numpy(), label="frame_size="+str(fs))
    
plt.legend()
plt.xlabel('Redundance Rate')
plt.ylabel('Time')
plt.title('Local Redundance Rate')

In [None]:
# number of prompts we will score :

n_prompts = 500

# params for each prompt :

prompt_length = 64
n_steps = 300

indices = range(0, db.fft.shape[0], db.fft.shape[0] // n_prompts)
prompts = torch.from_numpy(np.stack([db.fft[i:i+prompt_length] for i in indices]))
wrts = torch.from_numpy(np.stack([db.fft[i+prompt_length:i+prompt_length+n_steps] for i in indices]))

outputs = model.generate(prompts, time_domain=False, n_steps=n_steps).squeeze(0)


In [None]:
# compute nearest neighbors:

with torch.no_grad():
    _, neighbs = nearest_neighbor(outputs[:, prompt_length:], wrts)


# multiple levels of locality :

frame_size = (2, 8, 32)
hop_length = (1, 1, 1)


# compute rr and plot

for fs, hop in zip(frame_size, hop_length):
    with torch.no_grad():
        r = repeat_rate(neighbs, frame_size, hop_length).mean(dim=-1)
    plt.plot(list(indices), r.squeeze().cpu().numpy(), label="frame_size="+str(fs))
    
plt.legend()
plt.legend()
plt.xlabel('Mean Local Redundance Rate')
plt.ylabel('Prompt Index')
plt.title("Output's Scores")