In [140]:
import os
import json
import shutil
import numpy as np
import tempfile

from allennlp.common.params import Params
from allennlp.commands.train import train_model

from IPython.display import display
from ipywidgets import Textarea, VBox, HBox, Layout, widgets

from model import InstEntityTagger
from predictor import InstPredictor
from dataset_reader import InstDatasetReader

# Global variables.
TRAIN_BUFFER_PATH = "../data/train_buffer.tmp"
VALIDATION_BUFFER_PATH = "../data/validate_buffer.tmp"
TRAIN_DONE = False
VALIDATION_DONE = False

def train_prompt():
    train_text = widgets.Text(layout=Layout(width='70%'))
    print("Training text:")
    display(train_text)
    train_text.on_submit(read_train)
    
def validation_prompt():
    valid_text = widgets.Text(layout=Layout(width='70%'))
    print("Validation text:")
    display(valid_text)
    valid_text.on_submit(read_validation)

def read_train(sender):
    global TRAIN_DONE
    assert os.path.exists("../data/") # Avoid absolute paths?
    # assert not os.path.isfile(TRAIN_BUFFER_PATH)
    with open(TRAIN_BUFFER_PATH, 'w') as train_file:
        train_file.write(sender.value)
    print("============")
    print("Writing to", TRAIN_BUFFER_PATH)
    print("Wrote:", sender.value)
    print("============")
    TRAIN_DONE = True
    
def read_validation(sender):
    global VALIDATION_DONE
    assert os.path.exists("../data/") # Avoid absolute paths?
    # assert not os.path.isfile(TRAIN_BUFFER_PATH)
    with open(VALIDATION_BUFFER_PATH, 'w') as valid_file:
        valid_file.write(sender.value)
    print("============")
    print("Writing to", VALIDATION_BUFFER_PATH)
    print("Wrote:", sender.value)
    print("============")
    VALIDATION_DONE = True
    
def set_params(train_buffer_path: str, validation_buffer_path: str) -> Params:
    # Modifying parameter values
    params = Params.from_file('template.jsonnet')
    params.__setitem__("train_data_path", train_buffer_path)
    params.__setitem__("validation_data_path", validation_buffer_path)
    # print(json.dumps(params.as_dict(), indent=4))
    return params

if __name__ == "__main__":
    train_prompt()
    validation_prompt()

Training text:


Text(value='', layout=Layout(width='70%'))

Validation text:


Text(value='', layout=Layout(width='70%'))

Writing to ../data/train_buffer.tmp
Wrote: When I lived in *Paris last year, *France was experiencing a recession. The night life was too fun, I developed an addiction to !Adderall and !cocaine
Writing to ../data/validate_buffer.tmp
Wrote: I lived in Munich last summer. Germany has a relaxing, slow summer lifestyle. One night, I got food poisoning and couldn't find Tylenol to make the pain go away, they insisted I take aspirin instead


In [141]:
def main():
    assert TRAIN_DONE and VALIDATION_DONE
    params = set_params(TRAIN_BUFFER_PATH, VALIDATION_BUFFER_PATH)
    
    parms = params.duplicate()
    serialization_dir = tempfile.mkdtemp()
    model = train_model(params, serialization_dir)

    predpath = parms.pop(key="validation_data_path")

    # Make predictions
    predictor = InstPredictor(model, dataset_reader=InstDatasetReader())
    with open(predpath, "r") as text_file:
        lines = text_file.readlines()
    all_text = " ".join(lines) # Makes it all 1 batch.
    output_dict = predictor.predict(all_text)
    tags = output_dict['tags']
    dataset_reader = InstDatasetReader()
    
    PRINT_STDOUT = True
    
    with open("log.log", 'a') as log:
        for instance in dataset_reader._read(predpath):
            tokenlist = list(instance['sentence'])
            for i, token in enumerate(tokenlist):
                log.write(tags[i] + str(token) + "\n")
                if PRINT_STDOUT:
                    print(tags[i] + str(token))
    shutil.rmtree(serialization_dir)
    print("DONE.")

if __name__ == "__main__":
    main()

0it [00:00, ?it/s]
1it [00:00, 63.09it/s]

0it [00:00, ?it/s]
1it [00:00, 63.10it/s]

0it [00:00, ?it/s]
2it [00:00, 304.98it/s]

  0%|          | 0/500000 [00:00<?, ?it/s]
100%|##########| 500000/500000 [00:07<00:00, 64685.12it/s]

  0%|          | 0/500000 [00:00<?, ?it/s]
100%|##########| 500000/500000 [00:07<00:00, 68398.20it/s]

  0%|          | 0/1 [00:00<?, ?it/s]
accuracy: 0.8621, loss: 0.9244 ||: 100%|##########| 1/1 [00:00<00:00, 22.76it/s]

  0%|          | 0/1 [00:00<?, ?it/s]
accuracy: 0.8537, loss: 0.9700 ||: 100%|##########| 1/1 [00:00<00:00, 48.82it/s]



_I
_lived
_in
*Munich
_last
_summer
_.
*Germany
_has
_a
_relaxing
_,
_slow
_summer
_lifestyle
_.
_One
_night
_,
_I
_got
!food
_poisoning
_and
_could
*n't
_find
!Tylenol
_to
_make
_the
_pain
_go
_away
_,
_they
_insisted
_I
_take
!aspirin
_instead
DONE.
