In [1]:
import re
import pandas as pd
from tqdm.notebook import tqdm
from sklearn.metrics import f1_score

def find_indication(report_text: str, return_matched_phrases=True):
    """Extract indication for the MRI Radiology report
    
    Args:
        report_text: A string containing the radiology report for an MRI exam
        return_matched_phrases: If true, this function will return a tuple
            containing the indication and a list of matched phrases.
    
    Returns:
        indication (string, list): The indication for the MRI exam. First item
                                   in the tuple is a indication category name.
                                   Second item are the matched regexes (only if
                                   return_matched_phrases is True).
    """
    if pd.isna(report_text):
        return "unknown", []
    
    report_text = report_text.lower()

    symptoms_list = [
        "nipple discharge",
        "nipple inversion",
        "pain",
        "palpable",
        "palpated",
        "swelling",
        "asymmetry",
        "breast hardness"
    ]


    # "dilated duct" => a sign of workup after ultrasound??
    # CLINICAL HISTORY: 69-year-old female with history of a right lumpectomy in 2011. => a sign of surveillance

    # clinical indication:50-year-old woman for screening mri
    # with new diagnosis


    
    # Regexes that directly point to the indication
    regexes = {
        "screening": [
            # Indication: High-risk screening
            #
            # usually done in patients with high- or intermedite-risk of breast cancer
            # such as BRCA2 mutation, family history, etc.
            #
            r'presenting for (a )?screening',
            r'presenting for (a )?high[- ]risk screening',
            r'presents for (a )?screening',
            r'high[- ]risk screening',
            r'routine screening',
            r'clinical history: screening',
            r'indication: screening',
            r'continued annual mr(i)? screening',
            #r'annual screening',  # seems to pull too many exams that are not screening
            r'indication for exam: (patient is a )?\d\d[- ]year[- ]old (female )?with a family history',  # ??? is this safe???
        ],
        "followup_or_surveillance": [
            # Indication: Follow-up or surveillance
            #
            # Follow-up is done with patients that had a slightly suspicious
            # (BI-RADS 3 Probably Benign category) finding in a previous examination.
            #
            # Surveillance is done in patients that have a history of breast cancer
            # and now they are being actively monitored for new findings (disease recurrence).
            #
            r'six months follow-up',
            r'(presenting|presents) for (a )?six-month(s)? follow-up',
            r'(presenting|presents) for (a )?6[- ]month(s)? follow-up',
            r'(indication|clinical history): short interval (six-month )?follow-up',
            r'(indication|clinical history): 6[- ]month(s)? follow-up',
            r'(presenting|presents) for [^\.\n]+follow[- ]up',
            r'(presenting|presents) for follow[- ]up',
            r'(presenting|presents) for [^\.\n]+followup',
            r'clinical indication:[^\.\n]+short interval follow[- ]?up',
            r'clinical indication:[^\.\n]+short[- ]term follow[- ]?up',
            r'clinical indication: \d\d[- ]year[- ]old female\. short interval follow[- ]?up',
            r'this is (a )?six-month(s)? follow[- ]up',
            r'this is (a )?six-month(s)? followup',
            # surveillance
            r'mr(i)? performed for surveillance',
            r'(presenting|presents) for surveillance',
            r'(presenting|presents) for surveillance',
            r'(presenting|presents) for mr(i)? surveillance',
            r'clinical indication: (left |right )?breast ca',
            r'high[- ]risk surveillance',
            r'annual (screening )?surveillance',
            r'high[- ]risk evaluation',
            r'for surveillance mri',
            r'routine surveillance mri.',
        ],
        "workup": [
            # Indication: Workup aka Problem Solving
            #
            # Workup is done in patients: (i) after a suspicious symptom shows up such as
            # nipple discharge, breast pain, swelling; (ii) after a suspicious finding is
            # found on another examination, e.g. mammogram or ultrasound.
            #
            r'indication:[^\.;\n]+problem solving',
            r'study (was )?performed for problem solving',
            r'evaluate (a )?questioned area',
            r'follow-up (mri ?)(status )?post (benign )?biopsy',
            r'presenting for further evaluation',
            r'presenting for evaluation of reported',
            r'indication:[^\.;\n]+for mri evaluation[^\.;\n]+findings',
            r'presenting for [^\.\n]+({})'.format('|'.join(symptoms_list)),
            r'mri[^\.;\n]+to evaluate[^\.;\n]+({})'.format('|'.join(symptoms_list)),
            r'mri evaluation[^\.;\n]+({})'.format('|'.join(symptoms_list)),
            r'mri[^\.;\n]+to further evaluate[^\.;\n]+({})'.format('|'.join(symptoms_list)),
            r'presenting [^\.\n]+problem solving',
            r'clinical information:[^\.;\n]+({})'.format('|'.join(symptoms_list)),
            r'underwent mri for ({})'.format('|'.join(symptoms_list)),
            r'mri was advised.',
            r'mri for (further )?evaluation.',
            r'clinical history:[^\.;\n]+({})'.format('|'.join(symptoms_list)),
            r'clinical indication: \d\d[- ]year[- ]old female with[^\.;\n]+({})'.format('|'.join(symptoms_list)),
            r'for which mri was recommended.',
        ],
        "implant": [
            # Indication: Implant assessment
            #
            # These are MRIs performed in patients with breast implants to check their
            # status for ruptures, abnormalities, etc.
            #
            r'evaluate for implant',
            r'evaluation of implant',
        ],
        "extent_of_disease": [
            # Indication: Extent of disease
            #
            # These MRIs are performed in patients who were diagnosed with breast cancer
            # on a biopsy, and now we want to (i) confirm the final size of the findings,
            # (ii) look to any additional areas of concern, (iii) use this MRI to prepare
            # surgeons for excision.
            #
            r'extent of disease',
            r'evaluation of disease',
            
            r'(indication|clinical history): (presurgical|preoperative)',
            r'presenting for (presurgical|preoperative)',
            r'here for (presurgical|preoperative)',

            r'clinical indication:[^\.\n]+staging',  # NOTE: This can cause some FPs when Pts are "restaging"
        ],
        "treatment_response": [
            # Indication: Treatment response
            #
            # Performed in patients who have been receiving breast cancer treatment, e.g. chemotherapy
            # and now they are evaluated for whether they responded to the treatment.
            r'response to treatment',
            r'assess (for )?treatment',
            r'evaluation of treatment',
        ]
    }
    
    matched_regexes = []
    regexes_categories = set()
    
    for regex_group, regex_phrases in regexes.items():
        for regex in regex_phrases:
            regex_result = re.search(regex, report_text)
            if regex_result:
                matched_regexes.append(regex_result)
                regexes_categories.add(regex_group)
    
    if len(matched_regexes) == 0:
        if return_matched_phrases:
            return "unknown", matched_regexes
        else:
            return "unknown"
    elif len(regexes_categories) > 1:
        if return_matched_phrases:
            return "conflict", matched_regexes
        else:
            return "conflict"
    else:
        if return_matched_phrases:
            return list(regexes_categories)[0], matched_regexes
        else:
            return list(regexes_categories)[0]


