Skip to content

Conversation

@yoshikisd
Copy link
Collaborator

@yoshikisd yoshikisd commented Jun 12, 2025

When running the transmission_RPI.py example script, an error is raised when calling model.compare(dataset); see failure log at the end. This bug seems to originate from the shape of sim_data in the following line

sim_data = self.forward(*inputs).detach().cpu().numpy()[0]

Specifically, in the case of a FancyPtycho (and presumably other ptychography) model, the shape of sim_data will be (N, M). However, if an RPI model is used, the shape of sim_data becomes (N,), which causes an invalid shape error to get thrown when plotting an image.

The bugfix solves this issue by only adding the [0] whenever the shape of the sim_data is bigger than 2 dimensions.

Failure log

cdtools/src/cdtools/models/base.py:922: UserWarning: Attempting to set identical low and high xlims makes transformation singular; automatically expanding.
  self.slider = Slider(axslider, 'Pattern #', 0, len(dataset)-1, valstep=1, valfmt="%d")
Traceback (most recent call last):
  File "/homes/dayne/repositories/cdtools_yoshikisd/cdtools/examples/transmission_RPI.py", line 46, in <module>
    model.compare(dataset)
  File "/homes/dayne/repositories/cdtools_yoshikisd/cdtools/src/cdtools/models/base.py", line 940, in compare
    update(0)
  File "/homes/dayne/repositories/cdtools_yoshikisd/cdtools/src/cdtools/models/base.py", line 891, in update
    sim = axes[0].imshow(sim_data)
          ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/homes/dayne/.conda/envs/cdtools_testing/lib/python3.12/site-packages/matplotlib/__init__.py", line 1473, in inner
    return func(
           ^^^^^
  File "/homes/dayne/.conda/envs/cdtools_testing/lib/python3.12/site-packages/matplotlib/axes/_axes.py", line 5895, in imshow
    im.set_data(X)
  File "/homes/dayne/.conda/envs/cdtools_testing/lib/python3.12/site-packages/matplotlib/image.py", line 729, in set_data
    self._A = self._normalize_image_array(A)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/homes/dayne/.conda/envs/cdtools_testing/lib/python3.12/site-packages/matplotlib/image.py", line 697, in _normalize_image_array
    raise TypeError(f"Invalid shape {A.shape} for image data")
TypeError: Invalid shape (1000,) for image data

@yoshikisd yoshikisd added the bug Something isn't working label Jun 12, 2025
@yoshikisd yoshikisd changed the title [Bugfix] Get model.compare(dataset) working for RPI [Bugfix] model.compare(dataset) does not work for the RPI test script Jun 12, 2025
Copy link
Collaborator

@allevitan allevitan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for fixing this - the change looks good to me.

@allevitan allevitan merged commit eb19ce3 into cdtools-developers:master Jun 12, 2025
6 checks passed
@yoshikisd yoshikisd deleted the bugfix/rpi_compare branch June 12, 2025 15:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants