In [1]:
# %%
import sys
import os
import time
import importlib
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.autograd as autograd
import torch.optim as optim

from torch.utils.data.distributed import DistributedSampler
from torch.utils import data

from tqdm import tqdm, trange
import collections

from transformers import BertModel
from transformers import BertPreTrainedModel, BertModel
import pickle
from transformers import AdamW
from transformers import BertTokenizer

In [2]:
print('Python version ', sys.version)
print('PyTorch version ', torch.__version__)

cuda_yes = torch.cuda.is_available()
# cuda_yes = False
print('Cuda is available?', cuda_yes)
device = torch.device("cuda:0" if cuda_yes else "cpu")
print('Device:', device)

# random.seed(44)
np.random.seed(44)
torch.manual_seed(44)
if cuda_yes:
    torch.cuda.manual_seed_all(44)

Python version  3.8.5 (default, Sep  4 2020, 07:30:14) 
[GCC 7.3.0]
PyTorch version  1.11.0+cu102
Cuda is available? True
Device: cuda:0


In [3]:
#unique_train_labels = {'I-Blinding_Object_Investigators', 'B-Design_Comparative_Intent_Equivalence', 'I-Randomization_Minimization_Criteria', 'B-Blinding_Quadruple_Blind', 'I-Design_Phase_2', 'B-Randomization_Stratified_Criteria', 'B-Settings_Location', 'B-Randomization_Type_Stratified', 'B-Design_Comparative_Intent_NonInferiority', 'B-Randomization_Personnel', 'B-Design_Crossover_Period_Treatment', 'B-Randomization_Type_Minimization', 'B-Settings_Multicenter', 'B-Randomization_Block_Size', 'I-Sample_Size_Target', 'B-Design_Factorial', 'B-Design_Phase_1', 'I-Randomization_Sequence_Generation_Method', 'I-Design_Phase_4', 'B-Blinding_Object_Investigators', 'B-Sample_Size_Target', 'B-Blinding_Open_Label', 'B-Design_Phase_3', 'I-Design_Phase_3', 'B-Blinding_Object_Care_Providers', 'I-Randomization_Type_Simple', 'I-Design_Phase_1', 'B-Randomization_Sequence_Generation_Method', 'I-Randomization_Type_Block', 'I-Randomization_Ratio', 'I-Sample_Size_Calculation_Power_Value', 'B-Design_Phase_2', 'I-Sample_Size_Actual_at_Enrollment', 'B-Blinding_Object_Others', 'B-Design_Comparative_Intent_Superiority', 'I-Blinding_Object_Others', 'I-Blinding_Object_Patients', 'B-Blinding_Object_Patients', 'B-Randomization_Ratio', 'B-Settings_Single_Center', 'I-Design_Parallel_Group', 'B-Sample_Size_Calculation_Power_Value', 'B-Sample_Size_Actual_at_Outcome_Analysis', 'B-Allocation_Concealment_Method', 'I-Sample_Size_Actual_at_Outcome_Analysis', 'B-Blinding_Double_Blind', 'B-Sample_Size_Calculation_Alpha_Value', 'B-Randomization_Minimization_Criteria', 'B-Blinding_Single_Blind', 'I-Blinding_Object_Outcome_Assessors', 'B-Sample_Size_Calculation_Dropout_Rate_Value', 'I-Sample_Size_Required', 'B-Design_Crossover', 'I-Settings_Single_Center', 'I-Blinding_Quadruple_Blind', 'I-Design_Crossover', 'B-Design_Parallel_Group', 'I-Randomization_Block_Size', 'O', 'I-Randomization_Personnel', 'I-Sample_Size_Calculation_Alpha_Value', 'I-Blinding_Object_Care_Providers', 'I-Blinding_Open_Label', 'I-Allocation_Concealment_Method', 'B-Design_Phase_4', 'I-Blinding_Double_Blind', 'B-Randomization_Type_Block', 'I-Randomization_Stratified_Criteria', 'I-Design_Crossover_Period_Treatment', 'B-Randomization_Type_Simple', 'B-Sample_Size_Required', 'B-Blinding_Object_Outcome_Assessors', 'I-Settings_Multicenter', 'I-Settings_Location', 'B-Sample_Size_Actual_at_Enrollment', 'I-Sample_Size_Calculation_Dropout_Rate_Value'}
unique_train_labels = {'I-Settings_Multicenter', 'B-Design_Cluster_Allocation', 'B-Randomization_Ratio', 'I-Design_Phase_2', 'I-Sample_Size_Calculation_Dropout_Rate_Value', 'B-Randomization_Personnel', 'I-Design_Phase_4', 'I-Sample_Size_Actual_at_Outcome_Analysis', 'B-Design_Phase_4', 'B-Design_Parallel_Group', 'I-Randomization_Type_Block', 'I-Settings_Single_Center', 'B-Sample_Size_Actual_at_Outcome_Analysis', 'B-Design_Comparative_Intent_Superiority', 'I-Sample_Size_Target', 'B-Design_Comparative_Intent_Equivalence', 'I-Allocation_Concealment_Method', 'B-Randomization_Type_Block', 'I-Blinding_Open_Label', 'O', 'I-Blinding_Object_Investigators', 'I-Randomization_Personnel', 'I-Design_Phase_3', 'I-Design_Cluster_Allocation', 'I-Blinding_Single_Blind', 'B-Randomization_Sequence_Generation_Method', 'B-Allocation_Concealment_Method', 'I-Randomization_Block_Size', 'I-Design_Crossover_Period_Treatment', 'B-Sample_Size_Calculation_Power_Value', 'I-Settings_Location', 'B-Blinding_Object_Investigators', 'B-Blinding_Single_Blind', 'I-Randomization_Type_Simple', 'B-Blinding_Object_Outcome_Assessors', 'I-Design_Crossover', 'B-Randomization_Stratified_Criteria', 'I-Blinding_Object_Patients', 'B-Blinding_Object_Care_Providers', 'B-Design_Comparative_Intent_NonInferiority', 'B-Blinding_Open_Label', 'B-Randomization_Type_Stratified', 'I-Blinding_Object_Care_Providers', 'I-Sample_Size_Calculation_Alpha_Value', 'B-Blinding_Double_Blind', 'I-Design_Parallel_Group', 'B-Design_Phase_3', 'B-Design_Crossover', 'B-Settings_Location', 'B-Sample_Size_Required', 'B-Design_Phase_1', 'I-Blinding_Object_Others', 'B-Randomization_Type_Minimization', 'I-Blinding_Object_Outcome_Assessors', 'B-Blinding_Quadruple_Blind', 'B-Sample_Size_Actual_at_Enrollment', 'I-Randomization_Minimization_Criteria', 'I-Randomization_Sequence_Generation_Method', 'I-Blinding_Quadruple_Blind', 'B-Randomization_Minimization_Criteria', 'I-Randomization_Ratio', 'B-Design_Crossover_Period_Treatment', 'B-Sample_Size_Target', 'B-Sample_Size_Calculation_Dropout_Rate_Value', 'B-Settings_Single_Center', 'I-Sample_Size_Calculation_Power_Value', 'I-Sample_Size_Required', 'I-Sample_Size_Actual_at_Enrollment', 'B-Randomization_Block_Size', 'B-Design_Phase_2', 'I-Randomization_Stratified_Criteria', 'B-Blinding_Object_Others', 'B-Blinding_Object_Patients', 'I-Design_Phase_1', 'B-Randomization_Type_Simple', 'B-Settings_Multicenter', 'B-Sample_Size_Calculation_Alpha_Value', 'I-Blinding_Double_Blind'}
unique_test_labels = {'I-Design_Comparative_Intent_NonInferiority', 'B-Blinding_Object_Others', 'I-Randomization_Sequence_Generation_Method', 'I-Sample_Size_Calculation_Dropout_Rate_Value', 'B-Randomization_Sequence_Generation_Method', 'I-Randomization_Type_Block', 'I-Blinding_Object_Others', 'B-Sample_Size_Actual_at_Enrollment', 'B-Randomization_Type_Block', 'B-Blinding_Object_Outcome_Assessors', 'I-Allocation_Concealment_Method', 'B-Randomization_Personnel', 'I-Randomization_Type_Minimization', 'I-Sample_Size_Actual_at_Enrollment', 'B-Blinding_Double_Blind', 'I-Randomization_Ratio', 'B-Settings_Multicenter', 'B-Randomization_Type_Simple', 'I-Randomization_Type_Simple', 'B-Randomization_Block_Size', 'B-Design_Comparative_Intent_Superiority', 'B-Settings_Single_Center', 'I-Randomization_Block_Size', 'B-Design_Phase_2', 'B-Design_Factorial_Factor_Treatment', 'B-Design_Factorial', 'I-Sample_Size_Required', 'I-Design_Parallel_Group', 'B-Design_Parallel_Group', 'B-Sample_Size_Actual_at_Outcome_Analysis', 'I-Sample_Size_Calculation_Power_Value', 'B-Randomization_Ratio', 'I-Blinding_Object_Outcome_Assessors', 'B-Allocation_Concealment_Method', 'O', 'B-Design_Comparative_Intent_NonInferiority', 'I-Blinding_Object_Care_Providers', 'B-Design_Crossover_Period_Treatment', 'B-Randomization_Type_Minimization', 'B-Blinding_Object_Patients', 'I-Blinding_Open_Label', 'I-Randomization_Minimization_Criteria', 'B-Blinding_Object_Care_Providers', 'I-Sample_Size_Calculation_Alpha_Value', 'B-Blinding_Object_Investigators', 'I-Randomization_Stratified_Criteria', 'B-Sample_Size_Required', 'B-Settings_Location', 'B-Sample_Size_Calculation_Power_Value', 'B-Design_Crossover', 'B-Randomization_Type_Stratified', 'I-Design_Factorial_Factor_Treatment', 'I-Settings_Single_Center', 'I-Blinding_Double_Blind', 'I-Design_Phase_2', 'I-Randomization_Personnel', 'I-Sample_Size_Target', 'B-Sample_Size_Calculation_Alpha_Value', 'I-Settings_Multicenter', 'I-Settings_Location', 'B-Blinding_Open_Label', 'B-Randomization_Minimization_Criteria', 'I-Blinding_Single_Blind', 'B-Sample_Size_Calculation_Dropout_Rate_Value', 'B-Randomization_Stratified_Criteria', 'B-Sample_Size_Target', 'I-Sample_Size_Actual_at_Outcome_Analysis'}
all_unique_labels = unique_train_labels.union(unique_test_labels)
print (len(unique_train_labels))
print (len(unique_test_labels))
print (len(all_unique_labels))
print (all_unique_labels)


