/
qgt_onthefly.py
235 lines (183 loc) · 7.17 KB
/
qgt_onthefly.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Union
from functools import partial
from textwrap import dedent
import jax
from jax import numpy as jnp
from flax import struct
import netket.jax as nkjax
from netket.utils.types import PyTree
from netket.utils import warn_deprecation
from .qgt_onthefly_logic import mat_vec_factory, mat_vec_chunked_factory
from ..linear_operator import LinearOperator, Uninitialized
def check_valid_vector_type(x, target):
"""
Raises a TypeError if x is complex where target is real, because it is not
supported by QGTOnTheFly and the imaginary part would be dicscarded after
anyhow.
"""
def check(x, target):
if jnp.iscomplexobj(target) and not jnp.iscomplexobj(x):
raise TypeError(
dedent(
"""
Cannot multiply the (real part of the) QGT by a complex vector.
You should either take the real part of the vector, or perform
the multiplication against the real and imaginary part of the
vector separately and then recomposing the two.
"""
)
)
jax.tree_multimap(check, x, target)
def QGTOnTheFly(vstate=None, **kwargs) -> "QGTOnTheFlyT":
"""
Lazy representation of an S Matrix computed by performing 2 jvp
and 1 vjp products, using the variational state's model, the
samples that have already been computed, and the vector.
The S matrix is not computed yet, but can be computed by calling
:code:`to_dense`.
The details on how the ⟨S⟩⁻¹⟨F⟩ system is solved are contaianed in
the field `sr`.
Args:
vstate: The variational State.
"""
if vstate is None:
return partial(QGTOnTheFly, **kwargs)
if "centered" in kwargs:
warn_deprecation(
"The argument `centered` is deprecated. The implementation now always behaves as if centered=False."
)
kwargs.pop("centered")
# TODO: Find a better way to handle this case
from netket.vqs import ExactState
if isinstance(vstate, ExactState):
raise TypeError("Only QGTJacobianPyTree works with ExactState.")
if jnp.ndim(vstate.samples) == 2:
samples = vstate.samples
else:
samples = vstate.samples.reshape((-1, vstate.samples.shape[-1]))
chunk_size = vstate.chunk_size
n_samples = samples.shape[0]
if chunk_size is None or chunk_size >= n_samples:
mv_factory = mat_vec_factory
chunking = False
else:
samples, _ = nkjax.chunk(samples, chunk_size)
mv_factory = mat_vec_chunked_factory
chunking = True
mat_vec = mv_factory(
forward_fn=vstate._apply_fun,
params=vstate.parameters,
model_state=vstate.model_state,
samples=samples,
)
return QGTOnTheFlyT(
_mat_vec=mat_vec,
_params=vstate.parameters,
_chunking=chunking,
**kwargs,
)
@struct.dataclass
class QGTOnTheFlyT(LinearOperator):
"""
Lazy representation of an S Matrix computed by performing 2 jvp
and 1 vjp products, using the variational state's model, the
samples that have already been computed, and the vector.
The S matrix is not computed yet, but can be computed by calling
:code:`to_dense`.
The details on how the ⟨S⟩⁻¹⟨F⟩ system is solved are contaianed in
the field `sr`.
"""
_mat_vec: Callable[[PyTree, float], PyTree] = Uninitialized
"""The S matrix-vector product as generated by mat_vec_factory.
It's a jax.Partial, so can be used as pytree_node."""
_params: PyTree = Uninitialized
"""The first input to apply_fun (parameters of the ansatz).
Only used as a shape placeholder."""
_chunking: bool = struct.field(pytree_node=False, default=False)
"""Wether the implementation with chunks is used which currently does not support vmapping over it"""
def __matmul__(self, y):
return onthefly_mat_treevec(self, y)
def _solve(self, solve_fun, y: PyTree, *, x0: Optional[PyTree], **kwargs) -> PyTree:
return _solve(self, solve_fun, y, x0=x0)
def to_dense(self) -> jnp.ndarray:
"""
Convert the lazy matrix representation to a dense matrix representation.
Returns:
A dense matrix representation of this S matrix.
"""
return _to_dense(self)
def __repr__(self):
return f"QGTOnTheFly(diag_shift={self.diag_shift})"
@jax.jit
def onthefly_mat_treevec(
S: QGTOnTheFly, vec: Union[PyTree, jnp.ndarray]
) -> Union[PyTree, jnp.ndarray]:
"""
Perform the lazy mat-vec product, where vec is either a tree with the same structure as
params or a ravelled vector
"""
# if hasa ndim it's an array and not a pytree
if hasattr(vec, "ndim"):
if not vec.ndim == 1:
raise ValueError("Unsupported mat-vec for chunks of vectors")
# If the input is a vector
if not nkjax.tree_size(S._params) == vec.size:
raise ValueError(
"""Size mismatch between number of parameters ({nkjax.tree_size(S.params)})
and vector size {vec.size}.
"""
)
_, unravel = nkjax.tree_ravel(S._params)
vec = unravel(vec)
ravel_result = True
else:
ravel_result = False
check_valid_vector_type(vec, S._params)
vec = nkjax.tree_cast(vec, S._params)
res = S._mat_vec(vec, S.diag_shift)
if ravel_result:
res, _ = nkjax.tree_ravel(res)
return res
@jax.jit
def _solve(
self: QGTOnTheFlyT, solve_fun, y: PyTree, *, x0: Optional[PyTree], **kwargs
) -> PyTree:
check_valid_vector_type(y, self._params)
y = nkjax.tree_cast(y, self._params)
# we could cache this...
if x0 is None:
x0 = jax.tree_map(jnp.zeros_like, y)
out, info = solve_fun(self, y, x0=x0)
return out, info
@jax.jit
def _to_dense(self: QGTOnTheFlyT) -> jnp.ndarray:
"""
Convert the lazy matrix representation to a dense matrix representation
Returns:
A dense matrix representation of this S matrix.
"""
Npars = nkjax.tree_size(self._params)
I = jax.numpy.eye(Npars)
if self._chunking:
# the linear_call in mat_vec_chunked does currently not have a jax batching rule,
# so it cannot be vmapped but we can use scan
# which is better for reducing the memory consumption anyway
_, out = jax.lax.scan(lambda _, x: (None, self @ x), None, I)
else:
out = jax.vmap(lambda x: self @ x, in_axes=0)(I)
if nkjax.is_complex(out):
out = out.T
return out