In [32]:
import argparse
import logging
import os
import random
import sys
import numpy as np
import torch
import torch.nn.functional as F
from seqeval.metrics import f1_score, precision_score, recall_score
from seqeval.metrics.sequence_labeling import get_entities, performance_measure
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from tqdm import trange
from tqdm.notebook import tqdm
from crf import CRFInference

from tqdm import tqdm


In [25]:


from transformers import (
    AdamW,
    BertConfig,
    BertTokenizer,
    set_seed
)
from utils import (
    convert_examples_to_features,
    read_examples_from_file,
    BertForTokenClassification,
    get_labels, filtered_tp_counts)

logger = logging.getLogger(__name__)

from unittest.mock import MagicMock

In [33]:

def set_seeds(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

def finetune_support(args, model, tokenizer, labels, pad_token_label_id):
    previous_score = 1e+6 # infinity placeholder
    sup_dataset = read_and_load_examples(args, tokenizer, labels, pad_token_label_id, mode=args.support_path,
                                            mergeB=True)
    sampler = SequentialSampler(sup_dataset)
    dataloader = DataLoader(sup_dataset, sampler=sampler, batch_size=len(sup_dataset))

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate_finetuning, eps=args.adam_epsilon)
    # Train!

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    rep_index = -1

    set_seeds(args)
    while(True):
        rep_index += 1
        epoch_iterator = tqdm(dataloader, desc="Iteration", disable=True)
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            # here loss can be either KL, or euclidean.
            inputs = {"input_ids": batch[0], "attention_mask": batch[1],
                      "token_type_ids": batch[2], "labels": batch[3],
                      "loss_type": args.finetune_loss,
                      "consider_mutual_O": args.consider_mutual_O}

            outputs = model(**inputs)
            loss = outputs[0]
            # logger.info("finetune loss at repetition "+ str(rep_index) + " : " + str(loss.item()))
            loss.backward()
            tr_loss += loss.item()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()
            model.zero_grad()

        if loss.item() > previous_score:
            # early stopping with single step patience
            break

        previous_score = loss.item()

def train(args, train_dataset, model):
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.num_train_epochs > 0:
        t_total = len(train_dataloader) * args.num_train_epochs
    else:
        t_total = 0

    # Prepare optimizer and schedule (decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
        os.path.join(args.model_name_or_path, "scheduler.pt")
    ):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))

    # Train!
    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    training_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=False
    )
    set_seeds(args)  # Added here for reproductibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=False )
        for step, batch in enumerate(epoch_iterator):
            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {"input_ids": batch[0], "attention_mask": batch[1],
                      "token_type_ids": batch[2], "labels": batch[3],"loss_type":args.training_loss,
                      "consider_mutual_O": args.consider_mutual_O}

            outputs = model(**inputs)
            loss = outputs[0]

            loss.backward()
            training_loss += loss.item()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()
            model.zero_grad()
            global_step += 1
                # TODO remove args.save_steps
    return global_step, training_loss / global_step if global_step > 0 else 0

