<a href="https://colab.research.google.com/github/inyong37/Study/blob/master/_Library/Keras/Keras_Example_FSL_Reptile.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Title: Few-Shot learning with Reptile
# Re-Author: Inyong Hwang
# Date: 2022-01-07-Fri.
# Refernce: https://keras.io/examples/vision/reptile/
# Date #2: 2022-01-21-Fri.

# Introduction
The Reptile algorithm was developed by OpenAI to perform model agnostic meta-learning. Specifically, this algorithm was designed to quickly learn to perform new tasks with minimal training (few-shot learning). The algorithm works by performing Stochastic Gradient Descent using the difference between weights trained on a mini-batch of never before seen data and the model weights prior to training over a fixed number of meta-iterations.

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import random
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds

In [3]:
# Define the Hyperparameters
learning_rate = 0.003
meta_step_size = 0.25

inber_batch_size = 25
eval_batch_size = 25

meta_iters = 2000
eval_iters = 5
inner_iters = 4

eval_interval = 1
train_shots = 20
shots = 5
classes = 5

# Prepare the data
The Omniglot dataset is a dataset of 1,623 characters taken from 50 different alphabets, with 20 examples of each character. The 20 samples for each character were drawn online via Amazon's Mechanical Turk. For the few-shot learning task, k samples (or "shots") are drawn randomly from n randomly-chosen classes. These n numerical values are used to create a new set of temporary labels to use to test the model's ability to learn a new task given few exmaples. In other words, if you are training on 5 classes, your new class labels will be either 0, 1, 2, 3, or 4. Omniglot is a great dataset for this task since there are many different classes to draw from, with a reasonable number of samples for each class.

In [14]:
class Dataset:
  # This class will facilitate the creation of a few-shot dataset
  # from the Omniglot dataset that can be sampled from quickly while also
  # allowing to create new labels at the same time.
  def __init__(self, training):
    # Download the tfrecord files containing the omniglot data and covert to a
    # dataset.
    split = "train" if training else "test"
    ds = tfds.load("omniglot", split=split, as_supervised=True, shuffle_files=False)
    # Iterate over the dataset to get each individual image and its class,
    # and put that data into a dictionary.
    self.data = {}

    def  extraction(image, label):
      # This function will shrink the Omniglot images to the desired size,
      # scale pixel values and convert the RGB image to grayscale
      image = tf.image.convert_image_dtype(image, tf.float32)
      image = tf.image.rgb_to_grayscale(image)
      image = tf.image.resize(image, [28, 28])
      return image, label
    
    for image, label in ds.map(extraction):
      image = image.numpy()
      label = str(label.numpy())
      if label not in self.data:
        self.data[label] = []
      self.data[label].append(image)
    self.labels = list(self.data.keys())
  
  def get_mini_dataset(
      self, batch_size, repetitions, shots, num_classes, split=False
  ):
    temp_labels = np.zeros(shape=(num_classes * shots))
    temp_images = np.zeros(shape=(num_classes * shots, 28, 28, 1))
    if split:
      test_labels = np.zeros(shape=(num_classes))
      test_images = np.zeros(shape=(num_classes, 28, 28, 1))
    
    # Get a random subset of labels from the entire label set.
    label_subset = random.choices(self.labels, k=num_classes)
    for class_idx, class_obj in enumerate(label_subset):
      # Use enumerated index value as a temporary label for mini-batch in
      # few shot learning
      temp_labels[class_idx * shots : (class_idx + 1) * shots] = class_idx
      # If creating a split dataset for testing, select an extra example from each
      # label to create the test dataset.
      if split:
        test_labels[class_idx] = class_idx
        images_to_split = random.choices(
          self.data[label_subset[class_idx]], k=shots + 1
        )
        test_images[class_idx] = images_to_split[-1]
        temp_images[
          class_idx * shots : (class_idx + 1) * shots
        ] = images_to_split[:-1]
      else:
        # For each index in the randomly selected label_subset, sample the
        # necessary number of images.
        temp_images[
          class_idx * shots : (class_idx + 1) * shots
        ] = random.choices(self.data[label_subset[class_idx]], k=shots)

    dataset = tf.data.Dataset.from_tensor_slices(
      (temp_images.astype(np.float32), temp_labels.astype(np.int32))
    )
    dataset = dataset.shuffle(100).batch(batch_size).repeat(repetitions)
    if split:
      return dataset, test_images, test_labels
    return dataset

import urllib3

urllib3.disable_warnings() # Disable SSL warnings that my happen during download.
train_dataset = Dataset(training=True)
test_dataset = Dataset(training=False)

[1mDownloading and preparing dataset omniglot/3.0.0 (download: 17.95 MiB, generated: Unknown size, total: 17.95 MiB) to /root/tensorflow_datasets/omniglot/3.0.0...[0m


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]






0 examples [00:00, ? examples/s]

Shuffling and writing examples to /root/tensorflow_datasets/omniglot/3.0.0.incompleteGCKKPU/omniglot-train.tfrecord


  0%|          | 0/19280 [00:00<?, ? examples/s]

0 examples [00:00, ? examples/s]

Shuffling and writing examples to /root/tensorflow_datasets/omniglot/3.0.0.incompleteGCKKPU/omniglot-test.tfrecord


  0%|          | 0/13180 [00:00<?, ? examples/s]

0 examples [00:00, ? examples/s]

Shuffling and writing examples to /root/tensorflow_datasets/omniglot/3.0.0.incompleteGCKKPU/omniglot-small1.tfrecord


  0%|          | 0/2720 [00:00<?, ? examples/s]

0 examples [00:00, ? examples/s]

Shuffling and writing examples to /root/tensorflow_datasets/omniglot/3.0.0.incompleteGCKKPU/omniglot-small2.tfrecord


  0%|          | 0/3120 [00:00<?, ? examples/s]

[1mDataset omniglot downloaded and prepared to /root/tensorflow_datasets/omniglot/3.0.0. Subsequent calls will reuse this data.[0m
