In [2]:
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
from tensorflow.keras import layers
import os
import matplotlib.pyplot as plt

tf.get_logger().setLevel('ERROR')

In [3]:
relations = ['Cause-Effect', 'Component-Whole', 'Entity-Destination', 'Product-Producer', 'Entity-Origin',
             'Member-Collection', 'Message-Topic', 'Content-Container', 'Instrument-Agency', 'Other']

dataset_dir = os.path.join("..", "dataset", "")
train_path = os.path.join(dataset_dir, "train.txt")
max_words = 10000

texts = list()
labels = list()
with open(train_path) as fp:
    while True:
        line = fp.readline()
        label = fp.readline()
        if not line or not label:
            break
        texts.append(line.split(" ", 1)[1])
        labels.append(label.split('(')[0])

X = tf.constant(texts)
Y = list(map(relations.index, labels))
Y = np.eye(len(relations))[Y]

In [8]:
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string)
preprocessor = hub.KerasLayer("https://hub.tensorflow.google.cn/tensorflow/albert_en_preprocess/2")
encoder_inputs = preprocessor(text_input)
encoder = hub.KerasLayer("https://hub.tensorflow.google.cn/tensorflow/albert_en_base/2", trainable=False)
outputs = encoder(encoder_inputs)
pooled_output = outputs["pooled_output"]      
sequence_output = outputs["sequence_output"]
x = layers.Bidirectional(layers.LSTM(128, recurrent_dropout=0.2, dropout=0.2))(sequence_output)
outputs = layers.Dense(len(relations), activation='softmax')(x)

In [9]:
model = tf.keras.Model(inputs=text_input, outputs=outputs)
model.compile(loss='categorical_crossentropy',
              optimizer='adam', metrics=['accuracy'])

In [10]:
history = model.fit(X, Y, epochs=10, batch_size=32, validation_split=0.25)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7f79c4df85f8>

In [11]:
test_path = os.path.join(dataset_dir, "test.txt")

test_texts = list()
with open(test_path) as fp:
    while True:
        line = fp.readline()
        if not line:
            break
        test_texts.append(line.split(" ", 1)[1])
        
X_test = tf.constant(test_texts)

preds = model.predict(X_test)
preds = np.argmax(preds,axis=1)
preds = list(map(lambda x:relations[x],preds))
pred_path = os.path.join('..','output', "prediction.txt")
with open(pred_path,'w+') as fp:
    fp.write("\n".join(preds))