In [28]:
from transformers import AutoTokenizer, AutoModel
from sklearn.model_selection import train_test_split
import pandas as pd
import os
import torch
from difflib import SequenceMatcher

device = "cuda:1" if torch.cuda.is_available() else "cpu"
args = {
    'data_set': 'acl_arc', 
    #'data_set' : 'sdp_act',
    'train_size': 1647, # for acl_arc
    #'train_size': 3000, # for sdp_act
    'model_path': 'allenai/specter',
    #'context_type': 'non_contiguous', 
    'context_type': 'contiguous',
    'exp_type': 'exp2'
}

TOKENIZER = AutoTokenizer.from_pretrained("allenai/specter")
MODEL = AutoModel.from_pretrained("allenai/specter").to(device)
DATA_PROCESSED_DIR = f'../data/'
DATASET_DIR = DATA_PROCESSED_DIR + f"/{args['data_set']}"
OUTPUT_DIR_PREV = os.path.join(DATASET_DIR, f"{args['data_set']}_dc_{'nc' if args['context_type']=='non_contiguous' else 'c'}_{args['exp_type']}_prev")
OUTPUT_DIR_NEXT = os.path.join(DATASET_DIR, f"{args['data_set']}_dc_{'nc' if args['context_type']=='non_contiguous' else 'c'}_{args['exp_type']}_next")
OUTPUT_DIR_COMB = os.path.join(DATASET_DIR, f"{args['data_set']}_dc_{'nc' if args['context_type']=='non_contiguous' else 'c'}_{args['exp_type']}_comb")

for OUTPUT_DIR in [OUTPUT_DIR_PREV, OUTPUT_DIR_NEXT, OUTPUT_DIR_COMB]:
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)

In [29]:
#helper
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)

def tokenize(input_seq):
    input = TOKENIZER(input_seq, padding=True, truncation=True, return_tensors="pt", max_length=512).to(device)
    result = MODEL(**input)
    embeddings = result.last_hidden_state[:, 0, :]
    return embeddings


def compute_smilarity(embd1, embd2, cos=cos):
    sim = cos(embd1, embd2)
    return sim


# extract preious and next sentences from the paragraph
def get_prev_next_context(paragraph, citing_sent_index):
    context_prev = list()
    context_next = list()
    if paragraph:
       context_prev.extend(paragraph[:citing_sent_index])
       context_prev.extend(paragraph[citing_sent_index+1:])
       
    return context_prev, context_next

# contigious / non-contigious context

class InputFeatures(object):

    def __init__(self, citing_title, cited_title, citation_context, cited_abstract, paragraph):
        self.citing_title = citing_title
        self.cited_title = cited_title
        self.citation_context = citation_context
        self.cited_abstract = cited_abstract
        self.paragraph = paragraph


