In [None]:
import torch, os
import matplotlib.pyplot as plt

import torch.nn as nn

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from timm.layers import DropPath, to_2tuple, trunc_normal_

In [None]:
current_dir = os.getcwd()
train_path = os.path.join(current_dir, 'imagenet1k', 'train')
test_path = os.path.join(current_dir, 'imagenet1k', 'test')
idx_to_label_path = os.path.join(current_dir, 'imagenet1k', 'idx_to_label.txt')

## Create datasets and dataloaders

In [None]:
import ast

NUM_WORKERS = os.cpu_count() 

def create_loaders(train_dir, test_dir, idx_to_label_path, transform, batch_size, num_workers=NUM_WORKERS):
    train_dataset = datasets.ImageFolder(train_dir, transform=transform)
    test_dataset = datasets.ImageFolder(test_dir, transform=transform)

    class_names = train_dataset.classes
    
    # Load the idx_to_label mapping
    with open(idx_to_label_path, 'r') as f:
        content = f.read()
        idx_to_label = ast.literal_eval(content)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

    return train_loader, test_loader, class_names, idx_to_label

In [None]:
IMG_SIZE = 224

manual_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])

print(f"Manually created transforms: {manual_transforms}")

In [None]:
BATCH_SIZE = 32

train_loader, test_loader, class_names, idx_to_label = create_loaders(train_path, test_path, idx_to_label_path, manual_transforms, BATCH_SIZE)

print(f"Number of classes: {len(class_names)}")
print(f"First 10 idx_to_label mappings: {list(idx_to_label.items())[:10]}")
print(f"Train loader: {train_loader}")
print(f"Test loader: {test_loader}")