In [None]:
# default_exp widgets.feats

In [None]:
# export
from snkrfinder.imports import *
from snkrfinder.core import *
from snkrfinder.data.munge import *
from snkrfinder.data.load import *
from snkrfinder.model.core import *
from snkrfinder.model.transfer import *
from snkrfinder.model.cvae import *

#from ipywidgets import widgets
#from ipywidgets import HBox,VBox,widgets,Button,Checkbox,Dropdown,Layout,Box,Output,Label,FileUpload
# from fastai.vision.widgets import *  # in imports
#from ipywidgets import Tab #fastai didn't include Tab
import seaborn as sns

In [None]:
#hide
from nbdev.showdoc import *

# first snkrfinder.widgets.core



## OVERVIEW: model module- MobileNet_v2 feature extractor

This is a project initiated while an Insight Data Science fellow.  It grew out of my interest in making data driven tools in the fashion/retail space I had most recently been working.   The original over-scoped idea was to make a shoe desighn tool which could quickly develop some initial sneakers based on choosing some examples, and some text descriptors.  Designs are constrained by the "latent space" defined (discovered?) by a database of shoe images.  However, given the 3 week sprint allowed for development, I pared the tool down to a simple "aesthetic" recommender for sneakers, using the same idea of utilizing an embedding space defined by the database fo shoe images.

Widgets:

These are litterally the 2.0 version of SneakerFinder.  Suitable to import and run a simple viola notebook page.


## Part 2: create tools out of widgets... i.e. make SneakerFinder 2.0 in the fastai framework

In [None]:
#hide
print(fastai.__version__)

2.2.7


In [None]:
print(Path().cwd())
os.chdir(L_ROOT)
print(Path().cwd())

/home/ergonyc/Projects/Project2.0/snkrfinder/nbs
/home/ergonyc/Projects/Project2.0/snkr-finder


In [None]:
# this should go into a utils or cfg module
HOME = get_home()

### get the decapitated featurenet e.g. mobilnet_v2 or resnet

### get the data

In [None]:
#hide
filename = ZAPPOS_FEATS_ALL_SORT # "zappos-50k-mobilenetv2-features_sort_3"
df = pd.read_pickle(f"data/{filename}.pkl")


In [None]:
#hide
images_path = D_ROOT/DBS['zappos']


### SANITY CHECK: 

Just want to chack that we can we extract single features that match those we just calculated.

In [None]:
# sicne we made sure our indices match up with our "classes" things should be easy
query_image = "Shoes/Sneakers and Athletic Shoes/Nike/7716996.288224.jpg"

df.loc[df.path==query_image,['path','classes_md']]



Unnamed: 0,path,classes_md
27079,Shoes/Sneakers and Athletic Shoes/Nike/7716996.288224.jpg,27079


The DataBlock performed a number of processing steps to prepare the images for embedding into the MobileNet_v2 space (a 2*1280 vector). (Because we pooled space as as _average_ and _max_ we have 2x dimensions.) 



I have made functional wrappers as well as a _fastai_ `pipeline` [02_models.ipynb], and confirmed that they are equivalent.  

For example _by hand_ pipeline with the MobileNet V2 is:

```python
# get net, prep image
mnet1 = get_mnetV2_feature_net(to_cuda=False)
t_image1 = load_and_prep_sneaker(image_path,size=IMG_SIZE,to_cuda=False)
    
```

versus the fastai based objects I defined:

```python
# FASTAI: get net, prep image, get feats 

mnet2 = create_cnn_featurenet(torchvision.models.mobilenet_v2,to_cuda=True)
t_image2 = load_and_prep_tf_pipe() # Pipeline
        
```

In [None]:
    
mnet1 = get_mnetV2_feature_net()
query_t1 = load_and_prep_sneaker(images_path/QUERY_IM)
test_feats1 = get_convnet_feature(mnet1,query_t1)

In [None]:
mnet2 = create_cnn_featurenet('mobilenet_v2')
query_t2 = load_and_prep_tf_pipe(images_path/QUERY_IM)
test_feats2 = get_convnet_feature(mnet2,query_t2)


test_feats1.mean(),test_feats2.mean(),(test_feats1-test_feats2).max(),
#PILImage.create((query_t1-query_t2).squeeze())

(TensorImage(1.2784), TensorBase(1.2784), TensorImage(0.))

Now I have the "embeddings" of the database in the mobileNet_v2 output space.  I can do a logistic regression on these vectors (should be identical to mapping these 1000 vectors to 4 categories (Part 3)) but I can also use an approximate KNN in this space to run the SneakerFinder tool.


## k-Nearest Neighbors: a proxy for "similar"

There realy isn't a ground truth to refer to for similarity of aesthetic preference so I'll start with a simple "gut" test: inspection of neighbors in our feature space.  Remember that the goal of all this is to find some shoes that someone will like, and we are using "similar" as the aproximation of human preference.

Personally, I like Jordans so I chose this as my `query_image`: <img alt="Sample Jordan" width="450" src="/home/ergonyc/.fastai/data/ut-zap50k-images/Shoes/Sneakers and Athletic Shoes/Nike/7716996.288224.jpg">

Here's the functions what will do it:

```python
        feats = get_mnet_feature(mnetv2,t_image,to_cuda=False)
        reducer = get_umap_reducer(latents)
        neighs = NearestNeighbors(n_neighbors=num_neighs) 
        neigh_images = query_neighs(q_feat, myneighs, data, root_path, show = True)
        plot_sneak_neighs(neigh_images)
```    

## Widgets: preamble. load the data


In [None]:
# hide

modelnm = 'mobilenet_v2'
filename = f"zappos-50k-{modelnm}-features_sort_3"
df = pd.read_pickle(f"data/{filename}.pkl")


model = create_cnn_featurenet(modelnm,to_cuda=False)
MODELS = {modelnm:model}
        
modelnm = 'resnet18'
filename = f"zappos-50k-{modelnm}-features_sort_3"
df = pd.read_pickle(f"data/{filename}.pkl")