def extract_target_labels(args, dataset, model):
    sampler = SequentialSampler(dataset)
    dataloader = DataLoader(dataset, sampler=sampler, batch_size=args.eval_batch_size)
    vecs = None
    vecs_mu = None
    vecs_sigma = None
    labels = None
    model.eval()
    for batch in tqdm(dataloader, desc="Support representations"):
        batch = tuple(t.to(args.device) for t in batch)
        label_batch = batch[3]

        with torch.no_grad():
            inputs = {"input_ids": batch[0], "attention_mask": batch[1],
                      "token_type_ids": batch[2]}
            outputs = model(**inputs)
            output_embed_mu = outputs[0]
            output_embed_sigma = outputs[1]
            hidden_states = outputs[2]

        if vecs_mu is None:
            vecs = hidden_states.detach().cpu().numpy()
            vecs_mu = output_embed_mu.detach().cpu().numpy()
            vecs_sigma = output_embed_sigma.detach().cpu().numpy()
            labels = label_batch.detach().cpu().numpy()
        else:
            vecs = np.append(vecs, hidden_states.detach().cpu().numpy(), axis=0)
            vecs_mu = np.append(vecs_mu, output_embed_mu.detach().cpu().numpy(), axis=0)
            vecs_sigma = np.append(vecs_sigma, output_embed_sigma.detach().cpu().numpy(), axis=0)
            labels = np.append(labels, label_batch.detach().cpu().numpy(), axis=0)
    _, _, hidden_size = vecs_mu.shape
    _, _, hidden_bert_size = vecs.shape
    vecs, vecs_mu, vecs_sigma, labels = vecs.reshape(-1, hidden_bert_size), vecs_mu.reshape(-1, hidden_size), vecs_sigma.reshape(-1, hidden_size), labels.reshape(-1)
    fil_vecs, fil_vecs_mu, fil_vecs_sigma, fil_labels = [], [], [], []
    for vec, vec_mu, vec_sigma, label in zip(vecs, vecs_mu, vecs_sigma, labels):
        if label == CrossEntropyLoss().ignore_index:
            continue
        fil_vecs.append(vec)
        fil_vecs_mu.append(vec_mu)
        fil_vecs_sigma.append(vec_sigma)
        fil_labels.append(label)
    vecs, vecs_mu, vecs_sigma, labels = torch.tensor(fil_vecs).to(args.device), torch.tensor(fil_vecs_mu).to(args.device), torch.Tensor(fil_vecs_sigma).to(args.device), torch.tensor(fil_labels).to(args.device)
    return vecs_mu.view(-1, hidden_size), vecs_sigma.view(-1, hidden_size), vecs.view(-1, hidden_bert_size), labels.view(-1)

def entitywise_max(scores, tags, addone=0, num_labels = None):
    # scores: n x m
    # tags: m
    # return: n x t
    n, m = scores.shape
    if num_labels == None:
        max_tag = torch.max(tags) + 1
    else:
        max_tag = num_labels # extra 1 is not needed since it's already 1 based counting
    ret = -100000. * torch.ones(n, max_tag+addone).to(scores.device)
    for t in range(addone, max_tag+addone):
        mask = (tags == (t-addone)).float().view(1, -1)
        masked = scores * mask
        masked = torch.where(masked < 0, masked, torch.tensor(-100000.).to(scores.device))
        ret[:, t] = torch.max(masked, dim=1)[0]
    return ret


def nearest_neighbor(args, rep_mus, rep_sigmas, rep_hidden_states, support_rep_mus, support_rep_sigmas, support_rep, support_tags, evaluation_criteria, num_labels):
    """
    Neariest neighbor decoder for the best named entity tag sequences
    """
    batch_size, sent_len, ndim = rep_mus.shape
    _, _, ndim_bert = rep_hidden_states.shape
    if evaluation_criteria == "KL":
        scores = _loss_kl(rep_mus.view(-1, ndim), rep_sigmas.view(-1,ndim), support_rep_mus, support_rep_sigmas, ndim)
        tags = support_tags[torch.argmin(scores, 1)]

    elif evaluation_criteria == "euclidean":
        scores = _euclidean_metric(rep_mus.view(-1, ndim), support_rep_mus, True)
        tags = support_tags[torch.argmax(scores, 1)]

    elif evaluation_criteria == "euclidean_hidden_state":
        scores = _euclidean_metric(rep_hidden_states.view(-1, ndim_bert), support_rep, True)
        tags = support_tags[torch.argmax(scores, 1)]

    else:
        raise Exception("Unknown decoding criteria detected. Please =specify KL/ euclidean/ euclidean_hidden_state")

    if args.temp_trans > 0:
        scores = entitywise_max(scores, support_tags, 1, num_labels)
        max_scores, tags = torch.max(scores, 1)
        tags = tags - 1

    return tags.view(batch_size, sent_len), scores.view(batch_size, sent_len, -1)

def _euclidean_metric(a, b, normalize=False):
    if normalize:
        a = F.normalize(a)
        b = F.normalize(b)
    n = a.shape[0]
    m = b.shape[0]
    a = a.unsqueeze(1).expand(n, m, -1)
    b = b.unsqueeze(0).expand(n, m, -1)
    logits = -((a - b) ** 2).sum(dim=2)
    return logits

