In [23]:
import umap
import numpy as np

class UMAPModel:
    def __init__(self, n_neighbors=15, min_dist=0.1):
        self.umap_model = None
        self.n_neighbors = n_neighbors
        self.min_dist = min_dist

    def fit(self, X):
        # Fit a UMAP model to the input data
        self.umap_model = umap.UMAP(n_neighbors=self.n_neighbors, min_dist=self.min_dist)
        self.umap_model.fit(X)

    def transform(self, X):
        # Project new points onto the 2D latent space
        if self.umap_model is not None:
            return self.umap_model.transform(X)
        else:
            raise ValueError("UMAP model not initialized. Call fit() first.")

    def reverse_transform(self, Y):
        # Map points from the 2D latent space back to the original 128D space
        if self.umap_model is not None:
            return self.umap_model.inverse_transform(Y)
        else:
            raise ValueError("UMAP model not initialized. Call fit() first.")


In [24]:
latent_space = np.random.rand(100, 8)

In [25]:
umapper = UMAPModel()

In [26]:
umapper.fit(latent_space)

In [27]:
random_vec = np.random.rand(1, 8)

In [28]:
embedding = umapper.transform(random_vec)

In [29]:
inverse = umapper.reverse_transform(embedding)

In [30]:
inverse.shape

(1, 8)

In [31]:
random_vec.shape

(1, 8)

In [32]:
np.allclose(inverse, random_vec, atol=.8)

True

In [33]:
print(random_vec)
print(inverse)

[[0.4797351  0.11144287 0.71489721 0.07081483 0.59334856 0.69778806
  0.18588996 0.24288225]]
[[0.69606674 0.24897045 0.49031407 0.12620339 0.27919522 0.46984434
  0.48226428 0.28736612]]


In [55]:
import numpy as np
from umap import UMAP


class UMAPWrapper:
    def __init__(self, n_neighbors=15, min_dist=0.1, n_components=2, random_state=None):
        self.umap = UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=n_components, random_state=random_state)
        self.index_map = None
        self.embedding_ = None

    def fit_transform(self, X):
        self.embedding_ = self.umap.fit_transform(X)
        self.index_map = {i:X[i] for i in np.arange(X.shape[0])}
        return self.embedding_

    def inverse_transform(self):
        if self.embedding_ is None:
            raise ValueError("UMAPWrapper has not been fit yet")
        return self.umap.inverse_transform(self.embedding_)

    def transform(self, X):
        if self.embedding_ is None:
            raise ValueError("UMAPWrapper has not been fit yet")
        return self.umap.transform(X)

    def get_original_vector(self, index):
        if self.index_map is None:
            raise ValueError("UMAPWrapper has not been fit yet")
        return self.index_map[index]

    def get_vector_from_embedding(self, embedding):
        if self.index_map is None:
            raise ValueError("UMAPWrapper has not been fit yet")
        dists = np.linalg.norm(self.embedding_ - embedding, axis=1)
        index = np.argmin(dists)
        return index, self.embedding_[index], self.index_map[index]


In [56]:
umap_wrapper = UMAPWrapper()

In [57]:
umap_wrapper.fit_transform(latent_space)

array([[2.1704125 , 5.87141   ],
       [0.6727694 , 5.859256  ],
       [2.7765727 , 3.6924806 ],
       [3.2505517 , 3.5533834 ],
       [1.7328081 , 5.3222904 ],
       [5.0537004 , 6.273857  ],
       [1.2288592 , 6.8608766 ],
       [2.0472836 , 5.1473317 ],
       [2.5902667 , 7.2860103 ],
       [2.101192  , 4.1712694 ],
       [1.8941754 , 6.4707294 ],
       [2.2534497 , 3.9728942 ],
       [1.4951508 , 5.0140705 ],
       [1.6394484 , 6.73142   ],
       [1.9728084 , 7.1558995 ],
       [2.6985154 , 6.5972548 ],
       [2.3424666 , 6.0553966 ],
       [3.3605783 , 6.853383  ],
       [4.2480083 , 7.5495744 ],
       [3.9663868 , 3.650904  ],
       [3.5873086 , 3.6235406 ],
       [3.5373967 , 7.7228045 ],
       [0.9750319 , 6.733413  ],
       [3.1375666 , 5.346933  ],
       [3.2241592 , 7.3273497 ],
       [2.2926    , 4.829034  ],
       [4.4211445 , 6.062355  ],
       [2.9910848 , 3.3319569 ],
       [5.30639   , 4.6951847 ],
       [3.8653345 , 4.6063294 ],
       [1.

In [58]:
umap_wrapper.get_original_vector(5)

array([0.7619319 , 0.44991535, 0.54017935, 0.01770059, 0.51626347,
       0.75899629, 0.40786152, 0.62512568])

In [60]:
umap_wrapper.get_vector_from_embedding(np.array([2.1704125 , 5.87141   ]))

(0,
 array([2.1704125, 5.87141  ], dtype=float32),
 array([0.60052374, 0.87028757, 0.50094329, 0.3488188 , 0.89264969,
        0.41704913, 0.92565311, 0.99018395]))