In [1]:
%reload_ext autoreload
%autoreload 2

The history saving thread hit an unexpected error (DatabaseError('database disk image is malformed',)).History will not be written to the database.


In [19]:
from typing import Any, Callable, Dict, List, Optional

from recon.dataset import Dataset
from pydantic import root_validator
from recon.types import Example, Span, Token
import numpy as np
from recon.augmentation import augment_example
from recon.operations import operation, registry
from recon.preprocess import SpacyPreProcessor

import names
from snorkel.augmentation import transformation_function
from snorkel.preprocess.nlp import SpacyPreprocessor
from recon.preprocess import SpacyPreProcessor
import spacy

In [491]:
def substitute_spans(example: Example, span_subs: Dict[Span, str]) -> Example:
    """Substitute spans in an example. Replaces span text and alters the example text
    and span offsets to create a valid example.

    Args:
        example (Example): Input example
        span_subs (Dict[int, str]): Mapping of span hash to a str replacement text

    Returns:
        Example: Output example with substituted spans
    """
    span_sub_start_counter = 0

    new_example_text = example.text
    new_example_spans = []
    
    prev_example_spans = {hash(span) for span in example.spans}
    spans = sorted(set(list(span_subs.keys()) + example.spans), key=lambda s: s.start)
        
    for span in spans:
        should_add_span = hash(span) in prev_example_spans
        
        prev_end = span.end
        new_text = span.text

        if span in span_subs:
            new_text = span_subs[span]
            new_start = span.start + span_sub_start_counter
            new_end = new_start + len(new_text)

            new_example_text = (
                new_example_text[: span.start + span_sub_start_counter]
                + new_text
                + new_example_text[span.end + span_sub_start_counter :]
            )

            span.text = new_text
            span.start = new_start
            span.end = new_end
            
            span_sub_start_counter += new_end - prev_end
        else:
            span.start += span_sub_start_counter
            span.end = span.start + len(new_text)
            span_sub_start_counter = span.end - prev_end

        span.text = new_text
        
        if should_add_span:
            new_example_spans.append(span)
        
    example.text = new_example_text
    example.spans = new_example_spans

    return example


In [493]:
np.random.seed(0)

def augment_example(
    example: Example,
    span_f: Callable[[Span, Any], Optional[str]],
    spans: List[Span] = None,
    span_label: str = None,
    **kwargs: Any,
) -> List[Example]:

    if spans is None:
        spans = example.spans

    prev_example_hash = hash(example)
    example = example.copy(deep=True)
    example_t = None

    if span_label:
        spans = [s for s in spans if s.label == span_label]

    if spans:
        spans_to_sub = [np.random.choice(spans)]

        span_subs = {}
        for span in spans_to_sub:
            res = span_f(span, **kwargs)  #  type: ignore
            if res:
                span_subs[span] = res

        if any(span_subs.values()):
            res = substitute_spans(example, span_subs)
            if hash(res) != prev_example_hash:
                example_t = res.copy(deep=True)

    return example_t

In [None]:
class recon_augmentation:
    
    def __init__(self, )

In [517]:
np.random.seed(0)

from snorkel.augmentation.tf import transformation_function


def ent_label_sub(
    example: Example, label: str, subs: List[str]
) -> List[Example]:
    
    def augmentation_f(span: Span, subs: List[str]) -> Optional[str]:
        subs = [s for s in subs if s != span.text]
        sub = None
        if len(subs) > 0:
            sub = np.random.choice(subs)
        return sub

    return augment_example(example, span_f=augmentation_f, span_label=label, subs=subs)


replacement_names = [names.get_full_name() for _ in range(50)]


@transformation_function()
def person_sub(example: Example):
    return ent_label_sub(example.copy(deep=True), label="PERSON", subs=replacement_names)

@transformation_function()
def gpe_sub(example: Example):
    return ent_label_sub(example.copy(deep=True), label="GPE", subs=["Russia", "USA", "China"])

In [511]:
tfs = [
    person_sub,
    gpe_sub
]

