In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds
from tensorflow.keras import layers
from IPython import display
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import PIL
import tensorflow as tf
import tensorflow_probability as tfp
import time
import tensorflow_datasets as tfds
from sklearn.mixture import GaussianMixture
import random
from tqdm import tqdm
from random import sample
from skimage.morphology import skeletonize
import cv2

Load Omniglot Dataset

In [None]:
size = 105
class Dataset:
    # This class will facilitate the creation of a few-shot dataset
    # from the Omniglot dataset that can be sampled from quickly while also
    # allowing to create new labels at the same time.
    def __init__(self, training):
        # Download the tfrecord files containing the omniglot data and convert to a
        # dataset.
        split = "train" if training else "test"
        ds = tfds.load("omniglot", split=split, as_supervised=True, shuffle_files=False)
        # Iterate over the dataset to get each individual image and its class,
        # and put that data into a dictionary.
        self.data = {}

        def extraction(image, label):
            # This function will shrink the Omniglot images to the desired size,
            # scale pixel values and convert the RGB image to grayscale
            image = tf.image.convert_image_dtype(image, tf.float32)
            image = tf.image.rgb_to_grayscale(image)
            image = tf.image.resize(image, [size, size])
            return image, label

        for image, label in ds.map(extraction):
            image = image.numpy()
            label = str(label.numpy())
            if label not in self.data:
                self.data[label] = []
            self.data[label].append(image)
        self.labels = list(self.data.keys())

    def get_mini_dataset(
        self, batch_size, repetitions, shots, num_classes, split=False
    ):
        temp_labels = np.zeros(shape=(num_classes * shots))
        temp_images = np.zeros(shape=(num_classes * shots, size, size, 1))
        if split:
            test_labels = np.zeros(shape=(num_classes))
            test_images = np.zeros(shape=(num_classes, size, size, 1))

        # Get a random subset of labels from the entire label set.
        label_subset = random.choices(self.labels, k=num_classes)
        for class_idx, class_obj in enumerate(label_subset):
            # Use enumerated index value as a temporary label for mini-batch in
            # few shot learning.
            temp_labels[class_idx * shots : (class_idx + 1) * shots] = class_idx
            # If creating a split dataset for testing, select an extra sample from each
            # label to create the test dataset.
            if split:
                test_labels[class_idx] = class_idx
                images_to_split = random.choices(
                    self.data[label_subset[class_idx]], k=shots + 1
                )
                test_images[class_idx] = images_to_split[-1]
                temp_images[
                    class_idx * shots : (class_idx + 1) * shots
                ] = images_to_split[:-1]
            else:
                # For each index in the randomly selected label_subset, sample the
                # necessary number of images.
                temp_images[
                    class_idx * shots : (class_idx + 1) * shots
                ] = random.choices(self.data[label_subset[class_idx]], k=shots)

        dataset = tf.data.Dataset.from_tensor_slices(
            (temp_images.astype(np.float32), temp_labels.astype(np.int32))
        )
        dataset = dataset.shuffle(100).batch(batch_size).repeat(repetitions)
        if split:
            return dataset, test_images, test_labels
        return dataset


import urllib3

urllib3.disable_warnings()  # Disable SSL warnings that may happen during download.
train_dataset = Dataset(training=True)
test_dataset = Dataset(training=False)

Downloading and preparing dataset 17.95 MiB (download: 17.95 MiB, generated: Unknown size, total: 17.95 MiB) to /root/tensorflow_datasets/omniglot/3.0.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Organize Alphabets

In [None]:
import urllib3

urllib3.disable_warnings()  # Disable SSL warnings that may happen during download.
train_dataset = Dataset(training=True)
test_dataset = Dataset(training=False)

ds = tfds.load("omniglot", split="train", as_supervised=False)

label_alphabet = {}

for i in ds:
    alphabet_str = str(i['alphabet'].numpy())
    label = i['label'].numpy()

    if alphabet_str not in label_alphabet:
        label_alphabet[alphabet_str] = []

    label_alphabet[alphabet_str].append(label)

alphabets = [i for i in label_alphabet.keys()]
# print(alphabets)
#['27', '30', '17', '12', '15', '37', '43', '48', '32', '3', '2', '21', '25', '13', '14', '35', '26', '20', '0', '38', '4', '16', '41', '24', '11', '10', '31', '5', '45', '22']


all_alphabets = []
for id in alphabets:
  alpha = [str(i) for i in np.unique(label_alphabet[id])]
  all_alphabets.append(alpha)


Necessary Functions

In [None]:
def plot(xy, labels):
  plt.subplots()
  plt.scatter(xy[:, 0], xy[:, 1], c=labels, s=40, cmap='viridis')

