# Annotation project workbook


This notebook contains code for annotating materials science examples including (optional) in-the-loop annotation. Reasonable defaults as described in the publication are given, though these can be tweaked to user preference.

Note: While we use OpenAI's API here for GPT-3 fine tuning, the openai code can be swapped out with handles to your LLM of choice.

### Tips we've learned from annotation many thousands of examples:
* You will encounter examples where it just isn't clear-cut how to fit the information from an abstract into the schema. Don't sweat these. Just do your best to do a "good enough" job and move on. These "bad/ok" examples will not affect the overall training much but it's important the model has seen this kind of thing before we run it over the full dataset.
* Don't succumb to sunk-cost fallacy. If the schema needs to change, do it as soon as you know in your heart it probably should change.
* We are aiming for about 1,000 annotations in total. After each round of 100 new annotations, retrain the model and use a new checkpoint

In [None]:
!pip install -r requirements.txt

## Step 1. Definining a schema

Use this section to define your annotation schema. It should be a python dictionary (e.g. jsonable).

Each distinct material mentioned in an abstract should get it's own entry and the result of an annotation "submit" will be a list of these dictionaries in the order they're mentioned in the text.

Here is an example we use for general materials data extraction:
```
material_schema = {
    'name': '',
    'formula': '',
    'description': [''],
    'acronym': '',
    'structure_or_phase': [''],
    'applications': ['']
    }
```

**Note: You can replace this schema with your own as you see fit (as long as the schema is a json-type document)**. If you use in the loop annotation, you should make sure the model you are using as the intermediate model is trained on the same schema you define here. 

In [None]:
my_schema = {
    'name': '',
    'formula': '',
    'description': [''],
    'acronym': '',
    'structure_or_phase': [''],
    'applications': ['']
    }

## Step 2. Set up a query to get training examples

This notebook queries [matscholar.com](matscholar.com) using the matscholar query syntax. You can use a simple phrase or some of the more advanced query operators. We advise you to build your query on matscholar.com first before pasting it here.


In our example query, we use "ferrimagnetic". But feel free to choose your own!

In [None]:
MY_QUERY = "ferrimagnetic"

## Some helper functions

In [None]:
import json
import re
from IPython.display import display_javascript, display_html, display

import pprint

pp = pprint.PrettyPrinter(indent=2)

def remove_many_newlines(string):
    """Removes all cases where there are multiple adjacent spaces after a newline in a string."""
    return re.sub(r'(\\n\ +)', '', string)

def remove_junk(s):
    for junk in ["<hi>", "</hi>", "<inf>", "</inf>", "<sup>", "</sup>", "</sub>", "</sub>"]:
        s = s.replace(junk, "")
    if "\n  " in s:
        s = remove_many_newlines(s)
    return s


## Step 3: Get test data from matscholar

In [None]:
import requests
from urllib import parse

