In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
from tensorflow.keras import layers
import os
import matplotlib.pyplot as plt

tf.get_logger().setLevel('ERROR')

In [12]:
relations = ['Cause-Effect', 'Component-Whole', 'Entity-Destination', 'Product-Producer', 'Entity-Origin',
             'Member-Collection', 'Message-Topic', 'Content-Container', 'Instrument-Agency', 'Other']

dataset_dir = os.path.join("..", "dataset", "")
text_train_path = os.path.join(dataset_dir, "text_train.txt")
label_train_path = os.path.join(dataset_dir, "label_train.txt")
max_words = 10000

texts = list()
labels = list()
with open(text_train_path) as fp:
    while True:
        line = fp.readline()
        if not line:
            break
        texts.append(line)

with open(label_train_path) as fp:
    while True:
        label = fp.readline()
        if not label:
            break
        labels.append(label.split('(')[0])

X = tf.constant(texts)
Y = list(map(relations.index, labels))
Y = np.eye(len(relations))[Y]

In [13]:
preprocessor = hub.load("https://hub.tensorflow.google.cn/tensorflow/albert_en_preprocess/2")
text_inputs = [tf.keras.layers.Input(shape=(), dtype=tf.string)]
tokenize = hub.KerasLayer(preprocessor.tokenize)
tokenized_inputs = [tokenize(segment) for segment in text_inputs]

seq_length = 64
bert_pack_inputs = hub.KerasLayer(preprocessor.bert_pack_inputs, arguments=dict(seq_length=seq_length))  # Optional argument.
encoder_inputs = bert_pack_inputs(tokenized_inputs)

encoder = hub.KerasLayer("https://hub.tensorflow.google.cn/tensorflow/albert_en_base/2", trainable=False)
outputs = encoder(encoder_inputs)
pooled_output = outputs["pooled_output"]      
sequence_output = outputs["sequence_output"]

In [15]:
x = sequence_output
x = layers.Bidirectional(layers.LSTM(256, recurrent_dropout=0.2, dropout=0.2))(x)
outputs = layers.Dense(len(relations), activation='softmax')(x)
model = tf.keras.Model(inputs=text_inputs, outputs=outputs)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
history = model.fit(X, Y, epochs=10, batch_size=32, validation_split=0.2)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [8]:
texts

['system described greatest application arrayed configuration antenna elements\n',
 'child carefully wrapped bound cradle means cord\n',
 'author keygen uses disassembler look raw assembly code\n',
 'misty ridge uprises surge\n',
 'student association voice undergraduate student population state university new york buffalo\n',
 'sprawling complex peru largest producer silver\n',
 'current view chronic inflammation distal part stomach caused helicobacter pylori infection results increased acid production non infected upper corpus region stomach\n',
 'people moving back downtown\n',
 'lawsonite contained platinum crucible counter weight plastic crucible metal pieces\n',
 'solute placed inside beaker 5 ml solvent pipetted 25 ml glass flask trial\n',
 'fifty essays collected volume testify prominent themes professor quispel scholarly career\n',
 'composer sunk oblivion\n',
 'pulitzer committee issues official citation explaining reasons award\n',
 'burst caused water hammer pressure\n',
 '

In [25]:
aux = tf.keras.Model(inputs=model.inputs, outputs=[model.layers[1].output])
rs = aux.predict(X)

In [39]:
lens = [r.shape[0] for r in rs]

In [48]:
ls =  np.asarray(lens)
(ls < 56).sum()/6400

0.990625

In [5]:
test_path = os.path.join(dataset_dir, "test.txt")

test_texts = list()
with open(test_path) as fp:
    while True:
        line = fp.readline()
        if not line:
            break
        test_texts.append(line.split(" ", 1)[1])
        
X_test = tf.constant(test_texts)

preds = model.predict(X_test)
preds = np.argmax(preds,axis=1)
preds = list(map(lambda x:relations[x],preds))
pred_path = os.path.join('..','output', "prediction.txt")
with open(pred_path,'w+') as fp:
    fp.write("\n".join(preds))