#use with scale shift
def plot_xy(xy):
  plt.ylim(-10,35)
  plt.xlim(-10,35)
  plt.scatter(xy[:, 0], xy[:, 1])

#generates scatter samples of the character
def generate(xy, n):
  gen = mixture.GaussianMixture(n_components= n).fit(xy)
  newwxy, newlabels = gen.sample(300) #300 is density of pixels
  newxy = newwxy[:,:2]
  return newlabels, newxy

#outputs two examples of the same class
def character(char_class, example1_index, example2_index):
  a = train_dataset.data[str(char_class)][example1_index]
  a = np.stack((a[:, :, 0],) * 3, axis=2)
  a *= 255
  a = np.clip(a, 0, 255).astype("uint8")
  a = pix2cart(a)
  a1 = train_dataset.data[str(char_class)][example2_index]
  a1 = np.stack((a1[:, :, 0],) * 3, axis=2)
  a1 *= 255
  a1 = np.clip(a1, 0, 255).astype("uint8")
  a1 = pix2cart(a1)
  return a, a1

#converts generated characters back into images to feed into the neural network
def cart2pix(generated_a):
  # Find the maximum value of x and y in the array
  max_x, max_y = np.amax(generated_a, axis=0)
  # Scale the XY coordinates to fit within the 28x28 pixel grid
  scaled_coordinates = np.round(np.array(generated_a) * (28 / max(max_x, max_y))).astype(int)
  scaled_coordinates = np.clip(scaled_coordinates, 0, 27)
  # Create an empty image with all pixels set to 0
  image = np.zeros((28, 28))
  # Set the pixel values at the scaled XY coordinates to 1
  image[scaled_coordinates[:, 0], scaled_coordinates[:, 1]] = 1
  # plt.imshow(image)
  return image


def oneshot():
############### VISUALIZATIONS ######################
  _, axarr = plt.subplots(nrows=2, ncols=5, figsize=(10, 6))
############### VISUALIZATIONS ######################

  sample_keys = list(train_dataset.data.keys())
  # Get a random sample of 10 different characters
  selected_classes = random.sample(sample_keys, 10)
  query_class = random.sample(selected_classes,1)
  query = train_dataset.data[query_class[0]][1]
  query = np.stack((query[:, :, 0],) * 3, axis=2)
  query *= 255
  query = np.clip(query, 0, 255).astype("uint8")
  comparisons = [] #the other example candidates to compare against the query

  # Iterate through the selected classes
  for i, selected_class in enumerate(selected_classes):
      # Get a random sample of an image from the class
      temp_image = train_dataset.data[selected_class][0] #TAKES THE FIRST EXAMPLE OF THE CLASS
      temp_image = np.stack((temp_image[:, :, 0],) * 3, axis=2)
      temp_image *= 255
      temp_image = np.clip(temp_image, 0, 255).astype("uint8")
      comparisons.append(temp_image)


