In [1]:
import numpy as np
from umap import UMAP


class UMAPWrapper:
    def __init__(self, n_neighbors=15, min_dist=0.1, n_components=2, random_state=None):
        self.umap = UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=n_components, random_state=random_state)

        self.latent_vectors = None
        self.embeddings = None

        self.audio_files = dict()

    def _assert_fitted(self):
        if self.embeddings is None:
            raise ValueError("UMAPWrapper has not been fit yet")

    @staticmethod
    def _calculate_closest_index(x, y):
        dists = np.linalg.norm(x - y, axis=1)
        return np.argmin(dists)

    def fit_transform(self, x):
        self.latent_vectors = x.copy()
        self.embeddings = self.umap.fit_transform(x)
        return self.embeddings

    def inverse_transform(self, y):
        self._assert_fitted()
        return self.umap.inverse_transform(y)

    def transform(self, x):
        self._assert_fitted()
        return self.umap.transform(x)

    def get_embedding_from_index(self, index):
        self._assert_fitted()
        return self.embeddings[index]

    def get_latent_vector_from_index(self, index):
        self._assert_fitted()
        return self.latent_vectors[index]

    def get_closest_embedding_to_point(self, x):
        self._assert_fitted()
        index = self._calculate_closest_index(self.embeddings, x)

        return index, self.embeddings[index]

    def get_closest_latent_vector_to_point(self, x):
        self._assert_fitted()
        index = self._calculate_closest_index(self.embeddings, x)

        return index, self.latent_vectors[index]

    def get_audio_for_embedding(self, embedding):
        """
        Function should take in an embedding and find the closest matching embedding in the latent space (get vector from embedding)
        Check if audio file exists for that embedding
        if it does, return it
        Otherwise:
            Use feedforward section of model to predict output audio and save to file
            add file to dict and return audio

        :param embedding:
        :return:
        """
        pass

    def plot_embeddings(self):
        import matplotlib.pyplot as plt

        self._assert_fitted()

        fig, ax = plt.subplots(1,1)

        ax.scatter(self.embeddings[:,0], self.embeddings[:,1])
        ax.grid(False)
        ax.set_xticks([])
        ax.set_yticks([])
        text=ax.text(0,0, "", va="bottom", ha="left")

        def onclick(event):
            tx = 'button=%d, x=%d, y=%d, xdata=%f, ydata=%f' % (event.button, event.x, event.y, event.xdata, event.ydata)
            text.set_text(tx)

        cid = fig.canvas.mpl_connect('button_press_event', onclick)
        plt.show()


In [2]:
# %matplotlib notebook

In [3]:
latent_space = np.random.rand(100, 64)
new_datapoint = np.random.rand(1, 64)

In [4]:
umap_wrapper = UMAPWrapper()

In [5]:
embeddings = umap_wrapper.fit_transform(latent_space)

In [6]:
umap_wrapper.get_closest_embedding_to_point(np.ones(2)*6)

(39, array([2.5609782, 5.2854466], dtype=float32))

In [7]:
umap_wrapper.get_closest_latent_vector_to_point(np.ones(2)*6)

(39,
 array([0.83248386, 0.4821484 , 0.82206747, 0.31900238, 0.73872611,
        0.8514468 , 0.58993486, 0.76728262, 0.41238704, 0.44656862,
        0.18824827, 0.38281686, 0.9513613 , 0.7088507 , 0.69058843,
        0.68325301, 0.14083225, 0.98583876, 0.50214802, 0.95951107,
        0.16413877, 0.58632032, 0.41257116, 0.86238714, 0.69411638,
        0.91183359, 0.22562311, 0.81185581, 0.53149328, 0.52437155,
        0.98465077, 0.73909676, 0.64253361, 0.51582989, 0.07783874,
        0.0341631 , 0.64668561, 0.85429321, 0.07986794, 0.9078949 ,
        0.39884967, 0.32694515, 0.67044851, 0.71304066, 0.37065939,
        0.97358948, 0.67722709, 0.93318999, 0.61003238, 0.17025172,
        0.05633042, 0.11322537, 0.89048318, 0.22992115, 0.77399235,
        0.50839098, 0.30507082, 0.66454499, 0.33711477, 0.58463837,
        0.84840726, 0.91430902, 0.92163273, 0.26289868]))

In [8]:
# import matplotlib.pyplot as plt
# import numpy as np
# %matplotlib notebook 

# fig = plt.figure();
# ax = fig.add_subplot(111)
# ax.scatter(embeddings[:,0], embeddings[:,1])

# # def onclick(event):
# #     a = ('button=%d, x=%d, y=%d, xdata=%f, ydata=%f' %
# #           (event.button, event.x, event.y, event.xdata, event.ydata))
# #     ax.set_title(a)

# # cid = fig.canvas.mpl_connect('button_press_event', onclick)

# plt.show()

In [9]:
import plotly.graph_objs as go
from plotly.offline import init_notebook_mode, iplot
from IPython.display import display

init_notebook_mode(connected=True)

# create data
x = [1, 2, 3, 4, 5]
y = [1, 4, 2, 3, 5]

# create trace
trace = go.Scatter(
    x=x,
    y=y,
    mode='markers',
    marker=dict(
        color='blue',
        size=10
    )
)

# create figure
fig = go.Figure(data=[trace])

# define click handler
def on_click(trace, points, state):
    for i in range(len(points.point_inds)):
        ind = points.point_inds[i]
        trace.marker.color[ind] = 'red'
        x_val = trace.x[ind]
        y_val = trace.y[ind]
        print(f"Clicked on point at ({x_val}, {y_val})")

# register click handler
fig.data[0].on_click(on_click)

# show plot
iplot(fig)
