In [490]:
import seqeval
# transformers
# seqeval
# pytorch-crf

In [491]:
seqeval.__version__

In [329]:
# !python main.py --task katanemo \
#                   --model_type bert \
#                   --model_dir katanemo_model \
#                   --save_steps 10 \
#                   --do_train --do_eval

In [332]:
!python main.py --task katanemo \
                  --model_type bert \
                  --model_dir katanemo_model \
                  --do_train --do_eval --num_train_epochs 30 --logging_steps 3 \
                  --save_steps 10

usage: main.py [-h] --task TASK --model_dir MODEL_DIR [--data_dir DATA_DIR]
               [--intent_label_file INTENT_LABEL_FILE]
               [--slot_label_file SLOT_LABEL_FILE] [--model_type MODEL_TYPE]
               [--seed SEED] [--train_batch_size TRAIN_BATCH_SIZE]
               [--eval_batch_size EVAL_BATCH_SIZE] [--max_seq_len MAX_SEQ_LEN]
               [--learning_rate LEARNING_RATE]
               [--num_train_epochs NUM_TRAIN_EPOCHS]
               [--weight_decay WEIGHT_DECAY]
               [--gradient_accumulation_steps GRADIENT_ACCUMULATION_STEPS]
               [--adam_epsilon ADAM_EPSILON] [--max_grad_norm MAX_GRAD_NORM]
               [--max_steps MAX_STEPS] [--warmup_steps WARMUP_STEPS]
               [--dropout_rate DROPOUT_RATE] [--logging_steps LOGGING_STEPS]
               [--save_steps SAVE_STEPS] [--do_train] [--do_eval] [--no_cuda]
               [--ignore_index IGNORE_INDEX] [--slot_loss_coef SLOT_LOSS_COEF]
               [--use_crf] [--slot_pad_label S

In [439]:
import predict
from utils import init_logger, load_tokenizer, get_intent_labels, get_slot_labels, MODEL_CLASSES
import torch
import os
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from tqdm import tqdm, trange
import numpy as np
import collections

In [424]:
def get_args(pred_config):
    return torch.load(os.path.join(pred_config.model_dir, 'training_args.bin'))

def process_lines(lines_to_pred):
    lines = []
    for line in lines_to_pred:
        line = line.strip()
        words = line.split()
        lines.append(words)
    
    return lines

In [426]:
class args:
    task = 'katanemo'
    model_dir = f'./{task}_model'
    data_dir = './data/'
    intent_label_file = 'intent_label.txt'
    slot_label_file = 'slot_label.txt'
    batch_size = 32

In [427]:
model_args = get_args(pred_config = args)
model_args

Namespace(task='katanemo', model_dir='katanemo_model', data_dir='./data', intent_label_file='intent_label.txt', slot_label_file='slot_label.txt', model_type='bert', seed=1234, train_batch_size=32, eval_batch_size=64, max_seq_len=50, learning_rate=5e-05, num_train_epochs=50.0, weight_decay=0.0, gradient_accumulation_steps=1, adam_epsilon=1e-08, max_grad_norm=1.0, max_steps=-1, warmup_steps=0, dropout_rate=0.1, logging_steps=3, save_steps=10, do_train=True, do_eval=True, no_cuda=False, ignore_index=0, slot_loss_coef=1.0, use_crf=False, slot_pad_label='PAD', model_name_or_path='bert-base-uncased')

In [428]:
intent_label_lst = get_intent_labels(args)
slot_label_lst = get_slot_labels(args)

In [429]:
get_slot_labels(args)

['PAD',
 'UNK',
 'O',
 'B-count_what',
 'I-count_what',
 'B-entity_type',
 'I-entity_type',
 'B-list_what',
 'B-time_frame',
 'I-time_frame']

In [441]:
model = MODEL_CLASSES['bert'][1].from_pretrained(args.model_dir, args=get_args(pred_config = args),
                                                                  intent_label_lst=get_intent_labels(args),
                                                                  slot_label_lst=get_slot_labels(args))
model.eval()
print('model_loaded')

model_loaded


In [478]:
# Convert input file to TensorDataset
pad_token_label_id = model_args.ignore_index
tokenizer = load_tokenizer(model_args)
# lines = read_input_file(pred_config)
lines_to_pred = ["how many devices have high network peaks?", 
                "Which network had major discards or errors?",
                "Are there any drops in network over the last 10 months"]
# lines_to_pred = ["show flights saturday evening from st. louis to burbank", 
#                  "which flights travel from las vegas to los angeles california and arrive on april ninth between 4 and 5 pm?"]
lines = process_lines(lines_to_pred)
dataset = predict.convert_input_file_to_tensor_dataset(lines, args, model_args, tokenizer, pad_token_label_id)

In [479]:
# Predict
sampler = SequentialSampler(dataset)
data_loader = DataLoader(dataset, sampler=sampler, batch_size=args.batch_size)

all_slot_label_mask = None
intent_preds = None
slot_preds = None

In [480]:
device = 'cpu'

In [481]:
for batch in tqdm(data_loader, desc="Predicting"):
    batch = tuple(t.to(device) for t in batch)
    with torch.no_grad():
        inputs = {"input_ids": batch[0],
                  "attention_mask": batch[1],
                  "intent_label_ids": None,
                  "slot_labels_ids": None}
        if model_args.model_type != "distilbert":
            inputs["token_type_ids"] = batch[2]
        outputs = model(**inputs)
        _, (intent_logits, slot_logits) = outputs[:2]

        # Intent Prediction
        if intent_preds is None:
            intent_preds = intent_logits.detach().cpu().numpy()
        else:
            intent_preds = np.append(intent_preds, intent_logits.detach().cpu().numpy(), axis=0)

        # Slot prediction
        if slot_preds is None:
            if model_args.use_crf:
                # decode() in `torchcrf` returns list with best index directly
                slot_preds = np.array(model.crf.decode(slot_logits))
            else:
                slot_preds = slot_logits.detach().cpu().numpy()
            all_slot_label_mask = batch[3].detach().cpu().numpy()
        else:
            if model_args.use_crf:
                slot_preds = np.append(slot_preds, np.array(model.crf.decode(slot_logits)), axis=0)
            else:
                slot_preds = np.append(slot_preds, slot_logits.detach().cpu().numpy(), axis=0)
            all_slot_label_mask = np.append(all_slot_label_mask, batch[3].detach().cpu().numpy(), axis=0)

Predicting: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.61it/s]