def _loss_kl(mu_i, sigma_i, mu_j, sigma_j, embed_dimension):
    n = mu_i.shape[0]
    m = mu_j.shape[0]

    mu_i = mu_i.unsqueeze(1).expand(n,m, -1)
    sigma_i = sigma_i.unsqueeze(1).expand(n,m,-1)
    mu_j = mu_j.unsqueeze(0).expand(n,m,-1)
    sigma_j = sigma_j.unsqueeze(0).expand(n,m,-1)
    sigma_ratio = sigma_j / sigma_i
    trace_fac = torch.sum(sigma_ratio, 2)
    log_det = torch.sum(torch.log(sigma_ratio + 1e-14), axis=2)
    mu_diff_sq = torch.sum((mu_i - mu_j) ** 2 / sigma_i, axis=2)
    ij_kl = 0.5 * (trace_fac + mu_diff_sq - embed_dimension - log_det)
    sigma_ratio = sigma_i / sigma_j
    trace_fac = torch.sum(sigma_ratio, 2)
    log_det = torch.sum(torch.log(sigma_ratio + 1e-14), axis=2)
    mu_diff_sq = torch.sum((mu_j - mu_i) ** 2 / sigma_j, axis=2)
    ji_kl = 0.5 * (trace_fac + mu_diff_sq - embed_dimension - log_det)
    kl_d = 0.5 * (ij_kl + ji_kl)
    return kl_d


def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""):
    sup_dataset = read_and_load_examples(args, tokenizer, labels, pad_token_label_id, mode=args.support_path, mergeB=True)
    sup_mus, sup_sigmas, sups, sup_labels = extract_target_labels(args, sup_dataset, model)
    eval_dataset = read_and_load_examples(args, tokenizer, labels, pad_token_label_id, mode=mode, mergeB=True)
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

    # Eval!
    preds = None
    out_label_ids = None

    model.eval()
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        batch = tuple(t.to(args.device) for t in batch)
        label_batch = batch[3]

        with torch.no_grad():
            inputs = {"input_ids": batch[0], "attention_mask": batch[1],
                      "token_type_ids": batch[2]}
            outputs = model(**inputs)
            hidden_states = outputs[2]
            output_embedding_mu = outputs[0]
            output_embedding_sigma = outputs[1]

            nn_predictions, nn_scores = nearest_neighbor(args, output_embedding_mu, output_embedding_sigma, hidden_states, sup_mus, sup_sigmas, sups, sup_labels, evaluation_criteria=args.evaluation_criteria, num_labels=len(labels))
        if preds is None:
            preds = nn_predictions.detach().cpu().numpy()
            scores = nn_scores.detach().cpu().numpy()
            out_label_ids = label_batch.detach().cpu().numpy()

        else:
            preds = np.append(preds, nn_predictions.detach().cpu().numpy(), axis=0)
            scores = np.append(scores, nn_scores.detach().cpu().numpy(), axis=0)
            out_label_ids = np.append(out_label_ids, label_batch.detach().cpu().numpy(), axis=0)

    merged_labels = [label for label in labels if not label.startswith('I-')]
    conv_labels = []
    for label in merged_labels:
        if label.startswith('B-'):
            conv_labels.append('I-' + label[2:])
        else:
            conv_labels.append(label)
    label_map = {i: label for i, label in enumerate(conv_labels)}

    out_label_list = [[] for _ in range(out_label_ids.shape[0])]
    scores_list = [[] for _ in range(out_label_ids.shape[0])]
    preds_list = [[] for _ in range(out_label_ids.shape[0])]

    for i in range(out_label_ids.shape[0]):
        for j in range(out_label_ids.shape[1]):
            if out_label_ids[i, j] != pad_token_label_id:
                out_label_list[i].append(label_map[out_label_ids[i][j]])
                scores_list[i].append(scores[i][j])
                preds_list[i].append(label_map[preds[i][j]])

    if args.temp_trans > 0:
        # START: Viterbi!!!
        vit_preds_list = [[] for _ in range(out_label_ids.shape[0])]
        crf = CRFInference(len(label_map) + 1, args.trans_priors, args.temp_trans)
        for i in range(out_label_ids.shape[0]):
            sent_scores = torch.tensor(scores_list[i])
            sent_probs = F.softmax(sent_scores, dim=1)
            sent_len, n_tag = sent_probs.shape
            feats = crf.forward(torch.log(sent_probs).view(1, sent_len, n_tag))
            vit_tags = crf.viterbi(feats)
            vit_tags = vit_tags.view(sent_len)
            vit_tags = vit_tags.detach().cpu().numpy()
            for tag in vit_tags:
                vit_preds_list[i].append(label_map[tag - 1])
        preds_list = vit_preds_list
        # END

    performance_dict = performance_measure(out_label_list, preds_list)
    pred_sum, tp_sum, true_sum = filtered_tp_counts(out_label_list, preds_list)
    results = {
        "precision": precision_score(out_label_list, preds_list),
        "recall": recall_score(out_label_list, preds_list),
        "f1": f1_score(out_label_list, preds_list),
        "TP": performance_dict['TP'],
        "TN": performance_dict['TN'],
        "FP": performance_dict['FP'],
        "FN": performance_dict['FN'],
        "pred_sum": pred_sum,
        "tp_sum": tp_sum,
        "true_sum": true_sum
    }

    logger.info("***** Eval results %s *****", prefix)
    for key in sorted(results.keys()):
        logger.info("  %s = %s", key, str(results[key]))

    return results, preds_list


