<a href="https://colab.research.google.com/github/hungpham13/Vietnamese-HTR/blob/main/ScrabbleGAN_DEMO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ScrabbleGAN DEMO (TF 2.X)

more information: https://github.com/Nikolai10/scrabble-gan

## Enabling and testing the GPU

First, you'll need to enable GPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select GPU from the Hardware Accelerator drop-down

Next, we'll confirm that we can connect to the GPU with tensorflow:

## Download Project

In [1]:
# download project
!git clone -b dev https://github.com/Nikolai10/scrabble-gan.git

Cloning into 'scrabble-gan'...
remote: Enumerating objects: 69, done.[K
remote: Counting objects: 100% (38/38), done.[K
remote: Compressing objects: 100% (22/22), done.[K
remote: Total 69 (delta 26), reused 16 (delta 16), pack-reused 31[K
Unpacking objects: 100% (69/69), done.


## Download Dataset (For Demonstration Purpose Only)

In [3]:
! gdown https://drive.google.com/uc?id=1duoY9gBmx6quHNGWDlQGKIYO2ubsVZ-y

Downloading...
From: https://drive.google.com/uc?id=1duoY9gBmx6quHNGWDlQGKIYO2ubsVZ-y
To: /content/data.zip
100% 1.39G/1.39G [00:21<00:00, 63.6MB/s]


In [6]:
# external users: manually download https://drive.google.com/file/d/1duoY9gBmx6quHNGWDlQGKIYO2ubsVZ-y/view?usp=sharing
# place files as described in https://github.com/Nikolai10/scrabble-gan (Setup)
!mkdir -p /content/scrabble-gan/res/data/iamDB
!unzip /content/data.zip -d /content/scrabble-gan/res/

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: /content/scrabble-gan/res/data/iamDB/words/h07/h07-071a/h07-071a-07-07.png  
  inflating: /content/scrabble-gan/res/__MACOSX/data/iamDB/words/h07/h07-071a/._h07-071a-07-07.png  
  inflating: /content/scrabble-gan/res/data/iamDB/words/h07/h07-071a/h07-071a-05-02.png  
  inflating: /content/scrabble-gan/res/__MACOSX/data/iamDB/words/h07/h07-071a/._h07-071a-05-02.png  
  inflating: /content/scrabble-gan/res/data/iamDB/words/h07/h07-071a/h07-071a-05-03.png  
  inflating: /content/scrabble-gan/res/__MACOSX/data/iamDB/words/h07/h07-071a/._h07-071a-05-03.png  
  inflating: /content/scrabble-gan/res/data/iamDB/words/h07/h07-071a/h07-071a-07-06.png  
  inflating: /content/scrabble-gan/res/__MACOSX/data/iamDB/words/h07/h07-071a/._h07-071a-07-06.png  
  inflating: /content/scrabble-gan/res/data/iamDB/words/h07/h07-071a/h07-071a-09-03.png  
  inflating: /content/scrabble-gan/res/__MACOSX/data/iamDB/words/h07/h07-071a/._h

## Visualize some of the image

## Import Libs

In [None]:
!pip install git+https://github.com/tensorflow/docs

In [None]:
import sys
sys.path.extend(['/content/scrabble-gan'])

import os
import random

import gin
import numpy as np
import tensorflow as tf
import tensorflow_docs.vis.embed as embed
import matplotlib.pyplot as plt

from src.bigacgan.arch_ops import spectral_norm
from src.bigacgan.data_utils import load_prepare_data, train, make_gif, load_random_word_list
from src.bigacgan.net_architecture import make_generator, make_discriminator, make_recognizer, make_gan
from src.bigacgan.net_loss import hinge, not_saturating

gin.external_configurable(hinge)
gin.external_configurable(not_saturating)
gin.external_configurable(spectral_norm)

from src.dinterface.dinterface import init_reading

## Init Config Params

In [None]:
@gin.configurable
def setup_optimizer(g_lr, d_lr, r_lr, beta_1, beta_2, loss_fn, disc_iters):
    generator_optimizer = tf.keras.optimizers.Adam(learning_rate=g_lr, beta_1=beta_1, beta_2=beta_2)
    discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=d_lr, beta_1=beta_1, beta_2=beta_2)
    recognizer_optimizer = tf.keras.optimizers.Adam(learning_rate=r_lr, beta_1=beta_1, beta_2=beta_2)
    return generator_optimizer, discriminator_optimizer, recognizer_optimizer, loss_fn, disc_iters


@gin.configurable('shared_specs')
def get_shared_specs(epochs, batch_size, latent_dim, embed_y, num_gen, kernel_reg, g_bw_attention, d_bw_attention):
    return epochs, batch_size, latent_dim, embed_y, num_gen, kernel_reg, g_bw_attention, d_bw_attention


@gin.configurable('io')
def setup_io(base_path, checkpoint_dir, gen_imgs_dir, model_dir, raw_dir, read_dir, input_dim, buf_size, n_classes,
             seq_len, char_vec, bucket_size):
    gen_path = base_path + gen_imgs_dir
    ckpt_path = base_path + checkpoint_dir
    m_path = base_path + model_dir
    raw_dir = base_path + raw_dir
    read_dir = base_path + read_dir
    return input_dim, buf_size, n_classes, seq_len, bucket_size, ckpt_path, gen_path, m_path, raw_dir, read_dir, char_vec

In [None]:
# init params
gin.parse_config_file('/content/scrabble-gan/src/scrabble_gan.gin')
epochs, batch_size, latent_dim, embed_y, num_gen, kernel_reg, g_bw_attention, d_bw_attention = get_shared_specs()
in_dim, buf_size, n_classes, seq_len, bucket_size, ckpt_path, gen_path, m_path, raw_dir, read_dir, char_vec = setup_io()

## Load and Preprocess Dataset

In [None]:
# convert IAM Handwriting dataset (words) to GAN format
if not os.path.exists(read_dir):
  print('converting iamDB-Dataset to GAN format...')
  init_reading(raw_dir, read_dir, in_dim, bucket_size)

# load random words into memory (used for word generation by G)
random_words = load_random_word_list(read_dir, bucket_size, char_vec)

# load and preprocess dataset (python generator)
train_dataset = load_prepare_data(in_dim, batch_size, read_dir, char_vec, bucket_size)

## Build Composite Model

In [None]:
# init generator, discriminator and recognizer
generator = make_generator(latent_dim, in_dim, embed_y, gen_path, kernel_reg, g_bw_attention, n_classes)
discriminator = make_discriminator(gen_path, in_dim, kernel_reg, d_bw_attention)
recognizer = make_recognizer(in_dim, seq_len, n_classes + 1, gen_path)

# build composite model (update G through composite model)
gan = make_gan(generator, discriminator, recognizer, gen_path)

# init optimizer for both generator, discriminator and recognizer
generator_optimizer, discriminator_optimizer, recognizer_optimizer, loss_fn, disc_iters = setup_optimizer()

## Define Optimizers + Checkpoint-Saver


In [None]:
# purpose: save and restore models
checkpoint_prefix = os.path.join(ckpt_path, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                     discriminator_optimizer=discriminator_optimizer,
                                     recognizer_optimizer=recognizer_optimizer,
                                     generator=generator,
                                     discriminator=discriminator,
                                     recognizer=recognizer)


## Start Training

**note:** If you use the [free Colab version](https://colab.research.google.com/signup), you should first reduce the number of epochs (e.g. epochs=5) to not exceed the 12h time limit.  

In [None]:
# reuse this seed + labels overtime to visualize progress in the animated GIF
seed = tf.random.normal([num_gen, latent_dim])
random_bucket_idx = random.randint(4, bucket_size - 1)
labels = np.array([random.choice(random_words[random_bucket_idx]) for _ in range(num_gen)], np.int32)

In [None]:
# start training
train(train_dataset, generator, discriminator, recognizer, gan, checkpoint, checkpoint_prefix, generator_optimizer,
          discriminator_optimizer, recognizer_optimizer, [seed, labels], buf_size, batch_size, epochs, m_path,
          latent_dim, gen_path, loss_fn, disc_iters, random_words, bucket_size, char_vec)

In [None]:
# use imageio to create an animated gif using the images saved during training.
make_gif(gen_path)
embed.embed_file(gen_path + 'biggan.gif')

## Run Inference On Your Data

In [None]:
path_to_saved_model = '/content/scrabble-gan/res/out/big_ac_gan/model/generator_' + str(epochs)

# number of samples to generate
n_samples = 10
# your sample string
sample_string = 'machinelearning'

# load trained model
imported_model = tf.saved_model.load(path_to_saved_model)

# inference loop
for idx in range(1):
  fake_labels = []
  words = [sample_string] * 10
  noise = tf.random.normal([n_samples, latent_dim])
  
  # encode words
  for word in words:
    fake_labels.append([char_vec.index(char) for char in word])
  fake_labels = np.array(fake_labels, np.int32)

  # run inference process
  predictions = imported_model([noise, fake_labels], training=False)
  # transform values into range [0, 1]
  predictions = (predictions + 1) / 2.0

  # plot results
  for i in range(predictions.shape[0]):
    plt.subplot(10, 1, i + 1)
    plt.imshow(predictions[i, :, :, 0], cmap='gray')
    # plt.text(0, -1, "".join([char_vec[label] for label in fake_labels[i]]))
    plt.axis('off')
    plt.show()