model = create_cnn_featurenet(modelnm,to_cuda=False)
MODELS[modelnm] = model

In [None]:
def pack_featurenets(model_list):
    return {m:create_cnn_featurenet(m,to_cuda=False) for m in model_list }
    

MODELS = pack_featurenets(['mobilenet_v2','resnet18'])

In principle we can easily load different feature models... but for now we only have mobilenet_v2 and resnet18 databases calculated for all the zappos data. 

TODO:  test that the outputs of xresnet and resnet are equivalent.


```python

        modelnm = 'resnet34'
        model = create_cnn_featurenet(modelnm,to_cuda=False)
        MODELS[modelnm] = model

        modelnm = 'xresnet18'
        model = create_cnn_featurenet(modelnm,to_cuda=False)
        MODELS[modelnm] = model
        
        modelnm = 'xresnet34'
        model = create_cnn_featurenet(modelnm,to_cuda=False)
        MODELS[modelnm] = model


```

Now lets load the 

In [None]:
model = MODELS['mobilenet_v2']

num_neighs = 5

# save the knns and umap reducers for later use
filename = f"data/{model.name}-knn{num_neighs}Xsize.pkl"
knns = load_pickle(filename)

filename = f"data/{model.name}-umapXsize.pkl"
reducers = load_pickle(filename)   

In [None]:
filename = f"zappos-50k-{model.name}-features_sort_3"
df = pd.read_pickle(f"data/{filename}.pkl")

In [None]:
#hide
test_feats = test_feats2
neighs = knns['small']
distance, nn_index = neighs.kneighbors(test_feats, return_distance=True)    

dist = distance.tolist()[0] 

In [None]:
# #def get_umap_embedding(latents):
# fn = df.path.values

# features = f"features_{SIZE_ABBR['small']}"
# print(features)
# data = df[['Category',features]].copy()
# db_feats = np.vstack(data[features].values)

# type(db_feats)

# snk2vec = dict(zip(fn,db_feats))

# snk2vec[list(snk2vec.keys())[0]]

# embedding = get_umap_embedding(db_feats)
# snk2umap = dict(zip(fn,embedding))



In [None]:

# # make the paths easily accessible
# paths = df[['path','classes_sm','classes_md','classes_lg']]
# neighbors = paths.iloc[nn_index.tolist()[0]].copy()

# df.columns

## Widgets: make this into a "tool"


In [None]:
# ts = [display(img.to_thumb(200,200))]+[display(i.to_thumb(100,100)) for i in images]

# nnc = carousel(ts, width='1200px')

# out_pl = HBox[display(img.to_thumb(200,200)), carousel]

In [None]:
# key = {sz:i for (i,sz) in enumerate(IMG_SIZES)}
    
# [x for x in IMG_SIZES.values()].index(128)


# #hide
# filename = "zappos-50k-mobilenetv2-features_sort_3"
# df = pd.read_pickle(f"data/{filename}.pkl")


# mnetv2 = model


In designing these widgets we'll use the naming convention of starting each object as "type, underscore, description" e.g. `btn_upload`

Types are: `btn` - button, `out` - an output "place", `dd` - dropdown, `tab` - tab, and `lbl` - label


In [None]:
#hide

# # Cell
# @patch
# def __getitem__(self:Box, i): return self.children[i]

# # Cell
# def widget(im, *args, **layout):
#     "Convert anything that can be `display`ed by IPython into a widget"
#     o = Output(layout=merge(*args, layout))
#     with o: display(im)
#     return o

# # Cell
# def _update_children(change):
#     for o in change['owner'].children:
#         if not o.layout.flex: o.layout.flex = '0 0 auto'

# # Cell
# def carousel(children=(), **layout):
#     "A horizontally scrolling carousel"
#     def_layout = dict(overflow='scroll hidden', flex_flow='row', display='flex')
#     res = Box([], layout=merge(def_layout, layout))
#     res.observe(_update_children, names='children')
#     res.children = children
#     return res
# def _open_thumb(fn, h, w): return Image.open(fn).to_thumb(h, w).convert('RGBA')

class SneakerFinder:
    "A widget that displays a SneakerFinder `fnms` along with a `Dropdown`"
    def __init__(self, opts=(), height=128, width=256, max_n=30):
        opts = ('<Keep>', '<Delete>')+tuple(opts)
        store_attr('opts,height,width,max_n')
        self.widget = carousel(width='100%')
