In [34]:
from typing import List, Optional, Dict
import os
import random
import json
from transformers import AutoTokenizer
from tqdm import tqdm
from collections import Counter, defaultdict
import re

CITE_MARKER = '#AUTHOR_TAG'
IGNORE_SENTS = {'----------------------------------'}

# Define the argument values as Python variables
input_directory = "/raid/deallab/CCE_Data/raw_data/multicite/"
file_name = 'full_raw.json'
output_directory = "/raid/deallab/CCE_Data/model_training/data/seq_tagger/"
model_name =  "allenai/scibert_scivocab_uncased"
rand_state = 99
train_prob = 0.7  # Specify the probability if needed
dev_prob = 0.1    # Specify the probability if needed
test_prob = 0.2   # Specify the probability if needed

# stuff need from model
tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=False, use_fast=True)
tokenizer.add_special_tokens({'additional_special_tokens': [CITE_MARKER]})


1

In [35]:
#helper

# replace html tags in the sentence
def replace_html(sentence: str, new_cite_token: str) -> str:
    pattern = '<span style="background: yellow; display: inline-block">.*?</span>'
    return re.sub(pattern, new_cite_token, sentence)

# Tokenize the input sentence using the provided tokenizer and return length
def compute_num_tokens(sentence: str, tokenizer):
    encodings_from_sent = tokenizer(
        sentence,
        is_split_into_words=False,
        add_special_tokens=False
    )
    return len(encodings_from_sent['input_ids'])

# describes  where the citation is in the context window
def _describe_window(window_sent_pos: List[int], window_num_tokens: int,
                     cite_sent_pos: List[int],
                     gold_context_sent_pos: List[int], gold_context_num_tokens: int):

    # Determine the positions of citation sentences within the window
    where_are_cite_in_window = [window_sent_pos.index(i) for i in cite_sent_pos]
    print(f'Token Window:{window_num_tokens}\tToken Gold:{gold_context_num_tokens}\tSent Window: {len(window_sent_pos)}\tSent Gold: {len(gold_context_sent_pos)}\tCiteInWindow: {where_are_cite_in_window}')
    
    
def build_section_dict(paper: dict):
    section_dict = {}
    current_section = 0
    for idx, sent_dict in enumerate(paper):
        text = sent_dict['text']
        if text[:2] + text[-2:] == '****':
            current_section += 1
        else:
            section_dict[idx] = current_section
    return section_dict

In [59]:

def split_examples(examples: list, train_prob: float, dev_prob: float, test_prob: float):
    """
    output looks like:
    {
        0: {'train': [...], 'dev': [...], 'test': [...]},
        1: {'train': [...], 'dev': [...], 'test': [...]},
        ...
        k: {'train': [...], 'dev': [...], 'test': [...]},
    }
    """
    assert train_prob + dev_prob + test_prob == 1.0
    num_folds = int(1.0 / test_prob)
    assert 1.0 / test_prob == num_folds  # check that we can evenly divide into folds

    paper_ids = list({example['instance_id'].split('_')[1] for example in examples})
    random.shuffle(paper_ids)

    num_test = int(len(paper_ids) / num_folds)
    num_train = int((len(paper_ids) - num_test) * train_prob / (train_prob + dev_prob))
    num_dev = len(paper_ids) - num_test - num_train
    assert num_train + num_dev + num_test == len(paper_ids)

    # precompute the CV paper ID assignments
    fold_to_paper_ids = {}
    for fold in range(num_folds):
        # Determine the range of paper IDs for the test set in the current fold
        start_test = fold * num_test
        end_test = start_test + num_test
        test = paper_ids[start_test:end_test]

        # Determine the remaining paper IDs for the training and development sets
        remaining_paper_ids = [paper_id for paper_id in paper_ids if paper_id not in test]
        random.shuffle(remaining_paper_ids)
        train = remaining_paper_ids[:num_train]
        dev = remaining_paper_ids[num_train:]
        assert len(train) + len(dev) + len(test) == len(paper_ids)
        fold_to_paper_ids[fold] = (train, dev, test)


    # split examples s.t. all examples from same paper go to same split
    fold_to_examples = {}
    for fold, (train, dev, test) in fold_to_paper_ids.items():
        fold_to_examples[fold] = {
            'train': [],
            'dev': [],
            'test': []
        }
        for example in examples:
            paper_id = example['instance_id'].split('_')[1]
            if paper_id in train:
                fold_to_examples[fold]['train'].append(example)
            elif paper_id in dev:
                fold_to_examples[fold]['dev'].append(example)
            else:
                fold_to_examples[fold]['test'].append(example)
    return fold_to_examples

