In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:75% !important; }</style>"))

import mimikit as mmk
import ipywidgets as W
import dataclasses as dtc
from typing import Optional, Callable, Any, Tuple, List, Dict
from IPython.display import display
UI = mmk.ui

cfg = mmk.TrainFreqnetConfig()

In [None]:
def EnumWidget(
        description,
        options,
        container,
        value_type=str,
        selected_index=0
):
    label = W.HTML(value=description,
                   layout=W.Layout(
                       min_width="max-content", align_self="baseline", margin="0 4px 0 4px", padding="0"),
                   disabled=True)
    children = [W.ToggleButton(value=False, description=opt, layout=W.Layout(width="inherit")) for opt in options]
    container.children = (label, *children)
    container.value = options[selected_index] if value_type is str else value_type(options[selected_index])
    for i, child in enumerate(children):
        def observer(ev, c=child, index=i):
            val = ev["new"]
            if val:
                container.selected_index = index
                container.value = c.description if value_type is str else value_type(c.description)
                setattr(c, "button_style", "success")
                for other in children:
                    if other.value and other is not c:
                        other.value = False
            else:
                c.button_style = ""

        child.observe(observer, "value")
    children[selected_index].value = True
    return container

EnumWidget("<b>test: </b>", ["1", "2", "8", "12345", "8987987987"],
           W.HBox(layout=W.Layout(
    width="95%",
#     display="flex",
#     justify_items="stretch"
)), value_type=int)

In [None]:
def pw2_widget(description, ):
    label = W.Label(value=description)
    value = W.Text(value="4", layout=W.Layout(width="75px"))
    plus = W.Button(icon="plus", layout=W.Layout(width="50px"))
    minus = W.Button(icon="minus", layout=W.Layout(width="50px"))
    plus.on_click(lambda clk: setattr(value, "value", str(int(value.value) * 2)))
    minus.on_click(lambda clk: setattr(value, "value", str(int(value.value) // 2)))
    box = W.HBox([label, minus, value, plus])
    # bind value state to box state
    box.observe = value.observe
    return box

In [None]:
def yesno_widget(initial_value, description, ):
    yes = W.ToggleButton(
        value=initial_value,
        description="yes",
        button_style="success" if initial_value else "",
        layout=W.Layout(width="75px")
    )
    no = W.ToggleButton(
        value=not initial_value,
        description="no",
        button_style="" if initial_value else "danger",
        layout=W.Layout(width="75px")
    )
    desc = W.Label(value=description)

    def toggle_yes(ev):
        v = ev["new"]
        if v:
            setattr(yes, "button_style", "success")
            setattr(no, "button_style", "")
            setattr(no, "value", False)

    def toggle_no(ev):
        v = ev["new"]
        if v:
            setattr(no, "button_style", "danger")
            setattr(yes, "button_style", "")
            setattr(yes, "value", False)

    yes.observe(toggle_yes, "value")
    no.observe(toggle_no, "value")

    box = W.HBox([desc, yes, no])
    box.observe = yes.observe
    return box

In [None]:
options_style = dict(display="flex", margin="4px",
                     justify_content="space-between",
#                      justify_items="stretch",
                     align_content='stretch',
                     flex_flow='row',
                     width="75%",
#                     flex="2 0 auto"
                    )

buttons_style = dict()

data_acc = mmk.ConfigView(
            cfg.data,
            UI.Param(name="sources",
                    widget=UI.FileWidget().widget),
            UI.Param(name="sr",
              widget=W.ToggleButtons(description="Sample Rate: ", options=[16000, 22050, 32000, 44100], index=1,
                                     style=buttons_style,
                                     layout=options_style),
             ),
            UI.Param(name="n_fft", 
                  widget=W.ToggleButtons(description="FFT Size: ", options=[512, 1024, 2048, 4096], index=2, layout=options_style)),
            UI.Param(name="hop_length", 
                 widget=W.ToggleButtons(description="Overlapp: ", options=[1, 2, 4, 8], index=2, layout=options_style),
                 compute=lambda conf, ov: conf.n_fft // ov,),
            UI.Param(name="coordinate",
                  widget=W.ToggleButtons(description="Coordinates: ", options=["magnitude", "polar"], index=0, layout=options_style),
                 compute=lambda conf, co: co[:3]),
            UI.Param(name="db_path",
                    widget=W.Text(value="train.h5", description="Db Path: ", layout=options_style))
        ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs), titles=("Data",), selected_index=0)
display(data_acc)

In [None]:
train_acc = mmk.ConfigView(
            cfg.training,
            UI.Param(name="batch_size",
                  widget=UI.pw2_widget("Batch Size: "),
                  compute=lambda conf, v: int(v)
                 ),
            UI.Param(name="MONITOR_TRAINING",
                  widget=UI.yesno_widget(True, "Monitor Training: ")
                 ),
            UI.Param(name="betas",
                 widget=W.FloatLogSlider(
#              description="Beta 1",
                 value=.9, min=-.75, max=0., step=.001, base=2,
#              layout=W.Layout(display='flex')
                 ),
                 compute=lambda conf, ev: (ev, conf.betas[1])),
            UI.Param(name="betas",
                 widget=W.FloatLogSlider(value=.9, min=-.75, max=0., step=.001, base=2),
                 compute=lambda conf, ev: (conf.betas[0], ev)),
                ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs), titles=("Training",), selected_index=0)

display(train_acc)

In [None]:
net_acc = mmk.ConfigView(
            cfg.network,
            UI.Param(name="core",
                    widget=mmk.ConfigView(
                        cfg.network.core,
                        UI.Param(name="blocks",
                                widget=W.Text(value=str(cfg.network.core.blocks)[1:-1], description="N layers per block", layout=options_style,
                                             style={'description_width': 'auto'}),
                                compute=lambda c, v: tuple(map(int, (s for s in v.split(",") if s not in ("", " "))))),
                        UI.Param(name="dims_dilated",
                                widget=UI.pw2_widget("Number of units per layer: "),
                                compute=lambda c, v: int(v)),
                    ).as_widget(W.VBox, ))
    ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs), titles=("Network",), selected_index=0, layout=dict(width="100%"))

display(net_acc)

In [None]:



train_view = mmk.ConfigView(
    cfg,
    # ---------------------------
    UI.Param(
        name="data",
        widget=),
    # ---------------------------
    UI.Param(
        name="network",
        widget=),
    # ---------------------------
    UI.Param(
        name="training",
        widget=)
    # Layout
    # Config Validation
)



display(*train_view.widgets)

In [None]:
W.FloatLogSlider(value=.9, min=-.75, max=0., step=.001, base=2).__dict__

In [None]:
cfg.network.core

In [None]:
"""
PRBLM: Python-config  <<=OMEGACONF=>>  YAML
1. use only primitive types in configs
    - no Features, not Modules, no functions...
2. json
"""

from omegaconf import OmegaConf

cg = OmegaConf.merge(cfg, OmegaConf.create(OmegaConf.to_yaml(OmegaConf.structured(cfg))))
cg = OmegaConf.to_object(cg)
cfg, cg

In [None]:
from pydantic import BaseModel
import dataclasses as dtc

class Model(mmk.TrainFreqnetConfig, BaseModel):
    pass

Model(**dtc.asdict(cfg))

In [None]:
mmk.demos.freqnet.main(cfg)

In [None]:
vars(cg.network.core)