In [None]:
# export
from snkrfinder.imports import *
from snkrfinder.core import *
from snkrfinder.data import *
from snkrfinder.model import *
from snkrfinder.cvae import *
from snkrfinder.widgets import *

#from ipywidgets import widgets
#from ipywidgets import HBox,VBox,widgets,Button,Checkbox,Dropdown,Layout,Box,Output,Label,FileUpload
# from fastai.vision.widgets import *  # in imports
#from ipywidgets import Tab #fastai didn't include Tab
import seaborn as sns


In [None]:
#hide
from nbdev.showdoc import *

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#hide
print(f"fastai v{fastai.__version__}")

fastai v2.2.5


In [None]:
os.chdir(L_ROOT)

In [None]:
# this should go into a utils or cfg module
HOME = get_home()

### Sneaker Finder 2.0 aka fast-sneak

In [None]:
# ts = [display(img.to_thumb(200,200))]+[display(i.to_thumb(100,100)) for i in images]

# nnc = carousel(ts, width='1200px')

# out_pl = HBox[display(img.to_thumb(200,200)), carousel]

In [None]:
def pack_featurenets(model_list):
    return {m:create_cnn_featurenet(m,to_cuda=False) for m in model_list }
    

MODELS = pack_featurenets(['mobilenet_v2','resnet18'])

In [None]:
# %matplotlib inline

load_pipe    = Pipeline([PILImage.create,
                         FeatsResize(size=IMG_SIZE, method='pad', pad_mode='border')] )

prep_tf_pipe = Pipeline([ToTensor(),
                         IntToFloatTensor(),
                         Normalize.from_stats(*imagenet_stats,cuda=False)])




# #set defaults 
# neighs = knns['small']
# umap = reducers['small']
# DEFAULT GLOBALS to start
im_sz = 'small'
model = MODELS['mobilenet_v2']
filename = f"zappos-50k-{model.name}-features_sort_3"
df = pd.read_pickle(f"data/{filename}.pkl")

num_neighs = 5
# save the knns and umap reducers for later use
filename = f"data/{model.name}-knn{num_neighs}Xsize.pkl"
knns = load_pickle(filename)

filename = f"data/{model.name}-umapXsize.pkl"
reducers = load_pickle(filename)   


In [None]:
def plot_umap(data,im_sz,mname):
    fig, ax = plt.subplots()
    sns.scatterplot(
        x="umap-one",
        y="umap-two",
        hue="Category",
        hue_order = ['Sneakers', 'Shoes', 'Boots','Slippers'],
        palette=sns.color_palette("hls", 4),
        data=data.sample(frac=sld_sampfrac.value),
        legend="full",
        alpha=0.3,ax=ax
    )
    ax.set_aspect('equal', 'datalim')
    ax.set_title(f'UMAP projection of {mname} embedded UT-Zappos data (sz={IMG_SIZES[im_sz]})', fontsize=12)
    return ax

def on_click_find_similar(change):
    """ 
    this is the 'go' signal
    """
    global im_sz
    update_knn_reducer(im_sz)
    find_similar()

def find_similar():
    """ 
    find the knn
    """
    global knns,model,im_sz,df
    neighs = knns[im_sz]
    
    # load the image
    im = btn_upload.data[-1]
    img = load_pipe(im)
    tensor_im = prep_tf_pipe(img)
    feats = get_convnet_feature(model, tensor_im)
    
    # find the neighbors
    distance, nn_index = neighs.kneighbors(feats.numpy(), return_distance=True)    
    dist = distance.tolist()[0] 
    # fix path to the database...
    neighbors = df.iloc[nn_index.tolist()[0]].copy()
    nbr = neighbors.index

    
    #widget(im, max_width="292px")
    
    images = [ PILImage.create(D_ROOT/DBS['zappos']/f) for f in neighbors.path]
    
    ts = [VBox([widget(im, max_width="292px"),Label(f"d={d:.03f}")]) for im,d in zip(images,dist)]
    target_im = img.to_thumb(200,200)
    
    car_nn = carousel(ts, width='1200px')
    
    out_nn_imgs.clear_output()
    with out_nn_imgs:
        display(HBox([widget(target_im, max_width="500px"), car_nn]))

    #lbl_neighs.value = f'distances: {dist}


