In [4]:
import pandas as pd
from transformers import AutoTokenizer
import re
from random import randint
import json
import os
from sklearn.model_selection import train_test_split

#length limitations
max_input_len = 300
data_size = None#1000

INPUT = '../data/multicite/'
OUTPUT = f'../data/multicite/'
if data_size: OUTPUT = f'../data/multicite/{data_size}/'

# as ther is not such a hughe difference between the token count we use mistrals tokenizer for the evaluation
tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-Instruct-v0.3')

In [5]:
# load data
with open('../data/multicite/full_raw.json', 'r') as file:
    full_raw = json.load(file)

#split data into train / test
train_data_keys, test_data_keys = train_test_split(list(full_raw.keys()), test_size=0.2, random_state=82)
train_data, test_data = map(lambda keys: {x: full_raw[x] for x in keys}, [train_data_keys, test_data_keys])

In [6]:
#helper
def label_mapping(label):
    if label == '@BACK@': return 'BACKGROUND'
    if label == '@MOT@': return 'MOTIVATION'
    if label == '@USE@': return 'USE'
    if label == '@EXT@': return 'EXTENDS'
    if label == '@SIM@': return 'SIMILARITY'
    if label == '@DIF@': return 'DIFFERENCE'
    if label == '@FUT@': return 'FUTURE'
    if label =='BACKGROUND': return 0
    if label =='MOTIVATION': return 1
    if label =='USE': return 2
    if label =='EXTENDS': return 3
    if label =='SIMILARITY': return 4
    if label =='DIFFERENCE': return 5
    if label =='FUTURE': return 6
    print(label)

def seperate_segments(text_json):
    # create sent lookup
    sent_lookup = {sent_entry['sent_id']:sent_entry['text'] for sent_entry in text_json}

    # sort sentences keys:
    sorted_keys = sorted(list(sent_lookup.keys()), key=lambda x: int(x.split('-')[-1]))
    
    #seperate segments
    key_sec_lookup, sec_arr = {}, [[]]
    current_sec = 0
    sec_idx = 0
    for key in sorted_keys:
        text = sent_lookup[key]
        if re.match(r'----------------------------------', text):
            current_sec += 1
            sec_idx = 0
            sec_arr.append([])
            key_sec_lookup[key] = -1
        else:
            sec_arr[current_sec].append(text)
            key_sec_lookup[key]= (current_sec, sec_idx)
            sec_idx += 1
    return key_sec_lookup, sec_arr

def get_cit_dict(label_json):
    res_dict = {}
    for label, context_data in label_json.items():
        if label == '@UNSURE@': continue
        if len(context_data['gold_contexts']) != len(context_data['cite_sentences']):continue
        for sent_id, context in zip(context_data['cite_sentences'], context_data['gold_contexts']):
            if sent_id in res_dict:
                res_dict[sent_id]['label'].append(label)
                res_dict[sent_id]['context'] = list(set(res_dict[sent_id]['context'] + context))
            else:
                res_dict[sent_id] = {
                    'label': [label],
                    'context': context
                }
    return res_dict

def replace_citations(text):
    #replace and fix #AUTHOR_TAG
    text = re.sub(r'#AUTHOR_TAG', '#TARGET_REF', text)
    rext = re.sub(r'(?:van |)(?:[A-Z][a-z]+-?)+ ?(?:and [A-Za-z-]{2,}|et al?t?\.)?(?: |,)*?#TARGET_REF', '#TARGET_REF', text) #fix wrong parsing
    
    #replace all other citatons with #REF
    text = re.sub(r'(?:van |)(?:[A-Z][a-z]+-?)+ ?(?:and [A-Za-z-]{2,}|et al?t?\.)?(?: |,)*?(?:\d{4}|9\d|\[ ?\d{1,2} ?\]|\( ?\d{4} ?\))', '#REF', text)
    
    return text

def preprocess_data(data, max_len):
    res_df = pd.DataFrame(columns=['text', 'label', 'context'])
    for big_key in data.keys():
        sent_lookup, sent_arr = seperate_segments(data[big_key]['x'])
        cit_dict = get_cit_dict(data[big_key]['y'])

        for key in cit_dict.keys():
            
            #check if key in lookup
            if key not in sent_lookup.keys(): continue
            
            #set basic data
            labels, context = cit_dict[key].values()
            sec_id, citing_sent_id = sent_lookup[key]
            heading = sent_arr[sec_id][0]

            #add all context sentences and enclosed sentences to input_arr
            sorted_context = sorted(context, key=lambda x: int(x.split('-')[-1]))
            context_sent_ids = [sent_lookup[context_sent][1] if sent_lookup[context_sent]!= -1 else -1 for context_sent in sorted_context]
            if -1 in context_sent_ids: continue
            input_arr = list(range(context_sent_ids[0], context_sent_ids[-1] +1))
            #chekc if input_arr > 5 -> continue
            if len(input_arr) > 5 or len(input_arr) ==0 or len(sent_arr[sec_id]) < 3: continue

            # randomly add sentences to front / back until 5 sentences or all sentences are added
            while len(input_arr) < 5 and len(input_arr) < len(sent_arr[sec_id])-1:
                chance = randint(0,1)
                prev = input_arr[0] -1
                next = input_arr[-1] + 1
                if chance == 0 and prev > 0:
                    input_arr.insert(0, prev)
                elif chance == 1 and next < len(sent_arr[sec_id]):
                    input_arr.append(next)

            #create res_text
            res_text = []
            for i, idx in enumerate(input_arr):
                sent = sent_arr[sec_id][idx]
                
                #replace all xml annotation
                if idx == citing_sent_id:
                    clean_sent = re.sub(r'<span.*?>(.*?)<\/span>',' #TARGET_REF', sent)
                    clean_sent = ' '.join(clean_sent.split())
                else:
                    clean_sent = re.sub(r'<span.*?>(.*?)<\/span>',r'\1', sent)
                    
                #replace all other reference marker
                clean_sent = re.sub(r'(?:van |)(?:[A-Z][a-z]+-?)+ ?(?:and [A-Za-z-]{2,}|et al?t?\.)?(?: |,)*?(?:\d{4}|9\d|\[ ?\d{1,2} ?\]|\( ?\d{4} ?\))', '#REF', clean_sent)
                
                #add to text
                res_text.append(clean_sent)
            
            # labels
            res_labels = [label_mapping(label) for label in labels]
            
            # context
            res_context = [1 if id in context_sent_ids else 0 for id in input_arr]
            
            #add to res_df
            res_df.loc[len(res_df)] = [res_text, res_labels, res_context]
            
    return res_df

