In [1]:
import torch
from transformers import get_constant_schedule_with_warmup
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
import os
from dataloader import CoNLLReader
from tqdm import tqdm
from tutils import *

In [2]:
try:
    del globals()['NERmodelbase']
except:
    pass
from NERmodel3 import NERmodelbase3

In [3]:
encoder_model = 'xlm-roberta-base'
tokenizer = AutoTokenizer.from_pretrained(encoder_model)

In [4]:
def collate_batch(batch):
    batch_ = list(zip(*batch))
    tokens, masks, token_masks, gold_spans, tags, lstm_encoded = batch_[0], batch_[1], batch_[2], batch_[3], \
                                                                 batch_[4], batch_[5]
    # print(tags)
    max_len = np.max([len(token) for token in tokens])
    # print(np.max([len(token) for token in tokens]), max_len)
    token_tensor = torch.empty(size=(len(tokens), max_len), dtype=torch.long).fill_(tokenizer.pad_token_id)
    tag_tensor = torch.empty(size=(len(tokens), max_len), dtype=torch.long).fill_(mconern['O'])
    mask_tensor = torch.zeros(size=(len(tokens), max_len), dtype=torch.bool)
    token_masks_tensor = torch.zeros(size=(len(tokens), max_len), dtype=torch.bool)
    lstm_encoded_tensor = torch.zeros(size=(len(tokens), max_len, 256), dtype=torch.float)
    # print(lstm_encoded.shape)
    for i in range(len(tokens)):
        tokens_ = tokens[i]
        seq_len = len(tokens_)

        token_tensor[i, :seq_len] = tokens_
        tag_tensor[i, :seq_len] = tags[i]
        mask_tensor[i, :seq_len] = masks[i]
        token_masks_tensor[i, :seq_len] = token_masks[i]
        lstm_encoded_tensor[i, 1:seq_len - 1, :] = lstm_encoded[i]

    return token_tensor, tag_tensor, mask_tensor, token_masks_tensor, gold_spans, lstm_encoded_tensor

In [5]:
def get_optimizer(net, opt=False):
    optimizer = torch.optim.AdamW(net.parameters(), lr=1e-4, weight_decay=0.03)
    if opt:
        scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEP)
        return [optimizer], [scheduler]
    return [optimizer]


In [6]:
NUM_EPOCH = 1
BATCH_SIZE = 64
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

In [7]:
# wnut_iob = {'B-CORP': 0, 'I-CORP': 1, 'B-CW': 2, 'I-CW': 3, 'B-GRP': 4, 'I-GRP': 5, 'B-LOC': 6, 'I-LOC': 7,
#             'B-PER': 8, 'I-PER': 9, 'B-PROD': 10, 'I-PROD': 11, 'O': 12}
mconern = indvidual(mconer_grouped, True)
reveremap = invert(mconer_grouped)

In [8]:
import pickle
if os.path.exists('train_load.pkl'):
    with open('train_load.pkl', 'rb') as f:
        ds = pickle.load(f)
else:
    print("reading from disk")
    ds = CoNLLReader(target_vocab=mconern, encoder_model=encoder_model, reversemap=reveremap, finegrained=fine)
    ds.read_data(data=r'C:\Users\Rah12937\PycharmProjects\mconer\multiconer2023\train_dev\en-train.conll')
    with open('train_load.pkl', 'wb') as f:
        pickle.dump(ds, f)

In [9]:
if os.path.exists('valid_load.pkl'):
    with open('valid_load.pkl', 'rb') as f:
        valid = pickle.load(f)
else:
    valid = CoNLLReader(target_vocab=mconern, encoder_model=encoder_model, reversemap=reveremap, finegrained=True)
    valid.read_data(data=r'C:\Users\Rah12937\PycharmProjects\mconer\multiconer2023\train_dev\en-dev.conll')
    with open('valid_load.pkl', 'wb') as f:
        pickle.dump(ds, f)

In [10]:
model = NERmodelbase3(tag_to_id=mconern, device=device, encoder_model=encoder_model, dropout=0.3, use_lstm=True).to(
    device)