78
67
83
{'I-Randomization_Personnel', 'I-Design_Phase_2', 'B-Randomization_Type_Minimization', 'B-Blinding_Open_Label', 'B-Design_Parallel_Group', 'B-Randomization_Personnel', 'I-Sample_Size_Calculation_Power_Value', 'I-Blinding_Object_Care_Providers', 'B-Sample_Size_Required', 'B-Allocation_Concealment_Method', 'B-Randomization_Minimization_Criteria', 'I-Sample_Size_Actual_at_Outcome_Analysis', 'B-Sample_Size_Actual_at_Outcome_Analysis', 'B-Design_Phase_2', 'I-Design_Phase_4', 'I-Blinding_Object_Patients', 'I-Randomization_Sequence_Generation_Method', 'B-Blinding_Single_Blind', 'I-Settings_Multicenter', 'I-Randomization_Type_Minimization', 'B-Design_Comparative_Intent_Superiority', 'B-Design_Crossover_Period_Treatment', 'I-Sample_Size_Target', 'B-Blinding_Object_Patients', 'B-Randomization_Sequence_Generation_Method', 'B-Design_Comparative_Intent_Equivalence', 'I-Design_Crossover_Period_Treatment', 'I-Blinding_Object_Others', 'I-Design_Parallel_Group', 'I-Design_Phase_3', 'B-Design_P

In [4]:
# PREPARE DATA
data_dir1 = "/home/lhoang2/Notebooks/RCT_Methodology_Extraction/NER-BiLSTM-CRF/DATA_09122022/"
data_dir2 = "/home/lhoang2/Notebooks/RCT_Methodology_Extraction/DATA_07232022/"

class InputExample(object):
    """A single training/test example for NER."""

    def __init__(self, guid, words, labels):
        """Constructs a InputExample.
        Args:
          guid: Unique id for the example(a sentence or a pair of sentences).
          words: list of words of sentence
          labels_a/labels_b: (Optional) string. The label seqence of the text_a/text_b. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        # list of words of the sentence,example: [EU, rejects, German, call, to, boycott, British, lamb .]
        self.words = words
        # list of label sequence of the sentence,like: [B-ORG, O, B-MISC, O, O, O, B-MISC, O, O]
        self.labels = labels


class InputFeatures(object):
    """A single set of features of data.
    result of convert_examples_to_features(InputExample)
    """

    def __init__(self, input_ids, input_mask, segment_ids,  predict_mask, label_ids):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.predict_mask = predict_mask
        self.label_ids = label_ids


class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""

    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

    @classmethod
    def _read_data(cls, input_file):
        """
        Reads a BIO data.
        """
        with open(input_file) as f:
            # out_lines = []
            out_lists = []
            entries = f.read().strip().split("\n\n")
            for entry in entries:
                words = []
                ner_labels = []
                for line in entry.splitlines():
                    pieces = line.strip().split()
                    if len(pieces) < 1:
                        continue
                    word = pieces[0]
                    words.append(word)
                    ner_labels.append(pieces[-1])

                out_lists.append([words,ner_labels])
        return out_lists


class CoNLLDataProcessor(DataProcessor):

    def __init__(self):
#        self._label_types = ['X','[CLS]', '[SEP]','O','I-Randomization_Stratified_Criteria', 'I-Blinding_Double_Blind', 'I-Design_Factorial_Factor_Treatment', 'I-Design_Phase_3', 'I-Sample_Size_Required', 'B-Randomization_Type_Stratified', 'I-Sample_Size_Calculation_Power_Value', 'B-Blinding_Object_Others', 'I-Randomization_Block_Size', 'B-Randomization_Stratified_Criteria', 'I-Design_Comparative_Intent_NonInferiority', 'B-Design_Factorial', 'I-Blinding_Object_Investigators', 'B-Blinding_Object_Patients', 'B-Settings_Single_Center', 'B-Blinding_Object_Care_Providers', 'B-Randomization_Type_Simple', 'B-Design_Comparative_Intent_Superiority', 'I-Randomization_Sequence_Generation_Method', 'I-Sample_Size_Calculation_Dropout_Rate_Value', 'B-Blinding_Object_Outcome_Assessors', 'B-Sample_Size_Required', 'I-Randomization_Type_Block', 'I-Randomization_Personnel', 'I-Blinding_Object_Outcome_Assessors', 'B-Sample_Size_Calculation_Power_Value', 'I-Sample_Size_Target', 'I-Randomization_Ratio', 'B-Blinding_Single_Blind', 'I-Settings_Multicenter', 'I-Settings_Single_Center', 'I-Randomization_Type_Minimization', 'B-Randomization_Type_Block', 'B-Design_Comparative_Intent_Equivalence', 'B-Randomization_Personnel', 'B-Settings_Location', 'I-Design_Crossover_Period_Treatment', 'B-Design_Phase_4', 'B-Design_Crossover', 'B-Blinding_Open_Label', 'O', 'I-Blinding_Object_Others', 'B-Randomization_Type_Minimization', 'B-Sample_Size_Actual_at_Enrollment', 'I-Design_Phase_2', 'B-Allocation_Concealment_Method', 'B-Design_Phase_2', 'I-Blinding_Object_Care_Providers', 'B-Randomization_Block_Size', 'B-Sample_Size_Actual_at_Outcome_Analysis', 'B-Blinding_Double_Blind', 'I-Sample_Size_Calculation_Alpha_Value', 'I-Sample_Size_Actual_at_Enrollment', 'B-Design_Comparative_Intent_NonInferiority', 'B-Sample_Size_Calculation_Dropout_Rate_Value', 'I-Randomization_Minimization_Criteria', 'I-Randomization_Type_Simple', 'I-Allocation_Concealment_Method', 'I-Blinding_Open_Label', 'I-Design_Phase_4', 'B-Sample_Size_Calculation_Alpha_Value', 'B-Randomization_Sequence_Generation_Method', 'B-Design_Factorial_Factor_Treatment', 'B-Design_Parallel_Group', 'B-Blinding_Object_Investigators', 'B-Sample_Size_Target', 'I-Design_Parallel_Group', 'I-Sample_Size_Actual_at_Outcome_Analysis', 'I-Settings_Location', 'B-Design_Crossover_Period_Treatment', 'B-Design_Phase_3', 'B-Settings_Multicenter', 'B-Randomization_Minimization_Criteria', 'B-Design_Phase_1', 'B-Randomization_Ratio', 'I-Blinding_Single_Blind', 'I-Blinding_Object_Patients', 'I-Design_Phase_1']
#        self._label_types = ['X','[CLS]', '[SEP]','O','B-Randomization_Block_Size', 'I-Blinding_Object_Outcome_Assessors', 'I-Sample_Size_Required', 'B-Settings_Location', 'I-Randomization_Personnel', 'I-Settings_Single_Center', 'B-Randomization_Personnel', 'I-Randomization_Block_Size', 'B-Design_Comparative_Intent_Equivalence', 'B-Design_Factorial_Factor_Treatment', 'I-Allocation_Concealment_Method', 'I-Randomization_Stratified_Criteria', 'B-Randomization_Type_Minimization', 'I-Sample_Size_Target', 'I-Design_Parallel_Group', 'I-Blinding_Double_Blind', 'B-Design_Phase_2', 'B-Blinding_Object_Outcome_Assessors', 'I-Design_Phase_2', 'B-Blinding_Object_Others', 'B-Sample_Size_Actual_at_Enrollment', 'B-Blinding_Quadruple_Blind', 'I-Settings_Location', 'I-Sample_Size_Calculation_Power_Value', 'I-Design_Crossover', 'B-Blinding_Object_Investigators', 'I-Design_Phase_1', 'B-Design_Crossover_Period_Treatment', 'B-Sample_Size_Target', 'B-Randomization_Type_Stratified', 'B-Sample_Size_Calculation_Dropout_Rate_Value', 'B-Randomization_Ratio', 'B-Design_Comparative_Intent_Superiority', 'B-Allocation_Concealment_Method', 'I-Blinding_Object_Care_Providers', 'B-Randomization_Stratified_Criteria', 'B-Settings_Multicenter', 'B-Design_Factorial', 'I-Design_Comparative_Intent_NonInferiority', 'B-Design_Phase_1', 'I-Design_Phase_4', 'B-Design_Parallel_Group', 'I-Sample_Size_Actual_at_Enrollment', 'I-Blinding_Single_Blind', 'B-Settings_Single_Center', 'I-Sample_Size_Calculation_Alpha_Value', 'I-Blinding_Object_Others', 'B-Blinding_Open_Label', 'I-Randomization_Sequence_Generation_Method', 'B-Blinding_Double_Blind', 'I-Randomization_Minimization_Criteria', 'B-Design_Crossover', 'B-Design_Phase_3', 'I-Blinding_Open_Label', 'B-Design_Comparative_Intent_NonInferiority', 'B-Sample_Size_Actual_at_Outcome_Analysis', 'B-Randomization_Sequence_Generation_Method', 'B-Randomization_Type_Block', 'I-Blinding_Quadruple_Blind', 'B-Randomization_Type_Simple', 'B-Sample_Size_Calculation_Power_Value', 'B-Blinding_Object_Patients', 'I-Blinding_Object_Investigators', 'B-Sample_Size_Required', 'I-Design_Phase_3', 'I-Randomization_Ratio', 'I-Randomization_Type_Minimization', 'I-Sample_Size_Calculation_Dropout_Rate_Value', 'I-Design_Crossover_Period_Treatment', 'I-Randomization_Type_Block', 'I-Settings_Multicenter', 'B-Design_Phase_4', 'I-Randomization_Type_Simple', 'B-Blinding_Object_Care_Providers', 'B-Randomization_Minimization_Criteria', 'I-Blinding_Object_Patients', 'B-Blinding_Single_Blind', 'I-Design_Factorial_Factor_Treatment', 'I-Sample_Size_Actual_at_Outcome_Analysis', 'B-Sample_Size_Calculation_Alpha_Value']
        self._label_types = ['X','[CLS]', '[SEP]','O','B-Randomization_Type_Block', 'B-Design_Phase_3', 'B-Design_Comparative_Intent_Superiority', 'B-Blinding_Object_Patients', 'B-Sample_Size_Calculation_Dropout_Rate_Value', 'B-Blinding_Single_Blind', 'B-Randomization_Block_Size', 'B-Sample_Size_Actual_at_Enrollment', 'I-Randomization_Sequence_Generation_Method', 'B-Randomization_Type_Stratified', 'B-Settings_Location', 'I-Sample_Size_Actual_at_Enrollment', 'I-Design_Comparative_Intent_NonInferiority', 'B-Design_Factorial', 'I-Settings_Single_Center', 'I-Blinding_Object_Outcome_Assessors', 'I-Blinding_Quadruple_Blind', 'I-Allocation_Concealment_Method', 'B-Blinding_Open_Label', 'I-Randomization_Type_Simple', 'B-Design_Cluster_Allocation', 'I-Blinding_Double_Blind', 'I-Sample_Size_Calculation_Dropout_Rate_Value', 'I-Design_Phase_3', 'B-Design_Phase_4', 'I-Design_Crossover_Period_Treatment', 'B-Design_Crossover_Period_Treatment', 'B-Blinding_Object_Others', 'I-Design_Factorial_Factor_Treatment', 'B-Blinding_Object_Care_Providers', 'I-Sample_Size_Calculation_Alpha_Value', 'B-Randomization_Sequence_Generation_Method', 'B-Blinding_Object_Investigators', 'B-Randomization_Type_Simple', 'B-Blinding_Object_Outcome_Assessors', 'B-Settings_Multicenter', 'B-Design_Parallel_Group', 'I-Blinding_Object_Investigators', 'B-Design_Phase_2', 'I-Sample_Size_Calculation_Power_Value', 'I-Blinding_Object_Others', 'I-Sample_Size_Target', 'B-Randomization_Type_Minimization', 'I-Settings_Location', 'B-Design_Comparative_Intent_Equivalence', 'B-Randomization_Minimization_Criteria', 'I-Design_Phase_1', 'B-Design_Comparative_Intent_NonInferiority', 'I-Blinding_Single_Blind', 'B-Sample_Size_Target', 'B-Randomization_Personnel', 'I-Blinding_Object_Patients', 'I-Design_Parallel_Group', 'B-Sample_Size_Calculation_Alpha_Value', 'B-Settings_Single_Center', 'I-Design_Phase_2', 'B-Randomization_Ratio', 'B-Randomization_Stratified_Criteria', 'B-Design_Phase_1', 'B-Blinding_Double_Blind', 'I-Randomization_Block_Size', 'B-Allocation_Concealment_Method', 'I-Blinding_Object_Care_Providers', 'I-Design_Cluster_Allocation', 'B-Sample_Size_Required', 'I-Blinding_Open_Label', 'I-Randomization_Type_Block', 'I-Design_Crossover', 'B-Design_Crossover', 'I-Randomization_Ratio', 'I-Sample_Size_Required', 'B-Blinding_Quadruple_Blind', 'B-Sample_Size_Calculation_Power_Value', 'I-Settings_Multicenter', 'I-Randomization_Minimization_Criteria', 'I-Randomization_Personnel', 'I-Sample_Size_Actual_at_Outcome_Analysis', 'I-Design_Phase_4', 'B-Design_Factorial_Factor_Treatment', 'B-Sample_Size_Actual_at_Outcome_Analysis', 'I-Randomization_Stratified_Criteria', 'I-Randomization_Type_Minimization']
        self._num_labels = len(self._label_types)
        self._label_map = {label: i for i,
                           label in enumerate(self._label_types)}

    def get_train_examples(self, data_dir):
        return self._create_examples(
            self._read_data(os.path.join(data_dir1, "TRAININGDATA_positive_sentences_only_CRF_09122022.txt")))

#     def get_dev_examples(self, data_dir):
#         return self._create_examples(
#             self._read_data(os.path.join(data_dir, "valid.txt")))

    def get_test_examples(self, data_dir):
        return self._create_examples(
            self._read_data(os.path.join(data_dir2, "TESTINGDATA_07232022_CRF.txt")))

    def get_labels(self):
        return self._label_types

    def get_num_labels(self):
        return self.get_num_labels

    def get_label_map(self):
        return self._label_map

    def get_start_label_id(self):
        return self._label_map['[CLS]']

    def get_stop_label_id(self):
        return self._label_map['[SEP]']

    def _create_examples(self, all_lists):
        examples = []
        for (i, one_lists) in enumerate(all_lists):
            guid = i
            words = one_lists[0]
            labels = one_lists[-1]
            examples.append(InputExample(
                guid=guid, words=words, labels=labels))
        return examples

    def _create_examples2(self, lines):
        examples = []
        for (i, line) in enumerate(lines):
            guid = i
            text = line[0]
            ner_label = line[-1]
            examples.append(InputExample(
                guid=guid, text_a=text, labels_a=ner_label))
        return examples


def example2feature(example, tokenizer, label_map, max_seq_length):

    add_label = 'X'
    # tokenize_count = []
    tokens = ['[CLS]']
    predict_mask = [0]
    label_ids = [label_map['[CLS]']]
    for i, w in enumerate(example.words):
        # use bertTokenizer to split words
        # 1996-08-22 => 1996 - 08 - 22
        # sheepmeat => sheep ##me ##at
        sub_words = tokenizer.tokenize(w)
        if not sub_words:
            sub_words = ['[UNK]']
        # tokenize_count.append(len(sub_words))
        tokens.extend(sub_words)
        for j in range(len(sub_words)):
            if j == 0:
                predict_mask.append(1)
                label_ids.append(label_map[example.labels[i]])
            else:
                # '##xxx' -> 'X' (see bert paper)
                predict_mask.append(0)
                label_ids.append(label_map[add_label])

    # truncate
    if len(tokens) > max_seq_length - 1:
        print('Example No.{} is too long, length is {}, truncated to {}!'.format(example.guid, len(tokens), max_seq_length))
        tokens = tokens[0:(max_seq_length - 1)]
        predict_mask = predict_mask[0:(max_seq_length - 1)]
        label_ids = label_ids[0:(max_seq_length - 1)]
    tokens.append('[SEP]')
    predict_mask.append(0)
    label_ids.append(label_map['[SEP]'])

    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    segment_ids = [0] * len(input_ids)
    input_mask = [1] * len(input_ids)

    feat=InputFeatures(
                # guid=example.guid,
                # tokens=tokens,
                input_ids=input_ids,
                input_mask=input_mask,
                segment_ids=segment_ids,
                predict_mask=predict_mask,
                label_ids=label_ids)

    return feat

class NerDataset(data.Dataset):
    def __init__(self, examples, tokenizer, label_map, max_seq_length):
        self.examples=examples
        self.tokenizer=tokenizer
        self.label_map=label_map
        self.max_seq_length=max_seq_length

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        feat=example2feature(self.examples[idx], self.tokenizer, self.label_map, max_seq_length)
        return feat.input_ids, feat.input_mask, feat.segment_ids, feat.predict_mask, feat.label_ids

    @classmethod
    def pad(cls, batch):

        seqlen_list = [len(sample[0]) for sample in batch]
        maxlen = np.array(seqlen_list).max()

        f = lambda x, seqlen: [sample[x] + [0] * (seqlen - len(sample[x])) for sample in batch] # 0: X for padding
        input_ids_list = torch.LongTensor(f(0, maxlen))
        input_mask_list = torch.LongTensor(f(1, maxlen))
        segment_ids_list = torch.LongTensor(f(2, maxlen))
        predict_mask_list = torch.ByteTensor(f(3, maxlen))
        label_ids_list = torch.LongTensor(f(4, maxlen))

        return input_ids_list, input_mask_list, segment_ids_list, predict_mask_list, label_ids_list


# Load pre-trained model tokenizer (vocabulary)
conllProcessor = CoNLLDataProcessor()
label_list = conllProcessor.get_labels()
label_map = conllProcessor.get_label_map()
train_examples = conllProcessor.get_train_examples(data_dir1)
test_examples = conllProcessor.get_test_examples(data_dir2)

print ("Training data: ", len(train_examples))
print ("Testing data: ", len(test_examples))
print (label_map)

Training data:  1143
Testing data:  1657
{'X': 0, '[CLS]': 1, '[SEP]': 2, 'O': 3, 'B-Randomization_Type_Block': 4, 'B-Design_Phase_3': 5, 'B-Design_Comparative_Intent_Superiority': 6, 'B-Blinding_Object_Patients': 7, 'B-Sample_Size_Calculation_Dropout_Rate_Value': 8, 'B-Blinding_Single_Blind': 9, 'B-Randomization_Block_Size': 10, 'B-Sample_Size_Actual_at_Enrollment': 11, 'I-Randomization_Sequence_Generation_Method': 12, 'B-Randomization_Type_Stratified': 13, 'B-Settings_Location': 14, 'I-Sample_Size_Actual_at_Enrollment': 15, 'I-Design_Comparative_Intent_NonInferiority': 16, 'B-Design_Factorial': 17, 'I-Settings_Single_Center': 18, 'I-Blinding_Object_Outcome_Assessors': 19, 'I-Blinding_Quadruple_Blind': 20, 'I-Allocation_Concealment_Method': 21, 'B-Blinding_Open_Label': 22, 'I-Randomization_Type_Simple': 23, 'B-Design_Cluster_Allocation': 24, 'I-Blinding_Double_Blind': 25, 'I-Sample_Size_Calculation_Dropout_Rate_Value': 26, 'I-Design_Phase_3': 27, 'B-Design_Phase_4': 28, 'I-Design_Cros

In [5]:
import pandas as pd
print (type(label_map))

labels_df = pd.DataFrame(label_map.items())

display(labels_df)
labels_df.to_csv("labels_map_withnewdata.csv")

<class 'dict'>


Unnamed: 0,0,1
0,X,0
1,[CLS],1
2,[SEP],2
3,O,3
4,B-Randomization_Type_Block,4
...,...,...
81,I-Design_Phase_4,81
82,B-Design_Factorial_Factor_Treatment,82
83,B-Sample_Size_Actual_at_Outcome_Analysis,83
84,I-Randomization_Stratified_Criteria,84


In [6]:
batch_size = 4
gradient_accumulation_steps = 1
total_train_epochs = 20


total_train_steps = int(len(train_examples) / batch_size / gradient_accumulation_steps * total_train_epochs)

print("***** Running training *****")
print("  Num examples = %d"% len(train_examples))
print("  Batch size = %d"% batch_size)
print("  Num steps = %d"% total_train_steps)
print("  Num train epochs = %d"% total_train_epochs)

***** Running training *****
  Num examples = 1143
  Batch size = 4
  Num steps = 5715
  Num train epochs = 20


In [7]:
max_seq_length = 256

tokenizer = BertTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")

train_dataset = NerDataset(train_examples,tokenizer,label_map,max_seq_length)
test_dataset = NerDataset(test_examples,tokenizer,label_map,max_seq_length)

#input_ids_list, input_mask_list, segment_ids_list, predict_mask_list, label_ids_list
print (train_dataset[1])

train_dataloader = data.DataLoader(dataset=train_dataset,
                                batch_size=batch_size,
                                shuffle=True,
                                num_workers=4,
                                collate_fn=NerDataset.pad)
test_dataloader = data.DataLoader(dataset=test_dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=4,
                                collate_fn=NerDataset.pad)

([2, 3399, 16, 5650, 18, 28, 9, 1927, 3297, 5721, 1920, 7421, 6966, 21, 2935, 2256, 7507, 22, 16, 3002, 18, 23, 9, 5721, 1920, 26, 17, 2935, 12718, 7421, 16, 1930, 6550, 18, 20, 9, 5721, 1920, 2161, 18, 3], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0], [1, 3, 3, 3, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 3, 3, 3, 3, 0, 0, 3, 3, 3, 3, 83, 0, 0, 80, 3, 3, 3, 3, 2])


In [8]:
print('*** Use BertModel + CRF ***')

def log_sum_exp_1vec(vec):  # shape(1,m)
    max_score = vec[0, np.argmax(vec)]
    max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
    return max_score + torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))

def log_sum_exp_mat(log_M, axis=-1):  # shape(n,m)
    return torch.max(log_M, axis)[0]+torch.log(torch.exp(log_M-torch.max(log_M, axis)[0][:, None]).sum(axis))

def log_sum_exp_batch(log_Tensor, axis=-1): # shape (batch_size,n,m)
    return torch.max(log_Tensor, axis)[0]+torch.log(torch.exp(log_Tensor-torch.max(log_Tensor, axis)[0].view(log_Tensor.shape[0],-1,1)).sum(axis))


class BERT_CRF_NER(nn.Module):

    def __init__(self, bert_model, start_label_id, stop_label_id, num_labels, max_seq_length, batch_size, device, return_dict=False):
        super(BERT_CRF_NER, self).__init__()
        self.hidden_size = 768
        self.start_label_id = start_label_id
        self.stop_label_id = stop_label_id
        self.num_labels = num_labels
        self.max_seq_length = max_seq_length
        self.batch_size = batch_size
        self.device=device

        # use pretrainded BertModel
        self.bert = bert_model
        
        self.dropout = torch.nn.Dropout(0.2)
        # Maps the output of the bert into label space.
        self.hidden2label = nn.Linear(self.hidden_size, self.num_labels)

        # Matrix of transition parameters.  Entry i,j is the score of transitioning *to* i *from* j.
        self.transitions = nn.Parameter(
            torch.randn(self.num_labels, self.num_labels))

        self.transitions.data[start_label_id, :] = -10000
        self.transitions.data[:, stop_label_id] = -10000

        nn.init.xavier_uniform_(self.hidden2label.weight)
        nn.init.constant_(self.hidden2label.bias, 0.0)

    def init_bert_weights(self, module):
        """ Initialize the weights.
        """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, BertLayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def _forward_alg(self, feats):
        '''
        this also called alpha-recursion or forward recursion, to calculate log_prob of all barX
        '''

        # T = self.max_seq_length
        T = feats.shape[1]
        batch_size = feats.shape[0]

        # alpha_recursion,forward, alpha(zt)=p(zt,bar_x_1:t)
        log_alpha = torch.Tensor(batch_size, 1, self.num_labels).fill_(-10000.).to(self.device)
        log_alpha[:, 0, self.start_label_id] = 0

        for t in range(1, T):
            log_alpha = (log_sum_exp_batch(self.transitions + log_alpha, axis=-1) + feats[:, t]).unsqueeze(1)

        # log_prob of all barX
        log_prob_all_barX = log_sum_exp_batch(log_alpha)
        return log_prob_all_barX

    def _get_bert_features(self, input_ids, segment_ids, input_mask):
        '''
        sentances -> word embedding -> lstm -> MLP -> feats
        '''
        bert_seq_out, _ = self.bert(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, return_dict=False)
        bert_seq_out = self.dropout(bert_seq_out)
        bert_feats = self.hidden2label(bert_seq_out)
        return bert_feats

    def _score_sentence(self, feats, label_ids):
        '''
        Gives the score of a provided label sequence
        p(X=w1:t,Zt=tag1:t)=...p(Zt=tag_t|Zt-1=tag_t-1)p(xt|Zt=tag_t)...
        '''
        T = feats.shape[1]
        batch_size = feats.shape[0]

        batch_transitions = self.transitions.expand(batch_size,self.num_labels,self.num_labels)
        batch_transitions = batch_transitions.flatten(1)

        score = torch.zeros((feats.shape[0],1)).to(device)

        for t in range(1, T):
            score = score + \
                batch_transitions.gather(-1, (label_ids[:, t]*self.num_labels+label_ids[:, t-1]).view(-1,1)) \
                    + feats[:, t].gather(-1, label_ids[:, t].view(-1,1)).view(-1,1)
        return score

    def _viterbi_decode(self, feats):
        '''
        Max-Product Algorithm or viterbi algorithm, argmax(p(z_0:t|x_0:t))
        '''

        # T = self.max_seq_length
        T = feats.shape[1]
        batch_size = feats.shape[0]

        # batch_transitions=self.transitions.expand(batch_size,self.num_labels,self.num_labels)

        log_delta = torch.Tensor(batch_size, 1, self.num_labels).fill_(-10000.).to(self.device)
        log_delta[:, 0, self.start_label_id] = 0

        # psi is for the vaule of the last latent that make P(this_latent) maximum.
        psi = torch.zeros((batch_size, T, self.num_labels), dtype=torch.long).to(self.device)  # psi[0]=0000 useless
        for t in range(1, T):
            log_delta, psi[:, t] = torch.max(self.transitions + log_delta, -1)
            log_delta = (log_delta + feats[:, t]).unsqueeze(1)

        # trace back
        path = torch.zeros((batch_size, T), dtype=torch.long).to(self.device)

        # max p(z1:t,all_x|theta)
        max_logLL_allz_allx, path[:, -1] = torch.max(log_delta.squeeze(), -1)

        for t in range(T-2, -1, -1):
            # choose the state of z_t according the state choosed of z_t+1.
            path[:, t] = psi[:, t+1].gather(-1,path[:, t+1].view(-1,1)).squeeze()

        return max_logLL_allz_allx, path

    def neg_log_likelihood(self, input_ids, segment_ids, input_mask, label_ids):
        bert_feats = self._get_bert_features(input_ids, segment_ids, input_mask)
        forward_score = self._forward_alg(bert_feats)

        gold_score = self._score_sentence(bert_feats, label_ids)

        return torch.mean(forward_score - gold_score)

    def forward(self, input_ids, segment_ids, input_mask):
        bert_feats = self._get_bert_features(input_ids, segment_ids, input_mask)
        score, label_seq_ids = self._viterbi_decode(bert_feats)
        return score, label_seq_ids


start_label_id = conllProcessor.get_start_label_id()
stop_label_id = conllProcessor.get_stop_label_id()

bert_model = BertModel.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
model = BERT_CRF_NER(bert_model, start_label_id, stop_label_id, len(label_list), max_seq_length, batch_size, device,return_dict=False)

*** Use BertModel + CRF ***


Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
learning_rate0 = 3e-5
lr0_crf_fc = 0

model.to(device)

# Prepare optimizer
param_optimizer = list(model.named_parameters())

no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
new_param = ['transitions', 'hidden2label.weight', 'hidden2label.bias']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay) \
        and not any(nd in n for nd in new_param)], 'weight_decay': 0.0},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) \
        and not any(nd in n for nd in new_param)], 'weight_decay': 0.0},
    {'params': [p for n, p in param_optimizer if n in ('transitions','hidden2label.weight')] \
        , 'lr':lr0_crf_fc, 'weight_decay': 0.0},
    {'params': [p for n, p in param_optimizer if n == 'hidden2label.bias'] \
        , 'lr':lr0_crf_fc, 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate0)



In [10]:
def warmup_linear(x, warmup=0.002):
    if x < warmup:
        return x/warmup
    return 1.0 - x

start_epoch = 0
# train procedure
global_step_th = int(len(train_examples) / batch_size / gradient_accumulation_steps * start_epoch)

print ("Number of epochs: ", total_train_epochs)

for epoch in range(start_epoch, total_train_epochs):
    print ("EPOCH: ", epoch)
    tr_loss = 0
    train_start = time.time()
    model.train()
    optimizer.zero_grad()
    # for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
    for step, batch in enumerate(train_dataloader):
        batch = tuple(t.to(device) for t in batch)
        input_ids, input_mask, segment_ids, predict_mask, label_ids = batch

        neg_log_likelihood = model.neg_log_likelihood(input_ids, segment_ids, input_mask, label_ids)

        if gradient_accumulation_steps > 1:
            neg_log_likelihood = neg_log_likelihood / gradient_accumulation_steps

        neg_log_likelihood.backward()

        tr_loss += neg_log_likelihood.item()

        if (step + 1) % gradient_accumulation_steps == 0:
            # modify learning rate with special warm up BERT uses
            lr_this_step = learning_rate0 * warmup_linear(global_step_th/total_train_steps)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_this_step
            optimizer.step()
            optimizer.zero_grad()
            global_step_th += 1

        print("Epoch:{}-{}/{}, Negative loglikelihood: {} ".format(epoch, step, len(train_dataloader), neg_log_likelihood.item()))

    print('--------------------------------------------------------------')
    print("Epoch:{} completed, Total training's Loss: {}, Spend: {}m".format(epoch, tr_loss, (time.time() - train_start)/60.0))


Number of epochs:  20
EPOCH:  0
Epoch:0-0/286, Negative loglikelihood: 7653.009765625 
Epoch:0-1/286, Negative loglikelihood: 7803.0849609375 
Epoch:0-2/286, Negative loglikelihood: 7815.12255859375 
Epoch:0-3/286, Negative loglikelihood: 7945.5869140625 
Epoch:0-4/286, Negative loglikelihood: 7924.73046875 
Epoch:0-5/286, Negative loglikelihood: 7876.068359375 
Epoch:0-6/286, Negative loglikelihood: 7712.5380859375 
Epoch:0-7/286, Negative loglikelihood: 7702.58984375 
Epoch:0-8/286, Negative loglikelihood: 7748.48388671875 
Epoch:0-9/286, Negative loglikelihood: 7698.2734375 
Epoch:0-10/286, Negative loglikelihood: 7735.5009765625 
Epoch:0-11/286, Negative loglikelihood: 7802.08984375 
Epoch:0-12/286, Negative loglikelihood: 7906.55615234375 
Epoch:0-13/286, Negative loglikelihood: 8012.767578125 
Epoch:0-14/286, Negative loglikelihood: 7590.8251953125 
Epoch:0-15/286, Negative loglikelihood: 7587.7646484375 
Epoch:0-16/286, Negative loglikelihood: 7684.1044921875 
Epoch:0-17/286, Ne

Epoch:0-144/286, Negative loglikelihood: 7526.2099609375 
Epoch:0-145/286, Negative loglikelihood: 7509.22998046875 
Epoch:0-146/286, Negative loglikelihood: 7512.85986328125 
Epoch:0-147/286, Negative loglikelihood: 7505.33447265625 
Epoch:0-148/286, Negative loglikelihood: 7528.060546875 
Epoch:0-149/286, Negative loglikelihood: 7535.84228515625 
Epoch:0-150/286, Negative loglikelihood: 7514.568359375 
Epoch:0-151/286, Negative loglikelihood: 7509.712890625 
Epoch:0-152/286, Negative loglikelihood: 7529.66748046875 
Epoch:0-153/286, Negative loglikelihood: 7565.24755859375 
Epoch:0-154/286, Negative loglikelihood: 7545.16748046875 
Epoch:0-155/286, Negative loglikelihood: 7517.77001953125 
Epoch:0-156/286, Negative loglikelihood: 7509.66943359375 
Epoch:0-157/286, Negative loglikelihood: 7529.35498046875 
Epoch:0-158/286, Negative loglikelihood: 7552.57421875 
Epoch:0-159/286, Negative loglikelihood: 7535.68408203125 
Epoch:0-160/286, Negative loglikelihood: 7513.1171875 
Epoch:0-161

Epoch:1-0/286, Negative loglikelihood: 7516.93603515625 
Epoch:1-1/286, Negative loglikelihood: 7507.53271484375 
Epoch:1-2/286, Negative loglikelihood: 7498.748046875 
Epoch:1-3/286, Negative loglikelihood: 7501.9326171875 
Epoch:1-4/286, Negative loglikelihood: 7501.751953125 
Epoch:1-5/286, Negative loglikelihood: 7529.13623046875 
Epoch:1-6/286, Negative loglikelihood: 7492.4296875 
Epoch:1-7/286, Negative loglikelihood: 7520.99755859375 
Epoch:1-8/286, Negative loglikelihood: 7502.74951171875 
Epoch:1-9/286, Negative loglikelihood: 7504.19775390625 
Epoch:1-10/286, Negative loglikelihood: 7507.89794921875 
Epoch:1-11/286, Negative loglikelihood: 7496.9326171875 
Epoch:1-12/286, Negative loglikelihood: 7501.708984375 
Epoch:1-13/286, Negative loglikelihood: 7504.24365234375 
Epoch:1-14/286, Negative loglikelihood: 7523.70166015625 
Epoch:1-15/286, Negative loglikelihood: 7511.0185546875 
Epoch:1-16/286, Negative loglikelihood: 7493.505859375 
Epoch:1-17/286, Negative loglikelihood:

Epoch:1-145/286, Negative loglikelihood: 7487.921875 
Epoch:1-146/286, Negative loglikelihood: 7495.3076171875 
Epoch:1-147/286, Negative loglikelihood: 7505.0390625 
Epoch:1-148/286, Negative loglikelihood: 7493.013671875 
Epoch:1-149/286, Negative loglikelihood: 7498.662109375 
Epoch:1-150/286, Negative loglikelihood: 7489.6474609375 
Epoch:1-151/286, Negative loglikelihood: 7494.40234375 
Epoch:1-152/286, Negative loglikelihood: 7495.3544921875 
Epoch:1-153/286, Negative loglikelihood: 7491.802734375 
Epoch:1-154/286, Negative loglikelihood: 7496.7373046875 
Epoch:1-155/286, Negative loglikelihood: 7501.07421875 
Epoch:1-156/286, Negative loglikelihood: 7502.07958984375 
Epoch:1-157/286, Negative loglikelihood: 7509.1611328125 
Epoch:1-158/286, Negative loglikelihood: 7509.3349609375 
Epoch:1-159/286, Negative loglikelihood: 7502.46337890625 
Epoch:1-160/286, Negative loglikelihood: 7506.1171875 
Epoch:1-161/286, Negative loglikelihood: 7506.08984375 
Epoch:1-162/286, Negative logli

Epoch:2-0/286, Negative loglikelihood: 7499.494140625 
Epoch:2-1/286, Negative loglikelihood: 7485.2265625 
Epoch:2-2/286, Negative loglikelihood: 7495.3798828125 
Epoch:2-3/286, Negative loglikelihood: 7495.859375 
Epoch:2-4/286, Negative loglikelihood: 7488.4599609375 
Epoch:2-5/286, Negative loglikelihood: 7486.4423828125 
Epoch:2-6/286, Negative loglikelihood: 7495.9072265625 
Epoch:2-7/286, Negative loglikelihood: 7498.86328125 
Epoch:2-8/286, Negative loglikelihood: 7487.9814453125 
Epoch:2-9/286, Negative loglikelihood: 7493.61767578125 
Epoch:2-10/286, Negative loglikelihood: 7507.353515625 
Epoch:2-11/286, Negative loglikelihood: 7487.19482421875 
Epoch:2-12/286, Negative loglikelihood: 7489.341796875 
Epoch:2-13/286, Negative loglikelihood: 7491.14453125 
Epoch:2-14/286, Negative loglikelihood: 4999.8134765625 
Epoch:2-15/286, Negative loglikelihood: 7491.43310546875 
Epoch:2-16/286, Negative loglikelihood: 7491.146484375 
Epoch:2-17/286, Negative loglikelihood: 7489.71386718

Epoch:2-146/286, Negative loglikelihood: 7497.71875 
Epoch:2-147/286, Negative loglikelihood: 7486.1708984375 
Epoch:2-148/286, Negative loglikelihood: 7487.51953125 
Epoch:2-149/286, Negative loglikelihood: 7493.681640625 
Epoch:2-150/286, Negative loglikelihood: 7487.263671875 
Epoch:2-151/286, Negative loglikelihood: 7484.52392578125 
Epoch:2-152/286, Negative loglikelihood: 7486.64453125 
Epoch:2-153/286, Negative loglikelihood: 7492.962890625 
Epoch:2-154/286, Negative loglikelihood: 7496.1962890625 
Epoch:2-155/286, Negative loglikelihood: 7491.60791015625 
Epoch:2-156/286, Negative loglikelihood: 7502.98046875 
Epoch:2-157/286, Negative loglikelihood: 7494.01708984375 
Epoch:2-158/286, Negative loglikelihood: 7496.0380859375 
Epoch:2-159/286, Negative loglikelihood: 7516.6396484375 
Epoch:2-160/286, Negative loglikelihood: 7500.3173828125 
Epoch:2-161/286, Negative loglikelihood: 7491.73828125 
Epoch:2-162/286, Negative loglikelihood: 7501.904296875 
Epoch:2-163/286, Negative lo

Epoch:3-1/286, Negative loglikelihood: 7491.5986328125 
Epoch:3-2/286, Negative loglikelihood: 7485.55078125 
Epoch:3-3/286, Negative loglikelihood: 7499.9287109375 
Epoch:3-4/286, Negative loglikelihood: 7496.4326171875 
Epoch:3-5/286, Negative loglikelihood: 7483.8681640625 
Epoch:3-6/286, Negative loglikelihood: 7492.0712890625 
Epoch:3-7/286, Negative loglikelihood: 7489.8994140625 
Epoch:3-8/286, Negative loglikelihood: 7493.2646484375 
Epoch:3-9/286, Negative loglikelihood: 7487.64013671875 
Epoch:3-10/286, Negative loglikelihood: 7488.52099609375 
Epoch:3-11/286, Negative loglikelihood: 7489.1669921875 
Epoch:3-12/286, Negative loglikelihood: 7486.01708984375 
Epoch:3-13/286, Negative loglikelihood: 7496.74658203125 
Epoch:3-14/286, Negative loglikelihood: 7488.6953125 
Epoch:3-15/286, Negative loglikelihood: 7488.791015625 
Epoch:3-16/286, Negative loglikelihood: 7483.61669921875 
Epoch:3-17/286, Negative loglikelihood: 7487.28955078125 
Epoch:3-18/286, Negative loglikelihood: 

Epoch:3-145/286, Negative loglikelihood: 7482.314453125 
Epoch:3-146/286, Negative loglikelihood: 7488.9462890625 
Epoch:3-147/286, Negative loglikelihood: 7494.9912109375 
Epoch:3-148/286, Negative loglikelihood: 7491.62744140625 
Epoch:3-149/286, Negative loglikelihood: 7486.5224609375 
Epoch:3-150/286, Negative loglikelihood: 7491.6201171875 
Epoch:3-151/286, Negative loglikelihood: 7487.0107421875 
Epoch:3-152/286, Negative loglikelihood: 7494.5380859375 
Epoch:3-153/286, Negative loglikelihood: 7489.3623046875 
Epoch:3-154/286, Negative loglikelihood: 7490.16796875 
Epoch:3-155/286, Negative loglikelihood: 7506.4140625 
Epoch:3-156/286, Negative loglikelihood: 7500.01953125 
Epoch:3-157/286, Negative loglikelihood: 7489.51708984375 
Epoch:3-158/286, Negative loglikelihood: 7482.662109375 
Epoch:3-159/286, Negative loglikelihood: 7489.3681640625 
Epoch:3-160/286, Negative loglikelihood: 7491.037109375 
Epoch:3-161/286, Negative loglikelihood: 7483.56298828125 
Epoch:3-162/286, Nega

Epoch:4-0/286, Negative loglikelihood: 7487.41162109375 
Epoch:4-1/286, Negative loglikelihood: 7481.83740234375 
Epoch:4-2/286, Negative loglikelihood: 7484.78369140625 
Epoch:4-3/286, Negative loglikelihood: 4999.18115234375 
Epoch:4-4/286, Negative loglikelihood: 7486.181640625 
Epoch:4-5/286, Negative loglikelihood: 7481.3720703125 
Epoch:4-6/286, Negative loglikelihood: 7481.068359375 
Epoch:4-7/286, Negative loglikelihood: 7493.740234375 
Epoch:4-8/286, Negative loglikelihood: 7483.5341796875 
Epoch:4-9/286, Negative loglikelihood: 7489.40087890625 
Epoch:4-10/286, Negative loglikelihood: 7483.61572265625 
Epoch:4-11/286, Negative loglikelihood: 7490.685546875 
Epoch:4-12/286, Negative loglikelihood: 7485.328125 
Epoch:4-13/286, Negative loglikelihood: 7485.765625 
Epoch:4-14/286, Negative loglikelihood: 7485.60693359375 
Epoch:4-15/286, Negative loglikelihood: 7488.07666015625 
Epoch:4-16/286, Negative loglikelihood: 7484.40771484375 
Epoch:4-17/286, Negative loglikelihood: 7487

Epoch:4-146/286, Negative loglikelihood: 7479.296875 
Epoch:4-147/286, Negative loglikelihood: 7484.412109375 
Epoch:4-148/286, Negative loglikelihood: 7481.359375 
Epoch:4-149/286, Negative loglikelihood: 7480.24267578125 
Epoch:4-150/286, Negative loglikelihood: 7485.35009765625 
Epoch:4-151/286, Negative loglikelihood: 7479.166015625 
Epoch:4-152/286, Negative loglikelihood: 7479.0478515625 
Epoch:4-153/286, Negative loglikelihood: 7480.65283203125 
Epoch:4-154/286, Negative loglikelihood: 7478.89404296875 
Epoch:4-155/286, Negative loglikelihood: 7481.427734375 
Epoch:4-156/286, Negative loglikelihood: 7481.63525390625 
Epoch:4-157/286, Negative loglikelihood: 7488.494140625 
Epoch:4-158/286, Negative loglikelihood: 7484.666015625 
Epoch:4-159/286, Negative loglikelihood: 7482.330078125 
Epoch:4-160/286, Negative loglikelihood: 7484.34765625 
Epoch:4-161/286, Negative loglikelihood: 7483.619140625 
Epoch:4-162/286, Negative loglikelihood: 7489.8232421875 
Epoch:4-163/286, Negative 

Epoch:5-0/286, Negative loglikelihood: 7507.7333984375 
Epoch:5-1/286, Negative loglikelihood: 7510.3232421875 
Epoch:5-2/286, Negative loglikelihood: 7504.697265625 
Epoch:5-3/286, Negative loglikelihood: 7507.2666015625 
Epoch:5-4/286, Negative loglikelihood: 7512.0244140625 
Epoch:5-5/286, Negative loglikelihood: 7508.97265625 
Epoch:5-6/286, Negative loglikelihood: 7512.88525390625 
Epoch:5-7/286, Negative loglikelihood: 7501.583984375 
Epoch:5-8/286, Negative loglikelihood: 7499.625 
Epoch:5-9/286, Negative loglikelihood: 7502.0966796875 
Epoch:5-10/286, Negative loglikelihood: 7498.2607421875 
Epoch:5-11/286, Negative loglikelihood: 7508.33056640625 
Epoch:5-12/286, Negative loglikelihood: 7516.86669921875 
Epoch:5-13/286, Negative loglikelihood: 7498.8056640625 
Epoch:5-14/286, Negative loglikelihood: 7494.5830078125 
Epoch:5-15/286, Negative loglikelihood: 7513.07275390625 
Epoch:5-16/286, Negative loglikelihood: 7499.4169921875 
Epoch:5-17/286, Negative loglikelihood: 7496.265

Epoch:5-145/286, Negative loglikelihood: 7482.0908203125 
Epoch:5-146/286, Negative loglikelihood: 7482.38671875 
Epoch:5-147/286, Negative loglikelihood: 7487.4228515625 
Epoch:5-148/286, Negative loglikelihood: 7500.54296875 
Epoch:5-149/286, Negative loglikelihood: 7480.1064453125 
Epoch:5-150/286, Negative loglikelihood: 7479.236328125 
Epoch:5-151/286, Negative loglikelihood: 7507.662109375 
Epoch:5-152/286, Negative loglikelihood: 7484.0673828125 
Epoch:5-153/286, Negative loglikelihood: 7488.20849609375 
Epoch:5-154/286, Negative loglikelihood: 7486.248046875 
Epoch:5-155/286, Negative loglikelihood: 7481.171875 
Epoch:5-156/286, Negative loglikelihood: 7482.703125 
Epoch:5-157/286, Negative loglikelihood: 7481.857421875 
Epoch:5-158/286, Negative loglikelihood: 7484.76708984375 
Epoch:5-159/286, Negative loglikelihood: 7488.8349609375 
Epoch:5-160/286, Negative loglikelihood: 7480.49462890625 
Epoch:5-161/286, Negative loglikelihood: 7486.7373046875 
Epoch:5-162/286, Negative l

Epoch:6-0/286, Negative loglikelihood: 7479.9365234375 
Epoch:6-1/286, Negative loglikelihood: 7476.1796875 
Epoch:6-2/286, Negative loglikelihood: 7482.97802734375 
Epoch:6-3/286, Negative loglikelihood: 7482.78466796875 
Epoch:6-4/286, Negative loglikelihood: 7476.8154296875 
Epoch:6-5/286, Negative loglikelihood: 7478.9482421875 
Epoch:6-6/286, Negative loglikelihood: 7475.9111328125 
Epoch:6-7/286, Negative loglikelihood: 7479.8232421875 
Epoch:6-8/286, Negative loglikelihood: 7477.65478515625 
Epoch:6-9/286, Negative loglikelihood: 7479.0 
Epoch:6-10/286, Negative loglikelihood: 7481.53125 
Epoch:6-11/286, Negative loglikelihood: 7480.9296875 
Epoch:6-12/286, Negative loglikelihood: 7478.06298828125 
Epoch:6-13/286, Negative loglikelihood: 7478.41162109375 
Epoch:6-14/286, Negative loglikelihood: 7476.4111328125 
Epoch:6-15/286, Negative loglikelihood: 7477.3984375 
Epoch:6-16/286, Negative loglikelihood: 7477.8369140625 
Epoch:6-17/286, Negative loglikelihood: 7478.4521484375 
Ep

Epoch:6-145/286, Negative loglikelihood: 7474.06298828125 
Epoch:6-146/286, Negative loglikelihood: 7473.49169921875 
Epoch:6-147/286, Negative loglikelihood: 7478.14990234375 
Epoch:6-148/286, Negative loglikelihood: 7477.4169921875 
Epoch:6-149/286, Negative loglikelihood: 7475.61181640625 
Epoch:6-150/286, Negative loglikelihood: 7474.7763671875 
Epoch:6-151/286, Negative loglikelihood: 7473.4365234375 
Epoch:6-152/286, Negative loglikelihood: 7475.55126953125 
Epoch:6-153/286, Negative loglikelihood: 7483.107421875 
Epoch:6-154/286, Negative loglikelihood: 7475.0966796875 
Epoch:6-155/286, Negative loglikelihood: 7472.65283203125 
Epoch:6-156/286, Negative loglikelihood: 7476.388671875 
Epoch:6-157/286, Negative loglikelihood: 7475.0810546875 
Epoch:6-158/286, Negative loglikelihood: 7473.62353515625 
Epoch:6-159/286, Negative loglikelihood: 7480.9521484375 
Epoch:6-160/286, Negative loglikelihood: 7475.78759765625 
Epoch:6-161/286, Negative loglikelihood: 7477.7109375 
Epoch:6-162

Epoch:7-0/286, Negative loglikelihood: 7474.9765625 
Epoch:7-1/286, Negative loglikelihood: 7476.1591796875 
Epoch:7-2/286, Negative loglikelihood: 7474.33056640625 
Epoch:7-3/286, Negative loglikelihood: 7472.2568359375 
Epoch:7-4/286, Negative loglikelihood: 7474.68359375 
Epoch:7-5/286, Negative loglikelihood: 7475.9609375 
Epoch:7-6/286, Negative loglikelihood: 7474.0244140625 
Epoch:7-7/286, Negative loglikelihood: 7475.33203125 
Epoch:7-8/286, Negative loglikelihood: 7472.1279296875 
Epoch:7-9/286, Negative loglikelihood: 7481.755859375 
Epoch:7-10/286, Negative loglikelihood: 7474.42919921875 
Epoch:7-11/286, Negative loglikelihood: 4982.98681640625 
Epoch:7-12/286, Negative loglikelihood: 7471.75048828125 
Epoch:7-13/286, Negative loglikelihood: 7479.734375 
Epoch:7-14/286, Negative loglikelihood: 7474.6357421875 
Epoch:7-15/286, Negative loglikelihood: 7473.9345703125 
Epoch:7-16/286, Negative loglikelihood: 7475.2490234375 
Epoch:7-17/286, Negative loglikelihood: 7472.8037109

Epoch:7-145/286, Negative loglikelihood: 7474.1494140625 
Epoch:7-146/286, Negative loglikelihood: 7476.20751953125 
Epoch:7-147/286, Negative loglikelihood: 7470.73974609375 
Epoch:7-148/286, Negative loglikelihood: 7479.33056640625 
Epoch:7-149/286, Negative loglikelihood: 7473.8173828125 
Epoch:7-150/286, Negative loglikelihood: 7473.09375 
Epoch:7-151/286, Negative loglikelihood: 7471.0986328125 
Epoch:7-152/286, Negative loglikelihood: 7471.92431640625 
Epoch:7-153/286, Negative loglikelihood: 7476.50732421875 
Epoch:7-154/286, Negative loglikelihood: 7472.3564453125 
Epoch:7-155/286, Negative loglikelihood: 7471.08203125 
Epoch:7-156/286, Negative loglikelihood: 7469.86376953125 
Epoch:7-157/286, Negative loglikelihood: 7470.9892578125 
Epoch:7-158/286, Negative loglikelihood: 7474.69873046875 
Epoch:7-159/286, Negative loglikelihood: 7476.81298828125 
Epoch:7-160/286, Negative loglikelihood: 7473.69677734375 
Epoch:7-161/286, Negative loglikelihood: 7470.037109375 
Epoch:7-162/2

Epoch:8-0/286, Negative loglikelihood: 7473.57568359375 
Epoch:8-1/286, Negative loglikelihood: 7469.79736328125 
Epoch:8-2/286, Negative loglikelihood: 7468.71875 
Epoch:8-3/286, Negative loglikelihood: 7473.29931640625 
Epoch:8-4/286, Negative loglikelihood: 7475.96630859375 
Epoch:8-5/286, Negative loglikelihood: 7468.58203125 
Epoch:8-6/286, Negative loglikelihood: 7470.79931640625 
Epoch:8-7/286, Negative loglikelihood: 7469.2587890625 
Epoch:8-8/286, Negative loglikelihood: 7470.4482421875 
Epoch:8-9/286, Negative loglikelihood: 7469.6279296875 
Epoch:8-10/286, Negative loglikelihood: 7468.70263671875 
Epoch:8-11/286, Negative loglikelihood: 7474.2060546875 
Epoch:8-12/286, Negative loglikelihood: 7468.91064453125 
Epoch:8-13/286, Negative loglikelihood: 7470.75146484375 
Epoch:8-14/286, Negative loglikelihood: 7470.40576171875 
Epoch:8-15/286, Negative loglikelihood: 7471.99462890625 
Epoch:8-16/286, Negative loglikelihood: 7485.65576171875 
Epoch:8-17/286, Negative loglikelihoo

Epoch:8-145/286, Negative loglikelihood: 7469.1123046875 
Epoch:8-146/286, Negative loglikelihood: 2489.395751953125 
Epoch:8-147/286, Negative loglikelihood: 7470.36181640625 
Epoch:8-148/286, Negative loglikelihood: 7468.74169921875 
Epoch:8-149/286, Negative loglikelihood: 7469.43359375 
Epoch:8-150/286, Negative loglikelihood: 7470.8818359375 
Epoch:8-151/286, Negative loglikelihood: 7468.03369140625 
Epoch:8-152/286, Negative loglikelihood: 7468.0927734375 
Epoch:8-153/286, Negative loglikelihood: 7469.26904296875 
Epoch:8-154/286, Negative loglikelihood: 7468.24560546875 
Epoch:8-155/286, Negative loglikelihood: 7467.33056640625 
Epoch:8-156/286, Negative loglikelihood: 7472.52490234375 
Epoch:8-157/286, Negative loglikelihood: 7468.421875 
Epoch:8-158/286, Negative loglikelihood: 7470.6689453125 
Epoch:8-159/286, Negative loglikelihood: 7468.732421875 
Epoch:8-160/286, Negative loglikelihood: 7469.2919921875 
Epoch:8-161/286, Negative loglikelihood: 7468.0048828125 
Epoch:8-162/

Epoch:9-0/286, Negative loglikelihood: 7468.2705078125 
Epoch:9-1/286, Negative loglikelihood: 7469.806640625 
Epoch:9-2/286, Negative loglikelihood: 7467.294921875 
Epoch:9-3/286, Negative loglikelihood: 7467.2568359375 
Epoch:9-4/286, Negative loglikelihood: 7467.4453125 
Epoch:9-5/286, Negative loglikelihood: 7466.6171875 
Epoch:9-6/286, Negative loglikelihood: 7469.88720703125 
Epoch:9-7/286, Negative loglikelihood: 7471.1474609375 
Epoch:9-8/286, Negative loglikelihood: 7469.6396484375 
Epoch:9-9/286, Negative loglikelihood: 7467.77978515625 
Epoch:9-10/286, Negative loglikelihood: 7467.060546875 
Epoch:9-11/286, Negative loglikelihood: 7468.8857421875 
Epoch:9-12/286, Negative loglikelihood: 7468.97607421875 
Epoch:9-13/286, Negative loglikelihood: 7465.7509765625 
Epoch:9-14/286, Negative loglikelihood: 7468.24169921875 
Epoch:9-15/286, Negative loglikelihood: 7470.0927734375 
Epoch:9-16/286, Negative loglikelihood: 7469.431640625 
Epoch:9-17/286, Negative loglikelihood: 7468.07

Epoch:9-146/286, Negative loglikelihood: 7464.8486328125 
Epoch:9-147/286, Negative loglikelihood: 7466.79296875 
Epoch:9-148/286, Negative loglikelihood: 7465.935546875 
Epoch:9-149/286, Negative loglikelihood: 7466.83056640625 
Epoch:9-150/286, Negative loglikelihood: 7473.2763671875 
Epoch:9-151/286, Negative loglikelihood: 7467.376953125 
Epoch:9-152/286, Negative loglikelihood: 7466.7021484375 
Epoch:9-153/286, Negative loglikelihood: 7465.81494140625 
Epoch:9-154/286, Negative loglikelihood: 7467.017578125 
Epoch:9-155/286, Negative loglikelihood: 7463.892578125 
Epoch:9-156/286, Negative loglikelihood: 7467.556640625 
Epoch:9-157/286, Negative loglikelihood: 7466.435546875 
Epoch:9-158/286, Negative loglikelihood: 7468.1806640625 
Epoch:9-159/286, Negative loglikelihood: 7469.6904296875 
Epoch:9-160/286, Negative loglikelihood: 7468.388671875 
Epoch:9-161/286, Negative loglikelihood: 7465.6552734375 
Epoch:9-162/286, Negative loglikelihood: 7465.62109375 
Epoch:9-163/286, Negati

Epoch:10-0/286, Negative loglikelihood: 7468.2119140625 
Epoch:10-1/286, Negative loglikelihood: 7464.9130859375 
Epoch:10-2/286, Negative loglikelihood: 7465.88330078125 
Epoch:10-3/286, Negative loglikelihood: 7465.1083984375 
Epoch:10-4/286, Negative loglikelihood: 7467.5205078125 
Epoch:10-5/286, Negative loglikelihood: 7466.6611328125 
Epoch:10-6/286, Negative loglikelihood: 7468.50634765625 
Epoch:10-7/286, Negative loglikelihood: 7465.666015625 
Epoch:10-8/286, Negative loglikelihood: 7464.72802734375 
Epoch:10-9/286, Negative loglikelihood: 4977.2021484375 
Epoch:10-10/286, Negative loglikelihood: 7473.56494140625 
Epoch:10-11/286, Negative loglikelihood: 7465.09228515625 
Epoch:10-12/286, Negative loglikelihood: 7463.7724609375 
Epoch:10-13/286, Negative loglikelihood: 7464.44677734375 
Epoch:10-14/286, Negative loglikelihood: 7464.6005859375 
Epoch:10-15/286, Negative loglikelihood: 7467.61376953125 
Epoch:10-16/286, Negative loglikelihood: 7466.50634765625 
Epoch:10-17/286, 

Epoch:10-142/286, Negative loglikelihood: 7464.375 
Epoch:10-143/286, Negative loglikelihood: 7463.751953125 
Epoch:10-144/286, Negative loglikelihood: 7464.1015625 
Epoch:10-145/286, Negative loglikelihood: 4975.9052734375 
Epoch:10-146/286, Negative loglikelihood: 7463.26611328125 
Epoch:10-147/286, Negative loglikelihood: 7464.2109375 
Epoch:10-148/286, Negative loglikelihood: 7463.548828125 
Epoch:10-149/286, Negative loglikelihood: 7462.91455078125 
Epoch:10-150/286, Negative loglikelihood: 4976.44775390625 
Epoch:10-151/286, Negative loglikelihood: 7463.7841796875 
Epoch:10-152/286, Negative loglikelihood: 7464.927734375 
Epoch:10-153/286, Negative loglikelihood: 7465.6171875 
Epoch:10-154/286, Negative loglikelihood: 7464.48046875 
Epoch:10-155/286, Negative loglikelihood: 7465.265625 
Epoch:10-156/286, Negative loglikelihood: 7463.70703125 
Epoch:10-157/286, Negative loglikelihood: 7464.1650390625 
Epoch:10-158/286, Negative loglikelihood: 7463.7685546875 
Epoch:10-159/286, Neg

Epoch:10-284/286, Negative loglikelihood: 7463.609375 
Epoch:10-285/286, Negative loglikelihood: 6637.3369140625 
--------------------------------------------------------------
Epoch:10 completed, Total training's Loss: 2119104.2817382812, Spend: 1.1721292813618978m
EPOCH:  11
Epoch:11-0/286, Negative loglikelihood: 7462.35693359375 
Epoch:11-1/286, Negative loglikelihood: 7466.13623046875 
Epoch:11-2/286, Negative loglikelihood: 7464.91259765625 
Epoch:11-3/286, Negative loglikelihood: 7463.5498046875 
Epoch:11-4/286, Negative loglikelihood: 7464.59765625 
Epoch:11-5/286, Negative loglikelihood: 7462.7783203125 
Epoch:11-6/286, Negative loglikelihood: 7463.095703125 
Epoch:11-7/286, Negative loglikelihood: 7465.697265625 
Epoch:11-8/286, Negative loglikelihood: 7462.4921875 
Epoch:11-9/286, Negative loglikelihood: 7463.45703125 
Epoch:11-10/286, Negative loglikelihood: 7463.509765625 
Epoch:11-11/286, Negative loglikelihood: 7462.22998046875 
Epoch:11-12/286, Negative loglikelihood: 7

Epoch:11-137/286, Negative loglikelihood: 7461.3720703125 
Epoch:11-138/286, Negative loglikelihood: 7461.6796875 
Epoch:11-139/286, Negative loglikelihood: 7461.3818359375 
Epoch:11-140/286, Negative loglikelihood: 4974.15771484375 
Epoch:11-141/286, Negative loglikelihood: 7462.11865234375 
Epoch:11-142/286, Negative loglikelihood: 7461.697265625 
Epoch:11-143/286, Negative loglikelihood: 7461.376953125 
Epoch:11-144/286, Negative loglikelihood: 7462.2978515625 
Epoch:11-145/286, Negative loglikelihood: 7460.671875 
Epoch:11-146/286, Negative loglikelihood: 7462.7158203125 
Epoch:11-147/286, Negative loglikelihood: 7460.96484375 
Epoch:11-148/286, Negative loglikelihood: 7462.8115234375 
Epoch:11-149/286, Negative loglikelihood: 7460.7421875 
Epoch:11-150/286, Negative loglikelihood: 7461.001953125 
Epoch:11-151/286, Negative loglikelihood: 7462.5869140625 
Epoch:11-152/286, Negative loglikelihood: 7461.6416015625 
Epoch:11-153/286, Negative loglikelihood: 7462.9130859375 
Epoch:11-1

Epoch:11-277/286, Negative loglikelihood: 7461.4658203125 
Epoch:11-278/286, Negative loglikelihood: 7461.1669921875 
Epoch:11-279/286, Negative loglikelihood: 7461.5302734375 
Epoch:11-280/286, Negative loglikelihood: 7461.8291015625 
Epoch:11-281/286, Negative loglikelihood: 7462.029296875 
Epoch:11-282/286, Negative loglikelihood: 7461.03466796875 
Epoch:11-283/286, Negative loglikelihood: 7462.25 
Epoch:11-284/286, Negative loglikelihood: 7469.7236328125 
Epoch:11-285/286, Negative loglikelihood: 6635.87646484375 
--------------------------------------------------------------
Epoch:11 completed, Total training's Loss: 2120981.9609375, Spend: 1.1669305205345153m
EPOCH:  12
Epoch:12-0/286, Negative loglikelihood: 7460.17431640625 
Epoch:12-1/286, Negative loglikelihood: 7462.8681640625 
Epoch:12-2/286, Negative loglikelihood: 7461.005859375 
Epoch:12-3/286, Negative loglikelihood: 7462.705078125 
Epoch:12-4/286, Negative loglikelihood: 7461.4111328125 
Epoch:12-5/286, Negative loglik

Epoch:12-132/286, Negative loglikelihood: 7460.7490234375 
Epoch:12-133/286, Negative loglikelihood: 7461.6865234375 
Epoch:12-134/286, Negative loglikelihood: 7459.9736328125 
Epoch:12-135/286, Negative loglikelihood: 7461.474609375 
Epoch:12-136/286, Negative loglikelihood: 7460.0595703125 
Epoch:12-137/286, Negative loglikelihood: 7459.9072265625 
Epoch:12-138/286, Negative loglikelihood: 7460.42138671875 
Epoch:12-139/286, Negative loglikelihood: 7458.89111328125 
Epoch:12-140/286, Negative loglikelihood: 7460.33740234375 
Epoch:12-141/286, Negative loglikelihood: 7458.91748046875 
Epoch:12-142/286, Negative loglikelihood: 7462.5341796875 
Epoch:12-143/286, Negative loglikelihood: 7459.8681640625 
Epoch:12-144/286, Negative loglikelihood: 7459.57080078125 
Epoch:12-145/286, Negative loglikelihood: 7459.25146484375 
Epoch:12-146/286, Negative loglikelihood: 7460.0234375 
Epoch:12-147/286, Negative loglikelihood: 7461.5322265625 
Epoch:12-148/286, Negative loglikelihood: 7461.5869140

Epoch:12-272/286, Negative loglikelihood: 7459.4716796875 
Epoch:12-273/286, Negative loglikelihood: 7459.1572265625 
Epoch:12-274/286, Negative loglikelihood: 7459.5869140625 
Epoch:12-275/286, Negative loglikelihood: 7459.48046875 
Epoch:12-276/286, Negative loglikelihood: 7460.9677734375 
Epoch:12-277/286, Negative loglikelihood: 7460.556640625 
Epoch:12-278/286, Negative loglikelihood: 7461.93408203125 
Epoch:12-279/286, Negative loglikelihood: 7458.77734375 
Epoch:12-280/286, Negative loglikelihood: 7459.21435546875 
Epoch:12-281/286, Negative loglikelihood: 7459.04296875 
Epoch:12-282/286, Negative loglikelihood: 7459.013671875 
Epoch:12-283/286, Negative loglikelihood: 7458.96044921875 
Epoch:12-284/286, Negative loglikelihood: 7459.361328125 
Epoch:12-285/286, Negative loglikelihood: 6633.8818359375 
--------------------------------------------------------------
Epoch:12 completed, Total training's Loss: 2117912.919189453, Spend: 1.1663667241732278m
EPOCH:  13
Epoch:13-0/286, N

Epoch:13-125/286, Negative loglikelihood: 7459.5498046875 
Epoch:13-126/286, Negative loglikelihood: 7459.1953125 
Epoch:13-127/286, Negative loglikelihood: 7459.0400390625 
Epoch:13-128/286, Negative loglikelihood: 7458.541015625 
Epoch:13-129/286, Negative loglikelihood: 7457.9931640625 
Epoch:13-130/286, Negative loglikelihood: 7458.5400390625 
Epoch:13-131/286, Negative loglikelihood: 7458.76953125 
Epoch:13-132/286, Negative loglikelihood: 7458.91943359375 
Epoch:13-133/286, Negative loglikelihood: 7460.341796875 
Epoch:13-134/286, Negative loglikelihood: 7457.7490234375 
Epoch:13-135/286, Negative loglikelihood: 7460.40673828125 
Epoch:13-136/286, Negative loglikelihood: 7461.927734375 
Epoch:13-137/286, Negative loglikelihood: 7461.5673828125 
Epoch:13-138/286, Negative loglikelihood: 7459.8857421875 
Epoch:13-139/286, Negative loglikelihood: 7457.1787109375 
Epoch:13-140/286, Negative loglikelihood: 7459.150390625 
Epoch:13-141/286, Negative loglikelihood: 7460.25244140625 
Epo

Epoch:13-266/286, Negative loglikelihood: 7458.896484375 
Epoch:13-267/286, Negative loglikelihood: 7457.28515625 
Epoch:13-268/286, Negative loglikelihood: 7458.16650390625 
Epoch:13-269/286, Negative loglikelihood: 7457.830078125 
Epoch:13-270/286, Negative loglikelihood: 7457.9697265625 
Epoch:13-271/286, Negative loglikelihood: 7456.79931640625 
Epoch:13-272/286, Negative loglikelihood: 7459.19091796875 
Epoch:13-273/286, Negative loglikelihood: 7459.0712890625 
Epoch:13-274/286, Negative loglikelihood: 7457.6279296875 
Epoch:13-275/286, Negative loglikelihood: 7457.1103515625 
Epoch:13-276/286, Negative loglikelihood: 7458.7734375 
Epoch:13-277/286, Negative loglikelihood: 7457.55859375 
Epoch:13-278/286, Negative loglikelihood: 7457.87109375 
Epoch:13-279/286, Negative loglikelihood: 7458.8525390625 
Epoch:13-280/286, Negative loglikelihood: 7457.05029296875 
Epoch:13-281/286, Negative loglikelihood: 7458.33251953125 
Epoch:13-282/286, Negative loglikelihood: 7457.599609375 
Epoc

Epoch:14-120/286, Negative loglikelihood: 7458.21728515625 
Epoch:14-121/286, Negative loglikelihood: 7457.111328125 
Epoch:14-122/286, Negative loglikelihood: 7458.66357421875 
Epoch:14-123/286, Negative loglikelihood: 4971.021484375 
Epoch:14-124/286, Negative loglikelihood: 7456.9375 
Epoch:14-125/286, Negative loglikelihood: 7456.419921875 
Epoch:14-126/286, Negative loglikelihood: 7457.6328125 
Epoch:14-127/286, Negative loglikelihood: 7456.99169921875 
Epoch:14-128/286, Negative loglikelihood: 7458.23779296875 
Epoch:14-129/286, Negative loglikelihood: 7456.00048828125 
Epoch:14-130/286, Negative loglikelihood: 7456.595703125 
Epoch:14-131/286, Negative loglikelihood: 7455.95654296875 
Epoch:14-132/286, Negative loglikelihood: 7455.55078125 
Epoch:14-133/286, Negative loglikelihood: 7457.66357421875 
Epoch:14-134/286, Negative loglikelihood: 7456.4892578125 
Epoch:14-135/286, Negative loglikelihood: 7456.8125 
Epoch:14-136/286, Negative loglikelihood: 7457.28173828125 
Epoch:14-1

Epoch:14-260/286, Negative loglikelihood: 7456.65673828125 
Epoch:14-261/286, Negative loglikelihood: 7456.36767578125 
Epoch:14-262/286, Negative loglikelihood: 7455.5224609375 
Epoch:14-263/286, Negative loglikelihood: 7455.828125 
Epoch:14-264/286, Negative loglikelihood: 7455.79248046875 
Epoch:14-265/286, Negative loglikelihood: 7456.427734375 
Epoch:14-266/286, Negative loglikelihood: 7457.21484375 
Epoch:14-267/286, Negative loglikelihood: 7457.1650390625 
Epoch:14-268/286, Negative loglikelihood: 7455.7880859375 
Epoch:14-269/286, Negative loglikelihood: 7455.75830078125 
Epoch:14-270/286, Negative loglikelihood: 7456.2353515625 
Epoch:14-271/286, Negative loglikelihood: 7456.59375 
Epoch:14-272/286, Negative loglikelihood: 7456.4541015625 
Epoch:14-273/286, Negative loglikelihood: 7457.3330078125 
Epoch:14-274/286, Negative loglikelihood: 7455.939453125 
Epoch:14-275/286, Negative loglikelihood: 7456.6572265625 
Epoch:14-276/286, Negative loglikelihood: 7458.060546875 
Epoch:1

Epoch:15-114/286, Negative loglikelihood: 7455.564453125 
Epoch:15-115/286, Negative loglikelihood: 7454.53125 
Epoch:15-116/286, Negative loglikelihood: 7457.5224609375 
Epoch:15-117/286, Negative loglikelihood: 7455.224609375 
Epoch:15-118/286, Negative loglikelihood: 7456.21875 
Epoch:15-119/286, Negative loglikelihood: 7458.330078125 
Epoch:15-120/286, Negative loglikelihood: 7454.4794921875 
Epoch:15-121/286, Negative loglikelihood: 7456.8408203125 
Epoch:15-122/286, Negative loglikelihood: 7454.8935546875 
Epoch:15-123/286, Negative loglikelihood: 7455.001953125 
Epoch:15-124/286, Negative loglikelihood: 7455.33984375 
Epoch:15-125/286, Negative loglikelihood: 7457.2099609375 
Epoch:15-126/286, Negative loglikelihood: 7455.96435546875 
Epoch:15-127/286, Negative loglikelihood: 7454.833984375 
Epoch:15-128/286, Negative loglikelihood: 7455.3056640625 
Epoch:15-129/286, Negative loglikelihood: 7457.1533203125 
Epoch:15-130/286, Negative loglikelihood: 7455.484375 
Epoch:15-131/286,

Epoch:15-254/286, Negative loglikelihood: 7454.86767578125 
Epoch:15-255/286, Negative loglikelihood: 7454.8056640625 
Epoch:15-256/286, Negative loglikelihood: 7455.5478515625 
Epoch:15-257/286, Negative loglikelihood: 7455.8974609375 
Epoch:15-258/286, Negative loglikelihood: 7455.58154296875 
Epoch:15-259/286, Negative loglikelihood: 7455.13671875 
Epoch:15-260/286, Negative loglikelihood: 7455.88330078125 
Epoch:15-261/286, Negative loglikelihood: 7454.85595703125 
Epoch:15-262/286, Negative loglikelihood: 7455.04541015625 
Epoch:15-263/286, Negative loglikelihood: 7456.41064453125 
Epoch:15-264/286, Negative loglikelihood: 7456.0615234375 
Epoch:15-265/286, Negative loglikelihood: 7455.435546875 
Epoch:15-266/286, Negative loglikelihood: 7454.9921875 
Epoch:15-267/286, Negative loglikelihood: 7455.927734375 
Epoch:15-268/286, Negative loglikelihood: 7454.99609375 
Epoch:15-269/286, Negative loglikelihood: 7454.4736328125 
Epoch:15-270/286, Negative loglikelihood: 7456.208984375 
E

Epoch:16-108/286, Negative loglikelihood: 7455.27294921875 
Epoch:16-109/286, Negative loglikelihood: 7455.203125 
Epoch:16-110/286, Negative loglikelihood: 7454.4072265625 
Epoch:16-111/286, Negative loglikelihood: 7455.376953125 
Epoch:16-112/286, Negative loglikelihood: 7456.0419921875 
Epoch:16-113/286, Negative loglikelihood: 7454.57080078125 
Epoch:16-114/286, Negative loglikelihood: 7454.3955078125 
Epoch:16-115/286, Negative loglikelihood: 7454.634765625 
Epoch:16-116/286, Negative loglikelihood: 7454.458984375 
Epoch:16-117/286, Negative loglikelihood: 7455.509765625 
Epoch:16-118/286, Negative loglikelihood: 7455.1650390625 
Epoch:16-119/286, Negative loglikelihood: 7454.994140625 
Epoch:16-120/286, Negative loglikelihood: 7454.51025390625 
Epoch:16-121/286, Negative loglikelihood: 7455.2099609375 
Epoch:16-122/286, Negative loglikelihood: 7454.97998046875 
Epoch:16-123/286, Negative loglikelihood: 7454.59326171875 
Epoch:16-124/286, Negative loglikelihood: 7455.4560546875 
E

Epoch:16-249/286, Negative loglikelihood: 7454.125 
Epoch:16-250/286, Negative loglikelihood: 7453.677734375 
Epoch:16-251/286, Negative loglikelihood: 7454.916015625 
Epoch:16-252/286, Negative loglikelihood: 7454.70263671875 
Epoch:16-253/286, Negative loglikelihood: 7454.39013671875 
Epoch:16-254/286, Negative loglikelihood: 7454.1279296875 
Epoch:16-255/286, Negative loglikelihood: 7453.630859375 
Epoch:16-256/286, Negative loglikelihood: 7453.67041015625 
Epoch:16-257/286, Negative loglikelihood: 7453.6552734375 
Epoch:16-258/286, Negative loglikelihood: 7453.70458984375 
Epoch:16-259/286, Negative loglikelihood: 7453.451171875 
Epoch:16-260/286, Negative loglikelihood: 7454.3857421875 
Epoch:16-261/286, Negative loglikelihood: 7454.15234375 
Epoch:16-262/286, Negative loglikelihood: 7454.4326171875 
Epoch:16-263/286, Negative loglikelihood: 7454.5625 
Epoch:16-264/286, Negative loglikelihood: 7455.62353515625 
Epoch:16-265/286, Negative loglikelihood: 7455.587890625 
Epoch:16-266

Epoch:17-103/286, Negative loglikelihood: 7452.8759765625 
Epoch:17-104/286, Negative loglikelihood: 7453.2177734375 
Epoch:17-105/286, Negative loglikelihood: 7453.72412109375 
Epoch:17-106/286, Negative loglikelihood: 7452.931640625 
Epoch:17-107/286, Negative loglikelihood: 7454.5615234375 
Epoch:17-108/286, Negative loglikelihood: 7453.47265625 
Epoch:17-109/286, Negative loglikelihood: 7456.09375 
Epoch:17-110/286, Negative loglikelihood: 7454.802734375 
Epoch:17-111/286, Negative loglikelihood: 7453.8984375 
Epoch:17-112/286, Negative loglikelihood: 7454.875 
Epoch:17-113/286, Negative loglikelihood: 7453.44140625 
Epoch:17-114/286, Negative loglikelihood: 7455.205078125 
Epoch:17-115/286, Negative loglikelihood: 7454.1513671875 
Epoch:17-116/286, Negative loglikelihood: 7454.0224609375 
Epoch:17-117/286, Negative loglikelihood: 7452.248046875 
Epoch:17-118/286, Negative loglikelihood: 7452.85546875 
Epoch:17-119/286, Negative loglikelihood: 7453.775390625 
Epoch:17-120/286, Nega

Epoch:17-244/286, Negative loglikelihood: 7453.88330078125 
Epoch:17-245/286, Negative loglikelihood: 7451.88916015625 
Epoch:17-246/286, Negative loglikelihood: 7452.84033203125 
Epoch:17-247/286, Negative loglikelihood: 7453.068359375 
Epoch:17-248/286, Negative loglikelihood: 7453.26953125 
Epoch:17-249/286, Negative loglikelihood: 7452.64111328125 
Epoch:17-250/286, Negative loglikelihood: 7453.44921875 
Epoch:17-251/286, Negative loglikelihood: 7453.67919921875 
Epoch:17-252/286, Negative loglikelihood: 7455.27880859375 
Epoch:17-253/286, Negative loglikelihood: 7453.12890625 
Epoch:17-254/286, Negative loglikelihood: 7453.0 
Epoch:17-255/286, Negative loglikelihood: 7452.8193359375 
Epoch:17-256/286, Negative loglikelihood: 7452.0400390625 
Epoch:17-257/286, Negative loglikelihood: 7454.9150390625 
Epoch:17-258/286, Negative loglikelihood: 7453.5654296875 
Epoch:17-259/286, Negative loglikelihood: 7456.0966796875 
Epoch:17-260/286, Negative loglikelihood: 7453.8896484375 
Epoch:1

Epoch:18-97/286, Negative loglikelihood: 7452.16162109375 
Epoch:18-98/286, Negative loglikelihood: 7451.8505859375 
Epoch:18-99/286, Negative loglikelihood: 7452.77685546875 
Epoch:18-100/286, Negative loglikelihood: 7453.4267578125 
Epoch:18-101/286, Negative loglikelihood: 7453.26611328125 
Epoch:18-102/286, Negative loglikelihood: 7453.73388671875 
Epoch:18-103/286, Negative loglikelihood: 7454.9912109375 
Epoch:18-104/286, Negative loglikelihood: 7453.6904296875 
Epoch:18-105/286, Negative loglikelihood: 7452.79443359375 
Epoch:18-106/286, Negative loglikelihood: 7454.85546875 
Epoch:18-107/286, Negative loglikelihood: 7454.71923828125 
Epoch:18-108/286, Negative loglikelihood: 7452.77783203125 
Epoch:18-109/286, Negative loglikelihood: 7453.3505859375 
Epoch:18-110/286, Negative loglikelihood: 7452.580078125 
Epoch:18-111/286, Negative loglikelihood: 7453.3095703125 
Epoch:18-112/286, Negative loglikelihood: 7453.41259765625 
Epoch:18-113/286, Negative loglikelihood: 7454.2070312

Epoch:18-237/286, Negative loglikelihood: 7453.53515625 
Epoch:18-238/286, Negative loglikelihood: 7454.7509765625 
Epoch:18-239/286, Negative loglikelihood: 7452.9296875 
Epoch:18-240/286, Negative loglikelihood: 7454.435546875 
Epoch:18-241/286, Negative loglikelihood: 7452.9921875 
Epoch:18-242/286, Negative loglikelihood: 7453.8671875 
Epoch:18-243/286, Negative loglikelihood: 7454.6689453125 
Epoch:18-244/286, Negative loglikelihood: 7453.2919921875 
Epoch:18-245/286, Negative loglikelihood: 7452.9482421875 
Epoch:18-246/286, Negative loglikelihood: 7452.681640625 
Epoch:18-247/286, Negative loglikelihood: 7453.17041015625 
Epoch:18-248/286, Negative loglikelihood: 7452.40380859375 
Epoch:18-249/286, Negative loglikelihood: 7452.94189453125 
Epoch:18-250/286, Negative loglikelihood: 7452.421875 
Epoch:18-251/286, Negative loglikelihood: 7453.97705078125 
Epoch:18-252/286, Negative loglikelihood: 7453.6484375 
Epoch:18-253/286, Negative loglikelihood: 7453.0185546875 
Epoch:18-254/

Epoch:19-90/286, Negative loglikelihood: 7452.95947265625 
Epoch:19-91/286, Negative loglikelihood: 7453.19921875 
Epoch:19-92/286, Negative loglikelihood: 7453.5302734375 
Epoch:19-93/286, Negative loglikelihood: 7453.22119140625 
Epoch:19-94/286, Negative loglikelihood: 7453.078125 
Epoch:19-95/286, Negative loglikelihood: 7453.4189453125 
Epoch:19-96/286, Negative loglikelihood: 7453.21875 
Epoch:19-97/286, Negative loglikelihood: 7453.474609375 
Epoch:19-98/286, Negative loglikelihood: 7452.34765625 
Epoch:19-99/286, Negative loglikelihood: 7451.80419921875 
Epoch:19-100/286, Negative loglikelihood: 7452.9677734375 
Epoch:19-101/286, Negative loglikelihood: 7452.1884765625 
Epoch:19-102/286, Negative loglikelihood: 7451.91455078125 
Epoch:19-103/286, Negative loglikelihood: 7453.68603515625 
Epoch:19-104/286, Negative loglikelihood: 7453.345703125 
Epoch:19-105/286, Negative loglikelihood: 7453.70068359375 
Epoch:19-106/286, Negative loglikelihood: 4968.63232421875 
Epoch:19-107/28

Epoch:19-230/286, Negative loglikelihood: 7452.166015625 
Epoch:19-231/286, Negative loglikelihood: 7453.291015625 
Epoch:19-232/286, Negative loglikelihood: 7453.1162109375 
Epoch:19-233/286, Negative loglikelihood: 7452.9287109375 
Epoch:19-234/286, Negative loglikelihood: 7452.9228515625 
Epoch:19-235/286, Negative loglikelihood: 7452.95556640625 
Epoch:19-236/286, Negative loglikelihood: 7453.4501953125 
Epoch:19-237/286, Negative loglikelihood: 7453.87109375 
Epoch:19-238/286, Negative loglikelihood: 7453.0625 
Epoch:19-239/286, Negative loglikelihood: 7453.54638671875 
Epoch:19-240/286, Negative loglikelihood: 7451.90771484375 
Epoch:19-241/286, Negative loglikelihood: 7453.5693359375 
Epoch:19-242/286, Negative loglikelihood: 7451.4853515625 
Epoch:19-243/286, Negative loglikelihood: 7452.64453125 
Epoch:19-244/286, Negative loglikelihood: 7452.97802734375 
Epoch:19-245/286, Negative loglikelihood: 7452.02734375 
Epoch:19-246/286, Negative loglikelihood: 7452.1181640625 
Epoch:1

In [11]:
from sklearn.metrics import classification_report,  accuracy_score, precision_score, recall_score, f1_score
import pandas as pd

def evaluate(model, predict_dataloader, batch_size, epoch_th, dataset_name):
    # print("***** Running prediction *****")
    model.eval()
    all_preds = []
    all_labels = []
    total=0
    correct=0
    start = time.time()
    
    len_of_sentences = []
    total_length = 0
    
    with torch.no_grad():

        for batch in predict_dataloader:
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, predict_mask, label_ids = batch

            for i in range(0, len(label_ids)):
                sentence_true_label = []
                for item in label_ids[i]:
                    if item not in [0,1,2]:
                        sentence_true_label.append(item.tolist())

                len_of_sentences.append(len(sentence_true_label))
                total_length += len(sentence_true_label)
          
            _, predicted_label_seq_ids = model(input_ids, segment_ids, input_mask)           

            # _, predicted = torch.max(out_scores, -1)
            valid_predicted = torch.masked_select(predicted_label_seq_ids, predict_mask)
#             print ("Linh test prediction: ", len(valid_predicted))
            valid_label_ids = torch.masked_select(label_ids, predict_mask)
#             print ("Linh test true: ", len(valid_label_ids))
            all_preds.extend(valid_predicted.tolist())
            all_labels.extend(valid_label_ids.tolist())
            # print(len(valid_label_ids),len(valid_predicted),len(valid_label_ids)==len(valid_predicted))
            total += len(valid_label_ids)
            correct += valid_predicted.eq(valid_label_ids).sum().item()
    print ("Length of predictions: ", len(all_preds))
    print ("Length of true labels: ", len(all_labels))
    print ("Length of test set: ", total_length)
    print ("Length of sentences list: ", len_of_sentences)

    print (all_preds)
    unique_labels = list(set(all_labels))
    unique_labels.remove(3)
    
    test_acc = correct/total
    precision, recall, f1 = f1_score_custom(np.array(all_labels), np.array(all_preds))
        
    print ('Accuracy:', accuracy_score(all_labels,all_preds))
    print ('Precision:', precision_score(all_labels,all_preds, average = "weighted",labels = unique_labels))
    print ('Recall:', recall_score(all_labels,all_preds,average = "weighted",labels = unique_labels))
    print ('F1 score:', f1_score(all_labels,all_preds,average = "weighted",labels = unique_labels))
    print (classification_report(all_labels,all_preds,labels = unique_labels))
    
#     report = classification_report(all_labels,all_preds, labels=unique_labels, output_dict=True)
#     report_df = pd.DataFrame(report).transpose()
#     report_df.to_csv("/home/lhoang2/Notebooks/RCT_Methodology_Extraction/DATA_07272022/RESULTS_07272022/test_set_lr35_epoch20_positive_sentences_only_CRF_07272022.csv")
#     textfile = open("/home/lhoang2/Notebooks/RCT_Methodology_Extraction/DATA_07272022/RESULTS_07272022/outputs_lr35_epoch20_positive_sentences_only_CRF_07272022.txt", "w")
#     for element in all_preds:
#         textfile.write(str(element) + "\n")
#     textfile.close()

    
#     textfile2 = open("/home/lhoang2/Notebooks/RCT_Methodology_Extraction/DATA_07272022/RESULTS_07272022/true_labels_test_CRF_07272022.txt", "w")
#     for element in all_labels:
#         textfile2.write(str(element) + "\n")
#     textfile2.close()
    
#     end = time.time()
#     print('Epoch:%d, Acc:%.2f, Precision: %.2f, Recall: %.2f, F1: %.2f on %s, Spend:%.3f minutes for evaluation' \
#         % (epoch_th, 100.*test_acc, 100.*precision, 100.*recall, 100.*f1, dataset_name,(end-start)/60.0))
#     print('--------------------------------------------------------------')
#     return test_acc, f1

def f1_score_custom(y_true, y_pred):
    '''
    0,1,2,3 are [CLS],[SEP],[X],O
    '''
    ignore_id=3

    num_proposed = len(y_pred[y_pred>ignore_id])
    num_correct = (np.logical_and(y_true==y_pred, y_true>ignore_id)).sum()
    num_gold = len(y_true[y_true>ignore_id])

    try:
        precision = num_correct / num_proposed
    except ZeroDivisionError:
        precision = 1.0

    try:
        recall = num_correct / num_gold
    except ZeroDivisionError:
        recall = 1.0

    try:
        f1 = 2*precision*recall / (precision + recall)
    except ZeroDivisionError:
        if precision*recall==0:
            f1=1.0
        else:
            f1=0

    return precision, recall, f1

In [12]:
evaluate(model, test_dataloader, batch_size, total_train_epochs-1, 'Test_set')

  valid_predicted = torch.masked_select(predicted_label_seq_ids, predict_mask)
  valid_label_ids = torch.masked_select(label_ids, predict_mask)


Length of predictions:  55483
Length of true labels:  55483
Length of test set:  55483
Length of sentences list:  [13, 43, 3, 45, 12, 50, 31, 26, 24, 58, 3, 30, 22, 53, 72, 3, 11, 27, 32, 17, 48, 35, 86, 49, 62, 38, 32, 17, 18, 26, 180, 21, 63, 14, 22, 44, 18, 20, 3, 20, 89, 146, 10, 53, 67, 58, 69, 22, 54, 24, 28, 72, 27, 40, 30, 32, 21, 30, 35, 24, 46, 59, 47, 42, 39, 12, 34, 35, 21, 53, 29, 60, 10, 24, 56, 30, 16, 16, 55, 35, 45, 13, 43, 3, 45, 12, 50, 31, 26, 24, 58, 3, 30, 22, 53, 72, 3, 11, 27, 32, 17, 48, 35, 86, 49, 62, 38, 32, 17, 18, 26, 180, 21, 63, 14, 22, 44, 18, 20, 3, 20, 89, 146, 10, 53, 67, 58, 69, 22, 54, 24, 28, 72, 27, 40, 30, 32, 21, 30, 35, 24, 46, 59, 47, 42, 39, 12, 34, 35, 21, 53, 29, 60, 10, 24, 56, 30, 16, 16, 55, 35, 45, 14, 25, 53, 20, 17, 34, 16, 24, 54, 38, 35, 53, 12, 27, 27, 30, 30, 15, 66, 23, 16, 15, 31, 18, 46, 59, 32, 38, 29, 41, 29, 35, 54, 42, 14, 14, 40, 24, 28, 24, 28, 34, 26, 32, 30, 18, 13, 20, 23, 44, 10, 18, 40, 33, 43, 10, 75, 54, 29, 36, 4

F1 score: 0.5044417564846099
              precision    recall  f1-score   support

           4       0.33      0.75      0.46         4
           6       1.00      1.00      1.00         1
           7       0.32      1.00      0.48         6
           8       0.50      0.71      0.59         7
          10       0.20      0.33      0.25         3
          11       0.62      0.97      0.76        34
          12       0.69      0.49      0.57        98
          13       0.42      0.89      0.57         9
          14       0.31      0.77      0.44        26
          15       0.17      0.80      0.28        10
          16       0.00      0.00      0.00         1
          17       0.00      0.00      0.00         5
          18       0.36      0.67      0.47        15
          19       0.00      0.00      0.00         2
          21       0.00      0.00      0.00         4
          22       0.86      0.50      0.63        12
          23       0.00      0.00      0.00         

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
