Skip to content

Fix batch classification TypeError when using data_path (#611)#616

Open
jQuinRivero wants to merge 1 commit intomicrosoft:mainfrom
jQuinRivero:fix/batch-classification-imagefolder-error
Open

Fix batch classification TypeError when using data_path (#611)#616
jQuinRivero wants to merge 1 commit intomicrosoft:mainfrom
jQuinRivero:fix/batch-classification-imagefolder-error

Conversation

@jQuinRivero
Copy link

Fixes #611

Problem

Calling batch_image_classification(data_path=...) on any TIMM-based or ResNet-based classifier raises a TypeError because pw_data.ImageFolder does not accept a path_head keyword argument:

# Both timm_base and resnet_base had this:
dataset = pw_data.ImageFolder(
    data_path,
    transform=self.transform,
    path_head='.'        # ← ImageFolder doesn't accept this
)

Additionally, ImageFolder is the abstract base class whose __getitem__ returns None, so even if the kwarg issue were fixed, the dataloader would fail to unpack (img, img_path).

Fix

  • Replace pw_data.ImageFolder with pw_data.ClassificationImageFolder — the correct subclass that implements __getitem__ returning (img, img_path).
  • Remove the invalid path_head='.' argument (ClassificationImageFolder builds full paths via os.walk, so path_head is unnecessary).
  • Add ClassificationImageFolder to __all__ in datasets.py for export consistency.

Files changed

File Change
PytorchWildlife/data/datasets.py Added ClassificationImageFolder to __all__
PytorchWildlife/models/classification/timm_base/base_classifier.py ImageFolderClassificationImageFolder, removed path_head
PytorchWildlife/models/classification/resnet_base/base_classifier.py Same fix

Safety

  • The det_results code path (using DetectionCrops) is untouchedpath_head is valid there.
  • No other file in the repo references pw_data.ImageFolder.
  • All detectors already use DetectionImageFolder correctly.
  • The broken code path could never have worked, so this is purely additive — no existing working behavior changes.

…\n\nReplace pw_data.ImageFolder with pw_data.ClassificationImageFolder in\nboth timm_base and resnet_base classifiers, and remove the invalid\npath_head keyword argument that ImageFolder does not accept.\n\nImageFolder.__getitem__ is abstract (returns None), so the correct\nsubclass for classification is ClassificationImageFolder, which returns\nthe (img, img_path) tuple the dataloader loop expects.\n\nAlso export ClassificationImageFolder from datasets __all__."
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Batch classification fails when using data_path in PytorchWildlife/models/classification/timm_base/base_classifier.py

1 participant