In [15]:
from typing import List, Optional, Dict
import os
import random
import json
from transformers import AutoTokenizer
from tqdm import tqdm
from collections import Counter

CITE_START = '[CITE]'
CITE_END = '[/CITE]'
INTENT_TOKENS = ['@BACK@', '@MOT@', '@FUT@', '@SIM@', '@DIF@', '@USE@', '@EXT@', '@UNSURE@']
SPECIAL_TOKENS = [CITE_START, CITE_END] + INTENT_TOKENS
IGNORE_SENTS = {'----------------------------------', '****'}

# Define the argument values as Python variables
input_directory = "../../data/multicite"
file_name = 'full_data.json'
output_directory = "../data/seq_tagger"
model_name =  "allenai/scibert_scivocab_uncased"
window = 4  # or some integer for a maximum window size
rand_state = 99
train_prob = 0.7  # Specify the probability if needed
dev_prob = 0.1    # Specify the probability if needed
test_prob = 0.2   # Specify the probability if needed

# stuff need from model
tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=False, use_fast=True)
tokenizer.add_special_tokens({'additional_special_tokens': SPECIAL_TOKENS})


10

In [16]:
#helper

# replace html tags in the sentence
def replace_html(sentence: str, new_cite_start: str, new_cite_end: str) -> str:
    cite_start = '<span style="background: yellow; display: inline-block">'
    cite_end = '</span>'
    return sentence.replace(cite_start, new_cite_start).replace(cite_end, new_cite_end)

# Tokenize the input sentence using the provided tokenizer and return length
def compute_num_tokens(sentence: str, tokenizer):
    encodings_from_sent = tokenizer(
        sentence,
        is_split_into_words=False,
        add_special_tokens=False
    )
    return len(encodings_from_sent['input_ids'])

# describes  where the citation is in the context window
def _describe_window(window_sent_pos: List[int], window_num_tokens: int,
                     cite_sent_pos: List[int],
                     gold_context_sent_pos: List[int], gold_context_num_tokens: int):

    # Determine the positions of citation sentences within the window
    where_are_cite_in_window = [window_sent_pos.index(i) for i in cite_sent_pos]
    print(f'Token Window:{window_num_tokens}\tToken Gold:{gold_context_num_tokens}\tSent Window: {len(window_sent_pos)}\tSent Gold: {len(gold_context_sent_pos)}\tCiteInWindow: {where_are_cite_in_window}')

In [18]:

def split_examples(examples: list, train_prob: float, dev_prob: float, test_prob: float):
    """
    output looks like:
    {
        0: {'train': [...], 'dev': [...], 'test': [...]},
        1: {'train': [...], 'dev': [...], 'test': [...]},
        ...
        k: {'train': [...], 'dev': [...], 'test': [...]},
    }
    """
    assert train_prob + dev_prob + test_prob == 1.0
    num_folds = int(1.0 / test_prob)
    assert 1.0 / test_prob == num_folds  # check that we can evenly divide into folds

    paper_ids = list({example['instance_id'].split('_')[1] for example in examples})
    random.shuffle(paper_ids)

    num_test = int(len(paper_ids) / num_folds)
    num_train = int((len(paper_ids) - num_test) * train_prob / (train_prob + dev_prob))
    num_dev = len(paper_ids) - num_test - num_train
    assert num_train + num_dev + num_test == len(paper_ids)

    # precompute the CV paper ID assignments
    fold_to_paper_ids = {}
    for fold in range(num_folds):
        # Determine the range of paper IDs for the test set in the current fold
        start_test = fold * num_test
        end_test = start_test + num_test
        test = paper_ids[start_test:end_test]

        # Determine the remaining paper IDs for the training and development sets
        remaining_paper_ids = [paper_id for paper_id in paper_ids if paper_id not in test]
        random.shuffle(remaining_paper_ids)
        train = remaining_paper_ids[:num_train]
        dev = remaining_paper_ids[num_train:]
        assert len(train) + len(dev) + len(test) == len(paper_ids)
        fold_to_paper_ids[fold] = (train, dev, test)


    # split examples s.t. all examples from same paper go to same split
    fold_to_examples = {}
    for fold, (train, dev, test) in fold_to_paper_ids.items():
        fold_to_examples[fold] = {
            'train': [],
            'dev': [],
            'test': []
        }
        for example in examples:
            paper_id = example['instance_id'].split('_')[1]
            if paper_id in train:
                fold_to_examples[fold]['train'].append(example)
            elif paper_id in dev:
                fold_to_examples[fold]['dev'].append(example)
            else:
                fold_to_examples[fold]['test'].append(example)
    return fold_to_examples