In [2]:
import pickle
from pathlib import Path
from collections import defaultdict


In [3]:
data_dir = Path('/gpfs/data/geraslab/ekr6072/projects/study_indication/data')
data_path = data_dir / 'dataset.pkl'

In [4]:
heuristic_output_mapping = {
  'screening': '(high-risk) screening',
  'followup_or_surveillance': '6-month follow-up / surveillance',
  'workup': 'additional workup',
  'implant': 'exclude',
  'extent_of_disease': 'extent of disease / pre-operative planning',
  'treatment_response': 'treatment monitoring',
  'unknown': 'unknown',
  'conflict': 'unknown',
  '[not applicable]': 'exclude'
}

In [5]:
id2category = [
  '(high-risk) screening',
  'extent of disease / pre-operative planning',
  'additional workup',
  '6-month follow-up / surveillance',
  'treatment monitoring',
  'exclude',
  'unknown',
]

In [6]:
category2id = {
  '(high-risk) screening': 0,
  'extent of disease / pre-operative planning': 1,
  'additional workup': 2,
  '6-month follow-up / surveillance': 3,
  'treatment monitoring': 4,
  'exclude': 5,
  'unknown': 6,
}

In [7]:
def heuristic_model(text, return_id=True):
  pred = find_indication(text, return_matched_phrases=False)
  pred = heuristic_output_mapping[pred]
  if not return_id:
    return pred
  pred = category2id[pred]
  return pred

In [8]:
with open(data_path, 'rb') as f:
  dataset = pickle.load(f)

In [9]:
def clean_dataset(dataset):
  output = {}
  for name, subset in dataset.items():
    clean_subset = []
    for data in subset:
      label = data['label']
      if label not in ['exclude', 'unknown']:
        clean_subset.append(data)
    output[name] = clean_subset
  return output

In [10]:
dataset = clean_dataset(dataset)

In [19]:
names = ['train', 'val']
for name in names:
  predictions = []
  labels = []
  correct = 0
  pred_count = defaultdict(int)
  for data in dataset[name]:
    label = category2id[data['label']]
    pred = heuristic_model(data['text']['longText'], return_id=True)
    if label == pred:
      correct += 1
    predictions.append(pred)
    labels.append(label)
  accuracy = correct / len(dataset[name])
  f1 = f1_score(labels, predictions, average='macro')
  print(f'{name}:  accuracy = {accuracy:.4f}, f1_score = {f1:.4f}')

train:  accuracy = 0.4536, f1_score = 0.3917
val:  accuracy = 0.3667, f1_score = 0.2858