def update_knn_reducer(size):
    "update knn & reducer for new size im, but nothing is recalculated until the btn_run is clicked"
    # set to the current 
    global model,knns,reducers,im_sz,df
    im_sz = size
        
    umap = reducers[im_sz]
    neighs = knns[im_sz]

    features = f"features_{SIZE_ABBR[im_sz]}"
    data = df[['Category',features]].copy()
    
    db_feats = np.vstack(data[features].values)   
    # this is probably the bottleneck...
    embedding = umap.transform(db_feats)    
    data['umap-one'] = embedding[:,0]
    data['umap-two'] = embedding[:,1] 

    out_umap.clear_output()
    with out_umap:
        ax = plot_umap(data,size,model.name)
        plt.show(ax)

    find_similar()  

def update_model(model_name,size):
    " update the model but nothing is recalculated until the btn_run is clicked"
    #key = {sz:i for (i,sz) in enumerate(IMG_SIZES)}
    global model,knns,reducers,df
    model = MODELS[model_name]

    num_neighs = 5
    if model_name!=model.name :  print(f"dammit, '{model_name}'!='{model.name}'")
    # save the knns and umap reducers for later use
    knns = load_pickle(f"data/{model.name}-knn{num_neighs}Xsize.pkl")

    reducers = load_pickle(f"data/{model.name}-umapXsize.pkl")   
    
    filename = f"zappos-50k-{model.name}-features_sort_3"
    df = pd.read_pickle(f"data/{filename}.pkl")

    update_knn_reducer(size)

#Events
def dd_im_size_eh(change):
    update_knn_reducer(change.new)
    
def dd_model_eh(change):
    update_model(change.new,dd_im_size.value)


#define my widgets
btn_run = Button(description='Find similar sneaks!',layout = Layout(width='25%', height='80px'))
btn_upload = FileUpload(layout = Layout(width='25%', height='80px'))

out_umap = Output() # not doing anything here yet...
# lbl_neighs = Label() # labels for neighbors
out_nn_imgs = Output() # VBox([out_im,out_car])

dd_im_size = Dropdown(options=IMG_SIZES.keys(),value='small',description='Image Size:' )                       
dd_model = Dropdown(options=['mobilenet_v2','resnet18'], 
                    value='mobilenet_v2',
                    disabled=False,
                    description='Model:')
                    #,layout = Layout(width='40%') ) #style=style,
    
sld_sampfrac = widgets.FloatSlider(value=.5,
                min=0,
                max=1.0,
                step=0.05,
                description='sample %:',
                disabled=False,
                continuous_update=False,
                orientation='vertical',
                readout=True,
                readout_format='.2f',
)
              
#item_layout = widgets.Layout(margin='0 0 50px 0')
#input_widgets = widgets.HBox([dd_model, dd_im_size])

knn_select = HBox([dd_model, dd_im_size])


tab = widgets.Tab(children=[out_nn_imgs,HBox([out_umap, sld_sampfrac]) ] )#,layout=item_layout)
tab.set_title(0, 'Dataset Exploration')
tab.set_title(1, 'UMAP Plot')

cta = HBox([widgets.Label('Find your sneaker!    '),
            btn_upload,
            btn_run])

console = Label()
dashboard =  VBox([ cta,
                    knn_select,
                    tab,
                  console])

console = Label()
# actions
btn_run.on_click(on_click_find_similar)
dd_im_size.observe(dd_im_size_eh, names='value')
dd_model.observe(dd_model_eh, names='value')
# dd_lat_dim.observe(dd_lat_dim_eh, names='value')



