## Purpose: To train/evaluate a custom relation extraction model/prediction pipeline for Sites of Metastases 

RE model training (required entity pairs):
- cancer imaging findings - body part
- cancer imaging findings - anatomical descriptor
- anatomical descriptor - body part
- direction - body part [this is currently not used in sites of mets prediction]
- direction - anatomical descriptor [this is currently not used in sites of mets prediction]

Single pipeline for Sites of Mets prediction:
- custom NER
- custom Assertion
- custom RE
- post processing for term normalization (using mapping files)

refer: https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Healthcare/10.Clinical_Relation_Extraction.ipynb#scrollTo=6doZTPX_xnEm

In [None]:
# uncomment to run to create the subfolders, for the first time
#!mkdir re_graphs saved_models re_output re_result re_logs inference

### Note: Before running this notebook, please configure the following paths

In [None]:
# we are using sparknlp clinical embedding word model
# specify your folder containing the downloaded clinical embedding word model file, or you can use .pretrained during training instead to load it online
embeddings_clinical_local_path = r"path\to\sparknlp_pretrained\embeddings_clinical_en_2.4.0_2.4_1580237286004"
model_type = "clinical_embeddings"

In [None]:
# specify your sparknlp online license key-need internet connection
# we are using v3.4.2
sparknlp_licence_key = r"..\sparknlp_licence_key\yourkey.json"

# specify your sparknlp offline license key-airgap env
# we are using v3.4.2
sparknlp_airgap_licence_key = r"..\sparknlp_licence_key\yourairgapkey.json"

## Import Libraries

In [None]:
import json, os, re, sparknlp, sparknlp_jsl, datetime, time
import pandas as pd
import numpy as np
import glob
import os

from pyspark.ml import Pipeline
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from sparknlp.base import *
from sparknlp.annotator import *
from sparknlp.training import CoNLL
from sparknlp_jsl.annotator import *
from sparknlp_jsl.training import tf_graph
from sparknlp_display import AssertionVisualizer, NerVisualizer,RelationExtractionVisualizer 

from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
from sklearn.model_selection import train_test_split

## Start Spark Session (offline)

In [None]:
# Offline-Load airgap license key
with open(sparknlp_airgap_licence_key) as f:
    airgap_license_keys = json.load(f)
    
# Defining license key-value pairs as local variables
locals().update(airgap_license_keys)
os.environ.update(airgap_license_keys)

# check variable
!echo $SECRET
!echo $JSL_VERSION
!echo $PUBLIC_VERSION

os.environ['PYSPARK_PYTHON'] = 'python'
os.environ['PYSPARK_DRIVER_PYTHON'] = 'jupyter'
print(os.environ['PYSPARK_PYTHON'])
print(os.environ['PYSPARK_DRIVER_PYTHON'])

# use this 20-sep-2022
# Start Spark Session with Custom Params (OFFLINE)
# https://spark.apache.org/docs/latest/configuration.html#memory-management
# Important! memory setting need to be adjusted for different work load 

def start(SECRET):
    builder = SparkSession.builder \
        .appName("Spark NLP Licensed radio mets jupyter") \
        .master("local[48]") \
        .config("spark.driver.memory", "90G") \
        .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
        .config("spark.kryoserializer.buffer.max", "2000M") \
        .config("spark.driver.maxResultSize","8000M") \
        .config("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.12:3.4.2") \
        .config("spark.jars", f"d:\content\spark-nlp-jsl-{JSL_VERSION}.jar, d:\airgap\spark-nlp_2.12-3.4.2.jar" )

    return builder.getOrCreate()


print("Spark NLP Version :", sparknlp.version())
print("Spark NLP_JSL Version :", sparknlp_jsl.version())

spark = start(SECRET) 

spark

## Start Spark Session (online)

Note: Requires Spark NLP and Spark NLP for Healthcare (licensed version) packages to be installed

## Import train/test data

In [None]:
## configure folder path
data_folder = "dataset"
train_folder = data_folder+"\\02csv"
dataset_name = "train4522"

In [None]:
# read in saved train.csv
df_csv_clean = pd.read_csv(os.path.join(train_folder,"radio_re_"+dataset_name+"_sitesofmets_relations_clean.csv"))

## create pyspark schema

In [None]:
from pyspark.sql.types import StructType,StructField, StringType, IntegerType
#Create User defined Custom Schema using StructType
df1Schema = StructType([StructField("relation", StringType(), True)\
                        ,StructField("pairs", StringType(), True)\
                        ,StructField("entity1", StringType(), True)\
                        ,StructField("chunk1", StringType(), True)\
                        ,StructField("entity2", StringType(), True)\
                        ,StructField("chunk2", StringType(), True)\
                        ,StructField("entity1_begin", IntegerType(), True)\
                        ,StructField("entity1_end", IntegerType(), True)\
                        ,StructField("entity2_begin", IntegerType(), True)\
                        ,StructField("entity2_end", IntegerType(), True)\
                        ,StructField("doc_text", StringType(), True)\
                        ,StructField("doc_title", StringType(), True)\
                        ,StructField("dataset", StringType(), True)])

## create spark dataframe

In [None]:
data=spark.createDataFrame(df_csv_clean, schema=df1Schema) 
data.count()

In [None]:
data.show(1)

In [None]:
# sort by dataset,relation
data.groupby(['dataset','relation']).count().sort(['dataset','relation']).show(10,truncate=False)

