Skip to content

Commit

Permalink
[REF] Make progressbar description agnostic to operation
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Sep 1, 2023
1 parent 38a661b commit 7be48e8
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
19 changes: 14 additions & 5 deletions curvlinops/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def __init__(
self._device = self._infer_device(self._params)
self._progressbar = progressbar

self._N_data = sum(X.shape[0] for (X, _) in self._loop_over_data())
self._N_data = sum(
X.shape[0] for (X, _) in self._loop_over_data(desc="_N_data")
)

if check_deterministic:
old_device = self._device
Expand Down Expand Up @@ -206,7 +208,7 @@ def _matvec(self, x: ndarray) -> ndarray:
x_list = self._preprocess(x)
out_list = [zeros_like(x) for x in x_list]

for X, y in self._loop_over_data():
for X, y in self._loop_over_data(desc="_matvec"):
normalization_factor = self._get_normalization_factor(X, y)

for mat_x, current in zip(out_list, self._matvec_batch(X, y, x_list)):
Expand Down Expand Up @@ -266,16 +268,23 @@ def _postprocess(self, x_list: List[Tensor]) -> ndarray:
"""
return self.flatten_and_concatenate(x_list).cpu().numpy()

def _loop_over_data(self) -> Iterable[Tuple[Tensor, Tensor]]:
def _loop_over_data(
self, desc: Optional[str] = None
) -> Iterable[Tuple[Tensor, Tensor]]:
"""Yield batches of the data set, loaded to the correct device.
Args:
desc: Description for the progress bar. Will be ignored if progressbar is
disabled.
Yields:
Mini-batches ``(X, y)``.
"""
data_iter = iter(self._data)

if self._progressbar:
data_iter = tqdm(data_iter, desc="matvec")
desc = f"{self.__class__.__name__}{'' if desc is None else f'.{desc}'}"
data_iter = tqdm(data_iter, desc=desc)

for X, y in data_iter:
X, y = X.to(self._device), y.to(self._device)
Expand All @@ -298,7 +307,7 @@ def gradient_and_loss(self) -> Tuple[List[Tensor], Tensor]:
total_loss = tensor([0.0], device=self._device)
total_grad = [zeros_like(p) for p in self._params]

for X, y in self._loop_over_data():
for X, y in self._loop_over_data(desc="gradient_and_loss"):
loss = self._loss_func(self._model_func(X), y)
normalization_factor = self._get_normalization_factor(X, y)

Expand Down
10 changes: 6 additions & 4 deletions curvlinops/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def _check_deterministic(self):

with no_grad():
for (X1, y1), (X2, y2) in zip(
self._loop_over_data(), self._loop_over_data()
self._loop_over_data(desc="_check_deterministic_data_pred"),
self._loop_over_data(desc="_check_deterministic_data_pred2"),
):
pred1, y1 = self._model_func(X1).cpu().numpy(), y1.cpu().numpy()
pred2, y2 = self._model_func(X2).cpu().numpy(), y2.cpu().numpy()
Expand Down Expand Up @@ -117,7 +118,7 @@ def _matvec(self, x: ndarray) -> ndarray:
jvp(self._model_func(X), self._params, x_list, retain_graph=False)[
0
].flatten(start_dim=1)
for X, _ in self._loop_over_data()
for X, _ in self._loop_over_data(desc="_matvec")
]

return self._postprocess(out_list)
Expand Down Expand Up @@ -212,7 +213,8 @@ def _check_deterministic(self):

with no_grad():
for (X1, y1), (X2, y2) in zip(
self._loop_over_data(), self._loop_over_data()
self._loop_over_data(desc="_check_deterministic_data_pred1"),
self._loop_over_data(desc="_check_deterministic_data_pred2"),
):
pred1, y1 = self._model_func(X1).cpu().numpy(), y1.cpu().numpy()
pred2, y2 = self._model_func(X2).cpu().numpy(), y2.cpu().numpy()
Expand Down Expand Up @@ -240,7 +242,7 @@ def _matvec(self, x: ndarray) -> ndarray:
out_list = [zeros_like(p) for p in self._params]

processed = 0
for X, _ in self._loop_over_data():
for X, _ in self._loop_over_data(desc="_matvec"):
pred = self._model_func(X)
v = x_torch[processed : processed + pred.numel()].reshape_as(pred)
processed += pred.numel()
Expand Down

0 comments on commit 7be48e8

Please sign in to comment.