def read_and_load_examples(args, tokenizer, labels, pad_token_label_id, mode, mergeB=False):
    examples = read_examples_from_file(args.data_dir, mode)
    features, label_map = convert_examples_to_features(
        examples,
        labels,
        args.max_seq_length,
        tokenizer,
        cls_token_at_end=False,
        cls_token=tokenizer.cls_token,
        cls_token_segment_id=0,
        sep_token=tokenizer.sep_token,
        sep_token_extra=False,
        pad_on_left=False,
        pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
        pad_token_segment_id=0,
        pad_token_label_id
        =pad_token_label_id,
        mergeB=mergeB,
    )

    # Convert to Tensors
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
    all_label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long)

    dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)

    return dataset

def trans_stats(args, labels):
    '''

    Reference: https://aclanthology.org/2020.emnlp-main.516.pdf
    '''
    tag_lists = get_tags(args.data_dir + '/train.txt', labels, True)
    s_o, s_i = 0., 0.
    o_o, o_i = 0., 0.
    i_o, i_i, x_y = 0., 0., 0.
    for tags in tag_lists:
        if tags[0] == 'O': s_o += 1
        else: s_i += 1
        for i in range(len(tags)-1):
            p, n = tags[i], tags[i+1]
            if p == 'O':
                if n == 'O': o_o += 1
                else: o_i += 1
            else:
                if n == 'O':
                    i_o += 1
                elif p != n:
                    x_y += 1
                else:
                    i_i += 1
    ret = []
    ret.append(s_o / (s_o + s_i))
    ret.append(s_i / (s_o + s_i))
    ret.append(o_o / (o_o + o_i))
    ret.append(o_i / (o_o + o_i))
    ret.append(i_o / (i_o + i_i + x_y))
    ret.append(i_i / (i_o + i_i + x_y))
    ret.append(x_y / (i_o + i_i + x_y))

    return ret


def get_tags(fname, labels, to_I=False):
    tag_lists = []
    tag_list = []
    with open(fname) as f:
        for line in f:
            if line.startswith("-DOCSTART-") or line.strip() == "":
                if tag_list:
                    tag_lists.append(tag_list)
                    tag_list = []
            else:
                splits = line.split()
                if len(splits) > 1:
                    tag = splits[1]
                    if tag not in labels:
                        tag = 'O'
                    if to_I and tag.startswith('B-'):
                        tag = 'I-' + tag[2:]
                    tag_list.append(tag)
        if tag_list:
            tag_lists.append(tag_list)

    return tag_lists