## Create a graph

In [None]:
#!pip install tensorflow==1.15

In [None]:
#!pip install -q tensorflow-addons

In [None]:
# Create Output Dataframe to store model performance metrics
output_df = pd.DataFrame(columns = ['re_model','ner_model','rels_set','trainset_count', 'testset_count','epoch', 'learning_rate', 'batch_size','start_time', 'end_time', 'duration', 'overall_accuracy','class_accuracy', 'classification_report','confusion_matrix'])

## Data Preparation (RE_SITES_OF_METS)

In [None]:
import pyspark.sql.types as T

#Annotation structure
annotationType = T.StructType([
            T.StructField('annotatorType', T.StringType(), False),
            T.StructField('begin', T.IntegerType(), False),
            T.StructField('end', T.IntegerType(), False),
            T.StructField('result', T.StringType(), False),
            T.StructField('metadata', T.MapType(T.StringType(), T.StringType()), False),
            T.StructField('embeddings', T.ArrayType(T.FloatType()), False)
        ])

#UDF function to convert train data to names entitities

@F.udf(T.ArrayType(annotationType))
def createTrainAnnotations(begin1, end1, begin2, end2, chunk1, chunk2, label1, label2):
    
    entity1 = sparknlp.annotation.Annotation("chunk", begin1, end1, chunk1, {'entity': label1.lower(), 'sentence': '0'}, [])
    entity2 = sparknlp.annotation.Annotation("chunk", begin2, end2, chunk2, {'entity': label2.lower(), 'sentence': '0'}, [])    
        
    entity1.annotatorType = "chunk"
    entity2.annotatorType = "chunk"

    return [entity1, entity2]    

train_pair = ['cancer_imaging_findings-body_part', \
            'cancer_imaging_findings-anatomical_descriptor', \
            'anatomical_descriptor-body_part',\
            'direction-body_part',\
            'direction-anatomical_descriptor']

# start of data preparation
valid_rel_query = "(" + " OR ".join(["pairs = '{}'".format(p) for p in train_pair]) + ")"
print(valid_rel_query)

data2 = data\
  .withColumn("entity1_begin", F.expr("cast(entity1_begin AS Int)"))\
  .withColumn("entity1_end", F.expr("cast(entity1_end AS Int)"))\
  .withColumn("entity2_begin", F.expr("cast(entity2_begin AS Int)"))\
  .withColumn("entity2_end", F.expr("cast(entity2_end AS Int)"))\
  .where("entity1_begin IS NOT NULL")\
  .where("entity1_end IS NOT NULL")\
  .where("entity2_begin IS NOT NULL")\
  .where("entity2_end IS NOT NULL")\
  .where(valid_rel_query)\
  .withColumn(
      "train_ner_chunks", 
      createTrainAnnotations(
          "entity1_begin", "entity1_end", "entity2_begin", "entity2_end", "chunk1", "chunk2", "entity1", "entity2"
      ).alias("train_ner_chunks", metadata={'annotatorType': "chunk"}))

train_data = data2.where("dataset='train'")
test_data = data2.where("dataset='test'")

#===================================================
print("total row count:",data2.count())

trainset_count = train_data.groupby('pairs').count().collect()
testset_count = test_data.groupby('pairs').count().collect()


In [None]:
train_data.groupby('relation').count().sort('relation').show(50,truncate=False)

In [None]:
test_data.groupby('relation').count().sort('relation').show(50,truncate=False)

## ------------------- START OF TRAINING  --------------------

## Training Pipeline (RE_SITES_OF_METS)

In [None]:
# specify the name of NER model
radio_ner_model = "clinical_embeddings_5_8_0.001_u0.4o1_train4522"

# specify the name of assertion model
radio_assertion_model = "radio_assertion_model_10_16_0.001_2022_11_18_15_39_34"

In [None]:
# remove sentencer, change setInputCol from sentences to document as our training chunk start/end position is with respect to each document

# training hyperparameters
#epoch = 70
epoch = 50
#batch_size = 8
batch_size = 16
#learning_rate = 0.001
learning_rate = 0.005

#======================================
# start training
#======================================
start = time.ctime()
start2 = time.time()
print('start time for training: ', start)
print('...setup training pipeline')

# this is for document level RE
documenter = DocumentAssembler()\
    .setInputCol("doc_text")\
    .setOutputCol("document")

tokenizer = Tokenizer()\
    .setInputCols(["document"])\
    .setOutputCol("tokens")\

words_embedder = WordEmbeddingsModel()\
    .load(embeddings_clinical_local_path)\
    .setInputCols(["document", "tokens"])\
    .setOutputCol("embeddings")

#use .pretrained("pos_clinical", "en", "clinical/models") for sparknlp online session
#use .load() for sparknlp airgap session
pos_tagger = PerceptronModel()\
    .load("./pretrained/pos_clinical_en_3.0.0_3.0_1617052315327") \
    .setInputCols(["document", "tokens"])\
    .setOutputCol("pos_tags")

#use .pretrained("dependency_conllu", "en") for sparknlp online session
#use .load() for sparknlp airgap session
dependency_parser = DependencyParserModel()\
    .load("./pretrained/dependency_conllu_en_3.0.0_3.0_1656858083101")\
    .setInputCols(["document", "pos_tags", "tokens"])\
    .setOutputCol("dependencies")

