# Relation extraction on BERN2 data

Processing steps:
1. Download BERN2 data: http://nlp.dmis.korea.edu/projects/bern2-sung-et-al-2022/annotation_v1.1.tar.gz
2. Process the files in the form of **sentence, plant_mention, disease_mention**. The original file has the PubMed identifier, sentence number (which can also be obtained using a sentence_splitter), the entity mentions and their CURIEs (MeSH and NCBITaxon identifiers). Since the BERN2 dump is over 60GB cannot be run in a laptop, the easiest way to process the file is to remove all the entities that are not MeSH or NCBITaxon identifiers, and proceed from there to format it as described above (**sentence, plant_mention, disease_mention**).
3. Load a dataframe and and run notebook (**run_relation_extraction** method).

In [None]:
import logging
from typing import Dict, Optional, Any, Tuple

from thefuzz import fuzz
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

logger = logging.getLogger(__name__)

In [None]:
BASE_MODEL = "bigscience/T0_11B"
HEAVIEST_MODEL = "bigscience/T0pp"

def initialize_model(huggingface_model: str = BASE_MODEL) -> Tuple[AutoModelForSeq2SeqLM, AutoTokenizer]:
    """Initialize model and tokenizer."""
    tokenizer = AutoTokenizer.from_pretrained(huggingface_model)
    model = AutoModelForSeq2SeqLM.from_pretrained(huggingface_model)

    return model, tokenizer


def _process_answer(answer: str) -> str:
    """Process answer from model."""
    # This requires to remove special characters, stripping and lower casing
    return answer.replace('</s>', '').replace('<pad>', '').strip().lower()


def run_prompt(model, tokenizer, prompt):
    """Run prompt on the model."""
    try:
        inputs = tokenizer.encode(prompt, return_tensors="pt")
        outputs = model.generate(inputs)
    except Exception as e:
        logger.error(f'skip sentence: {e} \n prompt {prompt}')
        return ''

    # Return the answer after processing
    return _process_answer(tokenizer.decode(outputs[0]))


def _evaluate_plant_disease_prompt(
    plant_entity: str,
    disease_entity: str,
    answer_1: str,
    answer_2: str,
    answer_3: str,
):
    """Evaluate the answers of the model for a plant-disease prompt."""
    answers = []

    # Evaluate answer 1 (which plants are used to treat {disease_entity}?)
    if plant_entity in answer_1 or fuzz.partial_ratio(answer_1, plant_entity) > 90:
        answers.append(True)
    else:
        answers.append(False)

    # Evaluate answer 2 (is {plant_entity} used to treat {disease_entity}?)
    if 'true' in answer_2:
        answers.append(True)
    else:
        answers.append(False)

    # Evaluate answer 3 (which diseases are associated with {plant_entity}?)
    if disease_entity in answer_3 or fuzz.partial_ratio(answer_3, disease_entity) > 90:
        answers.append(True)
    else:
        answers.append(False)

    # If any of the answers if True
    if any(answers):
        edge_exist = True

        if all(answers):
            confidence = 'high'
        elif sum(answers) == 2:
            confidence = 'medium'
        else:
            confidence = 'low'
    else:
        edge_exist = False
        confidence = ''

    return {
        'edge_exists': edge_exist,
        'confidence': confidence,
    }


def run_relation_extraction(
    model,
    tokenizer,
    sentence: str,
    plant_mention: str,
    disease_mention: str,
) -> Optional[Dict[str, Any]]:
    """Run a prompt of the language model to evaluate whether there is a relation between a plant and a disease."""

    plant_entity = plant_mention.strip()
    disease_entity = disease_mention.strip()

    answer_1 = run_prompt(
        model=model,
        tokenizer=tokenizer,
        prompt=f"{sentence}. In the previous sentence, which plants are used to treat {disease_entity}?"
    )

    answer_2 = run_prompt(
        model=model,
        tokenizer=tokenizer,
        prompt=f"{sentence}. In the previous sentence, is {plant_entity} used to treat {disease_entity}?"
    )

    answer_3 = run_prompt(
        model=model,
        tokenizer=tokenizer,
        prompt=f"{sentence}. In the previous sentence, which diseases are associated with {plant_entity}?"
    )

    return _evaluate_plant_disease_prompt(
        plant_entity=plant_entity,
        disease_entity=disease_entity,
        answer_1=answer_1,
        answer_2=answer_2,
        answer_3=answer_3,
    )
