In [18]:
import numpy as np

#A random initial value to start the iteration in the Manifold space S^n-1 which would be of the form R^n
x_cur = np.array((1, 0, 0, 0)) #n is chosen to be 4, so its a 3-sphere of unit radius embedded in R^4 Euclidean space

#Our cost function requires a symmetric nxn matrix A. One such A is initialised(random infact). 
A = np.array([[1, 2, 3, 0],[2, 1, 4, 0],[3, 4, 3, 0],[0, 0, 0, 0]])

#To keep track of the iterations
iterate = 0

'''Our step size calculation would be based on a concept called Armijo point. 
Armijo point requires the selection of scalars α > 0, β, σ ∈ (0, 1). These have also been randomly intialised for now. 
They will be tweaked later for better convergence.'''
alpha,beta,sigma = [3, 0.2, 0.05] 

#This is essentially Rayleigh quotient minimization , x->x.TAx where A=symmetric 4x4 matrix
field = lambda x: np.matmul(np.transpose(x),np.matmul(A,x))

'''The retraction which is a function that maps the manifold to itself. 
The retraction we have used takes two vectors x,y as arguements and returns (x+y)/abs(x+y) 
i.e direction of the vector x+y ''' 
def retraction(x_cur, tau):
    z = x_cur + tau
    absval = np.linalg.norm(z, 1)
    return np.divide(z, absval)

#Gradient is calculated as 2*(Ax_k − x_k x_k.T A x_k) .
def gradient(x_cur):
    expression = A*x_cur - x_cur*field(x_cur)
    return 2*expression

'''Find the smallest integer m ≥ 0 such that
f (Rx_k (aplha.beta^m eta_k )) ≤ f(x_k ) − sigma.alpha.beta^m eta_k.T eta_k'''
def armijo_step(x_cur, eta_k):
    m = 0
    while True:
        expr = np.dot(eta_k, eta_k)
        if np.all(field(retraction(x_cur, alpha*(beta**m)*eta_k))) <= np.all((field(x_cur) - (sigma*alpha*(beta**m)*expr))):
            break
        m += 1
        
    return alpha*(beta**m)

#The updation happens in this loop for 10000 iterations
while True:
    if iterate==10000:
        break
    eta_k = -1*gradient(x_cur)#eta_k is chosen as -grad f(x_k) which is gradient-related
    
    t_k = armijo_step(x_cur, eta_k)
    
    x_cur = retraction(x_cur, t_k*eta_k)
    iterate += 1

print(x_cur)

[[1.e-323 0.e+000 0.e+000 0.e+000]
 [5.e-324 0.e+000 0.e+000 0.e+000]
 [1.e+000 0.e+000 0.e+000 0.e+000]
 [0.e+000 0.e+000 0.e+000 0.e+000]]
