Skip to content

Fix grayscale channel dimension (1) lost in DataLoader#307

Merged
cregouby merged 6 commits intomlverse:mainfrom
Chandraveersingh1717:fix-grayscale-dimension
Mar 30, 2026
Merged

Fix grayscale channel dimension (1) lost in DataLoader#307
cregouby merged 6 commits intomlverse:mainfrom
Chandraveersingh1717:fix-grayscale-dimension

Conversation

@Chandraveersingh1717
Copy link
Copy Markdown
Contributor

Fix grayscale channel handling in MNIST-style datasets

Summary

This PR fixes an issue where grayscale images could lose their channel dimension during batching, resulting in shapes like [B, 28, 28] instead of [B, 1, 28, 28] (seen on Apple M2). This could cause errors in models expecting channel-first input.

Changes

  • Ensure .getitem() always preserves the channel dimension
  • Add a regression test for batch_size = 128
  • Make tensor conversion more robust for different input layouts
  • Fix a small MD5 check bug
  • Correct a typo in kmnist_dataset
  • Improve QMNIST download handling
  • Add support for .xz files
  • Improve EMNIST download warning message

Impact

Makes dataset behavior more consistent across environments and avoids shape-related errors during training.

@Chandraveersingh1717 Chandraveersingh1717 changed the title Update deprecated parameters and improve documentation Fix grayscale channel dimension (1) lost in DataLoader (#306) Mar 24, 2026
Copy link
Copy Markdown
Collaborator

@cregouby cregouby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

praise thanks for this improvements
todo see inline

Comment thread R/transforms-array.R Outdated
res <- torch::torch_tensor(img)$permute(c(3, 1, 2))
dims <- dim(img)
if (length(dims) != 3)
stop("Expected a 2D or 3D array.")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

todo Please rely on functions from conditions.R for correct user message translation

Comment thread R/dataset-mnist.R Outdated
Comment on lines +454 to +456
if (length(index) != 1) {
return(lapply(as.integer(index), function(i) self$.getitem(i)))
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion this is supposed to be the .getbatch() function. I would rather implement the .getbatch than trying to implement a get-batch-within-getitem.

Comment thread R/dataset-mnist.R Outdated

x <- self$data[index, , ]
y <- self$targets[index]
idx <- as.integer(index)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion `.getitem() is mainly an internal function for dataloader to run. So index will always be integer. I'd prefer to avoid the perrformance impact and remove this line.

Comment thread R/dataset-mnist.R
Comment on lines +478 to +482
x <- if (grepl("\\\\.xz$", path, ignore.case = TRUE)) {
xzfile(path, open = "rb")
} else {
gzfile(path, open = "rb")
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

praise nice addition

Comment thread R/dataset-mnist.R Outdated
Comment on lines +152 to +156
if (length(index) != 1) {
return(lapply(as.integer(index), function(i) self$.getitem(i)))
}

idx <- as.integer(index)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

todo Same as for L454-L458

Comment thread tests/testthat/test-dataset-mnist.R Outdated
expect_named(raw_items[[1]], c("x", "y"))
expect_equal(dim(raw_items[[1]]$x), c(1, 28, 28))
expect_named(raw_items, c("x", "y"))
expect_equal(dim(raw_items$x), c(2, 28, 28))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question so there is no "fix of the grayscale channel dimension lost in dataloader", true ?

Comment thread R/dataset-mnist.R Outdated
if (!is.null(self$target_transform))
y <- self$target_transform(y)
if (!is.null(self$target_transform)) {
if (length(index) > 1) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

todo idem

…s[1] to be ndim=4.

make emnist_collection inherit mnist to remove duplicated code
fix tests accordingly
@cregouby
Copy link
Copy Markdown
Collaborator

cregouby commented Mar 29, 2026

This would fix mlverse/torch#1417 but would raise mlverse/luz#159

@cregouby cregouby merged commit 6c86e09 into mlverse:main Mar 30, 2026
3 checks passed
@cregouby cregouby changed the title Fix grayscale channel dimension (1) lost in DataLoader (#306) Fix grayscale channel dimension (1) lost in DataLoader Mar 30, 2026
@dfalbel
Copy link
Copy Markdown
Member

dfalbel commented Mar 30, 2026

Have we looked at : #264 and how this is handled?

@cregouby
Copy link
Copy Markdown
Collaborator

Have we looked at : #264 and how this is handled?

Yes, this is where mlverse/luz#160 comes into play. I'll try to better document it in the {luz} pull request.

@dfalbel
Copy link
Copy Markdown
Member

dfalbel commented Mar 31, 2026

I believe this is tricky terrain. In theory torch vision transforms are not implemented to operate on batches of images, so while we can fix for some luz code, it will inevitably be problematic for other transforms.

Or otherwise we should error earlier if a trasnform takes a 4d tensor as input, which would clearly indicate it's a batch of images.

@cregouby
Copy link
Copy Markdown
Collaborator

Good point. I'll revert (again) on that part, and make that transform limitation explicit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants