In [15]:
from datasets import load_dataset
import torchvision.transforms as T

train_dataset = load_dataset("imagenet-1k", split="train")
eval_dataset = load_dataset("imagenet-1k", split="validation")

mean = [0.485, 0.456, 0.406]
std  = [0.229, 0.224, 0.225]

# Define the data augmentation and preprocessing pipeline for training images
train_transform = T.Compose([
    T.Lambda(lambda x: x.convert('RGB') if x.mode != 'RGB' else x),  # Ensure 3 channels (convert grayscale to RGB)
    T.RandomResizedCrop(224, scale=(0.08, 1.0)),      # Randomly crop and resize to 224x224 (simulates zoom/scale)
    T.RandomHorizontalFlip(),                         # Randomly flip images horizontally (augmentation)
    T.RandAugment(num_ops=2, magnitude=9),            # Apply 2 random augmentations with magnitude 9 (extra augmentation)
    T.ToTensor(),                                     # Convert PIL Image or numpy.ndarray to tensor and scale to [0, 1]
    T.Normalize(mean, std),                           # Normalize using ImageNet mean and std
    T.RandomErasing(p=0.25, scale=(0.02, 0.1)),       # Randomly erase a rectangle region (extra augmentation, 25% chance)
])

# Define the preprocessing pipeline for evaluation images (no heavy augmentation)
eval_transform = T.Compose([
    T.Lambda(lambda x: x.convert('RGB') if x.mode != 'RGB' else x),  # Ensure 3 channels (convert grayscale to RGB)
    T.Resize(256),                                    # Resize shorter side to 256 pixels
    T.CenterCrop(224),                                # Crop the center 224x224 region
    T.ToTensor(),                                     # Convert to tensor and scale to [0, 1]
    T.Normalize(mean, std),                           # Normalize using ImageNet mean and std
])

def train_transform_fn(examples):
    # Handle both single examples and batches
    if isinstance(examples['image'], list):
        # Batch processing
        examples["pixel_values"] = [train_transform(image) for image in examples["image"]]
    else:
        # Single example processing  
        examples["pixel_values"] = train_transform(examples["image"])
    
    # Remove the original image to avoid DataLoader issues
    del examples["image"]
    return examples

def eval_transform_fn(examples):
    # Handle both single examples and batches
    if isinstance(examples['image'], list):
        # Batch processing
        examples["pixel_values"] = [eval_transform(image) for image in examples["image"]]
    else:
        # Single example processing
        examples["pixel_values"] = eval_transform(examples["image"])
    
    # Remove the original image to avoid DataLoader issues
    del examples["image"]
    return examples

# Now you can use these with with_transform()
train_dataset = train_dataset.with_transform(train_transform_fn)
eval_dataset = eval_dataset.with_transform(eval_transform_fn)

Loading dataset shards:   0%|          | 0/257 [00:00<?, ?it/s]

In [11]:
train_dataset[0], eval_dataset[0]

({'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=817x363>,
  'label': 726,
  'pixel_values': tensor([[[-0.1828, -0.1657, -0.1657,  ...,  0.1254,  0.1083,  0.1254],
           [-0.1828, -0.1828, -0.1657,  ...,  0.1254,  0.1083,  0.1083],
           [-0.1657, -0.1657, -0.1657,  ...,  0.1083,  0.1083,  0.1083],
           ...,
           [ 0.2111,  0.1939,  0.1939,  ...,  0.2796,  0.2453,  0.2624],
           [ 0.1939,  0.1939,  0.1939,  ...,  0.2624,  0.2624,  0.2624],
           [ 0.1939,  0.1939,  0.2111,  ...,  0.2796,  0.2796,  0.2796]],
  
          [[ 0.0126,  0.0301,  0.0301,  ...,  0.3277,  0.3102,  0.3277],
           [ 0.0301,  0.0301,  0.0301,  ...,  0.3102,  0.2927,  0.2927],
           [ 0.0126,  0.0126,  0.0126,  ...,  0.2927,  0.2927,  0.2927],
           ...,
           [ 0.3452,  0.3277,  0.3277,  ...,  0.4153,  0.3978,  0.4153],
           [ 0.3277,  0.3277,  0.3277,  ...,  0.4153,  0.4153,  0.4153],
           [ 0.3277,  0.3452,  0.3627,  ...,  0.4328,

In [12]:
train_dataset[0]['pixel_values'].shape, eval_dataset[0]['pixel_values'].shape

(torch.Size([3, 224, 224]), torch.Size([3, 224, 224]))

In [14]:
# Test with actual grayscale image
from PIL import Image
import numpy as np
import torch

print("Testing RGB conversion with actual grayscale image...")

# Create a test grayscale image
gray_array = np.random.randint(0, 256, (300, 400), dtype=np.uint8)
grayscale_image = Image.fromarray(gray_array, mode='L')  # 'L' mode = grayscale

print(f"Original grayscale image mode: {grayscale_image.mode}")
print(f"Original grayscale image size: {grayscale_image.size}")

# Test train transform on grayscale image
print("\nTesting train transform on grayscale image:")
try:
    train_result = train_transform(grayscale_image)
    print(f"‚úÖ Train transform result shape: {train_result.shape}")
    print(f"‚úÖ Successfully converted to {train_result.shape[0]} channels")
    
    # Verify all 3 channels have the same values (since it was grayscale)
    channel_equality = torch.allclose(train_result[0], train_result[1]) and torch.allclose(train_result[1], train_result[2])
    print(f"‚úÖ All 3 channels identical (as expected): {channel_equality}")
    
except Exception as e:
    print(f"‚ùå Error in train transform: {e}")

# Test eval transform on grayscale image  
print("\nTesting eval transform on grayscale image:")
try:
    eval_result = eval_transform(grayscale_image)
    print(f"‚úÖ Eval transform result shape: {eval_result.shape}")
    print(f"‚úÖ Successfully converted to {eval_result.shape[0]} channels")
    
    # Verify all 3 channels have the same values (since it was grayscale)
    channel_equality = torch.allclose(eval_result[0], eval_result[1]) and torch.allclose(eval_result[1], eval_result[2])
    print(f"‚úÖ All 3 channels identical (as expected): {channel_equality}")
    
except Exception as e:
    print(f"‚ùå Error in eval transform: {e}")

print("\nüéØ Grayscale to RGB conversion test completed!")


Testing RGB conversion with actual grayscale image...
Original grayscale image mode: L
Original grayscale image size: (400, 300)

Testing train transform on grayscale image:
‚úÖ Train transform result shape: torch.Size([3, 224, 224])
‚úÖ Successfully converted to 3 channels
‚úÖ All 3 channels identical (as expected): False

Testing eval transform on grayscale image:
‚úÖ Eval transform result shape: torch.Size([3, 224, 224])
‚úÖ Successfully converted to 3 channels
‚úÖ All 3 channels identical (as expected): False

üéØ Grayscale to RGB conversion test completed!