# set training params and upload model graph (see ../Healthcare/8.Generic_Classifier.ipynb)
reApproach = RelationExtractionApproach()\
    .setInputCols(["embeddings", "pos_tags", "train_ner_chunks", "dependencies"])\
    .setOutputCol("relations")\
    .setLabelColumn("relation")\
    .setEpochsNumber(epoch)\
    .setBatchSize(batch_size)\
    .setDropout(0.2)\
    .setLearningRate(learning_rate)\
    .setModelFile("./re_graph/rel_in1200_out57.pb")\
    .setFixImbalance(True)\
    .setValidationSplit(0.10)\
    .setFromEntity("entity1_begin", "entity1_end", "entity1")\
    .setToEntity("entity2_begin", "entity2_end", "entity2")\
    .setOutputLogsPath('./re_logs')

finisher = Finisher()\
    .setInputCols(["relations"])\
    .setOutputCols(["relations_out"])\
    .setCleanAnnotations(False)\
    .setValueSplitSymbol(",")\
    .setAnnotationSplitSymbol(",")\
    .setOutputAsArray(False)

re_mets_train_pipeline = Pipeline(stages=[
    documenter, 
    tokenizer, 
    words_embedder, 
    pos_tagger, 
    dependency_parser, 
    reApproach,
    finisher
])

print('...train model')
%time re_mets_model = re_mets_train_pipeline.fit(train_data)
print('...training completed')
done = time.ctime()
done2 = time.time()
duration = done2-start2
print('end time for training: ', done)
#======================================
# end training
#======================================

