<a href="https://colab.research.google.com/github/ldeluigi/supermarket-2077-product-vision/blob/master/ProductDetection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Download datasets

In [None]:
!rm -rf sample_data
!gdown --id 1fDr4g4wbnSRkuCYyS3wpuJS7Ax22bVB_ -O all.zip
!unzip -oq all.zip

%matplotlib inline

## Imports

In [None]:
import scipy.io
import os
from pathlib import Path
import re
import cv2
import matplotlib.pyplot as plt
import numpy as np
import math
import itertools
import shutil
from tqdm.notebook import tqdm
from keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
from sklearn.metrics import confusion_matrix

## Data loaders

In [None]:
training_dirname = 'Training'

def create_class_label(class_index, class_name):
  return class_name

def read_classes():
  mat = scipy.io.loadmat(os.path.join(training_dirname, 'TrainingClassesIndex.mat'))
  raw_classes = list(map(lambda x: x[0], mat['classes'][0]))
  classes = map(lambda x: (x[0], create_class_label(*x)), enumerate(raw_classes, start=1))
  return dict(classes), dict(enumerate(raw_classes, start=1))

def read_training_data(classes):
  result = []
  for class_index, class_name in classes.items():
    dirname_images = os.path.join(training_dirname, class_name)
    directory_images = os.fsencode(dirname_images)
    for file in os.listdir(directory_images):
      img = cv2.imread(os.path.join(dirname_images, os.fsdecode(file)))
      img_rgb =  cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
      result.append((img_rgb, class_index))
  return np.rec.array(result, dtype=[('image', 'O'), ('class_index', 'i4')])

def read_store_data(storename):
  dirname_anno = os.path.join(storename, 'annotation')
  dirname_images = os.path.join(storename, 'images')
  directory_anno = os.fsencode(dirname_anno)
  directory_images = os.fsencode(dirname_images)

  result = []

  for file in os.listdir(directory_anno):
    filename = os.fsdecode(file)
    if filename.endswith(".mat"): 
      mat = scipy.io.loadmat(os.path.join(dirname_anno, filename))
      number = re.search(r'^anno.(\d+).mat$', filename).group(1)
      img = cv2.imread(os.path.join(dirname_images, number + '.jpg'))

      img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
      img_annotation = mat['annotation'][0, 0]
      bboxes = map(lambda x: x[0], img_annotation[0][0])
      labels = map(lambda x: str(x[0][0][0]), img_annotation[1][0])
      class_indexes = img_annotation[2][0]
      result.append((img_rgb, list(zip(bboxes, labels, class_indexes))))
  return np.rec.array(result, dtype=[('image', 'O'), ('items', 'O')])

## Data visualization utilities

In [None]:
def show_image(img):
  plt.axis('off')
  plt.imshow(img)

def show_grayscale_image(img):
  show_image(cv2.merge([img, img, img]))

def plot_grid(images, columns, show_axis=False, labels=None):
  height = 1 + math.ceil(len(images) / columns) * 2
  width = columns * 2
  dpi = max(images[0].shape[0], images[0].shape[1]) // 4
  fig = plt.figure(figsize=(width, height), dpi=dpi)
  fig.subplots_adjust(hspace=0.4)
  for index, img in enumerate(images, start=1):
    if 'float' in img.dtype.str:
      img = (img * 255).astype('uint8')
    sp = fig.add_subplot(math.ceil(len(images) / columns), columns, index)
    if not show_axis:
      plt.axis('off')
    plt.imshow(img)
    if labels is not None:
      sp.set_title(labels[(index-1) % columns], fontsize=10)
    else:
      sp.set_title(index, fontsize=10)

def dataset_plot_grid(indexes, columns, dataset, draw_item):
  fig = plt.figure(figsize=(12, 6), dpi=120)
  # fig.subplots_adjust(hspace=0.2)
  for index, i_img in enumerate(indexes, start=1):
    sp = fig.add_subplot(math.ceil(len(indexes) / columns), columns, index)
    row = dataset[i_img]
    draw_item(row, sp)

# Image search

## Prepare products class dictionary

In [None]:
classes, raw_classes = read_classes()

def class_name(class_index):
  return classes[class_index] if class_index >= 0 else None

