In [13]:
import pandas as pd
import ast

In [14]:
df = pd.read_csv('output_train.csv')
df['A_raw_entities'][0]

"[{'entity': 'MISC', 'score': 0.99866974, 'index': 4, 'word': 'ĠAmericans', 'start': 14, 'end': 23}]"

In [None]:
def tag_A_entities(text, raw_entities):
    import ast
    try:
        entities = ast.literal_eval(raw_entities)
    except:
        return text
    
    # sort by start index to avoid messing up positions as we insert
    entities = sorted(entities, key=lambda x: x['start'])
    
    offset = 0
    for ent in entities:
        label = ent.get('entity')
        start = ent.get('start')
        end = ent.get('end')
        if start is None or end is None or not label:
            continue
        
        start += offset
        end += offset
        start_tag = f"<{label}>"
        end_tag = f"</{label}>"
        text = text[:start] + start_tag + text[start:end] + end_tag + text[end:]
        offset += len(start_tag) + len(end_tag)
    
    return text

def tag_B_entities(text, raw_entities):
    import ast
    try:
        entities = ast.literal_eval(raw_entities)
    except:
        return text
    
    tagged = text
    offset = 0

    for ent in entities:
        word = ent.get('word')
        label = ent.get('entity')
        if not word or not label:
            continue

        start = tagged.find(word, offset)
        if start == -1:
            continue

        end = start + len(word)
        start_tag = f"<{label}>"
        end_tag = f"</{label}>"
        tagged = tagged[:start] + start_tag + word + end_tag + tagged[end:]
        offset = end + len(start_tag) + len(end_tag)

    return tagged

In [16]:
df['A_tagged'] = df.apply(lambda row: tag_A_entities(row['statement'], row['A_raw_entities']), axis=1)
df['B_tagged'] = df.apply(lambda row: tag_B_entities(row['statement'], row['B_raw_entities']), axis=1)

In [17]:
print(df.head())
df.to_csv('AB_tagged_train.csv', index=False)

                                           statement  label  label_binary  \
0  90 percent of Americans "support universal bac...      5             1   
1  Last year was one of the deadliest years ever ...      1             0   
2  Bernie Sanders's plan is "to raise your taxes ...      0             0   
3  Voter ID is supported by an overwhelming major...      4             1   
4  Says Barack Obama "robbed Medicare (of) $716 b...      2             0   

                                      A_raw_entities  \
0  [{'entity': 'MISC', 'score': 0.99866974, 'inde...   
1                                                 []   
2  [{'entity': 'PER', 'score': 0.9983652, 'index'...   
3  [{'entity': 'MISC', 'score': 0.9153446, 'index...   
4  [{'entity': 'PER', 'score': 0.9980445, 'index'...   

                                      B_raw_entities  \
0  [{'word': '90 percent', 'entity': 'PERCENT'}, ...   
1  [{'word': 'Last year', 'entity': 'DATE'}, {'wo...   
2  [{'word': "Bernie Sanders's",

In [18]:

entity_labels = set()

for row in df['A_raw_entities']:
    try:
        entities = ast.literal_eval(row)
    except:
        continue
    for ent in entities:
        label = ent.get('entity')
        if label:
            entity_labels.add(label)

# now create special tokens
special_tokens = []
for label in entity_labels:
    special_tokens.append(f"[{label}_]")
    special_tokens.append(f"[/{label}_]")

print(special_tokens)

['[PER_]', '[/PER_]', '[MISC_]', '[/MISC_]', '[LOC_]', '[/LOC_]', '[ORG_]', '[/ORG_]']
