In [1]:
%%capture
import sparknlp
from sparknlp.base import *
from sparknlp.annotator import *
from pyspark.sql.types import StringType, IntegerType

from pyspark.ml.feature import *
from pyspark.ml import Pipeline

spark = sparknlp.start()

In [14]:
df = spark.read.csv("gs://bdp_group6_bckt_2/data/processed_data/processed_data.parquet",
                    inferSchema=True, header=True,
                    multiLine=True, quote='\"', escape='\"')
df.printSchema()

[Stage 10:>                                                         (0 + 1) / 1]

root
 |-- paper_id: string (nullable = true)
 |-- cord_uid: string (nullable = true)
 |-- source_x: string (nullable = true)
 |-- publish_time: string (nullable = true)
 |-- journal: string (nullable = true)
 |-- body_text: string (nullable = false)
 |-- title: string (nullable = true)
 |-- abstract: string (nullable = true)
 |-- authors: string (nullable = true)



                                                                                

In [15]:
df = df.drop(*["authors", "publish_time", "journal", "abstract"])

df.show(5)

+--------+--------------------+--------------------+
|paper_id|               title|           body_text|
+--------+--------------------+--------------------+
|PMC35282|Clinical features...|Mycoplasma pneumo...|
|PMC59543|Nitric oxide: a p...|Since its discove...|
|PMC59549|Surfactant protei...|Surfactant protei...|
|PMC59574|Role of endotheli...|ET-1, ET-2, and E...|
|PMC59580|Gene expression i...|RSV and PVM are v...|
+--------+--------------------+--------------------+
only showing top 5 rows



In [24]:
# Pick first 5 papers as context and hand-pick a few questions
context_samples = df.limit(5).toPandas().body_text.tolist()

# Create a dataframe for testing
sample_texts = [
    ["What is the most common cause of atypical pneumonia?", context_samples[0]],
    ["What is the main reason why NO production is regulated?", context_samples[1]],
    ["What is the name of the molecule that binds to a pulmonary pathogen?", context_samples[2]],
    ["What factors can cause the development of pulmonary fibrosis?", context_samples[3]],
    ["How does pneumoviruses enter respiratory epithelial cells?", context_samples[4]]
]
qa_df = spark.createDataFrame(sample_texts).toDF("question", "context")
qa_df.show(truncate=50)

In [27]:
# Define a pipeline to test several models
def create_pipeline(model, model_name):
    document_assembler = MultiDocumentAssembler() \
        .setInputCols(["question", "context"]) \
        .setOutputCols(["document_question", "document_context"])
    
    qa_model = model.pretrained(name=model_name)\
        .setInputCols(["document_question", "document_context"])\
        .setOutputCol("answer") \
        .setCaseSensitive(False)
    
    pipeline = Pipeline().setStages([
        document_assembler,
        qa_model
    ])
    
    return pipeline

### English BertForQuestionAnswering Cased model (from Callmenicky) [[Model Details](https://nlp.johnsnowlabs.com/2022/07/07/bert_qa_callmenicky_finetuned_squad_en_3_0.html)]

In [28]:
pipeline1 = create_pipeline(BertForQuestionAnswering,
                            "bert_qa_callmenicky_finetuned_squad")
model1 = pipeline1.fit(qa_df).transform(qa_df)
model1.select("question", "answer.result").show(truncate=False)

bert_qa_callmenicky_finetuned_squad download started this may take some time.
Approximate size to download 385.6 MB
[ | ]bert_qa_callmenicky_finetuned_squad download started this may take some time.
Approximate size to download 385.6 MB
[ — ]Download done! Loading the resource.
[OK!]




+--------------------------------------------------------------------+-----------------------+
|question                                                            |result                 |
+--------------------------------------------------------------------+-----------------------+
|What is the most common cause of atypical pneumonia?                |[Mycoplasma pneumoniae]|
|What is the main reason why NO production is regulated?             |[]                     |
|What is the name of the molecule that binds to a pulmonary pathogen?|[]                     |
|What factors can cause the development of pulmonary fibrosis?       |[]                     |
|How does pneumoviruses enter respiratory epithelial cells?          |[]                     |
+--------------------------------------------------------------------+-----------------------+



                                                                                

### English DebertaForQuestionAnswering model (from nbroad) [[Model Details](https://nlp.johnsnowlabs.com/2022/06/15/deberta_v3_xsmall_qa_squad2_en_3_0.html)]

In [30]:
pipeline2 = create_pipeline(DeBertaForQuestionAnswering,
                            "deberta_v3_xsmall_qa_squad2")
model2 = pipeline2.fit(qa_df).transform(qa_df)
model2.select("question", "answer.result").show(truncate=False)

deberta_v3_xsmall_qa_squad2 download started this may take some time.
Approximate size to download 240.6 MB
[ | ]deberta_v3_xsmall_qa_squad2 download started this may take some time.
Approximate size to download 240.6 MB
[ — ]Download done! Loading the resource.
[OK!]




+--------------------------------------------------------------------+-----------------------+
|question                                                            |result                 |
+--------------------------------------------------------------------+-----------------------+
|What is the most common cause of atypical pneumonia?                |[mycoplasma pneumoniae]|
|What is the main reason why NO production is regulated?             |[]                     |
|What is the name of the molecule that binds to a pulmonary pathogen?|[]                     |
|What factors can cause the development of pulmonary fibrosis?       |[endothelins]          |
|How does pneumoviruses enter respiratory epithelial cells?          |[]                     |
+--------------------------------------------------------------------+-----------------------+



                                                                                

### English LongformerForQuestionAnswering model (from allenai) [[Model Details](https://nlp.johnsnowlabs.com/2022/06/26/longformer_qa_large_4096_finetuned_triviaqa_en_3_0.html)]

In [31]:
pipeline3 = create_pipeline(LongformerForQuestionAnswering,
                            "longformer_qa_large_4096_finetuned_triviaqa")
model3 = pipeline3.fit(qa_df).transform(qa_df)
model3.select("question", "answer.result").show(truncate=False)

longformer_qa_large_4096_finetuned_triviaqa download started this may take some time.
Approximate size to download 1.5 GB
[ | ]longformer_qa_large_4096_finetuned_triviaqa download started this may take some time.
Approximate size to download 1.5 GB
[ / ]Download done! Loading the resource.
[OK!]




+--------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|question                                                            |result                                                                                                                                                                                                                                                                                                                                                                                                                       

                                                                                

### English RobertaForQuestionAnswering (from deepset) [[Model Details](https://nlp.johnsnowlabs.com/2022/06/20/roberta_qa_roberta_base_squad2_covid_en_3_0.html)]

In [34]:
pipeline4 = create_pipeline(RoBertaForQuestionAnswering,
                            "roberta_qa_roberta_base_squad2_covid")
model4 = pipeline4.fit(qa_df).transform(qa_df)
model4.select("question", "answer.result").show(truncate=False)

roberta_qa_roberta_base_squad2_covid download started this may take some time.
Approximate size to download 442.8 MB
[OK!]




+--------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

                                                                                