In [1]:
# download src image + imagenet_classes.txt + _label.txt + _data.npz
from datasets import load_dataset
from pathlib import Path
from PIL import Image
import numpy as np
import os
import torchvision.transforms as transforms

# 定义图像预处理步骤
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def download_dataset(data_dir: Path, split: str, size: int):
  image_dir = data_dir / str(split)
  image_dir.mkdir(parents=True, exist_ok=True)

  # Load the dataset in streaming mode to avoid downloading the full dataset
  dataset = load_dataset(
    'imagenet-1k',
    split=split,
    streaming=True,
    trust_remote_code=True,
  )

  # 创建保存图片和标签的目录
  label_file = os.path.join(image_dir, '_labels.txt')

  labels = []
  images = []
  # 打开标签文件以写入标签信息
  with open(label_file, 'w') as f:
      for i, sample in enumerate(dataset):
          # 可根据需要设置保存的样本数量，避免保存过多数据
          if i >= size:  # 这里仅保存前 size 个样本，可根据需求修改
              break
          
          # 获取图片和标签
          image = sample['image']
          label = sample['label']       

          image_path = os.path.join(image_dir, f'{i}.jpg')
          image.save(image_path)

          # image = Image.open(image_path)

          image = image.convert('RGB')
          image = preprocess(image)
          
          images.append(image)
          labels.append(label)

          # 保存标签信息到文件
          f.write(f'{i}.jpg,{label}\n')

  print(f"Images saved to {image_dir} and labels saved to {label_file}.")

  np.savez(os.path.join(image_dir, '_data.npz'), images=np.array(images), labels=np.array(labels))

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from pathlib import Path

data_dir = Path('./data/imagenet').resolve()
data_dir.mkdir(parents=True, exist_ok=True)

download_dataset(data_dir, 'validation', size=300)

Images saved to C:\Dev\AI\Models\resnet\data\imagenet\validation and labels saved to C:\Dev\AI\Models\resnet\data\imagenet\validation\_labels.txt.


In [3]:
# Download ImageNet labels
import urllib

url, classFile = ("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt", os.path.join(data_dir, "imagenet_classes.txt"))
try: urllib.URLopener().retrieve(url, classFile)
except: urllib.request.urlretrieve(url, classFile)

In [4]:
# url, imgFile = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")