In [2]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from keras.utils import np_utils
import tensorflow_hub as hub
import tensorflow as tf
import official.nlp.bert.tokenization as tokenization
import numpy as np


def create_raw_input(text, label, test_size):
    
    x = text
    y = label

    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=test_size, stratify=y)

    label_encoder = LabelEncoder()

    label_encoder.fit(y)
    encoded_y_train = label_encoder.transform(y_train)
    encoded_y_test = label_encoder.transform(y_test)

    dummy_y_train = np_utils.to_categorical(encoded_y_train)
    dummy_y_test = np_utils.to_categorical(encoded_y_test)
    
    return x_train, x_test, dummy_y_train, dummy_y_test

def pull_bert(bert_type):
    
    loaded_model = hub.load(bert_type)
    bert_layer = hub.KerasLayer(loaded_model, trainable = False)

    vocab_file = bert_layer.resolved_object.vocab_file.asset_path.numpy()
    do_lower_case = bert_layer.resolved_object.do_lower_case.numpy()
    tokenizer = tokenization.FullTokenizer(vocab_file, do_lower_case)
    
    return bert_layer, tokenizer


def encode_names(n, tokenizer):
    tokens = list(tokenizer.tokenize(n))
    tokens.append('[SEP]')
    return tokenizer.convert_tokens_to_ids(tokens)


def get_max_lengths(text, tokenizer):
    
    tweets = tf.ragged.constant([
        encode_names(n, tokenizer) for n in text
    ])

    tweet_lens = [len(tweet) for tweet in tweets]
    max_seq_len = int(max(tweet_lens))
    
    return max_seq_len


def get_lengths(text, tokenizer):
    
    tweets = tf.ragged.constant([
        encode_names(n, tokenizer) for n in text
    ])

    tweet_lens = [len(tweet) for tweet in tweets]
    
    return tweet_lens


def bert_encode(string_list, tokenizer, max_seq_length):
    num_examples = len(string_list)
    
    string_tokens = tf.ragged.constant([
        encode_names(n, tokenizer) for n in np.array(string_list)
    ])
    
    cls = [tokenizer.convert_tokens_to_ids(['[CLS]'])]*string_tokens.shape[0]
    input_word_ids = tf.concat([cls, string_tokens], axis=1)
    
    input_mask = tf.ones_like(input_word_ids).to_tensor(shape=(None, max_seq_length))
    
    type_cls = tf.zeros_like(cls)
    type_tokens = tf.ones_like(string_tokens)
    input_type_ids = tf.concat([type_cls, type_tokens], axis=1).to_tensor(shape=(None, max_seq_length))
    
    inputs = {
        'input_word_ids': input_word_ids.to_tensor(shape=(None, max_seq_length)),
        'input_mask': input_mask,
        'input_type_ids': input_type_ids
    }
    
    return inputs