In [40]:
import numpy as np
import re
import csv
import torch
import spacy
from spacy.tokens import Token
from allennlp.common.file_utils import cached_path
from allennlp.predictors import SemanticRoleLabelerPredictor


# Create custom SRL component for Spacy pipeline using AllenNLP
class SRL(object):

    name = "Semantic Role Labeler"  # Component name

    def __init__(self):
        # Create token extension (for args), name it, set default value and override previously set values
        Token.set_extension("srl_arg0", default=False, force=True)  # agent
        Token.set_extension("srl_arg1", default=False, force=True)  # patient
        Token.set_extension("srl_arg2", default=False, force=True)  # instrument, benefactive, attribute
        Token.set_extension("srl_arg3", default=False, force=True)  # starting point, benefactive, attribute
        Token.set_extension("srl_arg4", default=False, force=True)  # ending point
        Token.set_extension("srl_arg5", default=False, force=True)  # modifier
        
        # Create token extension (for modifiers), name it, set default value and override previously set values
        # Commented-out modifiers will not be used for the purpose of this project, left for completion
        Token.set_extension("srl_argCOM", default=False, force=True)  # Comitatives
        #Token.set_extension("srl_argLOC", default=False, force=True)  # Locatives
        #Token.set_extension("srl_argDIR", default=False, force=True)  # Directional
        Token.set_extension("srl_argGOL", default=False, force=True)  # Goal
        Token.set_extension("srl_argMNR", default=False, force=True)  # Manner
        Token.set_extension("srl_argTMP", default=False, force=True)  # Temporal
        Token.set_extension("srl_argEXT", default=False, force=True)  # Extent
        #Token.set_extension("srl_argREC", default=False, force=True)  # Reciprocal
        Token.set_extension("srl_argPRD", default=False, force=True)  # Secondary Predication
        Token.set_extension("srl_argPRP", default=False, force=True)  # Prepositional
        Token.set_extension("srl_argCAU", default=False, force=True)  # Cause
        #Token.set_extension("srl_argDIS", default=False, force=True)  # Discourse
        #Token.set_extension("srl_argMOD", default=False, force=True)  # Modal
        Token.set_extension("srl_argNEG", default=False, force=True)  # Negation
        #Token.set_extension("srl_argDSP", default=False, force=True)  # Direct speech "quote"
        #Token.set_extension("srl_argLVB", default=False, force=True)  # Light verb
        Token.set_extension("srl_argADV", default=False, force=True)  # Adverbials
        Token.set_extension("srl_argADJ", default=False, force=True)  # Adjectival
    
    def __call__(self, doc):
        
        srlpredictor = Predictor.from_path(cached_path("https://s3-us-west-2.amazonaws.com/allennlp/models/srl-model-2018.05.25.tar.gz"))
        words = [token.text for token in doc]
        # Loop through words, find verbs. For each verb:
        # Create verb_labels list to indicate position of verb
        # Create instance needed to pass to srlpredictor using verb_labels mapping and doc
        # Produce model output consisting of BIO tags 
        for i, word in enumerate(doc):
            if word.pos_ == "VERB":
                verb_labels = [0 for _ in words]
                verb_labels[i] = 1
                instance = srlpredictor._dataset_reader.text_to_instance(doc, verb_labels)
                output = srlpredictor._model.forward_on_instance(instance)
                tags = output['tags']
                # Use to find tags for args and modifiers
                generic = r'B-ARGM?-?(([0-5]|COM|GOL|MNR|TMP|EXT|PRD|PRP|CAU|NEG|ADV|ADJ))'
                # Iterate through tags produced by srlpredictor 
                # Set values on token for args and modifiers using doc[span]
                for tag in tags:
                    match = re.search(generic, tag)
                    if match:
                        argnum = match.group(2)
                        inspan = r'I-ARG' + argnum
                        extension = "srl_arg" + argnum
                        start = tags.index(match.group())
                        end = max([i for i, x in enumerate(tags) if x == inspan] + [start]) + 1
                        word._.set(extension, doc[start:end])
        return doc
