<a href="https://colab.research.google.com/github/mvenouziou/Project-Attention-Is-What-You-Get/blob/main/bms_molecular_translation_AttentionIsWhatYouGet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Attention is What You Get

This is my entry into the [Bristol-Myers Squibb Molecular Translation](https://www.kaggle.com/c/bms-molecular-translation)  Kaggle competition.

-----

AUTHOR: 

Mo Venouziou

- *Email: mvenouziou@gmail.com*
- *LinkedIn: www.linkedin.com/in/movenouziou/*

Updates:

 - *Original Posting: June 2, 2021*
 - *06/21/21: added TPU support*
 - *06/17/21: improved training & inference speed. Allows full AIAYN model size on TPU, faster small model training on GPU.*

----

### Our Goal: Predict the "InChI" value of any given chemical compound diagram. 

International Chemical Identifiers ("InChI values") are a standardized encoding to describe chemical compounds. They take the form of a string of letters, numbers and deliminators, often between 100 - 400 characters long. 

The chemical diagrams are provided as PNG files, often of such low quality that it may take a human several seconds to decipher. 

Label length and image quality become a serious challenge here, because we must predict labels for a very large quantity of images. There are 1.6 million images in the test set abd 2.4 million images available in the training set!

In [None]:
"""
# Example (image, target label) pair\n\n'
for val in train_ds.unbatch().take(1):
    print('Example Label:\n', val['InChI'].numpy())
    print('\nCorresponding Image:', plt.imshow(val['image'][:,:,0], cmap='binary'))
### note: load datasets before running this cell
"""

## MODEL STRUCTURE: 

**Image CNN + Attention Features encoder --> text Attention + (optional )CNN feature layer decoder.**

This is a hybrid approach with:
 
 - Image Encoder from [*Show, Attend and Tell: Neural Image Caption Generation with Visual Attention*](https://proceedings.mlr.press/v37/xuc15.pdf).  Generate image feature vectors using intermediate layer outputs from a pretrained CNN. (Here I use the more modern EfficientNet model (recommended by [*Darien Schettler*](https://www.kaggle.com/dschettler8845/bms-efficientnetv2-tpu-e2e-pipeline-in-3hrs/notebook)) with fixed weights and a trainable Dense layer for customization.)
 
 - T2T encoder-decoder model from [*All You Need is Attention*](https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf) (Self-attention feature extraction for both encoder and decoder, joint encoder-decoder attention feature interactions, and a dense prediction output block. Includes parameters to control number of encoder / decoder blocks.

 - ***PLUS*** *(optional):* Decoder Output Blocks placed in Series (not stacked). Increase the number of trainable parameters without adding inference computational complexity, while also allowing decoders to specialize on different regions of the output.
 
 - ***PLUS*** *(optional):* Is attention really all you need? Add a convolutional layer to enhance text features before decoder self-attention to experiment with performance differences with and without extra convolutional layer(s). Use of CNN's in NLP comes from [*Convolutional Sequence to Sequence Learning*](http://proceedings.mlr.press/v70/gehring17a.html.)

 - ***PLUS*** *(optional):* Beam-Search Alternative, an extra decoding layer applied after the full logits prediction has been made. This takes the form of a bidirectional RNN with attention, applied to the full logits sequence. Because a full (initial) prediction has already been made, computations can be parallelized using statefull RNNs. (See more details below.)

*Optional features can be enabled/disabled using parameters in my model definitions.*

----

## NEXT STEPS:

 - (Low priority, specific to Kaggle's TPU implementation.) Fix "session.run()" TPU calls on Kaggle. (It works correctly on Colab.) This severely impacts inference speed on Kaggle.

 - Experiment with **"Tokens-to-Token ViT"** in place of the image CNN. (Technique from [*Training Vision Transformers from Scratch on ImageNet*](https://arxiv.org/pdf/2101.11986.pdf)
  
 - Train my **Beam-search Alternative**. 

    - Beam search is a technique to modify model predictions to reflect the (local) maximum likelihood estimate. However, it is *very* local in that computation expense increases quickly with the number of character steps taken into account. This is also a hard-coded algorithm, which is somewhat contrary to the philosophy of deep learning.

    - A *Beam-search Alternative* would be an extra decoding layer applied *after* the full logits prediction has been made. This might be in the form of a stateful, bidirectional RNN that is computationally parallizable because it is applied to the full logits sequence.

    - Need to revamp code to accept main model changes made for TPU support.

 - Treat the number of convolutional layers (decoder feature extraction) and number of decoders places in series (decoder prediction output) as **new hyperparameters** to tune.

 - *6/21/21: TPU Support added* ~~Implement TPU compatability.~~

 - *6/17/21: Increased model size and efficiency. * ~~ Implement full size model (matching AISYN) with efficient training and inference speeds for the large dataset. (TPU required. GPU doesn't have enough memory to train such a large model)~~

----

### CITATIONS

- "Attention is All You Need." 
 - Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin. NIPS (2017). *https://research.google/pubs/pub46201/*

- "Convolutional Sequence to Sequence Learning."
 
  - Gehring, J., Auli, M., Grangier, D., Yarats, D. & Dauphin, Y.N.. (2017). Convolutional Sequence to Sequence Learning. Proceedings of the 34th International Conference on Machine Learning, in Proceedings of Machine Learning Research 70:1243-1252, *http://proceedings.mlr.press/v70/gehring17a.html.*


- "EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks."
 
  - Mingxing Tan, Quoc V. Le (2019). Convolutional Sequence to Sequence Learning. International Conference on Machine Learning. *http://arxiv.org/abs/1905.11946.*


-  "Show, Attend and Tell: Neural Image Caption Generation with Visual Attention."
  -  Xu, K., Ba, J., Kiros, R., Cho, K., Courville, A., Salakhudinov, R., Zemel, R. & Bengio, Y.. (2015). Show, Attend and Tell: Neural Image Caption Generation with Visual Attention. Proceedings of the 32nd International Conference on Machine Learning, in Proceedings of Machine Learning Research 37:2048-2057. *http://proceedings.mlr.press/v37/xuc15.html.* 
            

- "Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet"

  - Li Yuan, Yunpeng Chen, Tao Wang, Weihao Yu, Yujun Shi, Zihang Jiang, Francis EH Tay, Jiashi Feng, Shuicheng Yan. Preprint (2021). *https://arxiv.org/abs/2101.11986*.

- Tensorflow documentation tutorial "Transformer model for language understanding." I found this after fully completing the model and found the attention mask was incorrect. My use of "tf.linalg.band_part" (only) is due to this tutorial. *www.tensorflow.org/text/tutorials/transformer#masking*

- Special thanks to [Darien Schettler](https://www.kaggle.com/dschettler8845/bms-efficientnetv2-tpu-e2e-pipeline-in-3hrs/notebook.) for leading readers to the "Show" and "Attention" papers cited above, using *session.run()* to improve inference speed in distributed settings and providing detailed info on creating TF Records. This work is otherwise derived independently from his.

- It is possible my idea of a Beam Search Alternative is based on a lecture video from DeepLearning.ai's [Deep Learning Specialization](https://www.coursera.org/specializations/deep-learning)  on Coursera.

- **Dataset / Kaggle Competition:** "Bristol-Myers Squibb – Molecular Translation" competition on Kaggle (2021). *https://www.kaggle.com/c/bms-molecular-translation*

----


## Contents

1. [Imports](https://colab.research.google.com/drive/1i6LMwu7BRfs955U4AdtV2oaI_9_A_Awq#scrollTo=TjuUOVXao__C&line=4&uniqifier=1)
2. [Data Pipeline](https://colab.research.google.com/drive/1i6LMwu7BRfs955U4AdtV2oaI_9_A_Awq#scrollTo=lrLHKs5Ni7Sz)
3. [Model Layers](https://colab.research.google.com/drive/1i6LMwu7BRfs955U4AdtV2oaI_9_A_Awq#scrollTo=W0T-u0vZamI8)
    - [InChI Encoding](https://colab.research.google.com/drive/1i6LMwu7BRfs955U4AdtV2oaI_9_A_Awq#scrollTo=DYApmA2lf1hp&line=1&uniqifier=1)
    - [Image Encoding and Self-Attention](https://colab.research.google.com/drive/1i6LMwu7BRfs955U4AdtV2oaI_9_A_Awq#scrollTo=FESofcGdEaWF&line=1&uniqifier=1)
    - [Decoder Self-Attention](https://colab.research.google.com/drive/1i6LMwu7BRfs955U4AdtV2oaI_9_A_Awq#scrollTo=6qFDs9RTjvod&line=1&uniqifier=1)
    - [Joint Encoder-Decoder Attention](https://colab.research.google.com/drive/1i6LMwu7BRfs955U4AdtV2oaI_9_A_Awq#scrollTo=jP-t1MkKnD5L)
    - [Decoder Head (Prediction Output)](https://colab.research.google.com/drive/1i6LMwu7BRfs955U4AdtV2oaI_9_A_Awq#scrollTo=38GA7wtNEhqW&line=1&uniqifier=1)
    - [Update Mechanism](https://colab.research.google.com/drive/1i6LMwu7BRfs955U4AdtV2oaI_9_A_Awq#scrollTo=_2UR1DLljD0S&line=1&uniqifier=1)
4. [Full Model](https://colab.research.google.com/drive/1i6LMwu7BRfs955U4AdtV2oaI_9_A_Awq#scrollTo=D6GIs3f3rpu0&line=1&uniqifier=1)
5. [Training](https://colab.research.google.com/drive/1i6LMwu7BRfs955U4AdtV2oaI_9_A_Awq#scrollTo=otxdN02mf1ht&line=1&uniqifier=1)
6. [Inference](https://colab.research.google.com/drive/1i6LMwu7BRfs955U4AdtV2oaI_9_A_Awq#scrollTo=Sbvzr5rdmjgs&line=5&uniqifier=1)

---

In [2]:
#### PACKAGE IMPORTS ####

# system management
import os
os.environ['TF_ENABLE_ONEDNN.OPTS'] = '1'  # Intel's TF optimization

# TF Model design
import tensorflow as tf
from tensorflow import keras
import tensorflow_hub as hub
from tensorflow.data import TFRecordDataset
from tensorflow.data.experimental import TFRecordWriter

# Text processing
import re
import string

# metric for Kaggle Competition
!pip install -q leven
from leven import levenshtein

# Kaggle (for TPU)
#from kaggle_datasets import KaggleDatasets

# Visualizations
import matplotlib.pyplot as plt
%matplotlib inline
from PIL import Image

# Tensorboard Profiler
!pip install -U -q tensorboard
!pip install -U -q tensorboard_plugin_profile
!pip install --upgrade -q "cloud-tpu-profiler>=2.3.0"
%load_ext tensorboard

"""
# Debugger
tf.debugging.experimental.enable_dump_debug_info('./logs/', tensor_debug_mode="FULL_HEALTH", 
                                                 circular_buffer_size=-1)
"""

# data management
import numpy as np
import pandas as pd
import itertools

## Model parameters

The 'ModelParameters' class manages global hyperparamaters for portability between Colab and Kaggle notebook environments. Once set, all other cells will run on either platform.

On Colab, connection to my personal Google Drive is required, as ModelParameters will extract the dataset from a zip file to the hosted environment. This process may take several minutes. (It would not be difficult for the reader to update the code to point to their own drive and download the zip dataset using the Kaggle API code below.)

In [None]:
""" Kaggle api for download the compressed dataset from Kaggle's servers.

# imports
!pip uninstall -y kaggle
!pip install --upgrade pip
!pip install kaggle==1.5.6

# if needed, download data using '!kaggle competitions download -c bms-molecular-translation'
# then unzip with '! unzip bms-molecular-translation.zip -d datasets'
os.environ['KAGGLE_CONFIG_DIR'] = '/content/gdrive/MyDrive/Kaggle'  # api token location
"""

" Kaggle api for download the compressed dataset from Kaggle's servers.\n\n# imports\n!pip uninstall -y kaggle\n!pip install --upgrade pip\n!pip install kaggle==1.5.6\n\n# if needed, download data using '!kaggle competitions download -c bms-molecular-translation'\n# then unzip with '! unzip bms-molecular-translation.zip -d datasets'\nos.environ['KAGGLE_CONFIG_DIR'] = '/content/gdrive/MyDrive/Kaggle'  # api token location\n"

In [None]:
# check for TPU & initialize
try:
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
    tf.config.experimental_connect_to_cluster(resolver)

    tf.tpu.experimental.initialize_tpu_system(resolver)
    print("All devices: ", tf.config.list_logical_devices('TPU'))

    STRATEGY = tf.distribute.TPUStrategy(resolver)
    TPU = True
    os.environ["TFHUB_MODEL_LOAD_FORMAT"] = "UNCOMPRESSED"  # for TF Hub models on TPU
    PRECISION_TYPE = 'mixed_bfloat16' 
    #PRECISION_TYPE = 'float32' 

    # extra imports for GCS
    !pip install -q fsspec
    !pip install -q gcsfs
    import fsspec, gcsfs 

except:
    TPU = False
    STRATEGY = tf.distribute.get_strategy()
    PRECISION_TYPE = 'mixed_float16'

# enable mixed precision
tf.keras.mixed_precision.set_global_policy(PRECISION_TYPE)

INFO:tensorflow:Initializing the TPU system: grpc://10.6.229.10:8470


INFO:tensorflow:Initializing the TPU system: grpc://10.6.229.10:8470


INFO:tensorflow:Clearing out eager caches


INFO:tensorflow:Clearing out eager caches


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Finished initializing TPU system.


All devices:  [LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU')]
INFO:tensorflow:Found TPU system:


INFO:tensorflow:Found TPU system:


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


In [None]:
class ModelParameters:
    def __init__(self, cloud_server='kaggle'):
               
        # universal parameters
        self._batch_size = 16  # used on GPU. TPU batch size increased below
        self._padded_length = 200
        self._image_size = (320, 320)  # shape to process images in data pipeline, matches HUB model
        self.SOS_string = 'InChI=1S/'  # start of sentence value
        self.EOS_string = '<EOS>'  # end of sentence value
        self._strategy = STRATEGY
        self._precision_type = PRECISION_TYPE
        self._tpu = TPU
        
        # TPU batch size
        if self._tpu:
            self._batch_size = 128 * self._strategy.num_replicas_in_sync

        # File Paths       
        if cloud_server == 'colab':  # Google Colab
            
            # load drive for saving checkpoints
            try:
                from google.colab import drive
                drive.mount('/content/gdrive/') 
            except:
                pass  # drive already mounted
            
            # check for TPU 
            if self._tpu: 

                # TPU file structure (via Kaggle GCS folder)
                self._dataset_dir = 'gs://kds-df3031ee4e277d641d1044cc3e9386a923ca98833b0d51a2575d2932' # from Kaggle. Get updated directory on Kaggle via KaggleDatasets().get_gcs_path('bms-molecular-translation')
                self._prepared_files_dir = 'gs://kds-96b617b700ddb4d07bc42a47c0a7abfe3a68d4510c459b7cd7b216e6'  # from Kaggle. Get updated directory on Kaggle via KaggleDatasets().get_gcs_path('periodic-table')
                self._tfrec_dir = 'gs://kds-dc74fe0494d010e8c9544cd7fff86e64f08cb0cffd4c608156ff3f41'  # from Kaggle. Get updated directory on Kaggle via KaggleDatasets().get_gcs_path('bmsshards')
                self._checkpoint_dir = '/content/gdrive/MyDrive/Colab_Notebooks/models/MolecularTranslation/checkpoints/'  # gdrive
                self._load_checkpoint_dir = self._checkpoint_dir
                self._csv_save_dir = './'

            else:
                # unzip data
                if not os.path.isdir('/content/bms-molecular-translation'):
                    !unzip -q /content/gdrive/MyDrive/Colab_Notebooks/models/MolecularTranslation/bms-molecular-translation.zip -d '/content/bms-molecular-translation'
                
                # file paths
                self._dataset_dir = 'bms-molecular-translation/'
                self._prepared_files_dir = '/content/gdrive/MyDrive/Colab_Notebooks/models/MolecularTranslation/'
                self._checkpoint_dir = '/content/gdrive/MyDrive/Colab_Notebooks/models/MolecularTranslation/checkpoints/'
                self._load_checkpoint_dir = self._checkpoint_dir
                self._csv_save_dir = self._prepared_files_dir 
                self._tfrec_dir = None
                
        elif cloud_server == 'kaggle': # Kaggle cloud notebook (CPU / GPU)
            from kaggle_datasets import KaggleDatasets
            
            # check for TPU 
            if self._tpu: 
                
                # file paths
                self._dataset_dir = '' #KaggleDatasets().get_gcs_path('bms-molecular-translation')
                self._prepared_files_dir = KaggleDatasets().get_gcs_path('periodic-table')
                self._tfrec_dir = KaggleDatasets().get_gcs_path('bmsshards')
                self._checkpoint_dir = './'
                self._load_checkpoint_dir = './'
                self._csv_save_dir = './'

            # set GPU instance info
            else:  
                # file paths
                self._dataset_dir = '../input/bms-molecular-translation/'
                self._prepared_files_dir = '../input/periodic-table/'
                self._tfrec_dir = '../input/bmsshards/'
                self._checkpoint_dir = './'
                self._load_checkpoint_dir = '../input/k/mvenou/bms-molecular-translation/checkpoints/'
                self._csv_save_dir = './'
                self._tfrec_dir = None

        # common file paths
        self._periodic_table_csv = os.path.join(self._prepared_files_dir, 'periodic_table_elements.csv')
        self._vocab_csv = os.path.join(self._prepared_files_dir, 'vocab.csv')        
        self._test_images_dir = os.path.join(self._dataset_dir, 'test/')
        self._train_images_dir = os.path.join(self._dataset_dir, 'train/')
        self._extra_labels_csv = os.path.join(self._dataset_dir, 'extra_approved_InChIs.csv')
        self._train_labels_csv = os.path.join(self._dataset_dir, 'train_labels.csv')
        self._sample_submission_csv = os.path.join(self._dataset_dir, 'sample_submission.csv')
        
    # functions to access params
    def padded_length(self):
        return self._padded_length
    def mixed_precision(self):
        return self._precision_type
    def tpu(self):
        return self._tpu
    def tfrec_dir(self):
        return self._tfrec_dir
    def cloud_server(self):
        return self._cloud_server
    def strategy(self):
        return self._strategy
    def csv_save_dir(self):
        return self._csv_save_dir
    def train_labels_csv(self):
        return self._train_labels_csv
    def vocab_csv(self):
        return self._vocab_csv
    def periodic_table_csv(self):
        return self._periodic_table_csv
    def batch_size(self):
        return self._batch_size  
    def image_size(self):
        return self._image_size    
    def SOS(self):
        return self.SOS_string
    def EOS(self):
        return self.EOS_string
    def train_images_dir(self):
        return self._train_images_dir
    def test_images_dir(self):
        return self._test_images_dir   
    def checkpoint_dir(self):
        return self._checkpoint_dir
    def load_checkpoint_dir(self):
        return self._load_checkpoint_dir


Initialize Parameter Options

In [None]:
PARAMETERS = ModelParameters(cloud_server='colab')

Drive already mounted at /content/gdrive/; to attempt to forcibly remount, call drive.mount("/content/gdrive/", force_remount=True).


# **Input Pipeline**

Load train labels as DataFrame

In [None]:
if not PARAMETERS.tpu():
    # Load CSV as dataframe
    train_labels_df = pd.read_csv(PARAMETERS.train_labels_csv())
    train_labels_df.head()

### InChI Text Parsing

We split each InChI label into its "vocabulary" of logical subunits, consisting of element abbreviations numbers, common symbols and the required string 'InChI=1S/', which is at the start of every InChI label. We want to narrow down this vocabulary to the smallest set represented in our training data. The functions below provide a system for finding this minimal set, as well as preparing a new CSV file with parsed labels ready to be fed into a tokenizer layer.

(For clarity and to reduce reliance on loading external files, the true code has been commented out and replaced with corresponding hard-coded values.)

In [None]:
def inchi_parsing_regex(parameters=PARAMETERS):
    # regex for spliting on InChi, but preserving chemical element abbreviations and three-digit numbers
    
    # shortcut: hard coded values
    vocab = [parameters.EOS(), parameters.SOS(), '(',
            ')', '+', ',', '-', '/', 'Br', 'B', 'Cl', 'C', 'D', 'F',
            'H', 'I', 'N', 'O', 'P', 'Si', 'S', 'T', 'b', 'c', 'h', 'i',
            'm', 's', 't']
        
    vocab += [str(num) for num in reversed(range(168))]
    vocab = [re.escape(val) for val in vocab]
       
    """ # to create vocab from scratch, use:
    SOS = parameters.SOS()
    EOS = parameters.EOS()
    
    # load list of elements we should search for within InChI strings: 
    periodic_elements = pd.read_csv(PARAMETERS.periodic_table_csv(), header=None)[1].to_list()
    periodic_elements = periodic_elements + [val.lower() for val in periodic_elements] + [val.upper() for val in periodic_elements]
    
    punctuation = list(string.punctuation)
    punctuation = [re.escape(val) for val in punctuation]   # update values with regex escape chars added as needed

    three_dig_nums_list = [str(i) for i in range(1000, -1, -1)]

    vocab = [SOS, EOS] + periodic_elements + three_dig_nums_list + punctuation
    """

    split_elements_regex = rf"({'|'.join(vocab)})"
    
    return split_elements_regex

In [None]:
INCHI_PARSING_REGEX = inchi_parsing_regex()

def parse_InChI(texts, parsing_regex=INCHI_PARSING_REGEX):  
    return ' '.join(re.findall(parsing_regex, texts))


# TF dataset map-compatible version
def parse_InChI_py_fn(texts, parsing_regex=INCHI_PARSING_REGEX):
    def tf_parse_InChI(texts):  
        texts = np.char.array(texts.numpy())
        texts = np.char.decode(texts).tolist()
        texts = tf.constant([parse_InChI(val) for val in texts])
        return tf.squeeze(texts)
    return tf.py_function(func=tf_parse_InChI, inp=[texts], Tout=tf.string)


# extracts filepath from image name
def path_from_image_id(x, root_folder):
    folder_a = tf.strings.substr(x, pos=0, len=1)
    folder_b = tf.strings.substr(x, pos=1, len=1)
    folder_c = tf.strings.substr(x, pos=2, len=1)
    filename =  tf.strings.join([x, '.png'])
    return tf.strings.join([root_folder, folder_a, folder_b, folder_c, filename], separator='/')

Tokenizer

Note: This must be kept outside the model (and used in dataset prep) for TPU compatability

In [None]:
def Tokenizer(parameters):
    """ tokenizes, crops & pads to parameters.padded_length() """

    SOS = parameters.SOS()
    EOS = parameters.EOS()
    padded_length = PARAMETERS.padded_length()
    
    # Create vocabulary for tokenizer
    def create_vocab():       
        hard_coded_vocab = [PARAMETERS.EOS(), PARAMETERS.SOS(), '(',
            ')', '+', ',', '-', '/', 'B', 'Br',  'C', 'Cl', 'D', 'F',
            'H', 'I', 'N', 'O', 'P', 'S', 'Si', 'T', 'b', 'c', 'h', 'i',
            'm', 's', 't']
        
        numbers = [str(num) for num in range(168)]
        
        vocab = hard_coded_vocab + numbers
        
        """
        # get from saved file
        vocab = pd.read_csv(PARAMETERS.vocab_csv())['vocab_value'].to_list()   
        vocab = list(vocab)
        """

        """ 
        # To create from scratch, extract all vocab elements appearing in train set:
        df = pd.read_csv(PARAMETERS.train_labels_csv())  
        seg_len = 250000
        num_breaks = len(df) // seg_len

        vocab = set()
        for i in range(num_breaks):

            df_i =  df['InChI'].iloc[seg_len * i: seg_len * (i+1)]
            texts =  df_i.apply(lambda x: set(parse_InChI(x).split()))
            texts = texts.tolist()

            vocab = vocab.union(*texts)

            print(f'completed {i} / {num_breaks}')

        vocab = list(vocab)
        vocab_df = pd.DataFrame({'vocab_value': vocab})

        # save results
        filename = os.path.join(PARAMETERS.csv_save_dir(), 'vocab.csv')
        vocab_df.to_csv(filename, index=False)
        """
               
        return vocab

    vocab = create_vocab()
    
    # create tokenizer
    tokenizer_layer = tf.keras.layers.experimental.preprocessing.TextVectorization(
        standardize=None, split=lambda x: tf.strings.split(x, sep=' ', maxsplit=-1), 
        output_mode='int', output_sequence_length=padded_length, vocabulary=vocab)

    # record EOS token
    tokenized_EOS = tokenizer_layer(tf.constant([EOS]))
    
    # create inverse (de-tokenizer)
    inverse_tokenizer = tf.keras.layers.experimental.preprocessing.StringLookup(
        vocabulary=vocab, invert=True)

    return tokenizer_layer, inverse_tokenizer, tokenized_EOS

In [None]:
TOKENIZER_LAYER, INVERSE_TOKENIZER, TOKENIZED_EOS = \
    Tokenizer(parameters=PARAMETERS)

def tokenize_text(w, x, y, z):
    # note: requires batch dim
    y = TOKENIZER_LAYER(y)
    return w, x, y, z

Image Loader

In [None]:
# Image loaders
def load_image(image_path):
    image_path = tf.squeeze(image_path)
    image = keras.layers.Lambda(lambda x: tf.io.read_file(x))(image_path)
    return image   

def decode_image(image, target_size):
    image = keras.layers.Lambda(lambda x: tf.io.decode_image(x, channels=1, expand_animations=False))(image)
    image = keras.layers.experimental.preprocessing.Resizing(*target_size)(image)
    return image    

## Datasets

Here we create efficient tf.data.Dataset train / validation / test sets.

Out data pipeline will read our prepared CSV of (image filename, parsed InChI and standard InChI) tuples. (If this file is not found, it will be created from scratch. This may take several minutes)  Iterating through the list, it will load batches of corresponding images and labels.

Our datasets contain the following information, accessible by dict keys: images, image_id, InChI, parsed_InChI. (The test set uses InChI = parsed_InChI = 'InChI=1S/', the known required stating value for any InChI code.)

In [None]:
def data_generator(image_set, parameters=PARAMETERS, labels_df=None, decode_images=True):
       
    # get global params
    batch_size = parameters.batch_size()
    target_size = parameters.image_size()
    SOS = parameters.SOS()
    EOS = parameters.EOS()
    
    # dataset options
    options = tf.data.Options()
    options.experimental_optimization.autotune_buffers = True
    options.experimental_optimization.map_vectorization.enabled = True
    options.experimental_optimization.apply_default_optimizations = True
        
    # Train & Validation Datasets
    if image_set in ['train', 'valid']:
        root_folder = parameters.train_images_dir()  # train / valid images
        valid_split = 0.10
        
        # load labels into memory as dataframe
        if labels_df is None:
            labels_df = pd.read_csv(parameters.train_labels_csv())

        # test / train split
        num_valid_samples = int(valid_split * len(labels_df))
        train_df = labels_df.iloc[num_valid_samples: ]  # get train split
        valid_df = labels_df.iloc[: num_valid_samples]  # get validation split

        # shuffle
        train_df = train_df.sample(frac=1)
        valid_df = valid_df.sample(frac=1)

        # load into datasets  # (image_id, InChI)
        train_ds = tf.data.Dataset.from_tensor_slices(train_df.values)
        valid_ds = tf.data.Dataset.from_tensor_slices(valid_df.values)

        train_ds = train_ds.with_options(options)
        valid_ds = valid_ds.with_options(options)

        # update image paths  
        def map_path(x):  # (image_path, image_id, InChI)
            image_id = x[0]
            image_path = path_from_image_id(image_id, root_folder)
            return image_path, x[0], x[1]

        train_ds = train_ds.map(map_path, num_parallel_calls=tf.data.AUTOTUNE)
        valid_ds = valid_ds.map(map_path, num_parallel_calls=tf.data.AUTOTUNE)

        def map_parse(x, y, z):  # (image_path, image_id, InChI)
            parsed_InChI = parse_InChI_py_fn(z)
            return x, y, parsed_InChI, z
   
        train_ds = train_ds.map(map_parse, num_parallel_calls=tf.data.AUTOTUNE)
        valid_ds = valid_ds.map(map_parse, num_parallel_calls=tf.data.AUTOTUNE)
                
        # load images into dataset       
        def open_images(w, x, y, z):
            w = load_image(w)
            return w, x, y, z
        
        train_ds = train_ds.map(open_images, num_parallel_calls=tf.data.AUTOTUNE)
        valid_ds = valid_ds.map(open_images, num_parallel_calls=tf.data.AUTOTUNE)    

        # PREFETCH dataset BEFORE decoding images
        train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
        valid_ds = valid_ds.prefetch(tf.data.AUTOTUNE)

        def decode(w, x, y, z):
            w = decode_image(w, target_size)
            return w, x, y, z

        if decode_images:
            train_ds = train_ds.map(decode, num_parallel_calls=tf.data.AUTOTUNE)
            valid_ds = valid_ds.map(decode, num_parallel_calls=tf.data.AUTOTUNE)    

        # BATCH dataset AFTER decoding images (required by tf.io)
        # should batch before other pure TF Lambda layer ops
        train_ds = train_ds.batch(batch_size, drop_remainder=True)
        valid_ds = valid_ds.batch(batch_size, drop_remainder=True)
        
        # add extra "EOS" values to end of parsed inchi
        def extend_EOS(w, x, y, z):
            y = tf.strings.join([y, EOS, EOS, EOS, EOS, EOS], separator=' ')
            y = tf.reshape(y, [-1])
            return w, x, y, z

        train_ds = train_ds.map(extend_EOS, num_parallel_calls=tf.data.AUTOTUNE)
        valid_ds = valid_ds.map(extend_EOS, num_parallel_calls=tf.data.AUTOTUNE)

        # Tokenize parsed_inchi.  Note: ds must be batched before this step (size=1 is ok) 
        train_ds = train_ds.map(tokenize_text, num_parallel_calls=tf.data.AUTOTUNE)
        valid_ds = valid_ds.map(tokenize_text, num_parallel_calls=tf.data.AUTOTUNE)

        # name the elements
        def map_names(w, x, y, z):
            return  {'image': w, 'image_id': x, 'tokenized_InChI': y, 'InChI': z}
        
        train_ds = train_ds.map(map_names, num_parallel_calls=tf.data.AUTOTUNE)
        valid_ds = valid_ds.map(map_names, num_parallel_calls=tf.data.AUTOTUNE)
        
        return train_ds, valid_ds
    
    # Test Dataset
    elif image_set == 'test':

        # note: image resizing and batching done during this loading step
        # other elements must be batched before combining
        image_ds = tf.keras.preprocessing.image_dataset_from_directory(
            directory=parameters.test_images_dir(), labels='inferred', label_mode=None,
            class_names=None, color_mode='grayscale', batch_size=1, 
            image_size=target_size, shuffle=False, seed=None, validation_split=None, 
            subset=None, follow_links=False)

        # set filenames as label and batch
        image_id_ds = tf.data.Dataset.from_tensor_slices(image_ds.file_paths)
        image_id_ds = image_id_ds.map(lambda x: tf.strings.split(x, os.path.sep)[-1],
                                      num_parallel_calls=tf.data.AUTOTUNE)
        
        # prepare images for TF Records creations. 
        # Note: do this step AFTER filenames step
        if decode_images is False:  
            # convert image to raw byte string. Note: cannot have batch dim for encoding
            image_ds = image_ds.unbatch()
            image_ds = image_ds.map(lambda x: tf.cast(x, dtype=tf.uint16))
            image_ds = image_ds.map(lambda image: tf.io.encode_png(image))
            image_ds = image_ds.map(lambda image: tf.io.serialize_tensor(image))
            
        # dataset consisting solely of InChI start 'InChI=1S/'
        inchi_ds = image_id_ds.map(lambda x: tf.constant(SOS, dtype=tf.string),
                                   num_parallel_calls=tf.data.AUTOTUNE)
        
        # merge datasets
        test_ds = tf.data.Dataset.zip((image_ds, image_id_ds, inchi_ds, inchi_ds))
        
        # prefetch
        test_ds = test_ds.prefetch(tf.data.AUTOTUNE)
        test_ds = test_ds.batch(batch_size)

        # Tokenize parsed_inchi.  Note: ds must be batched before this step (size=1 is ok) 
        test_ds = test_ds.map(tokenize_text, num_parallel_calls=tf.data.AUTOTUNE)

        # set key names
        def map_names(w, x, y, z):
            return  {'image': w, 'image_id': x, 'tokenized_InChI': y, 'InChI': z}
        
        test_ds = test_ds.map(map_names, num_parallel_calls=tf.data.AUTOTUNE)
        


        return test_ds

Create Test, Train and Validation Datasets

In [None]:
if not PARAMETERS.tpu():
    train_ds, valid_ds = data_generator('train', parameters=PARAMETERS, labels_df=train_labels_df, decode_images=True)
    #test_ds = data_generator('test', parameters=PARAMETERS, labels_df=None, decode_images=True)

Examine data shapes

In [None]:
if not PARAMETERS.tpu():

    print('Train DS')
    for val in train_ds.take(1):    
        print('image:', val['image'].shape, 'image_id:', val['image_id'].shape, 
              'InChI:', val['InChI'].shape, 'tokenized_InChI:', val['tokenized_InChI'].shape)

    print('\nValidation DS')
    for val in valid_ds.take(1):
        print('image:', val['image'].shape, 'image_id:', val['image_id'].shape, 
              'InChI:', val['InChI'].shape, 'tokenized_InChI:', val['tokenized_InChI'].shape)

    try:
        print('\nTest DS')
        for val in test_ds.take(1):
            print('image:', val['image'].shape, 'image_id:', val['image_id'].shape, 
                'InChI:', val['InChI'].shape, 'tokenized_InChI:', val['tokenized_InChI'].shape)
    except:
        pass

### TF Records Implementation

In [None]:
# Create TF Examples
def make_example(image, image_id, tokenized_InChI, InChI):
    image_feature = tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[image.numpy()])  # image provided as raw bytestring
    )
    image_id_feature = tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[image_id.numpy()])
    )
    tokenized_InChI_feature = tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(tokenized_InChI).numpy()])
    )
    InChI_feature = tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[InChI.numpy()])
    )

    features = tf.train.Features(feature={
        'image': image_feature,
        'image_id': image_id_feature,
        'tokenized_InChI': tokenized_InChI_feature,
        'InChI': InChI_feature
    })
    
    example = tf.train.Example(features=features)

    return example.SerializeToString()


def make_example_py_fn(image, image_id, InChI, tokenized_InChI):
    return tf.py_function(func=make_example, 
                   inp=[image, image_id, InChI, tokenized_InChI], 
                   Tout=tf.string)


# Decode TF Examples
def decode_example(example, parameters=PARAMETERS):        
    feature_description = {'image': tf.io.FixedLenFeature([], tf.string),
                           'image_id': tf.io.FixedLenFeature([], tf.string),
                           'tokenized_InChI': tf.io.FixedLenFeature([], tf.string),
                           'InChI': tf.io.FixedLenFeature([], tf.string)}
    
    values = tf.io.parse_single_example(example, feature_description)
    
    
    values['image'] = decode_image(values['image'], parameters.image_size())
    values['tokenized_InChI'] = tf.io.parse_tensor(values['tokenized_InChI'],
                                                  out_type=tf.int64)
    values['tokenized_InChI'] = tf.cast(values['tokenized_InChI'], tf.int32)
    
    return values

In [None]:
def serialized_dataset_gen(parameters=PARAMETERS, set_type='train', labels_df=None):
    
    if set_type == 'train':
        train_ds, valid_ds = data_generator(image_set='train', 
                                            parameters=parameters, 
                                            labels_df=train_labels_df, 
                                            decode_images=False)  # output images as bytestrings

        train_ds = train_ds.unbatch()
        valid_ds = valid_ds.unbatch()

        # Create TF Examples
        train_ds = train_ds.map(lambda x: make_example_py_fn(x['image'], x['image_id'], x['tokenized_InChI'], x['InChI']), 
                                num_parallel_calls=tf.data.AUTOTUNE)
        valid_ds = valid_ds.map(lambda x: make_example_py_fn(x['image'], x['image_id'], x['tokenized_InChI'], x['InChI']), 
                                num_parallel_calls=tf.data.AUTOTUNE)
        
        return train_ds, valid_ds
    
    else: #test_set:
        test_ds = data_generator(image_set='test', 
                                 parameters=parameters, 
                                 labels_df=None, 
                                 decode_images=False)  # output images as bytestrings
        
        test_ds = test_ds.unbatch()
            
        # Create TF Examples
        test_ds = test_ds.map(lambda x: make_example_py_fn(x['image'], x['image_id'], x['tokenized_InChI'], x['InChI']), 
                              num_parallel_calls=tf.data.AUTOTUNE)
        
        return test_ds

In [None]:
# Create TF Record Shards
"""
NOTE: Changes have been made to the other dataset pipeline functions. 
Test / Revise this for compatability before running.
"""
def create_records(dataset, subset, num_shards):
    
    folder = subset + '_tfrec'
    
    if subset =='train':
        num_samples = int(.9 * len(train_labels_df))    # test / valid split
    elif subset == 'valid':
        num_samples = int(.1 * len(train_labels_df))
    else:
        num_samples = 2000000

    if not os.path.isdir(folder):
        os.mkdir(folder)
        
    for shard_num in range(num_shards):
        
        filename = os.path.join(folder, f'{subset}_shard_{shard_num+1}')
        try:
            this_shard = dataset.skip(shard_num * num_samples//num_shards).take(num_samples//num_shards)
        
            print(f'Writing shard {shard_num+1}/{num_shards} to {filename}')
            writer = tf.data.experimental.TFRecordWriter(filename)
            writer.write(this_shard)
        except:
            break
    return None 
    
# Load dataset from saved TF Record Shards
def dataset_from_records(subset, parameters=PARAMETERS):

    # optimizations
    options = tf.data.Options()
    options.experimental_optimization.autotune_buffers = True
    options.experimental_optimization.map_vectorization.enabled = True
    options.experimental_optimization.apply_default_optimizations = True

    filepath = os.path.join(parameters.tfrec_dir(), 
                            subset + '_tfrec/*')

    dataset = tf.data.Dataset.list_files(filepath)  # put all tf rec filenames in a ds
    dataset = dataset.shuffle(10**6)
 
    # merge the files
    num_readers = parameters.strategy().num_replicas_in_sync
    dataset = dataset.interleave(tf.data.TFRecordDataset,  
                                 cycle_length=num_readers, block_length=1,
                                 deterministic=False, num_parallel_calls=tf.data.AUTOTUNE)
    
    dataset = dataset.shuffle(10**6)
    
    # decode examples
    dataset = dataset.map(decode_example, num_parallel_calls=tf.data.AUTOTUNE)

    # note: tokenized InChI element spec needs help determining shape
    for val in dataset.take(1):
        padded_length = val['tokenized_InChI'].shape[-1]

    # coerce unknown shape
    dataset = dataset.map(lambda x: {'image':x['image'],
                                     'image_id': x['image_id'],
                                     'tokenized_InChI': tf.reshape(x['tokenized_InChI'], [padded_length]),
                                     'InChI': x['InChI']},
                          num_parallel_calls=tf.data.AUTOTUNE)  

    dataset = dataset.batch(parameters.batch_size(), drop_remainder=True)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
        
    return dataset

In [None]:
# To create TF_Records files
# note: can take 8+ hours for train set alone!
"""
serial_train_ds, serial_valid_ds = serialized_dataset_gen(parameters=PARAMETERS, set_type='train', labels_df=train_labels_df)
serial_test_ds = serialized_dataset_gen(parameters=PARAMETERS, set_type='test', labels_df=None)
"""

"\nserial_train_ds, serial_valid_ds = serialized_dataset_gen(parameters=PARAMETERS, set_type='train', labels_df=train_labels_df)\nserial_test_ds = serialized_dataset_gen(parameters=PARAMETERS, set_type='test', labels_df=None)\n"

In [None]:
# IF USING TF_RECORDS:
#NOTE: If no dataset loads, check if Kaggle GCS directories have changed. (This happends periodically)
if PARAMETERS.tpu():
    with PARAMETERS.strategy().scope(): 
        train_ds = dataset_from_records('train', parameters=PARAMETERS)
        valid_ds = dataset_from_records('valid', parameters=PARAMETERS)

    print('Train DS')
    for val in train_ds.take(1):    
        print('image:', val['image'].shape, 'image_id:', val['image_id'].shape, 
              'InChI:', val['InChI'].shape, 'tokenized_InChI:', val['tokenized_InChI'].shape)

    print('\nValidation DS')
    for val in valid_ds.take(1):
        print('image:', val['image'].shape, 'image_id:', val['image_id'].shape, 
              'InChI:', val['InChI'].shape, 'tokenized_InChI:', val['tokenized_InChI'].shape)

    try:
        print('\nTest DS')
        for val in test_ds.take(1):
            print('image:', val['image'].shape, 'image_id:', val['image_id'].shape, 
                'InChI:', val['InChI'].shape, 'tokenized_InChI:', val['tokenized_InChI'].shape)
    except:
        pass

Train DS
image: (1024, 320, 320, 1) image_id: (1024,) InChI: (1024,) tokenized_InChI: (1024, 200)

Validation DS
image: (1024, 320, 320, 1) image_id: (1024,) InChI: (1024,) tokenized_InChI: (1024, 200)

Test DS


# **Model Layers**

## InChI Encoding

Tokenizer and Embedding to convert parsed InChI strings to tensors of numbers

InChI Input Prep Layer

In [None]:
def InchiPrep(batch_dim, padded_length, max_len, embedding_dim, vocab_size, name='InchiPrep'):
    """ initial InChI prep step. Handles separation into input / target pair,
    including trainable start variable and positional encoding.  """
        
    # embedding layer
    token_embedding_layer = keras.layers.Embedding(input_dim=vocab_size, 
                                                   output_dim=embedding_dim, 
                                                   mask_zero=False,
                                                   input_length=None,  # embedding will also be using on individual token predictions
                                                   name='TokenEmbeddingLayer')    
    
    # inputs
    tokenized_inchi = keras.layers.Input([padded_length], name='tokenized_inchi')
    positional_encoding = keras.layers.Input([max_len, embedding_dim], name='positional_encoding')
    start_var = keras.layers.Input([1, embedding_dim], name='start_var')
    
    inputs = [tokenized_inchi, positional_encoding, start_var]
    
    # split into input / target pairs
    inchi_target = tokenized_inchi[:, :max_len]
    inchi_input = tokenized_inchi[:, :max_len-1]
    
    # input embedding
    inchi_input = token_embedding_layer(inchi_input)

    # add start variable
    inchi_input = tf.cast(inchi_input, dtype=start_var.dtype)
    inchi_input = keras.layers.Concatenate(-2, name='concat_start_var')([start_var, inchi_input])

    # Add positional encoding
    positional_encoding = tf.cast(positional_encoding, dtype=inchi_input.dtype)
    inchi_input = keras.layers.Add(name='add_pos_encoding')([positional_encoding, inchi_input])

    # outpus
    outputs = [inchi_input, inchi_target]
    
    return keras.Model(inputs, outputs, name=name)

In [None]:
def InchiIndividalEmbedding(token_embedding_layer, name='InchiIndividalEmbedding'):
    """ used during character-by-character prediction generation loop """
        
    # embedding layer
    embedding_dim = token_embedding_layer.output_dim
    vocab_size = token_embedding_layer.input_dim
    
    # inputs
    tokenized_inchi = keras.layers.Input([1], name='tokenized_inchi')  #, dtype=tf.int32, name='tokenized_inchi')
    positional_encoding_step = keras.layers.Input([embedding_dim], name='positional_encoding_step')

    inputs = [tokenized_inchi, positional_encoding_step]
   
    # input embedding
    inchi_input = token_embedding_layer(tokenized_inchi)

    # Add positional encoding
    inchi_input = keras.layers.Add(name='add_pos_encoding')([positional_encoding_step, inchi_input])
    
    # outpus
    outputs = [inchi_input]
    
    return keras.Model(inputs, outputs, name=name)

In [None]:
temp_inchi_prep = InchiPrep(batch_dim=16, padded_length=TOKENIZER_LAYER.get_config()['output_sequence_length'],
        max_len=177, embedding_dim=512, vocab_size=200, name='InchiPrep')
temp_inchi_prep.summary()

temp_token_embedding_layer = temp_inchi_prep.get_layer('TokenEmbeddingLayer')

Model: "InchiPrep"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
tokenized_inchi (InputLayer)    [(None, 200)]        0                                            
__________________________________________________________________________________________________
tf.__operators__.getitem_1 (Sli (None, 176)          0           tokenized_inchi[0][0]            
__________________________________________________________________________________________________
TokenEmbeddingLayer (Embedding) (None, 176, 512)     102400      tf.__operators__.getitem_1[0][0] 
__________________________________________________________________________________________________
positional_encoding (InputLayer [(None, 177, 512)]   0                                            
__________________________________________________________________________________________

In [None]:
InchiIndividalEmbedding(token_embedding_layer=temp_token_embedding_layer, name='InchiIndividalEmbedding').summary()

Model: "InchiIndividalEmbedding"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
tokenized_inchi (InputLayer)    [(None, 1)]          0                                            
__________________________________________________________________________________________________
positional_encoding_step (Input [(None, 512)]        0                                            
__________________________________________________________________________________________________
TokenEmbeddingLayer (Embedding) multiple             102400      tokenized_inchi[0][0]            
__________________________________________________________________________________________________
add_pos_encoding (Add)          (None, 1, 512)       0           positional_encoding_step[0][0]   
                                                                 TokenEmbedd

# Image Encoder

Feature Extraction Step 1: Run the images through a pre-trained image network, extracting features as the output of an intermediate convolutional layer. [Technique from "Show, Attend and Tell: Neural Image Caption Generation with Visual Attention," cited at the top of this notebook.]  A dense layer is added for transfer learning and to control the dimension of the attention mechanism used later.

Transfer Model

In [None]:
def TransferModel(image_shape, name='TransferModel'):
    
    # Note: temporarily disable mixed precision during load. (Model doesn't handle it properly)
    tf.keras.mixed_precision.set_global_policy('float32')  # removed mixed precision

    base_transfer_model = keras.applications.EfficientNetB0(
                            include_top=False, weights='imagenet', input_tensor=None,
                            input_shape=(*image_shape[:2], 3), pooling=None, classes=1000,
                            classifier_activation='softmax')

    model = keras.Model(inputs=base_transfer_model.inputs, 
                        outputs=base_transfer_model.get_layer('block7a_project_bn').output, 
                        name=name)
    
    # revert to orig mixed precision policy
    tf.keras.mixed_precision.set_global_policy(PARAMETERS.mixed_precision()) 

    return model

In [None]:
#TransferModel((224,224,1)).summary()

In [None]:
def ImageEncoder(image_shape, output_dim, use_dense_top, name='ImageEncoder'):

    # transfer model
    transfer_model = TransferModel(image_shape, name='TransferModel')
    
    # Input
    image = keras.layers.Input(shape=image_shape, name='image')
    inputs = [image]

    # standardize and apply model
    # Note: 'keras.applications.EfficientNet' models expect color image values in [0, 255] with any shape
    if image_shape[2] == 1:  # convert to color
        image = keras.layers.Lambda(lambda x: tf.image.grayscale_to_rgb(x),
                                    name='grayscale_to_rgb')(image)
    image = keras.applications.efficientnet.preprocess_input(image)  # this is only a pass-through on efficientnet
    image = transfer_model(image)
    
    # extract feature vector
    features_dim = image.shape[3]
    num_vectors = image.shape[2] * image.shape[1]
    image_features = tf.keras.layers.Reshape(target_shape=[num_vectors, features_dim])(image)
    
    # Extra customization layer    
    if use_dense_top:    
        image_features = keras.layers.Dense(output_dim, activation='relu',
                                            name='dense')(image_features)

    outputs = [image_features]
    
    return keras.Model(inputs, outputs, name=name)

In [None]:
# ImageEncoder(image_shape=(320, 320,1), output_dim=256, use_dense_top=True).summary()

## Encoder Attention

Feature Extraction Step 2: Now that we have basic feature vectors, we use self-attention to generate more complex features. This is the encoding step used in "Attention is All You Need," cited above. 

In [None]:
def EncoderAttention(num_blocks, encoder_feature_dim, num_att_elems, name='EncoderAttention'):

    """ note: use num_blocks=6 to match "Attention is All You Need" """

    # inputs
    encoder_vectors = keras.layers.Input([num_att_elems, encoder_feature_dim], name='encoder_vectors')   # from image encoder
    inputs = [encoder_vectors]

    # attention (uses "Attention is All You Need" structure, 
    # except without positional encoding because image feature vectors are unordered)

    # shared params (follows 'num_heads * key_dim = units' from paper)
    num_heads = 8
    key_dim = encoder_feature_dim // 8
    
    for i in range(num_blocks):

        # self-attention block
        attention = tf.keras.layers.MultiHeadAttention(
                                            num_heads=num_heads, 
                                            key_dim=key_dim, 
                                            dropout=.1,
                                            name=f'encoder_attention_{i}')(  
                                query=encoder_vectors, value=encoder_vectors)
        
        attention = keras.layers.Add()([encoder_vectors, attention])
        attention = keras.layers.LayerNormalization(axis=[1,2], epsilon=1e-6)(attention)    

        # Feed Forward Block
        encoder_vectors = keras.layers.Dense(encoder_feature_dim * 4, 'relu',
                                                name=f'dense_encodeR_{i}')(attention)   
        encoder_vectors = keras.layers.Dense(encoder_feature_dim, activation=None,
                                                )(encoder_vectors) 
        encoder_vectors = keras.layers.Dropout(rate=.1)(encoder_vectors)
        encoder_vectors = keras.layers.Add()([attention, encoder_vectors])
        encoder_vectors = keras.layers.LayerNormalization(
                                            axis=[1,2], epsilon=1e-6)(encoder_vectors)     

    # output
    outputs = [encoder_vectors]

    return keras.Model(inputs, outputs, name=name)

In [None]:
EncoderAttention(num_blocks=1, encoder_feature_dim=512, num_att_elems=100).summary()

Model: "EncoderAttention"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
encoder_vectors (InputLayer)    [(None, 100, 512)]   0                                            
__________________________________________________________________________________________________
encoder_attention_0 (MultiHeadA (None, 100, 512)     1050624     encoder_vectors[0][0]            
                                                                 encoder_vectors[0][0]            
__________________________________________________________________________________________________
add (Add)                       (None, 100, 512)     0           encoder_vectors[0][0]            
                                                                 encoder_attention_0[0][0]        
___________________________________________________________________________________

## Decoder Attention

Text Feature extraction + Encoder/Decoder Joint Attention interaction.

With use_covolutions set to False, this is the decoder self-attention feature-extraction step from "Attention is All You Need," cited above (with learned positional encoding). 

Includes an (optional) parameter to add a small convolutional layer for feature enhancement before the attention layer. This is included for experimentation / verification that attention really is all you need.


In [None]:
def DecoderAttention(num_blocks, encoder_units, decoder_units, num_encoder_vectors, max_len, 
                     use_convolutions, name='DecoderAttention'):
    
    """ note: "Attention is All You Need" "uses num_blocks = 6 """

    # Inputs
    encoder_features = keras.layers.Input([num_encoder_vectors, encoder_units], 
                                          name='encoder_features')   # from image
    decoder_features = keras.layers.Input([max_len, decoder_units], 
                                          name='decoder_features')   # from known text
    
    inputs = [encoder_features, decoder_features]

    """
    # (Optional convolution feature extraction, for experimentation) 
    if use_convolutions:
        
        # crop to masked input
        step = tf.math.argmin(mask[0, :, 0])
        decoder_features = decoder_features[:, :step, :]

        # apply convolutions
        decoder_features = tf.keras.layers.Conv1D(filters=decoder_units, kernel_size=3, 
                    strides=1, padding='same', groups=1)(decoder_features)

        # pad back to full size for (masked) Attention input
        decoder_features = tf.pad(decoder_features, [[0,0], [0, max_len - step], [0,0]])
    """

    # shared params (follows 'num_heads * key_dim = units' from paper)
    num_heads = 8
    key_dim = decoder_units // 8

    # create look-ahead mask
    mask = 1 - tf.linalg.band_part(tf.ones((max_len, max_len)), 0, -1)
    
    for i in range(num_blocks):
      
        # decoder self-attention block
        decoder_attention = tf.keras.layers.MultiHeadAttention(
                                            num_heads=num_heads, 
                                            key_dim=key_dim, 
                                            dropout=.1,
                                            name=f'decoder_attention_{i}')(  
            query=decoder_features, value=decoder_features, attention_mask=mask)
        
        decoder_features = keras.layers.Add()([decoder_features, decoder_attention])
        decoder_features = keras.layers.LayerNormalization(
                                            axis=[1,2], epsilon=1e-6)(decoder_features)    

        # joint attention block
        joint_attention = tf.keras.layers.MultiHeadAttention(
                                            num_heads=num_heads, 
                                            key_dim=key_dim, 
                                            dropout=.1,
                                            name=f'joint_attention_{i}')(  
            query=decoder_features, value=encoder_features)
        
        joint_features = keras.layers.Add()([decoder_features, joint_attention])
        joint_features = keras.layers.LayerNormalization(
                                            axis=[1,2], epsilon=1e-6)(joint_features)    

        # Reshape -- (note: Required for XLA / TPU support only) 
        # Has no effect other than allow XLA to properly infer all shapes
        decoder_features = keras.layers.Reshape([max_len, decoder_units])(decoder_features)  
        joint_features = keras.layers.Reshape([max_len, decoder_units])(joint_features)  

        # Feed Forward Block
        decoder_features = keras.layers.Dense(decoder_units * 4, activation='relu',
                                                name=f'dense_decoder_{i}')(joint_features)   
        decoder_features = keras.layers.Dense(decoder_units, activation=None,
                                                )(decoder_features) 
        decoder_features = keras.layers.Dropout(rate=.1)(decoder_features)

        # Reshape -- (note: Required for XLA / TPU support only) 
        decoder_features = keras.layers.Reshape([max_len, decoder_units])(decoder_features)  
        joint_features = keras.layers.Reshape([max_len, decoder_units])(joint_features)  

        decoder_features = keras.layers.Add()([joint_features, decoder_features])
        decoder_features = keras.layers.LayerNormalization(
                                        axis=[1,2], epsilon=1e-6)(decoder_features)     


    outputs = [decoder_features]
    
    return keras.Model(inputs, outputs, name=name)

In [None]:
DecoderAttention(num_blocks=2, encoder_units = 208, decoder_units=512, num_encoder_vectors=64, 
                 max_len=200, use_convolutions=False, name='DecoderAttention').summary()

Model: "DecoderAttention"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
decoder_features (InputLayer)   [(None, 200, 512)]   0                                            
__________________________________________________________________________________________________
decoder_attention_0 (MultiHeadA (None, 200, 512)     1050624     decoder_features[0][0]           
                                                                 decoder_features[0][0]           
__________________________________________________________________________________________________
add_2 (Add)                     (None, 200, 512)     0           decoder_features[0][0]           
                                                                 decoder_attention_0[0][0]        
___________________________________________________________________________________

## Decoder Head (Prediction Output)

This is where we use what was learned in the encoder-decoder attention to output predicted labels. It is the prediction step from "Attention is All You Need."

In [None]:
def DecoderHead(decoder_units, vocab_size, max_len, use_dual_heads, 
                split_char_num, name='DecoderHead'):
    
    decoder_features = keras.layers.Input([max_len, decoder_units], 
                                          name='decoder_features')  # from Decoder Attention layer

    inputs = [decoder_features]

    # Prediction Block               
    # include activation dtype on final output layer for overriding mixed precision policies
    if not use_dual_heads:
        logits = keras.layers.Dense(vocab_size, activation=None, 
            kernel_initializer= tf.keras.initializers.HeNormal())(decoder_features)  
    
    else:
        decoder_features_0 = decoder_features[:, :split_char_num, :]
        decoder_features_1 = decoder_features[:, split_char_num: , :]
        
        logits_0 = keras.layers.Dense(vocab_size, activation=None, 
            kernel_initializer= tf.keras.initializers.HeNormal())(decoder_features_0)  

        logits_1 = keras.layers.Dense(vocab_size, activation=None, 
            kernel_initializer= tf.keras.initializers.HeNormal())(decoder_features_1)  

        logits = keras.layers.Concatenate(1)([logits_0, logits_1])


    probs = keras.layers.Activation('softmax', dtype=tf.float32, name='probs')(logits)  

    outputs = [probs]

    return keras.Model(inputs, outputs, name=name)

In [None]:
DecoderHead(decoder_units=512, vocab_size=199, max_len=200,
            use_dual_heads=True, split_char_num=50).summary()

Model: "DecoderHead"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
decoder_features (InputLayer)   [(None, 200, 512)]   0                                            
__________________________________________________________________________________________________
tf.__operators__.getitem_2 (Sli (None, 50, 512)      0           decoder_features[0][0]           
__________________________________________________________________________________________________
tf.__operators__.getitem_3 (Sli (None, 150, 512)     0           decoder_features[0][0]           
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 50, 199)      102087      tf.__operators__.getitem_2[0][0] 
________________________________________________________________________________________

## Update Mechanism (Optional)

*Note: this is fully coded but I have not had time to train parameters with it. I leave that as a future opportunity for exploration.*

NLP technicques typically output logits to find the highest likelhood token prediction. This can be improved to a (local) maximum likelihood selection using a "beam step" that ay override the initial prediction choice. 

This layer is an alternative system for updating predictions. Unlike "beam," it is trainable and includes longer-range dependencies (instead of the very "local" beam step.) The entire original prediction is passed through a bidirectional RNN (decoder feature extraction) followed by AIAYN stye attention blocks. No masking is needed since we have the full RNN output to work with.

In [39]:
def BeamUpdate(num_beam_blocks, num_att_blocks,
               num_encoder_vectors, encoder_units, 
               decoder_units, max_len, vocab_size, name='BeamUpdate'):
    
    # update to required GRU model dtypes
    tf.keras.mixed_precision.set_global_policy('float32')
    
    # layers
    # note: GRU doesn't appear to be compatible with reduced precision
    BeamUnit = keras.layers.GRU(decoder_units, return_sequences=True, 
                    return_state=True, go_backwards=True,
                    dtype=tf.keras.mixed_precision.Policy('float32'))  

    use_convolutions = False
    BeamDecoderAttention = DecoderAttention(num_att_blocks, encoder_units, decoder_units, 
                                            num_encoder_vectors, max_len, 
                                            use_convolutions, name='BeamDecoderAttention')

    BeamDecoderHead = DecoderHead(decoder_units, vocab_size, max_len, 
                                  use_dual_heads=False, split_char_num=None,
                                  name='BeamDecoderHead')

    
    # Inputs
    beam_input = keras.layers.Input([max_len, vocab_size], name='beam_input') 
    hidden_state = keras.layers.Input([decoder_units], name='hidden_state')
    encoder_features = keras.layers.Input([num_encoder_vectors, encoder_units], name='encoder_features')   # from image 
    mask = keras.layers.Input([max_len, max_len], name='mask')   # should pass in all 1's, i.e. no masking
    
    inputs = [beam_input, hidden_state, encoder_features, mask]

    # create initial hidden state for RNN dim = decoder_units
    beam_hidden_state = tf.reduce_mean(encoder_features, -2)
    beam_hidden_state = keras.layers.Dense(decoder_units, activation='relu')(beam_hidden_state)
    
    # Decoder encoding using 1 or more Beam layers
    for i in range(num_beam_blocks):
        beam_out, beam_hidden_state = \
            BeamUnit(beam_input, initial_state=[beam_hidden_state])

    # Attention & Prediction (uses "Attention is All You Need" structure)
    decoder_features = BeamDecoderAttention([encoder_features, beam_out, mask])

    probs = BeamDecoderHead([decoder_features])  

    outputs = [probs]

    return keras.Model(inputs, outputs, name=name)

In [40]:
BeamUpdate(num_beam_blocks=1, num_att_blocks=1, num_encoder_vectors=50, encoder_units=256, 
            decoder_units=512, max_len=200, vocab_size=160, name='BeamUpdate').summary()

Model: "BeamUpdate"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
encoder_features (InputLayer)   [(None, 50, 256)]    0                                            
__________________________________________________________________________________________________
tf.math.reduce_mean (TFOpLambda (None, 256)          0           encoder_features[0][0]           
__________________________________________________________________________________________________
beam_input (InputLayer)         [(None, 200, 160)]   0                                            
__________________________________________________________________________________________________
dense_9 (Dense)                 (None, 512)          131584      tf.math.reduce_mean[0][0]        
_________________________________________________________________________________________

# **Full Model**

All the components are combined into a full encoder/decoder model. This is implemented using the subclassing API with custom call, train,  evaluation and prediction steps. Once initialized, the models have full access to high-level model.fit(), model.compile() and model.save_weights() methods.

An extra features implemented is having Decoder() elements in *series* (not stacked). This adds more trainable parameters without affecting inference speed, and allows decoders to specialize more on different regions of the text.

BaseTrainer() model has the BeamUpdate mechanism disabled. InchiGenerator() models include the BeamUpdate.

In [41]:
class BaseTrainer(keras.Model):
    
    def __init__(self, image_dense_output_dim, decoder_units, 
                 num_encoder_blocks, num_decoder_blocks, 
                 use_dense_encoder_top, use_convolutions, use_dual_decoders, 
                 max_len, parameters, name='BaseTrainer', **kwargs):
        super().__init__(name=name, **kwargs)

        """ Beam updates turned off. Training conducted with teach-fed inputs.
        note: dataset provided as (image, image_id, tokenized_InChI, InChI) """
        
        # tokenizer info
        
        tokenizer_layer, self.inverse_tokenizer, self.tokenized_EOS = \
                Tokenizer(parameters=parameters)
        self.vocab_size = len(tokenizer_layer.get_vocabulary())
        
        # save params
        self.max_len = max_len  # max number of token prediction length
        self.image_dense_output_dim = image_dense_output_dim
        self.decoder_units = decoder_units
        self.num_encoder_blocks = num_encoder_blocks
        self.num_decoder_blocks = num_decoder_blocks
        self.use_dense_encoder_top = tf.constant(use_dense_encoder_top, tf.bool)
        self.use_convolutions = tf.constant(use_convolutions, tf.bool)
        self.use_dual_decoders = tf.constant(use_dual_decoders, tf.bool)
        self.parameters = parameters
        self.SOS = parameters.SOS()
        self.EOS = parameters.EOS()
        self.embedding_dim = self.decoder_units  # required for consistency

        #  Initialize positional encoding. (Batch dim added during build)
        initializer = tf.random_normal_initializer()(shape=[self.max_len, self.embedding_dim])
        self.positional_encoding = tf.Variable(initializer, trainable=True,
                                               dtype=tf.float32,
                                               name='positional_encoding', )
        self.positional_encoding = tf.expand_dims(self.positional_encoding, axis=0)
    
        # trainable start variable  (Batch dim added during build)
        initializer = tf.random_normal_initializer()(shape=[1, self.embedding_dim])
        self.start_var = tf.Variable(initializer, trainable=True,
                                     dtype=tf.float32,
                                     name='start_variable')
        self.start_var = tf.expand_dims(self.start_var, axis=0)

        # Layers (note: additional layers initialized within '.build()')
        split_char_num = 50  # controls dual decoder head switch (if applicable)
        self.decoder_head = DecoderHead(decoder_units=self.decoder_units, 
                                        vocab_size=self.vocab_size, 
                                        max_len=self.max_len, 
                                        use_dual_heads=self.use_dual_decoders,
                                        split_char_num=split_char_num,
                                        name='DecoderHead')

    def get_config(self):
        config = {'image_dense_output_dim': self.image_dense_output_dim,
                  'decoder_units': self.decoder_units,
                  'num_encoder_blocks': self.num_encoder_blocks,
                  'num_decoder_blocks': self.num_decoder_blocks,
                  'use_dense_encoder_top': self.use_dense_encoder_top,
                  'use_convolutions': use_convolutions,
                  'use_dual_decoders': self.use_dual_decoders,
                  'max_len': self.max_len,
                  'parameters': self.parameters}
        return config 

    def build(self, input_shape):
        # note: dataset prepared with dict keys (image, tokenized_InChI, image_id, InChI)
        self.batch_size = input_shape[0][0]  # 'image' batch size
        self.padded_length = input_shape[1][-1]  # 'tokenized_InChI' token sequ. length

        # InChI Encoders
        self.inchi_prep = InchiPrep(batch_dim=self.batch_size,
                                    padded_length=self.padded_length,
                                    max_len=self.max_len, 
                                    embedding_dim=self.embedding_dim, 
                                    vocab_size=self.vocab_size,
                                    name='InchiPrep')
        
        # Token Embedding Layers
        self.token_embedding_layer = self.inchi_prep.get_layer('TokenEmbeddingLayer')               
        self.inchi_indiv_embedding = InchiIndividalEmbedding(
            token_embedding_layer=self.token_embedding_layer, name='InchiIndividalEmbedding')
        
        # Image Encoders
        self.image_shape = input_shape[0][1:]  # drops batch dims from 'image' shape   
        self.image_encoder = ImageEncoder(image_shape=self.image_shape, 
                                          output_dim=self.image_dense_output_dim,
                                          use_dense_top=self.use_dense_encoder_top,
                                          name='ImageEncoder')   
        # collect params
        self.num_encoder_vectors = self.image_encoder.output_shape[1]
        self.image_features_dim = self.image_encoder.output_shape[2]
        
        # image encoding
        self.encoder_attention = EncoderAttention(num_blocks=self.num_encoder_blocks, 
                                                  encoder_feature_dim=self.image_features_dim, 
                                                  num_att_elems=self.num_encoder_vectors,
                                                  name='EncoderAttention')

        # Decoders
        self.decoder_attention = DecoderAttention(num_blocks=self.num_decoder_blocks,
                                                  encoder_units=self.image_features_dim,
                                                  decoder_units=self.decoder_units, 
                                                  num_encoder_vectors=self.num_encoder_vectors,
                                                  max_len=self.max_len,
                                                  use_convolutions=self.use_convolutions,
                                                  name='DecoderAttention') 
        
        # positional encoding and start variables: add batch dim
        self.positional_encoding = tf.tile(self.positional_encoding, [self.batch_size, 1, 1])
        self.start_var = tf.tile(self.start_var, [self.batch_size, 1, 1])

        # save Transfer Model (for controlling trainability)
        self.transfer_model = self.image_encoder.get_layer('TransferModel')

    def encoding_step(self, image, tokenized_InChI, training):
        
        # get inputs / target pairs
        inchi_vectors, targets = self.inchi_prep([tokenized_InChI, self.positional_encoding, self.start_var]) 

        # for inference: zero out all inchi vectors except the initial value
        if not training:  
            zeros = tf.zeros_like(inchi_vectors[:, 1:, :], dtype=inchi_vectors.dtype)
            inchi_vectors = tf.concat([inchi_vectors[:, 0:1, :], zeros], axis=1)

        # image encoding
        image = self.image_encoder(image)
        image_encoder_vectors = self.encoder_attention(image)

        return image_encoder_vectors, inchi_vectors, targets
    
    @tf.function
    def call(self, inputs, training=False):

        # inputs
        image = inputs[0]
        tokenized_InChI = inputs[1]
        
        # Encoder
        encoder_vectors, inchi_vectors, targets = \
            self.encoding_step(image, tokenized_InChI, training)

        # Decoder (char generation loop using generated predictions)
        # inference loop
        if not training:  # inference loop
            predicted_probs = self.generation_loop(encoder_vectors=encoder_vectors, 
                                                   inchi_vectors=inchi_vectors)
        
        # teacher-fed training
        else:  
            decoder_attention = self.decoder_attention([encoder_vectors, inchi_vectors])
            predicted_probs = self.decoder_head([decoder_attention])
        
        return predicted_probs
     
    @tf.function(jit_compile=False)
    def tokens_to_string(self, tokens):

        parsed_string_vals = self.inverse_tokenizer(tokens)
        string_vals = keras.layers.Lambda(
            lambda x: tf.strings.reduce_join(x, axis=-1))(parsed_string_vals)

        # remove first EOS generated and everything after
        pattern = ''.join([self.EOS, '.*$'])
        string_vals = tf.strings.regex_replace(string_vals, pattern, rewrite='', 
                                               replace_global=True, name='remove_EOS')   

        return string_vals

    # Full Generation Loop
    def generation_loop(self, encoder_vectors, inchi_vectors):

        # containers
        # note: use fixed size arrays for XLA / JIT
        probs_array = tf.TensorArray(size=self.max_len, 
                                     dtype=tf.float32, 
                                     dynamic_size=False, 
                                     element_shape=[None, self.vocab_size],
                                     tensor_array_name='probs_array')
        
        # create initial (embedded) predictions
        embedded_preds_array = tf.TensorArray(size=self.max_len, 
                                              dtype=tf.float32, 
                                              clear_after_read = False, 
                                              dynamic_size=False, 
                                              element_shape=[None, self.embedding_dim],
                                              tensor_array_name='embedded_preds_array')
        
        # pre-populate embedded predictions array with InChI values
        # (These are just start value and zeros during inference-- 
        # make sure this happened during the InChI encoding step)
        # (note: structure below would allow for teacher-training an RNN with attention)
        for step in range(1, self.max_len):
            embedded_preds_array = embedded_preds_array.write(
                            step, tf.cast(inchi_vectors[:, step, :], tf.float32))
       
        # character generation loop (tf.while_loop)
        def loop_fn(step, continue_cond, inchi_vectors, probs_array, embedded_preds_array):

            # attention update
            decoder_attention = self.decoder_attention([encoder_vectors, inchi_vectors])
            
            # get current step's probabilities and save results
            probs = self.decoder_head([decoder_attention])
            probs = probs[:, step, :]
            probs_array = probs_array.write(step, tf.cast(probs, dtype=probs_array.dtype))

            # make prediction
            predictions = tf.argmax(probs, axis=-1)  # shape: (batch_size)
            positional_encoding_step = self.positional_encoding[:, step, :]

            # generated next iteration's input
            embedded_preds = self.inchi_indiv_embedding(
                                [predictions, positional_encoding_step])

            # increment counter and check stopping conditions
            step = tf.math.add(step, 1)

            # check: final step reached
            continue_cond = tf.math.less(step, self.max_len)
           
            # check: all batch elements reached EOS
            def true_fn_2(predictions):
                predictions = tf.expand_dims(predictions, axis=1)
                continue_cond = tf.math.reduce_any(predictions != self.tokenized_EOS)
                predictions = tf.squeeze(predictions, axis=1)
                return continue_cond

            continue_cond = tf.cond(tf.math.equal(continue_cond, True), 
                                    lambda: true_fn_2(predictions),
                                    lambda: continue_cond)
            
            # save input embedding
            def true_fn_3(embedded_preds_array):
                embedded_preds_array = embedded_preds_array.write(
                    step, tf.cast(tf.squeeze(embedded_preds), dtype=embedded_preds_array.dtype))  # match array dtype
                return embedded_preds_array

            embedded_preds_array = tf.cond(tf.math.equal(continue_cond, True), 
                                           lambda: true_fn_3(embedded_preds_array),
                                           lambda: embedded_preds_array)

            # prepare padded input for next iteration
            def true_fn_4(embedded_preds_array, orig_dtype): #), orig_dtype):
                inchi_vectors = embedded_preds_array.stack()
                inchi_vectors = tf.transpose(inchi_vectors, perm=[1, 0, 2])  
                
                # cast back to original dtype
                inchi_vectors = tf.cast(inchi_vectors, dtype=orig_dtype)  
                return inchi_vectors

            orig_dtype = inchi_vectors.dtype  # (careful casting needed for mixed precision)
            inchi_vectors = tf.cond(tf.math.equal(continue_cond, True), 
                                    lambda: true_fn_4(embedded_preds_array, orig_dtype),
                                    lambda: inchi_vectors)

            return [step, continue_cond, inchi_vectors, probs_array, embedded_preds_array]

        # stopping condition function
        def cond_fn(step, continue_cond, inchi_vectors, probs_array, embedded_preds_array):
            return continue_cond

        # generation loop
        # initial loop values
        step = 0
        continue_cond = tf.constant(True, dtype=tf.bool)

        step, continue_cond, inchi_vectors, probs_array, embedded_preds_array \
            = tf.while_loop(
                    maximum_iterations=self.max_len-1,  
                    cond=cond_fn, 
                    body=loop_fn, 
                    loop_vars=[step, continue_cond, inchi_vectors,  
                               probs_array, embedded_preds_array],               
                    shape_invariants=
                        [tf.TensorShape([]), # step
                         tf.TensorShape([]), # continue_cond
                         tf.TensorShape([None, self.max_len, 
                                         self.embedding_dim]), # inchi_vectors
                         None,  # probs_array
                         None],  # embedded_preds_array
                )
        
        # unpack probs_array
        predicted_probs = probs_array.stack()  # predicted logits
        predicted_probs = tf.squeeze(predicted_probs)
        predicted_probs = tf.transpose(predicted_probs, perm=[1, 0, 2])  

        return predicted_probs

In [42]:
class InchiGenerator(BaseTrainer):
    """
    NOTE: Updates have been make to the main BaseTrainer model class since 
    InChiGenerator was last revised.  Test / Revise this class before using.

    Beam updates turned on, training conducted using generated preds.
    """

    def __init__(self, base_model, num_beam_att_blocks, name='BeamInchiTrainer', **kwargs):
        
        super().__init__(image_dense_output_dim=base_model.image_dense_output_dim,
                         use_dense_encoder_top=base_model.use_dense_encoder_top,
                         decoder_units=base_model.decoder_units,
                         num_decoder_blocks=base_model.num_decoder_blocks,
                         num_encoder_blocks=base_model.num_encoder_blocks,
                         use_convolutions=base_model.use_convolutions,
                         use_dual_decoders= base_model.use_dual_decoders,
                         max_len = base_model.max_len,
                         parameters=base_model.parameters, 
                         name=name, **kwargs)
        
        self.num_beam_att_blocks = num_beam_att_blocks
            
    def get_config(self):
        config = {'num_beam_att_blocks': self.num_beam_att_blocks,
                  'image_dense_output_dim': self.image_dense_output_dim,
                  'decoder_units': self.decoder_units,
                  'num_decoder_blocks': self.num_decoder_blocks,
                  'num_encoder_blocks': self.num_encoder_blocks,
                  'use_dense_encoder_top': self.use_dense_encoder_top,
                  'use_convolutions': use_convolutions,
                  'max_len': self.max_len,
                  'use_dual_decoders': self.use_dual_decoders,
                  'parameters': self.parameters}
        return config 

    def build(self, input_shape):

        super().build(input_shape)

        self.num_encoder_vectors = super().image_encoder.output_shape[1]
        self.image_features_dim = super().image_encoder.output_shape[2]

        self.beam = BeamUpdate(num_beam_blocks=self.num_beam_att_blocks, 
                               num_att_blocks=self.num_beam_att_blocks,
                               num_encoder_vectors=self.num_encoder_vectors,
                               encoder_units=self.decoder_units, 
                               decoder_units=self.decoder_units, 
                               max_len=self.max_len, 
                               vocab_size=self.vocab_size, 
                               name='BeamUpdate')
    
    def call(self, inputs, training=False):

        #### first portion matches base model path:  #############
        # inputs
        image = inputs[0]
        tokenized_InChI = inputs[1]
        
        # Encoder
        encoder_vectors, inchi_vectors, targets = \
            self.encoding_step(image, tokenized_InChI, training)

        # Decoder (char generation loop using generated predictions)
        # inference loop
        if not training:  # inference loop
            predicted_probs = self.generation_loop(encoder_vectors=encoder_vectors, 
                                                   inchi_vectors=inchi_vectors)
        
         # training steps (before gradient calculations)
        else:  
            decoder_attention = self.decoder_attention([encoder_vectors, inchi_vectors])
            predicted_probs = self.decoder_head([decoder_attention])
       
        #### beam update #############
        # create initial RNN state
        initial_state = tf.math.reduce_mean(encoder_vectors, axis=1)

        # get probs and predictions
        predicted_probs = self.beam([predicted_probs, hidden_state, encoder_vectors])

        return predicted_probs

In [43]:
class EditDistanceMetric(tf.keras.metrics.Metric):
    def __init__(self, name='edit_distance', **kwargs):
        super().__init__(name=name, **kwargs)
        self.edit_distance = self.add_weight(name='edit_distance', initializer='zeros')
        self.batch_counter = self.add_weight(name='batch_counter', initializer='zeros')
    
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.sparse.from_dense(y_true)
        y_pred = tf.sparse.from_dense(tf.argmax(y_pred, axis=-1))  # convert probs to preds

        y_true = tf.cast(y_true, tf.int32)
        y_pred = tf.cast(y_pred, tf.int32)

        # compute edit distance (of parsed tokens)
        edit_distance = tf.edit_distance(y_pred, y_true, normalize=False)
        self.edit_distance.assign_add(tf.reduce_mean(edit_distance))

        # update counter
        self.batch_counter.assign_add(tf.reduce_sum(1.))
    
    def result(self):
        return self.edit_distance / self.batch_counter

    def reset_state(self):
        # The state of the metric will be reset at the start of each epoch.
        self.edit_distance.assign(0.0)
        self.batch_counter.assign(0.0)

In [44]:
# Learning rate schedule used in "Attention is All You Need"
class LRScheduleAIAYN(tf.keras.optimizers.schedules.LearningRateSchedule):

    def __init__(self, scale_factor=1, warmup_steps=4000):  # defaults reflect paper's values
        # cast dtypes
        self.warmup_steps = tf.constant(warmup_steps, dtype=tf.float32)
        dim = tf.constant(352, dtype=tf.float32)
        scale_factor = tf.constant(scale_factor, dtype=tf.float32)
        
        self.scale = scale_factor * tf.math.pow(dim, -1.5)

    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        opt_1 = step * tf.math.pow(self.warmup_steps, -1.5)  # linear increase
        opt_2 = tf.math.pow(step, -.5) # decay
        return self.scale * tf.math.reduce_min([opt_1, opt_2])

# visualize learning rate 
#temp_lr = LRScheduleAIAYN()
#plt.plot([i for i in range(1, 8000)], [temp_lr(i) for i in range(1, 8000)])
#print('Learning Rate Schedule')

## Build Model

Model Parameters

In [45]:
NAME_MODIFIER = ''

# build model
IMAGE_DENSE_OUTPUT_DIM = 256  # note: only used with USE_DENSE_ENCODER_TOP = True.
DECODER_UNITS = 512  # # "All You Need is Attention" uses 512 units
BEAM_RNN_UNITS = 128  # note: only used in beam_model.
NUM_ENCODER_BLOCKS = 6  # "All You Need is Attention" uses 6 blocks
NUM_DECODER_BLOCKS = 6 # "All You Need is Attention" uses 6 blocks 
MAX_LEN = 200
USE_DENSE_ENCODER_TOP = False  # for fine tuning image features pre-self-attention
USE_DUAL_DECODERS = False
USE_CONVOLUTIONS = False
if USE_CONVOLUTIONS:
    checkpoint_save_name = 'ConvAtt_model_checkpoints' + NAME_MODIFIER
else:
    checkpoint_save_name = 'AISAYN_model_checkpoints' + NAME_MODIFIER

LOAD_CHECKPOINT_FILE = os.path.join(PARAMETERS.load_checkpoint_dir(), checkpoint_save_name, checkpoint_save_name)
SAVE_CHECKPOINT_FILE = os.path.join(PARAMETERS.checkpoint_dir(), checkpoint_save_name, checkpoint_save_name)

# note: in Kaggle,
# LOAD_CHECKPOINT_FILE points to saved outputs from prev session
# SAVE_CHECKPOINT_FILE points to saved outputs from current session

Initialize model

In [46]:
# NOTE: If nothing happens on model build / call, check if Kaggle GCS directories have changed. (This happends periodically)

# Update inputs: remove string keys, as they are not compatible with TPU
train_ds_int_index = train_ds.map(lambda x: (x['image'], x['tokenized_InChI'][:, :MAX_LEN], 
                                             x['image_id'], x['InChI'])).prefetch(tf.data.AUTOTUNE)

valid_ds_int_index = valid_ds.map(lambda x: (x['image'], x['tokenized_InChI'][:, :MAX_LEN], 
                                             x['image_id'], x['InChI'])).prefetch(tf.data.AUTOTUNE)

# create model using distribution strategy
with PARAMETERS.strategy().scope():
    base_model = BaseTrainer(image_dense_output_dim=IMAGE_DENSE_OUTPUT_DIM,
                                decoder_units=DECODER_UNITS,
                                num_encoder_blocks=NUM_ENCODER_BLOCKS,
                                num_decoder_blocks=NUM_DECODER_BLOCKS,
                                use_dense_encoder_top=USE_DENSE_ENCODER_TOP,
                                use_convolutions=USE_CONVOLUTIONS,
                                use_dual_decoders=USE_DUAL_DECODERS,
                                max_len=MAX_LEN,  
                                parameters=PARAMETERS, 
                                name='BaseTrainer')

    # build (Note: make sure to do this before compiling!)
    # NOTE: On distributed training, the batch size is distributed among replicas, 
    # so a smaller batch size must be used on build /first call
    if PARAMETERS.tpu():  
    
        temp_ds = PARAMETERS.strategy().experimental_distribute_dataset(train_ds_int_index)
        temp_ds = iter(temp_ds)
        val = next(temp_ds)

        # build with new val
        temp_func = tf.function(func=base_model, experimental_relax_shapes=True,
                                experimental_follow_type_hints=True)
        PARAMETERS.strategy().run(temp_func, args=[(val[0], val[1])])  # use strategy.run() on TPU
            

    else:  
        # build with original val
        for val in train_ds_int_index.take(1): 
            base_model(val, training=False)
            #base_model(val, training=True)
    
    # compiler components
    learning_rate = LRScheduleAIAYN(scale_factor=PARAMETERS.strategy().num_replicas_in_sync, 
                                    warmup_steps=4000)  # from "Attention is All You Need"       

    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate,  # params from "Attention is All You Need"
                                         beta_1=0.9, beta_2=0.98, epsilon=10e-9)

    # callbacks
    checkpoint = tf.keras.callbacks.ModelCheckpoint(SAVE_CHECKPOINT_FILE, monitor='loss', 
                            save_weights_only=True, save_best_only=True, save_freq='epoch',
                            options=tf.train.CheckpointOptions(experimental_io_device='/job:localhost'))
    nan_stop = tf.keras.callbacks.TerminateOnNaN()
    backup_checkpoint = tf.keras.callbacks.experimental.BackupAndRestore(
        os.path.join(PARAMETERS.checkpoint_dir(), checkpoint_save_name, '/'))
    
    # metrics
    edit_dist_metric = EditDistanceMetric()
    
    # loss with label smoothing
    loss_fn = keras.losses.CategoricalCrossentropy(label_smoothing=.1)

    # optimizations
    if not PARAMETERS.tpu():
        os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private'  # better balances CPU / GPU interaction in tf.data
        tf.config.optimizer.set_jit("autoclustering")  # XLA compiler optimization
        optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)  # required with mixed precision on GPU / CPU

        # compile       
        base_model.compile(optimizer=optimizer, 
                           loss=loss_fn,
                           metrics=['categorical_accuracy', edit_dist_metric],
                           steps_per_execution=8)
    else:
        # compile (without EditDistance metric - not functioning on TPU)
        base_model.compile(optimizer=optimizer, 
                           loss=loss_fn,
                           metrics=['categorical_accuracy'],
                           steps_per_execution=4*PARAMETERS.strategy().num_replicas_in_sync)

    # show summary
    print(base_model.summary())
    print('Models initialized.')

#"""
# verify model calls & methods work
if not PARAMETERS.tpu():
    base_model(val, training=False)

    # sync weights
    # WARNING!: in Kaggle this loads from prev session saved weights
    try:
        with PARAMETERS.strategy().scope(): 
            base_model.load_weights(LOAD_CHECKPOINT_FILE)  
            
            pass
    except:
        print('No weights loaded')  

else:
    # sync weights
    # WARNING!: in Kaggle this loads from prev session saved weights
    try:
        with PARAMETERS.strategy().scope(): 
            base_model.load_weights(LOAD_CHECKPOINT_FILE, 
                                    options=tf.train.CheckpointOptions(experimental_io_device="/job:localhost"))  

    except:
        print('No weights loaded')    
#"""


Model: "BaseTrainer"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
string_lookup_3 (StringLooku multiple                  0 (unused)
_________________________________________________________________
DecoderHead (Functional)     (None, 200, 199)          102087    
_________________________________________________________________
InchiPrep (Functional)       [(None, 200, 512), (None, 101888    
_________________________________________________________________
TokenEmbeddingLayer (Embeddi multiple                  101888    
_________________________________________________________________
InchiIndividalEmbedding (Fun (None, 1, 512)            101888    
_________________________________________________________________
ImageEncoder (Functional)    (None, 100, 320)          3634851   
_________________________________________________________________
EncoderAttention (Functional (None, 100, 320)          

Inference Functions

In [47]:
def run_inference(model, dataset, return_lev_score=False, take_num=100, skip_set_num=0):
    """ produces image_id, pred_string pairs """
    
    # update model to tf.functions
    # NOTE: these model methods are decorates with @tf.function within class definition
    model_tf = model.call
    tokens_to_string_tf = model.tokens_to_string

    """  # if @tf.function not used in class def, adapt functions with:
    def tf_fn(orig_func):
        return tf.function(func=orig_func, experimental_relax_shapes=True,
                           experimental_follow_type_hints=True,
                           jit_compile=False)  # JIT compile causes GPU system to crash
    """

    # initialize containers
    image_ids_list = []
    generated_predictions_list = []
    true_InChI_list = []
    
    # prepare dataset for parallel / distributed execution
    if PARAMETERS.tpu():
        if not take_num:  # use full dataset (~ 4 min on TPU)
            dataset = PARAMETERS.strategy().experimental_distribute_dataset(dataset)
        
        else:  # use restricted dataset (useful for testing purposes)
            dataset = PARAMETERS.strategy().experimental_distribute_dataset(
                        dataset.skip(take_num * skip_set_num).take(take_num))
            
        # convert distributed ds to iterator
        dataset = iter(dataset)

    else:
        if not take_num:  # use full dataset
            pass

        else:  # use restricted dataset (useful for testing purposes)
            dataset = dataset.skip(take_num * skip_set_num).take(take_num).prefetch(tf.data.AUTOTUNE)
        
    
    # generate (image_id, token preds)
    for val in dataset:

        # get actual values and gather into single batch
        true_InChI = val[3]
        if PARAMETERS.tpu():
            true_InChI = PARAMETERS.strategy().gather(true_InChI, axis=0)

        # get corresponding image ids and gather into single batch
        image_ids = val[2]
        if PARAMETERS.tpu():
            image_ids = PARAMETERS.strategy().gather(image_ids, axis=0)

        # get predictions (as tokens)  and gather into single batch
        if PARAMETERS.tpu():
            generated_probs = PARAMETERS.strategy().run(model_tf, args=[val[:2]])
        else:
            generated_probs = model_tf(val[:2])
        
        if PARAMETERS.tpu():
            generated_probs = PARAMETERS.strategy().gather(generated_probs, axis=0)
        generated_probs = tf.squeeze(generated_probs)

        # convert predictions to strings
        generated_predictions = tf.argmax(generated_probs, axis=-1)
        generated_predictions = tokens_to_string_tf(generated_predictions)

        # decode bytestrings and update containers
        image_ids_list.extend([x.decode() for x in image_ids.numpy().tolist()])
        generated_predictions_list.extend([x.decode() for x in generated_predictions.numpy().tolist()])
        true_InChI_list.extend([x.decode() for x in true_InChI.numpy().tolist()])

    output = [image_ids_list, generated_predictions_list, true_InChI_list]

    if return_lev_score:
        
        # compute scores
        lev_score = [levenshtein(pred, orig) for (pred, orig)
                      in zip(generated_predictions_list, true_InChI_list)]

        # add to outputs
        output.append(lev_score)

    return output

Test inference speed

In [48]:
"""
# test inference speed - time for 'take_num' batches
%timeit run_inference(base_model, dataset=train_ds_int_index, return_lev_score=True, take_num=2, skip_set_num=0)
"""

1 loop, best of 5: 47 s per loop


In [None]:
if not PARAMETERS.tpu():
    tensorboard = tf.keras.callbacks.TensorBoard(log_dir='./logs/')
    %tensorboard --logdir './logs/'

In [None]:
"""
# Full model: inference speed (with beam)
%%timeit
num_batches = 3

for val in train_ds.unbatch().batch(PARAMETERS.inference_batch_size()).take(num_batches): 
    im_id, preds = (base_model.predict(val))
"""

# Training

Prepare dataset for training

In [49]:
# Dataset: Random Perturbations and One-Hot-Encoding
# note: one-hot needed so we can use label smoothing in CrossEntropy Loss
tf.keras.mixed_precision.set_global_policy('float32')  # temporarily removed mixed precision

# define transformations
rotate = keras.layers.experimental.preprocessing.RandomRotation(
            factor=(-0.5, 0.5), fill_mode='constant')
contrast = keras.layers.experimental.preprocessing.RandomContrast(factor=.1)
depth = base_model.vocab_size
one_hot = keras.layers.Lambda(lambda x: tf.one_hot(x, depth=depth)) 

# apply transformations
train_ds_prepared = train_ds_int_index.map(lambda w, x, y, z: (rotate(w), x),
                                           num_parallel_calls=tf.data.AUTOTUNE)
train_ds_prepared = train_ds_prepared.map(lambda w, x: (contrast(w), x),
                                          num_parallel_calls=tf.data.AUTOTUNE)
train_ds_prepared = train_ds_prepared.map(lambda w, x: ((w, x), one_hot(x)),
                                          num_parallel_calls=tf.data.AUTOTUNE)\
                                          .prefetch(tf.data.AUTOTUNE)

valid_ds_prepared = valid_ds_int_index.map(lambda w, x, y, z: ((w, x), one_hot(x)),
                                          num_parallel_calls=tf.data.AUTOTUNE)\
                                          .prefetch(tf.data.AUTOTUNE)

# re-enable to mixed precision
tf.keras.mixed_precision.set_global_policy(PARAMETERS.mixed_precision())

Train base model

In [None]:
# Train base model (teacher-fed training, prediction-fed validation, no beam update)
if not PARAMETERS.tpu():
    steps_per_epoch = 512
    validation_steps = 128
    callbacks=[checkpoint, nan_stop, backup_checkpoint],# tensorboard]
    validation_freq = 6

else:
    steps_per_epoch = 12 * (4 * PARAMETERS.strategy().num_replicas_in_sync)
    validation_steps = PARAMETERS.strategy().num_replicas_in_sync  
    callbacks=[checkpoint, nan_stop]
    validation_freq = 10  # note: validation step causes major slowdown on TPU


epoch_multiple = 100
epochs = epoch_multiple * int(1.8 * 1e6) // (steps_per_epoch * PARAMETERS.batch_size())
    
# Important! Lock transfer model during initial training. (Optional) unlock after initial convergence
base_model.transfer_model.trainable = False  

# train
history = base_model.fit(train_ds_prepared.repeat(),
                         validation_data=valid_ds_prepared.repeat(),
                         epochs=epochs,
                         steps_per_epoch=steps_per_epoch,
                         validation_freq=validation_freq, 
                         validation_steps=validation_steps, 
                         callbacks=callbacks,
                         verbose=1)

Epoch 1/457


  "shape. This may consume a large amount of memory." % value)
  "shape. This may consume a large amount of memory." % value)


TPU-safe saving to local directory

In [None]:
base_model.save_weights(os.path.join(PARAMETERS.checkpoint_dir(), checkpoint_save_name, 'saved_model'), 
                        options=tf.saved_model.SaveOptions(experimental_io_device='/job:localhost'))

In [None]:
base_model.load_weights(os.path.join(PARAMETERS.checkpoint_dir(), checkpoint_save_name, 'saved_model'), 
                        options=tf.saved_model.SaveOptions(experimental_io_device='/job:localhost'))

Train beam update model

In [None]:
"""
Not yet implemented
"""

# Inference

Here we define function to conduct inference on the test set. Results are saved to "submission.csv".

Intermediate results are saved at regular intervals to. This allows inference to be conducted in stages and is a safeguard in case of interruptions before the full set has been processed. 

In [None]:
def make_inference_progress(dataset, model, return_lev_score=True, save_freq=50, parameters=PARAMETERS):

    batch_size = 1024
    est_num_batches = 2*10e7 // batch_size
    take_num = 100

    #initialize new dataframe
    predictions_df = pd.DataFrame(columns=['image_id', 'InChI', 'lev_score'])


    for i in range(int(est_num_batches // take_num)):
        try:

             # get predictions
            inference_outputs = run_inference(model, dataset, return_lev_score=True, 
                                              take_num=take_num, skip_set_num=i)
            
            im_id, pred, true_val, lev_score  = inference_outputs[:]

            # add to dataframe
            new_preds = pd.DataFrame({'image_id': im_id, 'InChI': pred, 'lev_score': lev_score})
            predictions_df = predictions_df.append(new_preds)

            # save to CSV
            if i % save_freq == 0:
                predictions_df = predictions_df.drop_duplicates(subset='image_id', keep='last')
                predictions_df[['image_id', 'InChI']].to_csv(PARAMETERS.csv_save_dir() + 'submission.csv', index=False)
                print(f'iteration {i}')

        except:
            print(f'completed at step {i}')
            break

    return predictions_df

Load previosuly generated predictions

In [None]:
try:
    predictions_df = pd.read_csv(PARAMETERS.csv_save_dir() + 'submission.csv')
except:
    predictions_df = pd.DataFrame({'image_id':[], 'InChI':[]}, dtype=str)

Generate additional predictions

In [None]:
""" On first pass or to start from scratch, initialize the dataframe with:
predictions_df = pd.DataFrame({'image_id':[], 'InChI':[]}, dtype=str)
"""

predictions_df = make_inference_progress(predictions_df, save_freq=100, 
                                         num_batches=1, starting_batch=0, 
                                         parameters=PARAMETERS)
predictions_df