def convert_examples_to_features(
        examples,
        label_list,
        max_seq_length,
        tokenizer,
        cls_token_at_end=False,
        cls_token="[CLS]",
        cls_token_segment_id=1,
        sep_token="[SEP]",
        sep_token_extra=False,
        pad_on_left=False,
        pad_token=0,
        pad_token_segment_id=0,
        pad_token_label_id=-100,
        sequence_a_segment_id=0,
        mask_padding_with_zero=True,
        mergeB=False,
):
    """ Loads a data file into a list of `InputBatch`s
        `cls_token_at_end` define the location of the CLS token:
            - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
            - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
        `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
    """

    label_map = defaultdict(int)
    if not mergeB:
        for i, label in enumerate(label_list):
            label_map[label] = i
    else:
        i = 0
        for label in label_list:
            if label.startswith('B-') or label.startswith('I-'):
                label_str = 'I-' + label[2:]
                if label_str not in label_map:
                    label_map[label_str] = i
                    i += 1
                label_map[label] = label_map[label_str]
            else:
                label_map[label] = i
                i += 1

    features = []
    for (ex_index, example) in tqdm(enumerate(examples)):

        tokens = []
        label_ids = []
        for word, label in zip(example.words, example.labels):
            word_tokens = tokenizer.tokenize(word)
            if len(word_tokens) == 0:
                continue
            tokens.extend(word_tokens)
            # Use the real label id for the first token of the word, and padding ids for the remaining tokens
            label_ids.extend([label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1))

        # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
        special_tokens_count = 3 if sep_token_extra else 2
        if len(tokens) > max_seq_length - special_tokens_count:
            tokens = tokens[: (max_seq_length - special_tokens_count)]
            label_ids = label_ids[: (max_seq_length - special_tokens_count)]

        # The convention in BERT is:
        # (a) For sequence pairs:
        #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
        #  type_ids:   0   0  0    0    0     0       0   0   1  1  1  1   1   1
        # (b) For single sequences:
        #  tokens:   [CLS] the dog is hairy . [SEP]
        #  type_ids:   0   0   0   0  0     0   0
        #
        # Where "type_ids" are used to indicate whether this is the first
        # sequence or the second sequence. The embedding vectors for `type=0` and
        # `type=1` were learned during pre-training and are added to the wordpiece
        # embedding vector (and position vector). This is not *strictly* necessary
        # since the [SEP] token unambiguously separates the sequences, but it makes
        # it easier for the model to learn the concept of sequences.
        #
        # For classification tasks, the first vector (corresponding to [CLS]) is
        # used as as the "sentence vector". Note that this only makes sense because
        # the entire model is fine-tuned.
        tokens += [sep_token]
        label_ids += [pad_token_label_id]
        if sep_token_extra:
            # roberta uses an extra separator b/w pairs of sentences
            tokens += [sep_token]
            label_ids += [pad_token_label_id]
        segment_ids = [sequence_a_segment_id] * len(tokens)

        if cls_token_at_end:
            tokens += [cls_token]
            label_ids += [pad_token_label_id]
            segment_ids += [cls_token_segment_id]
        else:
            tokens = [cls_token] + tokens
            label_ids = [pad_token_label_id] + label_ids
            segment_ids = [cls_token_segment_id] + segment_ids
        assert len(tokens) == len(label_ids), str(tokens) + " vs" + str(label_ids)
        input_ids = tokenizer.convert_tokens_to_ids(tokens)
        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

        # Zero-pad up to the sequence length.
        padding_length = max_seq_length - len(input_ids)
        if pad_on_left:
            input_ids = ([pad_token] * padding_length) + input_ids
            input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
            segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
            label_ids = ([pad_token_label_id] * padding_length) + label_ids
        else:
            input_ids += [pad_token] * padding_length
            input_mask += [0 if mask_padding_with_zero else 1] * padding_length
            segment_ids += [pad_token_segment_id] * padding_length
            label_ids += [pad_token_label_id] * padding_length

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        assert len(label_ids) == max_seq_length


        features.append(
            InputFeatures(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_ids=label_ids)
        )
    return features, label_map


In [27]:
class test_arg:
    def __init__(self):
        self.test = 0
        
    