In [7]:
#create different task shemata for acl-arc
'''data schema
    [{
        "gold": {
            "text": [<str>],
            "label": [<str>]
        },
        "input": <str>
        "output": <str>
    }]
'''

def create_xml_data(df):
    res_data = []
    for idx, row in df.iterrows():
        text = row['text']
        labels = row['label']
        context = row['context']
        input = ' '.join(row['text'])
        
        #set context XML tags
        sent_arr = []
        prev_label = 0
        for sent, label in zip(row['text'],row['context']):
            if prev_label == 0 and label == 1:
                sent = '<CONT> ' + sent
            elif prev_label == 1 and label == 0:
                sent ='</CONT> ' + sent
            sent_arr.append(sent)
            prev_label = label
            
        output = ' '.join(sent_arr)
        #add closing tag if last sentence is context
        if prev_label == 1:
            output += ' </CONT>'
        
        
        #set label XML tag
        tag_location = output.find('#TARGET_REF') + len('#TARGET_REF')
        label_tags = ''.join([f'<{label}/>' for label in labels])
        output = output[:tag_location] + label_tags + output[tag_location:]
        
        # add example to response
        res_data.append({
            "gold": {
                "text": text,
                "label": labels,
                "context": context
            },
            "input": input,
            "output": output
        })
    return res_data

def create_json_data(df):
    res_data = []
    for idx, row in df.iterrows():
        text = row['text']
        label = row['label']
        context = row['context']
        
        #create input
        input = ' '.join(row['text'])
        
        #create json object
        output = {"label": label,
                  "context": [text for text, context_label in zip(text, context) if context_label == 1]}
        
        # add example to response
        res_data.append({
            "gold": {
                "text": text,
                "label": label,
                "context": context
            },
            "input": input,
            "output": json.dumps(output)
        })
    return res_data

# def create_json_2_data(df):
#     res_data = []
#     for idx, row in df.iterrows():
#         text = row['text']
#         label = row['label']
#         context = row['context']
#         input = [f'sent{idx}: {sent}\n' for idx, sent in enumerate(text)]
        
#         #create json object
#         output = {"label": label,
#                   "context": [f'sent{idx}' for idx, context_label in enumerate(context) if context_label == 1]}
        
#         # add example to response
#         res_data.append({
#             "gold": {
#                 "text": text,
#                 "label": label,
#                 "context": context
#             },
#             "input": ' '.join(input),
#             "output": json.dumps(output)
#         })
#     return res_data

In [8]:
# prepare data
test_df_clean = preprocess_data(test_data, max_input_len)
train_df_clean = preprocess_data(train_data, max_input_len)

# create XML data
test_xml = create_xml_data(test_df_clean)
train_xml = create_xml_data(train_df_clean)
XML_OUTPUT = OUTPUT + 'XML/'
os.makedirs(XML_OUTPUT, exist_ok=True)
with open(XML_OUTPUT+ 'train.json', 'w', encoding='utf-8') as f:
    json.dump(train_xml, f, ensure_ascii=False, indent=4)
with open(XML_OUTPUT+ 'test.json', 'w', encoding='utf-8') as f:
    json.dump(test_xml, f, ensure_ascii=False, indent=4)

# create JSON 1 data
test_json1 = create_json_data(test_df_clean)
train_json1 = create_json_data(train_df_clean)
JSON1_OUTPUT = OUTPUT + 'JSON/'
os.makedirs(JSON1_OUTPUT, exist_ok=True)
with open(JSON1_OUTPUT + '/train.json', 'w', encoding='utf-8') as f:
    json.dump(train_json1, f, ensure_ascii=False, indent=4)
with open(JSON1_OUTPUT + 'test.json', 'w', encoding='utf-8') as f:
    json.dump(test_json1, f, ensure_ascii=False, indent=4)

# # create JSON 2 data
# test_json2 = create_json_2_data(test_df_clean)
# train_json2 = create_json_2_data(train_df_clean)
# JSON2_OUTPUT = OUTPUT + 'JSON2/'
# os.makedirs(JSON2_OUTPUT, exist_ok=True)
# with open(JSON2_OUTPUT + 'train.json', 'w', encoding='utf-8') as f:
#     json.dump(train_json2, f, ensure_ascii=False, indent=4)
# with open(JSON2_OUTPUT + 'test.json', 'w', encoding='utf-8') as f:
#     json.dump(test_json2, f, ensure_ascii=False, indent=4)