In [1]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
import random
import os
import json
import nltk
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
try:
    nltk.data.find('tokenizers/punkt')
    print("Punkt tokenizers already installed.")
except:
    print("Punkt tokenizers not found; installing now.")
    nltk.download('punkt')
from nltk.tokenize import sent_tokenize
import re
from datetime import datetime

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

  from .autonotebook import tqdm as notebook_tqdm


Punkt tokenizers already installed.


In [2]:
def prep_dataset(text_path: str, task: str='QUANT', bert_type: str="bert-base-cased", label_path: str=None, 
                 remove_markers: bool=True, batch_size: int=64, torch_seed: int=434):
    
    """
    Prepare data for model creation with BERT variants.
    =====================================================
    Inputs:
    text_path (str, required): filepath to text files
    task (str): which MeasEval task to prepare dataset for
    bert_type (str): type of BERT variant (huggingface)
    label_path (str): filepath to labels for training datasets
    remove_markers (bool): toggle to remove figure numbers, publication years, etc.
    batch_size (int): batch size for training
    torch_seed (int): seed number for (some) reproducibility
    =====================================================
    Outputs:
    dataset (list): flattened/tokenized/batched sentences w/annotations
    """
        
    torch.manual_seed(torch_seed)
    
    if not label_path:
        dataset_type = "test"
    else:
        dataset_type = "training or validation"
        
    typemap = {"Quantity": "QUANT",
           "MeasuredEntity": "ME", 
           "MeasuredProperty": "MP", 
           "Qualifier": "QUAL"
        }
        
    print("="*40)
    print("Preparing **", dataset_type, "** dataset, based on arguments provided.", sep="")
    if not label_path:
        print("If you intended to prepare training/validation data, provide a label_path to the function.")
    print("="*40)
    
    textset = {}
    for fn in os.listdir(text_path):
        with open(text_path+fn, encoding='utf-8', errors='replace') as textfile:
            text = textfile.read()
            textset[fn[:-4]] = text
    
    if label_path:
        # Load all annotations
        files_with_label = [file_name[:-4] for file_name in os.listdir(label_path)]
        all_files_with_or_without_label = [file_name[:-4] for file_name in os.listdir(text_path)]
        files_without_label = list(set(all_files_with_or_without_label).difference(set(files_with_label)))
        print("Unlabeled files: ", len(files_without_label), ". Labeled files: ", len(files_with_label),".", sep="")
        if len(files_without_label) > 0:
            print("Example of unlabeled file: ", files_without_label[0],".", sep="")
    else:
        files_without_label = [file_name[:-4] for file_name in os.listdir(text_path)]
        
    
    all_data = []
    
    if label_path:
        # Load annotations from files with labels, if applicable
        for fn_no_ext in files_with_label:
            fn = fn_no_ext + ".tsv"
            entities = {"QUANT": [], "ME": [], "MP": [], "QUAL": []}
            with open(label_path+fn, encoding='utf-8', errors='replace') as annotfile:
                text = textset[fn[:-4]]
                next(annotfile)
                annots = annotfile.read().splitlines()
                for a in annots:
                    annot = a.split("\t")
                    atype = typemap[annot[2]]
                    start = int(annot[3])
                    stop = int(annot[4])
                    # This is where we toss out the overlaps:
                    overlap = False
                    for ent in entities[atype]:
                        if ((start >= ent[0] and start <= ent[1]) or (stop >= ent[0] and stop <= ent[1]) or
                            (ent[0] >= start and ent[0] <= stop) or (ent[1] >= start and ent[1] <= stop)):
                            overlap = True
                    if overlap == False:    
                        entities[atype].append((start, stop, atype))
                all_data.append((text,
                                {"QUANT": entities["QUANT"],
                                 "ME": entities["ME"],
                                 "MP": entities["MP"],
                                 "QUAL": entities["QUAL"]
                                },
                                (fn,)
                            ))

    # Load annotations from files without label
    for fn in files_without_label:
        text = textset[fn]
        all_data.append((text,
                        {"QUANT": [],
                         "ME": [],
                         "MP": [],
                         "QUAL": []
                    },
                    (fn,)
                ))
        
    # ===== splits text data into sentences, applying processing if desired =====

    cnt_toks = {"figs.": 0, "fig.": 0, "et al.": 0,
                "ref.": 0, "eq.": 0, "e.g.": 0,
                "i.e.": 0, "nos.": 0, "no.": 0,
                "spp.": 0
                }
    regex_end_checker = [".*[a-zA-Z]figs\.$", 
                        ".*[a-zA-Z]fig\.$",
                        ".*[a-zA-Z]et al\.$",
                        ".*[a-zA-Z]ref\.$",
                        ".*[a-zA-Z]eq\.$",
                        ".*[a-zA-Z]e\.g\.$",
                        ".*[a-zA-Z]i\.e\.$",
                        ".*[a-zA-Z]nos\.$",
                        ".*[a-zA-Z]no\.$",
                        ".*[a-zA-Z]spp\.$",
                        # figs., fig., et al., Ref., Eq., e.g., i.e., Nos., No., spp.
                    ]

    assert len(cnt_toks) == len(regex_end_checker)

    # list of sentences
    # for every sentence obtained
    # check if ends with "fig or Figs or et al."
    # keep track of the index where the this sentence starts and ends
    # adjust pointers to tokens where changes were made

    all_processed_data = []
    for doc in all_data:
        flag = False
        sentences = sent_tokenize(doc[0])

        fixed_sentence_tokens = []
        curr_len = 0
        for s in sentences:
            if flag == True:
                assert s[0] != ' '
                white_length = doc[0][curr_len:].find(s[0])

                prev_len = len(fixed_sentence_tokens[-1])
                fixed_sentence_tokens[-1] = fixed_sentence_tokens[-1] + (" "*white_length) + s

                assert fixed_sentence_tokens[-1][prev_len+white_length] == doc[0][curr_len+white_length], (fixed_sentence_tokens[-1], doc[0], curr_len, tmp_this_sent_len)
                tmp_this_sent_len = white_length + len(s)
                assert fixed_sentence_tokens[-1][-1] == doc[0][curr_len+tmp_this_sent_len-1], (fixed_sentence_tokens[-1], doc[0], curr_len, tmp_this_sent_len)
                curr_len += tmp_this_sent_len
            else:
                if len(fixed_sentence_tokens) != 0:
                    assert s[0] != ' '
                    white_length = doc[0][curr_len:].find(s[0])
                    fixed_sentence_tokens.append( (" "*white_length) + s )
                else:
                    fixed_sentence_tokens.append(s)
                assert fixed_sentence_tokens[-1][0] == doc[0][curr_len], (fixed_sentence_tokens, doc[0], curr_len, tmp_this_sent_len)
                tmp_this_sent_len = len(fixed_sentence_tokens[-1])
                assert fixed_sentence_tokens[-1][-1] == doc[0][curr_len+tmp_this_sent_len-1], (fixed_sentence_tokens[-1], doc[0], curr_len, tmp_this_sent_len)
                curr_len += tmp_this_sent_len

            lower_cased_s = fixed_sentence_tokens[-1].lower()
            flag = False
            if remove_markers:
                for i, k in enumerate(cnt_toks):
                    this_regex_pattern = regex_end_checker[i]
                    if lower_cased_s.endswith(k) and re.match(this_regex_pattern, lower_cased_s) == None:
                        cnt_toks[k] += 1
                        flag = True
                        break

        all_processed_data.append(fixed_sentence_tokens)      
    print("Fixed sentence splitting:", cnt_toks)

        
    # Load in annotations/associate with text from files
    # In each loop,replace all numbers by zero,
    # check if any of the annotations fall in between the start and end. 
    # Deal with overlaps and add offsets

    normalize = lambda x: re.sub(r'\d', '0', x)
    all_annotated_split_data = []

    for doc, sent_splits in zip(all_data, all_processed_data):

        this_offsets = []

        prev_end = 0
        for s in sent_splits:
            this_offsets.append([prev_end, prev_end+len(s)])
            prev_end += len(s)

        this_annotations = []
        for s, offset in zip(sent_splits, this_offsets):
            this_sent_ann = {}
            for k,v in doc[1].items():
                this_key_annotation_sentence = []

                for ann in v:
                    if offset[0] <= ann[0] and ann[1] < offset[1]:
                        this_key_annotation_sentence.append((ann[0]-offset[0], ann[1]-offset[0]))

                this_sent_ann[k] = this_key_annotation_sentence

            this_annotations.append(this_sent_ann)

        all_annotated_split_data.append({'doc_id': doc[-1][0],
                    'sentences': [normalize(ss) for ss in sent_splits],
                    'offsets': this_offsets,
                    'annotations': this_annotations
                })
        # print(all_annotated_split_data[-1])
        assert len(all_annotated_split_data[-1]['offsets']) == len(all_annotated_split_data[-1]['sentences'])
        assert len(all_annotated_split_data[-1]['offsets']) == len(all_annotated_split_data[-1]['annotations'])
        
    bert_tok = AutoTokenizer.from_pretrained(bert_type, use_fast=True)
    
    # First flatten and shuffle
    # Then Batch
    
    flattened = [[doc['doc_id'], doc['sentences'][i], doc['offsets'][i], doc['annotations'][i]]
                 for doc in all_annotated_split_data for i in range(len(doc['sentences']))]

    print(f'Flattened {len(all_annotated_split_data)} docs into {len(flattened)} sentences.',
            "\nSome examples:", flattened[:2])
    
    if label_path:
        random.shuffle(flattened)

    # cls_token_idx = bert_tok.convert_tokens_to_ids(bert_tok.tokenize('[CLS]'))[0]
    # sep_token_idx = bert_tok.convert_tokens_to_ids(bert_tok.tokenize('[SEP]'))[0]
    # pad_token_idx = bert_tok.convert_tokens_to_ids(bert_tok.tokenize('[PAD]'))[0]

    dataset = []
    idx = 0
    num_data = len(flattened)
    while idx < num_data:
        batch_doc_ids = []
        batch_sent_offsets = []
        batch_raw_text = []
        batch_raw_labels = []
        batch_tokens = []
        batch_token_offset = []

        for single_docid, single_sentence, single_offset, single_annotations in \
                    flattened[idx:min(idx + batch_size, num_data)]:

            batch_doc_ids.append(single_docid)
            batch_sent_offsets.append(single_offset)
            batch_raw_text.append(single_sentence)
            if label_path:
                batch_raw_labels.append(single_annotations)


        batched_dict = bert_tok.batch_encode_plus(batch_raw_text,
                                                    return_offsets_mapping=True,
                                                    padding=True)

        batch_tokens = torch.LongTensor(batched_dict['input_ids']).to(device)
        batch_maxlen = batch_tokens.shape[-1]

            
        pad_masks = torch.LongTensor(batched_dict['attention_mask']).to(device)
        batch_offset_mapping = batched_dict['offset_mapping']

        if label_path:
            
            # Create sequence labels using token offsets
            batch_labels = []
            for single_token_offs, single_anns in zip(batch_offset_mapping, batch_raw_labels):
                anns = single_anns[task]
                single_labels = []
                i = 0
                for off in single_token_offs:
                    if off == (0,0):
                        single_labels.append(0)
                    elif type(anns) == list and i < len(anns):
                        if off[1] < anns[i][0]:
                            single_labels.append(0)
                        elif off[0] > anns[i][1]:
                            i += 1
                            single_labels.append(0)
                        else:
                            single_labels.append(1)
                    else:
                        single_labels.append(0)
                batch_labels.append(single_labels)

            batch_labels = torch.LongTensor(batch_labels).to(device)

        b = batch_size if (idx + batch_size) < num_data else (num_data - idx)
        assert batch_tokens.size() == torch.Size([b, batch_maxlen])
        assert pad_masks.size() == torch.Size([b, batch_maxlen])
        
        if label_path:
            assert batch_labels.size() == torch.Size([b, batch_maxlen])
            dataset.append((batch_tokens, batch_labels, pad_masks, batch_doc_ids, batch_sent_offsets, batch_offset_mapping))
        else:
            dataset.append((batch_tokens, batch_raw_text, pad_masks, batch_doc_ids, batch_sent_offsets, batch_offset_mapping))
        assert pad_masks.size() == torch.Size([b, batch_maxlen])
        idx += batch_size

    print("num_batches=", len(dataset), " | num_data=", num_data)
    print("="*40)
    print("Dataset created!")
    print("="*40)
    
    return dataset

