In [None]:
import math
import os
import re
import json
import csv
import string
import numpy as np
import tensorflow as tf
import pickle
from tensorflow.data import Dataset
from tensorflow import keras
from tensorflow.keras import layers
from tokenizers import BertWordPieceTokenizer
from transformers import BertTokenizer, TFBertModel, BertConfig
from tqdm import tqdm
from random import randrange

CONTEXT_LEN = 512
MODEL_DIR = "/kaggle/input/huggingface-bert/"
configuration = BertConfig()  # default parameters and configuration for BERT

# Load the fast tokenizer from saved file
tokenizer = BertWordPieceTokenizer("../input/scibert-210605/vocab.txt", lowercase=True)

In [None]:
def load_model():
    ## BERT encoder
    encoder = TFBertModel.from_pretrained("../input/scibert-210605/", from_pt=True)

    input_ids = layers.Input(shape=(CONTEXT_LEN,), dtype=tf.int32)
    attention_mask = layers.Input(shape=(CONTEXT_LEN,), dtype=tf.int32)
    embedding = encoder(
        input_ids, attention_mask=attention_mask
    )[0]

    start_logits = layers.Dense(1, name="start_logit", use_bias=False)(embedding)
    start_logits = layers.Flatten()(start_logits)

    end_logits = layers.Dense(1, name="end_logit", use_bias=False)(embedding)
    end_logits = layers.Flatten()(end_logits)

    start_probs = layers.Activation(keras.activations.softmax)(start_logits)
    end_probs = layers.Activation(keras.activations.softmax)(end_logits)
    
    confidence = layers.Flatten()(embedding)
    confidence = layers.Dense(1, name="confidence", use_bias=False)(confidence)
    confidence = layers.Activation(keras.activations.sigmoid)(confidence)

    model = keras.Model(
        inputs=[input_ids, attention_mask],
        outputs=[start_probs, end_probs, confidence]
    )
    loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False)
    optimizer = keras.optimizers.Adam(lr=5e-5)
    model.compile(optimizer=optimizer, loss=[loss, loss, keras.losses.BinaryCrossentropy(from_logits=False)])
    
    with open("../input/coleridgebertweights/part2_scibert_len512.dat", "rb") as file:
        weights = pickle.load(file)
        
        for layer, layer_weights in zip(model.layers, weights):
            if layer_weights:
                layer.set_weights(layer_weights)
    
    return model

In [None]:
model = load_model()

In [None]:
def load_json(path, pub_id):
    filepath = path + "{}.json".format(pub_id)
    with open(filepath, "r") as file:
        return json.load(file)
    raise Error("could not open json file at '{}'".format(filepath))

def clean_text(txt):
    return re.sub('[^A-Za-z0-9]+', ' ', str(txt).lower())

def concat_sections(sections):
    return " ".join(section['text'] for section in sections)

def find_matches(text, label):
    esc_label = re.escape(label) # TODO ignore case
    return [match.start() for match in re.finditer(esc_label, text)]

In [None]:
with open("../input/crilabels/db_labels.dat", "rb") as file:
    labels = [clean_text(lbl) for lbl in pickle.load(file)]
print(labels)

In [None]:
def chunk_text(text):
    tokens = tokenizer.encode(text).ids

    chunk_count = int(math.ceil(len(tokens) / CONTEXT_LEN))
    
    flattened_ids = np.zeros((chunk_count * CONTEXT_LEN,), dtype=np.float32)
    flattened_masks = np.ones((chunk_count * CONTEXT_LEN,), dtype=np.float32)
    
    flattened_ids[:len(tokens)] = tokens
    flattened_masks[len(tokens):] = 0
    
    ids = flattened_ids.reshape((chunk_count, CONTEXT_LEN))
    masks = flattened_masks.reshape((chunk_count, CONTEXT_LEN))
    
    return [ids, masks]

def make_excerpts(text):
    chunk_ids, chunk_masks = chunk_text(text)
    start_probs, end_probs, confs = model.predict([chunk_ids, chunk_masks])
    
    label_chunks = np.ravel(confs) >= 0.5
    starts, ends = np.argmax(start_probs[label_chunks], axis=1), np.argmax(end_probs[label_chunks], axis=1) + 1
    excerpts = [tokenizer.decode(chunk[start:end + 1].astype(int)) for (chunk, start, end) in zip(chunk_ids[label_chunks], starts, ends)]
    
    cleaned_text = clean_text(text)
    matched_labels = [lbl for lbl in labels if re.search(lbl, cleaned_text)]
    excerpts.extend(matched_labels)
    
    return set(" ".join(clean_text(excerpt).split()) for excerpt in excerpts)

In [None]:
test_path = "../input/coleridgeinitiative-show-us-the-data/test"
_, _, filenames = next(os.walk(test_path))

with open("submission.csv", "w") as submissions:
    writer = csv.writer(submissions)

    # headers
    writer.writerow(["Id", "PredictionString"])
    # entries
    for filename in filenames:
        pub_id = filename[:-5]
        sections = load_json("../input/coleridgeinitiative-show-us-the-data/test/", pub_id)
        text = concat_sections(sections)
        excerpts = make_excerpts(text)
        writer.writerow([pub_id, "|".join(excerpts)])
