In [1]:
import os
from edgetpu.basic.basic_engine import BasicEngine
from edgetpu.classification.engine import ClassificationEngine
from edgetpu.learn.imprinting.engine import ImprintingEngine
import numpy as np
from PIL import Image

In [2]:
def _ReadData(path, test_ratio):
  """Parses data from given directory, split them into two sets.

  Args:
    path: string, path of the data set. Images are stored in sub-directory
      named by category.
    test_ratio: float in (0,1), ratio of data used for testing.

  Returns:
    (train_set, test_set), A tuple of two dicts. Keys are the categories and
      values are lists of image file names.
  """
  train_set = {}
  test_set = {}
  for category in os.listdir(path):
    category_dir = os.path.join(path, category)
    if os.path.isdir(category_dir):
      images = [f for f in os.listdir(category_dir)
                if os.path.isfile(os.path.join(category_dir, f))]
      if images:
        k = int(test_ratio * len(images))
        test_set[category] = images[:k]
        assert test_set[category], 'No images to test [{}]'.format(category)
        train_set[category] = images[k:]
        assert train_set[category], 'No images to train [{}]'.format(category)
  return train_set, test_set

In [3]:
def _PrepareImages(image_list, directory, shape):
  """Reads images and converts them to numpy array with given shape.

  Args:
    image_list: a list of strings storing file names.
    directory: string, path of directory storing input images.
    shape: a 2-D tuple represents the shape of required input tensor.

  Returns:
    A list of numpy.array.
  """
  ret = []
  for filename in image_list:
    with Image.open(os.path.join(directory, filename)) as img:
      img = img.resize(shape, Image.NEAREST)
      ret.append(np.asarray(img).flatten())
  return np.array(ret)


In [4]:
def _SaveLabels(labels, model_path):
  """Output labels as a txt file.

  Args:
    labels: {int : string}, map between label id and label.
    model_path: string, path of the model.
  """
  label_file_name = model_path.replace('.tflite', '.txt')
  with open(label_file_name, 'w') as f:
    for label_id, label in labels.items():
      f.write(str(label_id) + '  ' + label + '\n')
  print('Labels file saved as :', label_file_name)

In [5]:
def _GetRequiredShape(model_path):
  """Gets image shape required by model.

  Args:
    model_path: string, path of the model.

  Returns:
    (width, height).
  """
  tmp = BasicEngine(model_path)
  input_tensor = tmp.get_input_tensor_shape()
  return (input_tensor[2], input_tensor[1])



In [14]:
data_path = '/home/mendel/trainImage/flower_photos'
test_ratio = 0.95
extractor = '/home/mendel/trainImage/mobilenet_v1_1.0_224_quant_embedding_extractor_edgetpu.tflite'
output = '/home/mendel/trainImage/flower_model.tflite'

In [15]:
train_set, test_set = _ReadData(data_path, test_ratio)

print('Image list successfully parsed! Category Num = ', len(train_set))
shape = _GetRequiredShape(extractor)

print('---------------- Processing training data ----------------')
print('This process may take more than 30 seconds.')
train_input = {}
for category, image_list in train_set.items():
    print('Processing category:', category)
    train_input[category] = _PrepareImages(
        image_list, os.path.join(data_path, category), shape)

print('----------------      Start training     -----------------')
engine = ImprintingEngine(extractor)
labels_map = engine.TrainAll(train_input)
print('----------------     Training finished!  -----------------')

engine.SaveModel(output)
print('Model saved as : ', output)
_SaveLabels(labels_map, output)


Image list successfully parsed! Category Num =  5
---------------- Processing training data ----------------
This process may take more than 30 seconds.
Processing category: dandelion
Processing category: daisy
Processing category: tulips
Processing category: roses
Processing category: sunflowers
----------------      Start training     -----------------
----------------     Training finished!  -----------------
Model saved as :  /home/mendel/trainImage/flower_model.tflite
Labels file saved as : /home/mendel/trainImage/flower_model.txt