In [28]:
args = test_arg()
args.data_dir = 'C:\\Users\\George\\Documents\\container_ner\\few-nerd\\inter\\'
args.labels_train = 'C:\\Users\\George\\Documents\\container_ner\\CONTaiNER\\data\\few-nerd\\inter\\labels_train.txt' 
args.labels_test = 'C:\\Users\\George\\Documents\\container_ner\\CONTaiNER\\data\\few-nerd\\inter\\labels_test.txt'
args.model_name_or_path = 'bert-base-uncased'
args.config_name = None
args.tokenizer_name = ''
args.do_train = True
args.do_predict = True


args.saved_model_dir = './test_save/' 
args.output_dir = './test_output/'
args.cache_dir = './cache/'

args.max_seq_length = 128 
args.embedding_dimension = 128 
args.num_train_epochs = 1 
args.train_batch_size = 16
args.weight_decay = 0
args.adam_epsilon = 1e-8

args.max_grad_norm = 1.0
args.learning_rate = 5e-5
args.training_loss = 'KL'
args.finetune_loss = 'KL'
args.evaluation_criteria = 'euclidean_hidden_state'
args.n_shots = 5
args.consider_mutual_O = 'store_true'
args.learning_rate_finetuning =5e-5
args.silent = 'store_true'
args.temp_trans = 0.01
args.do_finetune_support_only = True
args.overwrite_output_dir = 'store_true'



device = torch.device("cuda:0")
args.device = torch.device("cuda:0")
args.no_cuda = False
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
args.n_gpu = min(1, args.n_gpu) # we are keeping ourselves restricted to only 1 gpu
args.best_validation_f1 = -1
args.seed = 42

In [30]:
args = test_arg()
args.data_dir = 'C:\\Users\\George\\Documents\\container_ner\\2018_n2c2\\few_nerd_format\\'
args.labels_train = 'C:\\Users\\George\\Documents\\container_ner\\2018_n2c2\\few_nerd_format\\labels_train.txt' 
args.labels_test = 'C:\\Users\\George\\Documents\\container_ner\\2018_n2c2\\few_nerd_format\\labels_test.txt'
args.model_name_or_path = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext'
args.config_name = None
args.tokenizer_name = ''
args.do_train = True
args.do_predict = True


args.saved_model_dir = './test_save/' 
args.output_dir = './test_output/'
args.cache_dir = './cache/'

args.max_seq_length = 128 
args.embedding_dimension = 128 
args.num_train_epochs = 10
args.train_batch_size = 16
args.weight_decay = 0
args.adam_epsilon = 1e-8

args.max_grad_norm = 1.0
args.learning_rate = 5e-5
args.training_loss = 'KL'
args.finetune_loss = 'KL'
args.evaluation_criteria = 'euclidean_hidden_state'
args.n_shots = 5
args.consider_mutual_O = 'store_true'
args.learning_rate_finetuning =5e-5
args.silent = 'store_true'
args.temp_trans = 0.01
args.do_finetune_support_only = True
args.overwrite_output_dir = 'store_true'



device = torch.device("cuda:0")
args.device = torch.device("cuda:0")
args.no_cuda = False
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
args.n_gpu = min(1, args.n_gpu) # we are keeping ourselves restricted to only 1 gpu
args.best_validation_f1 = -1
args.seed = 42

In [7]:
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO
)
logger.warning(
    "Device: %s, n_gpu: %s",
    args.device,
    args.n_gpu,
)




In [8]:
set_seeds(args)
labels_train = get_labels(args.labels_train)
labels_test = get_labels(args.labels_test)
num_labels = len(labels_train)
pad_token_label_id = CrossEntropyLoss().ignore_index

In [9]:
torch.cuda.empty_cache()

In [10]:

config = BertConfig.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        num_labels=num_labels,
        id2label={str(i): label for i, label in enumerate(labels_train)},
        label2id={label: i for i, label in enumerate(labels_train)},
        cache_dir=args.cache_dir if args.cache_dir else None,
        task_specific_params={"embedding_dimension": args.embedding_dimension}
    )
TOKENIZER_ARGS = ["do_lower_case", "strip_accents", "keep_accents", "use_fast"]

