In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
import torch.optim as optim
from transformers import DistilBertTokenizer,AdamW,DistilBertPreTrainedModel,  DistilBertModel, DistilBertConfig
import numpy as np
import os
from tqdm import tqdm, trange
from seqeval.metrics import f1_score, precision_score, recall_score, classification_report
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
torch.manual_seed(1)
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
device = 'cuda'
print(tokenizer.tokenize('shubham'))
print(tokenizer.vocab_size)

I0305 19:02:47.621206 140580620220224 file_utils.py:35] PyTorch version 1.2.0 available.
I0305 19:02:50.027409 140580620220224 tokenization_utils.py:398] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/shubham/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084


['shu', '##bham']
30522


In [2]:


class InputExample(object):
    """A single training/test example for token classification."""

    def __init__(self, guid, words, labels):
        """Constructs a InputExample.

        Args:
            guid: Unique id for the example.
            words: list. The words of the sequence.
            labels: (Optional) list. The labels for each word of the sequence. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.words = words
        self.labels = labels


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, input_mask, segment_ids, label_ids):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_ids = label_ids

def read_examples_from_file(data_dir, mode):
    file_path = os.path.join(data_dir, "{}.txt".format(mode))
    guid_index = 1
    examples = []
    with open(file_path, encoding="utf-8") as f:
        words = []
        labels = []
        for line in f:
            if line.startswith("-DOCSTART-") or line == "" or line == "\n":
                if words:
                    examples.append(InputExample(guid="{}-{}".format(mode, guid_index),
                                                 words=words,
                                                 labels=labels))
                    guid_index += 1
                    words = []
                    labels = []
            else:
                splits = line.split(" ")
                words.append(splits[0])
                if len(splits) > 1:
                    lab = splits[-3].replace("\n", "")
                    #print('#####################:',lab,splits[-4].replace("\n", ""))
                    if '$' in lab and len(list(lab))>1:
                        labels.append(lab[:])
                    elif not lab.isalnum():
                        labels.append("PUNC")
                    else:
                        labels.append(lab[:])
                else:
                    # Examples could have no label for mode = "test"
                    labels.append("PUNC")
            #print(labels)
        if words:
            examples.append(InputExample(guid="%s-%d".format(mode, guid_index),
                                         words=words,
                                        labels=labels))
    return examples

In [3]:
train_obj = read_examples_from_file('/home/shubham/Project/pos_tag/data/ner','train')
lab_list = ['NNS', 'CD', 'TO', 'VBD', 'WP$', 'LS', 'RP', 'SYM', 'VBN', 'NNPS', 'RBR', 'JJS', 'VBP', 'MD', 'JJ', 'CC', 'VBG', 'IN', 'WP', 'PRP', 'PUNC', 'POS', 'FW', 'JJR', 'EX', 'WRB', 'DT', 'UH', 'VB', 'VBZ', 'RB', 'RBS', 'NN', 'WDT', 'NNP', 'PRP$', 'PDT']

label_map = {label:i for i, label in enumerate(lab_list)}
ix_to_tag = {i:label for i, label in enumerate(lab_list)}

def input_to_features(words,labels,pad_token_label_id=-1):  
    tokens = []
    label_ids = []
    for word, label in zip(words, labels):
        #print(word)
        word_tokens = tokenizer.tokenize(word)
        #print(word_tokens)
        tokens.extend(word_tokens)
        # Use the real label id for the first token of the word, and padding ids for the remaining tokens
        label_ids.extend([label_map[label]] + [label_map[label]] * (len(word_tokens) - 1))
        input_ids = tokenizer.convert_tokens_to_ids(tokens)
        dbert_input_ids = tokenizer.encode("[CLS]",add_special_tokens=False) + input_ids + tokenizer.encode("[SEP]",add_special_tokens=False) 
    return torch.tensor(dbert_input_ids, dtype=torch.long,device=device),torch.tensor(input_ids, dtype=torch.long,device=device),torch.tensor(label_ids, dtype=torch.long,device=device)

In [4]:
#train_obj = read_examples_from_file('/home/shubham/Project/pos_tag/data/ner','train')
#lab_list = ['NNS', 'CD', 'TO', 'VBD', 'WP$', 'LS', 'RP', 'SYM', 'VBN', 'NNPS', 'RBR', 'JJS', 'VBP', 'MD', 'JJ', 'CC', 'VBG', 'IN', 'WP', 'PRP', 'PUNC', 'POS', 'FW', 'JJR', 'EX', 'WRB', 'DT', 'UH', 'VB', 'VBZ', 'RB', 'RBS', 'NN', 'WDT', 'NNP', 'PRP$', 'PDT']

training_data =[]
for i in train_obj:
    training_data.append((i.words, i.labels,))
_,a,b = input_to_features(training_data[0][0],training_data[0][1])
print(a,b)
print(training_data[0][0])
print(tokenizer.encode(training_data[0][0],add_special_tokens=False))


tensor([ 7327, 19164,  2446,  2655,  2000, 17757,  2329, 12559,  1012],
       device='cuda:0') tensor([34, 29, 14, 32,  2, 28, 14, 32, 20], device='cuda:0')
['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.']
[100, 19164, 100, 2655, 2000, 17757, 100, 12559, 1012]


In [5]:
class DistilBertForTokenClassificationCustom(DistilBertPreTrainedModel):
    def __init__(self, config):
        super(DistilBertForTokenClassificationCustom, self).__init__(config)
        self.num_labels = config.num_labels

        self.distilbert = DistilBertModel(config)
        self.dropout = nn.Dropout(config.dropout)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

    def forward(self, input_ids, attention_mask=None, head_mask=None,
                inputs_embeds=None, labels=None):

        outputs = self.distilbert(input_ids,
                            attention_mask=None,
                            head_mask=None,
                            inputs_embeds=None)

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        outputs = (logits)

        return outputs
    

In [6]:

config = DistilBertConfig(num_labels=37)
teacher_model = DistilBertForTokenClassificationCustom(config)
state_dict = torch.load("/home/shubham/Project/pos_tag/models/distil/pytorch_model.bin",map_location=device)
teacher_model.load_state_dict(state_dict)
teacher_model.to(device)
teacher_model.eval()

DistilBertForTokenClassificationCustom(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (dropout): Dropout(p=0.1, inplace=False)
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
      

In [7]:
EMBEDDING_DIM = 256
HIDDEN_DIM = 512
learning_rate = 5e-5
adam_epsilon = 1e-8
weight_decay = 0.0


class LSTMTagger(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):
        super(LSTMTagger, self).__init__()
        self.hidden_dim = hidden_dim

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)

        # The LSTM takes word embeddings as inputs, and outputs hidden states
        # with dimensionality hidden_dim.
        self.lstm = nn.LSTM(embedding_dim, hidden_dim,num_layers=5,bidirectional=True)

        # The linear layer that maps from hidden state space to tag space
        self.hidden2tag = nn.Linear(hidden_dim*2, tagset_size)

    def forward(self, sentence):
        embeds = self.word_embeddings(sentence)
        lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))
        tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
        tag_scores = F.log_softmax(tag_space, dim=1)
        return tag_scores,tag_space

In [8]:
def evaluation(model):
    model.eval()
    val_obj = read_examples_from_file('/home/shubham/Project/pos_tag/data/ner','dev')
    validation_data =[]
    for i in val_obj:
        validation_data.append((i.words, i.labels,))
    out_list = []
    pred_list = []
    for sentence, tags in (validation_data[:]): 
        dbert_input_ids,sentence_in,targets = input_to_features(sentence,tags)
        targets = [ix_to_tag[i] for i in targets.tolist()]
        out_list.append(targets)
        with torch.no_grad():
            tag_scores,_ = model(sentence_in)
        pred = []
        for i in tag_scores:
            pred.append(ix_to_tag[int(np.argmax(i.cpu().detach().numpy()))])
        pred_list.append(pred)
    #print((pred_list))
    sc = f1_score(out_list,pred_list)
    print(sc)
    print(classification_report(out_list,pred_list))
    return sc

In [9]:
model = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, tokenizer.vocab_size, len(label_map))
#model.load_state_dict(torch.load("/home/shubham/Project/pos_tag/code/distilation_experiments/lstm_models/model_0.8884873404025104.pt"))
model.to(device)
#loss_function = nn.NLLLoss()
#loss_function = CrossEntropyLoss()

def custom_loss(lstm_prob, bert_prob, real_label):
    a = 0.8
    criterion_mse = nn.MSELoss()
    #criterion_ce = nn.NLLLoss()
    criterion_ce = CrossEntropyLoss()
    return a*criterion_ce(lstm_prob, real_label) + (1-a)*criterion_mse(lstm_prob, bert_prob)


#optimizer = optim.SGD(model.parameters(), lr=0.1)
optimizer = torch.optim.Adam(model.parameters())
"""
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_epsilon)
"""

for epoch in tqdm(range(100)): 
    model.train()
    for sentence, tags in (training_data[:]):
        model.zero_grad()
        dbert_input_ids,sentence_in,targets = input_to_features(sentence,tags)
        #print(targets.shape)
        
        with torch.no_grad(): 
            dbert_logits = teacher_model(dbert_input_ids.unsqueeze(0))
        
        tag_scores,logits = model(sentence_in)
        #print('logits score:', logits)
        #print('tag score:', tag_scores)
        #print('dbert_logits',dbert_logits[0][1:-1])
        #loss = loss_function(tag_scores, targets)
        #print(tag_scores.shape)
        loss = custom_loss(logits, dbert_logits[0][1:-1],targets)
        loss.backward()
        optimizer.step()

    print('loss',loss)
    f1 = evaluation(model)
    if f1>0.89:
        torch.save(model.state_dict(),'/home/shubham/Project/pos_tag/code/distilation_experiments/lstm_models/model_'+str(f1)+'.pt')

  0%|          | 0/100 [00:00<?, ?it/s]

loss tensor(0.5234, device='cuda:0', grad_fn=<AddBackward0>)
0.7572721467421224


  1%|          | 1/100 [10:47<17:48:16, 647.44s/it]

           precision    recall  f1-score   support

       RB       0.56      0.62      0.59       939
       JJ       0.52      0.53      0.52      2765
       IN       0.94      0.96      0.95      4892
       DT       0.99      0.98      0.98      3511
      POS       0.96      0.98      0.97       423
       TO       0.98      0.99      0.99       905
       NN       0.66      0.54      0.60      5358
      NNP       0.48      0.63      0.54      5817
       CD       0.81      0.85      0.83      2935
      VBP       0.76      0.64      0.70       365
       EX       0.93      0.33      0.48        40
      NNS       0.69      0.63      0.66      2486
     PUNC       0.96      0.99      0.97      6029
      WDT       1.00      0.66      0.80       155
      VBG       0.27      0.44      0.33       699
       MD       0.95      0.93      0.94       300
      PRP       0.99      0.92      0.96       862
      VBZ       0.74      0.67      0.70       509
       CC       0.97      0.96

  2%|▏         | 2/100 [21:32<17:36:07, 646.60s/it]

           precision    recall  f1-score   support

       RB       0.72      0.71      0.72       939
       JJ       0.63      0.65      0.64      2765
       IN       0.95      0.98      0.96      4892
       DT       0.99      0.98      0.99      3511
      POS       0.98      0.99      0.98       423
       TO       0.99      1.00      1.00       905
       NN       0.74      0.72      0.73      5358
      NNP       0.62      0.72      0.67      5817
       CD       0.82      0.91      0.86      2935
      VBP       0.73      0.78      0.75       365
       EX       0.87      1.00      0.93        40
      NNS       0.75      0.78      0.77      2486
     PUNC       0.97      0.99      0.98      6029
      WDT       0.96      0.88      0.92       155
      VBG       0.41      0.59      0.48       699
       MD       0.97      0.96      0.96       300
      PRP       0.98      0.97      0.98       862
      VBZ       0.66      0.78      0.71       509
       CC       0.99      0.98

  3%|▎         | 3/100 [32:17<17:24:45, 646.24s/it]

           precision    recall  f1-score   support

       RB       0.69      0.76      0.72       939
       JJ       0.59      0.69      0.64      2765
       IN       0.96      0.98      0.97      4892
       DT       0.99      0.99      0.99      3511
      POS       0.99      0.99      0.99       423
       TO       1.00      1.00      1.00       905
       NN       0.77      0.74      0.76      5358
      NNP       0.68      0.73      0.70      5817
       CD       0.82      0.88      0.85      2935
      VBP       0.70      0.78      0.74       365
       EX       1.00      0.90      0.95        40
      NNS       0.80      0.79      0.79      2486
     PUNC       0.98      0.98      0.98      6029
      WDT       0.97      0.88      0.92       155
      VBG       0.54      0.65      0.59       699
       MD       0.97      0.98      0.97       300
      PRP       0.99      0.96      0.98       862
      VBZ       0.64      0.84      0.73       509
       CC       0.99      0.98

  4%|▍         | 4/100 [43:02<17:13:31, 645.95s/it]

           precision    recall  f1-score   support

       RB       0.68      0.77      0.72       939
       JJ       0.63      0.69      0.66      2765
       IN       0.96      0.98      0.97      4892
       DT       0.99      0.98      0.99      3511
      POS       0.99      0.98      0.98       423
       TO       0.99      1.00      1.00       905
       NN       0.77      0.75      0.76      5358
      NNP       0.69      0.75      0.72      5817
       CD       0.84      0.89      0.86      2935
      VBP       0.82      0.75      0.78       365
       EX       0.88      0.95      0.92        40
      NNS       0.81      0.82      0.81      2486
     PUNC       0.98      0.99      0.99      6029
      WDT       0.95      0.90      0.92       155
      VBG       0.54      0.69      0.61       699
       MD       0.98      0.96      0.97       300
      PRP       0.98      0.98      0.98       862
      VBZ       0.77      0.81      0.79       509
       CC       0.99      0.98

  5%|▌         | 5/100 [53:47<17:02:14, 645.62s/it]

           precision    recall  f1-score   support

       RB       0.77      0.76      0.76       939
       JJ       0.62      0.72      0.66      2765
       IN       0.97      0.98      0.97      4892
       DT       0.99      0.99      0.99      3511
      POS       0.98      0.98      0.98       423
       TO       0.99      1.00      1.00       905
       NN       0.77      0.77      0.77      5358
      NNP       0.73      0.75      0.74      5817
       CD       0.87      0.92      0.89      2935
      VBP       0.85      0.77      0.81       365
       EX       0.95      0.90      0.92        40
      NNS       0.86      0.80      0.83      2486
     PUNC       0.98      0.99      0.99      6029
      WDT       0.95      0.90      0.92       155
      VBG       0.68      0.68      0.68       699
       MD       0.98      0.97      0.97       300
      PRP       0.99      0.97      0.98       862
      VBZ       0.71      0.84      0.77       509
       CC       1.00      0.99

  6%|▌         | 6/100 [1:04:32<16:51:12, 645.46s/it]

           precision    recall  f1-score   support

       RB       0.76      0.78      0.77       939
       JJ       0.63      0.73      0.68      2765
       IN       0.97      0.98      0.97      4892
       DT       0.99      0.99      0.99      3511
      POS       0.99      0.98      0.99       423
       TO       0.99      1.00      1.00       905
       NN       0.80      0.77      0.79      5358
      NNP       0.73      0.75      0.74      5817
       CD       0.86      0.93      0.89      2935
      VBP       0.83      0.77      0.80       365
       EX       1.00      0.68      0.81        40
      NNS       0.82      0.82      0.82      2486
     PUNC       0.99      0.99      0.99      6029
      WDT       0.93      0.92      0.93       155
      VBG       0.68      0.72      0.70       699
       MD       0.98      0.97      0.97       300
      PRP       0.98      0.98      0.98       862
      VBZ       0.72      0.85      0.78       509
       CC       1.00      0.99

  7%|▋         | 7/100 [1:15:17<16:40:06, 645.23s/it]

           precision    recall  f1-score   support

       RB       0.76      0.79      0.77       939
       JJ       0.66      0.74      0.70      2765
       IN       0.97      0.98      0.97      4892
       DT       0.99      0.99      0.99      3511
      POS       0.99      0.97      0.98       423
       TO       1.00      1.00      1.00       905
       NN       0.80      0.79      0.79      5358
      NNP       0.74      0.76      0.75      5817
       CD       0.88      0.92      0.90      2935
      VBP       0.79      0.80      0.80       365
       EX       0.97      0.95      0.96        40
      NNS       0.81      0.82      0.82      2486
     PUNC       0.98      0.99      0.99      6029
      WDT       0.96      0.92      0.94       155
      VBG       0.70      0.72      0.71       699
       MD       0.98      0.97      0.97       300
      PRP       0.99      0.97      0.98       862
      VBZ       0.75      0.83      0.79       509
       CC       0.99      0.99

  8%|▊         | 8/100 [1:26:02<16:29:12, 645.14s/it]

           precision    recall  f1-score   support

       RB       0.76      0.77      0.77       939
       JJ       0.62      0.74      0.67      2765
       IN       0.97      0.98      0.97      4892
       DT       0.99      0.99      0.99      3511
      POS       0.98      0.99      0.98       423
       TO       1.00      1.00      1.00       905
       NN       0.78      0.78      0.78      5358
      NNP       0.75      0.75      0.75      5817
       CD       0.88      0.92      0.90      2935
      VBP       0.84      0.75      0.80       365
       EX       0.97      0.85      0.91        40
      NNS       0.83      0.82      0.83      2486
     PUNC       0.99      0.99      0.99      6029
      WDT       0.95      0.92      0.93       155
      VBG       0.69      0.69      0.69       699
       MD       0.98      0.97      0.97       300
      PRP       0.99      0.98      0.98       862
      VBZ       0.78      0.81      0.79       509
       CC       0.99      0.98

  9%|▉         | 9/100 [1:36:47<16:18:18, 645.04s/it]

           precision    recall  f1-score   support

       RB       0.77      0.78      0.77       939
       JJ       0.67      0.71      0.69      2765
       IN       0.98      0.98      0.98      4892
       DT       0.99      0.99      0.99      3511
      POS       0.99      0.98      0.98       423
       TO       1.00      1.00      1.00       905
       NN       0.81      0.77      0.79      5358
      NNP       0.71      0.78      0.74      5817
       CD       0.88      0.92      0.90      2935
      VBP       0.78      0.79      0.78       365
       EX       0.92      0.90      0.91        40
      NNS       0.86      0.82      0.84      2486
     PUNC       0.98      0.99      0.99      6029
      WDT       0.95      0.94      0.94       155
      VBG       0.70      0.74      0.72       699
       MD       0.98      0.97      0.97       300
      PRP       0.99      0.98      0.98       862
      VBZ       0.76      0.83      0.80       509
       CC       0.99      0.99

 10%|█         | 10/100 [1:47:31<16:07:28, 644.98s/it]

           precision    recall  f1-score   support

       RB       0.78      0.77      0.78       939
       JJ       0.66      0.75      0.70      2765
       IN       0.97      0.97      0.97      4892
       DT       0.99      0.99      0.99      3511
      POS       0.97      0.99      0.98       423
       TO       0.99      1.00      1.00       905
       NN       0.79      0.80      0.79      5358
      NNP       0.74      0.77      0.75      5817
       CD       0.88      0.93      0.90      2935
      VBP       0.84      0.75      0.79       365
       EX       0.90      0.95      0.93        40
      NNS       0.86      0.81      0.83      2486
     PUNC       0.99      0.99      0.99      6029
      WDT       0.97      0.89      0.93       155
      VBG       0.73      0.73      0.73       699
       MD       0.99      0.98      0.98       300
      PRP       0.99      0.98      0.98       862
      VBZ       0.77      0.84      0.80       509
       CC       0.99      0.99

 11%|█         | 11/100 [1:58:16<15:56:27, 644.80s/it]

           precision    recall  f1-score   support

       RB       0.79      0.79      0.79       939
       JJ       0.64      0.74      0.69      2765
       IN       0.97      0.97      0.97      4892
       DT       0.99      0.99      0.99      3511
      POS       0.97      1.00      0.98       423
       TO       1.00      1.00      1.00       905
       NN       0.79      0.79      0.79      5358
      NNP       0.74      0.76      0.75      5817
       CD       0.89      0.91      0.90      2935
      VBP       0.83      0.78      0.80       365
       EX       0.86      0.95      0.90        40
      NNS       0.82      0.84      0.83      2486
     PUNC       0.98      0.99      0.99      6029
      WDT       0.96      0.92      0.94       155
      VBG       0.80      0.71      0.75       699
       MD       0.98      0.98      0.98       300
      PRP       0.99      0.98      0.98       862
      VBZ       0.77      0.82      0.80       509
       CC       0.99      0.98

 12%|█▏        | 12/100 [2:09:06<15:47:55, 646.31s/it]

           precision    recall  f1-score   support

       RB       0.80      0.77      0.78       939
       JJ       0.67      0.73      0.70      2765
       IN       0.97      0.98      0.97      4892
       DT       0.99      0.99      0.99      3511
      POS       0.97      1.00      0.98       423
       TO       1.00      1.00      1.00       905
       NN       0.81      0.78      0.80      5358
      NNP       0.73      0.77      0.75      5817
       CD       0.88      0.91      0.90      2935
      VBP       0.81      0.77      0.79       365
       EX       0.90      0.93      0.91        40
      NNS       0.81      0.83      0.82      2486
     PUNC       0.99      0.99      0.99      6029
      WDT       0.95      0.91      0.93       155
      VBG       0.69      0.75      0.72       699
       MD       0.97      0.97      0.97       300
      PRP       0.99      0.98      0.98       862
      VBZ       0.77      0.84      0.80       509
       CC       1.00      0.99

 13%|█▎        | 13/100 [2:19:50<15:36:28, 645.84s/it]

           precision    recall  f1-score   support

       RB       0.79      0.80      0.80       939
       JJ       0.66      0.74      0.70      2765
       IN       0.98      0.97      0.98      4892
       DT       0.99      0.99      0.99      3511
      POS       0.98      0.99      0.98       423
       TO       0.99      1.00      1.00       905
       NN       0.81      0.79      0.80      5358
      NNP       0.74      0.78      0.76      5817
       CD       0.89      0.92      0.90      2935
      VBP       0.82      0.78      0.80       365
       EX       0.90      0.93      0.91        40
      NNS       0.85      0.83      0.84      2486
     PUNC       0.99      0.99      0.99      6029
      WDT       0.94      0.94      0.94       155
      VBG       0.75      0.74      0.74       699
       MD       0.98      0.98      0.98       300
      PRP       0.98      0.98      0.98       862
      VBZ       0.77      0.83      0.80       509
       CC       0.99      0.98

 14%|█▍        | 14/100 [2:30:39<15:26:59, 646.73s/it]

           precision    recall  f1-score   support

       RB       0.77      0.80      0.79       939
       JJ       0.67      0.75      0.71      2765
       IN       0.98      0.97      0.97      4892
       DT       0.99      0.99      0.99      3511
      POS       0.99      1.00      0.99       423
       TO       0.99      1.00      1.00       905
       NN       0.79      0.79      0.79      5358
      NNP       0.75      0.76      0.75      5817
       CD       0.88      0.93      0.91      2935
      VBP       0.83      0.76      0.79       365
       EX       0.95      0.93      0.94        40
      NNS       0.81      0.83      0.82      2486
     PUNC       0.99      0.99      0.99      6029
      WDT       0.95      0.92      0.94       155
      VBG       0.77      0.72      0.75       699
       MD       0.98      0.97      0.98       300
      PRP       0.98      0.98      0.98       862
      VBZ       0.74      0.82      0.78       509
       CC       0.99      0.98

 15%|█▌        | 15/100 [2:41:23<15:14:48, 645.75s/it]

           precision    recall  f1-score   support

       RB       0.81      0.81      0.81       939
       JJ       0.63      0.75      0.69      2765
       IN       0.98      0.97      0.97      4892
       DT       0.99      0.99      0.99      3511
      POS       0.97      0.99      0.98       423
       TO       1.00      1.00      1.00       905
       NN       0.81      0.78      0.80      5358
      NNP       0.74      0.76      0.75      5817
       CD       0.88      0.92      0.90      2935
      VBP       0.82      0.77      0.79       365
       EX       1.00      0.88      0.93        40
      NNS       0.82      0.84      0.83      2486
     PUNC       0.99      0.99      0.99      6029
      WDT       0.94      0.93      0.93       155
      VBG       0.78      0.72      0.75       699
       MD       0.99      0.97      0.98       300
      PRP       0.99      0.98      0.99       862
      VBZ       0.81      0.81      0.81       509
       CC       0.99      0.98

 16%|█▌        | 16/100 [2:52:07<15:03:14, 645.17s/it]

           precision    recall  f1-score   support

       RB       0.76      0.79      0.78       939
       JJ       0.67      0.73      0.70      2765
       IN       0.97      0.97      0.97      4892
       DT       0.99      0.99      0.99      3511
      POS       0.98      0.99      0.98       423
       TO       0.99      1.00      1.00       905
       NN       0.79      0.80      0.80      5358
      NNP       0.75      0.76      0.75      5817
       CD       0.88      0.92      0.90      2935
      VBP       0.83      0.74      0.78       365
       EX       0.93      0.95      0.94        40
      NNS       0.81      0.84      0.82      2486
     PUNC       0.98      0.99      0.99      6029
      WDT       0.93      0.94      0.93       155
      VBG       0.72      0.74      0.73       699
       MD       0.96      0.98      0.97       300
      PRP       0.99      0.98      0.98       862
      VBZ       0.84      0.81      0.83       509
       CC       1.00      0.98

KeyboardInterrupt: 

In [None]:
"""# See what the scores are after training
with torch.no_grad():
    #inputs = prepare_sequence(training_data[0][0], word_to_ix)
    inputs = torch.tensor(tokenizer.encode(training_data[0][0],add_special_tokens=False), dtype=torch.long)
    print('input ',inputs)
    tag_scores = model(inputs)

    # The sentence is "the dog ate the apple".  i,j corresponds to score for tag j
    # for word i. The predicted tag is the maximum scoring tag.
    # Here, we can see the predicted sequence below is 0 1 2 0 1
    # since 0 is index of the maximum value of row 1,
    # 1 is the index of maximum value of row 2, etc.
    # Which is DET NOUN VERB DET NOUN, the correct sequence!
    for i in tag_scores:
        print(int(np.argmax(i)))
    print(tag_scores.shape)"""

In [None]:
"""out_list = []
for i in tag_scores:
    out_list.append(ix_to_tag[int(np.argmax(i))])
print((training_data[0][0],training_data[0][1]))
print(out_list)
sc = f1_score(out_list,training_data[0][1])
print(sc)"""

In [None]:
evaluation(model)