This repository is Pytorch implementation of "Zero-shot User Intent Detection via Capsule Neural Networks".
Details of this model is available in the original paper, "Zero-shot user intent detection via capsule neural networks"[1].
And this Pytorch implementation is revised and upgraded version of the original repository, "Zero-shot User Intent Detection via Capsule Neural Networks (PyTorch Implementation)"[2].
The details of the model structure are as follows.
-
Unlike the existing version, you can use BERT[3] as the encoder.
Also besides word2vec,
nn.Embedding
layer can be added to be trained in the beginning as an option.Tokenizers and embedding methods can be different by the model type you choose.
bert_capsnet basic_capsnet w2v_capsnet Encoder BERT( bert-base-uncased
)BiLSTM BiLSTM Tokenizer BERT Tokenizer BERT Tokenizer WhiteSpace Embedding BERT Embedding Pytorch nn.Embedding
GoogleNews Word2Vec
-
In addition to zero shot intent detection task, you can train & test the model in original seen intent classification task.
All you need to do is just to specify the mode option.
This repository contains the sample dataset in data/raw
directory.
The sample is SNIPS NLU benchmark dataset[4] parsed only for intent tags and texts.
You can use different dataset but you should set the raw data file same as the sample's format.
Each txt
file represents one intent and each line in a file consists of intent and text, separated by \t
.
argument | type | description | default |
---|---|---|---|
seed |
int |
The random seed. | 0 |
batch_size |
int |
The batch size. | 16 |
learning_rate |
float |
The learning rate. | 1e-4 |
num_epochs |
int |
The total number of epochs. | 10 |
max_len |
int |
The maximum input length. | 128 |
dropout |
float |
The dropout rate. | 0.0 |
d_a |
int |
The dimension size of an internal vector during self-attention. | 80 |
num_props |
int |
The number of properties in each capsule. | 10 |
r |
int |
The number of semantic features | 3 |
num_iters |
int |
The number of iterations for the dynamic routing algorithm. | 1 |
alpha |
float |
The coefficient value for encouraging the discrepancies among different attention heads in the loss function. | 1e-4 |
sim_scale |
int |
The scaling factor for intent similarity. | 1 |
num_layers |
int |
The number of layers for an LSTM encoder. | 1 |
ckpt_dir |
str |
The directory for trained ckpts. | "saved_models" |
data_dir |
str |
The directory for data. | "data" |
raw_dir |
str |
The directory for raw data. | "raw" |
train_frac |
float |
The ratio of the conversations to be included in the train set. | 0.8 |
train_prefix |
str |
The train data file name's prefix. | "train" |
valid_prefix |
str |
The validation data file name's prefix. | "valid" |
model_type |
str |
The model type. ("bert_capsnet" , "basic_capsnet" , "w2v_capsnet" ) |
"bert_capsnet" |
mode |
str |
Seen class or zero shot? ("seen_class" , "zero_shot" ) |
"seen_class" |
bert_embedding_frozen |
str |
Do you want to freeze BERT's embedding layer or not? ("True" , "False" ) |
"False" |
gpu |
str |
The index of gpu to use. | "0" |
-
Install all required packages.
pip install -r requirements.txt
-
Run below codes to train & test a model. (You might have to adjust each argument as you desire...)
sh exec_train.sh
You will have the processed data files as follows. (If you follow the default directory names...)
data └--raw └--intent0.txt └--intent1.txt └--... └--intent(I-1).txt └--MODE(seen_class/zero_shot) └--train.txt └--valid.txt
[1] Xia, C., Zhang, C., Yan, X., Chang, Y., & Yu, P. S. (2018). Zero-shot user intent detection via capsule neural networks. arXiv preprint arXiv:1809.00385. (https://arxiv.org/abs/1809.00385)
[2] Zero-shot User Intent Detection via Capsule Neural Networks (PyTorch Implementation). (https://github.com/nhhoang96/ZeroShotCapsule-PyTorch-)
[3] Devlin, J., Chang, M. W., Lee, K., & Toutanova, K. (2018). Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805.
[4] Natural Language Understanding benchmark. (https://github.com/sonos/nlu-benchmark)