In [482]:
# slot_label_lst

In [483]:
intent_preds = np.argmax(intent_preds, axis=1)

if not model_args.use_crf:
    slot_preds = np.argmax(slot_preds, axis=2)

slot_label_map = {i: label for i, label in enumerate(slot_label_lst)}
slot_preds_list = [[] for _ in range(slot_preds.shape[0])]

for i in range(slot_preds.shape[0]):
    for j in range(slot_preds.shape[1]):
        if all_slot_label_mask[i, j] != pad_token_label_id:
            slot_preds_list[i].append(slot_label_map[slot_preds[i][j]])

# Write to output file

for words, slot_preds, intent_pred in zip(lines, slot_preds_list, intent_preds):
    line = ""
    entity_dict = collections.defaultdict(list)
    for word, pred in zip(words, slot_preds):
        if pred == 'O':
            line = line + word + " "
        else:
            # line = line + "[{}:{}] ".format(word, pred)
            entity_dict[pred].append(word)
    # print("<{}> -> {} -> {}\n".format(intent_label_lst[intent_pred], " ".join(words), line.strip()))
    print("<{}> -> {} -> {}\n".format(intent_label_lst[intent_pred], " ".join(words), dict(entity_dict)))



<fact_count> -> how many devices have high network peaks? -> {'B-entity_type': ['devices', 'network']}

<fact_list> -> Which network had major discards or errors? -> {'B-entity_type': ['network']}

<time_series> -> Are there any drops in network over the last 10 months -> {'B-entity_type': ['network'], 'I-time_frame': ['months']}



In [477]:
dict(entity_dict)

{'B-entity_type': ['network'], 'I-time_frame': ['months'], 1: [], 0: []}

In [403]:
intent_preds = np.argmax(intent_preds, axis=1)

