Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Silenced exception makes it harder to debug custom Transforms #1098

Closed
1 task done
Zhack47 opened this issue Jul 21, 2023 · 5 comments · Fixed by #1101
Closed
1 task done

Silenced exception makes it harder to debug custom Transforms #1098

Zhack47 opened this issue Jul 21, 2023 · 5 comments · Fixed by #1101
Labels
bug Something isn't working

Comments

@Zhack47
Copy link
Contributor

Zhack47 commented Jul 21, 2023

Is there an existing issue for this?

  • I have searched the existing issues

Bug summary

In the _get_next_subject() method of the Queue class, there is a tr / except statement which goes like this:

       try:
            subject = next(self.subjects_iterable)
        except StopIteration as exception:
            self._print('Queue is empty:', exception)
            self._initialize_subjects_iterable()
            subject = next(self.subjects_iterable)
        except AssertionError as exception:
            if 'can only test a child process' in str(exception):
                message = (
                    'The number of workers for the data loader used to pop'
                    ' patches from the queue should be 0. Is it?'
                )
                raise RuntimeError(message) from exception
        return subject

When an AssertionError arises and the if condition is not fulfilled, we get an UnboundLocalError telling us subject is not defined. The actual exception is lost, and this makes debugging harder.
In my case the AssertionError was :
AssertionError: Output of SimulateLowResolutionTransform is 5D
which explicits betterr whatt my problem was.

In order to remove the confusion, we could raise the original AssertionError if it does not fulfill the if statement

Code for reproduction

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchio.data.sampler.label import LabelSampler
from batchgenerators.augmentations.resample_augmentations import augment_linear_downsampling_scipy
from torchio import DATA, TYPE, LABEL, INTENSITY, IntensityTransform

class SimulateLowResolutionTransform(IntensityTransform):

    def __init__(self, zoom_range=(0.5, 1), per_channel=False, p_per_channel=1,
                 channels=None, order_downsample=1, order_upsample=0, data_key="data", p_per_sample=1,
                 ignore_axes=None):
        super().__init__(p_per_sample)
        self.order_upsample = order_upsample
        self.order_downsample = order_downsample
        self.channels = channels
        self.per_channel = per_channel
        self.p_per_channel = p_per_channel
        self.p_per_sample = p_per_sample
        self.data_key = data_key
        self.zoom_range = zoom_range
        self.ignore_axes = ignore_axes

    def apply_transform(self, subject):
        keys = sorted(subject.keys())
        for key in keys:
            if subject[key][TYPE] == INTENSITY:
                data = subject[key][DATA].unsqueeze(1).numpy()
                for i in range(data.shape[0]):
                    if np.random.uniform() < self.p_per_sample:
                        data[i] = augment_linear_downsampling_scipy(data[i], zoom_range=self.zoom_range,
                                                                    per_channel=self.per_channel,
                                                                    p_per_channel=self.p_per_channel,
                                                                    channels=self.channels,
                                                                    order_downsample=self.order_downsample,
                                                                    order_upsample=self.order_upsample,
                                                                    ignore_axes=self.ignore_axes)
                subject[key][DATA] = torch.tensor(data)
        return subject

if __name__ == "__main__":
    import torchio as tio
    st = SimulateLowResolutionTransform(zoom_range=(.05, 1.), per_channel=True,
                                                     p_per_channel=1.,
                                                     order_downsample=0, order_upsample=0, p_per_sample=1.,
                                                     ignore_axes=None)
    colin_dataset = tio.datasets.mni.Colin27()
    
    ds_train = tio.SubjectsDataset([colin_dataset], transform=st)
    sampler = LabelSampler((120, 120, 80))
    patches_queue_train = tio.Queue(ds_train, max_length=32, samples_per_volume=4, sampler=sampler,
                                    shuffle_patches=True, shuffle_subjects=True, num_workers=8)

    training_loader = DataLoader(patches_queue_train, batch_size=2, shuffle=True)
    for batch in training_loader:
        print(batch["t1"][DATA].shape)

Actual outcome

Traceback (most recent call last):
File "/home/zhack/Documents/THESE/4Net/fournet/utils/transforms/augmentations/spatial_augments.py", line 304, in
for batch in training_loader:
File "/home/zhack/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 628, in next
data = self._next_data()
File "/home/zhack/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 671, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/home/zhack/.local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 58, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/zhack/.local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 58, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/zhack/.local/lib/python3.10/site-packages/torchio/data/queue.py", line 170, in getitem
self._fill()
File "/home/zhack/.local/lib/python3.10/site-packages/torchio/data/queue.py", line 229, in _fill
subject = self._get_next_subject()
File "/home/zhack/.local/lib/python3.10/site-packages/torchio/data/queue.py", line 270, in _get_next_subject
return subject
UnboundLocalError: local variable 'subject' referenced before assignment

Error messages

UnboundLocalError: local variable 'subject' referenced before assignment

Expected outcome

AssertionError: Output of SimulateLowResolutionTransform is 5D

System info

pegar/torchio/main/print_system.py)
Platform:   Linux-5.15.0-79-generic-x86_64-with-glibc2.35
TorchIO:    0.18.78
PyTorch:    1.13.1+cu117
SimpleITK:  2.1.1.2 (ITK 5.2)
NumPy:      1.22.0
Python:     3.10.6 (main, May 29 2023, 11:10:38) [GCC 11.3.0]
@Zhack47 Zhack47 added the bug Something isn't working label Jul 21, 2023
@fepegar
Copy link
Owner

fepegar commented Jul 31, 2023

Hi, @Zhack47. Can you please share a minimal example I can reproduce?

@Zhack47
Copy link
Contributor Author

Zhack47 commented Aug 1, 2023

This is more minimal and should trigger the bug (tested on the same machine as above)

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchio.data.sampler.label import LabelSampler
from torchio import DATA, TYPE, LABEL, INTENSITY, IntensityTransform

class SimulateLowResolutionTransform(IntensityTransform):
    def __init__(self):
        super().__init__(1)

    def apply_transform(self, subject):
        keys = sorted(subject.keys())
        for key in keys:
                subject[key][DATA] = subject[key][DATA].unsqueeze(0)
        return subject

if __name__ == "__main__":
    import torchio as tio
    st = SimulateLowResolutionTransform()
    colin_dataset = tio.datasets.mni.Colin27()
    ds_train = tio.SubjectsDataset([colin_dataset], transform=st)
    sampler = LabelSampler((120, 120, 80))
    patches_queue_train = tio.Queue(ds_train, max_length=32, samples_per_volume=4, sampler=sampler,
                                    shuffle_patches=True, shuffle_subjects=True, num_workers=8)

    training_loader = DataLoader(patches_queue_train, batch_size=2, shuffle=True)
    for batch in training_loader:
        print(batch["t1"][DATA].shape)

@fepegar
Copy link
Owner

fepegar commented Aug 3, 2023

Thanks, @Zhack47. Good catch!
I think adding raise exception after line 330 would do. Do you agree?
Would you like to contribute with a PR?

@Zhack47
Copy link
Contributor Author

Zhack47 commented Aug 3, 2023

Thanks!
I completely agree with this solution
I am going to make a PR to fix this !

@fepegar
Copy link
Owner

fepegar commented Aug 3, 2023

Fixed in v0.19.1.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants