In [1]:
%cd ../../..


/Users/bezha/PycharmProjects/TripleStreams


In [2]:
import numpy as np
import yaml
from matplotlib import pyplot as plt
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
os.environ.pop("MPLDEBUG", None)
import tqdm
import torch
from torch.utils.data import DataLoader
from model import FlexControlTripleStreamsVAE
from data import get_flexcontrol_triplestream_dataset



Could not import fluidsynth. AUDIO rendering will not work.
Holoviews not installed. Please install holoviews to be able to generate heatmaps.


## Load Dataset

### Note: First run, it will take a while to load the dataset, but it will be cached for future runs.

In [3]:
config = yaml.safe_load(open('eval/Post-Training/ControlConfiguration1/config.yaml', 'r'))

is_testing = True

dataset = get_flexcontrol_triplestream_dataset(
        config=config,
        subset_tag="validation",
        use_cached=True,
        downsampled_size=2000 if is_testing else None,
        print_logs=False                                #<---  Set to True to print dataset loading logs
    )

## Load Model

In [4]:
from model import load_model
model = load_model(
    model_path='eval/Post-Training/ControlConfiguration1/step_274196.pth',
    model_class=FlexControlTripleStreamsVAE,
    is_evaluating=True
)

✅ Using config from model file
[None] [None, None, None, None, None]
🎉 Successfully loaded FlexControlTripleStreamsVAE


## Run Inference

In [12]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from data.triple_streams.triple_stream_data_utils import create_multitab_from_HVO_Sequences, compile_into_list_of_hvo_seqs
from bokeh.io import save
from IPython.display import HTML, display
import torch
import os

drum_mapping = {
    "Input Groove": [36],
    "Stream 1": [37],
    "Stream 2": [38],
    "Stream 3": [39],
}

@interact_manual(
    sample=widgets.IntSlider(min=0, max=len(dataset)-1, step=1, value=0, description='Sample'),
    param1=widgets.FloatSlider(min=0, max=1, step=0.01, value=0, description='Structural Sim'),
    param2=widgets.FloatSlider(min=0, max=1, step=0.01, value=0, description='Tot Hit'),
    param3=widgets.FloatSlider(min=0, max=1, step=0.01, value=0, description='Step Den'),
    param4=widgets.FloatSlider(min=0, max=1, step=0.01, value=0, description='s1 rel den'),
    param5=widgets.FloatSlider(min=0, max=1, step=0.01, value=0, description='s2 rel den'),
    param6=widgets.FloatSlider(min=0, max=1, step=0.01, value=0, description='s3 rel den')
)
def generate_function(sample, param1, param2, param3, param4, param5, param6):
    """Generate and save plot as HTML file"""
    print(f"Generating with parameters: Sample={sample}, Controls=[{param1}, {param2}, {param3}, {param4}, {param5}, {param6}]")

    model.eval()
    with torch.no_grad():
        input_groove = dataset.input_grooves[sample].unsqueeze(0)
        encoding_control_tokens = torch.tensor([param1]).unsqueeze(0)
        sum_dens = param4 + param5 + param6
        sum_dens = 1.0 if sum_dens == 0 else sum_dens
        decoding_control_tokens = torch.tensor([param2, param3, param4/sum_dens, param5/sum_dens, param6/sum_dens]).unsqueeze(0)

        hvo, latent_z = model.predict(
            flat_hvo_groove=input_groove,
            encoding_control_tokens=encoding_control_tokens,
            decoding_control_tokens=decoding_control_tokens,
        )

    hvo_sequence_list = compile_into_list_of_hvo_seqs(
        input_hvos = input_groove,
        output_hvos = hvo, 
        metadatas = [dataset.metadata[sample]]
        
    )

    # Create the plot
    tabs = create_multitab_from_HVO_Sequences(
        hvos=hvo_sequence_list[0]
    )

    hvo_sequence_list[0]

interactive(children=(IntSlider(value=0, description='Sample', max=1955), FloatSlider(value=0.0, description='…