Some weights of the model checkpoint at xlm-roberta-base were not used when initializing XLMRobertaModel: ['lm_head.dense.bias', 'lm_head.bias', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.dense.weight']
- This IS expected if you are initializing XLMRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  nn.init.uniform(self.w_omega, -0.1, 0.1)


In [11]:
trainloader = DataLoader(ds, batch_size=BATCH_SIZE, collate_fn=collate_batch, num_workers=0, shuffle=False)
validloader = DataLoader(valid, batch_size=BATCH_SIZE, collate_fn=collate_batch, num_workers=0)

In [12]:
WARMUP_STEP = int(len(trainloader) * NUM_EPOCH * 0.1)
print(f"Number of warm up step is {WARMUP_STEP}")
optim, scheduler = get_optimizer(model, True)

Number of warm up step is 26


In [13]:
from pytorchtools import EarlyStopping
import random
from tensorboardX import SummaryWriter
run_id = random.randint(1, 10000)
run_name = f"runid_{run_id}_EP_{NUM_EPOCH}_fine_xlm-b-birnnn-focal-loss-0_8-sep_lr-alpha-2-gama-4"
writer = SummaryWriter(run_name)
step = 0
running_loss = 0
early_stopping = EarlyStopping(patience=10, verbose=True, path=run_name + '.pt')

In [15]:
import numpy as np
eval_step = 100
for epoch in range(NUM_EPOCH):
    val_track = []
    with tqdm(trainloader, unit='batch') as tepoch:
        # model.train()
        tepoch.set_description(f"Epoch {epoch}")
        for i, data in enumerate(tepoch):
            optim[0].zero_grad()
            outputs, focal_loss, all_prob, token_scores, mask, metadata = model(data)
            loss = 0.2 * outputs['loss'] + 0.8 * focal_loss
            running_loss += loss
            loss.backward()
            optim[0].step()
            scheduler[0].step()
            # if i % 10 == 0:  # print every 2000 mini-batches
            model.spanf1.reset()
            # writer.add_scalar('lr', scheduler[0].get_last_lr()[0], step)
            # run validation
            step += 1
            if (step + 1) % eval_step == 0:
                model.eval()
                with torch.no_grad():
                    with tqdm(validloader, unit='batch') as tepoch:
                        val_loss = 0
                        for i, data in enumerate(tepoch):
                            outputs, focal_loss, all_prob, token_scores, mask, metadata = model(data, mode='predict')
                            val_loss += 0.2 * outputs['loss'] + 0.8 * focal_loss
                model.train()
                writer.add_scalars("Loss",
                                   {
                                       "Train Loss": round(running_loss.detach().cpu().numpy().ravel()[0] / 20, 4),
                                       "Valid Loss": round(val_loss.detach().cpu().numpy().ravel()[0] / 20, 4),
                                   }
                                   , step)
                writer.add_scalars("Metrics", outputs['results'], step)
                # print(outputs['results'])
                # writer.add_scalar("Loss/Test", round(val_loss.numpy()[0] / len(validloader), 4), step)
                # writer.add_scalar("Loss/Valid",
                #                   round(val_loss.detach().cpu().numpy().ravel()[0] / 20, 4), step)
                val_track.append(round(val_loss.detach().cpu().numpy().ravel()[0] / eval_step, 4))
                running_loss = 0

                early_stopping(round(val_loss.detach().cpu().numpy().ravel()[0] / eval_step, 4), model)
                if early_stopping.early_stop:
                    print("Stopping early")
                    break

writer.close()

Epoch 0:  37%|███████████████████████████████████████████████████████████▌                                                                                                    | 98/263 [01:07<01:52,  1.46batch/s]
  0%|                                                                                                                                                                                   | 0/14 [00:00<?, ?batch/s][A
  7%|████████████▏                                                                                                                                                              | 1/14 [00:00<00:07,  1.77batch/s][A

tensor([[-1.1259e+01, -1.0008e+04, -8.2031e+00,  ..., -8.7335e+00,
         -1.0009e+04, -1.0547e+00],
        [-1.0671e+01, -1.9902e+01, -1.0335e+01,  ..., -9.6462e+00,
         -1.7437e+01, -1.0900e+00],
        [-7.3997e+00, -1.6369e+01, -7.2292e+00,  ..., -6.5972e+00,
         -1.5234e+01, -1.5071e+00],
        ...,
        [-3.7593e+01, -3.7997e+01, -3.7270e+01,  ..., -3.6562e+01,
         -3.7791e+01, -2.7648e+01],
        [-3.7563e+01, -4.6462e+01, -3.7224e+01,  ..., -3.6527e+01,
         -4.5734e+01, -2.7679e+01],
        [-3.9265e+01, -4.5919e+01, -3.8370e+01,  ..., -3.6927e+01,
         -4.8371e+01, -2.6347e+01]], device='cuda:0') tensor([72, 50, 51,  ..., 72, 72, 72], device='cuda:0')
tensor(3.5233, device='cuda:0')



 14%|████████████████████████▍                                                                                                                                                  | 2/14 [00:00<00:05,  2.16batch/s][A

tensor([[-1.1323e+01, -1.0009e+04, -8.2771e+00,  ..., -8.8125e+00,
         -1.0009e+04, -1.0540e+00],
        [-1.0753e+01, -2.0031e+01, -1.0424e+01,  ..., -9.7350e+00,
         -1.7653e+01, -1.0880e+00],
        [-5.8942e+00, -1.5405e+01, -5.9268e+00,  ..., -5.4355e+00,
         -1.4151e+01, -3.3925e+00],
        ...,
        [-7.2552e+01, -7.4089e+01, -7.2614e+01,  ..., -7.2125e+01,
         -7.3353e+01, -7.1020e+01],
        [-7.5641e+01, -7.7182e+01, -7.5714e+01,  ..., -7.5236e+01,
         -7.6463e+01, -7.4142e+01],
        [-8.5666e+01, -8.3933e+01, -8.4774e+01,  ..., -8.3315e+01,
         -8.6942e+01, -7.2811e+01]], device='cuda:0') tensor([72, 40, 41,  ..., 41, 41, 72], device='cuda:0')
tensor(4.2999, device='cuda:0')



 21%|████████████████████████████████████▋                                                                                                                                      | 3/14 [00:01<00:04,  2.30batch/s][A

tensor([[-1.1264e+01, -1.0008e+04, -8.2312e+00,  ..., -8.7586e+00,
         -1.0009e+04, -1.0542e+00],
        [-1.0848e+01, -2.0096e+01, -1.0533e+01,  ..., -9.8345e+00,
         -1.7703e+01, -1.0864e+00],
        [-1.1299e+01, -2.0165e+01, -1.0956e+01,  ..., -1.0245e+01,
         -1.9230e+01, -1.1149e+00],
        ...,
        [-3.6037e+01, -4.5310e+01, -3.5716e+01,  ..., -3.4973e+01,
         -4.4303e+01, -2.5839e+01],
        [-3.5902e+01, -4.5195e+01, -3.5557e+01,  ..., -3.4868e+01,
         -4.4263e+01, -2.5869e+01],
        [-3.7506e+01, -4.4297e+01, -3.6604e+01,  ..., -3.5163e+01,
         -4.6728e+01, -2.4537e+01]], device='cuda:0') tensor([72, 72, 72,  ..., 72, 72, 72], device='cuda:0')
tensor(2.7286, device='cuda:0')



 29%|████████████████████████████████████████████████▊                                                                                                                          | 4/14 [00:01<00:04,  2.32batch/s][A

tensor([[-1.1229e+01, -1.0008e+04, -8.1836e+00,  ..., -8.7138e+00,
         -1.0009e+04, -1.0547e+00],
        [-1.0894e+01, -2.0125e+01, -1.0557e+01,  ..., -9.8767e+00,
         -1.7664e+01, -1.0862e+00],
        [-1.1197e+01, -2.0070e+01, -1.0879e+01,  ..., -1.0138e+01,
         -1.9131e+01, -1.1161e+00],
        ...,
        [-3.1608e+01, -4.0683e+01, -3.1284e+01,  ..., -3.0582e+01,
         -3.9767e+01, -2.1655e+01],
        [-3.1622e+01, -4.0681e+01, -3.1289e+01,  ..., -3.0592e+01,
         -3.9799e+01, -2.1685e+01],
        [-3.3257e+01, -3.9976e+01, -3.2373e+01,  ..., -3.0931e+01,
         -4.2420e+01, -2.0354e+01]], device='cuda:0') tensor([72, 72, 72,  ..., 72, 72, 72], device='cuda:0')
tensor(2.4289, device='cuda:0')



 36%|█████████████████████████████████████████████████████████████                                                                                                              | 5/14 [00:02<00:03,  2.33batch/s][A

tensor([[-1.1220e+01, -1.0008e+04, -8.1855e+00,  ..., -8.7218e+00,
         -1.0009e+04, -1.0546e+00],
        [-1.0875e+01, -2.0081e+01, -1.0550e+01,  ..., -9.8641e+00,
         -1.7685e+01, -1.0864e+00],
        [-1.1160e+01, -2.0089e+01, -1.0815e+01,  ..., -1.0147e+01,
         -1.9165e+01, -1.1154e+00],
        ...,
        [-2.0757e+01, -2.9965e+01, -2.0425e+01,  ..., -1.9771e+01,
         -2.9081e+01, -1.1011e+01],
        [-2.0780e+01, -2.9655e+01, -2.0440e+01,  ..., -1.9783e+01,
         -2.8816e+01, -1.1043e+01],
        [-2.2624e+01, -2.9133e+01, -2.1726e+01,  ..., -2.0304e+01,
         -3.1639e+01, -9.7114e+00]], device='cuda:0') tensor([72, 72, 72,  ..., 72, 72, 72], device='cuda:0')
tensor(2.4356, device='cuda:0')



 43%|█████████████████████████████████████████████████████████████████████████▎                                                                                                 | 6/14 [00:02<00:03,  2.34batch/s][A

tensor([[-1.1101e+01, -1.0008e+04, -8.0828e+00,  ..., -8.6081e+00,
         -1.0009e+04, -1.0556e+00],
        [-1.0894e+01, -2.0006e+01, -1.0579e+01,  ..., -9.8935e+00,
         -1.7599e+01, -1.0866e+00],
        [-1.1314e+01, -2.0230e+01, -1.0994e+01,  ..., -1.0265e+01,
         -1.9323e+01, -1.1147e+00],
        ...,
        [-3.4578e+01, -3.5057e+01, -3.4232e+01,  ..., -3.3539e+01,
         -3.4793e+01, -2.4538e+01],
        [-3.4566e+01, -4.3610e+01, -3.4208e+01,  ..., -3.3527e+01,
         -4.2788e+01, -2.4567e+01],
        [-3.6205e+01, -4.2960e+01, -3.5304e+01,  ..., -3.3869e+01,
         -4.5396e+01, -2.3236e+01]], device='cuda:0') tensor([72, 72, 72,  ..., 72, 72, 72], device='cuda:0')
tensor(2.5973, device='cuda:0')



 50%|█████████████████████████████████████████████████████████████████████████████████████▌                                                                                     | 7/14 [00:03<00:02,  2.36batch/s][A

tensor([[-1.1346e+01, -1.0009e+04, -8.3000e+00,  ..., -8.8366e+00,
         -1.0009e+04, -1.0537e+00],
        [-1.0810e+01, -2.0112e+01, -1.0485e+01,  ..., -9.7898e+00,
         -1.7738e+01, -1.0868e+00],
        [-1.1366e+01, -2.0156e+01, -1.1023e+01,  ..., -1.0301e+01,
         -1.9256e+01, -1.1146e+00],
        ...,
        [-1.2324e+01, -2.1612e+01, -1.1971e+01,  ..., -1.1273e+01,
         -2.0713e+01, -2.0904e+00],
        [-1.2312e+01, -2.1581e+01, -1.1946e+01,  ..., -1.1259e+01,
         -2.0690e+01, -2.1186e+00],
        [-1.3826e+01, -2.0735e+01, -1.2910e+01,  ..., -1.1487e+01,
         -2.3209e+01, -7.8652e-01]], device='cuda:0') tensor([72, 72, 72,  ..., 72, 72, 72], device='cuda:0')
tensor(5.0346, device='cuda:0')



 57%|█████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                         | 8/14 [00:03<00:02,  2.39batch/s][A

tensor([[-1.1382e+01, -1.0009e+04, -8.3305e+00,  ..., -8.8572e+00,
         -1.0009e+04, -1.0536e+00],
        [-1.0789e+01, -2.0108e+01, -1.0463e+01,  ..., -9.7562e+00,
         -1.7737e+01, -1.0873e+00],
        [-1.1361e+01, -2.0125e+01, -1.1019e+01,  ..., -1.0304e+01,
         -1.9220e+01, -1.1151e+00],
        ...,
        [-1.8422e+01, -2.7708e+01, -1.8075e+01,  ..., -1.7372e+01,
         -2.6799e+01, -8.2602e+00],
        [-1.8420e+01, -2.7643e+01, -1.8063e+01,  ..., -1.7368e+01,
         -2.6749e+01, -8.2887e+00],
        [-1.9975e+01, -2.6832e+01, -1.9075e+01,  ..., -1.7636e+01,
         -2.9286e+01, -6.9568e+00]], device='cuda:0') tensor([72, 72, 72,  ..., 72, 72, 72], device='cuda:0')
tensor(4.0720, device='cuda:0')



 64%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                             | 9/14 [00:03<00:02,  2.41batch/s][A

tensor([[-1.1310e+01, -1.0009e+04, -8.2670e+00,  ..., -8.8036e+00,
         -1.0009e+04, -1.0539e+00],
        [-1.0844e+01, -2.0115e+01, -1.0525e+01,  ..., -9.8302e+00,
         -1.7753e+01, -1.0864e+00],
        [-1.0974e+01, -1.9876e+01, -1.0632e+01,  ..., -9.9601e+00,
         -1.8969e+01, -1.1168e+00],
        ...,
        [-2.3926e+01, -2.6231e+01, -2.3593e+01,  ..., -2.2881e+01,
         -2.5793e+01, -1.3858e+01],
        [-2.3909e+01, -3.3060e+01, -2.3564e+01,  ..., -2.2863e+01,
         -3.2144e+01, -1.3887e+01],
        [-2.5583e+01, -3.2333e+01, -2.4683e+01,  ..., -2.3230e+01,
         -3.4780e+01, -1.2555e+01]], device='cuda:0') tensor([72, 72, 72,  ..., 72, 72, 72], device='cuda:0')
tensor(3.5446, device='cuda:0')



 71%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                | 10/14 [00:04<00:01,  2.50batch/s][A

tensor([[-1.1290e+01, -1.0009e+04, -8.2444e+00,  ..., -8.7803e+00,
         -1.0009e+04, -1.0541e+00],
        [-1.0820e+01, -2.0087e+01, -1.0492e+01,  ..., -9.8001e+00,
         -1.7670e+01, -1.0870e+00],
        [-5.7602e+00, -1.5429e+01, -5.7955e+00,  ..., -5.3544e+00,
         -1.4131e+01, -3.9795e+00],
        ...,
        [-2.3793e+01, -3.2781e+01, -2.3460e+01,  ..., -2.2785e+01,
         -3.1908e+01, -1.4053e+01],
        [-2.3822e+01, -3.2697e+01, -2.3476e+01,  ..., -2.2807e+01,
         -3.1836e+01, -1.4085e+01],
        [-2.5609e+01, -3.2128e+01, -2.4712e+01,  ..., -2.3284e+01,
         -3.4605e+01, -1.2754e+01]], device='cuda:0') tensor([72,  4,  5,  ..., 72, 72, 72], device='cuda:0')
tensor(4.0704, device='cuda:0')



 79%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                    | 11/14 [00:04<00:01,  2.59batch/s][A

tensor([[-1.1158e+01, -1.0008e+04, -8.1217e+00,  ..., -8.6648e+00,
         -1.0009e+04, -1.0551e+00],
        [-1.0934e+01, -2.0094e+01, -1.0598e+01,  ..., -9.9248e+00,
         -1.7687e+01, -1.0859e+00],
        [-1.1077e+01, -2.0059e+01, -1.0741e+01,  ..., -1.0042e+01,
         -1.9167e+01, -1.1155e+00],
        ...,
        [-3.3734e+01, -3.5846e+01, -3.3414e+01,  ..., -3.2736e+01,
         -3.5458e+01, -2.3981e+01],
        [-3.3677e+01, -4.2576e+01, -3.3347e+01,  ..., -3.2683e+01,
         -4.1708e+01, -2.4014e+01],
        [-3.5666e+01, -4.2067e+01, -3.4761e+01,  ..., -3.3324e+01,
         -4.4549e+01, -2.2682e+01]], device='cuda:0') tensor([72, 72, 72,  ..., 72, 72, 72], device='cuda:0')
tensor(4.8810, device='cuda:0')



 86%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                        | 12/14 [00:05<00:00,  2.50batch/s][A

tensor([[-1.1280e+01, -1.0008e+04, -8.2413e+00,  ..., -8.7715e+00,
         -1.0009e+04, -1.0542e+00],
        [-1.0880e+01, -2.0119e+01, -1.0556e+01,  ..., -9.8593e+00,
         -1.7730e+01, -1.0862e+00],
        [-1.1308e+01, -2.0191e+01, -1.0978e+01,  ..., -1.0261e+01,
         -1.9247e+01, -1.1145e+00],
        ...,
        [-3.5495e+01, -3.8826e+01, -3.5159e+01,  ..., -3.4453e+01,
         -3.8330e+01, -2.5433e+01],
        [-3.5463e+01, -4.4608e+01, -3.5123e+01,  ..., -3.4427e+01,
         -4.3722e+01, -2.5462e+01],
        [-3.7184e+01, -4.3901e+01, -3.6273e+01,  ..., -3.4833e+01,
         -4.6381e+01, -2.4130e+01]], device='cuda:0') tensor([72, 72, 72,  ..., 72, 72, 72], device='cuda:0')
tensor(3.7363, device='cuda:0')



 93%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊            | 13/14 [00:05<00:00,  2.48batch/s][A

tensor([[-1.1326e+01, -1.0009e+04, -8.2748e+00,  ..., -8.8169e+00,
         -1.0009e+04, -1.0539e+00],
        [-1.0809e+01, -2.0093e+01, -1.0483e+01,  ..., -9.7920e+00,
         -1.7717e+01, -1.0869e+00],
        [-1.1249e+01, -2.0068e+01, -1.0887e+01,  ..., -1.0199e+01,
         -1.9162e+01, -1.1155e+00],
        ...,
        [-6.8807e+01, -6.9274e+01, -6.8478e+01,  ..., -6.7781e+01,
         -6.9030e+01, -5.8868e+01],
        [-6.8780e+01, -7.7734e+01, -6.8441e+01,  ..., -6.7757e+01,
         -7.6950e+01, -5.8898e+01],
        [-7.0539e+01, -7.7169e+01, -6.9637e+01,  ..., -6.8184e+01,
         -7.9612e+01, -5.7567e+01]], device='cuda:0') tensor([72, 72, 72,  ..., 72, 72, 72], device='cuda:0')
tensor(3.4025, device='cuda:0')



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:05<00:00,  2.44batch/s][A

tensor([[-1.1334e+01, -1.0009e+04, -8.2843e+00,  ..., -8.8112e+00,
         -1.0009e+04, -1.0539e+00],
        [-1.0779e+01, -2.0064e+01, -1.0447e+01,  ..., -9.7448e+00,
         -1.7668e+01, -1.0876e+00],
        [-1.1316e+01, -2.0084e+01, -1.0951e+01,  ..., -1.0251e+01,
         -1.9167e+01, -1.1157e+00],
        ...,
        [-1.4470e+01, -2.3612e+01, -1.4133e+01,  ..., -1.3454e+01,
         -2.2721e+01, -4.5135e+00],
        [-1.4499e+01, -2.3553e+01, -1.4152e+01,  ..., -1.3482e+01,
         -2.2680e+01, -4.5431e+00],
        [-1.6257e+01, -2.2953e+01, -1.5351e+01,  ..., -1.3927e+01,
         -2.5453e+01, -3.2107e+00]], device='cuda:0') tensor([72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72,
        72, 72, 28, 29, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 28, 29, 29,
        29, 29, 29, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72,
        72, 72, 28, 29, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 54,
        55, 55, 55, 72, 72, 28,




1.0806
Validation loss decreased (inf --> 1.080600).  Saving model ...


Epoch 0:  75%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                       | 198/263 [02:24<00:46,  1.40batch/s]
  0%|                                                                                                                                                                                   | 0/14 [00:00<?, ?batch/s][A
  7%|████████████▏                                                                                                                                                              | 1/14 [00:00<00:06,  1.87batch/s][A

tensor([[-9.0412e+00, -1.0007e+04, -8.0746e+00,  ..., -8.4767e+00,
         -1.0009e+04, -1.0554e+00],
        [-7.1213e+00, -1.4684e+01, -8.7745e+00,  ..., -7.9318e+00,
         -1.5342e+01, -1.1703e+00],
        [-5.2888e+00, -1.1870e+01, -5.8623e+00,  ..., -5.5004e+00,
         -1.2465e+01, -3.2819e+00],
        ...,
        [-3.4838e+01, -3.5887e+01, -3.6531e+01,  ..., -3.5741e+01,
         -3.7203e+01, -2.7541e+01],
        [-3.4812e+01, -4.1937e+01, -3.6490e+01,  ..., -3.5710e+01,
         -4.4002e+01, -2.7583e+01],
        [-3.6669e+01, -4.1406e+01, -3.7766e+01,  ..., -3.6238e+01,
         -4.6850e+01, -2.6262e+01]], device='cuda:0') tensor([72, 50, 51,  ..., 72, 72, 72], device='cuda:0')
tensor(3.5537, device='cuda:0')



 14%|████████████████████████▍                                                                                                                                                  | 2/14 [00:00<00:05,  2.28batch/s][A

tensor([[-8.3458e+00, -1.0006e+04, -7.2473e+00,  ..., -7.6691e+00,
         -1.0008e+04, -1.0754e+00],
        [-7.4113e+00, -1.4463e+01, -9.0763e+00,  ..., -8.2135e+00,
         -1.4846e+01, -1.1705e+00],
        [-5.2930e+00, -1.2189e+01, -5.8026e+00,  ..., -5.4610e+00,
         -1.2719e+01, -3.5697e+00],
        ...,
        [-5.2123e+01, -5.5012e+01, -5.2773e+01,  ..., -5.2360e+01,
         -5.5104e+01, -4.9773e+01],
        [-5.3936e+01, -5.6849e+01, -5.4598e+01,  ..., -5.4177e+01,
         -5.6960e+01, -5.1545e+01],
        [-5.7972e+01, -5.8400e+01, -5.8412e+01,  ..., -5.7121e+01,
         -6.2318e+01, -5.0666e+01]], device='cuda:0') tensor([72, 40, 41,  ..., 41, 41, 72], device='cuda:0')
tensor(3.5508, device='cuda:0')



 21%|████████████████████████████████████▋                                                                                                                                      | 3/14 [00:01<00:04,  2.38batch/s][A

tensor([[-8.4441e+00, -1.0006e+04, -7.3290e+00,  ..., -7.8052e+00,
         -1.0008e+04, -1.0675e+00],
        [-8.4767e+00, -1.5635e+01, -1.0239e+01,  ..., -9.4508e+00,
         -1.6215e+01, -1.1075e+00],
        [-5.7020e+00, -1.3475e+01, -6.6919e+00,  ..., -6.1455e+00,
         -1.4641e+01, -1.7739e+00],
        ...,
        [-1.4527e+01, -1.8940e+01, -1.5607e+01,  ..., -1.4988e+01,
         -1.9651e+01, -1.0338e+01],
        [-1.5556e+01, -1.9946e+01, -1.6767e+01,  ..., -1.6138e+01,
         -2.0934e+01, -1.0620e+01],
        [-1.9147e+01, -2.1746e+01, -2.0159e+01,  ..., -1.8621e+01,
         -2.6654e+01, -9.3212e+00]], device='cuda:0') tensor([72, 72, 72,  ..., 72, 72, 72], device='cuda:0')
tensor(3.1312, device='cuda:0')



 29%|████████████████████████████████████████████████▊                                                                                                                          | 4/14 [00:01<00:04,  2.34batch/s][A

tensor([[-8.6438e+00, -1.0007e+04, -7.5777e+00,  ..., -8.0317e+00,
         -1.0008e+04, -1.0623e+00],
        [-8.3641e+00, -1.5777e+01, -1.0101e+01,  ..., -9.3276e+00,
         -1.6319e+01, -1.1054e+00],
        [-5.7258e+00, -1.3341e+01, -6.7520e+00,  ..., -6.2023e+00,
         -1.4555e+01, -1.7144e+00],
        ...,
        [-3.1268e+01, -3.3815e+01, -3.2951e+01,  ..., -3.2170e+01,
         -3.5292e+01, -2.3985e+01],
        [-3.1244e+01, -3.8362e+01, -3.2910e+01,  ..., -3.2140e+01,
         -4.0423e+01, -2.4027e+01],
        [-3.3340e+01, -3.8027e+01, -3.4492e+01,  ..., -3.2950e+01,
         -4.3542e+01, -2.2703e+01]], device='cuda:0') tensor([72, 72, 72,  ..., 72, 72, 72], device='cuda:0')
tensor(2.8394, device='cuda:0')



 36%|█████████████████████████████████████████████████████████████                                                                                                              | 5/14 [00:02<00:03,  2.36batch/s][A

tensor([[-9.0604e+00, -1.0007e+04, -8.0763e+00,  ..., -8.4965e+00,
         -1.0009e+04, -1.0543e+00],
        [-8.4304e+00, -1.5995e+01, -1.0246e+01,  ..., -9.4036e+00,
         -1.6835e+01, -1.0947e+00],
        [-9.0303e+00, -1.6082e+01, -1.0876e+01,  ..., -1.0065e+01,
         -1.8383e+01, -1.1244e+00],
        ...,
        [-1.8118e+01, -2.3834e+01, -1.9966e+01,  ..., -1.9126e+01,
         -2.5990e+01, -1.0106e+01],
        [-1.8136e+01, -2.5813e+01, -1.9978e+01,  ..., -1.9141e+01,
         -2.8227e+01, -1.0134e+01],
        [-1.9636e+01, -2.5047e+01, -2.0832e+01,  ..., -1.9256e+01,
         -3.0736e+01, -8.8062e+00]], device='cuda:0') tensor([72, 72, 72,  ..., 72, 72, 72], device='cuda:0')
tensor(2.3262, device='cuda:0')



 43%|█████████████████████████████████████████████████████████████████████████▎                                                                                                 | 6/14 [00:02<00:03,  2.40batch/s][A

tensor([[-8.8646e+00, -1.0007e+04, -7.8429e+00,  ..., -8.2793e+00,
         -1.0008e+04, -1.0574e+00],
        [-8.4623e+00, -1.5957e+01, -1.0254e+01,  ..., -9.4432e+00,
         -1.6669e+01, -1.0973e+00],
        [-7.9285e+00, -1.5224e+01, -9.5578e+00,  ..., -8.7931e+00,
         -1.7191e+01, -1.1550e+00],
        ...,
        [-2.9602e+01, -3.0643e+01, -3.1351e+01,  ..., -3.0545e+01,
         -3.2067e+01, -2.2028e+01],
        [-2.9612e+01, -3.6942e+01, -3.1357e+01,  ..., -3.0558e+01,
         -3.9149e+01, -2.2064e+01],
        [-3.1131e+01, -3.6175e+01, -3.2225e+01,  ..., -3.0699e+01,
         -4.1682e+01, -2.0743e+01]], device='cuda:0') tensor([72, 72, 72,  ..., 72, 72, 72], device='cuda:0')
tensor(2.2834, device='cuda:0')



 50%|█████████████████████████████████████████████████████████████████████████████████████▌                                                                                     | 7/14 [00:02<00:02,  2.47batch/s][A

tensor([[-8.9655e+00, -1.0007e+04, -7.9603e+00,  ..., -8.3910e+00,
         -1.0008e+04, -1.0557e+00],
        [-8.5803e+00, -1.6092e+01, -1.0409e+01,  ..., -9.5728e+00,
         -1.6902e+01, -1.0928e+00],
        [-5.8425e+00, -1.3695e+01, -6.8924e+00,  ..., -6.3208e+00,
         -1.4936e+01, -1.6173e+00],
        ...,
        [-1.5532e+01, -2.1309e+01, -1.7198e+01,  ..., -1.6420e+01,
         -2.3203e+01, -8.3205e+00],
        [-1.5591e+01, -2.2617e+01, -1.7260e+01,  ..., -1.6487e+01,
         -2.4675e+01, -8.3626e+00],
        [-1.7791e+01, -2.2447e+01, -1.8963e+01,  ..., -1.7399e+01,
         -2.8002e+01, -7.0358e+00]], device='cuda:0') tensor([72, 72, 72,  ..., 72, 72, 72], device='cuda:0')
tensor(3.3196, device='cuda:0')



 57%|█████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                         | 8/14 [00:03<00:02,  2.50batch/s][A

tensor([[-8.9665e+00, -1.0007e+04, -7.9761e+00,  ..., -8.3891e+00,
         -1.0008e+04, -1.0558e+00],
        [-8.3330e+00, -1.5856e+01, -1.0136e+01,  ..., -9.2865e+00,
         -1.6623e+01, -1.0991e+00],
        [-5.7090e+00, -1.3304e+01, -6.7337e+00,  ..., -6.1863e+00,
         -1.4495e+01, -1.7190e+00],
        ...,
        [-2.8370e+01, -3.1837e+01, -2.9900e+01,  ..., -2.9171e+01,
         -3.3209e+01, -2.1782e+01],
        [-2.8385e+01, -3.4924e+01, -2.9898e+01,  ..., -2.9187e+01,
         -3.6689e+01, -2.1850e+01],
        [-3.0616e+01, -3.4727e+01, -3.1638e+01,  ..., -3.0137e+01,
         -3.9970e+01, -2.0538e+01]], device='cuda:0') tensor([72, 72, 72,  ..., 72, 72, 72], device='cuda:0')
tensor(2.6540, device='cuda:0')



 64%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                             | 9/14 [00:03<00:01,  2.52batch/s][A

tensor([[-8.8062e+00, -1.0007e+04, -7.7700e+00,  ..., -8.2025e+00,
         -1.0008e+04, -1.0587e+00],
        [-8.5947e+00, -1.6001e+01, -1.0413e+01,  ..., -9.5799e+00,
         -1.6736e+01, -1.0956e+00],
        [-8.4539e+00, -1.5806e+01, -1.0170e+01,  ..., -9.3744e+00,
         -1.7919e+01, -1.1358e+00],
        ...,
        [-2.6425e+01, -2.9233e+01, -2.8002e+01,  ..., -2.7257e+01,
         -3.0563e+01, -1.9650e+01],
        [-2.6438e+01, -3.3150e+01, -2.8002e+01,  ..., -2.7270e+01,
         -3.4982e+01, -1.9710e+01],
        [-2.8363e+01, -3.2705e+01, -2.9363e+01,  ..., -2.7876e+01,
         -3.7929e+01, -1.8402e+01]], device='cuda:0') tensor([72, 72, 72,  ..., 72, 72, 72], device='cuda:0')
tensor(2.4598, device='cuda:0')



 71%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                | 10/14 [00:04<00:01,  2.54batch/s][A

tensor([[-9.1491e+00, -1.0007e+04, -8.2161e+00,  ..., -8.5915e+00,
         -1.0009e+04, -1.0535e+00],
        [-7.5556e+00, -1.5144e+01, -9.2968e+00,  ..., -8.4127e+00,
         -1.5938e+01, -1.1304e+00],
        [-5.2402e+00, -1.2287e+01, -5.8504e+00,  ..., -5.4776e+00,
         -1.2945e+01, -3.1474e+00],
        ...,
        [-2.1365e+01, -2.7835e+01, -2.3184e+01,  ..., -2.2359e+01,
         -3.0075e+01, -1.3448e+01],
        [-2.1402e+01, -2.9004e+01, -2.3217e+01,  ..., -2.2398e+01,
         -3.1389e+01, -1.3477e+01],
        [-2.2881e+01, -2.8236e+01, -2.4048e+01,  ..., -2.2491e+01,
         -3.3897e+01, -1.2150e+01]], device='cuda:0') tensor([72,  4,  5,  ..., 72, 72, 72], device='cuda:0')
tensor(2.0232, device='cuda:0')



 79%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                    | 11/14 [00:04<00:01,  2.57batch/s][A

tensor([[-8.8636e+00, -1.0007e+04, -7.8474e+00,  ..., -8.2732e+00,
         -1.0008e+04, -1.0574e+00],
        [-8.5764e+00, -1.6070e+01, -1.0385e+01,  ..., -9.5674e+00,
         -1.6793e+01, -1.0946e+00],
        [-6.5214e+00, -1.4187e+01, -7.8063e+00,  ..., -7.1460e+00,
         -1.5725e+01, -1.3160e+00],
        ...,
        [-3.2596e+01, -3.4834e+01, -3.4323e+01,  ..., -3.3526e+01,
         -3.6348e+01, -2.5112e+01],
        [-3.2590e+01, -3.9847e+01, -3.4302e+01,  ..., -3.3520e+01,
         -4.2018e+01, -2.5149e+01],
        [-3.4273e+01, -3.9190e+01, -3.5370e+01,  ..., -3.3846e+01,
         -4.4708e+01, -2.3827e+01]], device='cuda:0') tensor([72, 72, 72,  ..., 72, 72, 72], device='cuda:0')
tensor(3.7690, device='cuda:0')



 86%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                        | 12/14 [00:04<00:00,  2.51batch/s][A

tensor([[-8.6986e+00, -1.0007e+04, -7.6419e+00,  ..., -8.0939e+00,
         -1.0008e+04, -1.0611e+00],
        [-8.4885e+00, -1.5836e+01, -1.0271e+01,  ..., -9.4643e+00,
         -1.6509e+01, -1.1007e+00],
        [-6.4125e+00, -1.3993e+01, -7.6568e+00,  ..., -7.0314e+00,
         -1.5493e+01, -1.3477e+00],
        ...,
        [-4.5044e+01, -4.7675e+01, -4.6495e+01,  ..., -4.5797e+01,
         -4.8862e+01, -3.8774e+01],
        [-4.5178e+01, -5.1423e+01, -4.6642e+01,  ..., -4.5945e+01,
         -5.3095e+01, -3.8855e+01],
        [-4.7485e+01, -5.1406e+01, -4.8475e+01,  ..., -4.6991e+01,
         -5.6578e+01, -3.7548e+01]], device='cuda:0') tensor([72, 72, 72,  ..., 72, 72, 72], device='cuda:0')
tensor(2.9710, device='cuda:0')



 93%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊            | 13/14 [00:05<00:00,  2.49batch/s][A

tensor([[-8.6756e+00, -1.0007e+04, -7.6055e+00,  ..., -8.0581e+00,
         -1.0008e+04, -1.0618e+00],
        [-8.5299e+00, -1.5823e+01, -1.0321e+01,  ..., -9.4964e+00,
         -1.6507e+01, -1.1006e+00],
        [-6.4063e+00, -1.4054e+01, -7.6257e+00,  ..., -6.9917e+00,
         -1.5492e+01, -1.3632e+00],
        ...,
        [-7.2599e+01, -7.4151e+01, -7.4107e+01,  ..., -7.3388e+01,
         -7.5276e+01, -6.6088e+01],
        [-7.2813e+01, -7.9255e+01, -7.4352e+01,  ..., -7.3631e+01,
         -8.1038e+01, -6.6150e+01],
        [-7.5218e+01, -7.9412e+01, -7.6311e+01,  ..., -7.4792e+01,
         -8.4747e+01, -6.4831e+01]], device='cuda:0') tensor([72, 72, 72,  ..., 72, 72, 72], device='cuda:0')
tensor(2.7744, device='cuda:0')



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:05<00:00,  2.50batch/s][A

tensor([[-8.7615e+00, -1.0007e+04, -7.7149e+00,  ..., -8.1475e+00,
         -1.0008e+04, -1.0599e+00],
        [-8.4746e+00, -1.5815e+01, -1.0276e+01,  ..., -9.4345e+00,
         -1.6516e+01, -1.1000e+00],
        [-6.4489e+00, -1.4010e+01, -7.6745e+00,  ..., -7.0442e+00,
         -1.5481e+01, -1.3455e+00],
        ...,
        [-1.8276e+01, -2.3710e+01, -1.9880e+01,  ..., -1.9127e+01,
         -2.5467e+01, -1.1359e+01],
        [-1.8243e+01, -2.5060e+01, -1.9824e+01,  ..., -1.9087e+01,
         -2.6952e+01, -1.1414e+01],
        [-2.0315e+01, -2.4681e+01, -2.1363e+01,  ..., -1.9858e+01,
         -3.0016e+01, -1.0098e+01]], device='cuda:0') tensor([72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72,
        72, 72, 28, 29, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 28, 29, 29,
        29, 29, 29, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72,
        72, 72, 28, 29, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 54,
        55, 55, 55, 72, 72, 28,


Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 263/263 [03:18<00:00,  1.32batch/s]


In [69]:
torch.cat(all_prob, dim=1).argmax(-1).shape

torch.Size([1, 238])

In [76]:
torch.argmax(torch.cat(all_prob, dim=1), dim=-1).squeeze(0)
torch.masked_select(metadata, mask)
# torch.nn.CrossEntropyLoss()(torch.cat(all_prob, dim=1).squeeze(0),
#                          torch.masked_select(metadata, mask))
from sklearn.metrics import f1_score
f1_score(torch.cat(all_prob, dim=1).argmax(-1).squeeze(0).cpu(),
        torch.masked_select(metadata, mask).cpu(), average='micro')

0.7268907563025211

In [17]:
from kornia.losses import FocalLoss
fl = FocalLoss(alpha=2, gamma=5, reduction='mean')
fl(token_scores.permute(0, 2, 1), metadata)

tensor(1.8132, device='cuda:0', grad_fn=<MeanBackward0>)

In [18]:
fl(torch.cat(all_prob, dim=1).squeeze(0),
    torch.masked_select(metadata, mask))

tensor(3.7399, device='cuda:0')

In [19]:
mask.unsqueeze(-1).expand(token_scores.size()).shape

torch.Size([10, 34, 73])

In [20]:
torch.masked_select(token_scores, mask.unsqueeze(-1).expand(token_scores.size())).ravel().shape

torch.Size([17374])

In [21]:
torch.masked_select(metadata, mask).shape

torch.Size([238])

In [22]:
all_prob[0].shape

torch.Size([1, 17, 73])

In [23]:
model.crf.transitions.cpu().detach().numpy().shape

(73, 73)

In [24]:
model.crf.viterbi_tags(torch.randn([2, 10, 73]))

([([4, 56, 10, 24, 16, 12, 0, 72, 28, 64], 24.494735717773438),
  ([4, 5, 16, 20, 22, 62, 56, 28, 18, 20], 23.203876495361328)],
 [tensor([[[-2.7083e+00, -9.9991e+03, -3.2452e+00, -1.0000e+04,  3.3397e+00,
            -9.9985e+03, -8.0218e-01, -9.9996e+03, -5.2809e-02, -9.9976e+03,
            -1.6960e+00, -1.0000e+04,  8.0595e-01, -9.9994e+03, -1.1864e+00,
            -9.9993e+03,  1.1634e+00, -9.9992e+03, -3.3863e+00, -1.0001e+04,
             2.3295e+00, -1.0000e+04,  6.1551e-01, -1.0001e+04, -3.6707e-01,
            -1.0001e+04, -9.2977e-01, -9.9997e+03,  4.8140e-01, -1.0001e+04,
            -2.1461e+00, -9.9992e+03, -2.0133e+00, -9.9994e+03, -1.0128e+00,
            -1.0001e+04, -4.6495e-02, -1.0000e+04,  2.6608e+00, -9.9999e+03,
            -4.0778e-01, -9.9998e+03, -4.0299e+00, -1.0000e+04, -1.1274e-01,
            -9.9996e+03, -9.5995e-02, -9.9989e+03,  1.9019e+00, -9.9996e+03,
            -2.8813e+00, -9.9999e+03, -2.0686e-01, -9.9990e+03,  7.8721e-01,
            -1.0000e+04,

In [25]:
import allennlp.nn.util as util
def viterbi_tags(
    logits: torch.Tensor, mask: torch.BoolTensor = None, top_k: int = None,
    _constraint_mask= None, _transitions=None, start_transitions=None, end_transitions=None
):
    if mask is None:
        mask = torch.ones(*logits.shape[:2], dtype=torch.bool, device=logits.device)

    if top_k is None:
        top_k = 1
        flatten_output = True
    else:
        flatten_output = False
    all_prob_val = []
    _, max_seq_length, num_tags = logits.size()
#     print(num_tags)
    # Get the tensors out of the variables
    logits, mask = logits.data, mask.data

    # Augment transitions matrix with start and end transitions
    start_tag = num_tags
    end_tag = num_tags + 1
    transitions = torch.full((num_tags + 2, num_tags + 2), -10000.0, device=logits.device)

    # Apply transition constraints
    constrained_transitions = _transitions * _constraint_mask[:num_tags, :num_tags] + -10000.0 * (1 - _constraint_mask[:num_tags, :num_tags])
    transitions[:num_tags, :num_tags] = constrained_transitions.data

    if True:
        transitions[
            start_tag, :num_tags
        ] = start_transitions.detach() * _constraint_mask[
            start_tag, :num_tags
        ].data + -10000.0 * (
            1 - _constraint_mask[start_tag, :num_tags].detach()
        )
        transitions[:num_tags, end_tag] = end_transitions.detach() * _constraint_mask[
            :num_tags, end_tag
        ].data + -10000.0 * (1 - _constraint_mask[:num_tags, end_tag].detach())
    else:
        transitions[start_tag, :num_tags] = -10000.0 * (
            1 - self._constraint_mask[start_tag, :num_tags].detach()
        )
        transitions[:num_tags, end_tag] = -10000.0 * (
            1 - self._constraint_mask[:num_tags, end_tag].detach()
        )

    best_paths = []
    # Pad the max sequence length by 2 to account for start_tag + end_tag.
    tag_sequence = torch.empty(max_seq_length + 2, num_tags + 2, device=logits.device)

    for prediction, prediction_mask in zip(logits, mask):
        mask_indices = prediction_mask.nonzero(as_tuple=False).squeeze()
        masked_prediction = torch.index_select(prediction, 0, mask_indices)
        sequence_length = masked_prediction.shape[0]

        # Start with everything totally unlikely
        tag_sequence.fill_(-10000.0)
        # At timestep 0 we must have the START_TAG
        tag_sequence[0, start_tag] = 0.0
        # At steps 1, ..., sequence_length we just use the incoming prediction
        tag_sequence[1 : (sequence_length + 1), :num_tags] = masked_prediction
        # And at the last timestep we must have the END_TAG
        tag_sequence[sequence_length + 1, end_tag] = 0.0
#         print(torch.argmax(tag_sequence, dim=-1))

        # We pass the tags and the transitions to `viterbi_decode`.
        viterbi_paths, viterbi_scores, all_prob = viterbi_decode(
            tag_sequence=tag_sequence[: (sequence_length + 2)],
            transition_matrix=transitions,
            top_k=top_k,
        )
        stacked_t = torch.stack(list(reversed(all_prob))[1:])
        all_prob_val.append(stacked_t.unsqueeze(0))
#         print(viterbi_paths, viterbi_scores)
        top_k_paths = []
        for viterbi_path, viterbi_score in zip(viterbi_paths, viterbi_scores):
            # Get rid of START and END sentinels and append.
            viterbi_path = viterbi_path[1:-1]
            top_k_paths.append((viterbi_path, viterbi_score.item()))
        best_paths.append(top_k_paths)

    if flatten_output:
        return [top_k_paths[0] for top_k_paths in best_paths]

    return best_paths, all_prob_val


In [26]:
from typing import Optional, List

def viterbi_decode(
    tag_sequence: torch.Tensor,
    transition_matrix: torch.Tensor,
    tag_observations: Optional[List[int]] = None,
    allowed_start_transitions: torch.Tensor = None,
    allowed_end_transitions: torch.Tensor = None,
    top_k: int = None,
):
    global all_score
    all_prob = []
#     print(tag_sequence.shape)
    if top_k is None:
        top_k = 1
        flatten_output = True
    elif top_k >= 1:
        flatten_output = False
    else:
        raise ValueError(f"top_k must be either None or an integer >=1. Instead received {top_k}")

    sequence_length, num_tags = list(tag_sequence.size())

    has_start_end_restrictions = (
        allowed_end_transitions is not None or allowed_start_transitions is not None
    )

    if has_start_end_restrictions:

        if allowed_end_transitions is None:
            allowed_end_transitions = torch.zeros(num_tags)
        if allowed_start_transitions is None:
            allowed_start_transitions = torch.zeros(num_tags)

        num_tags = num_tags + 2
        new_transition_matrix = torch.zeros(num_tags, num_tags)
        new_transition_matrix[:-2, :-2] = transition_matrix

        # Start and end transitions are fully defined, but cannot transition between each other.

        allowed_start_transitions = torch.cat(
            [allowed_start_transitions, torch.tensor([-math.inf, -math.inf])]
        )
        allowed_end_transitions = torch.cat(
            [allowed_end_transitions, torch.tensor([-math.inf, -math.inf])]
        )

        # First define how we may transition FROM the start and end tags.
        new_transition_matrix[-2, :] = allowed_start_transitions
        # We cannot transition from the end tag to any tag.
        new_transition_matrix[-1, :] = -math.inf

        new_transition_matrix[:, -1] = allowed_end_transitions
        # We cannot transition to the start tag from any tag.
        new_transition_matrix[:, -2] = -math.inf

        transition_matrix = new_transition_matrix

    if tag_observations:
        if len(tag_observations) != sequence_length:
            raise ConfigurationError(
                "Observations were provided, but they were not the same length "
                "as the sequence. Found sequence of length: {} and evidence: {}".format(
                    sequence_length, tag_observations
                )
            )
    else:
        tag_observations = [-1 for _ in range(sequence_length)]

    if has_start_end_restrictions:
        tag_observations = [num_tags - 2] + tag_observations + [num_tags - 1]
        zero_sentinel = torch.zeros(1, num_tags)
        extra_tags_sentinel = torch.ones(sequence_length, 2) * -math.inf
        tag_sequence = torch.cat([tag_sequence, extra_tags_sentinel], -1)
        tag_sequence = torch.cat([zero_sentinel, tag_sequence, zero_sentinel], 0)
        sequence_length = tag_sequence.size(0)

    path_scores = []
    path_indices = []

    if tag_observations[0] != -1:
        one_hot = torch.zeros(num_tags)
        one_hot[tag_observations[0]] = 100000.0
        path_scores.append(one_hot.unsqueeze(0))
    else:
        path_scores.append(tag_sequence[0, :].unsqueeze(0))

    # Evaluate the scores for all possible paths.
#     print("length", path_scores[0].shape)
#     all_score = []
    for timestep in range(1, sequence_length):
#         print(timestep, path_scores[timestep - 1].shape, transition_matrix.shape)
        # Add pairwise potentials to current scores.
        summed_potentials = path_scores[timestep - 1].unsqueeze(2) + transition_matrix
        summed_potentials = summed_potentials.view(-1, num_tags)
#         print(summed_potentials.shape, num_tags)
        # Best pairwise potential path score from the previous timestep.
        max_k = min(summed_potentials.size()[0], top_k)
        scores, paths = torch.topk(summed_potentials, k=max_k, dim=0)
#         print(scores, paths)
        all_score.append(summed_potentials)
#         print("best paths", paths)
        # If we have an observation for this timestep, use it
        # instead of the distribution over tags.
        observation = tag_observations[timestep]
        # Warn the user if they have passed
        # invalid/extremely unlikely evidence.
        if tag_observations[timestep - 1] != -1 and observation != -1:
            if transition_matrix[tag_observations[timestep - 1], observation] < -10000:
                logger.warning(
                    "The pairwise potential between tags you have passed as "
                    "observations is extremely unlikely. Double check your evidence "
                    "or transition potentials!"
                )
        if observation != -1:
            one_hot = torch.zeros(num_tags)
            one_hot[observation] = 100000.0
            path_scores.append(one_hot.unsqueeze(0))
        else:
            path_scores.append(tag_sequence[timestep, :] + scores)
        path_indices.append(paths.squeeze())
#     print("Simple Argmax", torch.argmax(torch.cat(path_scores[1:-1], dim=0), dim=-1))
#     print(torch.argmax(torch.cat(all_score, dim=0), dim=-1))
    # Construct the most likely sequence backwards.
    path_scores_v = path_scores[-1].view(-1)
#     print(path_scores[-1].shape)
    max_k = min(path_scores_v.size()[0], top_k)
    viterbi_scores, best_paths = torch.topk(path_scores_v, k=max_k, dim=0)
#     print(viterbi_scores, best_paths, max_k)
    viterbi_paths = []
    
    for i in range(max_k):
        viterbi_path = [best_paths[i]]
        
        for idx, backward_timestep in enumerate(reversed(path_indices)):
#             print(idx)
#             print("All Collected", backward_timestep, viterbi_path[-1])
#             print("recreated", all_score[-(idx+1)])
#             print("Tensor", all_score[-(idx+1)][:, viterbi_path[-1]])
            all_prob.append(all_score[-(idx+1)][:-2][:, viterbi_path[-1]])
            viterbi_path.append(int(backward_timestep.view(-1)[viterbi_path[-1]]))
        # Reverse the backward path.
        viterbi_path.reverse()

        if has_start_end_restrictions:
            viterbi_path = viterbi_path[1:-1]

        # Viterbi paths uses (num_tags * n_permutations) nodes; therefore, we need to modulo.
        viterbi_path = [j % num_tags for j in viterbi_path]
        viterbi_paths.append(viterbi_path)

    if flatten_output:
        return viterbi_paths[0], viterbi_scores[0]

    return viterbi_paths, viterbi_scores, all_prob


In [27]:
dimen = 73
z = torch.randn([2, 100, dimen])
all_prob_val = []
x, all_prob = viterbi_tags(logits=z, top_k=1,
             _constraint_mask= model.crf._constraint_mask[:dimen+2, :dimen+2], _transitions=model.crf.transitions[:dimen, :dimen],
             start_transitions=model.crf.start_transitions[:dimen],
            end_transitions = model.crf.end_transitions[:dimen])
print(torch.argmax(torch.cat(all_prob, dim=0), dim=-1))
print(x[0])
# for i in torch.tensor_split(z, len(z), dim=0):
#     all_score = []

#     final_val = []
#     x, all_prob = viterbi_tags(logits=i, top_k=1,
#                  _constraint_mask= model.crf._constraint_mask[:dimen+2, :dimen+2], _transitions=model.crf.transitions[:dimen, :dimen],
#                  start_transitions=model.crf.start_transitions[:dimen],
#                 end_transitions = model.crf.end_transitions[:dimen])
# #     print(x)
#     stacked_t = torch.stack(list(reversed(all_prob))[1:])
#     all_prob_val.append(stacked_t.unsqueeze(0))
#     print(torch.argmax(stacked_t, dim=1))
#     print(all(torch.tensor(x[0][0][0]) == torch.argmax(stacked_t, dim=1)))

NameError: name 'all_score' is not defined

In [None]:
for i in x:
    print(i)

In [None]:
torch.argmax(torch.stack(list(reversed(all_prob))), dim=0)


In [None]:
print(torch.topk(all_score[-2], 1, 0))
print(all_score[-1])
all_score[-1][:, 6]

In [None]:
p = torch.randn([1,10]).unsqueeze(2)
q = torch.randn([10,10])
# print(p,q)
model.crf._constraint_mask.shape
model.crf.transitions.shape

In [None]:
def _viterbi_decoding(emissions, transitions):
    scores = torch.zeros(emissions.size(1))
#     print(scores)
#     back_pointers = torch.zeros(emissions.size()).int()
#     scores = scores + emissions[0]
#     # Generate most likely scores and paths for each step in sequence
#     for i in range(1, emissions.size(0)):
#         scores_with_transitions = scores.unsqueeze(1).expand_as(transitions) + transitions
#         max_scores, back_pointers[i] = torch.max(scores_with_transitions, 0)
#         scores = emissions[i] + max_scores
#     # Generate the most likely path
#     viterbi = [scores.numpy().argmax()]
#     back_pointers = back_pointers.numpy()
#     for bp in reversed(back_pointers[1:]):
#         viterbi.append(bp[viterbi[-1]])
#     viterbi.reverse()
#     viterbi_score = scores.numpy().max()
#     return viterbi_score, viterbi

In [None]:
_viterbi_decoding(z, model.crf.transitions)

In [None]:
torch.topk((p+q).view(-1,10), k=1, dim=0)

In [None]:
model.crf._constraint_mask[:73, :73] *  model.crf.transitions

In [None]:
p+q