<a href="https://colab.research.google.com/github/mvenouziou/Project-Attention-Is-What-You-Get/blob/main/bms_molecular_translation_AttentionIsWhatYouGet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Attention is What You Get

This is my entry into the [Bristol-Myers Squibb Molecular Translation](https://www.kaggle.com/c/bms-molecular-translation)  Kaggle competition.

-----

AUTHOR: 

Mo Venouziou

- *Email: mvenouziou@gmail.com*
- *LinkedIn: www.linkedin.com/in/movenouziou/*

Updates:

 - *Original Posting: June 2, 2021*
 - *07/04/21: fixed kaggle compat issues*
 - *06/25/21: fixed conv decoder layers*
 - *06/21/21: added TPU compatability*
 - *06/17/21: improved training & inference speed. Allows full AIAYN model size on TPU, although my experiments show that limited memory capacity is better served with smaller transformers and larger image encoding steps.*

----

### Our Goal: Predict the "InChI" value of any given chemical compound diagram. 

International Chemical Identifiers ("InChI values") are a standardized encoding to describe chemical compounds. They take the form of a string of letters, numbers and deliminators, often between 100 - 400 characters long. 

The chemical diagrams are provided as PNG files, often of such low quality that it may take a human several seconds to decipher. 

Label length and image quality become a serious challenge here, because we must predict labels for a very large quantity of images. There are 1.6 million images in the test set abd 2.4 million images available in the training set!

In [49]:
"""
# Example (image, target label) pair\n\n'
for val in train_ds.unbatch().take(1):
    print('Example Label:\n', val['InChI'].numpy())
    print('\nCorresponding Image:', plt.imshow(val['image'][:,:,0], cmap='binary'))
### note: load datasets before running this cell
"""

SyntaxError: ignored

## MODEL STRUCTURE: 

**Image CNN + Attention Features encoder --> text Attention + (optional )CNN feature layer decoder.**

This is a hybrid approach with:
 
 - Image Encoder strategy from [*Show, Attend and Tell: Neural Image Caption Generation with Visual Attention*](https://proceedings.mlr.press/v37/xuc15.pdf).  Generate image feature vectors using intermediate layer outputs from a pretrained CNN. (Here I use the more modern EfficientNet model (recommended by [*Darien Schettler*](https://www.kaggle.com/dschettler8845/bms-efficientnetv2-tpu-e2e-pipeline-in-3hrs/notebook)) with final layer before head, fixed weights and a trainable Dense layer for tuning.)
 
 - Transformer encoder-decoder model from [*All You Need is Attention*](https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf) (Self-attention feature extraction for both encoder and decoder, joint encoder-decoder attention feature interactions, and a dense prediction output block. Includes parameters to control number of encoder / decoder blocks.)

 - ***PLUS*** *(optional):* Decoder Output Blocks placed in Series (not stacked). Increase the number of trainable parameters without adding inference computational complexity, while also allowing decoders to specialize on different regions of the output.
 
 - ***PLUS*** *(optional):* Is attention really all you need? Add a convolutional layer to enhance text features before decoder self-attention to experiment with performance differences with and without extra convolutional layer(s). Use of CNN's in NLP comes from [*Convolutional Sequence to Sequence Learning*](http://proceedings.mlr.press/v70/gehring17a.html.)

 - ***PLUS*** *(optional):* Beam-Search Alternative, an extra decoding layer applied after the full logits prediction has been made. This takes the form of a bidirectional RNN with attention, applied to the full logits sequence. Because a full (initial) prediction has already been made, computations can be parallelized using statefull RNNs. (See more details below.)

----

## NEXT STEPS:

 - Experiment with **"Tokens-to-Token ViT"** in place of the image CNN. (Technique from [*Training Vision Transformers from Scratch on ImageNet*](https://arxiv.org/pdf/2101.11986.pdf)
  
 - Train my **Beam-search Alternative**. 

    - Beam search is a technique to modify model predictions to reflect the (local) maximum likelihood estimate. However, it is *very* local in that computation expense increases quickly with the number of character steps taken into account. This is also a hard-coded algorithm, which is somewhat contrary to the philosophy of deep learning.

    - A *Beam-search Alternative* would be an extra decoding layer applied *after* the full logits prediction has been made. This might be in the form of a stateful, bidirectional RNN that is computationally parallizable because it is applied to the full logits sequence.

    - Need to revamp code to accept main model changes made for TPU support.

 - Treat the number of convolutional layers (decoder feature extraction) and number of decoders places in series (decoder prediction output) as **new hyperparameters** to tune.

 - *6/21/21: TPU Support added* ~~Implement TPU compatability.~~

 - *6/17/21: Increased model size and efficiency. * ~~ Improve code to have capacity for full size model (matching AISYN) with more efficient training and inference speeds for the large dataset. (TPU required. GPU doesn't have enough memory to train such a large model.)~~

----

### CITATIONS

- "Attention is All You Need." 
 - Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin. NIPS (2017). *https://research.google/pubs/pub46201/*

- "Convolutional Sequence to Sequence Learning."
 
  - Gehring, J., Auli, M., Grangier, D., Yarats, D. & Dauphin, Y.N.. (2017). Convolutional Sequence to Sequence Learning. Proceedings of the 34th International Conference on Machine Learning, in Proceedings of Machine Learning Research 70:1243-1252, *http://proceedings.mlr.press/v70/gehring17a.html.*


- "EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks."
 
  - Mingxing Tan, Quoc V. Le (2019). Convolutional Sequence to Sequence Learning. International Conference on Machine Learning. *http://arxiv.org/abs/1905.11946.*


-  "Show, Attend and Tell: Neural Image Caption Generation with Visual Attention."
  -  Xu, K., Ba, J., Kiros, R., Cho, K., Courville, A., Salakhudinov, R., Zemel, R. & Bengio, Y.. (2015). Show, Attend and Tell: Neural Image Caption Generation with Visual Attention. Proceedings of the 32nd International Conference on Machine Learning, in Proceedings of Machine Learning Research 37:2048-2057. *http://proceedings.mlr.press/v37/xuc15.html.* 
            

- "Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet"

  - Li Yuan, Yunpeng Chen, Tao Wang, Weihao Yu, Yujun Shi, Zihang Jiang, Francis EH Tay, Jiashi Feng, Shuicheng Yan. Preprint (2021). *https://arxiv.org/abs/2101.11986*.

- Tensorflow documentation tutorial "Transformer model for language understanding." I found this after fully completing the model and found the attention mask was incorrect. My use of "tf.linalg.band_part" (only) is due to this tutorial. *www.tensorflow.org/text/tutorials/transformer#masking*

- Special thanks to:

 -  [Darien Schettler](https://www.kaggle.com/dschettler8845/bms-efficientnetv2-tpu-e2e-pipeline-in-3hrs/notebook.) for leading readers to the "Show" and "Attention" papers cited above, using *session.run()* to improve inference speed in distributed settings and providing detailed info on creating TF Records. I also utilized his scores as performance benchmarks, testing against his results and experimenting with models that sometimes included his hyper-parameter choices, most significantly by dropping the encoder attention component altogether. (My model allows this by setting encoder_heads=0.) This work is otherwise derived independently from his.

 - [Qishen Ha Team](https://www.kaggle.com/c/bms-molecular-translation/discussion/243943) for sharing their structure results. I switched to larger input image dimensions based on their success with much larger values than I had previously used. However image and model size was *significantly* smaller than theirs due to Colab memory allowances vs their top of the line hardware.

- It is possible my idea of a Beam Search Alternative is based on a lecture video from DeepLearning.ai's [Deep Learning Specialization](https://www.coursera.org/specializations/deep-learning)  on Coursera.

- **Dataset / Kaggle Competition:** "Bristol-Myers Squibb – Molecular Translation" competition on Kaggle (2021). *https://www.kaggle.com/c/bms-molecular-translation*

----


## Contents

1. [Imports](https://colab.research.google.com/drive/1i6LMwu7BRfs955U4AdtV2oaI_9_A_Awq#scrollTo=TjuUOVXao__C&line=4&uniqifier=1)
2. [Data Pipeline](https://colab.research.google.com/drive/1i6LMwu7BRfs955U4AdtV2oaI_9_A_Awq#scrollTo=lrLHKs5Ni7Sz)
3. [Model Layers](https://colab.research.google.com/drive/1i6LMwu7BRfs955U4AdtV2oaI_9_A_Awq#scrollTo=W0T-u0vZamI8)
    - [InChI Encoding](https://colab.research.google.com/drive/1i6LMwu7BRfs955U4AdtV2oaI_9_A_Awq#scrollTo=DYApmA2lf1hp&line=1&uniqifier=1)
    - [Image Encoding and Self-Attention](https://colab.research.google.com/drive/1i6LMwu7BRfs955U4AdtV2oaI_9_A_Awq#scrollTo=FESofcGdEaWF&line=1&uniqifier=1)
    - [Decoder Self-Attention](https://colab.research.google.com/drive/1i6LMwu7BRfs955U4AdtV2oaI_9_A_Awq#scrollTo=6qFDs9RTjvod&line=1&uniqifier=1)
    - [Joint Encoder-Decoder Attention](https://colab.research.google.com/drive/1i6LMwu7BRfs955U4AdtV2oaI_9_A_Awq#scrollTo=jP-t1MkKnD5L)
    - [Decoder Head (Prediction Output)](https://colab.research.google.com/drive/1i6LMwu7BRfs955U4AdtV2oaI_9_A_Awq#scrollTo=38GA7wtNEhqW&line=1&uniqifier=1)
    - [Update Mechanism](https://colab.research.google.com/drive/1i6LMwu7BRfs955U4AdtV2oaI_9_A_Awq#scrollTo=_2UR1DLljD0S&line=1&uniqifier=1)
4. [Full Model](https://colab.research.google.com/drive/1i6LMwu7BRfs955U4AdtV2oaI_9_A_Awq#scrollTo=D6GIs3f3rpu0&line=1&uniqifier=1)
5. [Training](https://colab.research.google.com/drive/1i6LMwu7BRfs955U4AdtV2oaI_9_A_Awq#scrollTo=otxdN02mf1ht&line=1&uniqifier=1)
6. [Inference](https://colab.research.google.com/drive/1i6LMwu7BRfs955U4AdtV2oaI_9_A_Awq#scrollTo=Sbvzr5rdmjgs&line=5&uniqifier=1)

---

In [1]:
#### PACKAGE IMPORTS ####

# system management
import os
os.environ['TF_ENABLE_ONEDNN.OPTS'] = '1'  # Intel's TF optimization

# TF Model design
import tensorflow as tf
from tensorflow import keras
import tensorflow_hub as hub
from tensorflow.data import TFRecordDataset
from tensorflow.data.experimental import TFRecordWriter

# Text processing
import re
import string

# metric for Kaggle Competition
!pip install -q leven
from leven import levenshtein

# Kaggle (for TPU)
#from kaggle_datasets import KaggleDatasets

# Visualizations
import matplotlib.pyplot as plt
%matplotlib inline
from PIL import Image

# Tensorboard Profiler
!pip install -U -q tensorboard
!pip install -U -q tensorboard_plugin_profile
!pip install --upgrade -q "cloud-tpu-profiler>=2.3.0"
%load_ext tensorboard

"""
# Debugger
tf.debugging.experimental.enable_dump_debug_info('./logs/', tensor_debug_mode="FULL_HEALTH", 
                                                 circular_buffer_size=-1)
"""

# data management
import numpy as np
import pandas as pd
import itertools

[?25l[K     |██▏                             | 10kB 25.7MB/s eta 0:00:01[K     |████▎                           | 20kB 16.3MB/s eta 0:00:01[K     |██████▍                         | 30kB 10.7MB/s eta 0:00:01[K     |████████▌                       | 40kB 8.8MB/s eta 0:00:01[K     |██████████▋                     | 51kB 5.0MB/s eta 0:00:01[K     |████████████▊                   | 61kB 5.4MB/s eta 0:00:01[K     |██████████████▉                 | 71kB 5.9MB/s eta 0:00:01[K     |█████████████████               | 81kB 6.1MB/s eta 0:00:01[K     |███████████████████             | 92kB 6.5MB/s eta 0:00:01[K     |█████████████████████▏          | 102kB 6.7MB/s eta 0:00:01[K     |███████████████████████▎        | 112kB 6.7MB/s eta 0:00:01[K     |█████████████████████████▍      | 122kB 6.7MB/s eta 0:00:01[K     |███████████████████████████▌    | 133kB 6.7MB/s eta 0:00:01[K     |█████████████████████████████▋  | 143kB 6.7MB/s eta 0:00:01[K     |████████████████████████

## Model parameters

The 'ModelParameters' class manages global hyperparamaters for portability between Colab and Kaggle notebook environments. Once set, all other cells will run on either platform.

On Colab, connection to my personal Google Drive is required, as ModelParameters will extract the dataset from a zip file to the hosted environment. This process may take several minutes. (It would not be difficult for the reader to update the code to point to their own drive and download the zip dataset using the Kaggle API code below.)

In [2]:
""" Kaggle api for download the compressed dataset from Kaggle's servers.

# imports
!pip uninstall -y kaggle
!pip install --upgrade pip
!pip install kaggle==1.5.6

# if needed, download data using '!kaggle competitions download -c bms-molecular-translation'
# then unzip with '! unzip bms-molecular-translation.zip -d datasets'
os.environ['KAGGLE_CONFIG_DIR'] = '/content/gdrive/MyDrive/Kaggle'  # api token location
"""

" Kaggle api for download the compressed dataset from Kaggle's servers.\n\n# imports\n!pip uninstall -y kaggle\n!pip install --upgrade pip\n!pip install kaggle==1.5.6\n\n# if needed, download data using '!kaggle competitions download -c bms-molecular-translation'\n# then unzip with '! unzip bms-molecular-translation.zip -d datasets'\nos.environ['KAGGLE_CONFIG_DIR'] = '/content/gdrive/MyDrive/Kaggle'  # api token location\n"

In [3]:
# check for TPU & initialize
try:
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
    tf.config.experimental_connect_to_cluster(resolver)

    tf.tpu.experimental.initialize_tpu_system(resolver)
    print("All devices: ", tf.config.list_logical_devices('TPU'))

    STRATEGY = tf.distribute.TPUStrategy(resolver)
    TPU = True
    os.environ["TFHUB_MODEL_LOAD_FORMAT"] = "UNCOMPRESSED"  # for TF Hub models on TPU
    PRECISION_TYPE = 'mixed_bfloat16' 
    #PRECISION_TYPE = 'float32' 

    # extra imports for GCS
    !pip install -q fsspec
    !pip install -q gcsfs
    import fsspec, gcsfs 

except:
    TPU = False
    STRATEGY = tf.distribute.get_strategy()
    PRECISION_TYPE = 'mixed_float16'

# enable mixed precision
tf.keras.mixed_precision.set_global_policy(PRECISION_TYPE)

INFO:tensorflow:Initializing the TPU system: grpc://10.93.81.34:8470


INFO:tensorflow:Initializing the TPU system: grpc://10.93.81.34:8470


INFO:tensorflow:Clearing out eager caches


INFO:tensorflow:Clearing out eager caches


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Finished initializing TPU system.


All devices:  [LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU')]
INFO:tensorflow:Found TPU system:


INFO:tensorflow:Found TPU system:


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


[K     |████████████████████████████████| 122kB 6.7MB/s 
[K     |████████████████████████████████| 1.3MB 8.1MB/s 
[K     |████████████████████████████████| 143kB 36.5MB/s 
[K     |████████████████████████████████| 296kB 36.5MB/s 
[?25h

In [4]:
class ModelParameters:
    def __init__(self, cloud_server='kaggle'):
               
        # universal parameters
        self._batch_size = 16  # used on GPU. TPU batch size increased below
        self._padded_length = 200
        self._image_size = (224, 224)  # shape to process images in data pipeline. Size is restricted by memory constraints.
        self.SOS_string = 'InChI=1S/'  # start of sentence value
        self.EOS_string = '<EOS>'  # end of sentence value
        self._strategy = STRATEGY
        self._precision_type = PRECISION_TYPE
        self._tpu = TPU
        
        # TPU batch size
        if self._tpu:
            # note: utilize steps_per_execution compile parameter to increase TPU throughput
            self._batch_size = 128 * self._strategy.num_replicas_in_sync  

        # File Paths       
        if cloud_server == 'colab':  # Google Colab
            
            # load drive for saving checkpoints
            try:
                from google.colab import drive
                drive.mount('/content/gdrive/') 
            except:
                pass  # drive already mounted
            
            # check for TPU 
            if self._tpu: 

                # TPU file structure (via Kaggle GCS folder)
                self._dataset_dir = 'gs://kds-df3031ee4e277d641d1044cc3e9386a923ca98833b0d51a2575d2932' # from Kaggle. Get updated directory on Kaggle via KaggleDatasets().get_gcs_path('bms-molecular-translation')
                self._prepared_files_dir = 'gs://kds-96b617b700ddb4d07bc42a47c0a7abfe3a68d4510c459b7cd7b216e6'  # from Kaggle. Get updated directory on Kaggle via KaggleDatasets().get_gcs_path('periodic-table')
                self._tfrec_dir = 'gs://kds-cf4c41e9d27a61775a28e7b7420f2258a75bd762c3ffcd88c9e83dc2'  # from Kaggle. Get updated directory on Kaggle via KaggleDatasets().get_gcs_path('bmsshards')
                self._checkpoint_dir = '/content/gdrive/MyDrive/Colab_Notebooks/models/MolecularTranslation/checkpoints/'  # gdrive
                self._load_checkpoint_dir = self._checkpoint_dir
                self._csv_save_dir = './'

            else:
                # unzip data
                if not os.path.isdir('/content/bms-molecular-translation'):
                    !unzip -q /content/gdrive/MyDrive/Colab_Notebooks/models/MolecularTranslation/bms-molecular-translation.zip -d '/content/bms-molecular-translation'
                
                # file paths
                self._dataset_dir = 'bms-molecular-translation/'
                self._prepared_files_dir = '/content/gdrive/MyDrive/Colab_Notebooks/models/MolecularTranslation/'
                self._checkpoint_dir = '/content/gdrive/MyDrive/Colab_Notebooks/models/MolecularTranslation/checkpoints/'
                self._load_checkpoint_dir = self._checkpoint_dir
                self._csv_save_dir = self._prepared_files_dir 
                self._tfrec_dir = None
                
        elif cloud_server == 'kaggle': # Kaggle cloud notebook (CPU / GPU)
            from kaggle_datasets import KaggleDatasets
            
            # check for TPU 
            if self._tpu: 
                
                # file paths
                self._dataset_dir = '' #KaggleDatasets().get_gcs_path('bms-molecular-translation')
                self._prepared_files_dir = KaggleDatasets().get_gcs_path('periodic-table')
                self._tfrec_dir = KaggleDatasets().get_gcs_path('bmsshards')
                self._checkpoint_dir = './'
                self._load_checkpoint_dir = './'
                self._csv_save_dir = './'

            # set GPU instance info
            else:  
                # file paths
                self._dataset_dir = '../input/bms-molecular-translation/'
                self._prepared_files_dir = '../input/periodic-table/'
                self._tfrec_dir = '../input/bmsshards/'
                self._checkpoint_dir = './'
                self._load_checkpoint_dir = '../input/k/mvenou/bms-molecular-translation/checkpoints/'
                self._csv_save_dir = './'
                self._tfrec_dir = None

        # common file paths
        self._periodic_table_csv = os.path.join(self._prepared_files_dir, 'periodic_table_elements.csv')
        self._vocab_csv = os.path.join(self._prepared_files_dir, 'vocab.csv')        
        self._test_images_dir = os.path.join(self._dataset_dir, 'test/')
        self._train_images_dir = os.path.join(self._dataset_dir, 'train/')
        self._extra_labels_csv = os.path.join(self._dataset_dir, 'extra_approved_InChIs.csv')
        self._train_labels_csv = os.path.join(self._dataset_dir, 'train_labels.csv')
        self._sample_submission_csv = os.path.join(self._dataset_dir, 'sample_submission.csv')
        
    # functions to access params
    def padded_length(self):
        return self._padded_length
    def mixed_precision(self):
        return self._precision_type
    def tpu(self):
        return self._tpu
    def tfrec_dir(self):
        return self._tfrec_dir
    def cloud_server(self):
        return self._cloud_server
    def strategy(self):
        return self._strategy
    def csv_save_dir(self):
        return self._csv_save_dir
    def train_labels_csv(self):
        return self._train_labels_csv
    def vocab_csv(self):
        return self._vocab_csv
    def periodic_table_csv(self):
        return self._periodic_table_csv
    def batch_size(self):
        return self._batch_size  
    def image_size(self):
        return self._image_size    
    def SOS(self):
        return self.SOS_string
    def EOS(self):
        return self.EOS_string
    def train_images_dir(self):
        return self._train_images_dir
    def test_images_dir(self):
        return self._test_images_dir   
    def checkpoint_dir(self):
        return self._checkpoint_dir
    def load_checkpoint_dir(self):
        return self._load_checkpoint_dir


Initialize Parameter Options

In [5]:
PARAMETERS = ModelParameters(cloud_server='colab')

Mounted at /content/gdrive/


# **Input Pipeline**

Load train labels as DataFrame

In [6]:
if not PARAMETERS.tpu():
    # Load CSV as dataframe
    train_labels_df = pd.read_csv(PARAMETERS.train_labels_csv())
    train_labels_df.head()

### InChI Text Parsing

We split each InChI label into its "vocabulary" of logical subunits, consisting of element abbreviations numbers, common symbols and the required string 'InChI=1S/', which is at the start of every InChI label. We want to narrow down this vocabulary to the smallest set represented in our training data. The functions below provide a system for finding this minimal set, as well as preparing a new CSV file with parsed labels ready to be fed into a tokenizer layer.

(For clarity and to reduce reliance on loading external files, the true code has been commented out and replaced with corresponding hard-coded values.)

In [7]:
def inchi_parsing_regex(parameters=PARAMETERS):
    # regex for spliting on InChi, but preserving chemical element abbreviations and three-digit numbers
    
    # shortcut: hard coded values
    vocab = [parameters.EOS(), parameters.SOS(), '(',
            ')', '+', ',', '-', '/', 'Br', 'B', 'Cl', 'C', 'D', 'F',
            'H', 'I', 'N', 'O', 'P', 'Si', 'S', 'T', 'b', 'c', 'h', 'i',
            'm', 's', 't']
        
    vocab += [str(num) for num in reversed(range(168))]
    vocab = [re.escape(val) for val in vocab]
       
    """ # to create vocab from scratch, use:
    SOS = parameters.SOS()
    EOS = parameters.EOS()
    
    # load list of elements we should search for within InChI strings: 
    periodic_elements = pd.read_csv(PARAMETERS.periodic_table_csv(), header=None)[1].to_list()
    periodic_elements = periodic_elements + [val.lower() for val in periodic_elements] + [val.upper() for val in periodic_elements]
    
    punctuation = list(string.punctuation)
    punctuation = [re.escape(val) for val in punctuation]   # update values with regex escape chars added as needed

    three_dig_nums_list = [str(i) for i in range(1000, -1, -1)]

    vocab = [SOS, EOS] + periodic_elements + three_dig_nums_list + punctuation
    """

    split_elements_regex = rf"({'|'.join(vocab)})"
    
    return split_elements_regex

In [8]:
INCHI_PARSING_REGEX = inchi_parsing_regex()

def parse_InChI(texts, parsing_regex=INCHI_PARSING_REGEX):  
    return ' '.join(re.findall(parsing_regex, texts))


# TF dataset map-compatible version
def parse_InChI_py_fn(texts, parsing_regex=INCHI_PARSING_REGEX):
    def tf_parse_InChI(texts):  
        texts = np.char.array(texts.numpy())
        texts = np.char.decode(texts).tolist()
        texts = tf.constant([parse_InChI(val) for val in texts])
        return tf.squeeze(texts)
    return tf.py_function(func=tf_parse_InChI, inp=[texts], Tout=tf.string)


# extracts filepath from image name
def path_from_image_id(x, root_folder):
    folder_a = tf.strings.substr(x, pos=0, len=1)
    folder_b = tf.strings.substr(x, pos=1, len=1)
    folder_c = tf.strings.substr(x, pos=2, len=1)
    filename =  tf.strings.join([x, '.png'])
    return tf.strings.join([root_folder, folder_a, folder_b, folder_c, filename], separator='/')

Tokenizer

Note: This must be kept outside the model (and used in dataset prep) for TPU compatability

In [9]:
def Tokenizer(parameters):
    """ tokenizes, crops & pads to parameters.padded_length() """

    SOS = parameters.SOS()
    EOS = parameters.EOS()
    padded_length = PARAMETERS.padded_length()
    
    # Create vocabulary for tokenizer
    def create_vocab():       
        hard_coded_vocab = [PARAMETERS.EOS(), PARAMETERS.SOS(), '(',
            ')', '+', ',', '-', '/', 'B', 'Br',  'C', 'Cl', 'D', 'F',
            'H', 'I', 'N', 'O', 'P', 'S', 'Si', 'T', 'b', 'c', 'h', 'i',
            'm', 's', 't']
        
        numbers = [str(num) for num in range(168)]
        
        vocab = hard_coded_vocab + numbers
        
        """
        # get from saved file
        vocab = pd.read_csv(PARAMETERS.vocab_csv())['vocab_value'].to_list()   
        vocab = list(vocab)
        """

        """ 
        # To create from scratch, extract all vocab elements appearing in train set:
        df = pd.read_csv(PARAMETERS.train_labels_csv())  
        seg_len = 250000
        num_breaks = len(df) // seg_len

        vocab = set()
        for i in range(num_breaks):

            df_i =  df['InChI'].iloc[seg_len * i: seg_len * (i+1)]
            texts =  df_i.apply(lambda x: set(parse_InChI(x).split()))
            texts = texts.tolist()

            vocab = vocab.union(*texts)

            print(f'completed {i} / {num_breaks}')

        vocab = list(vocab)
        vocab_df = pd.DataFrame({'vocab_value': vocab})

        # save results
        filename = os.path.join(PARAMETERS.csv_save_dir(), 'vocab.csv')
        vocab_df.to_csv(filename, index=False)
        """
               
        return vocab

    vocab = create_vocab()
    
    # create tokenizer
    tokenizer_layer = tf.keras.layers.experimental.preprocessing.TextVectorization(
        standardize=None, split=lambda x: tf.strings.split(x, sep=' ', maxsplit=-1), 
        output_mode='int', output_sequence_length=padded_length, vocabulary=vocab)

    # record EOS token
    tokenized_EOS = tokenizer_layer(tf.constant([EOS]))
    
    # create inverse (de-tokenizer)
    inverse_tokenizer = tf.keras.layers.experimental.preprocessing.StringLookup(
        vocabulary=vocab, invert=True)

    return tokenizer_layer, inverse_tokenizer, tokenized_EOS

In [10]:
TOKENIZER_LAYER, INVERSE_TOKENIZER, TOKENIZED_EOS = \
    Tokenizer(parameters=PARAMETERS)

def tokenize_text(w, x, y, z):
    # note: requires batch dim
    y = TOKENIZER_LAYER(y)
    return w, x, y, z

Image Loader

In [11]:
# Image loaders
def load_image(image_path):
    image_path = tf.squeeze(image_path)
    image = keras.layers.Lambda(lambda x: tf.io.read_file(x))(image_path)
    return image   

def decode_image(image, target_size):
    image = keras.layers.Lambda(lambda x: tf.io.decode_image(x, channels=3, expand_animations=False))(image)
    image = keras.layers.experimental.preprocessing.Resizing(*target_size)(image)
    return image    

## Datasets

Here we create efficient tf.data.Dataset train / validation / test sets.

Out data pipeline will read our prepared CSV of (image filename, parsed InChI and standard InChI) tuples. (If this file is not found, it will be created from scratch. This may take several minutes)  Iterating through the list, it will load batches of corresponding images and labels.

Our datasets contain the following information, accessible by dict keys: images, image_id, InChI, parsed_InChI. (The test set uses InChI = parsed_InChI = 'InChI=1S/', the known required stating value for any InChI code.)

In [12]:
def data_generator(image_set, parameters=PARAMETERS, labels_df=None, decode_images=True):
       
    # get global params
    batch_size = parameters.batch_size()
    target_size = parameters.image_size()
    SOS = parameters.SOS()
    EOS = parameters.EOS()
    
    # dataset options
    options = tf.data.Options()
    options.experimental_optimization.autotune_buffers = True
    options.experimental_optimization.map_vectorization.enabled = True
    options.experimental_optimization.apply_default_optimizations = True
        
    # Train & Validation Datasets
    if image_set in ['train', 'valid']:
        root_folder = parameters.train_images_dir()  # train / valid images
        valid_split = 0.10
        
        # load labels into memory as dataframe
        if labels_df is None:
            labels_df = pd.read_csv(parameters.train_labels_csv())

        # test / train split
        num_valid_samples = int(valid_split * len(labels_df))
        train_df = labels_df.iloc[num_valid_samples: ]  # get train split
        valid_df = labels_df.iloc[: num_valid_samples]  # get validation split

        # shuffle
        train_df = train_df.sample(frac=1)
        valid_df = valid_df.sample(frac=1)

        # load into datasets  # (image_id, InChI)
        train_ds = tf.data.Dataset.from_tensor_slices(train_df.values)
        valid_ds = tf.data.Dataset.from_tensor_slices(valid_df.values)

        train_ds = train_ds.with_options(options)
        valid_ds = valid_ds.with_options(options)

        # update image paths  
        def map_path(x):  # (image_path, image_id, InChI)
            image_id = x[0]
            image_path = path_from_image_id(image_id, root_folder)
            return image_path, x[0], x[1]

        train_ds = train_ds.map(map_path, num_parallel_calls=tf.data.AUTOTUNE)
        valid_ds = valid_ds.map(map_path, num_parallel_calls=tf.data.AUTOTUNE)

        def map_parse(x, y, z):  # (image_path, image_id, InChI)
            parsed_InChI = parse_InChI_py_fn(z)
            return x, y, parsed_InChI, z
   
        train_ds = train_ds.map(map_parse, num_parallel_calls=tf.data.AUTOTUNE)
        valid_ds = valid_ds.map(map_parse, num_parallel_calls=tf.data.AUTOTUNE)
                
        # load images into dataset       
        def open_images(w, x, y, z):
            w = load_image(w)
            return w, x, y, z
        
        train_ds = train_ds.map(open_images, num_parallel_calls=tf.data.AUTOTUNE)
        valid_ds = valid_ds.map(open_images, num_parallel_calls=tf.data.AUTOTUNE)    

        # PREFETCH dataset BEFORE decoding images
        train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
        valid_ds = valid_ds.prefetch(tf.data.AUTOTUNE)

        def decode(w, x, y, z):
            w = decode_image(w, target_size)
            return w, x, y, z

        if decode_images:
            train_ds = train_ds.map(decode, num_parallel_calls=tf.data.AUTOTUNE)
            valid_ds = valid_ds.map(decode, num_parallel_calls=tf.data.AUTOTUNE)    

        # BATCH dataset AFTER decoding images (required by tf.io)
        # should batch before other pure TF Lambda layer ops
        train_ds = train_ds.batch(batch_size, drop_remainder=True)
        valid_ds = valid_ds.batch(batch_size, drop_remainder=True)
        
        # add extra "EOS" values to end of parsed inchi
        def extend_EOS(w, x, y, z):
            y = tf.strings.join([y, EOS, EOS, EOS, EOS, EOS], separator=' ')
            y = tf.reshape(y, [-1])
            return w, x, y, z

        train_ds = train_ds.map(extend_EOS, num_parallel_calls=tf.data.AUTOTUNE)
        valid_ds = valid_ds.map(extend_EOS, num_parallel_calls=tf.data.AUTOTUNE)

        # Tokenize parsed_inchi.  Note: ds must be batched before this step (size=1 is ok) 
        train_ds = train_ds.map(tokenize_text, num_parallel_calls=tf.data.AUTOTUNE)
        valid_ds = valid_ds.map(tokenize_text, num_parallel_calls=tf.data.AUTOTUNE)

        # name the elements
        def map_names(w, x, y, z):
            return  {'image': w, 'image_id': x, 'tokenized_InChI': y, 'InChI': z}
        
        train_ds = train_ds.map(map_names, num_parallel_calls=tf.data.AUTOTUNE)
        valid_ds = valid_ds.map(map_names, num_parallel_calls=tf.data.AUTOTUNE)
        
        return train_ds, valid_ds
    
    # Test Dataset
    elif image_set == 'test':

        # note: image resizing and batching done during this loading step
        # other elements must be batched before combining
        image_ds = tf.keras.preprocessing.image_dataset_from_directory(
            directory=parameters.test_images_dir(), labels='inferred', label_mode=None,
            class_names=None, color_mode='grayscale', batch_size=1, 
            image_size=target_size, shuffle=False, seed=None, validation_split=None, 
            subset=None, follow_links=False)

        # set filenames as label and batch
        image_id_ds = tf.data.Dataset.from_tensor_slices(image_ds.file_paths)
        image_id_ds = image_id_ds.map(lambda x: tf.strings.split(x, os.path.sep)[-1],
                                      num_parallel_calls=tf.data.AUTOTUNE)
        
        # prepare images for TF Records creations. 
        # Note: do this step AFTER filenames step
        if decode_images is False:  
            # convert image to raw byte string. Note: cannot have batch dim for encoding
            image_ds = image_ds.unbatch()
            image_ds = image_ds.map(lambda x: tf.cast(x, dtype=tf.uint16))
            image_ds = image_ds.map(lambda image: tf.io.encode_png(image))
            image_ds = image_ds.map(lambda image: tf.io.serialize_tensor(image))
            
        # dataset consisting solely of InChI start 'InChI=1S/'
        inchi_ds = image_id_ds.map(lambda x: tf.constant(SOS, dtype=tf.string),
                                   num_parallel_calls=tf.data.AUTOTUNE)
        
        # merge datasets
        test_ds = tf.data.Dataset.zip((image_ds, image_id_ds, inchi_ds, inchi_ds))
        
        # prefetch
        test_ds = test_ds.prefetch(tf.data.AUTOTUNE)
        test_ds = test_ds.batch(batch_size)

        # Tokenize parsed_inchi.  Note: ds must be batched before this step (size=1 is ok) 
        test_ds = test_ds.map(tokenize_text, num_parallel_calls=tf.data.AUTOTUNE)

        # set key names
        def map_names(w, x, y, z):
            return  {'image': w, 'image_id': x, 'tokenized_InChI': y, 'InChI': z}
        
        test_ds = test_ds.map(map_names, num_parallel_calls=tf.data.AUTOTUNE)
        
        return test_ds

Create Test, Train and Validation Datasets

In [13]:
if not PARAMETERS.tpu():
    train_ds, valid_ds = data_generator('train', parameters=PARAMETERS, labels_df=train_labels_df, decode_images=True)
    #test_ds = data_generator('test', parameters=PARAMETERS, labels_df=None, decode_images=True)

Examine data shapes

In [14]:
if not PARAMETERS.tpu():

    print('Train DS')
    for val in train_ds.take(1):    
        print('image:', val['image'].shape, 'image_id:', val['image_id'].shape, 
              'InChI:', val['InChI'].shape, 'tokenized_InChI:', val['tokenized_InChI'].shape)

    print('\nValidation DS')
    for val in valid_ds.take(1):
        print('image:', val['image'].shape, 'image_id:', val['image_id'].shape, 
              'InChI:', val['InChI'].shape, 'tokenized_InChI:', val['tokenized_InChI'].shape)

    try:
        print('\nTest DS')
        for val in test_ds.take(1):
            print('image:', val['image'].shape, 'image_id:', val['image_id'].shape, 
                'InChI:', val['InChI'].shape, 'tokenized_InChI:', val['tokenized_InChI'].shape)
    except:
        pass

### TF Records Implementation

In [15]:
# Create TF Examples
def make_example(image, image_id, tokenized_InChI, InChI):
    image_feature = tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[image.numpy()])  # image provided as raw bytestring
    )
    image_id_feature = tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[image_id.numpy()])
    )
    tokenized_InChI_feature = tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(tokenized_InChI).numpy()])
    )
    InChI_feature = tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[InChI.numpy()])
    )

    features = tf.train.Features(feature={
        'image': image_feature,
        'image_id': image_id_feature,
        'tokenized_InChI': tokenized_InChI_feature,
        'InChI': InChI_feature
    })
    
    example = tf.train.Example(features=features)

    return example.SerializeToString()


def make_example_py_fn(image, image_id, InChI, tokenized_InChI):
    return tf.py_function(func=make_example, 
                   inp=[image, image_id, InChI, tokenized_InChI], 
                   Tout=tf.string)


# Decode TF Examples
def decode_example(example, parameters=PARAMETERS):        
    feature_description = {'image': tf.io.FixedLenFeature([], tf.string),
                           'image_id': tf.io.FixedLenFeature([], tf.string),
                           'tokenized_InChI': tf.io.FixedLenFeature([], tf.string),
                           'InChI': tf.io.FixedLenFeature([], tf.string)}
    
    values = tf.io.parse_single_example(example, feature_description)
    
    
    values['image'] = decode_image(values['image'], parameters.image_size())
    values['tokenized_InChI'] = tf.io.parse_tensor(values['tokenized_InChI'],
                                                  out_type=tf.int64)
    values['tokenized_InChI'] = tf.cast(values['tokenized_InChI'], tf.int32)
    
    return values

In [16]:
def serialized_dataset_gen(parameters=PARAMETERS, set_type='train', labels_df=None):
    
    if set_type == 'train':
        train_ds, valid_ds = data_generator(image_set='train', 
                                            parameters=parameters, 
                                            labels_df=train_labels_df, 
                                            decode_images=False)  # output images as bytestrings

        train_ds = train_ds.unbatch()
        valid_ds = valid_ds.unbatch()

        # Create TF Examples
        train_ds = train_ds.map(lambda x: make_example_py_fn(x['image'], x['image_id'], x['tokenized_InChI'], x['InChI']), 
                                num_parallel_calls=tf.data.AUTOTUNE)
        valid_ds = valid_ds.map(lambda x: make_example_py_fn(x['image'], x['image_id'], x['tokenized_InChI'], x['InChI']), 
                                num_parallel_calls=tf.data.AUTOTUNE)
        
        return train_ds, valid_ds
    
    else: #test_set:
        test_ds = data_generator(image_set='test', 
                                 parameters=parameters, 
                                 labels_df=None, 
                                 decode_images=False)  # output images as bytestrings
        
        test_ds = test_ds.unbatch()
            
        # Create TF Examples
        test_ds = test_ds.map(lambda x: make_example_py_fn(x['image'], x['image_id'], x['tokenized_InChI'], x['InChI']), 
                              num_parallel_calls=tf.data.AUTOTUNE)
        
        return test_ds

In [17]:
# Create TF Record Shards
"""
NOTE: Changes have been made to the other dataset pipeline functions. 
Test / Revise this for compatability before running.
"""
def create_records(dataset, subset, num_shards):
    
    folder = subset + '_tfrec'
    
    if subset =='train':
        num_samples = int(.9 * len(train_labels_df))    # test / valid split
    elif subset == 'valid':
        num_samples = int(.1 * len(train_labels_df))
    else:
        num_samples = 2000000

    if not os.path.isdir(folder):
        os.mkdir(folder)
        
    for shard_num in range(num_shards):
        
        filename = os.path.join(folder, f'{subset}_shard_{shard_num+1}')
        try:
            this_shard = dataset.skip(shard_num * num_samples//num_shards).take(num_samples//num_shards)
        
            print(f'Writing shard {shard_num+1}/{num_shards} to {filename}')
            writer = tf.data.experimental.TFRecordWriter(filename)
            writer.write(this_shard)
        except:
            break
    return None 
    
# Load dataset from saved TF Record Shards
def dataset_from_records(subset, parameters=PARAMETERS):

    # optimizations
    options = tf.data.Options()
    options.experimental_optimization.autotune_buffers = True
    options.experimental_optimization.map_vectorization.enabled = True
    options.experimental_optimization.apply_default_optimizations = True

    filepath = os.path.join(parameters.tfrec_dir(), 
                            subset + '_tfrec/*')

    dataset = tf.data.Dataset.list_files(filepath)  # put all tf rec filenames in a ds
    dataset = dataset.shuffle(10**6)
 
    # merge the files
    num_readers = parameters.strategy().num_replicas_in_sync
    dataset = dataset.interleave(tf.data.TFRecordDataset,  
                                 cycle_length=num_readers, block_length=1,
                                 deterministic=False, num_parallel_calls=tf.data.AUTOTUNE)
    
    dataset = dataset.shuffle(10**6)
    
    # decode examples
    dataset = dataset.map(decode_example, num_parallel_calls=tf.data.AUTOTUNE)

    # note: tokenized InChI element spec needs help determining shape
    for val in dataset.take(1):
        padded_length = val['tokenized_InChI'].shape[-1]

    # coerce unknown shape
    dataset = dataset.map(lambda x: {'image':x['image'],
                                     'image_id': x['image_id'],
                                     'tokenized_InChI': tf.reshape(x['tokenized_InChI'], [padded_length]),
                                     'InChI': x['InChI']},
                          num_parallel_calls=tf.data.AUTOTUNE)  

    dataset = dataset.batch(parameters.batch_size(), drop_remainder=True)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
        
    return dataset

In [18]:
# To create TF_Records files
# note: can take 8+ hours for train set alone!
"""
serial_train_ds, serial_valid_ds = serialized_dataset_gen(parameters=PARAMETERS, set_type='train', labels_df=train_labels_df)
serial_test_ds = serialized_dataset_gen(parameters=PARAMETERS, set_type='test', labels_df=None)
"""

"\nserial_train_ds, serial_valid_ds = serialized_dataset_gen(parameters=PARAMETERS, set_type='train', labels_df=train_labels_df)\nserial_test_ds = serialized_dataset_gen(parameters=PARAMETERS, set_type='test', labels_df=None)\n"

In [19]:
# IF USING TF_RECORDS:
#NOTE: If no dataset loads, check if Kaggle GCS directories have changed. (This happends periodically)
if PARAMETERS.tpu():
    with PARAMETERS.strategy().scope(): 
        train_ds = dataset_from_records('train', parameters=PARAMETERS)
        valid_ds = dataset_from_records('valid', parameters=PARAMETERS)

    print('Train DS')
    for val in train_ds.take(1):    
        print('image:', val['image'].shape, 'image_id:', val['image_id'].shape, 
              'InChI:', val['InChI'].shape, 'tokenized_InChI:', val['tokenized_InChI'].shape)

    print('\nValidation DS')
    for val in valid_ds.take(1):
        print('image:', val['image'].shape, 'image_id:', val['image_id'].shape, 
              'InChI:', val['InChI'].shape, 'tokenized_InChI:', val['tokenized_InChI'].shape)

    try:
        print('\nTest DS')
        for val in test_ds.take(1):
            print('image:', val['image'].shape, 'image_id:', val['image_id'].shape, 
                'InChI:', val['InChI'].shape, 'tokenized_InChI:', val['tokenized_InChI'].shape)
    except:
        pass

Train DS
image: (1024, 224, 224, 3) image_id: (1024,) InChI: (1024,) tokenized_InChI: (1024, 200)

Validation DS
image: (1024, 224, 224, 3) image_id: (1024,) InChI: (1024,) tokenized_InChI: (1024, 200)

Test DS


# **Model Layers**

## InChI Encoding

Tokenizer and Embedding to convert parsed InChI strings to tensors of numbers. Includes (optional) masked convolutional layer as an extra preprocessing step

InChI Input Prep Layer

In [20]:
def InChIEncoder(vocab_size, inchi_embedding_dim, num_chars, use_convolutions=False):

    EmbeddingLayer = tf.keras.layers.Embedding(input_dim=vocab_size, 
        output_dim=inchi_embedding_dim, mask_zero=False, input_length=num_chars-1)
    
    inchi = keras.layers.Input([num_chars], name='tokenized_inchi')
    start_var = keras.layers.Input([1, inchi_embedding_dim], name='start_var')
    pos_encoding = keras.layers.Input([num_chars, inchi_embedding_dim], name='positional_encoding')

    inputs = [inchi, start_var, pos_encoding]

    # embedding
    inchi = inchi[:, :-1]  # drop last val
    inchi = EmbeddingLayer(inchi)
    
    # (Optional: masked convolution)
    if use_convolutions:
        
        # extend to (batch, len, len, dim) and mask for parallelized convolutions
        ones = tf.ones((num_chars-1, num_chars-1))
        mask = tf.linalg.band_part(ones, -1, 0)
        mask = tf.reshape(mask, [1, num_chars-1, 1, num_chars-1])

        inchi = tf.tile(tf.expand_dims(inchi, -1), [1, 1, 1, num_chars-1])
        inchi *= mask

        # apply parallel convolutions (maintains independence to mask future steps)
        inchi = keras.layers.DepthwiseConv2D(kernel_size=3, 
                    strides=1, padding='same', data_format='channels_first',
                    activation='relu')(inchi)

        # squeeze out last dim
        inchi = keras.layers.Dense(1, activation='relu')(inchi)
        inchi = tf.squeeze(inchi, axis=-1)

    # append start token
    inchi = keras.layers.Concatenate(axis=1)(
        [tf.cast(start_var, dtype=inchi.dtype), inchi])  
    
    # add positional encoding
    inchi = keras.layers.Add()(
        [tf.cast(pos_encoding, dtype=inchi.dtype), inchi])
    
    outputs = [inchi]

    return keras.Model(inputs, outputs, name='InChIEncoder')

In [21]:
InChIEncoder(vocab_size=199, inchi_embedding_dim=512, num_chars=200, use_convolutions=False).summary()

Model: "InChIEncoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
tokenized_inchi (InputLayer)    [(None, 200)]        0                                            
__________________________________________________________________________________________________
start_var (InputLayer)          [(None, 1, 512)]     0                                            
__________________________________________________________________________________________________
tf.__operators__.getitem (Slici (None, 199)          0           tokenized_inchi[0][0]            
__________________________________________________________________________________________________
positional_encoding (InputLayer [(None, 200, 512)]   0                                            
_______________________________________________________________________________________

# Image Encoder

Feature Extraction Step 1: Run the images through a pre-trained image network, extracting features as the output of an intermediate convolutional layer. [Technique from "Show, Attend and Tell: Neural Image Caption Generation with Visual Attention," cited at the top of this notebook.]  A dense layer is added for transfer learning and to control the dimension of the attention mechanism used later.

In [22]:
def ImageFeaturesExtractor(image_shape):
    
    # Note: temporarily disable mixed precision during load. (Model doesn't handle it properly)
    tf.keras.mixed_precision.set_global_policy('float32')  # removed mixed precision

    base_transfer_model = keras.applications.EfficientNetB2(
                            include_top=False, 
                            weights=None,
                            input_shape=(*image_shape[:2], 3))
    tf.keras.mixed_precision.set_global_policy('float32')  # removed mixed precision

    # revert to orig mixed precision policy
    tf.keras.mixed_precision.set_global_policy(PARAMETERS.mixed_precision()) 

    model = keras.Model(inputs=base_transfer_model.inputs, 
                        outputs=base_transfer_model.get_layer('top_activation').output, 
                        name='ImageFeaturesExtractor')
    
    return model

In [23]:
def ImageEncoder(image_shape, encoder_dim, ImageFeaturesExtractor):

    num_features = ImageFeaturesExtractor.output_shape[-3] * ImageFeaturesExtractor.output_shape[-2]

    image = keras.layers.Input(image_shape, name='image')
    pos_encoding = keras.layers.Input([num_features, encoder_dim], name='positional_encoding')

    inputs = [image, pos_encoding]
    
    # get features
    image = ImageFeaturesExtractor(image)

    # update dims
    image = keras.layers.Dense(encoder_dim)(image)
    image = keras.layers.Dense(encoder_dim)(image)

    # reshape as feature vectors
    image = keras.layers.Reshape([-1, encoder_dim])(image)

    # add positional encoding
    pos_encoding = tf.cast(pos_encoding, dtype=image.dtype)
    image = keras.layers.Add(name='add_positional_encoding')([image, pos_encoding])

    outputs = [image]

    return keras.Model(inputs, outputs, name='ImageEncoder')

In [24]:
temp_image_shape=(320,320,3)

ImageEncoder(image_shape=temp_image_shape, 
             encoder_dim=256,
             ImageFeaturesExtractor=ImageFeaturesExtractor(temp_image_shape)).summary()

Model: "ImageEncoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
image (InputLayer)              [(None, 320, 320, 3) 0                                            
__________________________________________________________________________________________________
ImageFeaturesExtractor (Functio (None, 10, 10, 1408) 7768569     image[0][0]                      
__________________________________________________________________________________________________
dense (Dense)                   (None, 10, 10, 256)  360704      ImageFeaturesExtractor[0][0]     
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 10, 10, 256)  65792       dense[0][0]                      
_______________________________________________________________________________________

## Encoder Attention

Feature Extraction Step 2: Now that we have basic feature vectors, we use self-attention to generate more complex features. This is the encoding step used in "Attention is All You Need," cited above. 

In [25]:
def EncoderAttention(num_blocks, num_attention_heads, features_dim, num_encoder_features):

    # Inputs
    encoder_features = keras.layers.Input([num_encoder_features, features_dim], name='encoder_features')
    inputs = [encoder_features]

    for i in range(num_blocks):
        # self attention block
        AttentionLayer = keras.layers.MultiHeadAttention(num_heads=num_attention_heads, 
                                       key_dim = features_dim//num_attention_heads, 
                                       dropout=0.1, name=f'SelfAttention_{i}')

        attention_features = AttentionLayer(query=encoder_features, 
                                    value=encoder_features, 
                                    attention_mask=None)
        
        encoder_features = keras.layers.Add()([encoder_features, attention_features])
        encoder_features = keras.layers.LayerNormalization()(encoder_features)

        # feed forward block
        dense_features = keras.layers.Dense(features_dim, activation='relu')(encoder_features)
        dense_features = keras.layers.Dense(features_dim, activation=None)(dense_features)

        dense_features = keras.layers.Dropout(rate=.1)(dense_features)
        encoder_features = keras.layers.Add()([encoder_features, dense_features])
        encoder_features = keras.layers.LayerNormalization()(encoder_features)

    outputs = [encoder_features]

    return keras.Model(inputs, outputs, name='EncoderAttention')

In [26]:
EncoderAttention(num_blocks=1, num_attention_heads=8, features_dim=256, num_encoder_features=100).summary()

Model: "EncoderAttention"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
encoder_features (InputLayer)   [(None, 100, 256)]   0                                            
__________________________________________________________________________________________________
SelfAttention_0 (MultiHeadAtten (None, 100, 256)     263168      encoder_features[0][0]           
                                                                 encoder_features[0][0]           
__________________________________________________________________________________________________
add_1 (Add)                     (None, 100, 256)     0           encoder_features[0][0]           
                                                                 SelfAttention_0[0][0]            
___________________________________________________________________________________

## Decoder Attention

Text Feature extraction + Encoder/Decoder Joint Attention interaction. This is implemented as a subclassed model (not Functional API) for extra flexibility in input shapes. This will allow for a quicker inference loop.

In [27]:
class DecoderAttentionBlock(keras.layers.Layer):

    def __init__(self, num_attention_heads, max_steps, **kwargs):
        super().__init__(**kwargs)

        self.num_attention_heads = num_attention_heads

        self.max_steps = max_steps

       

        # mask
        ones = tf.ones([self.max_steps, self.max_steps])
        self.full_self_attention_mask = tf.linalg.band_part(ones, -1, 0)

    def get_config(self):
        config = {'num_attention_heads': self.num_attention_heads,
                  'max_steps': max_steps
        }
        return config

    def build(self, input_shape):
        decoder_shape = input_shape[1]
        self.decoder_dim = decoder_shape[-1]

        # Self Attention Layers
        self.SelfAttentionLayer = keras.layers.MultiHeadAttention(
                                    num_heads=self.num_attention_heads, 
                                    key_dim=self.decoder_dim//self.num_attention_heads,
                                    dropout=0.1,
                                    name=f'SelfAttention')
        self.SelfLayerNorm = keras.layers.LayerNormalization()

        
        # Joint Attention Layers
        self.JointAttentionLayer = keras.layers.MultiHeadAttention(
                                    num_heads=self.num_attention_heads, 
                                    key_dim=self.decoder_dim//self.num_attention_heads,
                                    dropout=0.1,
                                    name=f'JointAttention')
        self.JointLayerNorm = keras.layers.LayerNormalization()

        # Feed Forward Layers                
        self.DenseRelu = keras.layers.Dense(self.decoder_dim, activation='relu')
        self.Dense = keras.layers.Dense(self.decoder_dim, activation=None)
        self.LayerNorm = keras.layers.LayerNormalization()


    # fixed shape call for XLA
    def call(self, inputs, training=False):  
        mask = self.full_self_attention_mask        
        return self.shared_call_steps(inputs=inputs, mask=mask, training=training)

    # variable shape for inference (not XLA compatible)
    def flexible_call(self, inputs, inference_step, training=False):  
        
        encoder_features = inputs[0]
        decoder_features = inputs[1]

        # restrict to features of interest       
        # note: '+1' needed due to indexing used in main function
        decoder_features = decoder_features[:, :inference_step+1, :]  
        inputs = [encoder_features, decoder_features]

        # self-attention mask
        ones = tf.ones([inference_step+1, inference_step+1])
        self_attention_mask = tf.linalg.band_part(ones, -1, 0)
        
        return self.shared_call_steps(inputs=inputs, mask=self_attention_mask, training=training)

    def shared_call_steps(self, inputs, mask, training=False):
        encoder_features = inputs[0]
        decoder_features = inputs[1]


        # Self Attention Block                
        attention_features = self.SelfAttentionLayer(query=decoder_features, 
                                                    value=decoder_features, 
                                                    attention_mask=mask,   
                                                    training=training)

        decoder_features = keras.layers.Add()([decoder_features, attention_features])
        decoder_features = self.SelfLayerNorm (decoder_features, training=training)

        

        # Joint Attention Block
        attention_features = self.JointAttentionLayer(query=decoder_features, 
                                                      value=encoder_features, 
                                                      attention_mask=None,
                                                      training=training)      
        decoder_features = keras.layers.Add()([decoder_features, attention_features])
        decoder_features = self.JointLayerNorm(decoder_features, training=training)

        
        # Feed Forward Block                
        dense_features = self.DenseRelu(decoder_features)
        dense_features = self.Dense(dense_features)
        
        dense_features = keras.layers.Dropout(rate=.1)(dense_features, training=training)
        decoder_features = keras.layers.Add()([decoder_features, dense_features])
        decoder_features = self.LayerNorm(decoder_features, training=training)

        return decoder_features      
        

In [28]:
class DecoderAttention(keras.Model):
    def __init__(self, num_blocks, num_attention_heads, max_steps, name='DecoderAttention', **kwargs):
        super().__init__(name=name, **kwargs)

        self.num_blocks = num_blocks
        self.num_attention_heads = num_attention_heads
        self.max_steps = max_steps

        self.Blocks = []
        for i in range(self.num_blocks):
            block = DecoderAttentionBlock(num_attention_heads=self.num_attention_heads, 
                                          max_steps=self.max_steps, 
                                          name=f'DecoderBlock_{i}')
            self.Blocks.append(block)
            
    def get_config(self):
        config = {'num_blocks': self.num_blocks,
                  'num_attention_heads': self.num_attention_heads,
                  'max_steps': max_steps
        }
        return config 
        
    # fixed shapes for XLA compiler
    def call(self, inputs, training=False):
        encoder_features = inputs[0]
        decoder_features = inputs[1]

        # regularization
        decoder_features = keras.layers.SpatialDropout1D(rate=.05)(decoder_features, training=training)
    
        for i in range(self.num_blocks ):
            decoder_features = self.Blocks[i]([encoder_features, decoder_features], training=training)

        return decoder_features

    # alt call with variable shapes (not XLA compatible). 
    # Provides faster inference on GPU but not yet working on TPU
    def flexible_call(self, inputs, inference_step, training=False):
        encoder_features = inputs[0]
        decoder_features = inputs[1]

        # regularization
        decoder_features = keras.layers.SpatialDropout1D(rate=.05)(decoder_features, training=training)
    
        for i in range(self.num_blocks ):
            decoder_features = self.Blocks[i].flexible_call(inputs=[encoder_features, decoder_features], 
                                              inference_step=inference_step, training=training)

        return decoder_features

In [29]:
temp_dec = DecoderAttention(num_blocks=2, num_attention_heads=2, max_steps=200)
temp_dec([tf.ones((2,100,256)), tf.ones((2,200,512))])
temp_dec.flexible_call([tf.ones((2,100,256)), tf.ones((2,200,512))], inference_step=3)
temp_dec.summary()

Model: "DecoderAttention"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
DecoderBlock_0 (DecoderAtten multiple                  2367488   
_________________________________________________________________
DecoderBlock_1 (DecoderAtten multiple                  2367488   
Total params: 4,734,976
Trainable params: 4,734,976
Non-trainable params: 0
_________________________________________________________________


## Decoder Head (Prediction Output)

This is where we use what was learned in the encoder-decoder attention to output predicted labels. It is the prediction step from "Attention is All You Need."

In [30]:
def DecoderHead(decoder_dim, vocab_size, dual_heads_split_step=None):
   
    # Inputs
    decoder_features = keras.layers.Input([None, decoder_dim], name='decoder_features')
    inputs = [decoder_features]

    # Model
    if dual_heads_split_step is None:
        logits = keras.layers.Dense(vocab_size)(decoder_features)

    else:
        decoder_features_0 = decoder_features[:, :dual_heads_split_step, :]
        decoder_features_1 = decoder_features[:, dual_heads_split_step:, :]

        logits_0 = keras.layers.Dense(vocab_size)(decoder_features_0)
        logits_1 = keras.layers.Dense(vocab_size)(decoder_features_1)

        logits = keras.layers.Concatenate(axis=1)([logits_0, logits_1])
    
    probabilities = keras.layers.Softmax(dtype=tf.float32)(logits)

    outputs = [probabilities]

    return keras.Model(inputs, outputs, name='DecoderHead')

In [31]:
DecoderHead(decoder_dim=512, vocab_size=199, dual_heads_split_step=None).summary()

Model: "DecoderHead"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
decoder_features (InputLayer [(None, None, 512)]       0         
_________________________________________________________________
dense_8 (Dense)              (None, None, 199)         102087    
_________________________________________________________________
softmax_4 (Softmax)          (None, None, 199)         0         
Total params: 102,087
Trainable params: 102,087
Non-trainable params: 0
_________________________________________________________________


## Update Mechanism (Optional)

*Note: this is fully coded but I have not had time to train parameters with it. I leave that as a future opportunity for exploration.*

NLP technicques typically output logits to find the highest likelhood token prediction. This can be improved to a (local) maximum likelihood selection using a "beam step" that ay override the initial prediction choice. 

This layer is an alternative system for updating predictions. Unlike "beam," it is trainable and includes longer-range dependencies (instead of the very "local" beam step.) The entire original prediction is passed through a bidirectional RNN (decoder feature extraction) followed by AIAYN stye attention blocks. No masking is needed since we have the full RNN output to work with.

In [32]:
def BeamUpdate(num_beam_blocks, num_att_blocks, num_attention_heads, 
               num_encoder_vectors, encoder_units, 
               decoder_units, max_len, vocab_size, name='BeamUpdate'):
        
    # layers
    # note: GRU doesn't appear to be compatible with reduced precision
    tf.keras.mixed_precision.set_global_policy('float32')  # temporarily disable mixed precision
    BeamUnit = keras.layers.GRU(decoder_units, return_sequences=True, 
                    return_state=True, go_backwards=True,
                    dtype=tf.keras.mixed_precision.Policy('float32'))  
    
    #tf.keras.mixed_precision.set_global_policy(PARAMETERS.mixed_precision())  # reset mixed precision

    BeamDecoderAttention = DecoderAttention(num_blocks=num_att_blocks, 
                                            num_attention_heads=num_attention_heads, 
                                            max_steps=max_len,
                                            name='BeamDecoderAttention')

    BeamDecoderHead = DecoderHead(decoder_dim=decoder_units, 
                                  vocab_size=vocab_size,  
                                  dual_heads_split_step=None)

    # Inputs
    beam_input = keras.layers.Input([max_len, vocab_size], name='beam_input') 
    hidden_state = keras.layers.Input([decoder_units], name='hidden_state')
    encoder_features = keras.layers.Input([num_encoder_vectors, encoder_units], name='encoder_features')   # from image 
    mask = keras.layers.Input([max_len, max_len], name='mask')   # should pass in all 1's, i.e. no masking
    
    inputs = [beam_input, hidden_state, encoder_features, mask]

    # create initial hidden state for RNN dim = decoder_units
    beam_hidden_state = tf.reduce_mean(encoder_features, -2)
    beam_hidden_state = keras.layers.Dense(decoder_units, activation='relu')(beam_hidden_state)
    
    # Decoder encoding using 1 or more Beam layers
    for i in range(num_beam_blocks):
        beam_out, beam_hidden_state = \
            BeamUnit(beam_input, initial_state=[beam_hidden_state])

    # Attention & Prediction (uses "Attention is All You Need" structure)
    decoder_features = BeamDecoderAttention([encoder_features, beam_out, mask])

    probs = BeamDecoderHead([decoder_features])  

    outputs = [probs]

    return keras.Model(inputs, outputs, name=name)


BeamUpdate(num_beam_blocks=1, num_att_blocks=1, num_attention_heads=8, num_encoder_vectors=50, encoder_units=256, 
            decoder_units=512, max_len=200, vocab_size=160, name='BeamUpdate').summary()

Model: "BeamUpdate"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
encoder_features (InputLayer)   [(None, 50, 256)]    0                                            
__________________________________________________________________________________________________
tf.math.reduce_mean (TFOpLambda (None, 256)          0           encoder_features[0][0]           
__________________________________________________________________________________________________
beam_input (InputLayer)         [(None, 200, 160)]   0                                            
__________________________________________________________________________________________________
dense_10 (Dense)                (None, 512)          131584      tf.math.reduce_mean[0][0]        
_________________________________________________________________________________________

# **Full Model**

All the components are combined into a full encoder/decoder model. This is implemented using the subclassing API with custom call, train,  evaluation and prediction steps. Once initialized, the models have full access to high-level model.fit(), model.compile() and model.save_weights() methods.

An extra features implemented is having Decoder() elements in *series* (not stacked). This adds more trainable parameters without affecting inference speed, and allows decoders to specialize more on different regions of the text.

BaseTrainer() model has the BeamUpdate mechanism disabled. InchiGenerator() models include the BeamUpdate.

In [33]:
class TrainerModel(keras.Model):

    def __init__(self, encoder_blocks, encoder_dim, decoder_blocks, decoder_dim, 
                 dual_heads_split_step=None, use_convolutions=False, parameters=PARAMETERS, 
                 name='TrainerModel', **kwargs):
        
        super().__init__(name=name, **kwargs)

        self.encoder_blocks = encoder_blocks
        self.encoder_dim = encoder_dim
        self.decoder_blocks = decoder_blocks
        self.decoder_dim = decoder_dim
        self.dual_heads_split_step = dual_heads_split_step
        self.use_convolutions = use_convolutions
        self.parameters = parameters

        tokenizer_layer, self.inverse_tokenizer, self.tokenized_EOS = \
            Tokenizer(parameters=self.parameters)
        self.vocab_size = tokenizer_layer.vocabulary_size()
        self.EOS = parameters.EOS()


    def get_config(self):
        config = {'encoder_blocks': self.encoder_blocks,
                  'encoder_dim': self.encoder_dim,
                  'decoder_blocks':self.decoder_blocks,
                  'decoder_dim': self.decoder_dim,
                  'dual_heads_split_step': self.dual_heads_split_step,
                  'use_convolutions': self.use_convolutions,
                  'parameters':self.parameters,
        }
        return config 


    def build(self, input_shape):

        self.batch_size = input_shape[0][0]
        image_shape = input_shape[0][1:]
        tokenized_inchi_shape = input_shape[1][1:]
        self.padded_length = tokenized_inchi_shape[0]
        
        ###### InChI  ######
        # InChI start variable
        initializer = tf.random_normal_initializer()(shape=[1, 1, self.decoder_dim])
        start_var = tf.Variable(initializer, trainable=True, dtype=tf.float32, name='start_var')
        self.start_var = tf.tile(start_var, [self.batch_size, 1, 1])

        # InChI encoder
        self.InChIEncoder = InChIEncoder(vocab_size=self.vocab_size, 
                                         inchi_embedding_dim=self.decoder_dim, 
                                         num_chars=self.padded_length,
                                         use_convolutions=self.use_convolutions)
        
        # InChI positional encoding variable
        initializer = tf.random_normal_initializer()(
            shape=[1, self.padded_length, self.decoder_dim], dtype=tf.float32)
        positional_encoding_inchi = tf.Variable(initializer, trainable=True, 
                                                     name='positional_encoding_inchi')
        self.positional_encoding_inchi = tf.tile(positional_encoding_inchi, 
                                                 [self.batch_size, 1, 1])

        ###### Image  ######
        self.ImageFeaturesExtractor = ImageFeaturesExtractor(image_shape)
        
        self.ImageEncoder = ImageEncoder(image_shape=image_shape, 
                                         encoder_dim=self.encoder_dim, 
                                         ImageFeaturesExtractor=self.ImageFeaturesExtractor)
        
        self.num_image_features = self.ImageEncoder.output_shape[-2]

        # Image positional encoding variable
        initializer = tf.random_normal_initializer()(
            shape=[1, self.num_image_features, self.encoder_dim], dtype=tf.float32)
        positional_encoding_image = tf.Variable(initializer, trainable=True, 
                                                     name='positional_encoding_image')
        self.positional_encoding_image = tf.tile(positional_encoding_image, 
                                                 [self.batch_size, 1, 1])

        ###### Transformers  ######
        self.EncoderAttention = EncoderAttention(num_blocks=self.encoder_blocks, 
                                                 num_attention_heads=8, 
                                                 features_dim=self.encoder_dim,
                                                 num_encoder_features=self.num_image_features)
        
        self.DecoderAttention = DecoderAttention(num_blocks=self.decoder_blocks, 
                                                 num_attention_heads=8,
                                                 max_steps=self.padded_length)
        
        self.DecoderHead = DecoderHead(decoder_dim=self.decoder_dim, 
                                       vocab_size=self.vocab_size, 
                                       dual_heads_split_step=self.dual_heads_split_step)
        
    # NOTE: call is used for training only
    def call(self, inputs, training=False):

        if training:
            image = inputs[0]
            tokenized_inchi = inputs[1]

            # encoder
            encoder_features = self.ImageEncoder(
                [image, self.positional_encoding_image], training=training)
            
            encoder_features = self.EncoderAttention(encoder_features, training=training)
            
            # decoder
            decoder_features = self.InChIEncoder(
                [tokenized_inchi, self.start_var, self.positional_encoding_inchi], training=training)

            decoder_features = self.DecoderAttention(
                [encoder_features, decoder_features], training=training)

            # predictions
            probabilities = self.DecoderHead(decoder_features, training=training)

        else:
            probabilities = self.generation_loop(inputs, training=False)

        return probabilities

    # inference yielding generated probabilities from a single batch input
    @tf.function
    def generation_loop(self, inputs, training=False):
        
        image = inputs[0]
        tokenized_inchi = inputs[1]

        # get shapes
        padded_length = self.padded_length
        batch_size = self.batch_size

        # encoder
        encoder_features = self.ImageEncoder(
            [image, self.positional_encoding_image], training=training)
        
        encoder_features = self.EncoderAttention(encoder_features, training=training)
        
        # decoder
        # create containers
        generated_probs = tf.TensorArray(dtype=tf.float32, size=padded_length,
                element_shape=tf.TensorShape([self.batch_size, self.vocab_size]))
        generated_inchi = tf.TensorArray(dtype=tf.int32, size=padded_length,
                element_shape=tf.TensorShape([self.batch_size]))
        
        # initialize generated probs array
        zeros = tf.zeros((self.batch_size, self.vocab_size), dtype=generated_probs.dtype)
        for i in range(padded_length):
            generated_probs = generated_probs.write(i, zeros)

        # initialize generated InChI values array
        zeros = tf.zeros((self.batch_size), dtype=generated_inchi.dtype)
        for i in range(padded_length):
            generated_inchi = generated_inchi.write(i, zeros)

        # initialize step
        step = tf.constant(0, dtype=tf.int32)
        
        # loop body function
        def body_fn(generated_inchi, generated_probs, step):

            inchi = tf.transpose(generated_inchi.stack(), [1,0])
            
            # get current step probs and save result
            probs = self.decoder_step(encoder_features, tokenized_inchi, step)
            generated_probs = generated_probs.write(step, tf.cast(probs, dtype=generated_probs.dtype))

            # get new token prediction and save result
            predicted_token = tf.argmax(probs, axis=-1)
            generated_inchi = generated_inchi.write(
                    step, tf.cast(predicted_token, dtype=generated_inchi.dtype))

            # update step
            step = step + 1
            step = tf.cast(step, dtype=step.dtype)
            
            return [generated_inchi, generated_probs, step]

        # loop conditional function
        def cond_fn(generated_inchi, generated_probs, step):
            return tf.math.less(step, padded_length)

        # run generation loop
        generated_inchi, generated_probs, step = \
            tf.while_loop(cond=cond_fn,
                          body=body_fn,
                          loop_vars=[generated_inchi, generated_probs, step],
                          parallel_iterations=1,
                          maximum_iterations=padded_length,
                          shape_invariants=[None, None, tf.TensorShape([])],
                          )
        
        # unpack generated probabilities
        probabilities = tf.transpose(generated_probs.stack(), [1, 0, 2])

        return probabilities
    
    @tf.function
    def decoder_step(self, encoder_features, tokenized_inchi, step):

        # decoder
        decoder_features = self.InChIEncoder(
            [tokenized_inchi, self.start_var, self.positional_encoding_inchi], training=False)


        decoder_features = self.DecoderAttention(
            [encoder_features, decoder_features], training=False)
        
        """ TODO: implementation that doesn't require full InChI vecs at each step
        decoder_features = self.DecoderAttention.flexible_call(
            [encoder_features, decoder_features], inference_step=step, training=False)
        """  

        # get probabilities
        probabilities = self.DecoderHead(decoder_features, training=False)[:, step, :]

        return probabilities

    @tf.function(jit_compile=False, experimental_relax_shapes=True)
    def tokens_to_string(self, token_predictions):

        # convert to strings
        parsed_string_vals = self.inverse_tokenizer(token_predictions)
        string_vals = keras.layers.Lambda(
            lambda x: tf.strings.reduce_join(x, axis=-1))(parsed_string_vals)

        # remove first EOS generated and everything after
        pattern = ''.join([self.EOS, '.*$'])
        string_vals = tf.strings.regex_replace(string_vals, pattern, rewrite='', 
                                               replace_global=True, name='remove_EOS')   

        return string_vals

In [34]:
class EditDistanceMetric(tf.keras.metrics.Metric):
    def __init__(self, name='edit_distance', **kwargs):
        super().__init__(name=name, **kwargs)
        self.edit_distance = self.add_weight(name='edit_distance', initializer='zeros')
        self.batch_counter = self.add_weight(name='batch_counter', initializer='zeros')
    
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.argmax(y_true, axis=-1)  # convert one_hot vectors back to sparce categorical
        y_true = tf.sparse.from_dense(y_true)
        y_pred = tf.sparse.from_dense(tf.argmax(y_pred, axis=-1))  # convert probs to preds

        y_true = tf.cast(y_true, tf.int32)
        y_pred = tf.cast(y_pred, tf.int32)

        # compute edit distance (of parsed tokens)
        edit_distance = tf.edit_distance(y_pred, y_true, normalize=False)
        self.edit_distance.assign_add(tf.reduce_mean(edit_distance))

        # update counter
        self.batch_counter.assign_add(tf.reduce_sum(1.))
    
    def result(self):
        return self.edit_distance / self.batch_counter

    def reset_state(self):
        # The state of the metric will be reset at the start of each epoch.
        self.edit_distance.assign(0.0)
        self.batch_counter.assign(0.0)

In [35]:
# Modified "Attention is All You Need" learning scheduler (to become cyclic)
class LRScheduleAIAYN(tf.keras.optimizers.schedules.LearningRateSchedule):

    def __init__(self, scale_factor=1, warmup_steps=4000):  # defaults reflect paper's values
        # cast dtypes
        self.warmup_steps = tf.constant(warmup_steps, dtype=tf.float32)
        dim = tf.constant(352, dtype=tf.float32)
        scale_factor = tf.constant(scale_factor, dtype=tf.float32)
        
        self.scale = scale_factor * tf.math.pow(dim, -1.5)

    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        crit = self.warmup_steps

        def false_fn(step):
            adj_step = (step - crit) % (2.0*crit) + crit
            return tf.math.pow(adj_step, -.5)

        val = tf.cond(tf.math.less(step, crit),
                      lambda: step * tf.math.pow(crit, -1.5),  # linear increase
                      lambda: false_fn(step)  # decay
                      )

        return self.scale * val

"""
# visualize learning rate 
temp_lr = LRScheduleAIAYN()
plt.plot([i for i in range(1, 16000)], [temp_lr(i) for i in range(1, 16000)])
print('Learning Rate Schedule')
"""

"\n# visualize learning rate \ntemp_lr = LRScheduleAIAYN()\nplt.plot([i for i in range(1, 16000)], [temp_lr(i) for i in range(1, 16000)])\nprint('Learning Rate Schedule')\n"

## Build Model

Model Parameters

In [36]:
NAME_MODIFIER = ''

# build model
IMAGE_DENSE_OUTPUT_DIM = 256  # note: only used with USE_DENSE_ENCODER_TOP = True.
ENCODER_ATT_UNITS = 256
DECODER_UNITS = 512  # # "All You Need is Attention" uses 512 units
BEAM_RNN_UNITS = 128  # note: only used in beam_model.

# Note: model has capacity for up to 6 encoder and 6 decoder blocks. (as in AISAYN base model)
NUM_ENCODER_BLOCKS = 2  # note: can set to 0 to skip encoder block
NUM_DECODER_BLOCKS = 4  # max 6 enc and 6 dec with Colab memory constraints and cuts elsewhere
USE_DUAL_DECODERS = False
USE_CONVOLUTIONS = False
if USE_CONVOLUTIONS:
    checkpoint_save_name = 'ConvAtt_model_checkpoints' + NAME_MODIFIER
else:
    checkpoint_save_name = 'AISAYN_model_checkpoints' + NAME_MODIFIER

LOAD_CHECKPOINT_FILE = os.path.join(PARAMETERS.load_checkpoint_dir(), checkpoint_save_name, checkpoint_save_name)
SAVE_CHECKPOINT_FILE = os.path.join(PARAMETERS.checkpoint_dir(), checkpoint_save_name, checkpoint_save_name)

# note: in Kaggle,
# LOAD_CHECKPOINT_FILE points to saved outputs from prev session
# SAVE_CHECKPOINT_FILE points to saved outputs from current session

Initialize model

In [37]:
# Update inputs: remove string keys, as they are not compatible with TPU
train_ds_int_index = train_ds.map(lambda x: (x['image'], x['tokenized_InChI'], 
                                            x['image_id'], x['InChI'])).prefetch(tf.data.AUTOTUNE)
valid_ds_int_index = valid_ds.map(lambda x: (x['image'], x['tokenized_InChI'], 
                                            x['image_id'], x['InChI'])).prefetch(tf.data.AUTOTUNE)

In [38]:
# callbacks
checkpoint = tf.keras.callbacks.ModelCheckpoint(SAVE_CHECKPOINT_FILE, monitor='loss', 
        save_weights_only=True, save_best_only=False, save_freq='epoch',
        options=tf.train.CheckpointOptions(experimental_io_device='/job:localhost'))

nan_stop = tf.keras.callbacks.TerminateOnNaN()

In [39]:
def compile_model(model=None, load_checkpoint=True, lr_scale_factor=1.0, label_smoothing=.1):

    # compile using distribution strategy
    with PARAMETERS.strategy().scope():
     
        # initialize if no model provided
        if model is None:

            model = TrainerModel(encoder_blocks=NUM_ENCODER_BLOCKS, 
                               encoder_dim=ENCODER_ATT_UNITS, 
                               decoder_blocks=NUM_DECODER_BLOCKS, 
                               decoder_dim=DECODER_UNITS, 
                               dual_heads_split_step=USE_DUAL_DECODERS,
                               use_convolutions=USE_CONVOLUTIONS,
                               parameters=PARAMETERS)

            if PARAMETERS.tpu():  
        
                temp_ds = PARAMETERS.strategy().experimental_distribute_dataset(train_ds_int_index)
                temp_ds = iter(temp_ds)
                val = next(temp_ds)

                # build with new val (inference and training modes)
                temp_func_train = tf.function(func=lambda x: model(x, True), experimental_relax_shapes=True,
                                        experimental_follow_type_hints=True)
                temp_func_inference = tf.function(func=lambda x: model(x, False), experimental_relax_shapes=True,
                                        experimental_follow_type_hints=True)

                PARAMETERS.strategy().run(temp_func_train, args=[(val[0], val[1])])  # use strategy.run() on TPU
                    
            else:  
                # build with original val
                for val in train_ds_int_index.take(1): 
                    model(val, training=True)

            # show summary
            print(model.summary())
            print('Models initialized.')
        
        # compiler components
        # cyclic modification to AIAYN lr
        learning_rate = LRScheduleAIAYN(scale_factor=lr_scale_factor, 
                                        warmup_steps=5000)  

        optimizer = tf.keras.optimizers.Adam(learning_rate,  # params from AIAYN
                                             beta_1=0.9, beta_2=0.98, epsilon=10e-9)
        
        # metrics
        edit_dist_metric = EditDistanceMetric()
        
        # loss with label smoothing
        loss_fn = keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing)

        # optimizations
        tf.config.optimizer.set_jit("autoclustering")  # XLA compiler optimization

        if not PARAMETERS.tpu():
            os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private'  # better balances CPU / GPU interaction in tf.data    
            optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)  # required with mixed precision on GPU / CPU. Not used on TPU

            # compile       
            model.compile(optimizer=optimizer, 
                            loss=loss_fn,
                            metrics=['categorical_accuracy', edit_dist_metric],
                            steps_per_execution=8)
        else:
            # compile (note: EditDistance metric not compatible with TPU)
            model.compile(optimizer=optimizer, 
                          loss=loss_fn,
                          metrics=['categorical_accuracy'],
                          steps_per_execution=8*PARAMETERS.strategy().num_replicas_in_sync)

    if load_checkpoint:
        # verify model calls & methods work
        if not PARAMETERS.tpu():
            model(val, training=False)

            # sync weights
            # WARNING!: in Kaggle this loads from prev session saved weights
            try:
                with PARAMETERS.strategy().scope(): 
                    model.load_weights(LOAD_CHECKPOINT_FILE)  
                    pass
            except:
                print('No weights loaded')  

        else:
            # sync weights
            # WARNING!: in Kaggle this loads from prev session saved weights
            try:
                with PARAMETERS.strategy().scope(): 
                    model.load_weights(LOAD_CHECKPOINT_FILE, 
                                       options=tf.train.CheckpointOptions(experimental_io_device="/job:localhost"))  

            except:
                print('No weights loaded')    

    return model

In [47]:
base_model = compile_model(model=None, load_checkpoint=True, 
                           lr_scale_factor=500.0, label_smoothing=.1)

Model: "TrainerModel"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
string_lookup_5 (StringLooku multiple                  0 (unused)
_________________________________________________________________
InChIEncoder (Functional)    (None, 200, 512)          101888    
_________________________________________________________________
ImageFeaturesExtractor (Func (None, 7, 7, 1408)        7768569   
_________________________________________________________________
ImageEncoder (Functional)    (None, 49, 256)           8195065   
_________________________________________________________________
EncoderAttention (Functional (None, 49, 256)           791552    
_________________________________________________________________
DecoderAttention (DecoderAtt multiple                  9469952   
_________________________________________________________________
DecoderHead (Functional)     (None, None, 199)        

Inference Functions

In [48]:
def inference(model, dataset, take_num=None, skip_set_num=0, show_sample=False):

    with model.parameters.strategy().scope():

        generation_fn = model.generation_loop
               
        # select batches
        dataset = dataset.skip(skip_set_num)
        if take_num is not None:
            dataset = dataset.take(take_num)
        
        # distribute dataset
        distributed_ds = model.parameters.strategy().experimental_distribute_dataset(dataset)
        distributed_ds = iter(distributed_ds)

        # initialize containers
        image_ids_list = []
        generated_predictions_list = []
        true_InChI_list = []

        for val in distributed_ds:
            
            # unpack ds element (and distribute across replicas if needed)
            inputs = (val[0], val[1])  # (image, tokenized InChI)
            
            image_ids = val[2]
            if model.parameters.tpu():
                image_ids = model.parameters.strategy().gather(image_ids, axis=0)
            
            true_InChI = val[3]
            if model.parameters.tpu():
                true_InChI = model.parameters.strategy().gather(true_InChI, axis=0)

            # generate probs
            generated_probs = model.parameters.strategy().run(generation_fn, args=[inputs])  # training=False
            if model.parameters.tpu():
                generated_probs = model.parameters.strategy().gather(generated_probs, axis=0)
                generated_probs = tf.squeeze(generated_probs)

            # get predicted InChI
            generated_predictions = tf.argmax(generated_probs, axis=2)
            
            # convert predictions to strings
            generated_predictions = model.tokens_to_string(generated_predictions)

            # decode bytestrings and update containers
            image_ids_list.extend([x.decode() for x in image_ids.numpy().tolist()])
            generated_predictions_list.extend([x.decode() for x in generated_predictions.numpy().tolist()])
            true_InChI_list.extend([x.decode() for x in true_InChI.numpy().tolist()])

        # compute Levenshtein scores
        lev_score_list = [levenshtein(pred, orig) for (pred, orig)
                        in zip(generated_predictions_list, true_InChI_list)]

    if show_sample:
        print(f'Mean Lev Score: {np.mean(lev_score_list)}\n')
        for i in range(0, 100, 5):
            print(generated_predictions_list[i])
            print(true_InChI_list[i])
            print(lev_score_list[i])
            print()

    return image_ids_list, generated_predictions_list, true_InChI_list, lev_score_list

temp_results = inference(base_model, dataset=train_ds_int_index, take_num=1, show_sample=True)


Mean Lev Score: 123.7822265625


InChI=1S/C10H15N3O2/c11-3-9-4-12-5-10(13-9)15-7-8-1-2-14-6-8/h4-5,8H,1-3,6-7,11H2
81

--,,
InChI=1S/C22H24N4O2/c27-21(9-4-11-25-12-10-23-16-25)24-19-13-22(28)26(15-19)14-18-7-3-6-17-5-1-2-8-20(17)18/h1-3,5-8,10,12,16,19H,4,9,11,13-15H2,(H,24,27)/t19-/m1/s1
161


InChI=1S/C12H13N3O3S2/c16-12-15-10(7-19-12)6-14-20(17,18)11-2-1-8-4-13-5-9(8)3-11/h1-3,7,13-14H,4-6H2,(H,15,16)
111


InChI=1S/C14H18O/c1-2-3-10-15-14-9-8-12-6-4-5-7-13(12)11-14/h4,6,8-9,11H,2-3,5,7,10H2,1H3
89


InChI=1S/C17H17FN4S/c1-11-3-2-4-13(9-11)20-17(19)22-21-15-7-8-23-16-6-5-12(18)10-14(15)16/h2-6,9-10H,7-8H2,1H3,(H3,19,20,22)/b21-15+
132


InChI=1S/C8H5F6NO/c1-16-5-3(6(9)10)2-15-7(11)4(5)8(12,13)14/h2,6H,1H3
69


InChI=1S/C18H22N2O5/c1-11-16(12(2)25-20-11)18(22)24-13(3)17(21)19-10-9-14-5-7-15(23-4)8-6-14/h5-8,13H,9-10H2,1-4H3,(H,19,21)/t13-/m0/s1
135

14--/22
InChI=1S/C16H23N/c1-4-6-15-7-8-16(13-14(15)5-2)9-11-17(3)12-10-16/h1,5-6H,7-13H2,2-3H3/b14-5-,15-6-
92

,,,,,,,,,,,
InChI=1S/C27

Test inference speed

In [42]:
#"""
# test inference speed - time for 'take_num' batches
%timeit inference(base_model, dataset=train_ds_int_index, take_num=2, skip_set_num=0)
#"""

1 loop, best of 5: 5.72 s per loop


In [43]:
if not PARAMETERS.tpu():
    tensorboard = tf.keras.callbacks.TensorBoard(log_dir='./logs/')
    %tensorboard --logdir './logs/'

In [44]:
"""
# Full model: inference speed (with beam)
%%timeit
num_batches = 3

for val in train_ds.unbatch().batch(PARAMETERS.inference_batch_size()).take(num_batches): 
    im_id, preds = (base_model.predict(val))
"""

'\n# Full model: inference speed (with beam)\n%%timeit\nnum_batches = 3\n\nfor val in train_ds.unbatch().batch(PARAMETERS.inference_batch_size()).take(num_batches): \n    im_id, preds = (base_model.predict(val))\n'

# Training

Prepare dataset for training

In [45]:
# Dataset: Random Perturbations and One-Hot-Encoding
# note: one-hot needed so we can use label smoothing in CrossEntropy Loss
tf.keras.mixed_precision.set_global_policy('float32')  # temporarily removed mixed precision

train_ds_prepared = train_ds_int_index
valid_ds_prepared = valid_ds_int_index

# define transformations
"""
rotate = keras.layers.experimental.preprocessing.RandomRotation(
            factor=(-0.5, 0.5), fill_mode='constant')
contrast = keras.layers.experimental.preprocessing.RandomContrast(factor=.1)

# apply transformations
train_ds_prepared = train_ds_prepared.map(lambda w, x, y, z: (rotate(w), x, y, z),
                                           num_parallel_calls=tf.data.AUTOTUNE)
train_ds_prepared = train_ds_prepared.map(lambda w, x: (contrast(w), x, y, z),
                                          num_parallel_calls=tf.data.AUTOTUNE)
"""

# create one-hot encoded targets (allows label smoothing)
depth = base_model.vocab_size
one_hot = keras.layers.Lambda(lambda x: tf.one_hot(x, depth=depth)) 

train_ds_prepared = train_ds_prepared.map(lambda w, x, y, z: ((w, x), one_hot(x)),
                                          num_parallel_calls=tf.data.AUTOTUNE)\
                                          .prefetch(tf.data.AUTOTUNE)

valid_ds_prepared = valid_ds_int_index.map(lambda w, x, y, z: ((w, x), one_hot(x)),
                                          num_parallel_calls=tf.data.AUTOTUNE)\
                                          .prefetch(tf.data.AUTOTUNE)

# re-enable to mixed precision
tf.keras.mixed_precision.set_global_policy(PARAMETERS.mixed_precision())

Train base model

In [None]:
"""
Note: validation results use inference step char generation. Training uses 
teacher-fed inputs.
"""

# Note: training accuracy based on teacher-training, not full generation loop

# Note: label smoothing is helpful in training, but needs to be removed 
# in the final stages

# Train base model (teacher-fed training, prediction-fed validation, no beam update)
if not PARAMETERS.tpu():
    steps_per_epoch = 1024 * max(1, PARAMETERS.batch_size() // 16)
    validation_steps = 128
    callbacks=[checkpoint, nan_stop, tensorboard]
    validation_freq = 6
    lr_scale_factor=10.

else:
    steps_per_epoch = 128 * max(1, 128 // PARAMETERS.batch_size())
    validation_steps = 32
    callbacks=[checkpoint, nan_stop]
    validation_freq = 6
    lr_scale_factor=50.


epoch_multiple = 100
epochs = epoch_multiple * int(1.8 * 1e6) // (steps_per_epoch * PARAMETERS.batch_size())
    
# recompile (optional)
#base_model = compile_model(model=base_model, load_checkpoint=True, lr_scale_factor=lr_scale_factor, label_smoothing=0.0)

# (Optional: focused training.)
base_model.ImageFeaturesExtractor.trainable = True  
base_model.get_layer('EncoderAttention').trainable = True
base_model.get_layer('DecoderAttention').trainable = True
base_model.get_layer('DecoderHead').trainable = True

# train
history = base_model.fit(train_ds_prepared.repeat(),
                         validation_data=valid_ds_prepared.repeat(),
                         epochs=epochs,
                         steps_per_epoch=steps_per_epoch,
                         validation_freq=validation_freq, 
                         validation_steps=validation_steps, 
                         callbacks=callbacks,
                         verbose=1)

# New Section

TPU-safe saving to local directory

In [None]:
base_model.save_weights(os.path.join(PARAMETERS.checkpoint_dir(), checkpoint_save_name, 'saved_model'), 
                        options=tf.saved_model.SaveOptions(experimental_io_device='/job:localhost'))

In [None]:
base_model.load_weights(os.path.join(PARAMETERS.checkpoint_dir(), checkpoint_save_name, 'saved_model'), 
                        options=tf.saved_model.SaveOptions(experimental_io_device='/job:localhost'))

Train beam update model

In [None]:
"""
Not yet implemented
"""

# Inference

Here we define function to conduct inference on the test set. Results are saved to "submission.csv".

Intermediate results are saved at regular intervals to. This allows inference to be conducted in stages and is a safeguard in case of interruptions before the full set has been processed. 

In [None]:
def make_inference_progress(dataset, model, return_lev_score=True, save_freq=50, parameters=PARAMETERS):

    batch_size = 1024
    est_num_batches = 2*10e7 // batch_size
    take_num = 100

    #initialize new dataframe
    predictions_df = pd.DataFrame(columns=['image_id', 'InChI', 'lev_score'])


    for i in range(int(est_num_batches // take_num)):
        try:

             # get predictions
            inference_outputs = run_inference(model, dataset, return_lev_score=True, 
                                              take_num=take_num, skip_set_num=i)
            
            im_id, pred, true_val, lev_score  = inference_outputs[:]

            # add to dataframe
            new_preds = pd.DataFrame({'image_id': im_id, 'InChI': pred, 'lev_score': lev_score})
            predictions_df = predictions_df.append(new_preds)

            # save to CSV
            if i % save_freq == 0:
                predictions_df = predictions_df.drop_duplicates(subset='image_id', keep='last')
                predictions_df[['image_id', 'InChI']].to_csv(PARAMETERS.csv_save_dir() + 'submission.csv', index=False)
                print(f'iteration {i}')

        except:
            print(f'completed at step {i}')
            break

    return predictions_df

Load previosuly generated predictions

In [None]:
try:
    predictions_df = pd.read_csv(PARAMETERS.csv_save_dir() + 'submission.csv')
except:
    predictions_df = pd.DataFrame({'image_id':[], 'InChI':[]}, dtype=str)

Generate additional predictions

In [None]:
""" On first pass or to start from scratch, initialize the dataframe with:
predictions_df = pd.DataFrame({'image_id':[], 'InChI':[]}, dtype=str)
"""

predictions_df = make_inference_progress(predictions_df, save_freq=100, 
                                         num_batches=1, starting_batch=0, 
                                         parameters=PARAMETERS)
predictions_df