In [3]:
train_text = "../data/raw/train/text/"
train_labels = "../data/raw/train/tsv/"
bert_type = "roberta-base"

task = 'QUANT'

train_dataset = prep_dataset(text_path=train_text, task=task, bert_type=bert_type, label_path=train_labels, batch_size=128)

Preparing **training or validation** dataset, based on arguments provided.
Unlabeled files: 15. Labeled files: 233.
Example of unlabeled file: S1570870512000637-1206.
Fixed sentence splitting: {'figs.': 4, 'fig.': 92, 'et al.': 31, 'ref.': 3, 'eq.': 3, 'e.g.': 8, 'i.e.': 3, 'nos.': 2, 'no.': 3, 'spp.': 2}
Flattened 248 docs into 1265 sentences. 
Some examples: [['S0006322312001096-1136.tsv', 'Data were drawn from the Whitehall II study with baseline examination in 0000; follow-up screenings in 0000, 0000, and 0000; and additional disease ascertainment from hospital data and registry linkage on 0000 participants (mean age 00.0 years, 00% women) without depressive symptoms at baseline.', [0, 296], {'QUANT': [(73, 77), (103, 123), (205, 222), (233, 243), (245, 248)], 'ME': [(49, 69), (79, 99), (25, 43), (205, 222)], 'MP': [(224, 232), (249, 254)], 'QUAL': []}], ['S0006322312001096-1136.tsv', ' Vascular risk was assessed with the Framingham Cardiovascular, Coronary Heart Disease, and Strok

In [4]:
trial_text = "../data/raw/trial/txt/"
trial_labels = "../data/raw/trial/tsv/"
bert_type = "roberta-base"

valid_dataset = prep_dataset(text_path=trial_text, task='QUANT', bert_type=bert_type, label_path=trial_labels, batch_size=128)

Preparing **training or validation** dataset, based on arguments provided.
Unlabeled files: 0. Labeled files: 65.
Fixed sentence splitting: {'figs.': 6, 'fig.': 61, 'et al.': 22, 'ref.': 0, 'eq.': 2, 'e.g.': 5, 'i.e.': 0, 'nos.': 0, 'no.': 0, 'spp.': 4}
Flattened 65 docs into 420 sentences. 
Some examples: [['S0012821X12004384-1302.tsv', 'Correspondence analysis (CA) and statistical diversity analysis were carried out on the palynological dataset (total counts per gram) to confirm assemblage designations (Figs. 0 and 0), to identify any disturbance to the core prior to interpretation, and to estimate diversity (Fig. 0).', [0, 286], {'QUANT': [], 'ME': [], 'MP': [], 'QUAL': []}], ['S0012821X12004384-1302.tsv', ' Dinoflagellate cyst assemblages (DA0–DA0) and pollen assemblages (PA0–PA0) were defined by visually comparing changes in the species dominance (Figs. 0 and 0), and confirmed by CA (Fig. 0) using the first three axes (describing the highest percentages of variance).', [286, 552],

In [5]:
class OurBERTModel(nn.Module):
    def __init__(self):
        super(OurBERTModel, self).__init__()
        self.bert = AutoModel.from_pretrained(bert_type)
        self.drop = nn.Dropout(self.bert.config.hidden_dropout_prob)
        self.classifier = nn.Linear(self.bert.config.hidden_size, 2)

    def forward(self, text, att_mask):
        b, num_tokens = text.shape
        token_type = torch.zeros((b, num_tokens), dtype=torch.long).to(device)
        outputs = self.bert(text, attention_mask=att_mask, token_type_ids=token_type)
        return self.classifier(self.drop(outputs['last_hidden_state']))

In [6]:
model = OurBERTModel()
print("Model created. Model has", sum(p.numel() for p in model.parameters()), "parameters.")
model = model.to(device)

Model created. Model has 124647170 parameters.


In [7]:
def train_model(bert_model, task: str='QUANT', n_epochs: int=6, lr: float=3e-05, train_dataset: list=train_dataset, 
                valid_dataset: list=valid_dataset, torch_seed: int=434, dummy_run: bool=False):
    
    """
    Build and train BERT model according to specifications provided.
    =====================================================
    Inputs:
    bert_model (OurBERTModel): model to train
    task (str): which MeasEval task to train on
    n_epochs (int): number of epochs to train the model
    lr (float): learning rate
    train_dataset (list): training dataset prepared with prep_dataset func
    valid_dataset (list): validation dataset prepared with prep_dataset func
    torch_seed (int): seed number for (some) reproducibility
    dummy_run (bool): flag to not run full evaluation and only test pipeline
    =====================================================
    Outputs:
    dataset (list): flattened/tokenized/batched sentences w/annotations
    """
    
    if torch.cuda.is_available():
        device = 'cuda'
        print("Detected", torch.cuda.device_count(), "GPUs; will train using CUDA.")
    else:
        device = 'cpu'
        print("No CUDA-enabled GPUs detected; will train using CPU.")
        
    torch.manual_seed(torch_seed)

    
    #### Train one epoch ####
    def train(model, batched_dataset, criterion):
        model.train() # Put the model to train

        # Lets keep track of the losses at each update
        train_losses = []
        num_batch = 0

        for batch in batched_dataset:
            # Unpack the batch
            (texts, labels, att_masks, doc_ids, offsets, token_offsets) = batch

            # Make predictions on the model
            preds = model(texts, att_masks)

            # Take into account padded while calculating loss
            loss_unreduced = criterion(preds.permute(0,2,1), labels)
            loss = (loss_unreduced * att_masks).sum() / (att_masks).sum()

            # Update model weights
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if num_batch % 10 == 0:
                print("Train loss at {}:".format(num_batch), loss.item())

            num_batch += 1
            # Append losses
            train_losses.append(loss.item())

        return np.average(train_losses)
    
    def evaluate(model, batched_dataset, criterion):
        target_names=["NOT_" + task, task]

        # Put the model to eval mode
        model.eval()

        # Keep track on predictions
        valid_losses = []
        predicts = []
        gnd_truths = []

        with torch.no_grad():
            for batch in batched_dataset:
                # Unpack the batch
                (texts, labels, att_masks, doc_ids, offsets, token_offsets) = batch
                # Make predictions on the model
                preds = model(texts, att_masks)

                # Take into account padded while calculating loss
                loss_unreduced = criterion(preds.permute(0,2,1), labels)
                loss = (loss_unreduced * att_masks).sum() / (att_masks).sum()

                # Get argmax of non-padded tokens
                for sent_preds, sent_labels, sent_att_masks in zip(preds, labels, att_masks):
                    for token_preds, token_labels, token_masks in zip(sent_preds, sent_labels, sent_att_masks):
                        if token_masks.item() != 0:
                            predicts.append(token_preds.argmax().item())
                            gnd_truths.append((token_labels.item()))
                valid_losses.append(loss.item())

                assert len(predicts) == len(gnd_truths)

        # Create confusion matrix and evaluate on the predictions
        confuse_mat = confusion_matrix(gnd_truths, predicts)
        if dummy_run:
            classify_report = None
        else:
            classify_report = classification_report(gnd_truths, predicts,
                                            target_names=target_names,
                                            output_dict=True)

        mean_valid_loss = np.average(valid_losses)
        print("Valid_loss", mean_valid_loss)
        print(confuse_mat)

        if not dummy_run:
            for labl in target_names:
                print(labl,"F1-score:", classify_report[labl]["f1-score"])
            print("Accu:", classify_report["accuracy"])
            print("F1-Weighted", classify_report["weighted avg"]["f1-score"])
            print("F1-Avg", classify_report["macro avg"]["f1-score"])

        return mean_valid_loss, confuse_mat ,classify_report
    
    ########## Optimizer & Loss ###########

    criterion = torch.nn.CrossEntropyLoss(reduction='none')
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    ########## Training loop ###########

    for epoch in range(n_epochs):
        print("\n\n========= Beginning epoch", epoch+1,"==========")

        train_loss = train(model, train_dataset, criterion)
        print("\n==== EVALUATING On Training set ====\n")
        _, _, _ = evaluate(model, train_dataset, criterion)

        print("\n==== EVALUATING On Validation set ====\n")
        valid_loss, confuse_mat, classify_report = evaluate(model, valid_dataset, criterion)

        epoch_len = len(str(n_epochs))
        print_msg = (f'[{epoch+1:>{epoch_len}}/{n_epochs:>{epoch_len}}]     ' +
                        f'train_loss: {train_loss:.5f} ' +
                        f'valid_loss: {valid_loss:.5f}')
        print(print_msg)

In [8]:
train_model(model, task=task, n_epochs=6, lr=3e-05)

Detected 1 GPUs; will train using CUDA.


Train loss at 0: 0.7087559103965759

==== EVALUATING On Training set ====

Valid_loss 0.1631327599287033
[[44011     3]
 [ 3555    25]]
NOT_QUANT F1-score: 0.9611487224284777
QUANT F1-score: 0.013858093126385808
Accu: 0.925242677648443
F1-Weighted 0.8898939328982536
F1-Avg 0.4875034077774318

==== EVALUATING On Validation set ====

Valid_loss 0.15639495477080345
[[15781     1]
 [ 1204     0]]
NOT_QUANT F1-score: 0.9632251960814234
QUANT F1-score: 0.0
Accu: 0.9290592252443188
F1-Weighted 0.8949499614127532
F1-Avg 0.4816125980407117
[1/6]     train_loss: 0.33031 valid_loss: 0.15639


Train loss at 0: 0.19527728855609894

==== EVALUATING On Training set ====

Valid_loss 0.07517297379672527
[[43347   667]
 [  531  3049]]
NOT_QUANT F1-score: 0.986369635461703
QUANT F1-score: 0.8358004385964911
Accu: 0.974828759927722
F1-Weighted 0.9750438858971053
F1-Avg 0.911085037029097

==== EVALUATING On Validation set ====

Valid_loss 0.05786129832267761
[[1555

In [9]:
# Predict spans on Test, Train and Val set

def predict_spans(batched_dataset, save_path):
    # Put the model to eval mode and track the predictions
        
    model.eval()
    valid_losses = []
    correct_sentences = 0
    total_sentences = 0 

    with torch.no_grad():
        for batch in valid_dataset:
            # Unpack the batch and feed into the model
            (texts, labels, att_masks, doc_ids, offsets, token_offsets) = batch
            preds = model(texts, att_masks)

            # Check for mispredicts
            for sent_preds, sent_labels, sent_att_masks, sent_doc_id, sent_offset in zip(preds, labels, att_masks, doc_ids, offsets):
                this_correct = (sent_preds.argmax(1) == sent_labels).sum()
                this_total = len(sent_labels)

                if this_correct != this_total:
                    print("Mispredict:", sent_doc_id, "for sentence at offset", sent_offset)

    with torch.no_grad():
        span_dict = {}
        for batch in batched_dataset:
            (texts, raw_texts, att_masks, doc_ids, offsets, token_offsets) = batch
            preds = model(texts, att_masks)

            for sent_preds, sent_raw_text, sent_att_masks, sent_doc_id, sent_offset, sent_token_offsets in zip(preds, raw_texts, att_masks, doc_ids, offsets, token_offsets):

                this_sentence_positives = []
                curr_positive_idx = -1
                for i, (token_preds, token_labels, token_masks) in enumerate(zip(sent_preds, sent_labels, sent_att_masks)):
                    if token_masks.item() != 0:
                        if token_preds.argmax().item() == 1:
                            if curr_positive_idx == -1:
                                curr_positive_idx = i
                        else:
                            if curr_positive_idx != -1:
                                this_sentence_positives.append([curr_positive_idx, i-1])
                                curr_positive_idx = -1
                    else:
                        if curr_positive_idx != -1:
                            this_sentence_positives.append([curr_positive_idx, i-1])
                            curr_positive_idx = -1
                            break

                # Here convert indices to offsets
                if sent_doc_id not in span_dict.keys():
                    span_dict[sent_doc_id] = {}

                this_sent_spans = []
                for span_offsets in this_sentence_positives:
                    this_sent_spans.append([sent_token_offsets[span_offsets[0]][0],
                                            sent_token_offsets[span_offsets[1]][1]
                                        ])
                
                assert sent_offset[0] not in span_dict[sent_doc_id].keys()
                span_dict[sent_doc_id][sent_offset[0]] = this_sent_spans
    
    json.dump(span_dict, open(save_path, 'w+'))

In [10]:
# Save model and prediction

eval_doc_path = "../data/raw/eval/text/"

folder_name = "../outputs/span_predictions_" + bert_type.replace('/', '_') + '_' + datetime.today().strftime('%Y-%m-%d_%H_%M_%S')
if os.path.isdir(folder_name):
    os.system("rm -rf " + folder_name)
os.mkdir(folder_name)

test_dataset = prep_dataset(eval_doc_path, task='QUANT')

predict_spans(train_dataset, folder_name+"/train_spans.json")
predict_spans(valid_dataset, folder_name+"/trial_spans.json")
predict_spans(test_dataset, folder_name+"/test_spans.json")

# don't write model to repo haha, it's ~500 MB...
# torch.save(model.state_dict(), folder_name+"/model.pt")

print("="*40)
print("Predictions complete and written to file!")
print("="*40)

## Add code to write parameters used to file (lr, epochs, what else?)

Preparing **test** dataset, based on arguments provided.
If you intended to prepare training/validation data, provide a label_path to the function.
Fixed sentence splitting: {'figs.': 1, 'fig.': 59, 'et al.': 33, 'ref.': 5, 'eq.': 0, 'e.g.': 3, 'i.e.': 1, 'nos.': 0, 'no.': 0, 'spp.': 1}
Flattened 135 docs into 737 sentences. 
Some examples: [['S0012821X12004384-1610', 'The brief peak in Apectodinium, AOM and low salinity dinoflagellate cysts (Deflandrea) at 0000.0 m (Fig. 0) indicate a sporadic episode of surface water freshening/eutrophication before the CIE, which is best explained by an increase in regional precipitation due to its rapid nature.', [0, 284], {'QUANT': [], 'ME': [], 'MP': [], 'QUAL': []}], ['S0012821X12004384-1610', ' A coincident reduction in I. hiatus swamp conifers indicates possible disturbance of nearby coastal environments possibly from flooding (see Section 0.0).', [284, 439], {'QUANT': [], 'ME': [], 'MP': [], 'QUAL': []}]]
num_batches= 12  | num_data= 737
Data

Mispredict: S0012821X13002185-1231.tsv for sentence at offset [0, 379]
Mispredict: S0019103512004009-5507.tsv for sentence at offset [0, 312]
Mispredict: S0012821X12004384-1405.tsv for sentence at offset [601, 658]
Mispredict: S0019103512004009-4492.tsv for sentence at offset [471, 663]
Mispredict: S0019103512003533-4685.tsv for sentence at offset [473, 610]
Mispredict: S0019103512003995-2760.tsv for sentence at offset [201, 341]
Mispredict: S0019103512003995-1910.tsv for sentence at offset [163, 346]
Mispredict: S0012821X13002185-1217.tsv for sentence at offset [944, 1057]
Mispredict: S0019103512004009-3825.tsv for sentence at offset [334, 667]
Mispredict: S0012821X12004384-1302.tsv for sentence at offset [745, 864]
Mispredict: S0019103512004009-4492.tsv for sentence at offset [297, 471]
Mispredict: S0019103513005058-4210.tsv for sentence at offset [437, 647]
Mispredict: S0022000014000026-18167.tsv for sentence at offset [5663, 5830]
Mispredict: S0019103512003533-3348.tsv for sentence

Mispredict: S0019103512004009-3962.tsv for sentence at offset [415, 584]
Mispredict: S0016236113008041-967.tsv for sentence at offset [0, 71]
Mispredict: S0021979713004438-1969.tsv for sentence at offset [188, 307]
Mispredict: S0019103512003995-1910.tsv for sentence at offset [0, 163]
Mispredict: S0012821X13007309-1989.tsv for sentence at offset [918, 1133]
Mispredict: S0022459611006116-1195.tsv for sentence at offset [138, 237]
Mispredict: S0019103512003533-5031.tsv for sentence at offset [413, 575]
Mispredict: S0019103512003995-3548.tsv for sentence at offset [612, 894]
Predictions complete and written to file!