def build_window(gold_context_sent_pos: List[int],                                                  # annotations
                 sent_pos_to_num_tokens: List[int],                                                 # resources
                 section_pos: List[int],
                 paper_num_sents: int, max_model_num_tokens: int                                    # constraints
                 ) -> Optional[List[int]]:

    # let's always include the full gold context, at minimum (unless it truncates due to window, in which case we skip)
    gold_context_num_tokens = sum([sent_pos_to_num_tokens[gold_sent_pos] for gold_sent_pos in gold_context_sent_pos])
    section_sent_num_tokens = sum([sent_pos_to_num_tokens[section_sent_pos] for section_sent_pos in section_pos])

    # validate our constraints before getting the actual sentence text
    if gold_context_num_tokens > max_model_num_tokens:  # too long for even this longformer
        print(f'Skipping... Token Gold: {gold_context_num_tokens}')
        return None
    elif section_sent_num_tokens < max_model_num_tokens:
        return section_pos
    else:

        # now let's randomly append sentences to front/back of gold context
        window_sent_pos = [sent_pos for sent_pos in gold_context_sent_pos]
        window_num_tokens = gold_context_num_tokens
        while window_num_tokens <= max_model_num_tokens:
            before_sent_id = min(window_sent_pos) - 1
            after_sent_id = max(window_sent_pos) + 1
            
            #check if sent in section
            if before_sent_id not in section_pos and after_sent_id in section_pos:
                new_sent_id = after_sent_id
            elif after_sent_id not in section_pos and before_sent_id in section_pos:
                new_sent_id = before_sent_id
            elif before_sent_id < min(section_pos) and after_sent_id > max(section_pos):
                break
            else:
                choice = random.randint(0, 1)
                if choice == 0:
                    new_sent_id = after_sent_id
                else:
                    new_sent_id = before_sent_id

            # token length check
            new_sent_num_tokens = sent_pos_to_num_tokens[new_sent_id]
            if window_num_tokens + new_sent_num_tokens <= max_model_num_tokens:

                # now add to window
                window_sent_pos.append(new_sent_id)
                window_num_tokens += sent_pos_to_num_tokens[new_sent_id]

            else:
                # ran out of space so end adding phase
                break

        window_sent_pos = sorted(window_sent_pos)
        return window_sent_pos

def build_all_examples(raw_data_dict: Dict) -> List[Dict]:
    
    active_logger = False
    n = 0
    all_examples = []
    for paper_id, data in tqdm(raw_data_dict.items(), position=0):
        paper = data['x']                          # paper
        intent_to_annotations = data['y']          # intent & context annotations

        # 0) clean out IGNORED SENTS
        clean_paper = [sent_dict for sent_dict in paper if sent_dict['text'] not in IGNORE_SENTS]

        # 1) build an easy lookup for sentences from its ID (str) to its list position (int) and the oposite around
        sent_id_to_list_pos: Dict[str, int] = {sent_dict['sent_id']: i for i, sent_dict in enumerate(clean_paper)}
        sent_pos_to_sent_id: Dict[int, str] = {sent_pos: sent_id for sent_id, sent_pos in sent_id_to_list_pos.items()}
        
        # 1.5) build section lookup table from sentence ID to section idx and the other way around
        sent_pos_to_section_idx = build_section_dict(clean_paper)
        section_idx_to_sent_pos = defaultdict(list)
        for sent_pos, section_idx in sent_pos_to_section_idx.items():
            section_idx_to_sent_pos[section_idx].append(sent_pos)

        # 2) clean html and get tokens length per sentence
        sent_pos_to_clean_text = [replace_html(sent_dict['text'], new_cite_token=CITE_MARKER) for sent_dict in clean_paper]
        sent_pos_to_num_tokens: List[int] = [compute_num_tokens(clean_sent, tokenizer) for clean_sent in sent_pos_to_clean_text]
        
        # 3) convert concat all citation contex always taking the biggest one if ther are multiple
        cite_sent_dict = {}
        for intent, context in intent_to_annotations.items():
            gold_context = context['gold_contexts']
            cite_sents = context['cite_sentences']
            if len(cite_sents) != len(gold_context):
                print('length is not matching')
                continue
            for i, sent in enumerate(cite_sents):
                cite_sent_dict[sent]=gold_context[i]

        # 4) create a training example out of each gold context (i.e. each cite mention)
        #    note, this code is a bit annoying because we want to convert all sent_ids to positions first
        for idx, (sent_id, gold_context_sent_ids) in enumerate(intent_to_annotations.items()):

            # 4a) build valid window
            gold_context_sent_pos = [sent_id_to_list_pos[sent_id] for sent_id in gold_context_sent_ids if sent_id in sent_id_to_list_pos]   # ignores if annotated an ignored sent
            if not gold_context_sent_pos: continue
            counter = Counter([sent_pos_to_section_idx[sent_pos] for sent_pos in gold_context_sent_pos if sent_pos in sent_pos_to_section_idx])
            if not counter: continue
            gold_context_section_idx = counter.most_common()[0][0]
            window_sent_pos: Optional[List[int]] = build_window(gold_context_sent_pos=gold_context_sent_pos,
                                                                sent_pos_to_num_tokens=sent_pos_to_num_tokens,
                                                                section_pos = section_idx_to_sent_pos[gold_context_section_idx],
                                                                paper_num_sents=len(clean_paper),
                                                                max_model_num_tokens=500)   # conservatively stay within 512
            if window_sent_pos is None:
                continue

            # verify there is at least one <CITE>
            if sent_id_to_list_pos[sent_id] not in window_sent_pos:
                continue
            if active_logger:
                _describe_window(window_sent_pos=window_sent_pos,
                                window_num_tokens=sum([sent_pos_to_num_tokens[pos] for pos in window_sent_pos]),
                                cite_sent_pos=sent_id_to_list_pos[sent_id],
                                gold_context_sent_pos=gold_context_sent_pos,
                                gold_context_num_tokens=sum([sent_pos_to_num_tokens[pos] for pos in gold_context_sent_pos]))

            # 4b) get the text and labels
            example = {
                'id': n,
                'instance_id': f'{paper_id}__{sent_id}',
                'sentences': [],
                'labels': [],
                'sent_ids': []
            }
            for w_pos in window_sent_pos:
                example['sentences'].append(sent_pos_to_clean_text[w_pos])
                example['labels'].append('context' if w_pos in gold_context_sent_pos else 'not-context')
                example['sent_ids'].append(sent_pos_to_sent_id[w_pos])

            all_examples.append(example)
            n += 1

    return all_examples

