# Trivialization

<a target="_blank" href="https://colab.research.google.com/github/numqi/numqi/blob/main/docs/foundation/manifold/trivialization.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

[arxiv-link](https://arxiv.org/abs/1909.09501) Trivializations for Gradient-Based Optimization on Manifolds

[arxiv-link](https://arxiv.org/abs/2203.04794) Geometric Optimisation on Manifolds with Applications to Deep Learning

**Trivialization**: given a manifold $\mathcal{M}$, trivialization is defined as a surjective map from the Euclidean space onto the manifold:

$$ \phi:\mathbb{R}^n\to \mathcal{M} $$

A constrained optimization problem over some manifold $\mathcal{M}$

$$ \min_{x\in\mathcal{M}} f(x) $$

can be converted to an unconstrained optimization problem via a trivialization:

$$ \min_{\theta\in\mathbb{R}^n} f(\phi(\theta)) $$

In this notebook, we will list some common manifolds and their trivializations.

## Math notation

1. $\mathbb{R}$: real number
2. $\mathbb{R}_+$: real positive number
3. $\mathbb{R}^d$: $d$-dimensional real vector
4. $x\in\mathbb{R}^d,x> 0$: $x$ is a $d$-dimensional real vector and all elements are positive
5. $x\in\mathbb{R}^{d},x\succeq 0$: $x$ is a $d$-dimensional real vector and all elements are non-negative
6. $\mathbb{R}^{m\times n}$: $m\times n$ real matrix
7. $\mathbb{R}^{m\times m},x\succ 0$: $m\times m$ real definite positive matrix (all eigenvalues are positive)
8. $\mathbb{R}^{m\times m},x\succeq 0$: $m\times m$ real semi-definite positive matrix (all eigenvalues are non-negative)

In [None]:
import numpy as np
import torch

try:
    import numqi
except ImportError:
    %pip install numqi
    import numqi


## Positive real number

$$ \mathbb{R}_+ = \{x\in\mathbb{R}:x>0\} $$

One trivialization is the SoftPlus function:

$$ \phi(\theta) = \log(1+\exp(\theta)):\mathbb{R}\to\mathbb{R}_+ $$

Another trivialization is the Exponential function:

$$ \phi(\theta) = \exp(\theta):\mathbb{R}\to\mathbb{R}_+ $$

In [None]:
manifold = numqi.manifold.PositiveReal(batch_size=5, method='softplus')
point = manifold().detach().numpy() #random point
print('softplus:', point)

manifold = numqi.manifold.PositiveReal(batch_size=5, method='exp')
point = manifold().detach().numpy() #random point
print('exp:', point)

## Discrete Probability Simplex

$$ \Delta^{n-1}_+ = \{x\in\mathbb{R}^n:x_i>0,x_1+x_2+\cdots +x_n = 1\} $$


Trivialization can be composed: Let $g$ be any trivialization map of $\mathcal{R}_+$, then

$$ \phi(\theta)=(x_1,x_2,\cdots,x_n):\mathrm{dom}(g)^n\to\Delta_+^{n-1} $$

wth

$$ x_i = \frac{g(\theta_i)}{\sum_j g(\theta_j)} $$

gives a trivialization of $\Delta_+^{n-1}$. Specifically, the SoftMax function corresponds to $g(\theta) = \exp(\theta)$.

In [None]:
manifold = numqi.manifold.DiscreteProbability(5)
point = manifold().detach().numpy()
print('point:', point)
print('sum(point):', np.sum(point))

## Sphere

[wiki-link](https://en.wikipedia.org/wiki/N-sphere)

The sphere $S^n$ is defined as

$$ S^n = \{x\in\mathbb{R}^{n+1}:\lVert x\rVert_2=1\} $$

The following quotient map is a trivialization of the sphere:

$$\phi(\theta)=\frac{\theta}{\lVert \theta\rVert}: \mathbb{R}^{n+1}\to S^n$$

*PS*: the origin point is divergent in the trivialization, so we should be careful in initialization. We just hope the optimization will not jump to the origin, and if that happens, we should reinitialize the optimization. (such cases seem rare if random initialization is used).

In [None]:
dim = 5
manifold = numqi.manifold.Sphere(dim, method='quotient')
point = manifold().detach().numpy()
print('point:', point)
print('norm:', np.linalg.norm(point))

## Stiefel Manifold

[wiki-link](https://en.wikipedia.org/wiki/Stiefel_manifold)

The Stiefel manifold $\mathrm{St}(n,r)$ is defined as

$$ \mathrm{St}(n,r) = \{X\in\mathbb{R}^{n\times r}:X^TX=I_r\} $$

QR decomposition is a trivialization of the Stiefel manifold:

$$ \phi(\theta)=Q: \mathbb{R}^{n\times r}\to \mathrm{St}(n,r) $$

where $\theta=QR$ is the QR decomposition of $X$, and $Q$ is an orthogonal matrix.

PS: if the rank of matrix $X$ is smaller than $r$, then the QR decomposition will fail. Still, nothing we can do except hoping the optimization will not jump this singular point or reinitializing the optimization (such cases seem rare if random initialization is used).

TODO [doi-link](https://doi.org/10.1109/ICASSP39728.2021.9414157) A Global Cayley Parametrization of Stiefel Manifold for Direct Utilization of Optimization Mechanisms Over Vector Spaces

In [None]:
manifold = numqi.manifold.Stiefel(dim=5, rank=3, method='qr')
point = manifold().detach().numpy()
print('point:', point, sep='\n')
print('\nX^T X:', point.T @ point, sep='\n')

## Symmetric and Hermitian Matrices

[wiki-link/symmetric-matrix](https://en.wikipedia.org/wiki/Symmetric_matrix)

[wiki-link/Hermitian-matrix](https://en.wikipedia.org/wiki/Hermitian_matrix)

The set of symmetric matrices $\mathrm{Sym}^n$ is defined as

$$ \mathrm{Sym}^n = \{X\in\mathbb{R}^{n\times n}:X=X^T\} $$

The set of Hermitian matrices $\mathrm{Herm}^n$ is defined as

$$ \mathrm{Herm}^n = \{X\in\mathbb{C}^{n\times n}:X=X^\dagger\} $$

Both of them are vector spaces and we can find their basis $\{E_i\}$. E.g., For Hermitian matrix, the basis is Gell-Mann matrices (see [tutorial/gellmann](../../gellmann)) and the identity matrix. The trivialization is just the vectorization of the matrix:

$$\phi(\theta)=\sum_i\theta_iE_i $$

In [None]:
manifold_sym = numqi.manifold.SymmetricMatrix(3)
point = manifold_sym().detach().numpy()
print('point:', point, sep='\n')

In [None]:
manifold_herm = numqi.manifold.SymmetricMatrix(3, dtype=torch.complex128)
point = manifold_herm().detach().numpy()
print('point:', point, sep='\n')

## Rank-$r$ Positive Semi-Definite Matrices

The set of real rank-$r$ positive semi-definite matrices $\mathrm{PSD}^n_r$ is defined as

$$ \mathrm{Sym}^{(n,r)}_+ = \{X\in\mathbb{R}^{n\times n}:X\succeq 0,\mathrm{rank}(X)=r\} $$

To make it bounded, `numqi` adds a constraint that the trace of the matrix is 1. The trivialization is the reverse of Cholesky decomposition:

$$ \phi(\theta)=g(\theta)g(\theta)^T: \mathrm{dom}(g)\to \mathrm{Sym}^{(n,r)}_+ $$

where $g$ is a trivialization map of the lower triangular matrix $\mathrm{image}(g)=L^{(n,r)}_+$

$$ L^{(n,r)}_+=\{X\in\mathbb{R}^{n\times r}: X_{ii}>0,X_{ij,j>i}=0\}$$

In [None]:
manifold = numqi.manifold.Trace1PSD(dim=4, rank=2, method='cholesky')
point = manifold().detach().numpy()
print(point)
print('eigenvalue:', np.linalg.eigvalsh(point)) #eigenvalues are positive (up to machine precision)
print('trace(point):', np.trace(point))


Similarly, we can define the set of trace-1 complex rank-$r$ positive semi-definite matrices $\mathrm{PSD}^n_r$

$$ \mathrm{Herm}^{(n,r)}_+ = \{X\in\mathbb{C}^{n\times n}:X\succeq 0,\mathrm{rank}(X)=r,\mathrm{trace}(X)=1\} $$

In [None]:
manifold = numqi.manifold.Trace1PSD(dim=3, rank=2, method='cholesky', dtype=torch.complex128)
point = manifold().detach().numpy()
print(point)
print('eigenvalue:', np.linalg.eigvalsh(point))
print('trace(point):', np.trace(point))

## Special Orthogonal Group

[wiki/orthogonal-group](https://en.wikipedia.org/wiki/Orthogonal_group)

The special orthogonal group $\mathrm{SO}(n)$ is defined as

$$ \mathrm{SO}(n) = \{X\in\mathbb{R}^{n\times n}:X^TX=I_n,\det(X)=1\} $$

The trivialization can be Cayley transform [wiki-link](https://en.wikipedia.org/wiki/Cayley_transform) or [matrix exponential](https://en.wikipedia.org/wiki/Matrix_exponential).

Similarly, for special unitary group $\mathrm{SU}(n)$

$$ \mathrm{SU}(n) = \{X\in\mathbb{C}^{n\times n}:X^\dagger X=I_n,\det(X)=1\} $$

In [None]:
manifold_so = numqi.manifold.SpecialOrthogonal(dim=4)
point = manifold_so().detach().numpy()
print('point:', point, sep='\n')
print('point.T @ point:', point.T @ point, sep='\n')

In [None]:
manifold_su = numqi.manifold.SpecialOrthogonal(dim=3, dtype=torch.complex128)
point = manifold_su().detach().numpy()
print('point:', point, sep='\n')
print('point.H @ point:', point.T.conj() @ point, sep='\n')

## Connection with Quantum Information

TODO

1. pure quantum states
2. Hamiltonian
3. density matrix
4. quantum gate
5. quantum channel


In [None]:
fig,ax = numqi.manifold.plot_qobject_trivialization_map()

In [None]:
fig,ax = numqi.manifold.plot.plot_cha_trivialization_map()

In [None]:
fig,ax = numqi.manifold.plot.plot_tensor_rank_sigmar_trivialization_map()

In [None]:
fig,ax = numqi.manifold.plot.plot_pureb_trivialization_map()

In [None]:
fig,ax = numqi.manifold.plot.plot_uda_trivialization_map()

In [None]:
fig,ax = numqi.manifold.plot.plot_udp_trivialization_map()