In [8]:
from config import Config

config = Config(
    data_dir='/mnt/d/datasets/mvtec',
    categories=['bottle', 'cable', 'capsule'],
    img_size=256,
    batch_size=32,
    valid_ratio=0.2,
    seed=42,
)

In [9]:
from data import MVTecDataset, get_transforms, get_dataloader, split_train_valid

# Get transforms
train_transform, test_transform = get_transforms(img_size=config.img_size)

# Create datasets
train_dataset = MVTecDataset(config.data_dir, config.categories, 'train', transform=train_transform)
valid_dataset = MVTecDataset(config.data_dir, config.categories, 'train', transform=test_transform)
test_dataset = MVTecDataset(config.data_dir, config.categories, 'test', transform=test_transform)

# Split train/valid
train_dataset, valid_dataset = split_train_valid(
    train_dataset, valid_dataset, 
    valid_ratio=config.valid_ratio, 
    seed=config.seed
)

# Create dataloaders
train_loader = get_dataloader(train_dataset, config.batch_size, 'train')
valid_loader = get_dataloader(valid_dataset, 16, 'valid')  # Smaller batch for validation
test_loader = get_dataloader(test_dataset, 16, 'test')

In [10]:
print(f"Train size: {len(train_loader.dataset)}")
print(f"Valid size: {len(valid_loader.dataset)}")
print(f"Test  size: {len(test_loader.dataset)}")

Train size: 522
Valid size: 130
Test  size: 365