## Load training raw images

In [None]:
products = read_training_data(raw_classes)

## Products visualization

In [None]:
def show_products_with_class(indexes, columns, dataset):
  def show_single_product_with_class(row, sp):
    plt.axis('off')
    plt.imshow(row.image)
    sp.set_title(class_name(row.class_index), fontsize=10)
  dataset_plot_grid(indexes, columns, dataset, show_single_product_with_class)

show_products_with_class(np.random.randint(0, len(products), 6), 3, products)

## Image preprocessing

### Background removal

In [None]:
# code taken from https://www.kaggle.com/vadbeg/opencv-background-removal and modified

def remove_background(img, threshold):
    gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    _, threshed = cv2.threshold(gray, threshold, 255, cv2.THRESH_BINARY_INV)

    kernel_size = round(max(img.shape[0], img.shape[1]) * 0.02)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
    morphed = cv2.morphologyEx(threshed, cv2.MORPH_CLOSE, kernel)

    cnts = cv2.findContours(morphed, 
                            cv2.RETR_EXTERNAL,
                            cv2.CHAIN_APPROX_SIMPLE)[0]

    cnts = sorted(cnts, key=cv2.contourArea)

    mask = cv2.drawContours(threshed, [cnts[-1]], 0, [255], cv2.FILLED)
    masked_data = cv2.bitwise_and(img, img, mask=mask)

    x, y, w, h = cv2.boundingRect(cnts[-1])
    dst = masked_data[y: y + h, x: x + w]

    alpha = mask[y: y + h, x: x + w]
    r, g, b = cv2.split(dst)

    rgba = [r, g, b, alpha]
    dst = cv2.merge(rgba, 4)
    return dst

n = np.random.randint(products.shape[0])
print(f'Index: {n}')
print(f'Class: {class_name(products[n].class_index)}')
plot_grid([products[n].image, remove_background(products[n].image, 250)], 2, show_axis=True)

### Image resize

In [None]:
def resize_image(img, size, color=[0,0,0,0]):
  target_w, target_h = size
  original_h, original_w, _ = img.shape
  target_ar = target_w / target_h
  original_ar = original_w / original_h

  scale_factor = target_h / original_h if target_ar > original_ar else target_w / original_w
  scaled_w = round(original_w * scale_factor)
  scaled_h = round(original_h * scale_factor)
  scaled_size = (scaled_w, scaled_h)
  resized = cv2.resize(img, scaled_size)

  delta_h = target_h - scaled_h
  delta_w = target_w - scaled_w
  top    = delta_h // 2
  left   = delta_w // 2
  bottom = delta_h - top
  right  = delta_w - left

  return cv2.copyMakeBorder(resized, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)

n = np.random.randint(products.shape[0])
image = products[n].image
image = remove_background(image, 250)
plot_grid([image, resize_image(image, (400, 400))], 2, show_axis=True)

## Dataset preparation



### Image cleaning

In [None]:
size = (224, 224)

def clean_image(img):
  threshold = 250
  img = remove_background(img, threshold)
  return resize_image(img, size)

### Export new dataset on disk

To save memory.

In [None]:
number_of_images_to_save = 6400
reduce_dataset = True

In [None]:
def get_class_from_sub_classes(complete_class):
  sub_classes = complete_class.split('/')
  top_level_class = sub_classes[0]
  if top_level_class == 'Food' or top_level_class == 'HouseProducts':
    return sub_classes[1]
  else:
    return top_level_class

In [None]:
train_data_directory = 'Temp'
shutil.rmtree(train_data_directory, ignore_errors=True)

if reduce_dataset:
  products_size = number_of_images_to_save
  products_to_dump = np.random.choice(products.shape[0], number_of_images_to_save, replace = False)
else:
  products_size = products.shape[0]
  products_to_dump = np.arange(products_size)

for index, (image, class_index) in tqdm(enumerate(products[products_to_dump]), total=products_size, desc='Writing files...'):
  output_dir = os.path.join(train_data_directory, str(index))
  Path(output_dir).mkdir(parents=True, exist_ok=True)
  out = cv2.cvtColor(clean_image(image), cv2.COLOR_RGBA2BGRA)
  cv2.imwrite(os.path.join(output_dir, f'{index}.png'), out)


