In [5]:
import torch
import timm
from timm.data import create_dataset, create_loader, resolve_data_config

In [6]:
# Create model and get its data config
model = timm.create_model("resnet18", pretrained=True)
model.eval()

# Get the model's expected input config (size, mean, std, etc.)
data_config = resolve_data_config(model.pretrained_cfg)
print(f"Model expects: {data_config}")

Model expects: {'input_size': (3, 224, 224), 'interpolation': 'bicubic', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'crop_pct': 0.95, 'crop_mode': 'center'}


In [17]:
loader

<timm.data.loader.PrefetchLoader at 0x119f06710>

In [7]:
# Create dataset using timm's factory
# For torch datasets, prefix with "torch/"
dataset = create_dataset(
    name="torch/cifar10",
    root="../data",
    split="test",      # "train" or "test" for CIFAR
    download=True,
)

print(f"Dataset size: {len(dataset)}")

Dataset size: 10000


  entry = pickle.load(f, encoding="latin1")


In [11]:
# Create loader with timm - handles transforms automatically!
loader = create_loader(
    dataset,
    input_size=data_config["input_size"],  # Uses model's expected size
    batch_size=8,
    is_training=False,
    mean=data_config["mean"],
    std=data_config["std"],
    crop_pct=data_config["crop_pct"],
    device=torch.device("cpu"),
)

print(f"Number of batches: {len(loader)}")

Number of batches: 1250


In [15]:
# Get one batch and run inference
images, labels = next(iter(loader))

print(f"Batch shape: {images.shape}")  # [8, 3, 224, 224]
print(f"Labels: {labels}")

# Run inference
with torch.no_grad():
    outputs = model(images)
    predictions = torch.argmax(outputs, dim=1)

print(f"Output shape: {outputs.shape}")  # [8, 1000] ImageNet classes
print(f"Predictions: {predictions}")

Batch shape: torch.Size([8, 3, 224, 224])
Labels: tensor([3, 8, 8, 0, 6, 6, 1, 6])
Output shape: torch.Size([8, 1000])
Predictions: tensor([758, 503, 510, 675, 335, 278, 675, 945])
