In [20]:
from functools import reduce
import itertools

import numpy as np
import torch, torchtext
from torchtext import data, datasets
from transformers import BertConfig, BertTokenizer, BertForTokenClassification

In [2]:
bert_model_type = 'bert-base-uncased'
max_tok_len = 128
cls_token, sep_token = '[CLS]', '[SEP]'

In [7]:
def load_train():    
    max_len, sents = 0, []
    with open('data/atis.sentences.train.csv', 'r') as f:
        raw = f.read().splitlines()        
        for sent in raw:
            splits = sent.split(',')
            max_len = max(max_len, len(splits))
            # Wordpiece tokenizer (downstream) doesn't play well with the
            # default BertTokenizer's cls_token and sep_token parameters
            splits = [cls_token] + splits[1:-1] + [sep_token]
            sents.append(' '.join(splits))    
    with open('data/atis.slots.train.csv', 'r') as f:
        raw = f.read().splitlines()
        slots = [x.split(',') for x in raw]
    with open('data/atis.intent.train.csv', 'r') as f:
        intents = f.read().splitlines()
    # Replace BOS slot label with intent
    slots = [[intents[idx]] + x[1:] for idx, x in enumerate(slots)]
    # Lower-case all slot labels - text lower casing handled by Bert tokenizer
    lower_slots = list(map(lambda x: [y.lower() for y in x], slots))
    del intents
    return sents, lower_slots, max_len

In [14]:
def conv_to_idxes(tokenizer, sents, max_tok_len):
    num_sents = len(sents)
    # Choosing pad-token idx to be 0 by default
    tokens = np.zeros((num_sents, max_tok_len), dtype=np.int)
    relevant_tok_mask = np.zeros((num_sents, max_tok_len), dtype=np.int)
    attn_mask = np.zeros((num_sents, max_tok_len), dtype=np.int)
    
    # Use wordpiece tokenizer, and maintain idxes of those tokens that are useful
    # In this case, that is the CLS-token, and the first subword token for every word
    for idx, sent in enumerate(sents):
        sent_toks = tokenizer.tokenize(sent)
        attn_mask[idx, :len(sent_toks)] = 1
        tokens[idx, :len(sent_toks)] = tokenizer.convert_tokens_to_ids(sent_toks)
        relevant_tok_mask[idx, :len(sent_toks)] = [0 
                                                   if (tok.startswith('#')
                                                       or tok in ('EOS')) is True
                                                   else 1
                                                   for idx, tok in enumerate(sent_toks)]
    return tokens, relevant_tok_mask, attn_mask

In [24]:
def to_categorical(labels):
    label_map = {l.lower():i for i, l in enumerate(set(itertools.chain(*labels)))}
    idxes = list()
    for label_list in labels:
        idxes.append([label_map[l.lower()] for l in label_list])
    return label_map, idxes

In [8]:
sents, slots, max_sent_token_len = load_train()

In [10]:
# NOTE: Wordpiece tokenizer doesn't respect cls_token/sep_token supplied here
tokenizer = BertTokenizer.from_pretrained(bert_model_type,
                                          do_basic_tokenize=False,
                                          do_lower_case=True)

I1224 03:03:24.523384 140286189291328 tokenization_utils.py:398] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/sduddu/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084


In [17]:
# FIXME: This isn't entirely correct, since max_tok_len is for the *wordpiece* tokenizer len
assert (max_sent_token_len <= max_tok_len), "Max naive token len greater than max tok len"
tokens, rel_tok_mask, attn_mask = conv_to_idxes(tokenizer, sents, max_tok_len)

In [25]:
slot_map, slot_ids = to_categorical(slots)