In [17]:
import re

from pyspark.ml import Pipeline
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
import sparknlp
from sparknlp.base import *
from sparknlp.annotator import *
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer

In [2]:
# Initialise Spark

spark = SparkSession.builder \
    .appName("Spark NLP")\
    .master("local[4]")\
    .config("spark.driver.memory","16G")\
    .config("spark.driver.maxResultSize", "0") \
    .config("spark.kryoserializer.buffer.max", "2000M")\
    .config("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.12:3.3.4")\
    .getOrCreate()

In [None]:
# Load BERT fine tuned model

class BertClassifier(nn.Module):
    """Bert Model for Classification Tasks.
    """
    def __init__(self, freeze_bert=False):
        """
        @param    bert: a BertModel object
        @param    classifier: a torch.nn.Module classifier
        @param    freeze_bert (bool): Set `False` to fine-tune the BERT model
        """
        super(BertClassifier, self).__init__()
        # Specify hidden size of BERT, hidden size of our classifier, and number of labels
        D_in, H, D_out = 768, 50, 17

        # Instantiate BERT model
        self.bert = BertModel.from_pretrained('bert-base-uncased')

        # Instantiate an one-layer feed-forward classifier
        self.classifier = nn.Sequential(
            nn.Linear(D_in, H),
            nn.ReLU(),
            #nn.Dropout(0.5),
            nn.Linear(H, D_out)
        )

        # Freeze the BERT model
        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False
        
    def forward(self, input_ids, attention_mask):
        """
        Feed input to BERT and the classifier to compute logits.
        @param    input_ids (torch.Tensor): an input tensor with shape (batch_size,
                      max_length)
        @param    attention_mask (torch.Tensor): a tensor that hold attention mask
                      information with shape (batch_size, max_length)
        @return   logits (torch.Tensor): an output tensor with shape (batch_size,
                      num_labels)
        """
        # Feed input to BERT
        outputs = self.bert(input_ids=input_ids,
                            attention_mask=attention_mask)
        
        # Extract the last hidden state of the token `[CLS]` for classification task
        last_hidden_state_cls = outputs[0][:, 0, :]

        # Feed input to classifier to compute logits
        logits = self.classifier(last_hidden_state_cls)

        return logits

model = torch.load('model/bert_trials.pth')
model.eval()

In [3]:
# Load data

CHEMBL_EVIDENCE_PATH = 'data/chembl-2021-08-23.json.gz'

stopReasons = (
        spark.read.json(CHEMBL_EVIDENCE_PATH)

        # Extract a test set
        .sample(0.01)

        # Extract studies with their reasons to stop
        .withColumn('urls', F.explode('urls'))
        .filter(F.col('urls.niceName').contains('ClinicalTrials'))
        .withColumn('nct_id', F.element_at(F.split(F.col('urls.url'), '%22'), -2))
        .select('nct_id', 'studyStopReason')
        .filter(F.col('studyStopReason').isNotNull())
        .distinct()
    )

## Create Pipeline

What the Pipeline should consist of:
- Document Assembler: converts the raw string to documents that Spark NLP can handle.
- Tokenize each document with a series of constraints:
  I have to reproduce this in SparkNLP's built-in tokenizer.
  ```
  encoded_sent = ( 
            BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
            .encode_plus(
            text=text_preprocessing(sent),  # Preprocess sentence
            add_special_tokens=True,        # Add `[CLS]` and `[SEP]`
            max_length=MAX_LEN,             # Max length to truncate/pad
            pad_to_max_length=True,         # Pad sentence to max length
            #return_tensors='pt',           # Return PyTorch tensor
            return_attention_mask=True      # Return attention mask
    )
  ```
- Create DataLoader. Transformers is fed with 2 Tensors: input_ids (the id representation of each token) and attention masks (mask that identifies whether a token is made out of padding).

In [16]:
document = (
    DocumentAssembler()
    .setInputCol('studyStopReason').setOutputCol('document')
)

tokenizer = Tokenizer().setInputCols('document').setOutputCol('token')

In [10]:
pipeline = Pipeline().setStages([
    document, tokenizer
])

model = pipeline.fit(stopReasons)

In [13]:
model.transform(stopReasons).first()

Row(nct_id='NCT00880373', studyStopReason='The funding withdrawal and early termination of the trial is based upon lack of suitable recruitment figures in order to reach the required trial endpoints.', document=[Row(annotatorType='document', begin=0, end=155, result='The funding withdrawal and early termination of the trial is based upon lack of suitable recruitment figures in order to reach the required trial endpoints.', metadata={'sentence': '0'}, embeddings=[])], token=[Row(annotatorType='token', begin=0, end=2, result='The', metadata={'sentence': '0'}, embeddings=[]), Row(annotatorType='token', begin=4, end=10, result='funding', metadata={'sentence': '0'}, embeddings=[]), Row(annotatorType='token', begin=12, end=21, result='withdrawal', metadata={'sentence': '0'}, embeddings=[]), Row(annotatorType='token', begin=23, end=25, result='and', metadata={'sentence': '0'}, embeddings=[]), Row(annotatorType='token', begin=27, end=31, result='early', metadata={'sentence': '0'}, embeddings=[

### Problem: you cannot add custom tokenizers to the pipeline. It'll have to be more manual

In [None]:
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

def clean_sentence(sentece:str) -> str:
    """
    - Remove entity mentions (eg. '@united')
    - Correct errors (eg. '&amp;' to '&')
    @param    text (str): a string to be processed.
    @return   text (Str): the processed string.
    """
    # Remove '@name'
    sentece = re.sub(r'(@.*?)[\s]', ' ', sentece)

    # Replace '&amp;' with '&'
    sentece = re.sub(r'&amp;', '&', sentece)

    # Remove trailing whitespace
    sentece = re.sub(r'\s+', ' ', sentece).strip()

    return sentece

def apply_bert_tokenizer(sentence:str, bert_tokenizer, max_len:int==64):

    cleaned_sentence = clean_sentence(sentence)

    return ( 
        bert_tokenizer
        .encode_plus(
        text=cleaned_sentence,  # Preprocess sentence
        add_special_tokens=True,        # Add `[CLS]` and `[SEP]`
        max_length=max_len,             # Max length to truncate/pad
        pad_to_max_length=True,         # Pad sentence to max length
        return_tensors='pt',           # Return PyTorch tensor
        return_attention_mask=True      # Return attention mask
    ))