Creates Tensorflow Graphs for spark-nlp DL Annotators and Models


In [None]:
import numpy as np
import os
import tensorflow as tf
import string
import random
import math
import sys
import shutil

from ner_model import NerModel
from dataset_encoder import DatasetEncoder
from ner_model_saver import NerModelSaver
from pathlib import Path

## SETTINGS

In [None]:
is_for_windows = True if os.name == 'nt' else False

use_contrib = False if is_for_windows else True

name_prefix = 'blstm-noncontrib' if is_for_windows else 'blstm'

In [None]:
def create_graph(ntags, embeddings_dim, nchars, lstm_size = 128):
    if sys.version_info[0] != 3 or sys.version_info[1] >= 7:
        print('Python 3.7 or above not supported by tensorflow')
        return
    if tf.__version__ != '1.12.0':
        print('Spark NLP is compiled with Tensorflo 1.12.0. Please use such version.')
        return
    tf.reset_default_graph()
    model_name = name_prefix+'_{}_{}_{}_{}'.format(ntags, embeddings_dim, lstm_size, nchars)
    with tf.Session() as session:
        ner = NerModel(session=None, use_contrib=use_contrib)
        ner.add_cnn_char_repr(nchars, 25, 30)
        ner.add_bilstm_char_repr(nchars, 25, 30)
        ner.add_pretrained_word_embeddings(embeddings_dim)
        ner.add_context_repr(ntags, lstm_size, 3)
        ner.add_inference_layer(True)
        ner.add_training_op(5)
        ner.init_variables()
        saver = tf.train.Saver()
        file_name = model_name + '.pb'
        tf.train.write_graph(ner.session.graph, './', file_name, False)
        ner.close()
        session.close()


### Attributes info
- 1st attribute: max number of tags (Must be at least equal to the number of unique labels, including O if IOB)
- 2nd attribute: embeddings dimension
- 3rd attribute: max number of characters processed (Must be at least the largest possible amount of characters)
- 4th attribute: LSTM Size (128)

In [None]:
create_graph(10, 100, 100)
# create_graph(10, 200, 100)
# create_graph(10, 300, 100)
# create_graph(10, 768, 100)
# create_graph(10, 1024, 100)
# create_graph(25, 300, 100)