## Data Augmentation

### 3D rotation

In [None]:
def image_3D_rotation(img, theta = 0, phi = 0, gamma = 0, dx = 0, dy = 0, dz = 0):
  """
  Parameters:
      img       : the image data as numpy array
      theta     : rotation around the x axis
      phi       : rotation around the y axis
      gamma     : rotation around the z axis (basically a 2D rotation)
      dx        : translation along the x axis
      dy        : translation along the y axis
      dz        : translation along the z axis (distance to the image)
  Output:
      image     : the rotated image
  
  Reference:
      1.        : http://stackoverflow.com/questions/17087446/how-to-calculate-perspective-transform-for-opencv-from-rotation-angles
      2.        : http://jepsonsblog.blogspot.tw/2012/11/rotation-in-3d-using-opencvs.html
      3.        : Code taken from https://github.com/eborboihuc/rotate_3d/blob/master/image_transformer.py
  """
  def deg_to_rad(deg):
    return deg * math.pi / 180.0
  def get_M(theta, phi, gamma, dx, dy, dz, size, focal):
    w = size[0]
    h = size[1]
    f = focal
    # Projection 2D -> 3D matrix
    A1 = np.array([ [1, 0, -w/2],
                    [0, 1, -h/2],
                    [0, 0, 1],
                    [0, 0, 1]])
    # Rotation matrices around the X, Y, and Z axis
    RX = np.array([ [1, 0, 0, 0],
                    [0, np.cos(theta), -np.sin(theta), 0],
                    [0, np.sin(theta), np.cos(theta), 0],
                    [0, 0, 0, 1]])
    RY = np.array([ [np.cos(phi), 0, -np.sin(phi), 0],
                    [0, 1, 0, 0],
                    [np.sin(phi), 0, np.cos(phi), 0],
                    [0, 0, 0, 1]])
    RZ = np.array([ [np.cos(gamma), -np.sin(gamma), 0, 0],
                    [np.sin(gamma), np.cos(gamma), 0, 0],
                    [0, 0, 1, 0],
                    [0, 0, 0, 1]])
    # Composed rotation matrix with (RX, RY, RZ)
    R = np.dot(np.dot(RX, RY), RZ)
    # Translation matrix
    T = np.array([  [1, 0, 0, dx],
                    [0, 1, 0, dy],
                    [0, 0, 1, dz],
                    [0, 0, 0, 1]])
    # Projection 3D -> 2D matrix
    A2 = np.array([ [f, 0, w/2, 0],
                    [0, f, h/2, 0],
                    [0, 0, 1, 0]])
    # Final transformation matrix
    return np.dot(A2, np.dot(T, np.dot(R, A1)))
  height = img.shape[0]
  width = img.shape[1]
  num_channels = img.shape[2]
  rtheta = deg_to_rad(theta)
  rphi = deg_to_rad(phi)
  rgamma = deg_to_rad(gamma)
  d = np.sqrt(height**2 + width**2)
  focal = d / (2 * np.sin(rgamma) if np.sin(rgamma) != 0 else 1)
  dz = focal
  mat = get_M(rtheta, rphi, rgamma, dx, dy, dz, (width, height), focal)
  return cv2.warpPerspective(img.copy(), mat, (width, height))

def random_spatial_rotation(theta_range, phi_range, gamma_range):
  return lambda img: image_3D_rotation(
    img, 
    theta = np.random.randint(theta_range[0], theta_range[1] + 1),
    phi = np.random.randint(phi_range[0], phi_range[1] + 1),
    gamma = np.random.randint(gamma_range[0], gamma_range[1] + 1)
  )

### Data generator parameters definition

In [None]:
real_datagen = ImageDataGenerator(
    data_format = 'channels_last',
)

augmented_datagen = ImageDataGenerator(
    brightness_range = [0.5, 1.2],
    width_shift_range = size[0] // 10,
    height_shift_range = size[1] // 10,
    zoom_range = 0.1,
    fill_mode = 'constant',
    cval = 0,
    data_format = 'channels_last',
    preprocessing_function = random_spatial_rotation(
      theta_range = (-20, 20),
      phi_range = (-30, 30),
      gamma_range = (-10, 10)
    )
)