class ContiguousContextExp2(InputFeatures):
    """A single set of features of data."""

    def extract_embeddings(self):

        title_abs = [self.cited_title + TOKENIZER.sep_token + '']
        input_seq = ['' + TOKENIZER.sep_token + self.citation_context]
        citing_context_emb = tokenize(input_seq)
        cited_abstract_emb = tokenize(title_abs)
        similarity_citation_context = compute_smilarity(cited_abstract_emb, citing_context_emb)
        return cited_abstract_emb, similarity_citation_context

    def extract_context(self, cited_abstract_emb, similarity_citation_context):
        
        citing_sent_index = None
        
        try:
            citing_sent_index = self.paragraph.index(self.citation_context)
            prev, next = get_prev_next_context(self.paragraph, citing_sent_index)
            
            prev_emb = self.extract_context_from_paragraph(prev, similarity_citation_context, cited_abstract_emb)
            next_emb = self.extract_context_from_paragraph(next, similarity_citation_context, cited_abstract_emb)
            
            dynamic_context_prev =  prev_emb + [self.citation_context]
            
            dynamic_context_next = [self.citation_context] + next_emb

            dynamic_context_combined =  prev_emb + [self.citation_context] + next_emb
                                       
        except (IndexError, ValueError):
            for i, sent in enumerate(self.paragraph):
                s = SequenceMatcher(None, self.citation_context, sent)
                if s.ratio() > 0.80:
                    citing_sent_index = i
                    prev, next = get_prev_next_context(self.paragraph, citing_sent_index)
                    prev_emb = self.extract_context_from_paragraph(prev, similarity_citation_context, cited_abstract_emb)
                    next_emb = self.extract_context_from_paragraph(next, similarity_citation_context, cited_abstract_emb)
                    
                    dynamic_context_prev =  prev_emb + [self.citation_context]
                    
                    dynamic_context_next = [self.citation_context] + next_emb

                    dynamic_context_combined =  prev_emb + [self.citation_context] + next_emb
                    break
                else:
                    continue
            
            if citing_sent_index is None:
                dynamic_context_prev = [self.citation_context]
                dynamic_context_next = [self.citation_context]
                dynamic_context_combined = [self.citation_context]


        return dynamic_context_prev, dynamic_context_next, dynamic_context_combined

    def extract_context_from_paragraph(self, context, similarity_citation_context, cited_abstract_emb):

        dynamic_context = list()

        if context:
            for sent in context:
                para_sent_seq = ['' + TOKENIZER.sep_token + sent]
                para_sent_emb = tokenize(para_sent_seq)
                similarity_sent = compute_smilarity(cited_abstract_emb, para_sent_emb)
                if similarity_sent >= similarity_citation_context:
                    dynamic_context.append(sent)

                else:
                    break

        return dynamic_context


class NonContiguousContextExp2(InputFeatures):
    """A single set of features of data."""

    def extract_embeddings(self):

        title_abs = [self.cited_title + TOKENIZER.sep_token + '']
        input_seq = ['' + TOKENIZER.sep_token + self.citation_context]
        citing_context_emb = tokenize(input_seq)
        cited_abstract_emb = tokenize(title_abs)
        similarity_citation_context = compute_smilarity(cited_abstract_emb, citing_context_emb)
        return cited_abstract_emb, similarity_citation_context

    def extract_context(self, cited_abstract_emb, similarity_citation_context):

        citing_sent_index = None
        try:
            citing_sent_index = self.paragraph.index(self.citation_context)
            prev, next = get_prev_next_context(self.paragraph, citing_sent_index)
            dynamic_context_prev = self.extract_context_from_paragraph(prev, similarity_citation_context,
                                                                       cited_abstract_emb) + [self.citation_context]
            dynamic_context_next = [self.citation_context] + self.extract_context_from_paragraph(next,
                                                                                                 similarity_citation_context,
                                                                                                 cited_abstract_emb)

            dynamic_context_combined = self.extract_context_from_paragraph(prev, similarity_citation_context,
                                                                           cited_abstract_emb) \
                                       + [self.citation_context] + \
                                       self.extract_context_from_paragraph(next, similarity_citation_context,
                                                                           cited_abstract_emb)
        except (IndexError, ValueError):
            for i, sent in enumerate(self.paragraph):
                s = SequenceMatcher(None, self.citation_context, sent)
                if s.ratio() > 0.80:
                    citing_sent_index = i
                    prev, next = get_prev_next_context(self.paragraph, citing_sent_index)
                    dynamic_context_prev = self.extract_context_from_paragraph(prev, similarity_citation_context,
                                                                               cited_abstract_emb) + [
                                               self.citation_context]
                    dynamic_context_next = [self.citation_context] + self.extract_context_from_paragraph(prev,
                                                                                                         similarity_citation_context,
                                                                                                         cited_abstract_emb)

                    dynamic_context_combined = self.extract_context_from_paragraph(prev, similarity_citation_context,
                                                                                   cited_abstract_emb) + [
                                                   self.citation_context] + \
                                               self.extract_context_from_paragraph(prev, similarity_citation_context,
                                                                                   cited_abstract_emb)

                    break
                else:
                    continue

            if citing_sent_index is None:
                dynamic_context_prev = [self.citation_context]
                dynamic_context_next = [self.citation_context]
                dynamic_context_combined = [self.citation_context]

        return dynamic_context_prev, dynamic_context_next, dynamic_context_combined

    def extract_context_from_paragraph(self, context, similarity_citation_context, cited_abstract_emb):

        dynamic_context = list()

        if context:
            for sent in context:
                para_sent_seq = ['' + TOKENIZER.sep_token + sent]
                para_sent_emb = tokenize(para_sent_seq)
                similarity_sent = compute_smilarity(cited_abstract_emb, para_sent_emb)
                if similarity_sent >= similarity_citation_context:
                    dynamic_context.append(sent)

                else:
                    continue

        return dynamic_context


