In [21]:
import os
import json
import shutil
import numpy as np
import tempfile

from allennlp.common.params import Params
from allennlp.commands.train import train_model

from IPython.display import Markdown, display
from ipywidgets import Textarea, VBox, HBox, Layout, widgets

from model import InstEntityTagger
from predictor import InstPredictor
from dataset_reader import InstDatasetReader

# Global variables.
TRAIN_BUFFER_PATH = "../data/train_buffer.tmp"
VALIDATION_BUFFER_PATH = "../data/validate_buffer.tmp"
TRAIN_DONE = False
VALIDATION_DONE = False

def printmd(string):
    display(Markdown(string))

def train_prompt():
    train_text = widgets.Text(layout=Layout(width='70%'))
    printmd("**Training text:**")
    display(train_text)
    train_text.on_submit(read_train)
    
def validation_prompt():
    valid_text = widgets.Text(layout=Layout(width='70%'))
    printmd("**Validation text:**")
    display(valid_text)
    valid_text.on_submit(read_validation)

def read_train(sender):
    global TRAIN_DONE
    assert os.path.exists("../data/") # Avoid absolute paths?
    # assert not os.path.isfile(TRAIN_BUFFER_PATH)
    with open(TRAIN_BUFFER_PATH, 'w') as train_file:
        train_file.write(sender.value)
    printmd("**Writing to:** " + TRAIN_BUFFER_PATH)
    printmd("**Wrote:** " + sender.value)
    TRAIN_DONE = True
    
def read_validation(sender):
    global VALIDATION_DONE
    assert os.path.exists("../data/") # Avoid absolute paths?
    # assert not os.path.isfile(TRAIN_BUFFER_PATH)
    with open(VALIDATION_BUFFER_PATH, 'w') as valid_file:
        valid_file.write(sender.value)
    printmd("**Writing to:** " + VALIDATION_BUFFER_PATH)
    printmd("**Wrote:** " + sender.value)
    VALIDATION_DONE = True
    
def set_params(train_buffer_path: str, validation_buffer_path: str) -> Params:
    # Modifying parameter values
    params = Params.from_file('template.jsonnet')
    params.__setitem__("train_data_path", train_buffer_path)
    params.__setitem__("validation_data_path", validation_buffer_path)
    # print(json.dumps(params.as_dict(), indent=4))
    return params

"""
I lived in *Munich last summer. *Germany has a relaxing, slow summer lifestyle. 
One night, I got food poisoning and couldn't find !Tylenol to make the 
pain go away, they insisted I take !aspirin instead.
"""

"""
When I lived in Paris last year, France was experiencing a recession. 
The night life was too fun, I developed an addiction to Adderall and cocaine.
"""

if __name__ == "__main__":
    train_prompt()
    validation_prompt()

**Training text:**

Text(value='', layout=Layout(width='70%'))

**Validation text:**

Text(value='', layout=Layout(width='70%'))

**Writing to:** ../data/train_buffer.tmp

**Wrote:** I lived in *Munich last summer. *Germany has a relaxing, slow summer lifestyle.  One night, I got food poisoning and couldn't find !Tylenol to make the  pain go away, they insisted I take !aspirin instead.

**Writing to:** ../data/validate_buffer.tmp

**Wrote:** When I lived in Paris last year, France was experiencing a recession.  The night life was too fun, I developed an addiction to Adderall and cocaine.

In [22]:
def main():
    assert TRAIN_DONE and VALIDATION_DONE
    params = set_params(TRAIN_BUFFER_PATH, VALIDATION_BUFFER_PATH)
    
    parms = params.duplicate()
    serialization_dir = tempfile.mkdtemp()
    model = train_model(params, serialization_dir)

    predpath = parms.pop(key="validation_data_path")

    # Make predictions
    predictor = InstPredictor(model, dataset_reader=InstDatasetReader())
    with open(predpath, "r") as text_file:
        lines = text_file.readlines()
    all_text = " ".join(lines) # Makes it all 1 batch.
    output_dict = predictor.predict(all_text)
    tags = output_dict['tags']
    dataset_reader = InstDatasetReader()
    
    PRINT_STDOUT = False
    out = []
    
    with open("log.log", 'a') as log:
        for instance in dataset_reader._read(predpath):
            tokenlist = list(instance['sentence'])
            for i, token in enumerate(tokenlist):
                log.write(tags[i] + str(token) + "\n")
                if PRINT_STDOUT:
                    print(tags[i] + str(token))
                out.append([str(token), tags[i]])
    shutil.rmtree(serialization_dir)
    print("DONE.")
    return out

if __name__ == "__main__":
    out = main()

0it [00:00, ?it/s]
1it [00:00, 94.65it/s]

0it [00:00, ?it/s]
1it [00:00, 68.76it/s]

0it [00:00, ?it/s]
2it [00:00, 8586.09it/s]

  0%|          | 0/500000 [00:00<?, ?it/s]
100%|##########| 500000/500000 [00:06<00:00, 73084.67it/s]

  0%|          | 0/500000 [00:00<?, ?it/s]
100%|##########| 500000/500000 [00:06<00:00, 74718.06it/s]

  0%|          | 0/1 [00:00<?, ?it/s]
accuracy: 0.8571, loss: 0.9295 ||: 100%|##########| 1/1 [00:00<00:00, 17.40it/s]

  0%|          | 0/1 [00:00<?, ?it/s]
accuracy: 0.8667, loss: 0.9748 ||: 100%|##########| 1/1 [00:00<00:00, 46.60it/s]



DONE.


<span style="background-color: #3399ff">This text is highlighted in yellow.</span>

In [25]:
# Blue: #3399ff
# Red: #ff5050

CITY_SPAN_OPEN = "<span style=\"background-color: #ff5050\">"
DRUG_SPAN_OPEN = "<span style=\"background-color: #3399ff\">"
SPAN_CLOSE = "</span>"

def generate_markdown(out):
    md_list = []
    for pair in out:
        if pair[1] == '*': # Cities
            md_list.append(CITY_SPAN_OPEN)
            md_list.append(pair[0] + SPAN_CLOSE)
        elif pair[1] == '!': # Drugs
            md_list.append(DRUG_SPAN_OPEN)
            md_list.append(pair[0] + SPAN_CLOSE)
        else:
            md_list.append(pair[0])
    return md_list

if __name__ == "__main__":
    md_list = generate_markdown(out)
    md = ' '.join(md_list)
    printmd(md)

        
        

When I lived in <span style="background-color: #ff5050"> Paris</span> last year , <span style="background-color: #ff5050"> France</span> was experiencing a recession . The night life was too fun , I developed an addiction to <span style="background-color: #3399ff"> Adderall</span> and <span style="background-color: #3399ff"> cocaine</span> .