def create_flow():
  return real_datagen.flow_from_directory(
    directory = train_data_directory,
    target_size = size,
    color_mode = 'rgba',
    class_mode = 'sparse',
    batch_size = 1,
    shuffle = True
  )

def flow_to_tuple(t):
  rescale = 1./255
  original = t[0][0] * rescale
  transformed = augmented_datagen.random_transform(augmented_datagen.preprocessing_function(t[0][0])) * rescale
  label = t[1][0]
  return original, transformed, label

def mismatched_images():
  flow_1 = create_flow()
  flow_2 = create_flow()
  while True:
    t_1 = next(flow_1)
    t_2 = next(flow_2)
    if t_1[1][0] != t_2[1][0]:
      _, transformed, _ = flow_to_tuple(t_2)
      yield t_1[0][0] / 255, transformed

def matched_images():
  flow = create_flow()
  for original, transformed, _ in map(flow_to_tuple, flow):
    yield original, transformed

res = []
it = matched_images()
for _ in range(10):
  t = next(it)
  res.append(t[0])
  res.append(t[1])

it = mismatched_images()
for _ in range(10):
  t = next(it)
  res.append(t[0])
  res.append(t[1])

plot_grid(res, 2, labels=["First", "Second"])
del res, it

In [None]:
def probability_merge(it_1, it_2, p = 0.5):
  while True:
    rand = np.random.random()
    it, label = (it_1, 1) if rand < p else (it_2, 0)
    original, transformed = next(it)
    yield original, transformed, label

it = probability_merge(mismatched_images(), matched_images(), p = 0.5)
res = []
for _ in range(20):
  t = next(it)
  res.append(t[0])
  res.append(t[1])

plot_grid(res, 2, labels=["First", "Second"])
del res, it

## Model definition

### Performance Evaluation

In [None]:
def accuracy(extract_features, classify_features, n = 100):
  actual = []
  predicted = []
  false_positives = []
  false_negatives = []
  it = probability_merge(mismatched_images(), matched_images(), p = 0.5)
  for _ in tqdm(range(n), total=n, desc='Calculating accuracy...'):
    original, transformed, label = next(it)
    original_features = extract_features(original)
    transformed_features = extract_features(transformed)
    prediction = classify_features(original_features, transformed_features)
    actual.append(label)
    predicted.append(prediction)
    if label != prediction:
      (false_positives if prediction == 0 else false_negatives).append((original, transformed))
  confusion = confusion_matrix(actual, predicted)
  print(confusion)
  to_be_plotted = []
  # for original, transformed in [*false_positives, *false_negatives]:
  #   to_be_plotted.append(original)
  #   to_be_plotted.append(transformed)
  for original, transformed in false_positives:
    to_be_plotted.append(original)
    to_be_plotted.append(transformed)
  plot_grid(to_be_plotted, 2, labels=['Original', 'False positive'])

  to_be_plotted = []
  for original, transformed in false_negatives:
    to_be_plotted.append(original)
    to_be_plotted.append(transformed)
  plot_grid(to_be_plotted, 2, labels=['Original', 'False negative'])

### Method 1


#### Feature extractor

In [None]:
def rmse(predictions, targets):
  return np.sqrt(((predictions - targets) ** 2).mean())

def create_feature_extractor():
  alpha = 1.0
  weights = 'imagenet'
  pooling = 'max'

  model = tf.keras.applications.MobileNetV2(
    input_shape = (*size, 3),
    alpha = alpha,
    include_top = False,
    weights = weights,
    pooling = pooling
  )

  return model

feature_extractor = create_feature_extractor()

def extract_features(img):
  return feature_extractor.predict(np.asarray([img[:,:,:3]]), batch_size=1)

def f(it, n = 1):
  rmses = []
  for _ in range(n):
    original, transformed = next(it)
    original_f, transformed_f = extract_features(original), extract_features(transformed)
    rmses.append(rmse(original_f, transformed_f))
  
  #rmses = np.asarray(rmses)
  print('Avg:', np.mean(rmses))
  print('Median:', np.median(rmses))
  print('Max:', np.max(rmses))
  print('Min:', np.min(rmses))
  print('Stdev:', np.std(rmses))

