In [1]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer

try:
    from params import params
except:
    print("Params not loaded")
    class Params:
        def __init__(self):
            self.bert_type = "bert-base-cased"
            self.device = "cuda"
            self.batch_size = 16
            self.task = "QUANT"
    params = Params()

import random
import os
import json
import nltk
from nltk.tokenize import sent_tokenize
import re

typemap = {"Quantity": "QUANT",
           "MeasuredEntity": "ME", 
           "MeasuredProperty": "MP", 
           "Qualifier": "QUAL"
        }

  from .autonotebook import tqdm as notebook_tqdm


Params not loaded


[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\matts\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [11]:
text_path = "./data/train/text/"
label_path = "./data/train/tsv/"

In [6]:
def get_doc_ids(text_path):
    # Get text and associated IDs from each text
    textset = {}
    for fn in os.listdir(text_path):
        with open(text_path+fn, encoding='utf-8', errors='replace') as textfile:
            text = textfile.read()
            textset[fn[:-4]] = text

    return textset

In [12]:
textset = get_doc_ids(text_path)

In [15]:
print(textset['S0012821X12004384-952'])

Carbon isotopic results of total organic matter (δ13CTOC) and amorphous organic matter (δ13CAOM) against core 22/10a-4 lithology and Apectodinium spp. (%). Blue=bulk rock δ13CTOC; black=δ13CAOM; solid red symbols=bulk rock δ13CTOC from samples with <30% wood/plant tissue (determined from palynological residue of the sample); open red symbols=bulk rock δ13CTOC from samples with >30% wood/plant tissue. The first appearance of Apectodinium augustum identifies the PETM in the North Sea (Bujak and Brinkhuis, 1998), and the first negative shift in δ13C identifies the approximate position of the CIE onset and the Paleocene–Eocene boundary. Values shaded at 2614.7 and 2619.6 m are considered possible outliers based on statistical analysis of the palynological residues (see Section 4.1). Lithologic column shows position of sand intervals (yellow), claystone intervals (brown; predominantly laminated claystone, dark brown), and ash layers (pink). (For interpretation of the references to color in 

In [34]:
# Load all annotations
files_with_label = [file_name[:-4] for file_name in os.listdir(label_path)]
all_files_with_or_without_label = [file_name[:-4] for file_name in os.listdir(text_path)]
files_without_label = list(set(all_files_with_or_without_label).difference(set(files_with_label)))
print(files_without_label[0])
print(len(files_with_label), len(files_without_label), len(all_files_with_or_without_label))

S0022000014000026-7850
233 15 248


In [18]:
def load_dataset(files_with_label, files_without_label):

    alldata = []

    # Load annotations from files with label
    for fn_no_ext in files_with_label:
        fn = fn_no_ext + ".tsv"
        entities = {"QUANT": [], "ME": [], "MP": [], "QUAL": []}
        with open(label_path+fn, encoding='utf-8', errors='replace') as annotfile:
            text = textset[fn[:-4]]
            next(annotfile)
            annots = annotfile.read().splitlines()
            for a in annots:
                annot = a.split("\t")
                atype = typemap[annot[2]]
                start = int(annot[3])
                stop = int(annot[4])
                # This is where we toss out the overlaps:
                overlap = False
                for ent in entities[atype]:
                    if ((start >= ent[0] and start <= ent[1]) or (stop >= ent[0] and stop <= ent[1]) or
                        (ent[0] >= start and ent[0] <= stop) or (ent[1] >= start and ent[1] <= stop)):
                        overlap = True
                if overlap == False:    
                    entities[atype].append((start, stop, atype))
            alldata.append((text,
                            {"QUANT": entities["QUANT"],
                             "ME": entities["ME"],
                             "MP": entities["MP"],
                             "QUAL": entities["QUAL"]
                            },
                            (fn,)
                        ))

    # Load annotations from files without label
    for fn in files_without_label:
        text = textset[fn]
        alldata.append((text,
                        {"QUANT": [],
                         "ME": [],
                         "MP": [],
                         "QUAL": []
                    },
                    (fn,)
                ))
    return alldata

In [27]:
all_data = load_dataset(files_with_label, files_without_label)
all_data[0]

('Data were drawn from the Whitehall II study with baseline examination in 1991; follow-up screenings in 1997, 2003, and 2008; and additional disease ascertainment from hospital data and registry linkage on 5318 participants (mean age 54.8 years, 31% women) without depressive symptoms at baseline. Vascular risk was assessed with the Framingham Cardiovascular, Coronary Heart Disease, and Stroke Risk Scores. New depressive symptoms at each follow-up screening were identified by General Health Questionnaire caseness, a Center for Epidemiologic Studies Depression Scale score ≥16, and use of antidepressant medication.',
 {'QUANT': [(73, 77, 'QUANT'),
   (103, 123, 'QUANT'),
   (205, 222, 'QUANT'),
   (233, 243, 'QUANT'),
   (245, 248, 'QUANT'),
   (576, 579, 'QUANT')],
  'ME': [(49, 69, 'ME'),
   (79, 99, 'ME'),
   (25, 43, 'ME'),
   (205, 222, 'ME'),
   (520, 569, 'ME')],
  'MP': [(224, 232, 'MP'), (249, 254, 'MP'), (570, 575, 'MP')],
  'QUAL': []},
 ('S0006322312001096-1136.tsv',))

In [24]:
def tokenize_split_data(all_data):
    # Sentence Tokenize the dataset
    processed_data = []

    cnt_toks = {"figs.": 0, "fig.": 0, "et al.": 0,
                "ref.": 0, "eq.": 0, "e.g.": 0,
                "i.e.": 0, "nos.": 0, "no.": 0,
                "spp.": 0
                }
    regex_end_checker = [".*[a-zA-Z]figs\.$", 
                        ".*[a-zA-Z]fig\.$",
                        ".*[a-zA-Z]et al\.$",
                        ".*[a-zA-Z]ref\.$",
                        ".*[a-zA-Z]eq\.$",
                        ".*[a-zA-Z]e\.g\.$",
                        ".*[a-zA-Z]i\.e\.$",
                        ".*[a-zA-Z]nos\.$",
                        ".*[a-zA-Z]no\.$",
                        ".*[a-zA-Z]spp\.$",
                        # figs., fig., et al., Ref., Eq., e.g., i.e., Nos., No., spp.
                    ]

    assert len(cnt_toks) == len(regex_end_checker)

    # list of sentences
    # for every tokenized sentence obtained
        # check if ends with "fig or Figs or et al."
        # keep track of the index where the this sentence starts and end,

    all_tokenized_data = []
    for doc in all_data:
        flag = False
        sentences = sent_tokenize(doc[0])

        fixed_sentence_tokens = []
        curr_len = 0
        for s in sentences:
            if flag == True:
                assert s[0] != ' '
                white_length = doc[0][curr_len:].find(s[0])

                prev_len = len(fixed_sentence_tokens[-1])
                fixed_sentence_tokens[-1] = fixed_sentence_tokens[-1] + (" "*white_length) + s

                assert fixed_sentence_tokens[-1][prev_len+white_length] == doc[0][curr_len+white_length], (fixed_sentence_tokens[-1], doc[0], curr_len, tmp_this_sent_len)
                tmp_this_sent_len = white_length + len(s)
                assert fixed_sentence_tokens[-1][-1] == doc[0][curr_len+tmp_this_sent_len-1], (fixed_sentence_tokens[-1], doc[0], curr_len, tmp_this_sent_len)
                curr_len += tmp_this_sent_len
            else:
                if len(fixed_sentence_tokens) != 0:
                    assert s[0] != ' '
                    white_length = doc[0][curr_len:].find(s[0])
                    fixed_sentence_tokens.append( (" "*white_length) + s )
                else:
                    fixed_sentence_tokens.append(s)
                assert fixed_sentence_tokens[-1][0] == doc[0][curr_len], (fixed_sentence_tokens, doc[0], curr_len, tmp_this_sent_len)
                tmp_this_sent_len = len(fixed_sentence_tokens[-1])
                assert fixed_sentence_tokens[-1][-1] == doc[0][curr_len+tmp_this_sent_len-1], (fixed_sentence_tokens[-1], doc[0], curr_len, tmp_this_sent_len)
                curr_len += tmp_this_sent_len

            lower_cased_s = fixed_sentence_tokens[-1].lower()
            flag = False
            for i, k in enumerate(cnt_toks):
                this_regex_pattern = regex_end_checker[i]
                if lower_cased_s.endswith(k) and re.match(this_regex_pattern, lower_cased_s) == None:
                    cnt_toks[k] += 1
                    flag = True
                    break

        all_tokenized_data.append(fixed_sentence_tokens)      
    print("Fixed sentence splitting:", cnt_toks)
    return all_tokenized_data

In [26]:
tokenized_data = tokenize_split_data(all_data)
tokenized_data[0]

Fixed sentence splitting: {'figs.': 4, 'fig.': 92, 'et al.': 31, 'ref.': 3, 'eq.': 3, 'e.g.': 8, 'i.e.': 3, 'nos.': 2, 'no.': 3, 'spp.': 2}


['Data were drawn from the Whitehall II study with baseline examination in 1991; follow-up screenings in 1997, 2003, and 2008; and additional disease ascertainment from hospital data and registry linkage on 5318 participants (mean age 54.8 years, 31% women) without depressive symptoms at baseline.',
 ' Vascular risk was assessed with the Framingham Cardiovascular, Coronary Heart Disease, and Stroke Risk Scores.',
 ' New depressive symptoms at each follow-up screening were identified by General Health Questionnaire caseness, a Center for Epidemiologic Studies Depression Scale score ≥16, and use of antidepressant medication.']

In [28]:
def annotation_map(all_tokenized_data, all_data):
    """
        Inputs:
            all_tokenized_data: list of [list of sentences]
            all_data: list of Tuple[doc, annotations, (doc_id,)]

        Outputs:
            all_annotated_split_data: 
                    [
                        Dict{'doc_id': doc_id,
                            'sentences': [list of string]
                            'offsets': [list of int]
                            'annotations': [list of Dict{"QUANT": [],
                                                        "ME": [],
                                                        "MP": [],
                                                        "QUAL": []
                                                    }
                                            ],

                            ... Repeat for the second doc
            ]
    """

    # in the next loop
        # Replace all numbers by zero.
        # check if any of the annotations are used for this falls in between the start and end. Make sure no overlap
        # add offset as well.

    normalize = lambda x: re.sub(r'\d', '0', x)
    all_annotated_split_data = []

    for doc, sent_splits in zip(all_data, all_tokenized_data):

        this_offsets = []

        prev_end = 0
        for s in sent_splits:
            this_offsets.append([prev_end, prev_end+len(s)])
            prev_end += len(s)

        this_annotations = []
        for s, offset in zip(sent_splits, this_offsets):
            this_sent_ann = {}
            for k,v in doc[1].items():
                this_key_annotation_sentence = []

                for ann in v:
                    if offset[0] <= ann[0] and ann[1] < offset[1]:
                        this_key_annotation_sentence.append((ann[0]-offset[0], ann[1]-offset[0]))

                this_sent_ann[k] = this_key_annotation_sentence

            this_annotations.append(this_sent_ann)

        all_annotated_split_data.append({'doc_id': doc[-1][0],
                    'sentences': [normalize(ss) for ss in sent_splits],
                    'offsets': this_offsets,
                    'annotations': this_annotations
                })
        # print(all_annotated_split_data[-1])
        assert len(all_annotated_split_data[-1]['offsets']) == len(all_annotated_split_data[-1]['sentences'])
        assert len(all_annotated_split_data[-1]['offsets']) == len(all_annotated_split_data[-1]['annotations'])

    return all_annotated_split_data

In [29]:
all_annotated_split_data = annotation_map(tokenized_data, all_data)

In [37]:
all_annotated_split_data[246]

{'doc_id': 'S0019103512004009-2930',
 'sentences': ['The differences between previous models arise from different assumptions regarding heating rates and boundary conditions.',
  ' In addition to modeling the density profiles of the detected heavy species, we have improved these aspects of the calculations in our work.',
  ' For instance, the lower boundary conditions are constrained by results from a detailed photochemical model of the lower atmosphere (Lavvas et al., in preparation).',
  ' With regard to the upper boundary conditions, we demonstrate that for HD000000b the extrapolated ‘outflow’ boundary conditions (e.g., Tian et al., 0000) are consistent with recent results from kinetic theory (Volkov et al., 0000a,b) as long as the upper boundary is at a sufficiently high altitude – although uncertainties regarding the interaction of the atmosphere with the stellar wind may limit the validity of both boundary conditions.',
  ' We highlight the effect of heating efficiency and stella

In [35]:
bert_tok = AutoTokenizer.from_pretrained("roberta-base", use_fast=True)

In [40]:
def batch_dataset(sentence_wise_data, shuffle=False):
    """
        Inputs:
            all_annotated_split_data: 
                    [
                        Dict{'doc_id': doc_id,
                            'sentences': [list of string]
                            'offsets': [list of int]
                            'annotations': [list of Dict{"QUANT": [],
                                                        "ME": [],
                                                        "MP": [],
                                                        "QUAL": []
                                                    }
                                            ],
                            }
                            ... Similar dict for the second doc and so on.
            ]
        Outputs:
            Batched data: [doc_ixs, sent_offsets, sentences, labels, pad_masks]
    """
    # First flatten and shuffle
    # Then Batch
    flattened = [[doc['doc_id'], doc['sentences'][i], doc['offsets'][i], doc['annotations'][i]]
                 for doc in sentence_wise_data for i in range(len(doc['sentences']))]

    print(f'Flattened {len(sentence_wise_data)} docs into {len(flattened)} data points',
            "\nSome examples:", flattened[:2])
    if shuffle:
        random.shuffle(flattened)

    cls_token_idx = bert_tok.convert_tokens_to_ids(bert_tok.tokenize('[CLS]'))[0]
    sep_token_idx = bert_tok.convert_tokens_to_ids(bert_tok.tokenize('[SEP]'))[0]
    pad_token_idx = bert_tok.convert_tokens_to_ids(bert_tok.tokenize('[PAD]'))[0]


    dataset = []
    idx = 0
    num_data = len(flattened)
    while idx < num_data:
        batch_doc_ids = []
        batch_sent_offsets = []
        batch_raw_text = []
        batch_raw_labels = []

        for single_docid, single_sentence, single_offset, single_annotations in \
                    flattened[idx:min(idx + batch_size, num_data)]:

            batch_doc_ids.append(single_docid)
            batch_sent_offsets.append(single_offset)
            batch_raw_text.append(single_sentence)
            batch_raw_labels.append(single_annotations)


        batched_dict = bert_tok.batch_encode_plus(batch_raw_text,
                                                    return_offsets_mapping=True,
                                                    padding=True)


        batch_tokens = torch.LongTensor(batched_dict['input_ids']).to(device)
        batch_maxlen = batch_tokens.shape[-1]
        pad_masks = torch.LongTensor(batched_dict['attention_mask']).to(device)
        batch_offset_mapping = batched_dict['offset_mapping']

        # Create sequence labels using token offsets
        batch_labels = []
        for single_token_offs, single_anns in zip(batch_offset_mapping, batch_raw_labels):
            anns = single_anns[task]
            single_labels = []
            i = 0
            for off in single_token_offs:
                if off == (0,0):
                    single_labels.append(0)
                elif type(anns) == list and i < len(anns):
                    if off[1] < anns[i][0]:
                        single_labels.append(0)
                    elif off[0] > anns[i][1]:
                        i += 1
                        single_labels.append(0)
                    else:
                        single_labels.append(1)
                else:
                    single_labels.append(0)
            batch_labels.append(single_labels)

        batch_labels = torch.LongTensor(batch_labels).to(device)

        b = batch_size if (idx + batch_size) < num_data else (num_data - idx)
        assert batch_tokens.size() == torch.Size([b, batch_maxlen])
        assert batch_labels.size() == torch.Size([b, batch_maxlen])
        assert pad_masks.size() == torch.Size([b, batch_maxlen])

        dataset.append((batch_tokens, batch_labels, pad_masks, batch_doc_ids, batch_sent_offsets, batch_offset_mapping))
        idx += batch_size

    print("num_batches=", len(dataset), " | num_data=", num_data)
    return dataset

In [41]:
batch_size = 64
device = 'cuda'
task = 'QUANT'

In [42]:
batched_dataset = batch_dataset(all_annotated_split_data, shuffle=True if "train" in text_path else False)

Flattened 248 docs into 1265 data points 
Some examples: [['S0006322312001096-1136.tsv', 'Data were drawn from the Whitehall II study with baseline examination in 0000; follow-up screenings in 0000, 0000, and 0000; and additional disease ascertainment from hospital data and registry linkage on 0000 participants (mean age 00.0 years, 00% women) without depressive symptoms at baseline.', [0, 296], {'QUANT': [(73, 77), (103, 123), (205, 222), (233, 243), (245, 248)], 'ME': [(49, 69), (79, 99), (25, 43), (205, 222)], 'MP': [(224, 232), (249, 254)], 'QUAL': []}], ['S0006322312001096-1136.tsv', ' Vascular risk was assessed with the Framingham Cardiovascular, Coronary Heart Disease, and Stroke Risk Scores.', [296, 407], {'QUANT': [], 'ME': [], 'MP': [], 'QUAL': []}]]
num_batches= 20  | num_data= 1265


## End Prep; Begin Train

In [56]:
import time
import torch
from transformers import AutoModel, AutoTokenizer
import torch, torch.nn as nn
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
import json
import os

In [43]:
torch_seed = 434

torch.manual_seed(torch_seed)

<torch._C.Generator at 0x22024642410>

In [66]:
#### Train one epoch ####
def train(model, batched_dataset, criterion):
    model.train() # Put the model to train

    # Lets keep track of the losses at each update
    train_losses = []
    num_batch = 0

    for batch in batched_dataset:
        # Unpack the batch
        (texts, labels, att_masks, doc_ids, offsets, token_offsets) = batch

        # Make predictions on the model
        preds = model(texts, att_masks)

        # Take into account padded while calculating loss
        loss_unreduced = criterion(preds.permute(0,2,1), labels)
        loss = (loss_unreduced * att_masks).sum() / (att_masks).sum()

        # Update model weights
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if num_batch % 10 == 0:
            print("Train loss at {}:".format(num_batch), loss.item())

        num_batch += 1
        # Append losses
        train_losses.append(loss.item())

    return np.average(train_losses)

In [69]:
def evaluate(model, batched_dataset, criterion):
    target_names=["NOT_" + params.task, params.task]

    # Put the model to eval mode
    model.eval()

    # Keep track on predictions
    valid_losses = []
    predicts = []
    gnd_truths = []

    with torch.no_grad():
        for batch in batched_dataset:
            # Unpack the batch
            (texts, labels, att_masks, doc_ids, offsets, token_offsets) = batch
            # Make predictions on the model
            preds = model(texts, att_masks)

            # Take into account padded while calculating loss
            loss_unreduced = criterion(preds.permute(0,2,1), labels)
            loss = (loss_unreduced * att_masks).sum() / (att_masks).sum()

            # Get argmax of non-padded tokens
            for sent_preds, sent_labels, sent_att_masks in zip(preds, labels, att_masks):
                for token_preds, token_labels, token_masks in zip(sent_preds, sent_labels, sent_att_masks):
                    if token_masks.item() != 0:
                        predicts.append(token_preds.argmax().item())
                        gnd_truths.append((token_labels.item()))
            valid_losses.append(loss.item())

            assert len(predicts) == len(gnd_truths)

    # Create confusion matrix and evaluate on the predictions
    confuse_mat = confusion_matrix(gnd_truths, predicts)
    if dummy_run:
        classify_report = None
    else:
        classify_report = classification_report(gnd_truths, predicts,
                                        target_names=target_names,
                                        output_dict=True)

    mean_valid_loss = np.average(valid_losses)
    print("Valid_loss", mean_valid_loss)
    print(confuse_mat)

    if not dummy_run:
        for labl in target_names:
            print(labl,"F1-score:", classify_report[labl]["f1-score"])
        print("Accu:", classify_report["accuracy"])
        print("F1-Weighted", classify_report["weighted avg"]["f1-score"])
        print("F1-Avg", classify_report["macro avg"]["f1-score"])

    return mean_valid_loss, confuse_mat ,classify_report

In [46]:
text_path = "./data/train/text/"
label_path = "./data/train/tsv/"

textset = get_doc_ids(text_path)

# Load all annotations
files_with_label = [file_name[:-4] for file_name in os.listdir(label_path)]
all_files_with_or_without_label = [file_name[:-4] for file_name in os.listdir(text_path)]
files_without_label = list(set(all_files_with_or_without_label).difference(set(files_with_label)))
print(files_without_label[0])
print(len(files_with_label), len(files_without_label), len(all_files_with_or_without_label))

all_data = load_dataset(files_with_label, files_without_label)

tokenized_data = tokenize_split_data(all_data)

all_annotated_split_data = annotation_map(tokenized_data, all_data)

#### RENAME THIS IF RUNNING SEQUENTIALLY!
train_dataset = batch_dataset(all_annotated_split_data, shuffle=True if "train" in text_path else False)

S0022000014000026-7850
233 15 248
Fixed sentence splitting: {'figs.': 4, 'fig.': 92, 'et al.': 31, 'ref.': 3, 'eq.': 3, 'e.g.': 8, 'i.e.': 3, 'nos.': 2, 'no.': 3, 'spp.': 2}
Flattened 248 docs into 1265 data points 
Some examples: [['S0006322312001096-1136.tsv', 'Data were drawn from the Whitehall II study with baseline examination in 0000; follow-up screenings in 0000, 0000, and 0000; and additional disease ascertainment from hospital data and registry linkage on 0000 participants (mean age 00.0 years, 00% women) without depressive symptoms at baseline.', [0, 296], {'QUANT': [(73, 77), (103, 123), (205, 222), (233, 243), (245, 248)], 'ME': [(49, 69), (79, 99), (25, 43), (205, 222)], 'MP': [(224, 232), (249, 254)], 'QUAL': []}], ['S0006322312001096-1136.tsv', ' Vascular risk was assessed with the Framingham Cardiovascular, Coronary Heart Disease, and Stroke Risk Scores.', [296, 407], {'QUANT': [], 'ME': [], 'MP': [], 'QUAL': []}]]
num_batches= 20  | num_data= 1265


In [49]:
text_path = "./data/trial/text/"
label_path = "./data/trial/tsv/"

textset = get_doc_ids(text_path)

# Load all annotations
files_with_label = [file_name[:-4] for file_name in os.listdir(label_path)]
all_files_with_or_without_label = [file_name[:-4] for file_name in os.listdir(text_path)]
files_without_label = list(set(all_files_with_or_without_label).difference(set(files_with_label)))
print(len(files_with_label), len(files_without_label), len(all_files_with_or_without_label))

all_data = load_dataset(files_with_label, files_without_label)

tokenized_data = tokenize_split_data(all_data)

all_annotated_split_data = annotation_map(tokenized_data, all_data)

#### RENAME THIS IF RUNNING SEQUENTIALLY!
valid_dataset = batch_dataset(all_annotated_split_data, shuffle=True if "train" in text_path else False)

65 0 65
Fixed sentence splitting: {'figs.': 6, 'fig.': 61, 'et al.': 22, 'ref.': 0, 'eq.': 2, 'e.g.': 5, 'i.e.': 0, 'nos.': 0, 'no.': 0, 'spp.': 4}
Flattened 65 docs into 420 data points 
Some examples: [['S0012821X12004384-1302.tsv', 'Correspondence analysis (CA) and statistical diversity analysis were carried out on the palynological dataset (total counts per gram) to confirm assemblage designations (Figs. 0 and 0), to identify any disturbance to the core prior to interpretation, and to estimate diversity (Fig. 0).', [0, 286], {'QUANT': [], 'ME': [], 'MP': [], 'QUAL': []}], ['S0012821X12004384-1302.tsv', ' Dinoflagellate cyst assemblages (DA0–DA0) and pollen assemblages (PA0–PA0) were defined by visually comparing changes in the species dominance (Figs. 0 and 0), and confirmed by CA (Fig. 0) using the first three axes (describing the highest percentages of variance).', [286, 552], {'QUANT': [], 'ME': [], 'MP': [], 'QUAL': []}]]
num_batches= 7  | num_data= 420


In [50]:
print("Dataset created")
os.system("nvidia-smi")

Dataset created


0

In [52]:
bert_type = "roberta-base"
device = 'cuda'

In [62]:
class OurBERTModel(nn.Module):
    def __init__(self):
        super(OurBERTModel, self).__init__()
        self.bert = AutoModel.from_pretrained(bert_type)
        self.drop = nn.Dropout(self.bert.config.hidden_dropout_prob)
        self.classifier = nn.Linear(self.bert.config.hidden_size, 2)

    def forward(self, text, att_mask):
        b, num_tokens = text.shape
        token_type = torch.zeros((b, num_tokens), dtype=torch.long).to(device)
        outputs = self.bert(text, attention_mask=att_mask, token_type_ids=token_type)
        return self.classifier(self.drop(outputs['last_hidden_state']))

model = OurBERTModel()
print("Model created")
os.system("nvidia-smi")

print(sum(p.numel() for p in model.parameters()),"parameters!")
model = model.to(params.device)
print("Detected", torch.cuda.device_count(), "GPUs!")

Model created
124647170 parameters!
Detected 1 GPUs!


In [73]:
n_epochs = 6
lr = 3e-05
dummy_run = False

In [74]:
########## Optimizer & Loss ###########

criterion = torch.nn.CrossEntropyLoss(reduction='none')
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

########## Training loop ###########

for epoch in range(n_epochs):
    print("\n\n========= Beginning epoch", epoch+1,"==========")

    train_loss = train(model, train_dataset, criterion)
    print("\n==== EVALUATING On Training set ====\n")
    _, _, _ = evaluate(model, train_dataset, criterion)

    print("\n==== EVALUATING On Validation set ====\n")
    valid_loss, confuse_mat, classify_report = evaluate(model, valid_dataset, criterion)

    epoch_len = len(str(n_epochs))
    print_msg = (f'[{epoch+1:>{epoch_len}}/{n_epochs:>{epoch_len}}]     ' +
                    f'train_loss: {train_loss:.5f} ' +
                    f'valid_loss: {valid_loss:.5f}')
    print(print_msg)



Train loss at 0: 0.006757479626685381
Train loss at 10: 0.007817627862095833

====EVALUATING On Training set :====

Valid_loss 0.002977533807279542
[[43973    41]
 [   42  3538]]
NOT_QUANT F1-score: 0.9990571289007033
QUANT F1-score: 0.9884062019835173
Accu: 0.9982560826995
F1-Weighted 0.9982559708059113
F1-Avg 0.9937316654421102

====EVALUATING On Validation set :====

Valid_loss 0.07035595870443753
[[15646   136]
 [   69  1135]]
NOT_QUANT F1-score: 0.993491443629552
QUANT F1-score: 0.9171717171717172
Accu: 0.9879312374896974
F1-Weighted 0.9880817562013621
F1-Avg 0.9553315804006346
[1/6]     train_loss: 0.00555 valid_loss: 0.07036


Train loss at 0: 0.0020206396002322435
Train loss at 10: 0.00983401108533144

====EVALUATING On Training set :====

Valid_loss 0.004273403610568494
[[43914   100]
 [    2  3578]]
NOT_QUANT F1-score: 0.9988399863527806
QUANT F1-score: 0.9859465417470378
Accu: 0.9978568727150481
F1-Weighted 0.9978701470518486
F1-Avg 0.9923932640499091

====EVALUATING On Va

In [76]:
model.eval()
valid_losses = []
correct_sentences = 0
total_sentences = 0 

with torch.no_grad():
    for batch in valid_dataset:
        # Unpack the batch and feed into the model
        (texts, labels, att_masks, doc_ids, offsets, token_offsets) = batch
        preds = model(texts, att_masks)

        # Check for mispredicts
        for sent_preds, sent_labels, sent_att_masks, sent_doc_id, sent_offset in zip(preds, labels, att_masks, doc_ids, offsets):
            this_correct = (sent_preds.argmax(1) == sent_labels).sum()
            this_total = len(sent_labels)

            if this_correct != this_total:
                print("Mispredict:", sent_doc_id, "for sentence at offset", sent_offset)

Mispredict: S0012821X12004384-1302.tsv for sentence at offset [552, 745]
Mispredict: S0012821X12004384-1302.tsv for sentence at offset [745, 864]
Mispredict: S0012821X12004384-1302.tsv for sentence at offset [1476, 1951]
Mispredict: S0012821X12004384-1302.tsv for sentence at offset [1951, 2142]
Mispredict: S0012821X12004384-1415.tsv for sentence at offset [0, 473]
Mispredict: S0012821X13002185-1217.tsv for sentence at offset [199, 367]
Mispredict: S0012821X13002185-1217.tsv for sentence at offset [367, 562]
Mispredict: S0012821X13002185-1217.tsv for sentence at offset [683, 883]
Mispredict: S0012821X13002185-1217.tsv for sentence at offset [944, 1057]
Mispredict: S0012821X13002185-1231.tsv for sentence at offset [0, 379]
Mispredict: S0012821X13002185-1231.tsv for sentence at offset [565, 663]
Mispredict: S0012821X13002185-1231.tsv for sentence at offset [1039, 1375]
Mispredict: S0012821X13002185-835.tsv for sentence at offset [129, 370]
Mispredict: S0012821X13007309-1605.tsv for senten

In [81]:
# Predict spans on Test, Train and Val set

def predict_spans(batched_dataset, save_path):
    # Put the model to eval mode and track the predictions
    model.eval()

    with torch.no_grad():
        span_dict = {}
        for batch in batched_dataset:
            (texts, raw_texts, att_masks, doc_ids, offsets, token_offsets) = batch
            preds = model(texts, att_masks)

            for sent_preds, sent_raw_text, sent_att_masks, sent_doc_id, sent_offset, sent_token_offsets in zip(preds, raw_texts, att_masks, doc_ids, offsets, token_offsets):

                this_sentence_positives = []
                curr_positive_idx = -1
                for i, (token_preds, token_labels, token_masks) in enumerate(zip(sent_preds, sent_labels, sent_att_masks)):
                    if token_masks.item() != 0:
                        if token_preds.argmax().item() == 1:
                            if curr_positive_idx == -1:
                                curr_positive_idx = i
                        else:
                            if curr_positive_idx != -1:
                                this_sentence_positives.append([curr_positive_idx, i-1])
                                curr_positive_idx = -1
                    else:
                        if curr_positive_idx != -1:
                            this_sentence_positives.append([curr_positive_idx, i-1])
                            curr_positive_idx = -1
                            break

                # Here convert indices to offsets
                if sent_doc_id not in span_dict.keys():
                    span_dict[sent_doc_id] = {}

                this_sent_spans = []
                for span_offsets in this_sentence_positives:
                    this_sent_spans.append([sent_token_offsets[span_offsets[0]][0],
                                            sent_token_offsets[span_offsets[1]][1]
                                        ])
                
                assert sent_offset[0] not in span_dict[sent_doc_id].keys()
                span_dict[sent_doc_id][sent_offset[0]] = this_sent_spans
    
    json.dump(span_dict, open(save_path, 'w+'))
    

# Save model and predicition
from datetime import datetime
folder_name = "./measeval/task1/output/" + bert_type.replace('/', '_') #datetime.today().strftime('%Y-%m-%d-%H-%M-%S')
if os.path.isdir(folder_name):
    os.system("rm -rf " + folder_name)
os.mkdir(folder_name)

test_dataset = METestDataset(eval_doc_path)

predict_spans(train_dataset, folder_name+"/train_spans.json")
predict_spans(valid_dataset, folder_name+"/trial_spans.json")
predict_spans(test_dataset, folder_name+"/test_spans.json")

torch.save(model.state_dict(), folder_name+"/model.pt")
json.dump(vars(params), open(folder_name+"/params.json", 'w+'))

NameError: name 'METestDataset' is not defined