In [1]:
LMDB_FILEPATH = "/mnt/data_ssd/lmdb/seefood_test_data_mobilenet_v2"

## Imports and Utilities

In [144]:
import umap
import torch

import pickle
import numpy as np
import pandas as pd
import umap.plot
import matplotlib as mpl
from matplotlib import cm
from skimage import img_as_ubyte
from io import BytesIO
import base64
from PIL import Image

from bokeh.plotting import figure, show, ColumnDataSource

In [118]:
umap.plot.output_notebook()

In [95]:
def to_png(arr):
    out = BytesIO()
    im = Image.fromarray(arr)
    im.save(out, format='png')
    return out.getvalue()

In [11]:
class Features:
    def __init__(self, image_path, features, target):
        self.shape = features.shape
        self.features = features.numpy().tobytes()
        self.image_path = image_path
        self.target = target.round().item()

    def get_features(self):
        features = np.frombuffer(self.features, dtype=np.float32)
        return torch.from_numpy(features.reshape(self.shape))

class LMDBDataset(torch.utils.data.Dataset):
    def __init__(self, lmdb_filename):
        self.env = lmdb.open(
            lmdb_filename,
            max_readers=1,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
        )
        print(self.env.stat())
        with self.env.begin(write=False) as txn:
            self.length = txn.stat()["entries"]

    def __getitem__(self, index):
        with self.env.begin(write=False) as txn:
            key = f"{index:08}".encode("ascii")
            buf = txn.get(key)

        features = pickle.loads(buf)
        return features.image_path, features.get_features(), np.log1p(features.target)

    def __len__(self):
        return self.length

## Load Data

In [12]:
dataloader = torch.utils.data.DataLoader(
    LMDBDataset(LMDB_FILEPATH), batch_size=70000, shuffle=False, num_workers=0,
)

{'psize': 4096, 'depth': 3, 'branch_pages': 14, 'leaf_pages': 2793, 'overflow_pages': 865748, 'entries': 432874}


In [14]:
sample = next(iter(dataloader))

## Create Embeddings using UMAP

In [9]:
reducer = umap.UMAP()

In [10]:
mapper = reducer.fit(sample[1])

## Plot Embeddings

In [137]:
imgs = pd.Series(sample[0]).map(lambda path: f"http://localhost:8080/{path[11:]}")

In [221]:
TOOLTIPS = """
    <div>
        <div>
            <img
                src="@imgs" height="224" alt="@imgs" width="224"
            ></img>
        </div>
        <div>@calories</div>
    </div>
"""

In [222]:
total_calories = np.expm1(sample[2]).numpy()

In [227]:
colors = [
    "#%02x%02x%02x" % (int(r), int(g), int(b)) for r, g, b, _ in 255*mpl.cm.hot(mpl.colors.Normalize()(total_calories))
]

In [228]:
source = ColumnDataSource(data=dict(
    x=mapper.embedding_[:,0],
    y=mapper.embedding_[:,1],
    imgs=imgs,
    colors=colors,
    calories=total_calories
))

In [230]:
p = figure(plot_width=900, plot_height=900, tooltips=TOOLTIPS)
p.scatter('x', 'y', size=2, fill_color="colors", line_color=None, fill_alpha=0.5, source=source)
show(p)