f(matched_images(), 100)
f(mismatched_images(), 100)

#tf.keras.utils.plot_model(feature_extractor, show_shapes=True, show_layer_names=True)

#### Feature classification

In [None]:
rmse_threshold = 2.311
equal_label = 0
different_label = 1

def classify_features(a, b):
  return different_label if rmse(a, b) > rmse_threshold else equal_label

#### Evaluate performance

In [None]:
accuracy(extract_features, classify_features, n = 1000)

## Testing model on stores

# Product Class Detection

## Prepare products class dictionary

In [None]:
classes, raw_classes = read_classes()

def class_name(class_index):
  return classes[class_index] if class_index >= 0 else None

## Load training raw images

In [None]:
products = read_training_data(raw_classes)

## Products visualization

In [None]:
def show_products_with_class(indexes, columns, dataset):
  def show_single_product_with_class(row, sp):
    plt.axis('off')
    plt.imshow(row.image)
    sp.set_title(class_name(row.class_index), fontsize=10)
  dataset_plot_grid(indexes, columns, dataset, show_single_product_with_class)

show_products_with_class(np.random.randint(0, len(products), 6), 3, products)

## Image preprocessing

### Background removal

In [None]:
# code taken from https://www.kaggle.com/vadbeg/opencv-background-removal and modified

def remove_background(img, threshold):
    gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    _, threshed = cv2.threshold(gray, threshold, 255, cv2.THRESH_BINARY_INV)

    kernel_size = round(max(img.shape[0], img.shape[1]) * 0.02)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
    morphed = cv2.morphologyEx(threshed, cv2.MORPH_CLOSE, kernel)

    cnts = cv2.findContours(morphed, 
                            cv2.RETR_EXTERNAL,
                            cv2.CHAIN_APPROX_SIMPLE)[0]

    cnts = sorted(cnts, key=cv2.contourArea)

    mask = cv2.drawContours(threshed, [cnts[-1]], 0, [255], cv2.FILLED)
    masked_data = cv2.bitwise_and(img, img, mask=mask)

    x, y, w, h = cv2.boundingRect(cnts[-1])
    dst = masked_data[y: y + h, x: x + w]

    alpha = mask[y: y + h, x: x + w]
    r, g, b = cv2.split(dst)

    rgba = [r, g, b, alpha]
    dst = cv2.merge(rgba, 4)
    return dst

n = np.random.randint(products.shape[0])
print(f'Index: {n}')
print(f'Class: {class_name(products[n].class_index)}')
plot_grid([products[n].image, remove_background(products[n].image, 250)], 2, show_axis=True)

### Image resize

In [None]:
def resize_image(img, size, color=[0,0,0,0]):
  target_w, target_h = size
  original_h, original_w, _ = img.shape
  target_ar = target_w / target_h
  original_ar = original_w / original_h

  scale_factor = target_h / original_h if target_ar > original_ar else target_w / original_w
  scaled_w = round(original_w * scale_factor)
  scaled_h = round(original_h * scale_factor)
  scaled_size = (scaled_w, scaled_h)
  resized = cv2.resize(img, scaled_size)

  delta_h = target_h - scaled_h
  delta_w = target_w - scaled_w
  top    = delta_h // 2
  left   = delta_w // 2
  bottom = delta_h - top
  right  = delta_w - left

  return cv2.copyMakeBorder(resized, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)

n = np.random.randint(products.shape[0])
image = products[n].image
image = remove_background(image, 250)
plot_grid([image, resize_image(image, (400, 400))], 2, show_axis=True)

## Dataset preparation



### Image cleaning

In [None]:
size = (224, 224)

def clean_image(img):
  threshold = 250
  img = remove_background(img, threshold)
  return resize_image(img, size)

### Export new dataset on file

To save memory.

In [None]:
number_of_images_to_save = 6400
reduce_dataset = True

In [None]:
def get_class_from_sub_classes(complete_class):
  sub_classes = complete_class.split('/')
  top_level_class = sub_classes[0]
  if top_level_class == 'Food' or top_level_class == 'HouseProducts':
    return sub_classes[1]
  else:
    return top_level_class

In [None]:
train_data_directory = 'Temp'
shutil.rmtree(train_data_directory, ignore_errors=True)

