In [11]:
!pip install fastapi

Collecting fastapi
  Using cached https://files.pythonhosted.org/packages/9f/33/1b643f650688ad368983bbaf3b0658438038ea84d775dd37393d826c3833/fastapi-0.63.0-py3-none-any.whl
Collecting starlette==0.13.6 (from fastapi)
  Using cached https://files.pythonhosted.org/packages/c5/a4/c9e228d7d47044ce4c83ba002f28ff479e542455f0499198a3f77c94f564/starlette-0.13.6-py3-none-any.whl
Installing collected packages: starlette, fastapi
Successfully installed fastapi-0.63.0 starlette-0.13.6


In [1]:
%reload_ext autoreload
%autoreload 2

The history saving thread hit an unexpected error (DatabaseError('database disk image is malformed')).History will not be written to the database.


In [10]:
from collections import OrderedDict
from pathlib import Path
from typing import Any, Callable, Dict, List

from fastapi import params
from fastapi.dependencies.utils import get_typed_signature, get_param_field, is_scalar_sequence_field, request_body_to_args, request_params_to_args
from pydantic.fields import ModelField

from recon.corrections import fix_annotations, corrections_from_dict
from recon.dataset import Dataset
from recon.loaders import read_jsonl
from recon.operations import registry
from recon.operations.utils import (
    get_received_operation_data,
    get_required_operation_params,
    request_body_to_args,
)
from recon.types import Example, OperationState

In [11]:
def example_data() -> Dict[str, List[Example]]:
    """Fixture to load example train/dev/test data that has inconsistencies.

    Returns:
        Dict[str, List[Example]]: Dataset containing the train/dev/test split
    """
    base_path = Path("../") / "examples/data/skills"
    return {
        "train": read_jsonl(base_path / "train.jsonl"),
        "dev": read_jsonl(base_path / "dev.jsonl"),
        "test": read_jsonl(base_path / "test.jsonl"),
    }

In [12]:
data = example_data()
train_dataset = Dataset("train", data["train"])

corrections = corrections_from_dict(
    {"software development engineer": "JOB_ROLE", "model": None}
)
print("CORRECTIONS:", corrections)
train_dataset.apply_("recon.v1.fix_annotations", corrections)

 46%|████▌     | 49/106 [00:00<00:00, 484.66it/s]

CORRECTIONS: [Correction(annotation='software development engineer', from_labels=['ANY'], to_label='JOB_ROLE'), Correction(annotation='model', from_labels=['ANY'], to_label=None)]
=> Applying operation 'recon.v1.fix_annotations' to dataset 'train'
VALUES:  {'corrections': [Correction(annotation='software development engineer', from_labels=['ANY'], to_label='JOB_ROLE'), Correction(annotation='model', from_labels=['ANY'], to_label=None)], 'case_sensitive': False, 'dryrun': False}


100%|██████████| 106/106 [00:00<00:00, 503.65it/s]

[38;5;2m✔ Completed operation 'recon.v1.fix_annotations'[0m





In [13]:
tmp_path = "./test_argument_resolution_dataset/"

In [14]:
train_dataset.to_disk(tmp_path, force=True)

In [15]:
train_dataset_loaded_2 = Dataset("train").from_disk(tmp_path)

In [16]:
op = train_dataset_loaded_2.operations[0]

In [19]:
op.args, op.kwargs

([],
 {'corrections': [{'annotation': 'software development engineer',
    'from_labels': ['ANY'],
    'to_label': 'JOB_ROLE'},
   {'annotation': 'model', 'from_labels': ['ANY'], 'to_label': None}],
  'case_sensitive': False,
  'dryrun': False})

In [21]:
required_params = get_required_operation_params(registry.operations.get(op.name).op)
received_data = get_received_operation_data(required_params, op)

required_params, received_data

(OrderedDict([('corrections',
               ModelField(name='corrections', type=List[Correction], required=True)),
              ('case_sensitive',
               ModelField(name='case_sensitive', type=bool, required=False, default=False)),
              ('dryrun',
               ModelField(name='dryrun', type=bool, required=False, default=False))]),
 {'corrections': [{'annotation': 'software development engineer',
    'from_labels': ['ANY'],
    'to_label': 'JOB_ROLE'},
   {'annotation': 'model', 'from_labels': ['ANY'], 'to_label': None}],
  'case_sensitive': False,
  'dryrun': False})

In [22]:
values, errors = request_body_to_args(list(required_params.values()), received_data)
values, errors

({'corrections': [Correction(annotation='software development engineer', from_labels=['ANY'], to_label='JOB_ROLE'),
   Correction(annotation='model', from_labels=['ANY'], to_label=None)],
  'case_sensitive': False,
  'dryrun': False},
 [])

In [132]:
required_params["strip_chars"].validate(values)

ModelField(name='strip_chars', type=List[str], required=False, default=['.', '!', '?', '-', ':', ' '])

In [23]:
fix_annotations(train_dataset_loaded_2.data[0], **values)

Example(text='History of Texas from Spanish period to present day.', spans=[Span(text='History', start=0, end=7, label='SKILL', token_start=0, token_end=1, kb_id=None), Span(text='Spanish', start=22, end=29, label='SKILL', token_start=4, token_end=5, kb_id=None)], tokens=[Token(text='History', start=0, end=7, id=0), Token(text='of', start=8, end=10, id=1), Token(text='Texas', start=11, end=16, id=2), Token(text='from', start=17, end=21, id=3), Token(text='Spanish', start=22, end=29, id=4), Token(text='period', start=30, end=36, id=5), Token(text='to', start=37, end=39, id=6), Token(text='present', start=40, end=47, id=7), Token(text='day', start=48, end=51, id=8), Token(text='.', start=51, end=52, id=9)], meta={'source': 'Courses', 'sourceLink': 'https://catalog.tamu.edu/undergraduate/general-information/university-core-curriculum/'}, formatted=True)

In [31]:
req_params = get_required_operation_params(registry.operations.get("recon.v1.strip_annotations").op)
request_body_to_args(list(req_params.values()), {"strip_chars": ['.', '!', '?', '-', ':', ' ']})[1][0].exc

pydantic.errors.ListError()