In [1]:
import os
import re
import ast 
import unittest
import sys
import torch
import pandas as pd
import json
from sklearn.metrics import f1_score, accuracy_score
from nltk.tokenize import sent_tokenize
import numpy as np
from collections import Counter
from tqdm.auto import tqdm

sys.path.append('../ifcc/')
sys.path.append('../ifcc/tests/')
sys.path.append('../guidedsum/')

from test_nli import TestSimpleNLI
from annotation_utils import expand_to_borders, Annotation

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [3]:
def get_majority_span_annotation(row):
    lst = row['annotations']

    lst_annotations = []

    for l in lst:
        a = Annotation(start=l['start'], end=l['end'], label=list(row['majority_vote'].keys())[0], annotator='majority', document_id=row['study_id'])
        lst_annotations.append(a)
    
    majority_annotion = expand_to_borders(lst_annotations)[0]
    span_indices = {
        'start' : majority_annotion.start,
        'end' : majority_annotion.end
    }
    span = row['candidate'][majority_annotion.start:majority_annotion.end]
    return span, span_indices

In [4]:
df_additions = pd.read_json('../error-analysis/annotations/additions.jsonl', lines=True)
df_additions['majority_vote'] = df_additions['majority_vote'].apply(Counter)
df_additions['annotation_kind'].value_counts()


EXTEND    186
ACCEPT     76
REVIEW     33
Name: annotation_kind, dtype: int64

In [5]:
# df_additions.groupby(['study_id', 'candidate_name']).size().rename('number of additions').reset_index()
df_filtered = df_additions[
    (df_additions['majority_vote'].apply(lambda x: x['2a'] >= 1 and len(x.keys()) == 1)) &
    ((df_additions['annotation_kind'] == 'ACCEPT') | (df_additions['annotation_kind'] == 'EXTEND'))]
df_filtered = df_filtered.copy()
df_filtered.head()

Unnamed: 0,study_id,candidate_name,annotations,annotation_kind,majority_vote,reference,candidate,findings+bg
0,50178679,wgsum,"[{'annotator': 'annotator4', 'end': 65, 'start...",EXTEND,{'2a': 1},No acute cardiopulmonary process based on this...,no definite acute cardiopulmonary process give...,History:\n_-year-old female with fever and cou...
2,50296389,wgsum+cl,"[{'annotator': 'annotator1', 'end': 96, 'start...",EXTEND,{'2a': 1},Improving right hydropneumothorax with right l...,decreased though persistent right-sided hydrop...,Indication:\nPatient with collapsed right in s...
3,50394941,bertabs,"[{'annotator': 'annotator1', 'end': 118, 'star...",EXTEND,{'2a': 1},"ET tube ends 2.5 cm above the carina, and coul...",endotracheal tube ends approximately 2.5 cm ab...,"Indication:\n_-year-old, unresponsive man stat..."
4,50394941,bertabs,"[{'annotator': 'annotator2', 'end': 182, 'star...",EXTEND,{'2a': 1},"ET tube ends 2.5 cm above the carina, and coul...",endotracheal tube ends approximately 2.5 cm ab...,"Indication:\n_-year-old, unresponsive man stat..."
5,50394941,bertabs,"[{'annotator': 'annotator1', 'end': 215, 'star...",EXTEND,{'2a': 1},"ET tube ends 2.5 cm above the carina, and coul...",endotracheal tube ends approximately 2.5 cm ab...,"Indication:\n_-year-old, unresponsive man stat..."


In [6]:
majority_spans = []
majority_annotations = []

for index, row in df_filtered.iterrows():
    span, annotation = get_majority_span_annotation(row)
    majority_spans.append(span)
    majority_annotations.append(annotation)

df_filtered['majority_span'] = majority_spans
df_filtered['majority_annotation'] = majority_annotations