if reduce_dataset:
  products_size = number_of_images_to_save
  products_to_dump = np.random.choice(products.shape[0], number_of_images_to_save, replace = False)
else:
  products_size = products.shape[0]
  products_to_dump = np.arange(products_size)

for index, (image, class_index) in tqdm(enumerate(products[products_to_dump]), total=products_size, desc='Writing files...'):
  complete_class = class_name(class_index)
  class_directory = get_class_from_sub_classes(complete_class)
  output_dir = os.path.join(train_data_directory, class_directory)
  Path(output_dir).mkdir(parents=True, exist_ok=True)
  out = cv2.cvtColor(clean_image(image), cv2.COLOR_RGBA2BGRA)
  cv2.imwrite(os.path.join(output_dir, f'{class_directory}.{index}.png'), out)


## Data Augmentation

### 3D rotation

In [None]:
def image_3D_rotation(img, theta = 0, phi = 0, gamma = 0, dx = 0, dy = 0, dz = 0):
  """
  Parameters:
      img       : the image data as numpy array
      theta     : rotation around the x axis
      phi       : rotation around the y axis
      gamma     : rotation around the z axis (basically a 2D rotation)
      dx        : translation along the x axis
      dy        : translation along the y axis
      dz        : translation along the z axis (distance to the image)
  Output:
      image     : the rotated image
  
  Reference:
      1.        : http://stackoverflow.com/questions/17087446/how-to-calculate-perspective-transform-for-opencv-from-rotation-angles
      2.        : http://jepsonsblog.blogspot.tw/2012/11/rotation-in-3d-using-opencvs.html
      3.        : Code taken from https://github.com/eborboihuc/rotate_3d/blob/master/image_transformer.py
  """
  def deg_to_rad(deg):
    return deg * math.pi / 180.0
  def get_M(theta, phi, gamma, dx, dy, dz, size, focal):
    w = size[0]
    h = size[1]
    f = focal
    # Projection 2D -> 3D matrix
    A1 = np.array([ [1, 0, -w/2],
                    [0, 1, -h/2],
                    [0, 0, 1],
                    [0, 0, 1]])
    # Rotation matrices around the X, Y, and Z axis
    RX = np.array([ [1, 0, 0, 0],
                    [0, np.cos(theta), -np.sin(theta), 0],
                    [0, np.sin(theta), np.cos(theta), 0],
                    [0, 0, 0, 1]])
    RY = np.array([ [np.cos(phi), 0, -np.sin(phi), 0],
                    [0, 1, 0, 0],
                    [np.sin(phi), 0, np.cos(phi), 0],
                    [0, 0, 0, 1]])
    RZ = np.array([ [np.cos(gamma), -np.sin(gamma), 0, 0],
                    [np.sin(gamma), np.cos(gamma), 0, 0],
                    [0, 0, 1, 0],
                    [0, 0, 0, 1]])
    # Composed rotation matrix with (RX, RY, RZ)
    R = np.dot(np.dot(RX, RY), RZ)
    # Translation matrix
    T = np.array([  [1, 0, 0, dx],
                    [0, 1, 0, dy],
                    [0, 0, 1, dz],
                    [0, 0, 0, 1]])
    # Projection 3D -> 2D matrix
    A2 = np.array([ [f, 0, w/2, 0],
                    [0, f, h/2, 0],
                    [0, 0, 1, 0]])
    # Final transformation matrix
    return np.dot(A2, np.dot(T, np.dot(R, A1)))
  height = img.shape[0]
  width = img.shape[1]
  num_channels = img.shape[2]
  rtheta = deg_to_rad(theta)
  rphi = deg_to_rad(phi)
  rgamma = deg_to_rad(gamma)
  d = np.sqrt(height**2 + width**2)
  focal = d / (2 * np.sin(rgamma) if np.sin(rgamma) != 0 else 1)
  dz = focal
  mat = get_M(rtheta, rphi, rgamma, dx, dy, dz, (width, height), focal)
  return cv2.warpPerspective(img.copy(), mat, (width, height))

def random_spatial_rotation(theta_range, phi_range, gamma_range):
  return lambda img: image_3D_rotation(
    img, 
    theta = np.random.randint(theta_range[0], theta_range[1] + 1),
    phi = np.random.randint(phi_range[0], phi_range[1] + 1),
    gamma = np.random.randint(gamma_range[0], gamma_range[1] + 1)
  )

