In [None]:
#| default_exp create_umap_layout

In [None]:
#| export

import numpy as np
import umap
import daft

  import pkg_resources


# Create umap layout(s)

In [None]:
#| export

def umap_2d(embeddings: np.ndarray, n_neighbors: int = 15, min_dist: float = 0.1,
                metric: str = "cosine", random_state: int = 42) -> np.ndarray:
    """
    Create a UMAP layout from embeddings.

    Args:
        embeddings: np.ndarray, shape (n_samples, n_features)
        n_neighbors: int, default=15
        min_dist: float, default=0.1
        n_components: int, default=2
    """
    return umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist,
                     metric=metric, random_state=random_state).fit_transform(embeddings)

In [None]:
#| hide

umap_2d(np.random.rand(5, 10), n_neighbors=15, min_dist=0.1, metric="cosine", random_state=42)

  warn(


array([[-1.093321 , -4.5273743],
       [-2.5225453, -4.822378 ],
       [-1.7975024, -4.022668 ],
       [-1.2994211, -5.5270386],
       [-2.1556873, -5.7061405]], dtype=float32)

In [None]:
#| export

def create_umap_col(df: daft.DataFrame, n_neighbors: int = 15, min_dist: float = 0.1,
                metric: str = "cosine", random_state: int = 42,
                col_namespace="umap") -> daft.DataFrame:
    """
    Create a UMAP layout from embeddings.

    Args:
        embeddings: np.ndarray, shape (n_samples, n_features)
        n_neighbors: int, default=15
        min_dist: float, default=0.1
        n_components: int, default=2
    """

    ## TODO: use namespacing for XYs column
    # embeds = np.asarray(df.select("embeddings").to_pylist()["embeddings"])
    embeds = df.select("embeddings").to_arrow()["embeddings"]
    shape = len(embeds), embeds.type.list_size
    embeds = embeds.values.to_numpy().reshape(shape)
    umap_xys = umap_2d(embeds, n_neighbors=n_neighbors, min_dist=min_dist,
                       metric=metric, random_state=random_state)
    umap_xys = daft.Series.from_numpy(umap_xys)
    return df.with_column(f"{col_namespace}_xys", umap_xys)

In [None]:
# | hide
import nbdev

nbdev.nbdev_export()