In [512]:
np.random.seed(0)

from snorkel.augmentation import RandomPolicy

random_policy = RandomPolicy(
    len(tfs), sequence_length=2, n_per_original=2, keep_original=True
)

random_policy.generate_for_example()

[[], [0, 1], [1, 0]]

In [513]:
from tqdm import tqdm

from snorkel.augmentation.apply.core import BaseTFApplier


class ReconDatasetTFApplier(BaseTFApplier):
    
    def __init__(self, tfs, policy, span_label: str = None, sub_prob: float = 0.5):
        super().__init__(tfs, policy)
        self.span_label = span_label
        self.sub_prob = sub_prob
    
    def _apply_policy_to_data_point(self, x: Example) -> List[Example]:
        
        x_transformed = set()
        for seq in self._policy.generate_for_example():
            x_t = x.copy(deep=True)
            # Handle empty sequence for `keep_original`
            transform_applied = len(seq) == 0
            # Apply TFs
            for tf_idx in seq:
                tf = self._tfs[tf_idx]                
                x_t_or_none = tf(x_t)
                # Update if transformation was applied
                if x_t_or_none is not None:
                    transform_applied = True
                    x_t = x_t_or_none.copy(deep=True)
            # Add example if original or transformations applied
            if transform_applied:
                x_transformed.add(x_t)
        return list(x_transformed)


    def apply(self, ds: Dataset, progress_bar: bool = True) -> Dataset:
        
        @operation("recon.v1.augment")
        def augment(example: Example):
            transformed_examples = self._apply_policy_to_data_point(example)
            return transformed_examples
            
        ds.apply_("recon.v1.augment")
        
        return ds

In [514]:
from recon.recognizer import SpacyEntityRecognizer

r = SpacyEntityRecognizer(nlp)

examples = list(r.predict([
    "John lives in the United States",
    "Sarah lives in Germany"
]))

ds = Dataset("aug_test", examples)
ds.data

[Example(text='John lives in the United States', spans=[Span(text='John', start=0, end=4, label='PERSON', token_start=0, token_end=1, kb_id=None), Span(text='the United States', start=14, end=31, label='GPE', token_start=3, token_end=6, kb_id=None)], tokens=[Token(text='John', start=0, end=4, id=0), Token(text='lives', start=5, end=10, id=1), Token(text='in', start=11, end=13, id=2), Token(text='the', start=14, end=17, id=3), Token(text='United', start=18, end=24, id=4), Token(text='States', start=25, end=31, id=5)], meta={}, formatted=True),
 Example(text='Sarah lives in Germany', spans=[Span(text='Sarah', start=0, end=5, label='PERSON', token_start=0, token_end=1, kb_id=None), Span(text='Germany', start=15, end=22, label='GPE', token_start=3, token_end=4, kb_id=None)], tokens=[Token(text='Sarah', start=0, end=5, id=0), Token(text='lives', start=6, end=11, id=1), Token(text='in', start=12, end=14, id=2), Token(text='Germany', start=15, end=22, id=3)], meta={}, formatted=True)]

In [515]:
np.random.seed(0)


applier = ReconDatasetTFApplier(tfs, random_policy)
applier.apply(ds)

