<a href="https://colab.research.google.com/github/martin-fabbri/jigsaw-multilingual-toxic-comment/blob/main/notebooks/tpu_jigsaw_multilingual_toxic_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Jigsaw Multilingual Toxic Comment Classification

In [22]:
%%capture
!pip install fsspec
!pip install gcsfs
!pip install --upgrade --pre dvc
!git clone https://github.com/martin-fabbri/jigsaw-multilingual-toxic-comment.git
%cd /content/jigsaw-multilingual-toxic-comment/

In [25]:
!pip list | grep dvc

dvc                           2.0.0a2       


In [24]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras import Model
import tensorflow_hub as hub
from matplotlib import pyplot as plt
import pandas as pd

## Local dataset

In [28]:
%%capture
!dvc pull -r origin data/raw/jigsaw-toxic-comment-train-processed-seqlen128.csv
!dvc pull -r origin data/raw/validation-processed-seqlen128.csv

## TPU setup

In [29]:
%%capture
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError:
    strategy = tf.distribute.MirroredStrategy()
print("Number of accelerators:", strategy.num_replicas_in_sync)

INFO:tensorflow:Initializing the TPU system: grpc://10.60.38.106:8470
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*

## Load config

In [30]:
SEQUENCE_LENGTH = 128
EPOCHS = 6
GCS_PATH = "gs://kds-d6b459191750de20505baf9adc31878a65fd287afd0812a8deb1cb15/"
TRAIN_PREFIX = "jigsaw-toxic-comment-train-processed-seqlen"
TRAIN_PROCESSED = f"{GCS_PATH}{TRAIN_PREFIX}{SEQUENCE_LENGTH}.csv"
#TRAIN_PROCESSED = "data/raw/jigsaw-toxic-comment-train-processed-seqlen128.csv"

VALID_PREFIX = "validation-processed-seqlen"
VALID_DATA = f"{GCS_PATH}{VALID_PREFIX}{SEQUENCE_LENGTH}.csv"

BERT_GCS_PATH = "gs://bert_multilingual_public/bert_multi_cased_L-12_H-768_A-12_2/"
BATCH_SIZE = 128 * strategy.num_replicas_in_sync
LR_MAX = 0.001 * strategy.num_replicas_in_sync
LR_EXP_DECAY = .9
LR_MIN = 0.0001
TRAIN_DATA_LENGTH = 223549  # count_dataset_steps(TRAIN_PROCESSED) = 223549 
STEPS_PER_EPOCH = TRAIN_DATA_LENGTH // BATCH_SIZE

print(f"EPOCHS:            {EPOCHS:,}")
print(f"BATCH_SIZE:        {BATCH_SIZE:,}")
print(f"STEPS_PER_EPOCH:   {STEPS_PER_EPOCH:,}")
print(f"TRAIN_DATA_LENGTH: {TRAIN_DATA_LENGTH:,}")

EPOCHS:            6
BATCH_SIZE:        1,024
STEPS_PER_EPOCH:   218
TRAIN_DATA_LENGTH: 223,549


## Explore dataset

In [31]:
%load_ext google.colab.data_table
comments = pd.read_csv(TRAIN_PROCESSED)
comments.head(2)

