## 配置

In [2]:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf

AUTOTUNE = tf.data.experimental.AUTOTUNE

## 下载并检查数据集
### 检索图片
在你开始任何训练之前，你将需要一组图片来教会网络你想要训练的新类别。你已经创建了一个文件夹，存储了最初使用的拥有创作共用许可的花卉照片。

In [None]:
import pathlib
data_root_orig = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
                                         fname='flower_photos', untar=True)
data_root = pathlib.Path(data_root_orig)
print(data_root)

# 下载了 218 MB 之后，你现在应该有花卉照片副本：
for item in data_root.iterdir():
  print(item)

import random
all_image_paths = list(data_root.glob('*/*'))
all_image_paths = [str(path) for path in all_image_paths]
random.shuffle(all_image_paths)

image_count = len(all_image_paths)
image_count
all_image_paths[:10]

Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
  2793472/228813984 [..............................] - ETA: 8:06:14

### 检查图片
现在让我们快速浏览几张图片，这样你知道你在处理什么：

In [None]:
import os
attributions = (data_root/"LICENSE.txt").open(encoding='utf-8').readlines()[4:]
attributions = [line.split(' CC-BY') for line in attributions]
attributions = dict(attributions)

import IPython.display as display
def caption_image(image_path):
    image_rel = pathlib.Path(image_path).relative_to(data_root)
    return "Image (CC BY 2.0) " + ' - '.join(attributions[str(image_rel)].split(' - ')[:-1])

for n in range(3):
  image_path = random.choice(all_image_paths)
  display.display(display.Image(image_path))
  print(caption_image(image_path))
  print()

## 确定每张图片的标签

In [None]:
# 列出可用的标签：
label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())
label_names

# 为每个标签分配索引：
label_to_index = dict((name, index) for index, name in enumerate(label_names))
label_to_index

# 创建一个列表，包含每个文件的标签索引：
all_image_labels = [label_to_index[pathlib.Path(path).parent.name]
                    for path in all_image_paths]

print("First 10 labels indices: ", all_image_labels[:10])