processors_dynamic_context_contiguous = {
    "exp2": ContiguousContextExp2,
}

processors_dynamic_context_non_contiguous = {
    "exp2": NonContiguousContextExp2,
}

#extractoin function
def extract_dynamic_context(data_df,context_type,exp_type):
    dynamic_contexts_prev = []
    dynamic_contexts_next = []
    dynamic_contexts_combined = []
    
    data_df['cite_context_paragraph'] = data_df['cite_context_paragraph'].apply(eval)
    for idx, row in data_df.iterrows():
        paragraph = row['cite_context_paragraph']
        citation_context = row['citation_context']
        citing_title = row['citing_title']
        cited_title = row['cited_title']

        if type(row['cited_abstract']) is not str:
            cited_abstract = ''

            print('abstract None')
        else:
            cited_abstract = row['cited_abstract']

        if context_type== 'non_contiguous':
            processors = processors_dynamic_context_non_contiguous[exp_type](citing_title, cited_title,
                                                                             citation_context,
                                                                             cited_abstract, paragraph)
            
            cited_abstract_emb, similarity_citation_context = processors.extract_embeddings()
            dynamic_context_prev, dynamic_context_next, dynamic_context_combined = \
                processors.extract_context(cited_abstract_emb, similarity_citation_context)
            
        else:
            processors = processors_dynamic_context_contiguous[exp_type](citing_title, cited_title,
                                                                         citation_context,
                                                                         cited_abstract, paragraph)
            cited_abstract_emb, similarity_citation_context = processors.extract_embeddings()

            dynamic_context_prev, dynamic_context_next, dynamic_context_combined = \
                processors.extract_context(cited_abstract_emb, similarity_citation_context)
            
        dynamic_contexts_next.append(dynamic_context_next)
        dynamic_contexts_prev.append(dynamic_context_prev)
        dynamic_contexts_combined.append(dynamic_context_combined)

    return dynamic_contexts_prev, dynamic_contexts_next, dynamic_contexts_combined


In [30]:
for dataset in ["train", "test"]:
    print(dataset)
    data_df = pd.read_csv(DATASET_DIR + f"/{dataset}_raw.txt", sep="\t", engine="python", dtype=object)
    dynamic_contexts_prev, dynamic_contexts_next, dynamic_contexts_combined = \
        extract_dynamic_context(data_df, args['context_type'], args['exp_type'])
    prev_df = pd.DataFrame({'CC': dynamic_contexts_prev,
                            'label': data_df['citation_class_label']})
    next_df = pd.DataFrame({'CC': dynamic_contexts_next,
                            'label': data_df['citation_class_label']})
    comb_df = pd.DataFrame({'CC': dynamic_contexts_combined,
                            'label': data_df['citation_class_label']})   
    
    if dataset == 'train':
        prev_df.to_csv(OUTPUT_DIR_PREV + '/train.csv', index=False)
        next_df.to_csv(OUTPUT_DIR_NEXT + '/train.csv', index=False)
        comb_df.to_csv(OUTPUT_DIR_COMB + '/train.csv', index=False)

    else:
        prev_df.to_csv(OUTPUT_DIR_PREV + '/test.csv', index=False)
        next_df.to_csv(OUTPUT_DIR_NEXT + '/test.csv', index=False)
        comb_df.to_csv(OUTPUT_DIR_COMB + '/test.csv', index=False)



train
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
abstract None