In [60]:
# stuff for input
with open(os.path.join(input_directory, file_name)) as f_in:
    d = json.load(f_in)

# build examples
random.seed(rand_state)
all_examples = build_all_examples(raw_data_dict=d)

# stuff for output
outdir = os.path.join(output_directory, f'section__{rand_state}__tr{train_prob}-v{dev_prob}-t{test_prob}'.replace('/', '-').replace('.', ''))
if not os.path.exists(outdir):
    os.makedirs(outdir, exist_ok=True)
out_file = os.path.join(outdir, 'full.json')
with open(out_file, 'w') as outfile:
    json.dump({'data': all_examples}, outfile, indent=4)

# now do the splitting & log it
logfile = os.path.join(outdir, 'log.txt')
with open(logfile, 'w') as f_log:

    fold_to_splits = split_examples(examples=all_examples, train_prob=train_prob, dev_prob=dev_prob, test_prob=test_prob)
    for fold, splits in fold_to_splits.items():
        folddir = os.path.join(outdir, f'{fold}/')
        os.makedirs(folddir, exist_ok=True)
        trainfile = os.path.join(folddir, 'train.jsonl')
        devfile = os.path.join(folddir, 'dev.jsonl')
        testfile = os.path.join(folddir, 'test.jsonl')
        train = splits['train']
        dev = splits['dev']
        test = splits['test']
        with open(trainfile, 'w') as f_train:
            for e in train:
                json.dump(e, f_train)
                f_train.write('\n')
        with open(devfile, 'w') as f_val:
            for e in dev:
                json.dump(e, f_val)
                f_val.write('\n')
        with open(testfile, 'w') as f_test:
            for e in test:
                json.dump(e, f_test)
                f_test.write('\n')

        f_log.write(f'Counting rare class in Fold {fold} \n')
        f_log.write('train ' + f"{Counter([sum([1 if tag == 'context' else 0 for tag in e['labels']]) for e in train])} / {len(train)}\n")
        f_log.write('dev ' + f"{Counter([sum([1 if tag == 'context' else 0 for tag in e['labels']]) for e in dev])} / {len(dev)}\n")
        f_log.write('test ' + f"{Counter([sum([1 if tag == 'context' else 0 for tag in e['labels']]) for e in test])} / {len(test)}\n")
        f_log.write('\n\n')

  0%|          | 0/1193 [00:00<?, ?it/s]

{'gold_contexts': [['8f0aab7fd30ffc56cc477b25e6bb16-C001-11']], 'cite_sentences': ['8f0aab7fd30ffc56cc477b25e6bb16-C001-11']}





KeyError: 'gold_context'