tokenizer_args = {k: v for k, v in vars(args).items() if v is not None and k in TOKENIZER_ARGS}
tokenizer = BertTokenizer.from_pretrained(
    args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
    cache_dir=args.cache_dir if args.cache_dir else None,
    **tokenizer_args,
)
model = BertForTokenClassification.from_pretrained(
    args.model_name_or_path,
    from_tf=bool(".ckpt" in args.model_name_or_path),
    config=config,
    cache_dir=args.cache_dir if args.cache_dir else None
)


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForTokenClassification: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-u

In [11]:
model.to(args.device)
logger.info("Training/evaluation parameters %s", args)

07/17/2023 23:05:21 - INFO - __main__ -   Training/evaluation parameters <__main__.test_arg object at 0x00000207A60FE820>


In [12]:
# # Training
# if args.do_train:
#     train_dataset = read_and_load_examples(args, tokenizer, labels_train, pad_token_label_id, mode="train", mergeB=True)
#     torch.save(train_dataset, 'train_dataset_n2c2.pt')

# train_dataset = torch.load('train_dataset.pt')
train_dataset = torch.load('train_dataset_n2c2.pt')

In [34]:
global_step, tr_loss = train(args, train_dataset, model)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

Epoch:   0%|                                                                                    | 0/10 [00:00<?, ?it/s]
Iteration:   0%|                                                                              | 0/2598 [00:00<?, ?it/s][A
Iteration:   0%|                                                                      | 1/2598 [00:00<10:55,  3.96it/s][A
Iteration:   0%|                                                                      | 2/2598 [00:00<08:38,  5.00it/s][A
Iteration:   0%|                                                                      | 3/2598 [00:00<07:59,  5.41it/s][A
Iteration:   0%|                                                                      | 4/2598 [00:00<07:53,  5.48it/s][A
Iteration:   0%|▏                                                                     | 5/2598 [00:03<45:31,  1.05s/it][A
Iteration:   0%|▏                                                                     | 6/2598 [00:03<33:55,  1.27it/s][A
Iteration:   0%|▏  

KeyboardInterrupt: 

In [23]:
model.save_pretrained("C:\\Users\\George\\Documents\\container_ner\\hf_ner_example\\container_bert_best_10ep\\")
tokenizer.save_pretrained("C:\\Users\\George\\Documents\\container_ner\\hf_ner_example\\container_bert_best_10ep\\")


('C:\\Users\\George\\Documents\\container_ner\\hf_ner_example\\container_bert_best_10ep\\tokenizer_config.json',
 'C:\\Users\\George\\Documents\\container_ner\\hf_ner_example\\container_bert_best_10ep\\special_tokens_map.json',
 'C:\\Users\\George\\Documents\\container_ner\\hf_ner_example\\container_bert_best_10ep\\vocab.txt',
 'C:\\Users\\George\\Documents\\container_ner\\hf_ner_example\\container_bert_best_10ep\\added_tokens.json')

In [133]:
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
if args.do_train:
    # Create output directory if needed
    if args.saved_model_dir is not None:
        if not os.path.exists(args.saved_model_dir):
            os.makedirs(args.saved_model_dir)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    logger.info("Saving model checkpoint")

    model_to_save = (
        model.module if hasattr(model, "module") else model
    )
    if args.saved_model_dir is None:
        model_to_save.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)
        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
    else:
        model_to_save.save_pretrained(args.saved_model_dir)
        tokenizer.save_pretrained(args.saved_model_dir)
        torch.save(args, os.path.join(args.saved_model_dir, "training_args.bin"))


07/17/2023 13:29:35 - INFO - __main__ -   Saving model checkpoint


### Debug

In [104]:
import numpy as np
import logging
import os
import torch
import random
import torch.nn.functional as F
from collections import defaultdict, Counter
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import BertPreTrainedModel, BertModel
from seqeval.metrics.sequence_labeling import get_entities


In [105]:
for word, label in zip(example.words, example.labels):
    word_tokens = tokenizer.tokenize(word)
    if len(word_tokens) == 0:
        continue
    tokens.extend(word_tokens)
    # Use the real label id for the first token of the word, and padding ids for the remaining tokens
    label_ids.extend([label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1))


In [106]:
examples = read_examples_from_file(args.data_dir, 'train')

In [125]:
features, label_map = convert_examples_to_features(
    examples,
    labels_train,
    args.max_seq_length,
    tokenizer,
    cls_token_at_end=False,
    cls_token=tokenizer.cls_token,
    cls_token_segment_id=0,
    sep_token=tokenizer.sep_token,
    sep_token_extra=False,
    pad_on_left=False,
    pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
    pad_token_segment_id=0,
    pad_token_label_id
    =pad_token_label_id,
    mergeB=True,
)

130112it [03:50, 563.97it/s]


In [119]:
mergeB = True
label_list = labels_train
max_seq_length = args.max_seq_length
cls_token_at_end=False
cls_token=tokenizer.cls_token
cls_token_segment_id=0
sep_token=tokenizer.sep_token
sep_token_extra=False
pad_on_left=False
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0]
pad_token_segment_id=0
pad_token_label_id=pad_token_label_id
sequence_a_segment_id=0
mask_padding_with_zero=True

class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, input_mask, segment_ids, label_ids):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_ids = label_ids


label_map = defaultdict(int)
if not mergeB:
    for i, label in enumerate(label_list):
        label_map[label] = i
else:
    i = 0
    for label in label_list:
        if label.startswith('B-') or label.startswith('I-'):
            label_str = 'I-' + label[2:]
            if label_str not in label_map:
                label_map[label_str] = i
                i += 1
            label_map[label] = label_map[label_str]
        else:
            label_map[label] = i
            i += 1

features = []

ex_index = 50778
example = examples[ex_index]

# for (ex_index, example) in tqdm(enumerate(examples)):

tokens = []
label_ids = []
for word, label in zip(example.words, example.labels):
    word_tokens = tokenizer.tokenize(word)
    if len(word_tokens) == 0:
        continue
    tokens.extend(word_tokens)
    # Use the real label id for the first token of the word, and padding ids for the remaining tokens
    label_ids.extend([label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1))

# Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
special_tokens_count = 3 if sep_token_extra else 2
if len(tokens) > max_seq_length - special_tokens_count:
    tokens = tokens[: (max_seq_length - special_tokens_count)]
    label_ids = label_ids[: (max_seq_length - special_tokens_count)]

# The convention in BERT is:
# (a) For sequence pairs:
#  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
#  type_ids:   0   0  0    0    0     0       0   0   1  1  1  1   1   1
# (b) For single sequences:
#  tokens:   [CLS] the dog is hairy . [SEP]
#  type_ids:   0   0   0   0  0     0   0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambiguously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens += [sep_token]
label_ids += [pad_token_label_id]
if sep_token_extra:
    # roberta uses an extra separator b/w pairs of sentences
    tokens += [sep_token]
    label_ids += [pad_token_label_id]
segment_ids = [sequence_a_segment_id] * len(tokens)

if cls_token_at_end:
    tokens += [cls_token]
    label_ids += [pad_token_label_id]
    segment_ids += [cls_token_segment_id]
else:
    tokens = [cls_token] + tokens
    label_ids = [pad_token_label_id] + label_ids
    segment_ids = [cls_token_segment_id] + segment_ids
assert len(tokens) == len(label_ids), str(tokens) + " vs" + str(label_ids)
# input_ids = tokenizer.convert_tokens_to_ids(tokens)
# # The mask has 1 for real tokens and 0 for padding tokens. Only real
# # tokens are attended to.
# input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

# # Zero-pad up to the sequence length.
# padding_length = max_seq_length - len(input_ids)
# if pad_on_left:
#     input_ids = ([pad_token] * padding_length) + input_ids
#     input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
#     segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
#     label_ids = ([pad_token_label_id] * padding_length) + label_ids
# else:
#     input_ids += [pad_token] * padding_length
#     input_mask += [0 if mask_padding_with_zero else 1] * padding_length
#     segment_ids += [pad_token_segment_id] * padding_length
#     label_ids += [pad_token_label_id] * padding_length

# assert len(input_ids) == max_seq_length
# assert len(input_mask) == max_seq_length
# assert len(segment_ids) == max_seq_length
# assert len(label_ids) == max_seq_length


# features.append(
#     InputFeatures(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_ids=label_ids)
# )

In [121]:
len(tokens)

57

In [122]:
len(label_ids)

57