The google.colab.data_table extension is already loaded. To reload it, use:
  %reload_ext google.colab.data_table


Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate,input_word_ids,input_mask,all_segment_id
0,0000997932d777bf,Explanation\nWhy the edits made under my usern...,0,0,0,0,0,0,"(101, 27746, 31609, 11809, 24781, 10105, 70971...","(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,000103f0d9cfb60f,D'aww! He matches this background colour I'm s...,0,0,0,0,0,0,"(101, 141, 112, 56237, 10874, 106, 10357, 1825...","(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [32]:
def format_sentences(data, label="toxic", remove_language=False):
    labels = {"labels": data.pop(label)}
    if remove_language:
        languages = {"language": data.pop("lang")}
    for k, v in data.items():
        data[k] = parse_string_list_into_ints(v)
    return data, labels

In [33]:
def make_sentence_dataset_from_csv(
    filename, label="toxic", language_to_filter=None
):
    # This assumes the column order label, input_word_ids, input_mask, segment_ids
    SELECTED_COLUMNS = [label, "input_word_ids", "input_mask", "all_segment_id"]
    label_default = tf.int32 if label == "id" else tf.float32
    COLUMN_DEFAULTS = [label_default, tf.string, tf.string, tf.string]

    if language_to_filter:
        insert_pos = 0 if label != "id" else 1
        SELECTED_COLUMNS.insert(insert_pos, "lang")
        COLUMN_DEFAULTS.insert(insert_pos, tf.string)
    preprocessed_sentences_dataset = tf.data.experimental.make_csv_dataset(
        filename,
        column_defaults=COLUMN_DEFAULTS,
        select_columns=SELECTED_COLUMNS,
        batch_size=1,
        num_epochs=1,
        shuffle=False,
    )  # We'll do repeating and shuffling ourselves
    # make_csv_dataset required a batch size, but we want to batch later
    preprocessed_sentences_dataset = preprocessed_sentences_dataset.unbatch()

    if language_to_filter:
        preprocessed_sentences_dataset = preprocessed_sentences_dataset.filter(
            lambda data: tf.math.equal(
                data["lang"], tf.constant(language_to_filter)
            )
        )
        # preprocessed_sentences.pop('lang')
    preprocessed_sentences_dataset = preprocessed_sentences_dataset.map(
        lambda data: format_sentences(
            data, label=label, remove_language=language_to_filter
        )
    )

    return preprocessed_sentences_dataset
   

In [34]:
def parse_string_list_into_ints(strlist):
    s = tf.strings.strip(strlist)
    s = tf.strings.substr(
        strlist, 1, tf.strings.length(s) - 2)  # Remove parentheses around list
    s = tf.strings.split(s, ',', maxsplit=SEQUENCE_LENGTH)
    s = tf.strings.to_number(s, tf.int32)
    s = tf.reshape(s, [SEQUENCE_LENGTH])  # Force shape here needed for XLA compilation (TPU)
    return s

In [35]:
def format_sentences(data, label='toxic', remove_language=False):
    labels = {'labels': data.pop(label)}
    if remove_language:
        languages = {'language': data.pop('lang')}
    # The remaining three items in the dict parsed from the CSV are lists of integers
    for k,v in data.items():  # "input_word_ids", "input_mask", "all_segment_id"
        data[k] = parse_string_list_into_ints(v)
    return data, labels

In [36]:
ds = make_sentence_dataset_from_csv(TRAIN_PROCESSED)
X, y = next(iter(ds.take(1)))
for k, v in X.items():
    print(f"--- {k} ---")
    print(v.numpy())

UnimplementedError: ignored

In [None]:
y["labels"].numpy()

In [None]:
def count_dataset_steps(dataset):
    # to be used on small datasets only: iterates through entire dataset and counts
    cnt = 0
    for data in dataset:
        cnt += 1
    return cnt

In [None]:
# count_dataset_steps(ds)

In [None]:
def make_dataset_pipeline(dataset, repeat_and_shuffle=True):
    cached_dataset = dataset.cache()
    if repeat_and_shuffle:
        cached_dataset = cached_dataset.repeat().shuffle(2048)
        cached_dataset = cached_dataset.batch(BATCH_SIZE, drop_remainder=True)
    else:
        cached_dataset = cached_dataset.batch(BATCH_SIZE)
    cached_dataset = cached_dataset.prefetch(tf.data.experimental.AUTOTUNE)
    return cached_dataset

In [None]:
english_train_dataset = make_dataset_pipeline(
    make_sentence_dataset_from_csv(TRAIN_PROCESSED)
)
english_train_dataset

In [None]:
non_english_val_datasets = {}
non_english_val_datasets_steps = {}
for language_name, language_label in [("Spanish", "es"), ('Italian', 'it')]:
    non_english_val_datasets[language_name] = make_sentence_dataset_from_csv(
        VALID_DATA, language_to_filter=language_label
    )
    non_english_val_datasets[language_name] = make_dataset_pipeline(
        non_english_val_datasets[language_name], repeat_and_shuffle=False
    )
    non_english_val_datasets_steps[language_name] = count_dataset_steps(
        non_english_val_datasets[language_name]
    )

non_english_val_datasets["Combined"] = make_sentence_dataset_from_csv(
    VALID_DATA
)
non_english_val_datasets["Combined"] = make_dataset_pipeline(
    non_english_val_datasets["Combined"], repeat_and_shuffle=False
)
non_english_val_datasets_steps["Combined"] =  count_dataset_steps(
    non_english_val_datasets["Combined"]
)

## Model

In [17]:
def multilingual_bert_model(max_seq_lenght=SEQUENCE_LENGTH):
    """Build and return a multilingual BERT model and tokenizer."""
    input_word_ids = Input(
        shape=(max_seq_lenght,), dtype=tf.int32, name="input_word_ids"
    )
    input_mask = Input(
        shape=(max_seq_lenght,), dtype=tf.int32, name="input_mask"
    )
    segment_ids = Input(
        shape=(max_seq_lenght,), dtype=tf.int32, name="all_segment_id"
    )

    bert_layer = tf.saved_model.load(BERT_GCS_PATH)
    bert_layer = hub.KerasLayer(bert_layer, trainable=True)

    pooled_output, _ = bert_layer([input_word_ids, input_mask, segment_ids])
    output = Dense(32, activation="relu")(pooled_output)
    output = Dense(1, activation="sigmoid", name="labels", dtype=tf.float32)(
        output
    )

    return Model(
        inputs={
            "input_word_ids": input_word_ids,
            "input_mask": input_mask,
            "all_segment_id": segment_ids,
        },
        outputs=output,
    )

In [18]:
with strategy.scope():
    multilingual_bert = multilingual_bert_model()

    multilingual_bert.compile(
        loss = tf.keras.losses.BinaryCrossentropy(),
        optimizer=tf.keras.optimizers.SGD(learning_rate=0.001 * strategy.num_replicas_in_sync),
        metrics = [tf.keras.metrics.AUC()],
        steps_per_execution=16
    )

multilingual_bert.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_word_ids (InputLayer)     [(None, 128)]        0                                            
__________________________________________________________________________________________________
input_mask (InputLayer)         [(None, 128)]        0                                            
__________________________________________________________________________________________________
all_segment_id (InputLayer)     [(None, 128)]        0                                            
__________________________________________________________________________________________________
keras_layer (KerasLayer)        [(None, 768), (None, 177853441   input_word_ids[0][0]             
                                                                 input_mask[0][0]             

In [19]:
def lr_fn(epoch):
    lr = (LR_MAX - LR_MIN) * LR_EXP_DECAY**(epoch) + LR_MIN
    return lr

In [20]:
# Train on English Wikipedia comment data.
lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_fn)
history = multilingual_bert.fit(
    english_train_dataset,
    steps_per_epoch=STEPS_PER_EPOCH,
    epochs=EPOCHS,
    # validation_data=non_english_val_datasets["Combined"],
    # validation_steps=non_english_val_datasets_steps["Combined"],
    callbacks=[lr_callback],
)

Epoch 1/6


  "shape. This may consume a large amount of memory." % value)
  "shape. This may consume a large amount of memory." % value)


Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6
