# Uniform Manifold Approximation and Projection (UMAP)

- UMAP이 t-SNE보다 속도가 현저히 빠름
- 일반화된 embedding 차원: 시각화 용도인 t-SNE와 다르게 UMAP은 embedding 차원의 크기에 대한 제한이 없이 때문에, 일반적인 차원 축소 알고리즘으로 적용가능
- Global structure: 전체적인 manifold 구조를 더 잘 보존함
- 탄탄한 이론적 배경: 리만 기하학과 위상 수학에 기반 (저자에 의하면, 이론적 디테일은 모르는 게 정신건강에 이롭다고 함)
- Paper: https://arxiv.org/abs/1802.03426

## Import

In [1]:
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt

import warnings

warnings.filterwarnings('ignore')

## Configuration

In [2]:
# matplotlib configure
plt.rcParams['image.cmap'] = 'gray'

# Color from R ggplot colormap
color = [
    '#6388b4', '#ffae34', '#ef6f6a', '#8cc2ca', '#55ad89', '#c3bc3f',
    '#bb7693', '#baa094', '#a9b5ae', '#767676'
]

## Load dataset

In [3]:
mnist = pd.read_csv('../data/mnist_train.csv')
mnist.head()

Unnamed: 0,label,pixel0,pixel1,pixel2,pixel3,pixel4,pixel5,pixel6,pixel7,pixel8,...,pixel774,pixel775,pixel776,pixel777,pixel778,pixel779,pixel780,pixel781,pixel782,pixel783
0,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,4,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [4]:
label = mnist['label']
mnist.drop(['label'], inplace=True, axis=1)

## UMAP result

In [5]:
%%time
from umap.umap_ import UMAP

umap = UMAP(random_state=0)
mnist_umap = umap.fit_transform(mnist, label)

CPU times: user 2min 17s, sys: 1min 12s, total: 3min 29s
Wall time: 1min 54s


In [6]:
import plotly.graph_objects as go

fig = go.Figure()

for idx in range(10):
    fig.add_trace(
        go.Scatter(x=mnist_umap[:, 0][label == idx],
                   y=mnist_umap[:, 1][label == idx],
                   name=str(idx),
                   opacity=0.6,
                   mode='markers',
                   marker=dict(color=color[idx])))

fig.update_layout(width=800,
                  height=800,
                  title="UMAP result",
                  yaxis=dict(scaleanchor="x", scaleratio=1),
                  legend=dict(orientation="h",
                              yanchor="bottom",
                              y=1.02,
                              xanchor="right",
                              x=1))

fig.show()