Fix grayscale channel dimension (1) lost in DataLoader#307
Fix grayscale channel dimension (1) lost in DataLoader#307cregouby merged 6 commits intomlverse:mainfrom
Conversation
b793f64 to
e14d80b
Compare
cregouby
left a comment
There was a problem hiding this comment.
praise thanks for this improvements
todo see inline
| res <- torch::torch_tensor(img)$permute(c(3, 1, 2)) | ||
| dims <- dim(img) | ||
| if (length(dims) != 3) | ||
| stop("Expected a 2D or 3D array.") |
There was a problem hiding this comment.
todo Please rely on functions from conditions.R for correct user message translation
| if (length(index) != 1) { | ||
| return(lapply(as.integer(index), function(i) self$.getitem(i))) | ||
| } |
There was a problem hiding this comment.
suggestion this is supposed to be the .getbatch() function. I would rather implement the .getbatch than trying to implement a get-batch-within-getitem.
|
|
||
| x <- self$data[index, , ] | ||
| y <- self$targets[index] | ||
| idx <- as.integer(index) |
There was a problem hiding this comment.
suggestion `.getitem() is mainly an internal function for dataloader to run. So index will always be integer. I'd prefer to avoid the perrformance impact and remove this line.
| x <- if (grepl("\\\\.xz$", path, ignore.case = TRUE)) { | ||
| xzfile(path, open = "rb") | ||
| } else { | ||
| gzfile(path, open = "rb") | ||
| } |
| if (length(index) != 1) { | ||
| return(lapply(as.integer(index), function(i) self$.getitem(i))) | ||
| } | ||
|
|
||
| idx <- as.integer(index) |
There was a problem hiding this comment.
todo Same as for L454-L458
| expect_named(raw_items[[1]], c("x", "y")) | ||
| expect_equal(dim(raw_items[[1]]$x), c(1, 28, 28)) | ||
| expect_named(raw_items, c("x", "y")) | ||
| expect_equal(dim(raw_items$x), c(2, 28, 28)) |
There was a problem hiding this comment.
question so there is no "fix of the grayscale channel dimension lost in dataloader", true ?
| if (!is.null(self$target_transform)) | ||
| y <- self$target_transform(y) | ||
| if (!is.null(self$target_transform)) { | ||
| if (length(index) > 1) { |
…s[1] to be ndim=4. make emnist_collection inherit mnist to remove duplicated code fix tests accordingly
|
This would fix mlverse/torch#1417 but would raise mlverse/luz#159 |
|
Have we looked at : #264 and how this is handled? |
Yes, this is where mlverse/luz#160 comes into play. I'll try to better document it in the {luz} pull request. |
|
I believe this is tricky terrain. In theory torch vision transforms are not implemented to operate on batches of images, so while we can fix for some luz code, it will inevitably be problematic for other transforms. Or otherwise we should error earlier if a trasnform takes a 4d tensor as input, which would clearly indicate it's a batch of images. |
|
Good point. I'll revert (again) on that part, and make that transform limitation explicit. |
Fix grayscale channel handling in MNIST-style datasets
Summary
This PR fixes an issue where grayscale images could lose their channel dimension during batching, resulting in shapes like
[B, 28, 28]instead of[B, 1, 28, 28](seen on Apple M2). This could cause errors in models expecting channel-first input.Changes
.getitem()always preserves the channel dimensionbatch_size = 128kmnist_dataset.xzfilesImpact
Makes dataset behavior more consistent across environments and avoids shape-related errors during training.