/
RealStiefel.py
398 lines (323 loc) · 12 KB
/
RealStiefel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
from __future__ import division
from .NullRangeManifold import NullRangeManifold
import numpy.linalg as la
import numpy as np
from numpy import trace, zeros_like, bmat, zeros
from numpy.random import randn
from scipy.linalg import expm, expm_frechet, logm, null_space
from scipy.optimize import minimize
from .tools import vech, unvech, vecah, unvecah, asym
if not hasattr(__builtins__, "xrange"):
xrange = range
def _calc_dim(n, d):
dm = d * (n-d) + d*(d-1)//2
cdm = d*(d+1) // 2
tdim = n*d
return dm, cdm, tdim
class RealStiefel(NullRangeManifold):
"""Class for a Real Stiefel manifold
Block matrix Y with Y.T @ Y = I
Y of dimension n*d
Metric is defined by 2 parameters in the array alpha
Parameters
----------
n, d : # of rows and columns of the manifold point
alpha : array of size 2, alpha > 0
log_callback : print out progress when running log
log_method : None is trust-krylov. Otherwise
one of trust-ncg or trust-krylov
"""
def __init__(self, n, d, alpha=None, log_stats=False,
log_method=None):
self._point_layout = 1
self.n = n
self.d = d
self._name = "Real_stiefel manifold n=%d d=%d alpha=%s" % (
self.n, self.d, str(alpha))
self._dimension, self._codim, _ = _calc_dim(n, d)
if alpha is None:
self.alpha = np.array([1, .5])
else:
self.alpha = alpha
self.log_stats = log_stats
if log_method is None:
self.log_method = 'trust-krylov'
elif log_method.lower() in ['trust-ncg', 'trust-krylov', 'l-bfgs-b']:
self.log_method = log_method.lower()
else:
raise(ValueError(
'log method must be one of trust-ncg or trust-krylov'))
self.log_gtol = None
def inner(self, X, eta1, eta2=None):
""" Inner product (Riemannian metric) on the tangent space.
The tangent space is given as a matrix of size n*d
"""
alf = self.alpha
if eta2 is None:
eta2 = eta1
return alf[0]*trace(eta1.T @ eta2) + (alf[1]-alf[0]) *\
trace((eta1.T @ X) @ (X.T @ eta2))
def __str__(self):
return self._name
def base_inner_ambient(self, eta1, eta2):
return trace(eta1.T @ eta2)
def base_inner_E_J(self, a1, a2):
return trace(a1 @ a2.T)
def g(self, X, eta):
alf = self.alpha
return alf[0]*eta + (alf[1]-alf[0]) *\
X @ (X.T @ eta)
def g_inv(self, X, ambient):
ialp = 1/self.alpha
return ialp[0]*ambient + (ialp[1]-ialp[0]) * X @ (X.T @ ambient)
def J(self, X, eta):
return eta.T @ X + X.T @ eta
def Jst(self, X, a):
return 2*X@a
def g_inv_Jst(self, X, a):
return (2/self.alpha[1])*X@a
def D_g(self, X, xi, eta):
alf = self.alpha
return (alf[1]-alf[0]) * (xi @ (X.T @ eta) + X @ (xi.T @ eta))
def christoffel_form(self, X, xi, eta):
ret = xi @ X.T @ eta + eta @ X.T @ xi
ret += X @ (xi.T @ eta + eta.T @ xi)
ret -= (xi @ eta.T + eta @ xi.T) @ X
return 0.5*(self.alpha[1]-self.alpha[0]) * ret
def D_J(self, X, xi, eta):
return eta.T @ xi + xi.T @ eta
def D_Jst(self, X, xi, a):
return 2*xi@a
def D_g_inv_Jst(self, X, xi, a):
return (2/self.alpha[1])*xi@a
def contract_D_g(self, X, xi, eta):
alf = self.alpha
return (alf[1] - alf[0])*(eta @ xi.T + xi @ eta.T) @ X
def st(self, mat):
"""The split_transpose. transpose if real, hermitian transpose if complex
"""
return mat.T
def J_g_inv_Jst(self, X, a):
return 4/self.alpha[1]*a
def solve_J_g_inv_Jst(self, X, b):
""" base is use CG. Unlikely to use
"""
return self.alpha[1]/4*b
def proj(self, X, U):
"""projection. U is in ambient
return one in tangent
"""
UTX = U.T @ X
return U - 0.5*X @ (UTX + UTX.T)
def proj_g_inv(self, X, U):
ret = zeros_like(X)
ialp = 1/self.alpha
ret = ialp[0] * U
ret += 0.5*(ialp[1]-2*ialp[0]) * X @ (X.T @ U)
ret -= 0.5*ialp[1]*X @ (U.T @ X)
return ret
def zerovec(self, X):
return zeros_like(X)
def egrad2rgrad(self, X, U):
return self.proj_g_inv(X, U)
def rhess02_alt(self, X, xi, eta, egrad, ehess_val):
""" Ehess is the Hessian Vector Product
"""
alpha = self.alpha
etaxiy = xi @ (eta.T@X) + eta@(xi.T@X)
egcoef = 0.5*X @ (xi.T@eta + eta.T@xi)
ft = (alpha[0]-alpha[1])/alpha[0]
egcoef += ft*(etaxiy - X@(X.T@etaxiy))
return ehess_val - trace(egcoef @ egrad.T)
def ehess2rhess(self, X, egrad, ehess, H):
""" Convert Euclidean into Riemannian Hessian.
ehess is the Hessian product on the ambient space
egrad is the gradient on the ambient space
Formula would be:
project of ehess -\
(gradient (self.st(JJ)) H) @ ((JJ @ self.st(JJ))^{-1}) @ JJ @ egrad)
"""
alp = self.alpha
egrady = egrad.T @ X
grad_part = 0.5*H@(egrady+egrady.T)
egyxi = egrad@X.T@H
xiproj = H - X@(X.T@H)
grad_part += (1-alp[1]/alp[0])*(egyxi-X@(X.T@egyxi))
grad_part += (1-alp[1]/alp[0])*X@(egrad.T@xiproj)
return self.proj_g_inv(X, ehess-grad_part)
def retr(self, X, eta):
""" Calculate 'thin' qr decomposition of X + G
then add point X
then do thin lq decomposition
"""
u, _, vh = la.svd(X+eta, full_matrices=False)
return u @ vh
def norm(self, X, eta):
# Norm on the tangent space is simply the Euclidean norm.
return np.sqrt(self.inner(X, eta, eta))
def rand(self):
# Generate random point using qr of random normally distributed
# matrix.
O, _ = la.qr(randn(
self.n, self.d))
return O
def randvec(self, X):
U = self.proj(X, randn(*X.shape))
U = U / self.norm(X, U)
return U
def _rand_ambient(self):
return randn(self.n, self.d)
def _rand_range_J(self):
u = randn(self.d, self.d)
return u + u.T
def _vec(self, E):
return E.reshape(-1)
def _unvec(self, vec):
return vec.reshape(self.n, self.d)
def _vec_range_J(self, a):
return vech(a)
def _unvec_range_J(self, vec):
return unvech(vec)
def exp(self, Y, eta):
""" Geodesics, the formula involves matrices of size 2d
Parameters
----------
Y : a manifold point
eta : tangent vector
Returns
----------
gamma(1), where gamma(t) is the geodesics at Y in direction eta
"""
K = eta - Y @ (Y.T @ eta)
Yp, R = la.qr(K)
alf = self.alpha[1]/self.alpha[0]
A = Y.T @eta
x_mat = bmat([[2*alf*A, -R.T],
[R, zeros((self.d, self.d))]])
return np.array(
bmat([Y, Yp]) @ expm(x_mat)[:, :self.d] @ expm((1-2*alf)*A))
def exp_alt(self, Y, eta):
""" Geodesics, alternative formula
"""
alf = self.alpha[1]/self.alpha[0]
A = Y.T @ eta
e_mat = bmat([[(2*alf-1)*A, -eta.T@eta - 2*(1-alf)*A@A],
[np.eye(self.d), A]])
return np.array(
(bmat([Y, eta]) @ expm(e_mat))[:, :self.d] @ expm((1-2*alf)*A))
def dist(self, X, Y):
lg = self.log(X, Y, show_steps=False, init_type=1)
return self.norm(X, lg)
def log(self, Y, Y1, show_steps=False, init_type=1):
"""Inverse of exp
Parameters
----------
Y : a manifold point
Y1 : tangent vector
Returns
----------
eta such that self.exp(X, eta) = Y1
Algorithm: use the scipy.optimize trust region method
to minimize in eta ||self.exp(Y, eta) - Y1||_F^2
_F is the Frobenius norm in R^{n\times d}
The jacobian could be computed by the expm_frechet function
"""
alf = self.alpha[1]/self.alpha[0]
d = self.d
adim = (d*(d-1))//2
def getQ():
""" algorithm: find a basis in linear span of Y Y1
orthogonal to Y
"""
u, s, v = np.linalg.svd(
np.concatenate([Y, Y1], axis=1), full_matrices=False)
k = (s > 1e-14).sum()
good = u[:, :k]@v[:k, :k]
qs = null_space(Y.T@good)
Q, _ = np.linalg.qr(good@qs)
return Q
# Q, s, _ = la.svd(Y1 - Y@Y.T@Y1, full_matrices=False)
# Q = Q[:, :np.sum(np.abs(s) > 1e-14)]
Q = getQ()
k = Q.shape[1]
if k == 0:
# Y1 and Y has the same linear span
A = logm(Y.T @ Y1)
if self.log_stats:
return Y@A, [('success', True), ('message', 'aligment')]
return Y@A
def vec(A, R):
return np.concatenate(
[vecah(A), R.reshape(-1)])
def unvec(avec):
return unvecah(avec[:adim]), avec[adim:].reshape(k, d)
def dist(v):
A, R = unvec(v)
ex2 = expm(
np.array(
bmat([[2*alf*A, -R.T], [R, np.zeros((k, k))]])))
M = ex2[:d, :d]
N = ex2[d:, :d]
return -np.trace(Y1.T@(Y@M+Q@N)@expm((1-2*alf)*A))
def jac(v):
A, R = unvec(v)
ex1 = expm((1-2*alf)*A)
mat = np.array(bmat([[2*alf*A, -R.T], [R, np.zeros((k, k))]]))
E = np.array(bmat(
[[ex1@Y1.T@Y, ex1@Y1.T@Q],
[np.zeros_like(R), np.zeros((k, k))]]))
ex2, fe2 = expm_frechet(mat, E)
M = ex2[:d, :d]
N = ex2[d:, :d]
YMQN = (Y@M+Q@N)
partA = asym(
(1-2*alf)*expm_frechet((1-2*alf)*A, Y1.T@YMQN)[1])
partA += 2*alf*asym(fe2[:d, :d])
partR = -(fe2[:d, d:].T - fe2[d:, :d])
return vec(partA, partR)
def hessp(v, xi):
dlt = 1e-8
return (jac(v+dlt*xi) - jac(v))/dlt
def conv_to_tan(A, R):
return Y@A + Q@R
eta0 = self.proj(Y, Y1-Y)
A0 = asym(Y.T@eta0)
R0 = Q.T@eta0 - (Q.T@Y)@(Y.T@eta0)
if init_type != 0:
x0 = vec(A0, R0)
else:
x0 = np.zeros(adim + self.d*k)
def printxk(xk):
print(la.norm(jac(xk)), dist(xk))
if show_steps:
callback = printxk
else:
callback = None
res = {'fun': np.nan, 'x': np.zeros_like(x0),
'success': False,
'message': 'minimizer exception'}
try:
if self.log_gtol is None:
if self.log_method.startswith('trust'):
res = minimize(dist, x0, method=self.log_method,
jac=jac, hessp=hessp, callback=callback)
else:
res = minimize(dist, x0, method=self.log_method,
jac=jac, callback=callback)
else:
if self.log_method.startswith('trust'):
res = minimize(dist, x0, method=self.log_method,
jac=jac, hessp=hessp, callback=callback,
options={'gtol': self.log_gtol})
else:
res = minimize(dist, x0, method=self.log_method,
jac=jac, callback=callback,
options={'gtol': self.log_gtol})
except Exception:
pass
stat = [(a, res[a]) for a in res.keys() if a not in ['x', 'jac']]
A1, R1 = unvec(res['x'])
if self.log_stats:
return conv_to_tan(A1, R1), stat
else:
return conv_to_tan(A1, R1)