# Imports and utils

In [None]:
import collections
import glob
import rst_lib
from rst_lib import TRAIN, DOUBLE
import yaml
import seaborn as sns
import pandas as pd

paths, files = rst_lib.build_file_map()

annotation_pairs = [
    rst_lib.build_annotation_pair(files, paths, identifier)
    for identifier in files[TRAIN][DOUBLE]
]


def f1_score(p, r):
  if not p+r:
    return 0
  return 2* p * r/(p+r)

def mean(l):
  if not l:
    assert False
  return sum(l)/len(l)


with open('label_classes.yaml', 'r') as f:
  LABEL_CLASS_MAP = yaml.load(f.read(), Loader=yaml.Loader)

Metrics = collections.namedtuple("Metrics", "s_p s_r r_p r_r")


def get_metrics(span_map_1, span_map_2):

  num_orig_spans = len(span_map_1)
  num_final_spans = len(span_map_2)
  true_positive_spans = set(span_map_1.keys()).intersection(span_map_2.keys())

  true_positive_relation_count = 0
  for span in true_positive_spans:
    if LABEL_CLASS_MAP[span_map_1[span]] == LABEL_CLASS_MAP[span_map_2[span]]:
      true_positive_relation_count += 1

  return Metrics(
    s_p=len(true_positive_spans)/num_final_spans,
    s_r=len(true_positive_spans)/num_orig_spans,
    r_p=true_positive_relation_count/num_final_spans,
    r_r=true_positive_relation_count/num_orig_spans)

In [None]:
def get_index_intervals(edus):
  total_edus = len(sum(edus[1:], []))
  
  index_intervals = []
  count = 0
  for edu in edus[1:]:
    index_intervals.append((count, count + len(edu)))
    count += len(edu)

  return list(sorted(index_intervals))


def reindex_intervals(old_intervals, shared_intervals):
  final_intervals = []
  for s in sorted(old_intervals):
    if s in shared_intervals:
      final_intervals.append(shared_intervals.index(s))
    else:
      final_intervals.append(list(i for i, p in enumerate(shared_intervals) if p[0] >=s[0] and p[0] <s[1]))
      
  return final_intervals

def create_new_edus(edus_1, edus_2):
  
  intervals_1 = get_index_intervals(edus_1)
  intervals_2 = get_index_intervals(edus_2)
  
  interval_ends = list(sorted(set([x[1] for x in intervals_1 + intervals_2])))
  interval_starts = [0] + interval_ends[:-1]
  
  new_spans = [*zip(interval_starts, interval_ends)]
  
  spans_1 = reindex_intervals(intervals_1, new_spans)
  spans_2 = reindex_intervals(intervals_2, new_spans)
  
      
  for x, y in zip(spans_1, spans_2):
    print(x,y)
  print()


In [None]:
file_level_info = []
for x in annotation_pairs: # Double-annotated files (train set only)
    if x is None or not x[3].is_valid or not x[2].is_valid:
        continue
    create_new_edus(x.main_annotation.edus, x.double_annotation.edus)
    main_span_map = x.main_annotation.span_map
    double_span_map = x.double_annotation.span_map
    metrics = get_metrics(main_span_map, double_span_map)
    file_level_info.append({
      "identifier":x.identifier,
      "span_f1":f1_score(metrics.s_p, metrics.s_r),
      "doc_len_chars":len(x.main_annotation.edus)})
    
file_level_df = pd.DataFrame.from_dict(file_level_info)

In [None]:
sns.histplot(file_level_df, x="doc_len_chars", bins=10)