-
Notifications
You must be signed in to change notification settings - Fork 104
/
fitting.py
385 lines (328 loc) · 12 KB
/
fitting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
"""Tools for computing distances between and fitting tensor networks."""
from autoray import dag, do
from .contraction import contract_strategy
from ..utils import check_opt
def tensor_network_distance(
tnA,
tnB,
xAA=None,
xAB=None,
xBB=None,
method="auto",
normalized=False,
**contract_opts,
):
r"""Compute the Frobenius norm distance between two tensor networks:
.. math::
D(A, B)
= | A - B |_{\mathrm{fro}}
= \mathrm{Tr} [(A - B)^{\dagger}(A - B)]^{1/2}
= ( \langle A | A \rangle - 2 \mathrm{Re} \langle A | B \rangle|
+ \langle B | B \rangle ) ^{1/2}
which should have matching outer indices. Note the default approach to
computing the norm is precision limited to about ``eps**0.5`` where ``eps``
is the precision of the data type, e.g. ``1e-8`` for float64. This is due
to the subtraction in the above expression.
Parameters
----------
tnA : TensorNetwork or Tensor
The first tensor network operator.
tnB : TensorNetwork or Tensor
The second tensor network operator.
xAA : None or scalar
The value of ``A.H @ A`` if you already know it (or it doesn't matter).
xAB : None or scalar
The value of ``A.H @ B`` if you already know it (or it doesn't matter).
xBB : None or scalar
The value of ``B.H @ B`` if you already know it (or it doesn't matter).
method : {'auto', 'overlap', 'dense'}, optional
How to compute the distance. If ``'overlap'``, the default, the
distance will be computed as the sum of overlaps, without explicitly
forming the dense operators. If ``'dense'``, the operators will be
directly formed and the norm computed, which can be quicker when the
exterior dimensions are small. If ``'auto'``, the dense method will
be used if the total operator (outer) size is ``<= 2**16``.
normalized : bool, optional
If ``True``, then normalize the distance by the norm of the two
operators, i.e. ``2 * D(A, B) / (|A| + |B|)``. The resulting distance
lies between 0 and 2 and is more useful for assessing convergence.
contract_opts
Supplied to :meth:`~quimb.tensor.tensor_core.TensorNetwork.contract`.
Returns
-------
D : float
"""
check_opt("method", method, ("auto", "dense", "overlap"))
tnA = tnA.as_network()
tnB = tnB.as_network()
oix = tnA.outer_inds()
if set(oix) != set(tnB.outer_inds()):
raise ValueError(
"Can only compute distance between tensor "
"networks with matching outer indices."
)
if method == "auto":
d = tnA.inds_size(oix)
if d <= 1 << 16:
method = "dense"
else:
method = "overlap"
# directly from vectorizations of both
if method == "dense":
tnA = tnA.contract(..., output_inds=oix, preserve_tensor=True)
tnB = tnB.contract(..., output_inds=oix, preserve_tensor=True)
# overlap method
if xAA is None:
xAA = (tnA | tnA.H).contract(..., **contract_opts)
if xAB is None:
xAB = (tnA | tnB.H).contract(..., **contract_opts)
if xBB is None:
xBB = (tnB | tnB.H).contract(..., **contract_opts)
dAB = do("abs", xAA - 2 * do("real", xAB) + xBB) ** 0.5
if normalized:
dAB *= 2 / (do("abs", xAA)**0.5 + do("abs", xBB)**0.5)
return dAB
def tensor_network_fit_autodiff(
tn,
tn_target,
steps=1000,
tol=1e-9,
autodiff_backend="autograd",
contract_optimize="auto-hq",
distance_method="auto",
inplace=False,
progbar=False,
**kwargs,
):
"""Optimize the fit of ``tn`` with respect to ``tn_target`` using
automatic differentation. This minimizes the norm of the difference
between the two tensor networks, which must have matching outer indices,
using overlaps.
Parameters
----------
tn : TensorNetwork
The tensor network to fit.
tn_target : TensorNetwork
The target tensor network to fit ``tn`` to.
steps : int, optional
The maximum number of autodiff steps.
tol : float, optional
The target norm distance.
autodiff_backend : str, optional
Which backend library to use to perform the gradient computation.
contract_optimize : str, optional
The contraction path optimized used to contract the overlaps.
distance_method : {'auto', 'dense', 'overlap'}, optional
Supplied to :func:`~quimb.tensor.tensor_core.tensor_network_distance`,
controls how the distance is computed.
inplace : bool, optional
Update ``tn`` in place.
progbar : bool, optional
Show a live progress bar of the fitting process.
kwargs
Passed to :class:`~quimb.tensor.tensor_core.optimize.TNOptimizer`.
See Also
--------
tensor_network_distance, tensor_network_fit_als
"""
from .optimize import TNOptimizer
from .tensor_core import tensor_network_distance
xBB = (tn_target | tn_target.H).contract(
...,
output_inds=(),
optimize=contract_optimize,
)
tnopt = TNOptimizer(
tn=tn,
loss_fn=tensor_network_distance,
loss_constants={"tnB": tn_target, "xBB": xBB},
loss_kwargs={"method": distance_method, "optimize": contract_optimize},
autodiff_backend=autodiff_backend,
progbar=progbar,
**kwargs,
)
tn_fit = tnopt.optimize(steps, tol=tol)
if not inplace:
return tn_fit
for t1, t2 in zip(tn, tn_fit):
t1.modify(data=t2.data)
return tn
def _tn_fit_als_core(
var_tags,
tnAA,
tnAB,
xBB,
tol,
contract_optimize,
steps,
enforce_pos,
pos_smudge,
solver="solve",
progbar=False,
):
from .tensor_core import group_inds
# shared intermediates + greedy = good reuse of contractions
with contract_strategy(contract_optimize):
# prepare each of the contractions we are going to repeat
env_contractions = []
for tg in var_tags:
# varying tensor and conjugate in norm <A|A>
tk = tnAA["__KET__", tg]
tb = tnAA["__BRA__", tg]
# get inds, and ensure any bonds come last, for linalg.solve
lix, bix, rix = group_inds(tb, tk)
tk.transpose_(*rix, *bix)
tb.transpose_(*lix, *bix)
# form TNs with 'holes', i.e. environment tensors networks
A_tn = tnAA.select((tg,), "!all")
y_tn = tnAB.select((tg,), "!all")
env_contractions.append((tk, tb, lix, bix, rix, A_tn, y_tn))
if tol != 0.0:
old_d = float("inf")
if progbar:
import tqdm
pbar = tqdm.trange(steps)
else:
pbar = range(steps)
# the main iterative sweep on each tensor, locally optimizing
for _ in pbar:
for tk, tb, lix, bix, rix, A_tn, y_tn in env_contractions:
Ni = A_tn.to_dense(lix, rix)
Wi = y_tn.to_dense(rix, bix)
if enforce_pos:
el, ev = do("linalg.eigh", Ni)
el = do("clip", el, el[-1] * pos_smudge, None)
Ni_p = ev * do("reshape", el, (1, -1)) @ dag(ev)
else:
Ni_p = Ni
if solver == "solve":
x = do("linalg.solve", Ni_p, Wi)
elif solver == "lstsq":
x = do("linalg.lstsq", Ni_p, Wi, rcond=pos_smudge)[0]
x_r = do("reshape", x, tk.shape)
# n.b. because we are using virtual TNs -> updates propagate
tk.modify(data=x_r)
tb.modify(data=do("conj", x_r))
# assess | A - B | for convergence or printing
if (tol != 0.0) or progbar:
xAA = do("trace", dag(x) @ (Ni @ x)) # <A|A>
xAB = do("trace", do("real", dag(x) @ Wi)) # <A|B>
d = do("abs", (xAA - 2 * xAB + xBB)) ** 0.5
if abs(d - old_d) < tol:
break
old_d = d
if progbar:
pbar.set_description(str(d))
def tensor_network_fit_als(
tn,
tn_target,
tags=None,
steps=100,
tol=1e-9,
solver="solve",
enforce_pos=False,
pos_smudge=None,
tnAA=None,
tnAB=None,
xBB=None,
contract_optimize="greedy",
inplace=False,
progbar=False,
):
"""Optimize the fit of ``tn`` with respect to ``tn_target`` using
alternating least squares (ALS). This minimizes the norm of the difference
between the two tensor networks, which must have matching outer indices,
using overlaps.
Parameters
----------
tn : TensorNetwork
The tensor network to fit.
tn_target : TensorNetwork
The target tensor network to fit ``tn`` to.
tags : sequence of str, optional
If supplied, only optimize tensors matching any of given tags.
steps : int, optional
The maximum number of ALS steps.
tol : float, optional
The target norm distance.
solver : {'solve', 'lstsq', ...}, optional
The underlying driver function used to solve the local minimization,
e.g. ``numpy.linalg.solve`` for ``'solve'`` with ``numpy`` backend.
enforce_pos : bool, optional
Whether to enforce positivity of the locally formed environments,
which can be more stable.
pos_smudge : float, optional
If enforcing positivity, the level below which to clip eigenvalues
for make the local environment positive definite.
tnAA : TensorNetwork, optional
If you have already formed the overlap ``tn.H & tn``, maybe
approximately, you can supply it here. The unconjugated layer should
have tag ``'__KET__'`` and the conjugated layer ``'__BRA__'``. Each
tensor being optimized should have tag ``'__VAR{i}__'``.
tnAB : TensorNetwork, optional
If you have already formed the overlap ``tn_target.H & tn``, maybe
approximately, you can supply it here. Each tensor being optimized
should have tag ``'__VAR{i}__'``.
xBB : float, optional
If you have already know, have computed ``tn_target.H @ tn_target``,
or it doesn't matter, you can supply the value here.
contract_optimize : str, optional
The contraction path optimized used to contract the local environments.
Note ``'greedy'`` is the default in order to maximize shared work.
inplace : bool, optional
Update ``tn`` in place.
progbar : bool, optional
Show a live progress bar of the fitting process.
Returns
-------
TensorNetwork
See Also
--------
tensor_network_fit_autodiff, tensor_network_distance
"""
# mark the tensors we are going to optimize
tna = tn.copy()
tna.add_tag("__KET__")
if tags is None:
to_tag = tna
else:
to_tag = tna.select_tensors(tags, "any")
var_tags = []
for i, t in enumerate(to_tag):
var_tag = f"__VAR{i}__"
t.add_tag(var_tag)
var_tags.append(var_tag)
# form the norm of the varying TN (A) and its overlap with the target (B)
if tnAA is None:
tnAA = tna | tna.H.retag_({"__KET__": "__BRA__"})
if tnAB is None:
tnAB = tna | tn_target.H
if (tol != 0.0) and (xBB is None):
# <B|B>
xBB = (tn_target | tn_target.H).contract(
...,
optimize=contract_optimize,
output_inds=(),
)
if pos_smudge is None:
pos_smudge = max(tol, 1e-15)
_tn_fit_als_core(
var_tags=var_tags,
tnAA=tnAA,
tnAB=tnAB,
xBB=xBB,
tol=tol,
contract_optimize=contract_optimize,
steps=steps,
enforce_pos=enforce_pos,
pos_smudge=pos_smudge,
solver=solver,
progbar=progbar,
)
if not inplace:
tn = tn.copy()
for t1, t2 in zip(tn, tna):
# transpose so only thing changed in original TN is data
t2.transpose_like_(t1)
t1.modify(data=t2.data)
return tn