=> Applying operation 'recon.v1.augment' inplace
[38;5;2m✔ Completed operation 'recon.v1.augment'[0m


<recon.dataset.Dataset at 0x7f907c2185c0>

In [516]:
ds.data

[Example(text='Lori Schlueter lives in USA', spans=[Span(text='Lori Schlueter', start=0, end=14, label='PERSON', token_start=0, token_end=1, kb_id=None), Span(text='USA', start=24, end=27, label='GPE', token_start=3, token_end=6, kb_id=None)], tokens=[Token(text='John', start=0, end=4, id=0), Token(text='lives', start=5, end=10, id=1), Token(text='in', start=11, end=13, id=2), Token(text='the', start=14, end=17, id=3), Token(text='United', start=18, end=24, id=4), Token(text='States', start=25, end=31, id=5)], meta={}, formatted=True),
 Example(text='John lives in the United States', spans=[Span(text='John', start=0, end=4, label='PERSON', token_start=0, token_end=1, kb_id=None), Span(text='the United States', start=14, end=31, label='GPE', token_start=3, token_end=6, kb_id=None)], tokens=[Token(text='John', start=0, end=4, id=0), Token(text='lives', start=5, end=10, id=1), Token(text='in', start=11, end=13, id=2), Token(text='the', start=14, end=17, id=3), Token(text='United', start=1

In [269]:
ds.data

[Example(text='Kelley Williamson lives in China', spans=[Span(text='Kelley Williamson', start=0, end=17, label='PERSON', token_start=0, token_end=1, kb_id=None), Span(text='John', start=13, end=17, label='PERSON', token_start=0, token_end=1, kb_id=None), Span(text='the United States', start=15, end=32, label='GPE', token_start=3, token_end=6, kb_id=None), Span(text='China', start=27, end=32, label='GPE', token_start=3, token_end=6, kb_id=None)], tokens=[Token(text='John', start=0, end=4, id=0), Token(text='lives', start=5, end=10, id=1), Token(text='in', start=11, end=13, id=2), Token(text='the', start=14, end=17, id=3), Token(text='United', start=18, end=24, id=4), Token(text='States', start=25, end=31, id=5)], meta={}, formatted=True),
 Example(text='William Conn lives in the United States', spans=[Span(text='William Conn', start=0, end=12, label='PERSON', token_start=0, token_end=1, kb_id=None), Span(text='Jordan Floyd', start=0, end=12, label='PERSON', token_start=0, token_end=1, k

In [265]:
ds.operations

[]

In [224]:
x_transformed = []
for seq in self._policy.generate_for_example():
    x_t = x
    # Handle empty sequence for `keep_original`
    transform_applied = len(seq) == 0
    # Apply TFs
    for tf_idx in seq:


        if spans is None:
            spans = example.spans

        prev_example = x.copy(deep=True)
        if self.span_label:
            spans = [s for s in spans if s.label == self.span_label]
        mask = mask_1d(len(spans), prob=sub_prob)
        spans_to_sub = list(np.asarray(spans)[mask])

        span_subs = {}
        tf = self._tfs[tf_idx]
        for span in spans_to_sub:
            x_t_or_none = tf(span, **kwargs)  #  type: ignore
            if x_t_or_none is not None:
                transform_applied = True
                span_subs[hash(span)] = res


        x_t_or_none = tf(x_t)
        # Update if transformation was applied
        if x_t_or_none is not None:
            transform_applied = True
            x_t = x_t_or_none
    # Add example if original or transformations applied
    if transform_applied:
        x_transformed.append(x_t)
return x_transformed


NameError: name 'self' is not defined

In [13]:
# spacy = SpacyPreprocessor(text_field="text", doc_field="doc", memoize=True)

spacy_pre = SpacyPreProcessor(nlp)

# Pregenerate some random person names to replace existing ones with
# for the transformation strategies below
replacement_names = [names.get_full_name() for _ in range(50)]


def make_ent_label_sub_tf(label, subs):
    # Replace a random named entity with a different entity of the same type.
#     @operation(f"recon.v1.{label}_subs", pre=[spacy_pre], augmentation=True)
    
    def augmentation(span: Span, subs: List[str]) -> Optional[str]:
        subs = [s for s in subs if s != span.text]
        sub = None
        if len(subs) > 0:
            sub = np.random.choice(subs)
        return sub
        
    return change_ents


change_person_ents = make_ent_label_sub_tf("PERSON", names)
change_gpe_ents = make_ent_label_sub_tf("GPE", ["Russia", "China", "Mongolia"])

In [14]:
tfs = [
    change_person_ents,
    change_gpe_ents
]

In [17]:
from snorkel.augmentation import RandomPolicy

random_policy = RandomPolicy(
    len(tfs), sequence_length=2, n_per_original=2, keep_original=True
)

random_policy.generate_for_example()

[[], [1, 1], [0, 1]]