def build_window(gold_context_sent_pos: List[int],                                                  # annotations
                 sent_pos_to_num_tokens: List[int],                                                 # resources
                 max_num_sents: int, paper_num_sents: int, max_model_num_tokens: int                # constraints
                 ) -> Optional[List[int]]:

    # let's always include the full gold context, at minimum (unless it truncates due to window, in which case we skip)
    gold_context_num_tokens = sum([sent_pos_to_num_tokens[gold_sent_pos] for gold_sent_pos in gold_context_sent_pos])

    # validate our constraints before getting the actual sentence text
    if len(gold_context_sent_pos) > max_num_sents:  # cant even contain the gold context
        print(f'Skipping... Sent Gold: {len(gold_context_sent_pos)}')
        return None
    elif gold_context_num_tokens > max_model_num_tokens:  # too long for even this longformer
        print(f'Skipping... Token Gold: {gold_context_num_tokens}')
        return None
    else:
        """Good to go :)"""

    # now let's randomly append sentences to front/back
    window_sent_pos = [sent_pos for sent_pos in gold_context_sent_pos]
    window_num_tokens = gold_context_num_tokens
    while len(window_sent_pos) <= max_num_sents and window_num_tokens <= max_model_num_tokens:
        before_sent_id = min(window_sent_pos) - 1
        after_sent_id = max(window_sent_pos) + 1

        # if only one valid option, just pick that one
        if before_sent_id < 0 and after_sent_id < paper_num_sents:
            new_sent_id = after_sent_id
        elif before_sent_id >= 0 and after_sent_id >= paper_num_sents:
            new_sent_id = before_sent_id
        # if no valid options, just get out of the whole thing
        elif before_sent_id < 0 and after_sent_id >= paper_num_sents:
            break
        # if both valid options, random
        else:
            choice = random.randint(0, 1)
            if choice == 0:
                new_sent_id = after_sent_id
            else:
                new_sent_id = before_sent_id

        # token length check
        new_sent_num_tokens = sent_pos_to_num_tokens[new_sent_id]
        if window_num_tokens + new_sent_num_tokens <= max_model_num_tokens:

            # now add to window
            window_sent_pos.append(new_sent_id)
            window_num_tokens += sent_pos_to_num_tokens[new_sent_id]

        else:
            # ran out of space so end adding phase
            break

    window_sent_pos = sorted(window_sent_pos)
    return window_sent_pos

def build_all_examples(raw_data_dict: Dict) -> List[Dict]:
    
    active_logger = False
    n = 0
    all_examples = []
    for paper_id, data in tqdm(raw_data_dict.items()):
        paper = data['x']                          # paper
        intent_to_annotations = data['y']          # intent & context annotations

        # 0) clean out IGNORED SENTS
        clean_paper = [sent_dict for sent_dict in paper if sent_dict['text'] not in IGNORE_SENTS]

        # 1) build an easy lookup for sentences from its ID (str) to its list position (int) and the oposite around
        sent_id_to_list_pos: Dict[str, int] = {sent_dict['sent_id']: i for i, sent_dict in enumerate(clean_paper)}
        sent_pos_to_sent_id: Dict[int, str] = {sent_pos: sent_id for sent_id, sent_pos in sent_id_to_list_pos.items()}

        # 2) clean html and get tokens length per sentence
        sent_pos_to_clean_text = [replace_html(sent_dict['text'], new_cite_start=CITE_START, new_cite_end=CITE_END) for sent_dict in clean_paper]
        sent_pos_to_num_tokens: List[int] = [compute_num_tokens(clean_sent, tokenizer) for clean_sent in sent_pos_to_clean_text]

        # 3) create a training example out of each gold context (i.e. each cite mention)
        #    note, this code is a bit annoying because we want to convert all sent_ids to positions first
        for intent, annotations in intent_to_annotations.items():
            for idx_gold_context, gold_context_sent_ids in enumerate(annotations['gold_contexts']):

                # 3a) build valid window
                gold_context_sent_pos = [sent_id_to_list_pos[sent_id] for sent_id in gold_context_sent_ids if sent_id in sent_id_to_list_pos]   # ignores if annotated an ignored sent
                window_sent_pos: Optional[List[int]] = build_window(gold_context_sent_pos=gold_context_sent_pos,
                                                                    sent_pos_to_num_tokens=sent_pos_to_num_tokens,
                                                                    max_num_sents=int(window),
                                                                    paper_num_sents=len(clean_paper),
                                                                    max_model_num_tokens=500)   # conservatively stay within 512
                if window_sent_pos is None:
                    continue

                # verify there is at least one <CITE>
                cite_sent_pos = [sent_id_to_list_pos[cite_sent_id] for cite_sent_id in annotations['cite_sentences'] if cite_sent_id in gold_context_sent_ids]
                if not cite_sent_pos:
                    continue
                if active_logger:
                    _describe_window(window_sent_pos=window_sent_pos,
                                 window_num_tokens=sum([sent_pos_to_num_tokens[pos] for pos in window_sent_pos]),
                                 cite_sent_pos=cite_sent_pos,
                                 gold_context_sent_pos=gold_context_sent_pos,
                                 gold_context_num_tokens=sum([sent_pos_to_num_tokens[pos] for pos in gold_context_sent_pos]))

                # 3b) get the text and labels
                example = {
                    'id': n,
                    'instance_id': f'{paper_id}__{intent}__{idx_gold_context}',
                    'intent': intent,
                    'sentences': [],
                    'labels': [],
                    'sent_ids': []
                }
                for w_pos in window_sent_pos:
                    example['sentences'].append(sent_pos_to_clean_text[w_pos])
                    example['labels'].append('context' if w_pos in gold_context_sent_pos else 'not-context')
                    example['sent_ids'].append(sent_pos_to_sent_id[w_pos])

                all_examples.append(example)
                n += 1

    return all_examples

