-
Notifications
You must be signed in to change notification settings - Fork 6
Labels
bugSomething isn't workingSomething isn't working
Description
Bug description
I am trying to swap out the Adam_optimize with SGD_optimize in the fancy_ptycho.py example script to ensure that my refactoring of SGD_optimize, at a minimum, preserves its original behavior. Regarding why I'm doing a refactoring, see PR #17 (comment).
The script will run one epoch with a 'nan' loss and subsequently crashes. Failure log shown at the bottom.
Script to reproduce
On the master branch of cdtools, run the following script
import cdtools
from matplotlib import pyplot as plt
filename = 'example_data/lab_ptycho_data.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)
model = cdtools.models.FancyPtycho.from_dataset(
dataset,
n_modes=3,
oversampling=2,
probe_support_radius=120,
propagation_distance=5e-3,
units='mm',
obj_view_crop=-50,
)
device = 'cuda'
model.to(device=device)
dataset.get_as(device=device)
for loss in model.SGD_optimize(50, dataset):
print(model.report())OS and package info
- Python 3.12.2
- Pytorch 2.5.1
- CUDA 12.4
Failure log
Epoch 1 completed in 3.27 s with loss nan
Traceback (most recent call last):
File "/cdtools_yoshikisd/cdtools/examples/fancy_ptycho.py", line 25, in <module>
for loss in model.SGD_optimize(50, dataset):#, lr=0.02, batch_size=10):
File "/cdtools_yoshikisd/cdtools/src/cdtools/models/base.py", line 492, in AD_optimize
raise res
File "/cdtools_yoshikisd/cdtools/src/cdtools/models/base.py", line 457, in target
result_queue.put(run_epoch(stop_event))
^^^^^^^^^^^^^^^^^^^^^
File "/cdtools_yoshikisd/cdtools/src/cdtools/models/base.py", line 417, in run_epoch
loss += optimizer.step(closure).detach().cpu().numpy()
^^^^^^^^^^^^^^^^^^^^^^^
File "/.conda/envs/cdtools_testing/lib/python3.12/site-packages/torch/optim/optimizer.py", line 487, in wrapper
out = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/.conda/envs/cdtools_testing/lib/python3.12/site-packages/torch/optim/optimizer.py", line 91, in _use_grad
ret = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.conda/envs/cdtools_testing/lib/python3.12/site-packages/torch/optim/sgd.py", line 112, in step
loss = closure()
^^^^^^^^^
File "/cdtools_yoshikisd/cdtools/src/cdtools/models/base.py", line 395, in closure
sim_patterns = self.forward(*inp)
^^^^^^^^^^^^^^^^^^
File "/cdtools_yoshikisd/cdtools/src/cdtools/models/base.py", line 94, in forward
return self.measurement(self.forward_propagator(self.interaction(*args)))
^^^^^^^^^^^^^^^^^^^^^^^
File "/cdtools_yoshikisd/cdtools/src/cdtools/models/fancy_ptycho.py", line 502, in interaction
exit_waves = self.probe_norm * tools.interactions.ptycho_2D_sinc(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/cdtools_yoshikisd/cdtools/src/cdtools/tools/interactions/interactions.py", line 472, in ptycho_2D_sinc
output = shifted_probe * selections[...,None,:,:]
~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~
RuntimeError: The size of tensor a (278) must match the size of tensor b (0) at non-singleton dimension 3
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working