In [1]:
from __future__ import print_function
import torch
from GIM_encoder import GIM_Encoder
from data import get_dataloader
from decoder_architectures import SimpleV2Decoder
from helper_functions import create_log_dir, fft_magnitude, plot_two_graphs_side_by_side
from options_interpolate import get_options
from IPython.display import Audio
import numpy as np
import soundfile as sf

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import matplotlib.pyplot as plt
import random

In [2]:
def load_decoder(options, decoder_model_path, device):
    decoder = options['decoder'].to(device)
    decoder.load_state_dict(torch.load(
        decoder_model_path, map_location=device))
    return decoder


def set_dimension_equal_to(batch, b_idx, dim_idx, val):
    # batch is of shape: 171, 32, 2
    # so always change the last two points

    b, c, l = batch.shape
    batch[b_idx, dim_idx, 0] = val

def invent_latent_rnd(device):
    batch_enc = torch.rand((171, 32, 2)).to(device)

    return batch_enc


def invent_latent(device):
    # batch_enc = torch.randn((171, 32, 2)).to(device)
    batch_enc = torch.zeros((171, 32, 2)).to(device)

    # first sample: val(dim 0) := 1
    nb_dims = 32
    for dim in range(nb_dims):
        set_dimension_equal_to(batch_enc, b_idx=dim, dim_idx=dim, val=0.5)

    return batch_enc

In [3]:
def setup():
  OPTIONS = get_options()
  DEVICE = OPTIONS["device"]

  CPC_MODEL_PATH = OPTIONS["cpc_model_path"]
  DECODER_MODEL_PATH = OPTIONS["decoder_model_path"]

  ENCODER = GIM_Encoder(OPTIONS, path=CPC_MODEL_PATH)
  ENCODER.encoder.eval()

  DECODER = load_decoder(OPTIONS, DECODER_MODEL_PATH, DEVICE)
  DECODER.eval()

  train_loader, _, test_loader, _ = get_dataloader.get_dataloader(
      OPTIONS, dataset="de_boer_sounds_reshuffledv2", split_and_pad=False, train_noise=False, shuffle=True)

  return OPTIONS, DEVICE, ENCODER, DECODER

OPTIONS, DEVICE, ENCODER, DECODER = setup()

Let's use 1 GPUs!
Loading De Boer Sounds dataset...


In [4]:
def latent_space(device, **kwargs):
    batch_enc = torch.zeros((171, 32, 2)).to(device)

    for dim_idx, dim_val in kwargs.items():
        #eg: dim_1_val
        dim_idx = int(dim_idx.split("_")[-2])
        set_dimension_equal_to(batch_enc, b_idx=0, dim_idx=dim_idx, val=dim_val)

    # set_dimension_equal_to(batch_enc, b_idx=0, dim_idx=0, val=dim1_val)

    return batch_enc


def plot_latents(**kwargs):
    batch_enc_audio = latent_space(DEVICE, **kwargs)
    batch_outp = DECODER(batch_enc_audio)

    target_dir = "invented_audios"
    sr = 16000
    create_log_dir(target_dir)
    for idx, outp in enumerate(batch_outp):
        sequence = outp[0].cpu().detach().numpy()
        fft_mag = fft_magnitude(sequence)
        fft_mag = fft_mag[:100]
        plot_two_graphs_side_by_side(sequence, fft_mag, fig_size=(
            10, 4), y_lims=[(-0.1, 0.1), (0, 5)])
        
        # play audio
        return Audio(sequence, rate=sr)

        break
    return


range = (-1.1, 1.1, 0.01)
interact(plot_latents,
         dim_0_val=range,
         dim_1_val=range,
         dim_2_val=range,
         dim_3_val=range,
         dim_4_val=range,
         dim_5_val=range,
         dim_6_val=range,
         dim_7_val=range,
         dim_8_val=range,
         dim_9_val=range,
         dim_10_val=range,
         dim_11_val=range,
         dim_12_val=range,
         dim_13_val=range,
         dim_14_val=range,
         dim_15_val=range,
         dim_16_val=range,
         dim_17_val=range,
         dim_18_val=range,
         dim_19_val=range,
         dim_20_val=range,
         dim_21_val=range,
         dim_22_val=range,
         dim_23_val=range,
         dim_24_val=range,
         dim_25_val=range,
         dim_26_val=range,
         dim_27_val=range,
         dim_28_val=range,
         dim_29_val=range,
         dim_30_val=range,
         dim_31_val=range)


interactive(children=(FloatSlider(value=0.0, description='dim_0_val', max=1.1, min=-1.1, step=0.01), FloatSlid…

<function __main__.plot_latents(**kwargs)>

In [5]:


def series(dots):
    plt.figure(figsize=(10, 10))
    a, b = [], []
    for i in range(dots):
        a.append(random.randint(1, 100))
        b.append(random.randint(1, 100))
    plt.scatter(a, b, c="red")
    return

interact(series, dots=(1, 100, 1))

# colr=["red", "orange", "brown"]

None

<Figure size 720x720 with 0 Axes>

interactive(children=(IntSlider(value=50, description='dots', min=1), Output()), _dom_classes=('widget-interac…