In [2]:
import pandas as pd
import os
from transformers import AutoTokenizer
import nltk
nltk.download('punkt')
from collections import Counter
import re
from sklearn.model_selection import train_test_split
import json

CITE_MARKER = '#AUTHOR_TAG'
SPECIAL_TOKEN = ['#TAUTHOR_TAG']

rand_state = 69
train_prob = 0.7  # Specify the probability if needed
dev_prob = 0.1    # Specify the probability if needed
test_prob = 0.2   # Specify the probability if needed
model_names =  ["allenai/scibert_scivocab_uncased","McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp","McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp"]

input_path = '/home/dataconv/deallab/lasse/CCE_Data/raw_data/finecite/'
ouput_path = f'/home/dataconv/deallab/lasse/CCE_Data/model_training/data/seq_tagger/fine_cite/'
os.makedirs(ouput_path, exist_ok=True)

tokenizers = []
for model_name in model_names:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.add_special_tokens({'additional_special_tokens': [CITE_MARKER, *SPECIAL_TOKEN]})
    tokenizers.append(tokenizer)

input_df = pd.read_csv(os.path.join(input_path, 'full_data.csv'))

# To Do
# CHANGE: for new form of ref markers
# includee #AUTHOR_TAG for ever tag

[nltk_data] Downloading package punkt to /home/dataconv/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [4]:
#helper
def eval_context(context:str, token_legth:int):
    res = eval(context)
    if len(res) != token_legth:
        print(f'The context was {len(res)-token_legth} labels longer than the tokens')
        res=res[:token_legth]
    return res

def flatten(lists):
    return [x for list in lists for x in list]
    

In [5]:
def return_majority_token_sentence_segment_df(data_df):
    res_df = pd.DataFrame(columns=['id', 'segments','labels'])
    for i in range(len(data_df)):
        #set variables
        id = data_df.loc[i, 'id']
        target_citation_loc = data_df.loc[i, 'target_reference_location']
        token_string = data_df.loc[i, 'paragraph'].replace(' ','').replace('\'','')
        token_list = re.sub(r'<ref[^\>]+>[^\<]+<\/ref>', CITE_MARKER, token_string).split(';')
        token_list[target_citation_loc] = '#TAUTHOR_TAG'
        contexts = [eval_context(data_df.loc[i, 'context_location1'], len(token_list))]
        # if data_df.loc[i, 'context_location2']:
        #     contexts.append(eval_context(data_df.loc[i, 'context_location2'], len(token_list)))
        
        #build token_id to sent_id, sent_id to token_id, token_id to context_label lookup
        token_id_to_sent_id = {}
        sent_id_to_token_id = {}
        token_id = 0
        sent_list = nltk.sent_tokenize(' '.join(token_list))
        for sent_id, sent in enumerate(sent_list):
            sent_id_to_token_id[sent_id] = []
            sent_token = sent.split(' ')
            for token in sent_token:
                token_id_to_sent_id[token_id] = sent_id
                sent_id_to_token_id[sent_id].append(token_id)
                token_id += 1
        # assert whether token_ids span the same length as the token_list
        if token_id != len(token_list):
            print('There is a mismatch between token id and token list length')
            continue
        
        for idx, context in enumerate(contexts):
            context_id = f'{id}_{idx}'
            if not any(context): continue
            if context[target_citation_loc] == 0:
                context[target_citation_loc] = context[target_citation_loc-1] if context[target_citation_loc-1] != 0 else 1
                print(f'set target reference market to {context[target_citation_loc]}')
            
            #select context, context sentences, calculate majority and priority tokens
            majority_token = []
            
            for sent_id in sent_id_to_token_id:
                sent_token_ids = sent_id_to_token_id[sent_id]
                sent_labels = [context[id] for id in sent_token_ids]
                #majority and priority token
                if not any(sent_labels):
                    majority_token.append(0)
                    continue
                majority_token.append(next(max_count[0] for max_count in Counter(sent_labels).most_common() if max_count[0] != 0))
            
            #assert existing majority and priority token for each sentence
            assert len(majority_token) == len(sent_list), 'There is a mismatch of the majority token legth and the number of context sentences'  

            res_df.loc[len(res_df)] = [context_id, sent_list, majority_token]
    return res_df
         
def return_priority_token_sentence_segment_df(data_df):
    res_df = pd.DataFrame(columns=['id', 'segments','labels'])
    for i in range(len(data_df)):
        #set variables
        id = data_df.loc[i, 'id']
        target_citation_loc = data_df.loc[i, 'target_reference_location']
        token_string = data_df.loc[i, 'paragraph'].replace(' ','').replace('\'','')
        token_list = re.sub(r'<ref[^\>]+>[^\<]+<\/ref>', CITE_MARKER, token_string).split(';')
        token_list[target_citation_loc] = '#TAUTHOR_TAG'
        contexts = [eval_context(data_df.loc[i, 'context_location1'], len(token_list))]
        # if data_df.loc[i, 'context_location2']:
        #     contexts.append(eval_context(data_df.loc[i, 'context_location2'], len(token_list)))
        
        #build token_id to sent_id, sent_id to token_id, token_id to context_label lookup
        token_id_to_sent_id = {}
        sent_id_to_token_id = {}
        token_id = 0
        sent_list = nltk.sent_tokenize(' '.join(token_list))
        for sent_id, sent in enumerate(sent_list):
            sent_id_to_token_id[sent_id] = []
            sent_token = sent.split(' ')
            for token in sent_token:
                token_id_to_sent_id[token_id] = sent_id
                sent_id_to_token_id[sent_id].append(token_id)
                token_id += 1
        # assert whether token_ids span the same length as the token_list
        if token_id != len(token_list):
            print('There is a mismatch between token id and token list length')
            continue        
        for idx, context in enumerate(contexts):
            context_id = f'{id}_{idx}'
            if not any(context): continue
            if context[target_citation_loc] == 0:
                context[target_citation_loc] = context[target_citation_loc-1] if context[target_citation_loc-1] != 0 else 1
                print(f'set target reference market to {context[target_citation_loc]}')
            
            #select context, context sentences, calculate majority and priority tokens
            priority_token = []
            
            for sent_id in sent_id_to_token_id:
                sent_token_ids = sent_id_to_token_id[sent_id]
                sent_labels = [context[id] for id in sent_token_ids]
                #majority and priority token
                if not any(sent_labels):
                    priority_token.append(0)
                    continue
                priority_token.append(min([label for label in sent_labels if label != 0]))
                
            #assert existing majority and priority token for each sentence
            assert len(priority_token) == len(sent_list), 'There is a mismatch of the priority token length and the number of context sentences'  

            res_df.loc[len(res_df)] = [context_id, sent_list, priority_token]
    return res_df