In [19]:
# stuff for input
with open(os.path.join(input_directory, file_name)) as f_in:
    d = json.load(f_in)

# build examples
random.seed(rand_state)
all_examples = build_all_examples(raw_data_dict=d)

# stuff for output
outdir = os.path.join(output_directory, f'{model_name}__{window}__{rand_state}__{train_prob}-{dev_prob}-{test_prob}'.replace('/', '-').replace('.', ''))
if not os.path.exists(outdir):
    os.makedirs(outdir, exist_ok=True)
out_file = os.path.join(outdir, 'full.json')
with open(out_file, 'w') as outfile:
    json.dump({'data': all_examples}, outfile, indent=4)

# now do the splitting & log it
logfile = os.path.join(outdir, 'log.txt')
with open(logfile, 'w') as f_log:

    fold_to_splits = split_examples(examples=all_examples, train_prob=train_prob, dev_prob=dev_prob, test_prob=test_prob)
    for fold, splits in fold_to_splits.items():
        folddir = os.path.join(outdir, f'{fold}/')
        os.makedirs(folddir, exist_ok=True)
        trainfile = os.path.join(folddir, 'train.jsonl')
        devfile = os.path.join(folddir, 'dev.jsonl')
        testfile = os.path.join(folddir, 'test.jsonl')
        train = splits['train']
        dev = splits['dev']
        test = splits['test']
        with open(trainfile, 'w') as f_train:
            for e in train:
                json.dump(e, f_train)
                f_train.write('\n')
        with open(devfile, 'w') as f_val:
            for e in dev:
                json.dump(e, f_val)
                f_val.write('\n')
        with open(testfile, 'w') as f_test:
            for e in test:
                json.dump(e, f_test)
                f_test.write('\n')

        f_log.write(f'Counting rare class in Fold {fold} \n')
        f_log.write('train ' + f"{Counter([sum([1 if tag == 'context' else 0 for tag in e['labels']]) for e in train])} / {len(train)}\n")
        f_log.write('dev ' + f"{Counter([sum([1 if tag == 'context' else 0 for tag in e['labels']]) for e in dev])} / {len(dev)}\n")
        f_log.write('test ' + f"{Counter([sum([1 if tag == 'context' else 0 for tag in e['labels']]) for e in test])} / {len(test)}\n")
        f_log.write('\n\n')

  2%|▏         | 26/1193 [00:00<00:17, 65.66it/s]

Skipping... Sent Gold: 5


  3%|▎         | 33/1193 [00:00<00:26, 44.38it/s]

Skipping... Sent Gold: 5
Skipping... Sent Gold: 12


  5%|▍         | 55/1193 [00:01<00:46, 24.42it/s]

Skipping... Sent Gold: 5


 10%|█         | 123/1193 [00:03<00:25, 41.79it/s]