### Data generator parameters definition

In [None]:
datagen_kwargs = dict(
  brightness_range = [0.5, 1.2],
  width_shift_range = size[0] // 8,
  height_shift_range = size[1] // 8,
  zoom_range = 0.1,
  fill_mode = 'constant',
  cval = 0,
  data_format = 'channels_last',
  preprocessing_function = random_spatial_rotation(
        theta_range = (-20, 20),
        phi_range = (-30, 30),
        gamma_range = (-10, 10)
      ),
  rescale = 1.0 / 255,
)

## [1] Training with splitting and augmentation

### Data flow definition

In [None]:
datagen = ImageDataGenerator(
  **datagen_kwargs,
  validation_split = 0.2
)

In [None]:
demo_flow = datagen.flow_from_directory(
  directory = train_data_directory,
  target_size = size,
  color_mode = 'rgba',
  class_mode = 'categorical',
  batch_size = 1,
  shuffle = True
)
plot_grid([next(demo_flow)[0][0] for _ in range(40)], 4)
del demo_flow

### Hyperparameters

In [None]:
alpha = 1.0
weights = 'imagenet'
pooling = 'max'
activation = 'softmax'
batch_size = 32
epochs = 20
optimizer = 'rmsprop'
loss = 'categorical_crossentropy'

### Generator definition

In [None]:
def create_flow(subset):
  return datagen.flow_from_directory(
    directory = train_data_directory,
    target_size = size,
    color_mode = 'rgb',
    class_mode = 'categorical',
    batch_size = batch_size,
    shuffle = True,
    subset=subset
  )

train_flow = create_flow('training')

validation_flow = create_flow('validation')

### Model definition

In [None]:
model = tf.keras.applications.MobileNetV2(
  input_shape = (*size, 3),
  alpha = alpha,
  include_top = False,
  weights = weights,
  pooling = pooling,
  classes = train_flow.num_classes,
  classifier_activation = activation
)

model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
#tf.keras.utils.plot_model(model, show_shapes=True, show_layer_names=True)

### Training phase

In [None]:
history = model.fit(
  train_flow,
  validation_data = validation_flow,
  validation_steps = validation_flow.samples // batch_size,
  epochs = epochs,
  steps_per_epoch = train_flow.samples // batch_size
)

plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

## [2] Training with validation generated from same set as train

### Data flow definition

In [None]:
train_datagen = ImageDataGenerator(
  **datagen_kwargs
)
test_datagen = ImageDataGenerator(
    rescale = 1./255
)

In [None]:
demo_flow = datagen.flow_from_directory(
  directory = train_data_directory,
  target_size = size,
  color_mode = 'rgba',
  class_mode = 'categorical',
  batch_size = 1,
  shuffle = True
)
plot_grid([next(demo_flow)[0][0] for _ in range(40)], 4)
del demo_flow

### Hyperparameters

In [None]:
alpha = 1.0
weights = None
pooling = 'max'
activation = 'softmax'
batch_size = 1
epochs = 20
optimizer = 'adam'
loss = 'categorical_crossentropy'

### Generator definition

In [None]:
def create_flow(datagen):
  return datagen.flow_from_directory(
    directory = train_data_directory,
    target_size = size,
    color_mode = 'rgb',
    class_mode = 'categorical',
    batch_size = batch_size,
    shuffle = True
  )

train_flow = create_flow(train_datagen)
validation_flow = create_flow(test_datagen)

### Model definition

In [None]:
model = tf.keras.applications.MobileNetV2(
  input_shape = (*size, 3),
  alpha = alpha,
  include_top = True,
  pooling = pooling,
  weights = weights,
  classes = train_flow.num_classes,
  classifier_activation = activation
)

model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
#tf.keras.utils.plot_model(model, show_shapes=True, show_layer_names=True)

### Training phase

In [None]:
history = model.fit(
  train_flow,
  validation_data = validation_flow,
  validation_steps = validation_flow.samples // batch_size,
  epochs = epochs,
  steps_per_epoch = train_flow.samples // batch_size
)

plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()