def return_token_segment_df(data_df, tokenizer):
    res_df = pd.DataFrame(columns=['id', 'segments','labels'])

    for i in range(len(data_df)):
        #set variables
        id = data_df.loc[i, 'id']
        target_citation_loc = data_df.loc[i, 'target_reference_location']
        token_string = data_df.loc[i, 'paragraph'].replace(' ','').replace('\'','')
        token_list = re.sub(r'<ref[^\>]+>[^\<]+<\/ref>', CITE_MARKER, token_string).split(';')
        token_list[target_citation_loc] = '#TAUTHOR_TAG'
        contexts = [eval_context(data_df.loc[i, 'context_location1'], len(token_list))]
        # if data_df.loc[i, 'context_location2']:
        #     contexts.append(eval_context(data_df.loc[i, 'context_location2'], len(token_list)))
        
        for idx, context in enumerate(contexts):
            context_id = f'{id}_{idx}'
            if not any(context): continue
            if context[target_citation_loc] == 0:
                context[target_citation_loc] = context[target_citation_loc-1] if context[target_citation_loc-1] != 0 else 1
                print(f'set target reference market to {context[target_citation_loc]}')
            
            tokenized_token_list = []
            tokenized_token_labels = []
            for idx, token in enumerate(token_list):
                tokenized_tokens = tokenizer.tokenize(token)
                for tokenized_token in tokenized_tokens:
                    tokenized_token_list.append(tokenized_token)
                    tokenized_token_labels.append(context[idx]) 
            assert len(tokenized_token_list) == len(tokenized_token_labels), 'The length of the tokenized tokens, and the associated labels do not have the same length'

            res_df.loc[len(res_df)] = [context_id, tokenized_token_list, tokenized_token_labels]
    return res_df

In [6]:
data_frames = []
segment_class = []
scope_weights_by_segment_class={}
total_weights_by_segment_class={}


# compute majority token sentence segments
data_frames.append(return_majority_token_sentence_segment_df(input_df))
segment_class.append('sentence_majo')

#compute priority token sentence segments
data_frames.append(return_priority_token_sentence_segment_df(input_df))
segment_class.append('sentence_prio')

for i, tokenizer in enumerate(tokenizers):
    data_frames.append(return_token_segment_df(input_df, tokenizer))
    segment_class.append('token_' + re.search(r'scibert|llama|mistral',model_names[i].lower()).group(0))
    
for i, res_df in enumerate(data_frames):
    folder_path  = os.path.join(ouput_path,f'{segment_class[i]}_{rand_state}__{train_prob}-{dev_prob}-{test_prob}'.replace('.', ''))
    os.makedirs(folder_path, exist_ok=True)
    train_df, test_df = train_test_split(res_df, test_size=0.2, random_state=rand_state)
    train_df, val_df = train_test_split(train_df, test_size=0.125, random_state=rand_state)
    train_df.to_csv(os.path.join(folder_path,'train.csv'), index=False)
    test_df.to_csv(os.path.join(folder_path,'test.csv'), index=False)
    val_df.to_csv(os.path.join(folder_path,'val.csv'), index=False)
    total = len(train_df)
    counter = Counter(flatten(train_df['labels'].to_list()))
    sorted_counter = sorted(counter.items())
    counter_sum = sum(counter.values())
    ratio_scopes = [counter_sum / (len(counter) * c )for l, c in sorted_counter]
    scope_weights_by_segment_class[segment_class[i]] = ratio_scopes
    
    ratio_total = [counter_sum/ (2 * sum([c for l, c in sorted_counter if l == 0])), counter_sum/ (2 * sum([c for l, c in sorted_counter if l != 0]))]
    total_weights_by_segment_class[segment_class[i]] = ratio_total

with open(os.path.join('./output/finecite_scopes_weights.json'), 'w') as f_out:
    json.dump(scope_weights_by_segment_class, f_out)

with open(os.path.join('./output/finecite_total_weights.json'), 'w') as f_out:
    json.dump(total_weights_by_segment_class, f_out)

set target reference market to 2
set target reference market to 2
set target reference market to 2
set target reference market to 1
set target reference market to 1
set target reference market to 1
set target reference market to 1
set target reference market to 1
set target reference market to 1
set target reference market to 1
set target reference market to 1
There is a mismatch between token id and token list length
set target reference market to 1
set target reference market to 1
set target reference market to 1
set target reference market to 1
There is a mismatch between token id and token list length
set target reference market to 1
set target reference market to 1
set target reference market to 1
set target reference market to 1
set target reference market to 2
There is a mismatch between token id and token list length
There is a mismatch between token id and token list length
set target reference market to 1
There is a mismatch between token id and token list length
set target r