Skipping... Sent Gold: 6


 11%|█         | 128/1193 [00:03<00:24, 43.79it/s]

Skipping... Sent Gold: 7


 13%|█▎        | 152/1193 [00:04<00:22, 46.94it/s]

Skipping... Sent Gold: 5


 13%|█▎        | 158/1193 [00:04<00:21, 48.14it/s]

Skipping... Sent Gold: 9


 14%|█▎        | 163/1193 [00:04<00:21, 46.93it/s]

Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 7
Skipping... Sent Gold: 5
Skipping... Sent Gold: 6


 15%|█▌        | 184/1193 [00:05<00:29, 33.90it/s]

Skipping... Sent Gold: 6


 16%|█▋        | 196/1193 [00:05<00:24, 41.24it/s]

Skipping... Sent Gold: 9
Skipping... Sent Gold: 9


 17%|█▋        | 201/1193 [00:05<00:26, 37.93it/s]

Skipping... Sent Gold: 6


 18%|█▊        | 218/1193 [00:06<00:28, 33.94it/s]

Skipping... Sent Gold: 5


 19%|█▊        | 222/1193 [00:06<00:27, 34.99it/s]

Skipping... Sent Gold: 5


 20%|██        | 239/1193 [00:06<00:25, 37.14it/s]

Skipping... Sent Gold: 5
Skipping... Sent Gold: 5


 21%|██        | 247/1193 [00:07<00:27, 34.03it/s]

Skipping... Sent Gold: 5


 22%|██▏       | 259/1193 [00:07<00:28, 33.17it/s]

Skipping... Sent Gold: 7


 22%|██▏       | 263/1193 [00:07<00:29, 32.05it/s]

Skipping... Sent Gold: 5


 47%|████▋     | 558/1193 [00:19<00:22, 28.63it/s]

Skipping... Sent Gold: 6


 48%|████▊     | 570/1193 [00:19<00:28, 21.75it/s]

Skipping... Sent Gold: 5


 50%|█████     | 601/1193 [00:21<00:24, 23.69it/s]

Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5


 51%|█████     | 610/1193 [00:21<00:24, 23.71it/s]

Skipping... Sent Gold: 7
Skipping... Sent Gold: 7


 52%|█████▏    | 616/1193 [00:21<00:25, 22.27it/s]

Skipping... Sent Gold: 5


 52%|█████▏    | 619/1193 [00:21<00:24, 23.41it/s]

Skipping... Sent Gold: 6
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 8


 52%|█████▏    | 625/1193 [00:21<00:18, 31.29it/s]

Skipping... Sent Gold: 5


 53%|█████▎    | 632/1193 [00:22<00:14, 38.19it/s]

Skipping... Sent Gold: 14
Skipping... Sent Gold: 5


 55%|█████▌    | 662/1193 [00:23<00:24, 21.83it/s]

Skipping... Sent Gold: 6


 62%|██████▏   | 734/1193 [00:26<00:17, 26.61it/s]

Skipping... Sent Gold: 7


 62%|██████▏   | 743/1193 [00:26<00:18, 24.70it/s]

Skipping... Sent Gold: 5
Skipping... Sent Gold: 6


 63%|██████▎   | 746/1193 [00:26<00:19, 22.55it/s]

Skipping... Sent Gold: 6
Skipping... Sent Gold: 6
Skipping... Sent Gold: 5


 63%|██████▎   | 749/1193 [00:26<00:20, 22.19it/s]

Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 6
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5


 64%|██████▍   | 768/1193 [00:27<00:13, 30.77it/s]

Skipping... Sent Gold: 5


 68%|██████▊   | 812/1193 [00:28<00:10, 34.77it/s]

Skipping... Sent Gold: 10


 75%|███████▍  | 891/1193 [00:31<00:09, 30.74it/s]

Skipping... Sent Gold: 8


 78%|███████▊  | 935/1193 [00:32<00:05, 49.40it/s]

Skipping... Sent Gold: 9


 80%|███████▉  | 953/1193 [00:32<00:03, 64.28it/s]

Skipping... Sent Gold: 5


 80%|████████  | 960/1193 [00:32<00:03, 60.25it/s]

Skipping... Sent Gold: 5
Skipping... Sent Gold: 5


 82%|████████▏ | 974/1193 [00:33<00:05, 41.64it/s]

Skipping... Sent Gold: 7
Skipping... Sent Gold: 7


 85%|████████▌ | 1016/1193 [00:36<00:08, 21.19it/s]