0
    def set_fnms(self, fnms):
        self.fnms = L(fnms)[:self.max_n]
        ims = parallel(_open_thumb, self.fnms, h=self.height, w=self.width, progress=False,
                       n_workers=min(len(self.fnms)//10,defaults.cpus))
        self.widget.children = [VBox([widget(im, height=f'{self.height}px'), Dropdown(
            options=self.opts, layout={'width': 'max-content'})]) for im in ims]

    def _ipython_display_(self): display(self.widget)
        
    def values(self): return L(self.widget.children).itemgot(1).attrgot('value')
    def delete(self): return self.values().argwhere(eq('<Delete>'))
    def change(self):
        idxs = self.values().argwhere(not_(in_(['<Delete>','<Keep>'])))
        return idxs.zipwith(self.values()[idxs])



In [None]:
# hide

knns['small'].neighbors


caption = widgets.Label(value='The values of slider1 and slider2 are synchronized')
sliders1, slider2 = widgets.IntSlider(description='Slider 1'),\
                    widgets.IntSlider(description='Slider 2')
l = widgets.link((sliders1, 'value'), (slider2, 'value'))
display(caption, sliders1, slider2)

AttributeError: 'NearestNeighbors' object has no attribute 'neighbors'

In [None]:

# DEFAULT GLOBALS to start
im_sz = 'small'
model = MODELS['mobilenet_v2']
filename = f"zappos-50k-{model.name}-features_sort_3"
df = pd.read_pickle(f"data/{filename}.pkl")


In [None]:
# export

load_pipe    = Pipeline([PILImage.create,
                         FeatsResize(size=IMG_SIZE, method='pad', pad_mode='border')] )

prep_tf_pipe = Pipeline([ToTensor(),
                         IntToFloatTensor(),
                         Normalize.from_stats(*imagenet_stats,cuda=False)])


def plot_umap(data,im_sz,mname):
    fig, ax = plt.subplots()
    sns.scatterplot(
        x="umap-one",
        y="umap-two",
        hue="Category",
        hue_order = ['Sneakers', 'Shoes', 'Boots','Slippers'],
        palette=sns.color_palette("hls", 4),
        data=data.sample(frac=sld_sampfrac.value),
        legend="full",
        alpha=0.3,ax=ax
    )
    ax.set_aspect('equal', 'datalim')
    ax.set_title(f'UMAP projection of {mname} embedded UT-Zappos data (sz={IMG_SIZES[im_sz]})', fontsize=12)
    return ax

def on_click_find_similar(change):
    """ 
    this is the 'go' signal
    """
    global im_sz
    update_knn_reducer(im_sz)
    find_similar()

def find_similar():
    """ 
    find the knn
    """
    global knns,model,im_sz,df
    neighs = knns[im_sz]
    
    # load the image
    im = btn_upload.data[-1]
    img = load_pipe(im)
    tensor_im = prep_tf_pipe(img)
    feats = get_convnet_feature(model, tensor_im)
    
    # find the neighbors
    distance, nn_index = neighs.kneighbors(feats.numpy(), return_distance=True)    
    dist = distance.tolist()[0] 
    # fix path to the database...
    neighbors = df.iloc[nn_index.tolist()[0]].copy()
    nbr = neighbors.index

    
    #widget(im, max_width="292px")
    
    images = [ PILImage.create(D_ROOT/DBS['zappos']/f) for f in neighbors.path]
    
    ts = [VBox([widget(im, max_width="292px"),Label(f"d={d:.03f}")]) for im,d in zip(images,dist)]
    target_im = img.to_thumb(200,200)
    
    car_nn = carousel(ts, width='1200px')
    
    out_nn_imgs.clear_output()
    with out_nn_imgs:
        display(HBox([widget(target_im, max_width="500px"), car_nn]))

    #lbl_neighs.value = f'distances: {dist}


def update_knn_reducer(size):
    "update knn & reducer for new size im, but nothing is recalculated until the btn_run is clicked"
    # set to the current 
    global model,knns,reducers,im_sz,df
    im_sz = size
        
    umap = reducers[im_sz]
    neighs = knns[im_sz]

    features = f"features_{SIZE_ABBR[im_sz]}"
    data = df[['Category',features]].copy()
    
    db_feats = np.vstack(data[features].values)   
    # this is probably the bottleneck...
    embedding = umap.transform(db_feats)    
    data['umap-one'] = embedding[:,0]
    data['umap-two'] = embedding[:,1] 

    out_umap.clear_output()
    with out_umap:
        ax = plot_umap(data,size,model.name)
        plt.show(ax)

    find_similar()  

def update_model(model_name,size):
    " update the model but nothing is recalculated until the btn_run is clicked"
    #key = {sz:i for (i,sz) in enumerate(IMG_SIZES)}
    global model,knns,reducers,df
    model = MODELS[model_name]

    num_neighs = 5
    if model_name!=model.name :  print(f"dammit, '{model_name}'!='{model.name}'")
    # save the knns and umap reducers for later use
    knns = load_pickle(f"data/{model.name}-knn{num_neighs}Xsize.pkl")

    reducers = load_pickle(f"data/{model.name}-umapXsize.pkl")   
    
    filename = f"zappos-50k-{model.name}-features_sort_3"
    df = pd.read_pickle(f"data/{filename}.pkl")

    update_knn_reducer(size)

#Events
def dd_im_size_eh(change):
    update_knn_reducer(change.new)
    
def dd_model_eh(change):
    update_model(change.new,dd_im_size.value)


#define my widgets
btn_run = Button(description='Find similar sneaks!',layout = Layout(width='25%', height='80px'))
btn_upload = FileUpload(layout = Layout(width='25%', height='80px'))

out_umap = Output() # not doing anything here yet...
# lbl_neighs = Label() # labels for neighbors
out_nn_imgs = Output() # VBox([out_im,out_car])

dd_im_size = Dropdown(options=IMG_SIZES.keys(),value='small',description='Image Size:' )                       
dd_model = Dropdown(options=['mobilenet_v2','resnet18'], 
                    value='mobilenet_v2',
                    disabled=False,
                    description='Model:')
                    #,layout = Layout(width='40%') ) #style=style,
    
sld_sampfrac = widgets.FloatSlider(value=.5,
                min=0,
                max=1.0,
                step=0.05,
                description='sample %:',
                disabled=False,
                continuous_update=False,
                orientation='vertical',
                readout=True,
                readout_format='.2f',
)
              
#item_layout = widgets.Layout(margin='0 0 50px 0')
#input_widgets = widgets.HBox([dd_model, dd_im_size])

knn_select = HBox([dd_model, dd_im_size])


tab = widgets.Tab(children=[out_nn_imgs,HBox([out_umap, sld_sampfrac]) ] )#,layout=item_layout)
tab.set_title(0, 'Dataset Exploration')
tab.set_title(1, 'UMAP Plot')

cta = HBox([widgets.Label('Find your sneaker!    '),
            btn_upload,
            btn_run])

console = Label()
dashboard =  VBox([ cta,
                    knn_select,
                    tab,
                  console])

console = Label()
# actions
btn_run.on_click(on_click_find_similar)
dd_im_size.observe(dd_im_size_eh, names='value')
dd_model.observe(dd_model_eh, names='value')
# dd_lat_dim.observe(dd_lat_dim_eh, names='value')



In [None]:
dashboard

VBox(children=(HBox(children=(Label(value='Find your sneaker!    '), FileUpload(value={}, description='Upload'…

## Transfer Learning data cleaning tool



WIP

In [None]:
num_categories = 4

transfer_mobilenet_v2(num_categories,freeze=True)

# could also make a resnet transfer in a few lines of fastai api

#resnet = create_cnn_model(models.resnet18, 4, True)

MobileNetV2(
  (features): Sequential(
    (0): ConvBNReLU(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): ConvBNReLU(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): ConvBNReLU(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=Tr

In [None]:
# this should infer the number of categories and automattical re-head the resnet
learn = Learner(dls,resnet18, 
                    #splitter=mobilenet_split,cut=-1, 
                    pretrained=True,metrics=error_rate)

NameError: name 'dls' is not defined

## export

In [None]:
# hide

from nbdev.export import notebook2script
notebook2script()



Converted 00_core.ipynb.
Converted 01a_zappos_data.ipynb.
Converted 01b_scraped_data.ipynb.
Converted 02a_model.ipynb.
Converted 02b_transferlearning_model.ipynb.
Converted 03b_latenttarget_cvae.ipynb.
Converted 04_widgets.ipynb.
Converted index.ipynb.


In [None]:
#hide

import ipywidgets as widgets
from ipywidgets import FloatSlider, interact
from fastai2.vision.all import *
from fastai2.vision.widgets import *
from IPython.display import display,clear_output, Javascript
warnings.filterwarnings("ignore",category=matplotlib.cbook.mplDeprecation)

style = {'description_width': 'initial'}

RED = '\033[31m'
BLUE = '\033[94m'
GREEN = '\033[92m'
BOLD   = '\033[1m'
ITALIC = '\033[3m'
RESET  = '\033[0m'

def dashboard_one():
    """GUI for first accordion window"""
    import torchvision
    try:
        import fastai2; fastver = fastai2.__version__
    except ImportError:
        fastver = 'fastai not found'
    try:
        import fastprogress; fastprog = fastprogress.__version__
    except ImportError:
        fastprog = 'fastprogress not found'
    try:
        import fastpages; fastp = fastpages.__version__
    except ImportError:
        fastp = 'fastpages not found'
    try:
        import nbdev; nbd = nbdev.__version__
    except ImportError:
        nbd = 'nbdev not found'

    print (BOLD +  RED + '>> fastGUI\n')
    button = widgets.Button(description='System', button_style='success')
    ex_button = widgets.Button(description='Explore', button_style='success')
    display(button)

    out = widgets.Output()
    display(out)

    def on_button_clicked_info(b):
        with out:
            clear_output()
            print(BOLD + BLUE + "fastai2 version: " + RESET + ITALIC + str(fastver))
            print(BOLD + BLUE + "nbdev version: " + RESET + ITALIC + str(nbd))
            print(BOLD + BLUE + "fastprogress version: " + RESET + ITALIC + str(fastprog))
            print(BOLD + BLUE + "fastpages version: " + RESET + ITALIC + str(fastp) + '\n')
            print(BOLD + BLUE + "python version: " + RESET + ITALIC + str(sys.version))
            print(BOLD + BLUE + "torchvision: " + RESET + ITALIC + str(torchvision.__version__))
            print(BOLD + BLUE + "torch version: " + RESET + ITALIC + str(torch.__version__))
            print(BOLD + BLUE + "\nCuda: " + RESET + ITALIC + str(torch.cuda.is_available()))
            print(BOLD + BLUE + "cuda version: " + RESET + ITALIC + str(torch.version.cuda))

    button.on_click(on_button_clicked_info)
def dashboard_two():
    """GUI for second accordion window"""
    dashboard_two.datas = widgets.ToggleButtons(
        options=['PETS', 'CIFAR', 'IMAGENETTE_160', 'IMAGEWOOF_160', 'MNIST_TINY'],
        description='Choose',
        value=None,
        disabled=False,
        button_style='info',
        tooltips=[''],
        style=style
    )
    display(dashboard_two.datas)

    button = widgets.Button(description='Explore', button_style='success')
    display(button)
    out = widgets.Output()
    display(out)
    def on_button_explore(b):
        with out:
            clear_output()
            ds_choice()
            show()
    button.on_click(on_button_explore)

#Helpers for dashboard two
def ds_choice():
    """Helper for dataset choices"""
    if dashboard_two.datas.value == 'PETS':
        ds_choice.source = untar_data(URLs.DOGS)
    elif dashboard_two.datas.value == 'CIFAR':
        ds_choice.source = untar_data(URLs.CIFAR)
    elif dashboard_two.datas.value == 'IMAGENETTE_160':
        ds_choice.source = untar_data(URLs.IMAGENETTE_160)
    elif dashboard_two.datas.value == 'IMAGEWOOF_160':
        ds_choice.source = untar_data(URLs.IMAGEWOOF_160)
    elif dashboard_two.datas.value == 'MNIST_TINY':
        ds_choice.source = untar_data(URLs.MNIST_TINY)

def plt_classes():
    ds_choice()
    print(BOLD + BLUE + "Dataset: " + RESET + BOLD + RED + str(dashboard_two.datas.value))
    """Helper for plotting classes in folder"""
    Path.BASE_PATH = ds_choice.source
    train_source = (ds_choice.source/'train/').ls().items
    print(BOLD + BLUE + "\n" + "No of classes: " + RESET + BOLD + RED + str(len(train_source)))

    num_l = []
    class_l = []
    for j, name in enumerate(train_source):
        fol = (ds_choice.source/name).ls().sorted()
        names = str(name)
        class_split = names.split('train')
        class_l.append(class_split[1])
        num_l.append(len(fol))

    y_pos = np.arange(len(train_source))
    performance = num_l

    fig = plt.figure(figsize=(7,7))
    plt.style.use('seaborn')
    plt.bar(y_pos, performance, align='center', alpha=0.5, color=['black', 'red', 'green', 'blue', 'cyan'])
    plt.xticks(y_pos, class_l, rotation=90)
    plt.ylabel('Images')
    plt.title('Images per Class')
    plt.show()

def display_images():
    """Helper for displaying images from folder"""
    train_source = (ds_choice.source/'train/').ls().items
    for i, name in enumerate(train_source):
        fol = (ds_choice.source/name).ls().sorted()
        fol_disp = fol[0:5]
        filename = fol_disp.items
        fol_tensor = [tensor(Image.open(o)) for o in fol_disp]
        img = fol_tensor[0]
        print(BOLD + BLUE + "Loc: " + RESET + str(name) + " " + BOLD + BLUE + "Number of Images: " + RESET +
              BOLD + RED + str(len(fol)))

        fig = plt.figure(figsize=(15,15))
        columns = 5
        rows = 1
        ax = []

        for i in range(columns*rows):
            for i, j in enumerate(fol_tensor):
                img = fol_tensor[i]    # create subplot and append to ax
                ax.append( fig.add_subplot(rows, columns, i+1))
                ax[-1].set_title("ax:"+str(filename[i]))  # set title
                plt.tick_params(bottom="on", left="on")
                plt.xticks([])
                plt.imshow(img)
        plt.show()
def browse_images():
    print(BOLD + BLUE + "Use slider to choose image" + RESET)
    ds_choice()
    items = get_image_files(ds_choice.source/'train/')
    n = len(items)
    def view_image(i):
        plt.imshow(Image.open(items[i]), cmap=plt.cm.gray_r, interpolation='nearest')
        plt.title('Training: %s' % items[i])
        browse_images.img = items[i]
        plt.show()
    interact(view_image, i=(0,n-1))

def show():
    a = widgets.Output()
    b = widgets.Output()
    c = widgets.Output()
    with a:
        plt_classes()
    with b:
        display_images()
    with c:
        browse_images()
    view_one = VBox([a, c])
    view_two = HBox([view_one, b])
    display(view_two)

def aug_show():
    aug_button = widgets.Button(description='Augmentations', button_style='success')
    display(aug_button)
    aug_out = widgets.Output()
    display(aug_out)
    def on_aug_button(b):
        with aug_out:
            clear_output()
            j = widgets.Output()
            u = widgets.Output()
            with j:
                print(browse_images.img)
                display(Image.open(browse_images.img))
            with u:
                aug_dash()
            display(HBox([j, u]))
    aug_button.on_click(on_aug_button)

def aug_paras():
    """If augmentations is choosen show available parameters"""
    print(BOLD + BLUE + "Choose Augmentation Parameters: ")
    button_paras = widgets.Button(description='Confirm', button_style='success')

    aug_paras.hh = widgets.ToggleButton(value=False, description='Erase', button_style='info',
                                      style=style)
    aug_paras.cc = widgets.ToggleButton(value=False, description='Contrast', button_style='info',
                                      style=style)
    aug_paras.dd = widgets.ToggleButton(value=False, description='Rotate', button_style='info',
                                      style=style)
    aug_paras.ee = widgets.ToggleButton(value=False, description='Warp', button_style='info',
                                      style=style)
    aug_paras.ff = widgets.ToggleButton(value=False, description='Bright', button_style='info',
                                      style=style)
    aug_paras.gg = widgets.ToggleButton(value=False, description='DihedralFlip', button_style='info',
                                      style=style)
    aug_paras.ii = widgets.ToggleButton(value=False, description='Zoom', button_style='info',
                                      style=style)

    qq = widgets.HBox([aug_paras.hh, aug_paras.cc, aug_paras.dd, aug_paras.ee, aug_paras.ff, aug_paras.gg, aug_paras.ii])
    display(qq)
    display(button_paras)
    aug_par = widgets.Output()
    display(aug_par)
    def on_button_two_click(b):
        with aug_par:
            clear_output()
            aug_dash_choice()
    button_paras.on_click(on_button_two_click)

def aug():
    """Aug choice helper"""
    #Erase
    if aug_paras.hh.value == True:
            aug.b_max = FloatSlider(min=0,max=50,step=1,value=0, description='max count',
                                     orientation='horizontal', disabled=False)
            aug.b_pval = FloatSlider(min=0,max=1,step=0.1,value=0, description=r"$p$",
                                     orientation='horizontal', disabled=False)
            aug.b_asp = FloatSlider(min=0.1,max=5, step=0.1, value=0.3, description=r'$aspect$',
                                     orientation='horizontal', disabled=False)
            aug.b_len = FloatSlider(min=0.1,max=5, step=0.1, value=0.3, description=r'$sl$',
                                     orientation='horizontal', disabled=False)
            aug.b_ht = FloatSlider(min=0.1,max=5, step=0.1, value=0.3, description=r'$sh$',
                                     orientation='horizontal', disabled=False)
            aug.erase_code = 'this is ERASE on'
    if aug_paras.hh.value == False:
            aug.b_max = FloatSlider(min=0,max=10,step=1,value=0, description='max count',
                                     orientation='horizontal', disabled=True)
            aug.b_pval = FloatSlider(min=0,max=1,step=0.1,value=0, description='p',
                                     orientation='horizontal', disabled=True)
            aug.b_asp = FloatSlider(min=0.1,max=1.7,value=0.3, description='aspect',
                                     orientation='horizontal', disabled=True)
            aug.b_len = FloatSlider(min=0.1,max=1.7,value=0.3, description='length',
                                     orientation='horizontal', disabled=True)
            aug.b_ht = FloatSlider(min=0.1,max=1.7,value=0.3, description='height',
                                     orientation='horizontal', disabled=True)
            aug.erase_code = 'this is ERASE OFF'
    #Contrast
    if aug_paras.cc.value == True:
            aug.b1_max = FloatSlider(min=0,max=0.9,step=0.1,value=0.2, description='max light',
                                  orientation='horizontal', disabled=False)
            aug.b1_pval = FloatSlider(min=0,max=1.0,step=0.05,value=0.75, description='p',
                                  orientation='horizontal', disabled=False)
            aug.b1_draw = FloatSlider(min=0,max=100,step=1,value=1, description='draw',
                                  orientation='horizontal', disabled=False)
    else:
            aug.b1_max = FloatSlider(min=0,max=0.9,step=0.1,value=0, description='max light',
                                  orientation='horizontal', disabled=True)
            aug.b1_pval = FloatSlider(min=0,max=1.0,step=0.05,value=0.75, description='p',
                                  orientation='horizontal', disabled=True)
            aug.b1_draw = FloatSlider(min=0,max=100,step=1,value=1, description='draw',
                                  orientation='horizontal', disabled=True)
    #Rotate
    if aug_paras.dd.value == True:
            aug.b2_max = FloatSlider(min=0,max=10,step=1,value=0, description='max degree',
                                  orientation='horizontal', disabled=False)
            aug.b2_pval = FloatSlider(min=0,max=1,step=0.1,value=0.5, description='p',
                                  orientation='horizontal', disabled=False)
    else:
            aug.b2_max = FloatSlider(min=0,max=10,step=1,value=0, description='max degree',
                                  orientation='horizontal', disabled=True)
            aug.b2_pval = FloatSlider(min=0,max=1,step=0.1,value=0, description='p',
                                  orientation='horizontal', disabled=True)
    #Warp
    if aug_paras.ee.value == True:
            aug.b3_mag = FloatSlider(min=0,max=10,step=1,value=0, description='magnitude',
                                  orientation='horizontal', disabled=False)
            aug.b3_pval = FloatSlider(min=0,max=1,step=0.1,value=0, description='p',
                                  orientation='horizontal', disabled=False)
    else:
            aug.b3_mag = FloatSlider(min=0,max=10,step=1,value=0, description='magnitude',
                                  orientation='horizontal', disabled=True)
            aug.b3_pval = FloatSlider(min=0,max=10,step=1,value=0, description='p',
                                  orientation='horizontal', disabled=True)
    #Bright
    if aug_paras.ff.value == True:
            aug.b4_max = FloatSlider(min=0,max=10,step=1,value=0, description='max light',
                                  orientation='horizontal', disabled=False)
            aug.b4_pval = FloatSlider(min=0,max=1,step=0.1,value=0, description='p',
                                  orientation='horizontal', disabled=False)
    else:
            aug.b4_max = FloatSlider(min=0,max=10,step=1,value=0, description='max_light',
                                  orientation='horizontal', disabled=True)
            aug.b4_pval = FloatSlider(min=0,max=1,step=0.1,value=0, description='p',
                                  orientation='horizontal', disabled=True)
    #DihedralFlip
    if aug_paras.gg.value == True:
            aug.b5_pval = FloatSlider(min=0,max=1,step=0.1, description='p',
                                     orientation='horizontal', disabled=False)
            aug.b5_draw = FloatSlider(min=0,max=7,step=1, description='p',
                                     orientation='horizontal', disabled=False)
    else:
            aug.b5_pval = FloatSlider(min=0,max=1,step=0.1, description='p',
                                     orientation='horizontal', disabled=True)
            aug.b5_draw = FloatSlider(min=0,max=7,step=1, description='p',
                                     orientation='horizontal', disabled=True)
    #Zoom
    if aug_paras.ii.value == True:
            aug.b6_zoom = FloatSlider(min=1,max=5,step=0.1, description='max_zoom',
                                     orientation='horizontal', disabled=False)
            aug.b6_pval = FloatSlider(min=0,max=1,step=0.1, description='p',
                                     orientation='horizontal', disabled=False)
    else:
            aug.b6_zoom = FloatSlider(min=1,max=5,step=0.1, description='max_zoom',
                                     orientation='horizontal', disabled=True)
            aug.b6_pval = FloatSlider(min=0,max=1,step=1, description='p',
                                     orientation='horizontal', disabled=True)

def aug_dash_choice():
    """Augmention parameter display helper"""
    button_aug_dash = widgets.Button(description='View', button_style='success')
    item_erase_val= widgets.HBox([aug.b_max, aug.b_pval, aug.b_asp, aug.b_len, aug.b_ht])
    item_erase = widgets.VBox([aug_paras.hh, item_erase_val])

    item_contrast_val = widgets.HBox([aug.b1_max, aug.b1_pval, aug.b1_draw])
    item_contrast = widgets.VBox([aug_paras.cc, item_contrast_val])

    item_rotate_val = widgets.HBox([aug.b2_max, aug.b2_pval])
    item_rotate = widgets.VBox([aug_paras.dd, item_rotate_val])

    item_warp_val = widgets.HBox([aug.b3_mag, aug.b3_pval])
    item_warp = widgets.VBox([aug_paras.ee, item_warp_val])

    item_bright_val = widgets.HBox([aug.b4_max, aug.b4_pval])
    item_bright = widgets.VBox([aug_paras.ff, item_bright_val])

    item_dihedral_val = widgets.HBox([aug.b5_pval, aug.b5_draw])
    item_dihedral = widgets.VBox([aug_paras.gg, item_dihedral_val])

    item_zoom_val = widgets.HBox([aug.b6_zoom, aug.b6_pval])
    item_zoom = widgets.VBox([aug_paras.ii, item_zoom_val])

    items = [item_erase, item_contrast, item_rotate, item_warp, item_bright, item_dihedral, item_zoom]
    dia = Box(items, layout=Layout(
                    display='flex',
                    flex_flow='column',
                    flex_grow=0,
                    flex_wrap='wrap',
                    border='solid 1px',
                    align_items='flex-start',
                    align_content='flex-start',
                    justify_content='space-between',
                    width='flex'
                    ))
    display(dia)
    display(button_aug_dash)
    aug_dash_out = widgets.Output()
    display(aug_dash_out)
    def on_button_two(b):
        with aug_dash_out:
            clear_output()
            print(browse_images.img)
    button_aug_dash.on_click(on_button_two)

def aug_dash():
    """GUI for augmentation dashboard"""
    tg = widgets.Button(description='Pad', disabled=True, button_style='info')
    aug_dash.pad = widgets.ToggleButtons(value='Reflection', options=['Zeros', 'Reflection', 'Border'], description='',
                                         button_style='info',style=style, layout=Layout(width='auto'))
    th = widgets.Button(description='ResizeMethod', disabled=True, button_style='warning')
    aug_dash.rzm = widgets.ToggleButtons(value='Squish', options=['Squish', 'Pad', 'Crop'], description='',
                                         button_style='warning', style=style, layout=Layout(width='auto'))
    ti = widgets.Button(description='Resize', disabled=True, button_style='primary')
    aug_dash.res = widgets.ToggleButtons(value='128', options=['28', '64', '128', '194', '254'], description='',
                                         button_style='primary', style=style, layout=Layout(width='auto'))
    aug_paras.hh = widgets.ToggleButton(value=False, description='Erase', button_style='info',
                                      style=style)
    aug_paras.cc = widgets.ToggleButton(value=False, description='Contrast', button_style='info',
                                      style=style)
    aug_paras.dd = widgets.ToggleButton(value=False, description='Rotate', button_style='info',
                                      style=style)
    aug_paras.ee = widgets.ToggleButton(value=False, description='Warp', button_style='info',
                                      style=style)
    aug_paras.ff = widgets.ToggleButton(value=False, description='Bright', button_style='info',
                                      style=style)
    aug_paras.gg = widgets.ToggleButton(value=False, description='DihedralFlip', button_style='info',
                                      style=style)
    aug_paras.ii = widgets.ToggleButton(value=False, description='Zoom', button_style='info',
                                      style=style)

    qq = widgets.HBox([aug_paras.hh, aug_paras.cc, aug_paras.dd, aug_paras.ee, aug_paras.ff, aug_paras.gg, aug_paras.ii])

    it2 = [tg, aug_dash.pad]
    it3 = [th, aug_dash.rzm]
    it4 = [ti, aug_dash.res]
    il = widgets.HBox(it2)
    ik = widgets.HBox(it3)
    ij = widgets.HBox(it4)
    ir = widgets.VBox([il, ik, ij])
    display(ir)
    print(BOLD + BLUE + "Choose Augmentation Parameters: ")
    display(qq)
    aug_img()

def show_imagee(im, **kwargs):
    "Show_image helper for viewing images in Voila"
    # Handle pytorch axis order
    if hasattrs(im, ('data','cpu','permute')):
        im = im.data.cpu()
        if im.shape[0]<5: im=im.permute(1,2,0)
    elif not isinstance(im,np.ndarray): im=array(im)
    # Handle 1-channel images
    if im.shape[-1]==1: im=im[...,0]
    it = Tensor(im)
    img = Image.fromarray(im, 'RGB')
    display(img)

def aug_img():
    aug_img_b = widgets.Button(description='Confirm', button_style='success')
    display(aug_img_b)
    aug_img_out = widgets.Output()
    display(aug_img_out)
    def aug_img_(b):
        with aug_img_out:
            clear_output()
            aug_img = browse_images.img
            imgt = Image.open(aug_img)
            h1, w1 = imgt.shape
            pil_img = PILImage(PILImage.create(aug_img).resize((w1,h1))) #flip
            print(BOLD + BLUE + 'Size:' + RED + aug_dash.res.value + BLUE + ' ResizeMode:' + RED +
                  aug_dash.rzm.value + BLUE + ' Padding:' + RED + aug_dash.pad.value + RESET)
            if aug_dash.rzm.value == 'Pad': method = ResizeMethod.Pad
            if aug_dash.rzm.value == 'Squish': method = ResizeMethod.Squish
            if aug_dash.rzm.value == 'Crop': method = ResizeMethod.Crop
            if aug_dash.pad.value == 'Zeros': pad = PadMode.Zeros
            if aug_dash.pad.value == 'Border': pad = PadMode.Border
            if aug_dash.pad.value == 'Reflection': pad = PadMode.Reflection
            rsz = Resize(int(aug_dash.res.value), method=method, pad_mode=pad)
            display(show_imagee(rsz(pil_img)))
    aug_img_b.on_click(aug_img_)

def display_ui():
    """ Display tabs for visual display"""
    out1a = widgets.Output()
    out1 = widgets.Output()
    out2 = widgets.Output()
    data1a = pd.DataFrame(np.random.normal(size = 50))
    data1 = pd.DataFrame(np.random.normal(size = 100))
    data2 = pd.DataFrame(np.random.normal(size = 150))

    with out1a: #info
        clear_output()
        dashboard_one()

    with out1: #data
        clear_output()
        dashboard_two()

    with out2: #augmentation
        clear_output()
        aug_show()

    display_ui.tab = widgets.Tab(children = [out1a, out1, out2])
    display_ui.tab.set_title(0, 'Info')
    display_ui.tab.set_title(1, 'Data')
    display_ui.tab.set_title(2, 'Augmentation')
    display(display_ui.tab)

In [None]:
# hide
fnm = "junk/london-raw.csv"

df_london = pd.read_csv(fnm)

df_london.head()

df_london.loc[:,'visits']=df_london["Visits (000s)"]
df_london.loc[:,'spend']=df_london["Spend (£m)"]
df_london.loc[:,'nights']=df_london["Nights (000s)"]
df_london = df_london.sample(500)


In [None]:
# hide
ALL = 'ALL'
def unique_sorted_values_plus_ALL(array):
    unique = array.unique().tolist()
    unique.sort()
    unique.insert(0, ALL)
    return unique

In [None]:
# hide
def colour_ge_value(value, comparison):
    if value >= comparison:
        return 'color: red'
    else:
        return 'color: black'
    

output = widgets.Output()
plot_output = widgets.Output()

dropdown_year = widgets.Dropdown(options = unique_sorted_values_plus_ALL(df_london.year),
                                 description='Year:' )                       
dropdown_purpose = widgets.Dropdown(options = unique_sorted_values_plus_ALL(df_london.purpose), 
                                    description='Purpose:')
bounded_num = widgets.BoundedFloatText(min=0, max=100000, value=5, step=1, 
                                       description='Number:')

def common_filtering(year, purpose,num):
    output.clear_output()
    if (year == ALL) & (purpose == ALL):
        common_filter = df_london
    elif (year == ALL):
        common_filter = df_london[df_london.purpose == purpose]
    elif (purpose == ALL):
        common_filter = df_london[df_london.year == year]
    else:
        common_filter = df_london[(df_london.year == year)&(df_london.purpose == purpose)]
                                 
    with output:
        display(common_filter.style.applymap(lambda x: colour_ge_value(x, num),
                                                subset=['visits','spend', 'nights'] ) )
    with plot_output:
        sns.kdeplot(common_filter['visits'], shade=True)
        plt.show()
    
def dropdown_year_eventhandler(change):
    common_filtering(change.new, dropdown_purpose.value, bounded_num.value)
                                 
def dropdown_purpose_eventhandler(change):
    common_filtering(dropdown_year.value, change.new, bounded_num.value)
    
def bounded_num_eventhandler(change):
    common_filtering(dropdown_year.value, dropdown_purpose.value, 
                     change.new)

dropdown_purpose.observe(dropdown_purpose_eventhandler, names='value')
dropdown_year.observe(dropdown_year_eventhandler, names='value')
bounded_num.observe(bounded_num_eventhandler, names='value')                                 


item_layout = widgets.Layout(margin='0 0 50px 0')
input_widgets = widgets.HBox([dropdown_year, dropdown_purpose, bounded_num],
                            layout=item_layout)

tab = widgets.Tab([output, plot_output],layout=item_layout)

tab.set_title(0, 'Dataset Exploration')
tab.set_title(1, 'KDE Plot')

dashboard = widgets.VBox([input_widgets, tab])

display(dashboard)

In [None]:
# hide


########## this is WIP
# import re
# import time
# # import matplotlib.pyplot as pltmodel
# import matplotlib.image as mpimg
# import matplotlib.pyplot as plt
# from mpl_toolkits.mplot3d import Axes3D
# import plotly
# import plotly.express as px
# import plotly.figure_factory as FF


import bokeh.plotting as bplt #import figure, show, output_notebook
#from bokeh.models import HoverTool, ColumnDataSource, CategoricalColorMapper
import bokeh
# from bokeh.palettes import Spectral10

import umap

#from scipy import spatial  #for now just brute force to find neighbors
import scipy 
#from scipy.spatial import distance

from io import BytesIO
import base64



########################################3
#  BOKEH
#
##########################################3
def init_bokeh_plot(umap_df):

    bplt.output_notebook()

    datasource = bokeh.models.ColumnDataSource(umap_df)
    color_mapping = bokeh.models.CategoricalColorMapper(factors=["sns","goat"],
                                        palette=bokeh.palettes.Spectral10)

    plot_figure = bplt.figure(
        title='UMAP projection VAE latent',
        plot_width=1000,
        plot_height=1000,
        tools=('pan, wheel_zoom, reset')
    )

    plot_figure.add_tools(bokeh.models.HoverTool(tooltips="""
    <div>
        <div>
            <img src='@image' style='float: left; margin: 5px 5px 5px 5px'/>
        </div>
        <div>
            <span style='font-size: 14px'>@fname</span>
            <span style='font-size: 14px'>@loss</span>
        </div>
    </div>
    """))

    plot_figure.circle(
        'x',
        'y',
        source=datasource,
        color=dict(field='db', transform=color_mapping),
        line_alpha=0.6,
        fill_alpha=0.6,
        size=4
    )

    return plot_figure


def embeddable_image(label):
    return image_formatter(label)

def get_thumbnail(path):
    i = Image.open(path)
    i.thumbnail((64, 64), Image.LANCZOS)
    return i

def image_base64(im):
    if isinstance(im, str):
        im = get_thumbnail(im)
    with BytesIO() as buffer:
        im.save(buffer, 'png')
        return base64.b64encode(buffer.getvalue()).decode()

def image_formatter(im):
    return f"data:image/png;base64,{image_base64(im)}"



# do we need it loaded... it might be fast enough??
#@st.cache
def load_UMAP_data():
    data_dir = f"data/{model_name}-X{params['x_dim'][0]}-Z{params['z_dim']}"
    load_dir = os.path.join(data_dir,f"kl_weight{int(params['kl_weight']):03d}")
    snk2umap = ut.load_pickle(os.path.join(load_dir,"snk2umap.pkl"))
    
    return snk2umap


def load_latent_data():
    data_dir = f"data/{model_name}-X{params['x_dim'][0]}-Z{params['z_dim']}"
    snk2umap = load_UMAP_data()

    # load df (filenames and latents...)

    mids = list(snk2vec.keys())
    vecs = np.array([snk2vec[m] for m in mids])
    vec_tree = scipy.spatial.KDTree(vecs)


    latents = np.array(list(snk2vec.values()))
    losses = np.array(list(snk2loss.values()))
    labels = np.array(mids)

    labels2 = np.array(list(snk2umap.keys()))
    embedding = np.array(list(snk2umap.values()))

    assert(np.all(labels == labels2))    
    umap_df = pd.DataFrame(embedding, columns=('x', 'y'))

    umap_df['digit'] = [str(x.decode()) for x in labels]
    umap_df['image'] = umap_df.digit.map(lambda f: embeddable_image(f))
    umap_df['fname'] = umap_df.digit.map(lambda x: f"{x.split('/')[-3]} {x.split('/')[-1]}")
    umap_df['db'] = umap_df.digit.map(lambda x: f"{x.split('/')[-3]}")
    umap_df['loss'] = [f"{x:.1f}" for x in losses]

    return umap_df,snk2vec,latents, labels, vecs,vec_tree,mids


#%%