In [7]:
df_filtered['findings'] = df_filtered['findings+bg'].apply(lambda x: re.search(r"^Findings:\n(.*)", x, flags=re.MULTILINE).group(1))

In [8]:
nli = TestSimpleNLI()
nli.setUp()

In [9]:
probas = []
preds = []
report_sents = []

for i, row in tqdm(df_filtered.iterrows(), total=len(df_filtered)):
    x_sents = sent_tokenize(row['findings'])
    report_sents.append(x_sents)
    addition_expanded = [row['majority_span']]*len(x_sents)
   
    rs = nli.nli.predict(x_sents, addition_expanded)
    rs_proba = rs[0]
    rs_preds = rs[1]

    probas.append(rs_proba)
    preds.append(rs_preds)

  0%|          | 0/206 [00:00<?, ?it/s]

In [10]:
final_preds = []
for p in preds:
    if 'contradiction' in p:
        final_preds.append('-')
    elif 'entailment' in p:
        final_preds.append('+')
    else:
        final_preds.append('o')

In [11]:
df_filtered['preds_counter'] = [Counter(p) for p in preds]
df_filtered['preds'] = preds
df_filtered['probas'] = probas
df_filtered['sents'] = report_sents 
df_filtered['final_pred'] = final_preds

In [12]:
c = Counter(final_preds)
for k, v in c.items():
    print(k, ' ', np.round((v*100)/len(df_filtered), 2))

+   33.01
-   25.73
o   41.26


Raw prediction counts by method

In [13]:
df_filtered.groupby('candidate_name')['final_pred'].value_counts().unstack()

final_pred,+,-,o
candidate_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
bertabs,15,11,21
gsum_thresholding,20,17,21
wgsum,16,12,22
wgsum+cl,17,13,21


Normalize and export as latex

In [14]:
df = df_filtered.groupby('candidate_name')['final_pred'].value_counts(normalize=True).unstack()
df = df.rename({'+': 'Entail', '-': 'Contradict', 'o': 'Neutral'}, axis=1)
df = df.rename({
    'bertabs': 'BertAbs',
    'gsum_thresholding': 'GSum w/ Thresholding',
    'wgsum': 'WGSum',
    'wgsum+cl': 'WGSum+CL',
}, axis=0)
df = df[['Entail', 'Neutral', 'Contradict']]
df.index.name = ''
df.columns.name = ''
df = (df * 100).round(1)
display(df)


tex = df.to_latex(
    na_rep="-",
    position='t',
    escape=False,
    index_names=True,
    column_format='l' + 'r' * len(df.columns), # Align: left for row label, right for all numbers
    multicolumn_format='c',
)
tex = tex.replace('table', 'table*')
tex = tex.replace('\\centering', '\\small\n\centering')
# tex = tex.replace('\centering', '\centering\n\\resizebox{\\textwidth}{!}{')
# tex = tex.replace('\end{tabular}', '\end{tabular}}')
tex = tex.replace('Method', '\\textbf{Method}')

for c in df.columns:
    tex = tex.replace(c, '\\textbf{' + c + '}')
    
tex = re.sub(r' +', ' ', tex)
print(tex)

Unnamed: 0,Entail,Neutral,Contradict
,,,
BertAbs,31.9,44.7,23.4
GSum w/ Thresholding,34.5,36.2,29.3
WGSum,32.0,44.0,24.0
WGSum+CL,33.3,41.2,25.5


\begin{table*}[t]
\small
\centering
\begin{tabular}{lrrr}
\toprule
{} & \textbf{Entail} & \textbf{Neutral} & \textbf{Contradict} \\
 & & & \\
\midrule
BertAbs & 31.9 & 44.7 & 23.4 \\
GSum w/ Thresholding & 34.5 & 36.2 & 29.3 \\
WGSum & 32.0 & 44.0 & 24.0 \\
WGSum+CL & 33.3 & 41.2 & 25.5 \\
\bottomrule
\end{tabular}
\end{table*}