Skipping... Sent Gold: 7


 86%|████████▌ | 1022/1193 [00:36<00:07, 23.34it/s]

Skipping... Sent Gold: 5


 86%|████████▋ | 1029/1193 [00:36<00:06, 25.59it/s]

Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5


 87%|████████▋ | 1032/1193 [00:36<00:06, 26.04it/s]

Skipping... Sent Gold: 6


 88%|████████▊ | 1046/1193 [00:37<00:05, 29.08it/s]

Skipping... Sent Gold: 5


 90%|████████▉ | 1071/1193 [00:38<00:04, 29.18it/s]

Skipping... Sent Gold: 12
Skipping... Sent Gold: 5
Skipping... Sent Gold: 12
Skipping... Sent Gold: 5


 90%|█████████ | 1074/1193 [00:38<00:04, 27.90it/s]

Skipping... Sent Gold: 7
Skipping... Sent Gold: 6
Skipping... Sent Gold: 7
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5


 90%|█████████ | 1077/1193 [00:38<00:04, 25.08it/s]

Skipping... Sent Gold: 6
Skipping... Sent Gold: 7
Skipping... Sent Gold: 7
Skipping... Sent Gold: 5
Skipping... Sent Gold: 7
Skipping... Sent Gold: 5
Skipping... Sent Gold: 7
Skipping... Sent Gold: 7
Skipping... Sent Gold: 7
Skipping... Sent Gold: 5
Skipping... Sent Gold: 7


 91%|█████████ | 1080/1193 [00:38<00:04, 23.18it/s]

Skipping... Sent Gold: 9


 91%|█████████ | 1083/1193 [00:38<00:04, 23.69it/s]

Skipping... Sent Gold: 8


 91%|█████████ | 1086/1193 [00:38<00:04, 22.69it/s]

Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 10


 91%|█████████▏| 1089/1193 [00:38<00:04, 22.35it/s]

Skipping... Sent Gold: 5
Skipping... Sent Gold: 6
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5


 92%|█████████▏| 1092/1193 [00:39<00:04, 23.95it/s]

Skipping... Sent Gold: 10
Skipping... Sent Gold: 5


 92%|█████████▏| 1095/1193 [00:39<00:04, 23.48it/s]

Skipping... Sent Gold: 5
Skipping... Sent Gold: 7
Skipping... Sent Gold: 5
Skipping... Sent Gold: 6
Skipping... Sent Gold: 5
Skipping... Sent Gold: 6
Skipping... Sent Gold: 5
Skipping... Sent Gold: 7
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 6
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5


 92%|█████████▏| 1098/1193 [00:39<00:04, 22.51it/s]

Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 6
Skipping... Sent Gold: 10
Skipping... Sent Gold: 6
Skipping... Sent Gold: 5
Skipping... Sent Gold: 6
Skipping... Sent Gold: 10
Skipping... Sent Gold: 6
Skipping... Sent Gold: 6


 92%|█████████▏| 1101/1193 [00:39<00:03, 23.47it/s]

Skipping... Sent Gold: 10
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 9
Skipping... Sent Gold: 9
Skipping... Sent Gold: 6
Skipping... Sent Gold: 6


 93%|█████████▎| 1104/1193 [00:39<00:03, 24.00it/s]

Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5
Skipping... Sent Gold: 6


 93%|█████████▎| 1108/1193 [00:39<00:03, 26.50it/s]

Skipping... Sent Gold: 6
Skipping... Sent Gold: 5
Skipping... Sent Gold: 6
Skipping... Sent Gold: 6
Skipping... Sent Gold: 5
Skipping... Sent Gold: 6
Skipping... Sent Gold: 5
Skipping... Sent Gold: 5


 93%|█████████▎| 1111/1193 [00:39<00:03, 26.34it/s]

Skipping... Sent Gold: 6


 96%|█████████▋| 1151/1193 [00:41<00:01, 27.10it/s]

Skipping... Sent Gold: 5


 98%|█████████▊| 1169/1193 [00:42<00:01, 19.40it/s]

Skipping... Sent Gold: 5
Skipping... Sent Gold: 5


 98%|█████████▊| 1172/1193 [00:42<00:01, 19.61it/s]

Skipping... Sent Gold: 6


100%|█████████▉| 1190/1193 [00:43<00:00, 13.09it/s]

Skipping... Sent Gold: 5


100%|██████████| 1193/1193 [00:43<00:00, 27.13it/s]
