# IOB Dataset Prompts

For sequence tagging datasets, sometimes it's easier to define a pure python prompting template vs. using the Jinja-based promptsource tools. A fair number of biomedical NER datasets are provided in IOB format, so these prompts should be transferable across other entity types. 



## Load Dataset
Some HuggingFace datasets are provided in tagged IOB format already. We'll load an example biomedical corpus (NCBI Disease Corpus) to illustrate. 

In [None]:
from datasets import load_dataset

dataset = load_dataset('ncbi_disease', split='train')

In [None]:
from datasets import load_dataset_builder
dataset_builder = load_dataset_builder('ncbi_disease')

dftrs = dataset_builder.info.features
dsplits = dataset_builder.info.splits

class_labels = dftrs['ner_tags'].feature.names

In [None]:
import collections
from itertools import groupby

def get_spans(toks, tags, class_labels):
    """
    Assume sentences that are tokenized and stored  as lists of: 
        `tokens` and `ner_tags` in IOB format. 
    For tag labels, we assume entity types are contiguous and that 
    the head token tags (B-*) labels are the minimal value within 
    an entity type. 
    
    For example for IOB tagging
        {0:'O', 1:'B-Disease', 2:'I-Disease', 3:'B-Chemical', 4:'I-Chemical'}
  
    NOTE: This does not support BILUO tagging like spaCy
    """
    idx = 0
    types = []
    spans = []
    
    # get head class label for each entity type
    entity_labels = {}
    for y in class_labels:
        entity_labels[class_labels[y]] = min(entity_labels[class_labels[y]], y) if class_labels[y] in entity_labels else y
 
    for i, j in groupby(tags):
        chunk = list(j)
        tokens = toks[idx:idx+len(chunk)]
        cls = chunk[0]
        entity_type = class_labels[cls]
        # non-entity tag
        if entity_type == 'O' or len(spans) == 0:
            spans.append(tokens)
            types.append(entity_type)
        # current token is I-tag and same type as B-tag
        elif types[-1] == entity_type and cls >= entity_labels[entity_type]:
            spans[-1].extend(tokens)
        # otherwise is B-tag entity
        else:
            spans.append(tokens)
            types.append(entity_type)
        idx += len(chunk)
        
    spans = [' '.join(s) for s in spans]
    return list(zip(types, spans))

# example
class_labels = {0:'O', 1:'disease', 2:'disease'}
entity_spans = get_spans(dataset[11]['tokens'], dataset[11]['ner_tags'], class_labels)
entity_spans


## Prompts as Python Functions
This is hacky and unpolished, but it works fine for something fast. See the [promptsource](https://github.com/bigscience-workshop/promptsource) contribution guide for more details around designing prompts.

NOTE: With sequence tagging tasks, we have to get a little creative when specificying tasks. It's unclear what the best stategy is for NER tasks and prompts. Some lessons learned from the initial prompt source experiments

- Performance is better if you generate the text from a span vs. some aux. attribute like the index of a span
- Outputs need to minimize ambiguity, e.g., when generating a list of outputs, using some human-like delimiters of commas, newlines, etc. 
- More templates per task is good for evaluation and other research inquiries 

In [None]:

def prompt_list_of_disease_entities_v1(x):
    class_labels = {0:'O', 1:'disease', 2:'disease'}
    entity_spans = get_spans(x['tokens'], x['ner_tags'], class_labels)
    sentence = ' '.join(x['tokens']).strip()
    # disease entities
    target = ', '.join([span[-1] for span in entity_spans if span[0] == 'disease'])
    if not target:
        target = 'None'
    
    tmpl = "Create a comma-separated list of all disease named entities found in the following sentence. "
    tmpl += "If there are no disease mentions, print None. \n"
    tmpl += f"Sentence: {sentence}\nEntities:|||{target}"
    return tmpl

def prompt_list_of_disease_entities_v2(x):
    """
    answers in prompt: ...
    original task: no
    """
    class_labels = {0:'O', 1:'disease', 2:'disease'}
    entity_spans = get_spans(x['tokens'], x['ner_tags'], class_labels)
    sentence = ' '.join(x['tokens']).strip()
    # disease entities
    target = ', '.join([span[-1] for span in entity_spans if span[0] == 'disease'])
    if not target:
        target = 'None'
    
    tmpl = "Identify all disease names mentioned in the following sentence. "
    tmpl += "If there are no disease mentions, print None. \n"
    tmpl += f"\"{sentence}\"\n|||{target}"
    return tmpl


def prompt_list_of_disease_entities_v3(x):
    class_labels = {0:'O', 1:'disease', 2:'disease'}
    entity_spans = get_spans(x['tokens'], x['ner_tags'], class_labels)
    sentence = ' '.join(x['tokens']).strip()
    # disease entities
    target = '\n'.join([f"- {span[-1]}" for span in entity_spans if span[0] == 'disease'])
    if not target:
        target = 'None'
    
    tmpl = "Create a bulleted list of all disease named entities found in the following sentence. "
    tmpl += "If there are no disease mentions, print None. \n"
    tmpl += f"\"{sentence}\"\n|||{target}"
    return tmpl

def prompt_list_of_disease_entities_v4(x):
    class_labels = {0:'O', 1:'disease', 2:'disease'}
    entity_spans = get_spans(x['tokens'], x['ner_tags'], class_labels)
    sentence = ' '.join(x['tokens']).strip()
    # disease entities
    target = '\n'.join([span[-1] for span in entity_spans if span[0] == 'disease'])
    if not target:
        target = 'None'
    
    tmpl = "Create a list, separated by newlines, of all disease named entities found in the following sentence. "
    tmpl += "If there are no disease mentions, print None. \n"
    tmpl += f"\"{sentence}\"\n|||{target}"
    return tmpl


def prompt_mention_of_disease_yes_no_answers(x):
    """
    Transform NER task into binary sentence classification, 
    "does this sentence contain an entity of type x"
    
    answers in prompt: yes
    original task: no
    """
    class_labels = {0:'O', 1:'disease', 2:'disease'}
    entity_spans = get_spans(x['tokens'], x['ner_tags'], class_labels)
    sentence = ' '.join(x['tokens']).strip()
    # disease entities
    diseases = [span[-1] for span in entity_spans if span[0] == 'disease']
    target = 'yes' if len(diseases) > 0 else 'no'
   
    tmpl = "Does the following sentence contain mentions of disease names? yes or no \n"
    tmpl += f"\"{sentence}\"\n|||{target}"
    return tmpl


for i in range(100):
    try:
        s = prompt_list_of_disease_entities_v4(dataset[i])
        print(s)
        print('-' * 80)
    except Exception as e:
        print('error', i)