if not model_args.use_crf:
    slot_preds = np.argmax(slot_preds, axis=2)

slot_label_map = {i: label for i, label in enumerate(slot_label_lst)}
slot_preds_list = [[] for _ in range(slot_preds.shape[0])]

for i in range(slot_preds.shape[0]):
    for j in range(slot_preds.shape[1]):
        if all_slot_label_mask[i, j] != pad_token_label_id:
            slot_preds_list[i].append(slot_label_map[slot_preds[i][j]])

for words, slot_preds, intent_pred in zip(lines, slot_preds_list, intent_preds):
    line = ""
    for word, pred in zip(words, slot_preds):
        if pred == 'O':
            line = line + word + " "
        else:
            line = line + "[{}:{}] ".format(word, pred)

    print("<{}> -> {}\n".format(intent_label_lst[intent_pred], line.strip()))
        
# # Write to output file
# with open(pred_config.output_file, "w", encoding="utf-8") as f:
#     for words, slot_preds, intent_pred in zip(lines, slot_preds_list, intent_preds):
#         line = ""
#         for word, pred in zip(words, slot_preds):
#             if pred == 'O':
#                 line = line + word + " "
#             else:
#                 line = line + "[{}:{}] ".format(word, pred)
#         f.write("<{}> -> {}\n".format(intent_label_lst[intent_pred], line.strip()))

# logger.info("Prediction Done!")

<atis_flight> -> show flights [saturday:B-depart_date.day_name] [evening:B-depart_time.period_of_day] from [st.:B-fromloc.city_name] [louis:I-fromloc.city_name] to [burbank:B-toloc.city_name]

<atis_flight> -> which flights travel from [las:B-fromloc.city_name] [vegas:I-fromloc.city_name] to [los:B-toloc.city_name] [angeles:I-toloc.city_name] [california:B-toloc.state_name] and arrive on [april:B-arrive_date.month_name] [ninth:B-arrive_date.day_number] between [4:B-arrive_time.start_time] and [5:B-arrive_time.end_time] [pm?:I-arrive_time.end_time]



In [385]:
slot_preds_list

[['B-entity_type',
  'B-entity_type',
  'B-entity_type',
  'B-entity_type',
  'B-entity_type',
  'B-entity_type',
  'B-entity_type'],
 ['B-entity_type',
  'B-entity_type',
  'B-entity_type',
  'B-entity_type',
  'B-entity_type',
  'B-entity_type',
  'B-entity_type'],
 ['B-entity_type',
  'B-entity_type',
  'B-entity_type',
  'B-entity_type',
  'B-entity_type',
  'B-entity_type',
  'B-entity_type',
  'B-entity_type',
  'B-entity_type',
  'B-entity_type',
  'B-entity_type']]

In [404]:
for words, slot_preds, intent_pred in zip(lines, slot_preds_list, intent_preds):
    line = ""
    for word, pred in zip(words, slot_preds):
        if pred == 'O':
            line = line + word + " "
        else:
            line = line + "[{}:{}] ".format(word, pred)

In [406]:
lines

[['show',
  'flights',
  'saturday',
  'evening',
  'from',
  'st.',
  'louis',
  'to',
  'burbank'],
 ['which',
  'flights',
  'travel',
  'from',
  'las',
  'vegas',
  'to',
  'los',
  'angeles',
  'california',
  'and',
  'arrive',
  'on',
  'april',
  'ninth',
  'between',
  '4',
  'and',
  '5',
  'pm?']]

In [388]:
slot_logits[2][5]

tensor([-1.0772, -1.0156, -0.3700, -0.3498, -0.3873,  3.2950, -0.6526, -0.3704,
        -0.6510, -0.4977])

In [389]:
intent_label_lst

['time_series', 'fact_count', 'fact_list', 'UNK']

In [390]:
for word, pred in zip(words, slot_preds):
    print(word, pred)

Are B-entity_type
there B-entity_type
any B-entity_type
drops B-entity_type
in B-entity_type
network B-entity_type
over B-entity_type
the B-entity_type
last B-entity_type
10 B-entity_type
months B-entity_type


In [391]:
model_args.use_crf

False