In [None]:
# save model
re_mets_model_name = "re_sites_of_mets_"+str(epoch)+"_"+str(batch_size)+"_"+str(learning_rate)+"_"+str(datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")+"_"+dataset_name)
print('...save models to folder: ./saved_models/'+re_mets_model_name)
re_mets_model.stages[-2].write().overwrite().save('./saved_models/'+re_mets_model_name)

In [None]:
#======================================
# model evaluation
#======================================
print('...evaluate model')
pred_result = re_mets_model.transform(test_data)
print('...evaluation completed')


In [None]:
pred_result.select('relation','entity1','chunk1','entity2','chunk2','relations_out').show(10)

In [None]:
# so save prediction in csv using parquet write, takes about 5mins, depending on filesize
filename = './re_output/'+re_mets_model_name+'_prediction'
print(filename)
pred_result.select('relation','entity1','chunk1','entity2','chunk2','relations_out').coalesce(1).write.options(header=True).mode('overwrite').csv(filename)

# generated parquet csv filename is in this format : part-00000-c4d1d33c-1254-4716-9866-44a79e73be35-c000.csv

In [None]:
# get the output filename in the above folder
# eg part-00000-199c0e31-85c9-4f4c-a894-d6ca8cdaaf38-c000.csv

In [None]:
# read in parquet csv to get y_true, y_pred
pred_df = pd.read_csv('./re_output/'+re_mets_model_name+'_prediction'+'//'+'part-00000-199c0e31-85c9-4f4c-a894-d6ca8cdaaf38-c000.csv')
pred_df.head(1)

In [None]:
# get model performance metrics

y_true = pred_df["relation"].values
y_pred = pred_df["relations_out"].values
accuracy = accuracy_score(y_true,y_pred)
print("accuracy: ", accuracy)

report = classification_report(y_true,y_pred, digits=4, labels=np.unique(y_true))
print(report)

cm = confusion_matrix(y_true,y_pred)
print(cm)

# get per class accuracy
# https://stackoverflow.com/questions/39770376/scikit-learn-get-accuracy-scores-for-each-class
classes=np.unique(y_true)

# We will store the results in a dictionary for easy access later
per_class_accuracies = {}

# Calculate the accuracy for each one of our classes
for idx, cls in enumerate(classes):
    # True negatives are all the samples that are not our current GT class (not the current row) 
    # and were not predicted as the current class (not the current column)
    true_negatives = np.sum(np.delete(np.delete(cm, idx, axis=0), idx, axis=1))
    
    # True positives are all the samples of our current GT class that were predicted as such
    true_positives = cm[idx, idx]
    
    # The accuracy for the current class is ratio between correct predictions to all predictions      
    # 03-jul-2023: dont consider TN, use TP/(TP+FP+FN), same formulae for whole manuscript
    per_class_accuracies[cls] = (true_positives) / (np.sum(cm)-true_negatives) 

    
# Combine class accuracies to classification report
report_dict = classification_report(y_true,y_pred, digits=4, labels=np.unique(y_true), output_dict=True)
classification_report_df = pd.DataFrame(report_dict).transpose()
per_class_accuracies_df = pd.DataFrame.from_dict(per_class_accuracies, orient='index', columns=['class_accuracy']) 
combine_report_df = pd.concat([per_class_accuracies_df,classification_report_df], axis=1)

# save performance to csv
# model,rels_set,trainset_count,epoch,learning_rate,batch_size,start_time,end_time,duration,accuracy,classification_report,confusion_matrix
to_append = [re_mets_model_name,radio_ner_model,train_pair,trainset_count,testset_count,epoch,learning_rate,batch_size,start,done,duration,accuracy,per_class_accuracies_df,report,cm]
df_length = len(output_df)

output_df.loc[df_length] = to_append
filename_prefix = "./re_result/radio_re_sites_of_mets" + "_" + str(datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S"))
filename = "%s.csv" % filename_prefix
output_df.to_csv(filename, header=True)
print(50*'-')
print("<<<Model Performance saved!>>>")
print(50*'-')
print(50*'-')

## ------------------- END OF TRAINING  --------------------

## ------------------- MODEL INFERENCE --------------------
same code as 04predict_radio_re_sites_of_mets

### Test single Prediction Pipeline NER + Assertion Detection + RE_SITES_OF_METS)
copy the radio NER model and radio cancer assertion model to be used for this pipeline in radio_re_model

### Load trained model from disk

In [None]:
import pandas as pd

def get_relations_df (results, col='relations'):
  rel_pairs=[]
  for rel in results[0][col]:
      rel_pairs.append((
          rel.result, 
          rel.metadata['entity1'], 
          rel.metadata['entity1_begin'],
          rel.metadata['entity1_end'],
          rel.metadata['chunk1'], 
          rel.metadata['entity2'],
          rel.metadata['entity2_begin'],
          rel.metadata['entity2_end'],
          rel.metadata['chunk2'], 
          rel.metadata['confidence']
      ))

  rel_df = pd.DataFrame(rel_pairs, columns=['relation','entity1','entity1_begin','entity1_end','chunk1','entity2','entity2_begin','entity2_end','chunk2', 'confidence'])

  return rel_df

In [None]:
# common block
# 11-Aug-2022 add chunk filterer to filter entities of interest
documenter = DocumentAssembler()\
    .setInputCol("text")\
    .setOutputCol("document")

sentencer = SentenceDetector()\
    .setInputCols(["document"])\
    .setOutputCol("sentences")

tokenizer = Tokenizer()\
    .setInputCols(["sentences"])\
    .setOutputCol("tokens")\

words_embedder = WordEmbeddingsModel()\
    .load(embeddings_clinical_local_path)\
    .setInputCols(["sentences", "tokens"])\
    .setOutputCol("embeddings")

pos_tagger = PerceptronModel()\
    .load("./pretrained/pos_clinical_en_3.0.0_3.0_1617052315327") \
    .setInputCols(["sentences", "tokens"])\
    .setOutputCol("pos_tags")

dependency_parser = DependencyParserModel()\
    .load("./pretrained/dependency_conllu_en_3.0.0_3.0_1656858083101")\
    .setInputCols(["sentences", "pos_tags", "tokens"])\
    .setOutputCol("dependencies")

# to detect all radio ner
radio_ner = MedicalNerModel.load(radio_ner_model)\
    .setInputCols(["sentences", "tokens", "embeddings"])\
    .setOutputCol("ner_tags")

# to get all radio ner chunks
radio_ner_converter = NerConverter()\
    .setInputCols(["sentences", "tokens", "ner_tags"])\
    .setOutputCol("ner_chunks")

# to filter for required ners
non_cancer_ner_converter = NerConverter()\
    .setInputCols(["sentences", "tokens", "ner_tags"])\
    .setOutputCol("non_cancer_ner_chunks")\
    .setWhiteList(["body_part","anatomical_descriptor","direction","probability_high","probability_medium","probability_uncertain","probability_low"])

## to filter for cancer_imaging_findings ner chunks, and send for assertion detection
cancer_ner_converter = NerConverter()\
        .setInputCols(["sentences", "tokens", "ner_tags"])\
        .setOutputCol("cancer_ner_chunks")\
        .setWhiteList(["cancer_imaging_findings"])

## assertion detection for cancer_imaging_findings
cancer_ner_assertion = AssertionDLModel.load(radio_assertion_model) \
    .setInputCols(["sentences", "cancer_ner_chunks", "embeddings"]) \
    .setOutputCol("cancer_assertion")

## add filterer to filter for probability_high/probability_medium
cancer_ner_assertion_filterer = AssertionFilterer()\
    .setInputCols("sentences","cancer_ner_chunks","cancer_assertion")\
    .setOutputCol("cancer_assertion_filtered")\
    .setWhiteList(["probability_high","probability_medium"])

# merge ner_chunks by prioritizing the overlapping indices (chunks with longer lengths and highest information will be kept from each ner model)
chunk_merger = ChunkMergeApproach()\
    .setInputCols('non_cancer_ner_chunks', "cancer_assertion_filtered")\
    .setOutputCol('merged_ner_chunks')

In [None]:
# use this to set the entity pairs you need for this prediction pipeline
pairs_mets = ['cancer_imaging_findings-body_part', \
            'cancer_imaging_findings-anatomical_descriptor', \
            'anatomical_descriptor-body_part', \
            'direction-body_part',
            'direction-anatomical_descriptor'
            ]             

mets_rel_pairs = pairs_mets

##============================================================================
loaded_re_mets_model = RelationExtractionModel()\
    .load('./saved_models/'+re_mets_model_name)\
    .setInputCols(["embeddings", "pos_tags", "merged_ner_chunks", "dependencies"])\
    .setOutputCol("mets_relations")\
    .setRelationPairs(mets_rel_pairs)\
    .setMaxSyntacticDistance(4)\
    .setPredictionThreshold(0.8)

re_mets_pipeline = Pipeline(stages=[
    documenter,
    sentencer,
    tokenizer, 
    words_embedder, 
    pos_tagger, 
    dependency_parser,
    radio_ner,
    radio_ner_converter,
    non_cancer_ner_converter,
    cancer_ner_converter,
    cancer_ner_assertion,
    cancer_ner_assertion_filterer,
    chunk_merger,
    loaded_re_mets_model
])

empty_data = spark.createDataFrame([[""]]).toDF("text")
re_mets_pipeline_model = re_mets_pipeline.fit(empty_data)

In [None]:
# site of mets
mtext1 = """
your sample text
"""

In [None]:
text = mtext1
sample_data = spark.createDataFrame([[text]]).toDF("text")
sample_data.show()
sample_data.dtypes

model = re_mets_pipeline_model
#model = re_mets_merge_pipeline_model
preds = model.transform(sample_data)

In [None]:
# check cancer assertion status
preds.select(F.explode(F.arrays_zip(preds.cancer_ner_chunks.result, 
                                     preds.cancer_ner_chunks.metadata, 
                                     preds.cancer_assertion.result)).alias("cols")) \
      .select(F.expr("cols['0']").alias("chunks"),
              F.expr("cols['1']['entity']").alias("ner_label"),
              F.expr("cols['1']['sentences']").alias("sent_id"),
              F.expr("cols['2']").alias("assertion")).show(50,truncate=False)

In [None]:
# check cancer assertion filterer
preds.select(F.explode(F.arrays_zip(preds.cancer_ner_chunks.result, 
                                     preds.cancer_ner_chunks.metadata, 
                                     preds.cancer_assertion_filtered.result)).alias("cols")) \
      .select(F.expr("cols['0']").alias("chunks"),
              F.expr("cols['1']['entity']").alias("ner_label"),
              F.expr("cols['1']['sentences']").alias("sent_id"),
              F.expr("cols['2']").alias("filtered_assertion")).show(50,truncate=False)

In [None]:
# check merged_ner_chunks
preds.select(F.explode(F.arrays_zip(preds.merged_ner_chunks.result, 
                                     preds.merged_ner_chunks.metadata)).alias("cols")) \
      .select(F.expr("cols['0']").alias("chunks"),
              F.expr("cols['1']['entity']").alias("ner_label"),
              F.expr("cols['1']['sentences']").alias("sent_id")).show(50,truncate=False)

In [None]:
result_df = preds.select(F.explode(F.arrays_zip('mets_relations.result', 'mets_relations.metadata')).alias("cols")) \
.select(F.expr("cols['0']").alias("relations"),
        F.expr("cols['1']['sentence']").alias("sentence_id"),
        F.expr("cols['1']['entity1']").alias("entity1"),
#        F.expr("cols['1']['entity1_begin']").alias("entity1_begin"),
#        F.expr("cols['1']['entity1_end']").alias("entity1_end"),
        F.expr("cols['1']['chunk1']").alias("chunk1"),
        F.expr("cols['1']['entity2']").alias("entity2"),
#        F.expr("cols['1']['entity2_begin']").alias("entity2_begin"),
#        F.expr("cols['1']['entity2_end']").alias("entity2_end"),
        F.expr("cols['1']['chunk2']").alias("chunk2"),
        F.expr("cols['1']['confidence']").alias("confidence")
        )

result_df.show(n=10, truncate=False, vertical=True)

### Create a light pipeline for annotating free text

In [None]:
model = re_mets_pipeline_model
light_model = LightPipeline(model)

annotations = light_model.fullAnnotate(text)

rel_df = get_relations_df(annotations,"mets_relations")
rel_df

In [None]:
vis= RelationExtractionVisualizer()
vis.display(annotations[0], 'mets_relations', show_relations=True)

## Get prediction with sample.csv

In [None]:
df_text = pd.read_csv("./inference/sample.csv", usecols=['sn_report_number','Report','report_date','Conclusion'])
df_text.count()

In [None]:
df_text.head(2)

In [None]:
# check for null text
df_text.isnull().sum()

In [None]:
# fill null
df_text['Conclusion'] = df_text['Conclusion'].fillna('')

In [None]:
# save the re visualisation to html file for review
# save the re annotation to csv for review
for i in range(df_text['sn_report_number'].count()):
    text = df_text['Conclusion'].loc[i]
    annotations = light_model.fullAnnotate(text)
    temp_df = get_relations_df(annotations,'mets_relations')
    print(temp_df)
    # write to csv
    temp_df.to_csv('./inference/display_result/'+df_text['sn_report_number'].iloc[i]+'_tabular.csv')
    vis.display(annotations[0], 'mets_relations', show_relations=True, save_path="./inference/display_result/"+df_text['sn_report_number'].loc[i]+"_report.html")

## Get Prediction (testset 460 reports for manuscritpt)

In [None]:
input_df = pd.read_csv("./inference/testset_460.csv")
input_df.count()

In [None]:
input_df["conclusion"] = input_df["conclusion"].fillna("")

## Post Processing to derive site of mets (term normalization)
mapping files:
Refer README_mapping_file doc on how to use these files.
- 01_map_anatomical_descriptor_to_body_part.csv
- 02_map_cancer_body_part_anat.csv
- 03_map_normalized_body_part.csv
- 04_map_merge_body_part.csv
- 05_map_drop_body_part.csv

### Import mapping files

In [None]:
temp_df = pd.read_csv("./mapping/01_map_anatomical_descriptor_to_body_part.csv")
# to convert all to lower case 
temp_df.anatomical_descriptor = temp_df.anatomical_descriptor.str.lower().str.rstrip()
temp_df.map_to = temp_df.map_to.str.lower().str.rstrip()

map_anat_descriptor_dict = dict(zip(list(temp_df.anatomical_descriptor), list(temp_df.map_to)))
#map_anat_descriptor_dict.keys()

In [None]:
temp_df = pd.read_csv("./mapping/03_map_normalized_body_part.csv")
# to convert all to lower case 
temp_df.body_part = temp_df.body_part.str.lower().str.rstrip()
temp_df.normalized_body_part = temp_df.normalized_body_part.str.lower()
map_nbody_part_dict = dict(zip(list(temp_df.body_part), list(temp_df.normalized_body_part)))
#map_nbody_part_dict.keys()

In [None]:
mbody_part_df = pd.read_csv("./mapping/04_map_merge_body_part.csv")
# to convert all to lower case 
mbody_part_df.body_part1 = mbody_part_df.body_part1.str.lower().str.rstrip()
mbody_part_df.body_part2 = mbody_part_df.body_part2.str.lower().str.rstrip()
mbody_part_df.output_body_part = mbody_part_df.output_body_part.str.lower().str.rstrip()
mbody_part_df

In [None]:
map_cancer_chunk = pd.read_csv("./mapping/02_map_cancer_body_part_anat.csv")
# to convert all to lower case 
map_cancer_chunk.cancer_imaging_findings = map_cancer_chunk.cancer_imaging_findings.str.lower().str.rstrip()
map_cancer_chunk.body_part_anat = map_cancer_chunk.body_part_anat.str.lower().str.rstrip()
map_cancer_chunk = map_cancer_chunk.drop_duplicates()
map_cancer_chunk  

In [None]:
map_to_drop = pd.read_csv("./mapping/05_map_drop_body_part.csv")
# to convert all to lower case 
map_to_drop.body_part_to_drop = map_to_drop.body_part_to_drop.str.lower().str.rstrip()
map_to_drop = map_to_drop.drop_duplicates()
map_to_drop = set(map_to_drop.body_part_to_drop)

### post processing logic based on dataframe approach

In [None]:
def get_mapping02_bpart(cancer,bpart, df_mapping02):
    
    # find mapping for the input cancer chunk
    df_cancer_map = df_mapping02[df_mapping02["cancer_imaging_findings"]==cancer]
    #print(df_cancer_map)
    
    # check if cancer chunk is found in mapping file.
    # if yes, search and get the output_body_part
    if len(df_cancer_map) > 0:
        
        for idx,row in df_cancer_map.iterrows():

            # convert string to list
            map_list = row["body_part_anat"].split(',')
            map_list = [x.strip() for x in map_list]
            #print(map_list)

            # search bdpart in the list of values. if found, return the output_body_part
            # if not found, return ""

            if bpart in map_list:

                # 22-mar-2023 handle no value in output_body_part (eg cake + omental return no value)
                if row["output_body_part"] != "":
                    print("***map ",row["cancer_imaging_findings"],"+",bpart,"to ",row["output_body_part"])               
                    return(row["output_body_part"])
                else:
                    print("***map ",row["cancer_imaging_findings"],"+",bpart,"to no value")  
                    return("")
            else:
                # bdpart not found in list, check new row
                #print(cancer + "+" + bpart + " not found in mapping 02")
                #return(bpart)
                pass
        # after processing all available rows, still not found
        print(cancer + "+" + bpart + " not found in mapping 02")
        return(bpart)        
    else:
        print(cancer + " not found in mapping 02")
        # return original body_part
        return(bpart)
    
df_mapping02 = map_cancer_chunk    

In [None]:
mets_df = input_df.copy()
mets_df["pred_site_of_mets"]=""

for i in range(input_df['conclusion'].count()):
   
    print("row: ", i)
    
    # clean text
    text = input_df['conclusion'].loc[i].lower()
    print(text)

    # get predictions
    annotations = light_model.fullAnnotate(text)
    rel_df = get_relations_df(annotations,"mets_relations")

    ###############################################################################################        
    # step 1 - get relation: cancer_imaging_findings > body_part (both direction), append to df
    # columns cancer_imaging_findings body_part
    ###############################################################################################    
    # get cancer_imaging_findings - body_part (both direction)
    print("\n-------------------- step 1 --------------------")
    print("***get cancer to body_part\n")
    pair1 = rel_df[(rel_df['entity1']=='cancer_imaging_findings') & (rel_df['entity2']=='body_part')][['chunk1','chunk2']]
    pair2 = rel_df[(rel_df['entity2']=='cancer_imaging_findings') & (rel_df['entity1']=='body_part')][['chunk2','chunk1']]
    pair2 = pair2.rename(columns={"chunk2":"chunk1", "chunk1":"chunk2"})
    df_cancer_bpart = pd.concat([pair1,pair2])
    df_cancer_bpart = df_cancer_bpart.rename(columns={"chunk1":"cancer_imaging_findings", "chunk2":"body_part"})
    print(df_cancer_bpart)

    # check for lymphangitis carcinomatosis, add pulmonary
    s = "lymphangitis carcinomatosis"
    if s in rel_df[rel_df['entity1']=='cancer_imaging_findings']['chunk1'].str.lower().values.tolist():
        print("***found ",s, "add pulmonary")
        to_append = [s,"pulmonary"]
        df_length = len(df_cancer_bpart)
        df_cancer_bpart.loc[df_length] = to_append    
    elif s in rel_df[rel_df['entity2']=='cancer_imaging_findings']['chunk2'].str.lower().values.tolist():
        print("***found ",s, "add pulmonary")
        to_append = [s,"pulmonary"]
        df_length = len(df_cancer_bpart)
        df_cancer_bpart.loc[df_length] = to_append   

    ###############################################################################################        
    # step 2a - get relation: cancer_imaging_findings > anatomical_descriptor (both direction)
    # columns cancer_imaging_findings anatomical_descriptor body_part
    ###############################################################################################
    print("\n-------------------- step 2a --------------------")   
    print("***get cancer to anatomical_descriptor\n")
    # get cancer_imaging_findings - anatomical_descriptor (need begin/end to link to bodypart) (both direction)
    pair1 = rel_df[(rel_df['entity1']=='cancer_imaging_findings') & (rel_df['entity2']=='anatomical_descriptor')][['chunk1','chunk2','entity2_begin','entity2_end']]
    pair2 = rel_df[(rel_df['entity2']=='cancer_imaging_findings') & (rel_df['entity1']=='anatomical_descriptor')][['chunk2','chunk1','entity1_begin','entity1_end']]
    pair2 = pair2.rename(columns={"chunk2":"chunk1", "chunk1":"chunk2","entity1_begin":"entity2_begin","entity1_end":"entity2_end"})
    df_cancer_anat = pd.concat([pair1,pair2])
    df_cancer_anat = df_cancer_anat.rename(columns={"chunk1":"cancer_imaging_findings", "chunk2":"anatomical_descriptor","entity2_begin":"anat_begin","entity2_end":"anat_end"})
    print(df_cancer_anat)

    ###############################################################################################        
    # step 2b - get relation: anatomical descriptor > body_part (both direction)
    # columns cancer_imaging_findings anatomical_descriptor anat_begin anat_end
    ###############################################################################################
    print("\n-------------------- step 2b --------------------")   
    print("***get anatomical_descriptor to body_part\n") 
    # get cancer_imaging_findings - anatomical_descriptor (need begin/end to link to bodypart) (both direction)
    pair1 = rel_df[(rel_df['entity1']=='anatomical_descriptor') & (rel_df['entity2']=='body_part')][['chunk1','chunk2','entity1_begin','entity1_end']]
    pair2 = rel_df[(rel_df['entity2']=='anatomical_descriptor') & (rel_df['entity1']=='body_part')][['chunk2','chunk1','entity2_begin','entity2_end']]
    pair2 = pair2.rename(columns={"chunk2":"chunk1", "chunk1":"chunk2", "entity2_begin":"entity1_begin", "entity2_end":"entity1_end"})
    df_anat_bpart = pd.concat([pair1,pair2])
    df_anat_bpart = df_anat_bpart.rename(columns={"chunk1":"anatomical_descriptor", "entity1_begin":"anat_begin","entity1_end":"anat_end","chunk2":"body_part"})
    print(df_anat_bpart)

    ###############################################################################################        
    # step 2c1 - link cancer > anatomical descriptor > body_part (2 hops), get body_part
    # link based on anatomatical_descriptor+begin+end
    # this is to get cancer > anatomical_descriptor > body_part
    # if body_part = NaN, means this anatomical_descriptor has no link to body part, exclude (step 2c2 will process mapping file 01)
    ###############################################################################################
    print("\n-------------------- step 2c1 --------------------")      
    print("***get cancer to anatomical_descriptor to body_part (RE-2 hops)") 
    df_cancer_anat_bpart1 = pd.merge(df_cancer_anat,df_anat_bpart,how="inner")
    print(df_cancer_anat_bpart1)

    ###############################################################################################        
    # step 2c2 - cancer > anatomical descriptor > map body_part (mapping file 01)
    ###############################################################################################
    print("\n-------------------- step 2c2 --------------------")      
    print("***get cancer to anatomical_descriptor to mapped body_part (mapping file 01)") 

    # make a copy of df_cancer_anat to store the mapped body_part
    df_cancer_anat_bpart2 = df_cancer_anat.copy()
    df_cancer_anat_bpart2["body_part"] = ""

    for idx,row in df_cancer_anat_bpart2.iterrows():
        c = row['anatomical_descriptor'].lower()

        # check for segment*, add hepatic
        #print("***checking for segment*: ",c)
        if (c.find("segment") != -1) or (c=="segment"):
            print("***found segment*, update body_part to hepatic for segment*")
            df_cancer_anat_bpart2.at[idx,'body_part'] = "hepatic"

        else:
            # look up c in mapping file 01   
            try:
                v = map_anat_descriptor_dict[c]
                print("***map ",c," -> ",v)
                df_cancer_anat_bpart2.at[idx,'body_part'] = v
                print("row updated")
            except:
                pass    

    # fill null with "", and filter for rows with body_part values 
    df_cancer_anat_bpart2 = df_cancer_anat_bpart2.fillna("")
    df_cancer_anat_bpart2 = df_cancer_anat_bpart2[df_cancer_anat_bpart2["body_part"]!=""]   
    print(df_cancer_anat_bpart2)

    # resolve conflict between df_cancer_anat_bpart1(via link, prediction may be wrong), df_cancer_anat_bpart2(via map, more accurate)
    # logic1-take all from df_cancer_anat_bpart2, union remaining from df_cancer_anat_bpart1
    print("\n***resolving conflict between linked body_part / mapped body_part")
    df_cancer_anat_bpart1 = df_cancer_anat_bpart1.rename(columns={"body_part":"linked_body_part"})
    df_cancer_anat_bpart2= df_cancer_anat_bpart2.rename(columns={"body_part":"mapped_body_part"})
    df_step2_cancer_bpart = pd.merge(df_cancer_anat_bpart1,df_cancer_anat_bpart2, how="outer")
    df_step2_cancer_bpart["body_part"] = np.where(df_step2_cancer_bpart["mapped_body_part"].notnull(), df_step2_cancer_bpart["mapped_body_part"], df_step2_cancer_bpart["linked_body_part"])
    print(df_step2_cancer_bpart)

    # merge with step 1 body_part
    df_step2_cancer_bpart = pd.concat([df_cancer_bpart,df_step2_cancer_bpart])

    # drop duplicate values
    df_step2_cancer_bpart = df_step2_cancer_bpart[["cancer_imaging_findings","body_part"]].drop_duplicates()
    print(df_step2_cancer_bpart)   
       
    # remove \n in body_part
    df_step2_cancer_bpart["body_part"] = df_step2_cancer_bpart["body_part"].str.replace("\n","",regex=False)
    
    # reset index
    df_step2_cancer_bpart = df_step2_cancer_bpart.reset_index(drop=True)
    print("\n>>>step 2 sites after resolving anatomical to body_part: \n")
    print(df_step2_cancer_bpart)


    ###############################################################################################        
    # step 3 - process mapping file 02 map cancer+body_part > body_part
    # 22-mar-2023: to handle no return value (eg cake + omental > no value)
    ###############################################################################################  
    print("\n-------------------- step 3 --------------------")      
    print("***map cancer+body_part > output_body_part (mapping file 02)\n") 

    # use function to search against mapping 02
    df_step2_cancer_bpart['02_body_part']=""
    for idx,row in df_step2_cancer_bpart.iterrows():
        print(idx, row["cancer_imaging_findings"],row["body_part"])
        output_body_part = get_mapping02_bpart(row["cancer_imaging_findings"],row["body_part"], df_mapping02)
        #print(output_body_part)
        df_step2_cancer_bpart.at[idx,'02_body_part'] = output_body_part

    # fill null with "", and filter for rows with 02_body_part values 
    df_step2_cancer_bpart = df_step2_cancer_bpart.fillna("")
    df_step2_cancer_bpart = df_step2_cancer_bpart[df_step2_cancer_bpart["02_body_part"]!=""]   

    # use set to get all 02_body_part 
    site = set(df_step2_cancer_bpart['02_body_part'])

    print("\n>>>step 3 sites after cancer+body_part mapping: ", site)

    ###############################################################################################        
    # step 4 - process mapping file 03 map body_part to normalized_body_part
    ############################################################################################### 
    print("\n-------------------- step 4 --------------------")      
    print("***map body_part to normalized_body_part (mapping file 03)\n")
    print("current site: ", site)
    tempsite=set(site)
    newsite=set()
    if tempsite == set():
        newsite=set()
    else:
        for s in tempsite:
            try:
                v = map_nbody_part_dict[s]
                print(s, "-> ",v)
                newsite.add(v) 
            except:
                # no map, append current site
                print(s)
                newsite.add(s)                 
    print("\n>>>step 4 sites after normalized body_part: ", newsite)

    ###############################################################################################    
    # step 5 - process mapping file 04 merge body part
    ##############################################################################################      
    print("\n-------------------- step 5 --------------------")      
    print("***merge body part (mapping file 04)\n")
    for index, row in mbody_part_df.iterrows():
        #print(row)
        combine_site=set([row.body_part1,row.body_part2])
        if len(newsite.intersection(combine_site)) == 2:
            print("to merge :", combine_site)
            newsite = newsite.difference(combine_site)
            newsite.add(row.output_body_part)
            print("\n>>>step 5 sites after merge body_part: ", newsite)

    ###############################################################################################                
    # step 6- remove nodal if it still exists in output
    # process mapping file 05 drop body part values
    ###############################################################################################        
    #newsite = newsite.difference({'nodal'})
    print("\n-------------------- step 6 --------------------")         
    print("***drop body part (mapping file 05)\n")
    sites_to_drop = newsite.intersection(map_to_drop)
    if len(sites_to_drop) > 0:
        print(">>>dropping site from mapping 05: ",sites_to_drop )
    newsite = newsite.difference(map_to_drop)
    print("\n>>>step 6 sites after merge body_part: ", newsite, "\n")
    
    print("\n>>>final sites to output :", newsite, "\n")       

    mets_df.at[i, 'pred_site_of_mets'] = set(newsite)


In [None]:
output_filename = input_filename.replace(".csv","_pred_site_of_mets.csv")
print("saving predictions to: ", output_filename)
mets_df.to_csv("./inference/"+output_filename, index=False)

In [None]:
# check visualisation for specific case
i=0

text = input_df['conclusion'].loc[i].lower()
annotations = light_model.fullAnnotate(text)

rel_df = get_relations_df(annotations,"mets_relations")
rel_df

In [None]:
vis.display(annotations[0], 'mets_relations', show_relations=True)