################ VISUALIZATIONS ######################
      #display the class
      axarr[i // 5, i % 5].set_title("Class : " + selected_class)
      axarr[i // 5, i % 5].imshow(temp_image, cmap="gray")
      axarr[i // 5, i % 5].xaxis.set_visible(True)
      axarr[i // 5, i % 5].yaxis.set_visible(True)
  plt.show()
################ VISUALIZATIONS ######################


  return query,query_class,comparisons,selected_classes



#Scale and shift code
from sklearn.preprocessing import MinMaxScaler
def scale_shift(xy_coords): #shifts all data to center and scales by min max normalization
  scaler = MinMaxScaler(feature_range=(0, 28),clip=True)
  xy = scaler.fit_transform(xy_coords)
  centroid = np.mean(xy,axis = 0)
  centroid_x = centroid[0]
  centroid_y = centroid[1]
  center_x = 14
  center_y = 14
  x_shift = center_x - centroid_x
  y_shift = center_x - centroid_y
  centered_x = list(i[0] for i in xy) + x_shift
  centered_y = list(i[1] for i in xy) + y_shift
  centered_xy = [[i, j] for i, j in zip(centered_x, centered_y)]

  return np.array(centered_xy)


def pix2cart(image):
  xy = np.stack(np.where(image<0.5)).transpose()
  xy = xy[:,:2]
  xy = list(set([tuple(i) for i in xy]))
  xy = [list(i) for i in xy]
  return np.array(xy)


def synt_char(image, n_examples,density): #returns n different examples of a single character
  synt_data = []
  range_n = [6,7,8,9,10]
  probabilities = [.05,.05,.2,.25,.45]
  for i in range(n_examples):
    current_n = np.random.choice(range_n, p = probabilities)
    a1 = pix2cart(image)
    gmm = GaussianMixture(n_components=current_n).fit(a1)
    a2,labels = gmm.sample(density) #density
    a3 = np.array(a2)
    synt_data.append(a3)
  return synt_data,labels #xy coords

def synt_alphabet(alphabet,n_examples,density):
  synthetic_characters = []
  labelz = []
  for char in tqdm(alphabet):
    sc,labels = synt_char(char,n_examples,density)
    synthetic_characters.extend(sc)
    labelz.append(labels)
  return synthetic_characters,labelz


def oneshot_within(n,alphabet,visual = False): # n-way different characters from single alphabet

  if visual == True:
############### VISUALIZATIONS ######################
    if n == 5:
      _, axarr = plt.subplots(nrows=1, ncols=5, figsize=(10, 3))

    elif n == 10:
      _, axarr = plt.subplots(nrows=2, ncols=5, figsize=(10, 6))

    else:
      _, axarr = plt.subplots(nrows=4, ncols=5, figsize=(10, 10))
############### VISUALIZATIONS ######################

  images, labels = n_chars_alphabet(n,alphabet)
  selected_classes = labels
  query_class = random.sample(selected_classes,1)
  query = train_dataset.data[query_class[0]][1]
  query = np.stack((query[:, :, 0],) * 3, axis=2)
  query *= 255
  query = np.clip(query, 0, 255).astype("uint8")
  comparisons = images #the other example candidates to compare against the query

############### VISUALIZATIONS ######################
  if visual == True:
    for i in range(len(selected_classes)):
      if n == 5:
        axarr[i].set_title("Class : " + selected_classes[i])
        axarr[i].imshow(images[i], cmap="gray")
        axarr[i].xaxis.set_visible(True)
        axarr[i].yaxis.set_visible(True)
      else:
        axarr[i // 5, i % 5].set_title("Class : " + selected_classes[i])
        axarr[i // 5, i % 5].imshow(images[i], cmap="gray")
        axarr[i // 5, i % 5].xaxis.set_visible(True)
        axarr[i // 5, i % 5].yaxis.set_visible(True)

    plt.show()
############### VISUALIZATIONS ######################


  return query,query_class,comparisons, selected_classes


#Extracts n different characters from the same alphabet
def n_chars_alphabet(n,alphabet):
  images = []
  labels = sample(alphabet,n)
  for i in range(n):
    image = train_dataset.data[str(labels[i])][0] #takes first example
    image = np.stack((image[:, :, 0],) * 3, axis=2)
    image *= 255
    image = np.clip(image, 0, 255).astype("uint8")
    images.append(image)

  return images, [str(i) for i in labels]


#Extracts n different characters from the same alphabet
def n_chars_alphabet(n,alphabet):
  images = []
  labels = sample(alphabet,n)
  for i in range(n):
    image = train_dataset.data[str(labels[i])][0] #takes first example
    image = np.stack((image[:, :, 0],) * 3, axis=2)
    image *= 255
    image = np.clip(image, 0, 255).astype("uint8")
    images.append(image)

  return images, [str(i) for i in labels]

In [None]:
numgen = 100

query,query_class,comparisons,selected_classes = oneshot()
data,labels = synt_alphabet(comparisons,numgen,300)
compared_characters = [cart2pix(k) for k in data]
scaled_compared_characters = [scale_shift(i) for i in compared_characters]
scaled_char = [cart2pix(k) for k in scaled_compared_characters]
sa_normalized = np.reshape(scaled_char,(len(comparisons)*numgen,28,28,1))
sa = np.reshape(compared_characters,(len(comparisons)*numgen,28,28,1))

VAE Stuff

In [None]:
class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.random.normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [None]:
#Encoder

latent_dim = 2
encoder_inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

In [None]:
#Decoder

latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

In [None]:
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }


Prepare Synthetic Data for VAE

In [None]:
num_examples = 100

def preprocess_images(images):
  return images.astype('float32')

train_images = preprocess_images(sa)
train_labels = []
for i in range(len(selected_classes)):
  train_labels.extend([i]*numgen)

train_size = len(comparisons) * num_examples
batch_size = 32
test_size = len(comparisons) * num_examples

s_train_images = tf.random.shuffle(train_images)

Train VAE with Synthetic Data

In [None]:
vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam(learning_rate=0.0001))
vae.fit(s_train_images, epochs=50, batch_size=32)

Display how the latent space clusters different classes

In [None]:
def plot_label_clusters(vae, data, labels):
    # display a 2D plot of the digit classes in the latent space
    z_mean, z_log_var, z = vae.encoder.predict(data)

    plt.figure(figsize=(12, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=labels)
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.show()

    plt.figure(figsize=(12, 10))
    plt.scatter(z_log_var[:, 0], z_log_var[:, 1], c=labels)
    plt.colorbar()
    plt.xlabel("zvar[0]")
    plt.ylabel("zvar[1]")
    plt.show()

    xmin = min(z_mean[:, 0])
    xmax = max(z_mean[:, 0])
    ymin = min(z_mean[:, 1])
    ymax = max(z_mean[:, 1])

    return xmin, xmax, ymin, ymax


xmin, xmax, ymin, ymax = plot_label_clusters(vae, train_images,train_labels)

Plot Omniglot Latent Space

In [None]:
import matplotlib.pyplot as plt


def plot_latent_space(vae, n=7, figsize=15):
    # display an n*n 2D manifold of digits
    digit_size = 28
    figure = np.zeros((digit_size * n, digit_size * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(xmin, xmax, n)
    grid_y = np.linspace(ymin, ymax, n)[::-1]


    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = vae.decoder.predict(z_sample)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[
                i * digit_size : (i + 1) * digit_size,
                j * digit_size : (j + 1) * digit_size,
            ] = digit

    plt.figure(figsize=(figsize, figsize))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(28 - figure, cmap="gray")
    plt.show()


plot_latent_space(vae)

Sample A Single Example \\

Visualize Different Thresholds

In [None]:
def single(x,y,vae,figsize=15):
    # display an n*n 2D manifold of digits
    digit_size = 28
    figure = np.zeros((digit_size, digit_size))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    z_sample = np.array([[x, y]])
    x_decoded = vae.decoder.predict(z_sample)
    digit = x_decoded[0].reshape(digit_size, digit_size)
    plt.imshow(digit,cmap="Greys_r")
    return digit

d = single(1.3,1.4,vae,figsize=15) #example z in (1.3,1.4) of latent space
min_value = np.min(d)
max_value = np.max(d)
normalized_image = (d - min_value) / (max_value - min_value)
scaled_image = (normalized_image * 27).astype(int)
threshold = range(0,15)
# binary_image = (scaled_image > threshold).astype(int)
# plt.imshow(skeletonize(binary_image))

# for i in threshold:
#   plt.subplots()
#   binary_image = (scaled_image > i).astype(int)
#   plt.imshow(28 - skeletonize(binary_image),cmap = "gray")
fig, axes = plt.subplots(3, 5, figsize=(10, 5))

# Iterate over thresholds and plot the skeletonized binary images
for i, threshold in enumerate(threshold):
    row = i // 5
    col = i % 5

    # Calculate the binary image with the current threshold
    binary_image = (scaled_image > threshold).astype(int)

    # Skeletonize the binary image
    skeletonized_image = skeletonize(binary_image)

    # Display the skeletonized image in the current subplot
    axes[row, col].imshow(28 - skeletonized_image, cmap="gray")
    axes[row, col].set_title(f'Threshold: {threshold}')
    axes[row, col].axis('off')

plt.tight_layout()
plt.show()


Example: New Generated Character Given Alphabet

In [None]:
import cv2
import numpy as np
from skimage.morphology import skeletonize
import matplotlib.pyplot as plt


def save_real(alphabet, selected_classes):
    not_in = [i for i in alphabet if i not in selected_classes]
    print(not_in)
    for i in not_in:
        image = train_dataset.data[i][0]
        image = cv2.resize(image, (28, 28))

        inverted_image_np = 255 - image  # Invert the image
        min_value = np.min(inverted_image_np)
        max_value = np.max(inverted_image_np)
        normalized_image = (inverted_image_np - min_value) / (max_value - min_value)
        scaled_image = (normalized_image * 27).astype(int)
        binary_image = (scaled_image > 1).astype(int)

        # Create a new figure and axis
        #fig, ax = plt.subplots()
        plt.imshow(28 - skeletonize(binary_image),cmap='gray')

        # Save the figure with a filename
        filename = f'real_{i}.png'
        plt.imsave(filename, 28 - skeletonize(binary_image),cmap='gray')


save_real(all_alphabets[26],selected_classes)

In [None]:
# <!-- We want to spatially represent the latent space:
# 1. Run a GMM with the specfies number of classes (20) in z
# 2. See the probability distribution of where the query falls
# 3. Use the similarity metric to weigh the proabilities.
# - A (0.3) , B (0.5) , C (0.2)
# - Sim(Q,A) = 200, Sim(Q,B) = 150, Sim(Q,C) = 20
# - Final score for A = 200*0.3, B = 0.5*150, C = 20*.2
# - Winner is B
# (We might have to transform the similarity scores into a probability to keep it consistent and meaningful) -->
