In [157]:
import pandas as pd
from transformers import AutoTokenizer
import re
from random import randint
import json
import os

MODEL = 'mistral'

INPUT = '../data/acl_arc/'
OUTPUT = f'../data/acl_arc/{MODEL}/'

model_mapping = {
    "mistral":'mistralai/Mistral-7B-Instruct-v0.3',
}
model_id = model_mapping[MODEL]
tokenizer = AutoTokenizer.from_pretrained(model_id)

#length limitations
max_input_len = 400

In [158]:
# load data 
train_raw = pd.read_csv(INPUT + 'train_raw.txt',sep='\t')
test_raw = pd.read_csv(INPUT + 'test_raw.txt',sep='\t')

#extract relevant data
train_df = train_raw[['citation_context','citation_class_label']]
train_df.columns = ['text', 'label']
test_df = test_raw[['citation_context','citation_class_label']]
test_df.columns = ['text', 'label']


In [159]:
#helper 
def preprocess_df(df, max_len):
    res_df = pd.DataFrame(columns=['text', 'label'])
    for idx, row in df.iterrows():
        
        #prepare text
        #text = eval(row['text']) # for whole paragraph
        text = [row['text']] # for citing sentence only
        text = [replace_citations(sent) for sent in text]
        
        #find target location
        target_location = [idx for idx, sent in enumerate(text) if '#TARGET_REF' in sent]
        if len(target_location) != 1: 
            #print(f'it seams like there are no or more then one target in the sample: {target_location}')
            continue
        
        # restrict text len to < max len
        text_len = [len(tokenizer.encode(sent)) for sent in text]
        cumm_len = text_len[target_location[0]]
        res_sent_ids = [target_location[0]]
        while len(res_sent_ids) < len(text):
            chance = randint(0,1)
            prev = res_sent_ids[0] - 1
            next = res_sent_ids[-1] +1
            if chance == 0 and prev >= 0:
                if cumm_len + text_len[prev] <= max_len:
                    res_sent_ids.insert(0, prev)
                    cumm_len += text_len[prev]
                else: break
            if chance == 1 and next < len(text):
                if cumm_len + text_len[next] <= max_len:
                    res_sent_ids.append(next)
                    cumm_len += text_len[next]
                else: break   
        text = [sent for idx, sent in enumerate(text) if idx in res_sent_ids]
        #prepare labels
        label = label_mapping(row['label'])
        
        # add text and labels to result
        res_df.loc[len(res_df)] = [text, label]
    return res_df

def replace_citations(text):
    #replace and fix #AUTHOR_TAG
    text = re.sub(r'#AUTHOR_TAG', '#TARGET_REF', text)
    rext = re.sub(r'(?:van |)(?:[A-Z][a-z]+-?)+ ?(?:and [A-Za-z-]{2,}|et al?t?\.)?(?: |,)*?#TARGET_REF', '#TARGET_REF', text) #fix wrong parsing
    
    #replace all other citatons with #REF
    text = re.sub(r'(?:van |)(?:[A-Z][a-z]+-?)+ ?(?:and [A-Za-z-]{2,}|et al?t?\.)?(?: |,)*?(?:\d{4}|9\d|\[ ?\d{1,2} ?\]|\( ?\d{4} ?\))', '#REF', text)
    
    return text

def label_mapping(label):
    if type(label) == int:
        if label == 0: return 'BACKGROUND'
        if label == 1: return 'USE'
        if label == 2: return 'COMPARE_CONTRAST'
        if label == 3: return 'MOTIVATION'
        if label == 4: return 'EXTENSION'
        if label == 5: return 'FUTURE'
    else:
        if label == 'BACKGROUND': return 0
        if label == 'USE': return 1
        if label == 'COMPARE_CONTRAST': return 2
        if label == 'MOTIVATION': return 3
        if label == 'EXTENSION': return 4
        if label == 'FUTURE': return 5

In [160]:
#create different task shemata for acl-arc
'''data schema
    [{
        "gold": {
            "text": [<str>],
            "label": [<str>]
        },
        "input": <str>
        "output": <str>
    }]
'''

def create_xml_data(df):
    res_data = []
    for idx, row in df.iterrows():
        text = row['text']
        label = row['label']
        input = ' '.join(row['text'])
        
        #set XML tags
        tag_location = input.find('#TARGET_REF') + len('#TARGET_REF')
        output = input[:tag_location] + f'<{label}/>' + input[tag_location:]
        
        # add example to response
        res_data.append({
            "gold": {
                "text": text,
                "label": [label]
            },
            "input": input,
            "output": output
        })
    return res_data

def create_json_1_data(df):
    res_data = []
    for idx, row in df.iterrows():
        text = row['text']
        label = row['label']
        input = ' '.join(row['text'])
        
        #create json object
        output = {"label": label}
        
        # add example to response
        res_data.append({
            "gold": {
                "text": text,
                "label": [label]
            },
            "input": input,
            "output": json.dumps(output)
        })
    return res_data

def create_json_2_data(df):
    res_data = []
    for idx, row in df.iterrows():
        text = row['text']
        label = row['label']
        input = [f'sent{idx}: {sent}\n' for idx, sent in enumerate(text)]
        
        #create json object
        output = {"label": label}
        
        # add example to response
        res_data.append({
            "gold": {
                "text": text,
                "label": [label]
            },
            "input": ' '.join(input),
            "output": json.dumps(output)
        })
    return res_data

In [161]:
# prepare data
test_df_clean = preprocess_df(test_df, max_input_len)
train_df_clean = preprocess_df(train_df, max_input_len)

# create XML data
test_xml = create_xml_data(test_df_clean)
train_xml = create_xml_data(train_df_clean)
XML_OUTPUT = OUTPUT + 'XML/'
os.makedirs(XML_OUTPUT, exist_ok=True)
with open(XML_OUTPUT+ 'train.json', 'w', encoding='utf-8') as f:
    json.dump(train_xml, f, ensure_ascii=False, indent=4)
with open(XML_OUTPUT+ 'test.json', 'w', encoding='utf-8') as f:
    json.dump(test_xml, f, ensure_ascii=False, indent=4)

# create JSON 1 data
test_json1 = create_json_1_data(test_df_clean)
train_json1 = create_json_1_data(train_df_clean)
JSON1_OUTPUT = OUTPUT + 'JSON1/'
os.makedirs(JSON1_OUTPUT, exist_ok=True)
with open(JSON1_OUTPUT + '/train.json', 'w', encoding='utf-8') as f:
    json.dump(train_json1, f, ensure_ascii=False, indent=4)
with open(JSON1_OUTPUT + 'test.json', 'w', encoding='utf-8') as f:
    json.dump(test_json1, f, ensure_ascii=False, indent=4)

# create JSON 2 data
test_json2 = create_json_2_data(test_df_clean)
train_json2 = create_json_2_data(train_df_clean)
JSON2_OUTPUT = OUTPUT + 'JSON2/'
os.makedirs(JSON2_OUTPUT, exist_ok=True)
with open(JSON2_OUTPUT + 'train.json', 'w', encoding='utf-8') as f:
    json.dump(train_json2, f, ensure_ascii=False, indent=4)
with open(JSON2_OUTPUT + 'test.json', 'w', encoding='utf-8') as f:
    json.dump(test_json2, f, ensure_ascii=False, indent=4)
