In [23]:
import sys
sys.path.append("..")

import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as ipw
import torch
import torch.nn.functional as F

from lib.dataset_wrapper import Dataset
from lib.nn.jerk_loss import compute_jerk_loss
from lib.notebooks import show_ema

In [24]:
DATASETS_NAME = ["pb2007", "fsew0", "msak0"]
ART_TYPE = "ema"

In [36]:
def compute_jerk_loss(art_seqs, seqs_mask=None):
    speed = torch.diff(art_seqs, dim=-2)
    acc = torch.diff(speed, dim=-2)
    jerk = torch.diff(acc, dim=-2)
    if seqs_mask is not None:
        jerk = jerk[seqs_mask[:, 3:]]
    jerk_loss = (jerk ** 2).mean()    
    return jerk_loss

def compute_jerk_loss_2(art_seqs, seqs_mask=None):
    jerk = torch.diff(art_seqs, n=3, dim=-2)
    if seqs_mask is not None:
        jerk = jerk[seqs_mask[:, 3:]]
    jerk_loss = (jerk ** 2).mean()    
    return jerk_loss

In [37]:
def show_dataset(dataset_name):
    dataset = Dataset(dataset_name)
    items_name = dataset.get_items_list()
    items_art = dataset.get_items_data(ART_TYPE, cut_silences=True)

    def resynth_item(item_name):
        item_art = items_art[item_name]
        item_art = torch.FloatTensor(item_art)
        jerk_loss = compute_jerk_loss(item_art)
        print(jerk_loss)
        jerk_loss = compute_jerk_loss_2(item_art)
        print(jerk_loss)

    display(ipw.interactive(resynth_item, item_name=items_name))

display(ipw.interactive(show_dataset, dataset_name=DATASETS_NAME))

interactive(children=(Dropdown(description='dataset_name', options=('pb2007', 'fsew0', 'msak0'), value='pb2007…

In [33]:
def test_relu(ceil, slope):
    plt.figure()
    ax = plt.subplot(aspect="equal")
    x = torch.arange(-10, 10)
    y = F.relu(x - ceil) * slope
    ax.plot(x, x)
    ax.plot(x, y)
    plt.show()

ipw.interactive(test_relu, ceil=(0, 2.), slope=(0., 2.))

interactive(children=(FloatSlider(value=1.0, description='ceil', max=2.0), FloatSlider(value=1.0, description=…