In [1]:
%run ../flaubert_token_classification.py

In [2]:
import re
import os
import csv
import tensorflow as tf
import collections
import numpy as np

from transformers import (
    TFFlaubertForSequenceClassification,
    FlaubertTokenizer
)

# Dataset

## Import

In [3]:
ROOT_FOLDER = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) + "/"
MODELS_PATH = ROOT_FOLDER + "models/"
DATASET_PATH = ROOT_FOLDER + "dataset/custom_dataset/"
SEQUENCE_LENGTH = 64

## Transform dataset

In [6]:
models = {
    "ner": TFFlaubertForTokenClassification.from_pretrained(MODELS_PATH + "ner"),
    "pos": TFFlaubertForTokenClassification.from_pretrained(MODELS_PATH + "pos"),
    "categorisation": TFFlaubertForSequenceClassification.from_pretrained(MODELS_PATH + "categorisation")
}

In [7]:
tokenizer = FlaubertTokenizer.from_pretrained("jplu/tf-flaubert-base-cased")
SEQUENCE_LENGTH = 64

In [8]:
writer = tf.io.TFRecordWriter("./test_caching.tf_record")

from time import time
t1 = time()

def create_int_feature(values):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))

def create_float_feature(values):
    return tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))

with open(DATASET_PATH + 'shuffled_since_january.csv', 'r', newline='') as csvfile:
    reader = csv.reader(csvfile, delimiter=',', quotechar='"')
    for idx, article in enumerate(reader):
        if idx % 100 == 0:
            print("Writing example", idx, "of 553030 : ", str(round(idx/553030 * 100, 2)) + "%")
            elapsed = time() - t1
            print('Elapsed time is %f seconds.' % elapsed)
        if idx > 1000:
            break
        if len(article) < 1:
            continue

        input_tokens = tokenizer.encode_plus(
            article[0],
            max_length=SEQUENCE_LENGTH,
            pad_to_max_length=SEQUENCE_LENGTH,
            add_special_tokens=True,
            return_tensors='tf',
            return_token_type_ids=True,
            return_attention_mask=True,
        )

        inputs = {
            "attention_mask": input_tokens["attention_mask"],
            "token_type_ids": input_tokens["token_type_ids"],
            "training": False
        }

        transformer_outputs = models["categorisation"].transformer(input_tokens["input_ids"], **inputs)
        output = transformer_outputs[0]
        
        record_feature = collections.OrderedDict()

        ner_predictions = models["ner"](input_tokens["input_ids"], **inputs)
        pos_predictions = models["pos"](input_tokens["input_ids"], **inputs)
        cat_predictions = models["categorisation"](input_tokens["input_ids"], **inputs)

        record_feature["cls_token"] = create_float_feature(output[0][0])
        record_feature["ner"] = create_int_feature(np.argmax(ner_predictions[0], axis=2)[0])
        record_feature["pos"] = create_int_feature(np.argmax(pos_predictions[0], axis=2)[0])
        record_feature["categorisation"] = create_float_feature(np.round(cat_predictions[0])[0])
        record_feature["label_idx"] = create_int_feature([idx])

        tf_example = tf.train.Example(features=tf.train.Features(feature=record_feature))
        writer.write(tf_example.SerializeToString())
            
writer.close()

Writing example 0 of 553030 :  0.0%
Elapsed time is 0.000998 seconds.
Writing example 100 of 553030 :  0.02%
Elapsed time is 86.205768 seconds.
Writing example 200 of 553030 :  0.04%
Elapsed time is 176.172176 seconds.
Writing example 300 of 553030 :  0.05%
Elapsed time is 264.363220 seconds.
Writing example 400 of 553030 :  0.07%
Elapsed time is 359.242074 seconds.
Writing example 500 of 553030 :  0.09%
Elapsed time is 452.591935 seconds.
Writing example 600 of 553030 :  0.11%
Elapsed time is 537.975520 seconds.
Writing example 700 of 553030 :  0.13%
Elapsed time is 621.078281 seconds.
Writing example 800 of 553030 :  0.14%
Elapsed time is 703.591844 seconds.
Writing example 900 of 553030 :  0.16%
Elapsed time is 799.414015 seconds.
Writing example 1000 of 553030 :  0.18%
Elapsed time is 901.373733 seconds.
Writing example 1100 of 553030 :  0.2%
Elapsed time is 1003.186386 seconds.
Writing example 1200 of 553030 :  0.22%
Elapsed time is 1095.906729 seconds.
Writing example 1300 of 553