def query_matscholar(query, hits=100, exlude_dois=[]):
    """Submits a query to matscholar and returns up to 10000 results.
    Args:
        query (str): Matscholar query. Supports matscholar query syntax.
    """
    all_results = []
    search_uri = f"https://matscholar.com/api/search/?query={parse.quote(query)}&type=all&restrict=doc&hits=1"
    num_results = requests.get(search_uri).json()['root']['fields']['totalCount']
    num_results = min(num_results, hits)
    for offset in range(0, num_results // 100 + 1):
        offset*=100
        search_uri = f"https://matscholar.com/api/search/?query={parse.quote(query)}&offset={offset}&type=all&restrict=doc&hits={min(hits, 100)}"
        results = requests.get(search_uri)
        if 'children' in results.json()['root']:
            all_results.extend(results.json()['root']['children'])
        else:
            break
    return [d for d in all_results if d['fields']['doi'] not in exlude_dois]

To use in-the-loop annotation with your own model, specify the API key and your model name. The annotation UI will automatically query the API to get a "good guess" for a completion based on your model. The quality of your model dertermines the quality of the guess. From this guess, you'll correct the annotation in the UI. We find this method to be much faster than annotating all abstracts from scratch. Usually, you'll need around 50 abstracts to train a decent intermediate model. 

In [None]:
# create a completion
import openai

openai.api_key = "YOUR_API_KEY_HERE"
YOUR_MODEL_NAME = "YOUR_MODEL_NAME_HERE"

def extract_materials_data(abstract, model=YOUR_MODEL_NAME):
    """ Sends abstract to OpenAI API and uses custom model name if provided.
    """
    if model and model != "YOUR_MODEL_NAME_HERE":
        start_sequence = "\n\n###\n\n"
        prompt = abstract + start_sequence
        response = openai.Completion.create(
          model = model,
          prompt=prompt,
          temperature=0,
          max_tokens=512,
          top_p=1,
          frequency_penalty=0,
          presence_penalty=0,
          stop=["\n\nEND\n\n"],
        )
        return response.choices[0].text.replace("name_of_mof", "mof_name")
    else:
        return json.dumps([my_schema])


def clean_up_results(results, fields_to_keep=['abstract', 'title', 'doi', 'year']):
    cleaned_results = []
    for r in results:
        new_entry = {key:r['fields'][key] for key in fields_to_keep}
        new_entry['title'] = remove_junk(new_entry['title'])
        new_entry['abstract'] = remove_junk(new_entry['abstract'])
        new_entry['annotation'] = []
        cleaned_results.append(new_entry)
    return cleaned_results

def make_prompt(entry):
    return entry['title'] + "\n" + entry['abstract']

### Load previous annotations to prevent getting same paper twice

In [None]:
PREVIOUS_ANNOTATIONS = "" # e.g. "/content/drive/MyDrive/<path to previous annotations folder>"
ANNOTATIONS_PREFIX = "my_annotations" # prefix for annotation files
if PREVIOUS_ANNOTATIONS:
    with open(PREVIOUS_ANNOTATIONS, "r") as file:
        if PREVIOUS_ANNOTATIONS.endswith("jsonl"):
            lines = file.readlines()
            records = []
            for line in lines:
                records.append(json.loads(line))
        else:
            records = json.load(file)


## Step 4: Annotate!
### Run the annotation GUI

In [None]:
import random
import ipywidgets as widgets
from IPython.display import display

results = clean_up_results(query_matscholar(MY_QUERY, hits=900))
results = random.sample(results, 100)

try:
    previous_prompts = [x['prompt'] for x in records]
except:
    previous_prompts = []
queue = [r for r in results if make_prompt(r) not in previous_prompts]
saved_entries = []

material_schema = my_schema

output = widgets.Output()

# button to add new material
new_material_button = widgets.Button(
    description='Add Material',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Add new material',
    icon='plus' # (FontAwesome names without the `fa-` prefix)
)

def on_new_material_button_clicked(b):
    global entry
    entry['annotation'] = eval(text_area.value)
    entry['annotation'].append(material_schema.copy())
    with output:
        output.clear_output()
        text_area.value = json.dumps(entry['annotation'], indent=2)
        display(content)

new_material_button.on_click(on_new_material_button_clicked)


back_button = widgets.Button(
    description='Back',
    disabled=False,
    button_style='info', # 'success', 'info', 'warning', 'danger' or ''
    icon='back' # (FontAwesome names without the `fa-` prefix)
)

def on_back_button_clicked(b):
    global entry
    global queue
    queue.append(entry)
    entry = saved_entries.pop()

    with output:
        output.clear_output()
        gui_abstract.value = make_prompt(entry)
        text_area.value = json.dumps(entry['annotation'], indent=2)
        display(content)

back_button.on_click(on_back_button_clicked)

# save annotation
save_button = widgets.Button(
    description='Save',
    disabled=False,
    button_style='success', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Save annotation',
    icon='check' # (FontAwesome names without the `fa-` prefix)
)

def on_save_button_clicked(b):
    global entry
    global queue
    entry['annotation'] = eval(text_area.value)
    with output:
        gui_abstract.value = "Please wait..."
        text_area.value = "Please wait..."
    saved_entries.append(entry)
    entry = queue.pop()
    with output:
        output.clear_output()
        prompt = make_prompt(entry)
        entry['annotation'] = json.loads(extract_materials_data(prompt))
        entry['annotation'] = entry['annotation']
        gui_abstract.value = prompt
        text_area.value = json.dumps(entry['annotation'], indent=2)
        display(content)

save_button.on_click(on_save_button_clicked)


# skip annotation
skip_button = widgets.Button(
    description='Skip',
    disabled=False,
    button_style='warning', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Save annotation',
    icon='check' # (FontAwesome names without the `fa-` prefix)
)

def on_skip_button_clicked(b):
    global entry
    global queue
    entry = queue.pop()
    with output:
        output.clear_output()
        prompt = make_prompt(entry)
        entry['annotation'] = json.loads(extract_materials_data(prompt))
        entry['annotation'] = entry['annotation']
        gui_abstract.value = prompt
        text_area.value = json.dumps(entry['annotation'], indent=2)
        display(content)

skip_button.on_click(on_skip_button_clicked)

# clear annotation
clear_button = widgets.Button(
    description='Clear',
    disabled=False,
    button_style='danger', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Clear annotation',
)

def on_clear_button_clicked(b):
    with output:
        output.clear_output()
        text_area.value = json.dumps([], indent=2)
        display(content)

clear_button.on_click(on_clear_button_clicked)


def setup():
    global entry
    global queue
    entry = queue.pop()
    entry['annotation'] = json.loads(extract_materials_data(make_prompt(entry)))
    entry['annotation'] = entry['annotation']

    text_area = widgets.Textarea(value=json.dumps(entry['annotation'], indent=2),
                                 layout=widgets.Layout(height='600px'))
    gui_abstract = widgets.Textarea(disabled=True,
                                    value=make_prompt(entry),
                                    layout=widgets.Layout(width='70%', height='200px'))
    return text_area, gui_abstract

text_area, gui_abstract = setup()

content = widgets.VBox([
    widgets.HBox(
        [save_button,skip_button, back_button]),
    gui_abstract,
    widgets.HBox(
        [new_material_button, clear_button]),
    text_area]
)
display(content)

### Saving your annotations

In [None]:
records = []
for entry in saved_entries:
    records.append({'prompt':make_prompt(entry), 'completion':entry['annotation'], 'record':entry})
print("You have", len(records), "records to save.")

In [None]:
from datetime import datetime
def save_records(records, combine_with=None):
    old_records = []
    if combine_with:
        try:
            with open(combine_with, "r") as file:
                old_records = json.load(file)

        except:
            with open(combine_with, "r") as file:
                old_records = [json.loads(f) for f in file.readlines()]
    clean_records = old_records + [record for record in records if record not in old_records]
    filename = f"{ANNOTATIONS_PREFIX}_{datetime.now().strftime('%m_%d_%Y_%H%M%S')}.json"

    with open(filename, "w") as file:
        json.dump(clean_records, file)
    return filename

records_filepath = save_records(records)

## Loading previous annotations

In [None]:
def load_records(filename):
    with open(filename, "r") as file:
        return json.load(file)

records = load_records(records_filepath)

Now we prepare a jsonlines file for input to the OpenAI API (or LLM of your chocie.)

In [None]:
def prepare_fine_tune(filename):
    new_filename = filename.replace(".json", ".jsonl")
    with open(new_filename, "w") as writer:
        for r in records:
            r_new = {}
            r_new['prompt'] = r['prompt'] + "\n\n###\n\n"
            r_new['completion'] = ' ' + json.dumps(r['completion']) + '\n\nEND\n\n'
            writer.write(json.dumps(r_new) + "\n")
    print(f"JSONL file written to {new_filename}")

prepare_fine_tune(records_filepath)

## Step 5: Train model

Now in a terminal, use the prepared file to train a model through the openai api.

The exact syntax for this depends on the version of the openai python package you are using, the train/validation/test split you desire, and other factors.

For more info, see the OpenAI API docs or use 


In [21]:
!openai api fine_tunes.create --help

usage: openai api fine_tunes.create [-h] -t TRAINING_FILE [-v VALIDATION_FILE]
                                    [--no_check_if_files_exist] [-m MODEL]
                                    [--suffix SUFFIX] [--no_follow]
                                    [--n_epochs N_EPOCHS]
                                    [--batch_size BATCH_SIZE]
                                    [--learning_rate_multiplier LEARNING_RATE_MULTIPLIER]
                                    [--prompt_loss_weight PROMPT_LOSS_WEIGHT]
                                    [--compute_classification_metrics]
                                    [--classification_n_classes CLASSIFICATION_N_CLASSES]
                                    [--classification_positive_class CLASSIFICATION_POSITIVE_CLASS]
                                    [--classification_betas CLASSIFICATION_BETAS [CLASSIFICATION_BETAS ...]]

optional arguments:
  -h, --help            show this help message and exit
  -t TRAINING_FILE, --trainin