In [None]:
!pip install tensorflow_text

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn import metrics

import tensorflow as tf
import tensorflow_hub as tf_hub
import tensorflow_text as tf_text

In [None]:
SEQUENCE_LENGTH = 128
BATCH_SIZE = 32

In [None]:
attributes = [
    'antagonize' , 'condescending', 'dismissive', 'generalisation',
    'generalisation_unfair', 'healthy', 'hostile', 'sarcastic']

Build an model based on English

In [None]:
def model(batch_size, length, output_size, trainable_bert=True):
  """Build and return a BERT model and tokenizer."""
  inputs = {
      'word_ids': tf.keras.layers.Input(
          shape=(None,), dtype=tf.int32, name='word_ids'),
  }
  bert_layer = tf_hub.KerasLayer(
      'https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/2',
      trainable=trainable_bert)
  vocab_file = bert_layer.resolved_object.vocab_file.asset_path
  cased = bert_layer.resolved_object.do_lower_case

  ids = inputs['word_ids']
  input_mask = tf.cast(tf.cast(ids, tf.bool), tf.int32)
  segment_ids = tf.zeros_like(ids, tf.int32)

  pooled_output, _ = bert_layer([ids, input_mask, segment_ids])
  output = tf.keras.layers.Dense(batch_size, activation='tanh')(pooled_output)
  outputs = tf.keras.layers.Dense(
      output_size, activation='sigmoid', name='labels')(
          output)

  return tf.keras.Model(inputs=inputs, outputs=outputs), vocab_file, cased

In [None]:
bert, vocab_file, cased = model(
    batch_size=BATCH_SIZE, length=SEQUENCE_LENGTH, output_size=len(attributes))
bert.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    optimizer=tf.keras.optimizers.SGD(learning_rate=0.005),
    metrics=tf.keras.metrics.AUC(multi_label=True))

In [None]:
tokenizer = tf_text.BertTokenizer(vocab_file, lower_case=cased)
# [CLS] and [SEP] token handling is not part of tf text http://b/160406014
cls, sep = tokenizer._wordpiece_tokenizer._vocab_lookup_table.lookup(
    tf.constant(['[CLS]', '[SEP]'])).numpy().tolist()

Preprocessing

In [None]:
def bert_feat(text, label=None):
  """Maps text into the bert word_ids by tokenizing."""
  rows = tf.size(text)
  tokens = tokenizer.tokenize(text).merge_dims(-2, -1)
  left = tf.fill((rows, 1), tf.cast(cls, dtype=tokens.dtype))
  right = tf.fill((rows, 1), tf.cast(sep, dtype=tokens.dtype))
  ids = {'word_ids': tf.concat([left, tokens, right], axis=1).to_tensor(
      0, shape=(None, SEQUENCE_LENGTH))}
  if label is not None:
    return (ids, {'labels': tf.cast(label, dtype=tf.float32)})
  return ids

Load Data

In [None]:
from google.colab import files
files.upload()

In [None]:
data = pd.read_csv('unhealthy_aggregated.csv')

In [None]:
train = data.loc[data['_unit_id'] % 10 < 8]
dev = data.loc[data['_unit_id'] % 10 == 8]
test = data.loc[data['_unit_id'] % 10 == 9]

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices(
    (train['comment'].astype(str), train[attributes]))

In [None]:
validation_dataset = tf.data.Dataset.from_tensor_slices(
    (dev['comment'].astype(str), dev[attributes])).batch(
        BATCH_SIZE, drop_remainder=True).map(bert_feat)

In [None]:
cached_train = train_dataset.repeat().shuffle(1024).batch(
        BATCH_SIZE, drop_remainder=True).map(bert_feat).prefetch(
            tf.data.experimental.AUTOTUNE)

In [None]:
steps_per_epoch = int(train.shape[0] / BATCH_SIZE)

In [None]:
history = bert.fit(
    cached_train, steps_per_epoch=steps_per_epoch, epochs=4,
    validation_data=validation_dataset, verbose=1)

In [None]:
test_dataset = tf.data.Dataset.from_tensor_slices(
    (test['comment'].astype(str),  test[attributes]))

In [None]:
predictions = pd.DataFrame(bert.predict(
    test_dataset.batch(BATCH_SIZE).map(bert_feat)),
    columns=attributes, index=test.index)

In [None]:
plt.figure()
for attribute in attributes:
  fpr, tpr, _ = metrics.roc_curve(
      test[attribute].astype(bool), predictions[attribute])
  auc = metrics.roc_auc_score(
      test[attribute].astype(bool), predictions[attribute])
  plt.plot(fpr, tpr, label='%s %g' % (attribute, auc))
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc='lower right')

  