Skip to content

Commit

Permalink
add transform data functionality (#26)
Browse files Browse the repository at this point in the history
Co-authored-by: Sanjay Mohan <>
  • Loading branch information
sanjmohan committed May 14, 2020
1 parent 6bd0821 commit d4d854d
Show file tree
Hide file tree
Showing 7 changed files with 359 additions and 76 deletions.
27 changes: 26 additions & 1 deletion README.md
Expand Up @@ -26,13 +26,38 @@ UMAP can use a precomputed distance matrix instead of finding the nearest neighb
embedding = umap(distances, n_components; metric=:precomputed)
```

## Fitting a UMAP model to a dataset and transforming new data

### Constructing a model
To construct a model to use for embedding new data, use the constructor:
```jl
model = UMAP_(X, n_components; <kwargs>)
```
where the constructor takes the same keyword arguments (kwargs) as `umap`. The returned object has the following fields:
```jl
model.graph # The graph of fuzzy simplicial set membership strengths of each point in the dataset
model.embedding # The embedding of the dataset
model.data # A reference to the original dataset
model.knns # A matrix of indices of nearest neighbors of points in the dataset,
# as determined on the original manifold (may be approximate)
model.dists # The distances of the neighbors indicated by model.knns
```

### Embedding new data
To transform new data into the existing embedding of a UMAP model, use the `transform` function:
```jl
Q_embedding = transform(model, Q; <kwargs>)
```
where `Q` is a matrix of new query data to embed into the existing embedding, and `model` is the object obtained from the `UMAP_` call above. `Q` must come from a space of the same dimensionality as `model.data` (ie `X` in the `UMAP_` call above).

The remaining keyword arguments (kwargs) are the same as for above functions.

## Implementation Details
There are two main steps involved in UMAP: building a weighted graph with edges connecting points to their nearest neighbors, and optimizing the low-dimensional embedding of that graph. The first step is accomplished either by an exact kNN search (for datasets with `< 4096` points) or by the approximate kNN search algorithm, [NNDescent](https://github.com/dillondaudert/NearestNeighborDescent.jl). This step is also usually the most costly.

The low-dimensional embedding is initialized (by default) with the eigenvectors of the normalized Laplacian of the kNN graph. These are found using ARPACK (via [Arpack.jl](https://github.com/JuliaLinearAlgebra/Arpack.jl)).

## Current Limitations
- **No transform**: Only one-time embeddings are possible at the moment. That is to say, it isn't possible to "fit" UMAP to a dataset and then use it to "transform" new data
- **Input data types**: Only data points that are represented by vectors of numbers (passed in as a matrix) are valid inputs. This is mostly due to a lack of support for other formats in [NNDescent](https://github.com/dillondaudert/NearestNeighborDescent.jl). Support for e.g. string datasets is possible in the future
- **Sequential**: This implementation does not take advantage of any parallelism

Expand Down
2 changes: 1 addition & 1 deletion src/UMAP.jl
Expand Up @@ -11,6 +11,6 @@ include("utils.jl")
include("embeddings.jl")
include("umap_.jl")

export umap, UMAP_
export umap, UMAP_, transform

end # module
63 changes: 46 additions & 17 deletions src/embeddings.jl
Expand Up @@ -19,6 +19,21 @@ function initialize_embedding(graph::AbstractMatrix{T}, n_components, ::Val{:ran
return [20 .* rand(T, n_components) .- 10 for _ in 1:size(graph, 1)]
end

"""
initialize_embedding(graph::AbstractMatrix{<:Real}, ref_embedding::AbstractMatrix{T<:AbstractFloat}) -> embedding
Initialize an embedding of points corresponding to the columns of the `graph`, by taking weighted means of
the columns of `ref_embedding`, where weights are values from the rows of the `graph`.
The resulting embedding will have shape `(size(ref_embedding, 1), size(graph, 2))`, where `size(ref_embedding, 1)`
is the number of components (dimensions) of the `reference embedding`, and `size(graph, 2)` is the number of
samples in the resulting embedding. Its elements will have type T.
"""
function initialize_embedding(graph::AbstractMatrix{<:Real}, ref_embedding::AbstractMatrix{T})::Vector{Vector{T}} where {T<:AbstractFloat}
embed = (ref_embedding * graph) ./ (sum(graph, dims=1) .+ eps(T))
return Vector{T}[eachcol(embed)...]
end

"""
spectral_layout(graph, embed_dim) -> embedding
Expand Down Expand Up @@ -46,29 +61,39 @@ function spectral_layout(graph::SparseMatrixCSC{T},
end

"""
optimize_embedding(graph, embedding, n_epochs, initial_alpha, min_dist, spread, gamma, neg_sample_rate) -> embedding
optimize_embedding(graph, query_embedding, ref_embedding, n_epochs, initial_alpha, min_dist, spread, gamma, neg_sample_rate, _a=nothing, _b=nothing; move_ref=false) -> embedding
Optimize an embedding by minimizing the fuzzy set cross entropy between the high and low dimensional simplicial sets using stochastic gradient descent.
Optimize "query" samples with respect to "reference" samples.
# Arguments
- `graph`: a sparse matrix of shape (n_samples, n_samples)
- `embedding`: a vector of length (n_samples,) of vectors representing the embedded data points
- `query_embedding`: a vector of length (n_samples,) of vectors representing the embedded data points to be optimized ("query" samples)
- `ref_embedding`: a vector of length (n_samples,) of vectors representing the embedded data points to optimize against ("reference" samples)
- `n_epochs`: the number of training epochs for optimization
- `initial_alpha`: the initial learning rate
- `gamma`: the repulsive strength of negative samples
- `neg_sample_rate::Integer`: the number of negative samples per positive sample
- `neg_sample_rate`: the number of negative samples per positive sample
- `_a`: this controls the embedding. If the actual argument is `nothing`, this is determined automatically by `min_dist` and `spread`.
- `_b`: this controls the embedding. If the actual argument is `nothing`, this is determined automatically by `min_dist` and `spread`.
# Keyword Arguments
- `move_ref::Bool = false`: if true, also improve the embeddings in `ref_embedding`, else fix them and only improve embeddings in `query_embedding`.
"""
function optimize_embedding(graph,
embedding,
query_embedding,
ref_embedding,
n_epochs,
initial_alpha,
min_dist,
spread,
gamma,
neg_sample_rate,
_a=nothing,
_b=nothing)
_b=nothing;
move_ref::Bool=false)
a, b = fit_ab(min_dist, spread, _a, _b)
self_reference = query_embedding === ref_embedding

alpha = initial_alpha
for e in 1:n_epochs
Expand All @@ -77,34 +102,38 @@ function optimize_embedding(graph,
j = rowvals(graph)[ind]
p = nonzeros(graph)[ind]
if rand() <= p
sdist = evaluate(SqEuclidean(), embedding[i], embedding[j])
sdist = evaluate(SqEuclidean(), query_embedding[i], ref_embedding[j])
if sdist > 0
delta = (-2 * a * b * sdist^(b-1))/(1 + a*sdist^b)
else
delta = 0
end
@simd for d in eachindex(embedding[i])
grad = clamp(delta * (embedding[i][d] - embedding[j][d]), -4, 4)
embedding[i][d] += alpha * grad
embedding[j][d] -= alpha * grad
@simd for d in eachindex(query_embedding[i])
grad = clamp(delta * (query_embedding[i][d] - ref_embedding[j][d]), -4, 4)
query_embedding[i][d] += alpha * grad
if move_ref
ref_embedding[j][d] -= alpha * grad
end
end

for _ in 1:neg_sample_rate
k = rand(1:size(graph, 2))
i != k || continue
sdist = evaluate(SqEuclidean(), embedding[i], embedding[k])
k = rand(eachindex(ref_embedding))
if i == k && self_reference
continue
end
sdist = evaluate(SqEuclidean(), query_embedding[i], ref_embedding[k])
if sdist > 0
delta = (2 * gamma * b) / ((1//1000 + sdist)*(1 + a*sdist^b))
else
delta = 0
end
@simd for d in eachindex(embedding[i])
@simd for d in eachindex(query_embedding[i])
if delta > 0
grad = clamp(delta * (embedding[i][d] - embedding[k][d]), -4, 4)
grad = clamp(delta * (query_embedding[i][d] - ref_embedding[k][d]), -4, 4)
else
grad = 4
end
embedding[i][d] += alpha * grad
query_embedding[i][d] += alpha * grad
end
end

Expand All @@ -114,5 +143,5 @@ function optimize_embedding(graph,
alpha = initial_alpha*(1 - e//n_epochs)
end

return embedding
return query_embedding
end

0 comments on commit d4d854d

Please sign in to comment.