<a href="https://colab.research.google.com/github/garfield-gray/Optimization/blob/main/Convex/TotalLeastSquaresMonifold.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Preparation and Test

In [12]:
!pip uninstall -y jax jaxlib pymanopt
!pip install jax==0.3.25 jaxlib==0.3.25
!pip install pymanopt==2.2.0

import pymanopt


Found existing installation: jax 0.4.30
Uninstalling jax-0.4.30:
  Successfully uninstalled jax-0.4.30
Found existing installation: jaxlib 0.4.30
Uninstalling jaxlib-0.4.30:
  Successfully uninstalled jaxlib-0.4.30
Found existing installation: pymanopt 2.2.0
Uninstalling pymanopt-2.2.0:
  Successfully uninstalled pymanopt-2.2.0
Collecting jax==0.3.25
  Using cached jax-0.3.25.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
[31mERROR: Could not find a version that satisfies the requirement jaxlib==0.3.25 (from versions: 0.4.6, 0.4.7, 0.4.9, 0.4.10, 0.4.11, 0.4.12, 0.4.13, 0.4.14, 0.4.16, 0.4.17, 0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23, 0.4.24, 0.4.25, 0.4.26, 0.4.27, 0.4.28, 0.4.29, 0.4.30)[0m[31m
[0m[31mERROR: No matching distribution found for jaxlib==0.3.25[0m[31m
[0mCollecting pymanopt==2.2.0
  Using cached pymanopt-2.2.0-py3-none-any.whl (71 kB)
Installing collected packages: pymanopt
Successfully installed pymanopt-2.2.0


In [13]:
!pip install jax
import jax

Collecting jax
  Using cached jax-0.4.30-py3-none-any.whl (2.0 MB)
Collecting jaxlib<=0.4.30,>=0.4.27 (from jax)
  Using cached jaxlib-0.4.30-cp310-cp310-manylinux2014_x86_64.whl (79.6 MB)
Installing collected packages: jaxlib, jax
Successfully installed jax-0.4.30 jaxlib-0.4.30


In [14]:
import jax.numpy as jnp
import pymanopt
import pymanopt.manifolds
import pymanopt.optimizers

from jax import random
key = random.PRNGKey(758493)  # Random seed is explicit in JAX

dim = 3
manifold = pymanopt.manifolds.Sphere(dim)

matrix = jax.random.uniform(key, shape=(dim,dim))
# matrix = anp.random.normal(size=(dim, dim))
matrix = 0.5 * (matrix + matrix.T)

@pymanopt.function.autograd(manifold)
def cost(point):
    return -point @ matrix @ point

problem = pymanopt.Problem(manifold, cost)

optimizer = pymanopt.optimizers.SteepestDescent()
result = optimizer.run(problem)

eigenvalues, eigenvectors = jnp.linalg.eig(matrix)


dominant_eigenvector = eigenvectors[:, eigenvalues.real.argmax()]

print("Dominant eigenvector:", dominant_eigenvector)
print("Pymanopt solution:", result.point)

Optimizing...
Iteration    Cost                       Gradient norm     
---------    -----------------------    --------------    
   1         -3.8244095444679260e-01    1.38542077e+00    
   2         -1.4024827480316162e+00    7.87703283e-01    
   3         -1.5149493217468262e+00    9.32544544e-02    
   4         -1.5163631439208984e+00    1.85166409e-02    
   5         -1.5164027214050293e+00    1.04554883e-02    
   6         -1.5164198875427246e+00    1.78467593e-03    
   7         -1.5164202451705933e+00    4.07885959e-04    
   8         -1.5164203643798828e+00    4.01647558e-04    
   9         -1.5164203643798828e+00    1.46429843e-04    
Terminated - min step_size reached after 9 iterations, 0.02 seconds.

Dominant eigenvector: [0.70809877+0.j 0.49849728+0.j 0.5000967 +0.j]
Pymanopt solution: [0.70811704 0.49850561 0.5000624 ]


# Manopt first function

In [15]:
import jax.numpy as jnp
import pymanopt
import pymanopt.manifolds
import pymanopt.optimizers

from jax import random
key = random.PRNGKey(758493)  # Random seed is explicit in JAX

dim = 3
manifold = pymanopt.manifolds.Sphere(dim)

matrix = jax.random.uniform(key, shape=(dim,dim))

matrix = 0.5 * (matrix + matrix.T)

@pymanopt.function.autograd(manifold)
def cost(point):
    return -point @ matrix @ point

problem = pymanopt.Problem(manifold, cost)

optimizer = pymanopt.optimizers.SteepestDescent()
result = optimizer.run(problem)

eigenvalues, eigenvectors = jnp.linalg.eig(matrix)


dominant_eigenvector = eigenvectors[:, eigenvalues.real.argmax()]

print("Dominant eigenvector:", dominant_eigenvector)
print("Pymanopt solution:", result.point)

Optimizing...
Iteration    Cost                       Gradient norm     
---------    -----------------------    --------------    
   1         -1.1613118648529053e+00    1.26149272e+00    
   2         -1.4093770980834961e+00    7.65685852e-01    
   3         -1.5139479637145996e+00    1.21127525e-01    
   4         -1.5161575078964233e+00    3.99360485e-02    
   5         -1.5164178609848022e+00    4.22774009e-03    
   6         -1.5164191722869873e+00    2.84098695e-03    
   7         -1.5164201259613037e+00    5.65828538e-04    
   8         -1.5164201259613037e+00    9.90259813e-04    
Terminated - min step_size reached after 8 iterations, 0.03 seconds.

Dominant eigenvector: [0.70809877+0.j 0.49849728+0.j 0.5000967 +0.j]
Pymanopt solution: [0.70821929 0.49855837 0.49986498]


In [31]:
import jax.numpy as jnp
import pymanopt
import pymanopt.manifolds
import pymanopt.optimizers

from jax import random
key = random.PRNGKey(758493)  # Random seed is explicit in JAX

dim = 3
r = 2
manifold = pymanopt.manifolds.Stiefel(dim, r)

matrix = jax.random.uniform(key, shape=(dim,dim))

matrix = 0.5 * (matrix + matrix.T)

@pymanopt.function.autograd(manifold)
def cost(X):
    return -np.trace(X.T @ matrix @ X)

problem = pymanopt.Problem(manifold, cost)

optimizer = pymanopt.optimizers.SteepestDescent()
result = optimizer.run(problem)

eigenvalues, eigenvectors = jnp.linalg.eig(matrix)


dominant_eigenvector = eigenvectors[:, eigenvalues.real.argmax()]

print("Dominant eigenvector:", dominant_eigenvector)
print("Pymanopt solution:", result.point)

Optimizing...
Iteration    Cost                       Gradient norm     
---------    -----------------------    --------------    
   1         -8.0173504352569580e-01    1.61086772e+00    
   2         -1.3878233432769775e+00    6.10348757e-01    
   3         -1.4215236902236938e+00    8.00967034e-01    
   4         -1.5121933221817017e+00    3.14564272e-01    
   5         -1.5192764997482300e+00    4.94227964e-01    
   6         -1.5418734550476074e+00    2.95756777e-01    
   7         -1.5473333597183228e+00    2.55893395e-01    
   8         -1.5562674999237061e+00    6.92585286e-02    
   9         -1.5571060180664062e+00    4.83043447e-02    
  10         -1.5575225353240967e+00    2.35111169e-02    
  11         -1.5576767921447754e+00    2.50217941e-02    
  12         -1.5577307939529419e+00    2.00211858e-02    
  13         -1.5577446222305298e+00    1.93545209e-02    
  14         -1.5577851533889771e+00    9.14661724e-03    
  15         -1.5577924251556396e+00    7.

# Final Code:)