In [None]:
dashboard

VBox(children=(HBox(children=(Label(value='Find your sneaker!    '), FileUpload(value={}, description='Upload'â€¦

In [None]:
# import ipywidgets as widgets
from ipywidgets import FloatSlider, interact
# from fastai2.vision.all import *
# from fastai2.vision.widgets import *
from IPython.display import clear_output, Javascript #,display
# warnings.filterwarnings("ignore",category=matplotlib.cbook.mplDeprecation)

style = {'description_width': 'initial'}

RED = '\033[31m'
BLUE = '\033[94m'
GREEN = '\033[92m'
BOLD   = '\033[1m'
ITALIC = '\033[3m'
RESET  = '\033[0m'

In [None]:
def dashboard_info():
    """GUI for first accordion window: 'info' dashboard """
    import torchvision
    try:
        import fastai2; fastver = fastai2.__version__
    except ImportError:
        fastver = 'fastai not found'
    try:
        import fastprogress; fastprog = fastprogress.__version__
    except ImportError:
        fastprog = 'fastprogress not found'
    try:
        import fastpages; fastp = fastpages.__version__
    except ImportError:
        fastp = 'fastpages not found'
    try:
        import nbdev; nbd = nbdev.__version__
    except ImportError:
        nbd = 'nbdev not found'

    print (BOLD +  RED + '>> fastGUI\n')
    button = widgets.Button(description='System', button_style='success')
    display(button)

    out = widgets.Output()
    display(out)

    def on_button_clicked_info(b):
        with out:
            clear_output()
            print(BOLD + BLUE + "fastai2 version: " + RESET + ITALIC + str(fastver))
            print(BOLD + BLUE + "nbdev version: " + RESET + ITALIC + str(nbd))
            print(BOLD + BLUE + "fastprogress version: " + RESET + ITALIC + str(fastprog))
            print(BOLD + BLUE + "fastpages version: " + RESET + ITALIC + str(fastp) + '\n')
            print(BOLD + BLUE + "python version: " + RESET + ITALIC + str(sys.version))
            print(BOLD + BLUE + "torchvision: " + RESET + ITALIC + str(torchvision.__version__))
            print(BOLD + BLUE + "torch version: " + RESET + ITALIC + str(torch.__version__))
            print(BOLD + BLUE + "\nCuda: " + RESET + ITALIC + str(torch.cuda.is_available()))
            print(BOLD + BLUE + "cuda version: " + RESET + ITALIC + str(torch.version.cuda))

    button.on_click(on_button_clicked_info)
    

In [None]:
def dashboard_dataset():
    """GUI for second accordion window: 'dataset' dashboard """
    dashboard_dataset.datas = widgets.ToggleButtons(
        options=['PETS', 'CIFAR', 'IMAGENETTE_160', 'IMAGEWOOF_160', 'MNIST_TINY'],
        description='Choose',
        value=None,
        disabled=False,
        button_style='info',
        tooltips=[''],
        style=style
    )
    display(dashboard_dataset.datas)

    button = widgets.Button(description='Explore', button_style='success')
    display(button)
    out = widgets.Output()
    display(out)
    def on_button_explore(b):
        with out:
            clear_output()
            ds_choice()
            show_db2()
    button.on_click(on_button_explore)

In [None]:
#Helpers for dashboard two - dataset
def ds_choice():
    """Helper for dataset choices"""
    if dashboard_dataset.datas.value == 'PETS':
        ds_choice.source = untar_data(URLs.DOGS)
    elif dashboard_dataset.datas.value == 'CIFAR':
        ds_choice.source = untar_data(URLs.CIFAR)
    elif dashboard_dataset.datas.value == 'IMAGENETTE_160':
        ds_choice.source = untar_data(URLs.IMAGENETTE_160)
    elif dashboard_dataset.datas.value == 'IMAGEWOOF_160':
        ds_choice.source = untar_data(URLs.IMAGEWOOF_160)
    elif dashboard_dataset.datas.value == 'MNIST_TINY':
        ds_choice.source = untar_data(URLs.MNIST_TINY)

def plt_classes():
    ds_choice()
    print(BOLD + BLUE + "Dataset: " + RESET + BOLD + RED + str(dashboard_dataset.datas.value))
    """Helper for plotting classes in folder"""
    Path.BASE_PATH = ds_choice.source
    train_source = (ds_choice.source/'train/').ls().items
    print(BOLD + BLUE + "\n" + "No of classes: " + RESET + BOLD + RED + str(len(train_source)))

    num_l = []
    class_l = []
    for j, name in enumerate(train_source):
        fol = (ds_choice.source/name).ls().sorted()
        names = str(name)
        class_split = names.split('train')
        class_l.append(class_split[1])
        num_l.append(len(fol))

    y_pos = np.arange(len(train_source))
    performance = num_l

    fig = plt.figure(figsize=(7,7))
    plt.style.use('seaborn')
    plt.bar(y_pos, performance, align='center', alpha=0.5, color=['black', 'red', 'green', 'blue', 'cyan'])
    plt.xticks(y_pos, class_l, rotation=90)
    plt.ylabel('Images')
    plt.title('Images per Class')
    plt.show()

def display_images():
    """Helper for displaying images from folder"""
    train_source = (ds_choice.source/'train/').ls().items
    for i, name in enumerate(train_source):
        fol = (ds_choice.source/name).ls().sorted()
        fol_disp = fol[0:5]
        filename = fol_disp.items
        fol_tensor = [tensor(Image.open(o)) for o in fol_disp]
        img = fol_tensor[0]
        print(BOLD + BLUE + "Loc: " + RESET + str(name) + " " + BOLD + BLUE + "Number of Images: " + RESET +
              BOLD + RED + str(len(fol)))

        fig = plt.figure(figsize=(15,15))
        columns = 5
        rows = 1
        ax = []

        for i in range(columns*rows):
            for i, j in enumerate(fol_tensor):
                img = fol_tensor[i]    # create subplot and append to ax
                ax.append( fig.add_subplot(rows, columns, i+1))
                ax[-1].set_title("ax:"+str(filename[i]))  # set title
                plt.tick_params(bottom="on", left="on")
                plt.xticks([])
                plt.imshow(img)
        plt.show()
def browse_images():
    print(BOLD + BLUE + "Use slider to choose image" + RESET)
    ds_choice()
    items = get_image_files(ds_choice.source/'train/')
    n = len(items)
    def view_image(i):
        plt.imshow(Image.open(items[i]), cmap=plt.cm.gray_r, interpolation='nearest')
        plt.title('Training: %s' % items[i])
        browse_images.img = items[i]
        plt.show()
    interact(view_image, i=(0,n-1))

def show_db2():
    o_classes = widgets.Output()
    o_disp_imgs = widgets.Output()
    o_browse_imgs = widgets.Output()
    with o_classes:
        plt_classes()
    with o_disp_imgs:
        display_images()
    with o_browse_imgs:
        browse_images()
    view_left = VBox([o_classes, o_browse_imgs],layout = Layout(width='40%'))
    view_full = HBox([view_left, o_disp_imgs])
    display(view_full)

In [None]:
# these aren't currently integrated... they just hold the "params"



# orphan. 
# aug_params() -> aug_dash_choice()

def aug_dash_choice():
    """Augmention parameter display helper"""
    button_aug_dash = widgets.Button(description='View', button_style='success')
   
    item_erase_val= widgets.HBox([aug.b_max, aug.b_pval, aug.b_asp, aug.b_len, aug.b_ht])
    item_erase = widgets.VBox([aug_params.hh, item_erase_val])

    item_contrast_val = widgets.HBox([aug.b1_max, aug.b1_pval, aug.b1_draw])
    item_contrast = widgets.VBox([aug_params.cc, item_contrast_val])

    item_rotate_val = widgets.HBox([aug.b2_max, aug.b2_pval])
    item_rotate = widgets.VBox([aug_params.dd, item_rotate_val])

    item_warp_val = widgets.HBox([aug.b3_mag, aug.b3_pval])
    item_warp = widgets.VBox([aug_params.ee, item_warp_val])

    item_bright_val = widgets.HBox([aug.b4_max, aug.b4_pval])
    item_bright = widgets.VBox([aug_params.ff, item_bright_val])

    item_dihedral_val = widgets.HBox([aug.b5_pval, aug.b5_draw])
    item_dihedral = widgets.VBox([aug_params.gg, item_dihedral_val])

    item_zoom_val = widgets.HBox([aug.b6_zoom, aug.b6_pval])
    item_zoom = widgets.VBox([aug_params.ii, item_zoom_val])

    items = [item_erase, item_contrast, item_rotate, item_warp, item_bright, item_dihedral, item_zoom]
    dia = Box(items, layout=Layout(
                    display='flex',
                    flex_flow='column',
                    flex_grow=0,
                    flex_wrap='wrap',
                    border='solid 1px',
                    align_items='flex-start',
                    align_content='flex-start',
                    justify_content='space-between',
                    width='flex'
                    ))
    display(dia)
    display(button_aug_dash)
    aug_dash_out = widgets.Output()
    display(aug_dash_out)
    def on_button_aug_dash(b):
        with aug_dash_out:
            clear_output()
            print(browse_images.img)
    button_aug_dash.on_click(on_button_aug_dash)

    
    
# aug_params() -> aug_dash_choice()
def aug_params():
    """If augmentations is choosen show available parameters"""
    print(BOLD + BLUE + "Choose Augmentation Parameters: ")
    button_params = widgets.Button(description='Confirm', button_style='success')

    aug_params.hh = widgets.ToggleButton(value=False, description='Erase', button_style='info',
                                      style=style)
    aug_params.cc = widgets.ToggleButton(value=False, description='Contrast', button_style='info',
                                      style=style)
    aug_params.dd = widgets.ToggleButton(value=False, description='Rotate', button_style='info',
                                      style=style)
    aug_params.ee = widgets.ToggleButton(value=False, description='Warp', button_style='info',
                                      style=style)
    aug_params.ff = widgets.ToggleButton(value=False, description='Bright', button_style='info',
                                      style=style)
    aug_params.gg = widgets.ToggleButton(value=False, description='DihedralFlip', button_style='info',
                                      style=style)
    aug_params.ii = widgets.ToggleButton(value=False, description='Zoom', button_style='info',
                                      style=style)

    qq = widgets.HBox([aug_params.hh, aug_params.cc, 
                       aug_params.dd, aug_params.ee, 
                       aug_params.ff, aug_params.gg, aug_params.ii])
    display(qq)
    display(button_params)
    o_aug_params = widgets.Output()
    display(o_aug_params)
    
    def on_button_params_click(b):
        with o_aug_params:
            clear_output()
            aug_dash_choice()
    button_params.on_click(on_button_params_click)

    

def aug():
    """Aug choice helper"""
    #Erase
    if aug_params.hh.value == True:
            aug.b_max = FloatSlider(min=0,max=50,step=1,value=0, description='max count',
                                     orientation='horizontal', disabled=False)
            aug.b_pval = FloatSlider(min=0,max=1,step=0.1,value=0, description=r"$p$",
                                     orientation='horizontal', disabled=False)
            aug.b_asp = FloatSlider(min=0.1,max=5, step=0.1, value=0.3, description=r'$aspect$',
                                     orientation='horizontal', disabled=False)
            aug.b_len = FloatSlider(min=0.1,max=5, step=0.1, value=0.3, description=r'$sl$',
                                     orientation='horizontal', disabled=False)
            aug.b_ht = FloatSlider(min=0.1,max=5, step=0.1, value=0.3, description=r'$sh$',
                                     orientation='horizontal', disabled=False)
            aug.erase_code = 'this is ERASE on'
    if aug_params.hh.value == False:
            aug.b_max = FloatSlider(min=0,max=10,step=1,value=0, description='max count',
                                     orientation='horizontal', disabled=True)
            aug.b_pval = FloatSlider(min=0,max=1,step=0.1,value=0, description='p',
                                     orientation='horizontal', disabled=True)
            aug.b_asp = FloatSlider(min=0.1,max=1.7,value=0.3, description='aspect',
                                     orientation='horizontal', disabled=True)
            aug.b_len = FloatSlider(min=0.1,max=1.7,value=0.3, description='length',
                                     orientation='horizontal', disabled=True)
            aug.b_ht = FloatSlider(min=0.1,max=1.7,value=0.3, description='height',
                                     orientation='horizontal', disabled=True)
            aug.erase_code = 'this is ERASE OFF'
    #Contrast
    if aug_params.cc.value == True:
            aug.b1_max = FloatSlider(min=0,max=0.9,step=0.1,value=0.2, description='max light',
                                  orientation='horizontal', disabled=False)
            aug.b1_pval = FloatSlider(min=0,max=1.0,step=0.05,value=0.75, description='p',
                                  orientation='horizontal', disabled=False)
            aug.b1_draw = FloatSlider(min=0,max=100,step=1,value=1, description='draw',
                                  orientation='horizontal', disabled=False)
    else:
            aug.b1_max = FloatSlider(min=0,max=0.9,step=0.1,value=.1, description='max light',
                                  orientation='horizontal', disabled=True)
            aug.b1_pval = FloatSlider(min=0,max=1.0,step=0.05,value=0.75, description='p',
                                  orientation='horizontal', disabled=True)
            aug.b1_draw = FloatSlider(min=0,max=100,step=1,value=1, description='draw',
                                  orientation='horizontal', disabled=True)
    #Rotate
    if aug_params.dd.value == True:
            aug.b2_max = FloatSlider(min=0,max=10,step=1,value=10, description='max degree',
                                  orientation='horizontal', disabled=False)
            aug.b2_pval = FloatSlider(min=0,max=1,step=0.1,value=0.9, description='p',
                                  orientation='horizontal', disabled=False)
    else:
            aug.b2_max = FloatSlider(min=0,max=10,step=1,value=10, description='max degree',
                                  orientation='horizontal', disabled=True)
            aug.b2_pval = FloatSlider(min=0,max=1,step=0.1,value=.9, description='p',
                                  orientation='horizontal', disabled=True)
    #Warp
    if aug_params.ee.value == True:
            aug.b3_mag = FloatSlider(min=0,max=10,step=1,value=1, description='magnitude',
                                  orientation='horizontal', disabled=False)
            aug.b3_pval = FloatSlider(min=0,max=1,step=0.1,value=0, description='p',
                                  orientation='horizontal', disabled=False)
    else:
            aug.b3_mag = FloatSlider(min=0,max=10,step=1,value=0, description='magnitude',
                                  orientation='horizontal', disabled=True)
            aug.b3_pval = FloatSlider(min=0,max=10,step=1,value=0, description='p',
                                  orientation='horizontal', disabled=True)
    #Bright
    if aug_params.ff.value == True:
            aug.b4_max = FloatSlider(min=0,max=10,step=1,value=0, description='max light',
                                  orientation='horizontal', disabled=False)
            aug.b4_pval = FloatSlider(min=0,max=1,step=0.1,value=0, description='p',
                                  orientation='horizontal', disabled=False)
    else:
            aug.b4_max = FloatSlider(min=0,max=10,step=1,value=0, description='max_light',
                                  orientation='horizontal', disabled=True)
            aug.b4_pval = FloatSlider(min=0,max=1,step=0.1,value=0, description='p',
                                  orientation='horizontal', disabled=True)
    #DihedralFlip
    if aug_params.gg.value == True:
            aug.b5_pval = FloatSlider(min=0,max=1,step=0.1, description='p',
                                     orientation='horizontal', disabled=False)
            aug.b5_draw = FloatSlider(min=0,max=7,step=1, description='p',
                                     orientation='horizontal', disabled=False)
    else:
            aug.b5_pval = FloatSlider(min=0,max=1,step=0.1, description='p',
                                     orientation='horizontal', disabled=True)
            aug.b5_draw = FloatSlider(min=0,max=7,step=1, description='p',
                                     orientation='horizontal', disabled=True)
    #Zoom
    if aug_params.ii.value == True:
            aug.b6_zoom = FloatSlider(min=1,max=5,step=0.1, description='max_zoom',
                                     orientation='horizontal', disabled=False)
            aug.b6_pval = FloatSlider(min=0,max=1,step=0.1, description='p',
                                     orientation='horizontal', disabled=False)
    else:
            aug.b6_zoom = FloatSlider(min=1,max=5,step=0.1, description='max_zoom',
                                     orientation='horizontal', disabled=True)
            aug.b6_pval = FloatSlider(min=0,max=1,step=1, description='p',
                                     orientation='horizontal', disabled=True)
            
    

In [None]:


# currently just doing this augment_show()-> aug_dash()->  aug_img() -> show_image_viola()
def aug_dash():
    """GUI for augmentation dashboard"""
    
    tg = widgets.Button(description='Pad', disabled=True, button_style='info')
    aug_dash.pad = widgets.ToggleButtons(value='Reflection', options=['Zeros', 'Reflection', 'Border'], description='',
                                         button_style='info',style=style, layout=Layout(width='auto'))
    
    th = widgets.Button(description='ResizeMethod', disabled=True, button_style='warning')
    aug_dash.rzm = widgets.ToggleButtons(value='Squish', options=['Squish', 'Pad', 'Crop'], description='',
                                         button_style='warning', style=style, layout=Layout(width='auto'))
    
    ti = widgets.Button(description='Resize', disabled=True, button_style='primary')
    aug_dash.res = widgets.ToggleButtons(value='128', options=['28', '64', '128', '194', '254'], description='',
                                         button_style='primary', style=style, layout=Layout(width='auto'))

    # --------- this code is from aug_params()  ----------------  #
    aug_params.hh = widgets.ToggleButton(value=False, description='Erase', button_style='info',
                                      style=style)
    aug_params.cc = widgets.ToggleButton(value=False, description='Contrast', button_style='info',
                                      style=style)
    aug_params.dd = widgets.ToggleButton(value=False, description='Rotate', button_style='info',
                                      style=style)
    aug_params.ee = widgets.ToggleButton(value=False, description='Warp', button_style='info',
                                      style=style)
    aug_params.ff = widgets.ToggleButton(value=False, description='Bright', button_style='info',
                                      style=style)
    aug_params.gg = widgets.ToggleButton(value=False, description='DihedralFlip', button_style='info',
                                      style=style)
    aug_params.ii = widgets.ToggleButton(value=False, description='Zoom', button_style='info',
                                      style=style)

    qq = widgets.HBox([aug_params.hh, aug_params.cc, 
                       aug_params.dd, aug_params.ee, 
                       aug_params.ff, aug_params.gg, aug_params.ii])
    
    # --------- this code is from aug_params()  ----------------  #

    it2 = [tg, aug_dash.pad]
    it3 = [th, aug_dash.rzm]
    it4 = [ti, aug_dash.res]
    il = widgets.HBox(it2)
    ik = widgets.HBox(it3)
    ij = widgets.HBox(it4)
    ir = widgets.VBox([il, ik, ij])
    display(ir)
    print(BOLD + BLUE + "Choose Augmentation Parameters: ")
    
    display(qq)
    aug_img()

def show_image_viola(im, **kwargs):
    "Show_image helper for viewing images in Voila"
    # Handle pytorch axis order
    if hasattrs(im, ('data','cpu','permute')):
        im = im.data.cpu()
        if im.shape[0]<5: im=im.permute(1,2,0)
    elif not isinstance(im,np.ndarray): im=array(im)
    # Handle 1-channel images
    if im.shape[-1]==1: im=im[...,0]
    it = Tensor(im)
    img = Image.fromarray(im, 'RGB')
    display(img)

# currently just doing this augment_show()-> aug_dash()->  aug_img() -> show_image_viola()
def aug_img():
    aug_img_b = widgets.Button(description='Confirm', button_style='success')
    display(aug_img_b)
    aug_img_out = widgets.Output()
    display(aug_img_out)
    def aug_img_(b):
        with aug_img_out:
            clear_output()
            aug_img = browse_images.img
            imgt = Image.open(aug_img)
            h1, w1 = imgt.shape
            pil_img = PILImage(PILImage.create(aug_img).resize((w1,h1))) #flip
            print(BOLD + BLUE + 'Size:' + RED + aug_dash.res.value + BLUE + ' ResizeMode:' + RED +
                  aug_dash.rzm.value + BLUE + ' Padding:' + RED + aug_dash.pad.value + RESET)
            if aug_dash.rzm.value == 'Pad': method = ResizeMethod.Pad
            if aug_dash.rzm.value == 'Squish': method = ResizeMethod.Squish
            if aug_dash.rzm.value == 'Crop': method = ResizeMethod.Crop
            if aug_dash.pad.value == 'Zeros': pad = PadMode.Zeros
            if aug_dash.pad.value == 'Border': pad = PadMode.Border
            if aug_dash.pad.value == 'Reflection': pad = PadMode.Reflection
            rsz = Resize(int(aug_dash.res.value), method=method, pad_mode=pad)
            display(show_image_viola(rsz(pil_img)))
    aug_img_b.on_click(aug_img_)

In [None]:
# currently just doing this augment_show()-> aug_dash()->  aug_img() -> show_image_viola()
def augment_show():
    aug_button = widgets.Button(description='Augmentations', button_style='success')
    display(aug_button)
    aug_out = widgets.Output()
    display(aug_out)
    def on_aug_button(b):
        with aug_out:
            clear_output()
            j = widgets.Output()
            u = widgets.Output()
            with j:
                print(browse_images.img)
                display(Image.open(browse_images.img))
            with u:
                aug_dash()
            display(HBox([j, u]))
    aug_button.on_click(on_aug_button)

In [None]:
def display_ui():
    """ Display tabs for visual display"""
    o_info = widgets.Output()
    o_ds = widgets.Output()
    o_aug = widgets.Output()
    
#     # stores data?
#     data_info = pd.DataFrame(np.random.normal(size = 50))
#     data_ds = pd.DataFrame(np.random.normal(size = 100))
#     data_aug = pd.DataFrame(np.random.normal(size = 150))

    with o_info: #info
        clear_output()
        dashboard_info()

    with o_ds: #data
        clear_output()
        dashboard_dataset()

    with o_aug: #augmentation
        clear_output()
        augment_show()

    display_ui.tab = widgets.Tab(children = [o_info, o_ds, o_aug])
    display_ui.tab.set_title(0, 'Info')
    display_ui.tab.set_title(1, 'Data')
    display_ui.tab.set_title(2, 'Augmentation')
    display(display_ui.tab)

In [None]:


display_ui()



Tab(children=(Output(), Output(), Output()), _titles={'0': 'Info', '1': 'Data', '2': 'Augmentation'})