In [1]:
import sys
sys.path.append('../scripts')

In [2]:
from rules.cnf_resolver import RulebasedCNFResolver
from evaluation import Metrics
import pandas as pd
from pathlib import Path
from dataset import load_data
from evaluation import error_analysis, get_scores

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import hydra
from hydra import compose, initialize

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=Path('..'), job_name='foo', version_base='1.1')
config = compose(config_name='experiment.yaml')

In [4]:
#!python -m spacy download de_core_news_lg

## Load Data

In [14]:
train_df, valid_df, test_df = load_data(Path('..') / config.data.cnf_tsv_path, Path('..') / config.data.controls_tsv_path)

In [7]:
#m = Metrics(['exact_match', 'google_bleu'], tokenizer=None)

In [8]:
#def print_metrics(pred, ground_truth):
#    exact = m.compute_exact_match(pred, ground_truth)
#    bleu = m.compute_bleu(pred, ground_truth)
#    print(f"Exact match: {exact:.3f} \n GLEU score: {bleu:.3f}")

### Optimal values

In [18]:
ea = error_analysis(valid_df.full_resolution.values, valid_df.full_resolution.values, valid_df.raw_sentence.values)
get_scores(ea, "eval")

{'eval/tp': 1.0,
 'eval/tp_abs': 1031,
 'eval/fn': 0.0,
 'eval/fn_abs': 0,
 'eval/replace': 0.0,
 'eval/replace_abs': 0,
 'eval/insert': 0.0,
 'eval/insert_abs': 0,
 'eval/delete': 0.0,
 'eval/delete_abs': 0,
 'eval/complex': 0.0,
 'eval/complex_abs': 0,
 'eval/edit_distance_rel': 1.0,
 'eval/exact_match': 1.0,
 'eval/gleu': 1.0,
 'eval/edit_distance_abs': 0.0}

## Trivial baseline: do nothing

In [17]:
ea = error_analysis(valid_df.raw_sentence.values, valid_df.full_resolution.values, valid_df.raw_sentence.values)
get_scores(ea, "eval")

{'eval/tp': 0.5053346265761397,
 'eval/tp_abs': 521,
 'eval/fn': 0.49466537342386036,
 'eval/fn_abs': 510,
 'eval/replace': 0.0,
 'eval/replace_abs': 0,
 'eval/insert': 0.0,
 'eval/insert_abs': 0,
 'eval/delete': 0.0,
 'eval/delete_abs': 0,
 'eval/complex': 0.0,
 'eval/complex_abs': 0,
 'eval/edit_distance_rel': 0.5053346265761397,
 'eval/exact_match': 0.5053346265761397,
 'eval/gleu': 0.9431774678698387,
 'eval/edit_distance_abs': 5.346265761396702}

In [19]:
ea = error_analysis(test_df.raw_sentence.values, test_df.full_resolution.values, test_df.raw_sentence.values)
get_scores(ea, "test")

{'test/tp': 0.5154845154845155,
 'test/tp_abs': 516,
 'test/fn': 0.48451548451548454,
 'test/fn_abs': 485,
 'test/replace': 0.0,
 'test/replace_abs': 0,
 'test/insert': 0.0,
 'test/insert_abs': 0,
 'test/delete': 0.0,
 'test/delete_abs': 0,
 'test/complex': 0.0,
 'test/complex_abs': 0,
 'test/edit_distance_rel': 0.5154845154845155,
 'test/exact_match': 0.5154845154845155,
 'test/gleu': 0.9425790379824756,
 'test/edit_distance_abs': 5.588411588411589}

## Rule-based resolver

In [20]:
resolver = RulebasedCNFResolver(5)

In [22]:
valid_preds = resolver.predict_all(valid_df.raw_sentence)

In [24]:
ea = error_analysis(valid_preds, valid_df.full_resolution.values, valid_df.raw_sentence.values)
get_scores(ea, "eval")

{'eval/tp': 0.6207565470417071,
 'eval/tp_abs': 640,
 'eval/fn': 0.24733268671193018,
 'eval/fn_abs': 255,
 'eval/replace': 0.05625606207565471,
 'eval/replace_abs': 58,
 'eval/insert': 0.016488845780795344,
 'eval/insert_abs': 17,
 'eval/delete': 0.03976721629485936,
 'eval/delete_abs': 41,
 'eval/complex': 0.019398642095053348,
 'eval/complex_abs': 20,
 'eval/edit_distance_rel': 0.702262819636803,
 'eval/exact_match': 0.6207565470417071,
 'eval/gleu': 0.9556788975775141,
 'eval/edit_distance_abs': 3.5460717749757515}

In [25]:
test_preds = resolver.predict_all(test_df.raw_sentence)

In [26]:
ea = error_analysis(test_preds, test_df.full_resolution.values, test_df.raw_sentence.values)
get_scores(ea, "test")

{'test/tp': 0.6723276723276723,
 'test/tp_abs': 673,
 'test/fn': 0.18681318681318682,
 'test/fn_abs': 187,
 'test/replace': 0.05194805194805195,
 'test/replace_abs': 52,
 'test/insert': 0.013986013986013986,
 'test/insert_abs': 14,
 'test/delete': 0.03996003996003996,
 'test/delete_abs': 40,
 'test/complex': 0.03496503496503497,
 'test/complex_abs': 35,
 'test/edit_distance_rel': 0.7584831053670736,
 'test/exact_match': 0.6723276723276723,
 'test/gleu': 0.9597314790771035,
 'test/edit_distance_abs': 3.3226773226773227}

In [2]:
# Analyze errors for baseline, e.g., failure to detect TRUNC, complex ellipses, etc.??