此模型采用bert[BERT for Joint Intent Classification and Slot Filling](https://arxiv.org/pdf/1902.10909.pdf)进行意图识别与槽填充。
结构如下：

![model](img/bert.png)

从上可知：

    1.意图识别采用[cls]的输出进行识别
    
    2.槽填充直接输出对应的结果进行序列标注，这里不使用mlm中的mask
    
步骤：

    1.输入是单句，[cls]用于最后的意图识别, [sep]作为句子的最后部分
    
    2.再加入位置信息
    
    3.经过embedding -> transformer-encoder -> 输出

In [1]:
import os
from torchtext import data, datasets
import pandas as pd
import pickle

base_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))
atis_data = os.path.join(base_dir, 'atis')

In [2]:
'''
build train and val dataset
'''
    
tokenize = lambda s:s.split()

SOURCE = data.Field(sequential=True, tokenize=tokenize,
                    lower=True, use_vocab=True,
                    init_token='<cls>', eos_token='<sep>',
                    pad_token='<pad>', unk_token='<unk>',
                    batch_first=True, fix_length=50,
                    include_lengths=True) #include_lengths=True为方便之后使用torch的pack_padded_sequence

TARGET = data.Field(sequential=True, tokenize=tokenize,
                    lower=True, use_vocab=True,
                    init_token='<cls>', eos_token='<sep>',
                    pad_token='<pad>', unk_token='<unk>',
                    batch_first=True, fix_length=50,
                    include_lengths=True) #include_lengths=True为方便之后使用torch的pack_padded_sequence
LABEL = data.Field(
                sequential=False,
                use_vocab=True)

train, val = data.TabularDataset.splits(
                                        path=atis_data,
                                        skip_header=True,
                                        train='atis.train.csv',
                                        validation='atis.test.csv',
                                        format='csv',
                                        fields=[('index', None), ('intent', LABEL), ('source', SOURCE), ('target', TARGET)])

SOURCE.build_vocab(train, val)
TARGET.build_vocab(train, val)
LABEL.build_vocab(train, val)

train_iter, val_iter = data.Iterator.splits(
                                            (train, val),
                                            batch_sizes=(64, len(val)), # 训练集设置为32,验证集整个集合用于测试
                                            shuffle=True,
                                            sort_within_batch=True, #为true则一个batch内的数据会按sort_key规则降序排序
                                            sort_key=lambda x: len(x.source)) #这里按src的长度降序排序，主要是为后面pack,pad操作)

In [3]:
# save source words
source_words_path = os.path.join(os.getcwd(), 'source_words.pkl')
with open(source_words_path, 'wb') as f_source_words:
    pickle.dump(SOURCE.vocab, f_source_words)

# save target words
target_words_path = os.path.join(os.getcwd(), 'target_words.pkl')
with open(target_words_path, 'wb') as f_target_words:
    pickle.dump(TARGET.vocab, f_target_words)
    
# save label words
label_words_path = os.path.join(os.getcwd(), 'label_words.pkl')
with open(label_words_path, 'wb') as f_label_words:
    pickle.dump(LABEL.vocab, f_label_words)

In [4]:
print('{}, {}'.format(SOURCE.init_token, SOURCE.eos_token))

<cls>, <sep>


In [5]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import math
from apex import amp
import time

In [6]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED = 1234
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic=True
torch.backends.cudnn.benchmark=False

In [7]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, dropout, max_length=100):
        super(Encoder, self).__init__()
        self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        
        # 多层encoder
        self.layers = nn.ModuleList([EncoderLayer(hid_dim, n_heads, pf_dim, dropout) for _ in range(n_layers)])
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(DEVICE)

    def forward(self, src, src_mask):
        #src:[batch_size, src_len]
        #src_mask:[batch_size, 1, 1, src_len]
        
        batch_size = src.shape[0]
        src_len = src.shape[1]
        #位置信息
        pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(DEVICE)
        #token编码+位置编码
        src = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos)) # [batch_size, src_len, hid_dim]
        
        for layer in self.layers:
            src = layer(src, src_mask) #[batch_size, src_len, hid_dim]
        
        return src
            
class EncoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src, src_mask):
        # src:[batch_size, src_len, hid_dim]
        # src_mask:[batch_size, 1, 1, src_len]
        
        # 1.经过多头attetnion后，再经过add+norm
        # self-attention
        _src = self.self_attention(src, src, src, src_mask)
        
        src = self.self_attn_layer_norm(src + self.dropout(_src)) # [batch_size, src_len, hid_dim]
        
        # 2.经过一个前馈网络后，再经过add+norm
        _src = self.positionwise_feedforward(src)
        
        src = self.ff_layer_norm(src + self.dropout(_src)) # [batch_size, src_len, hid_dim]
        
        return src     

class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout):
        super(MultiHeadAttentionLayer, self).__init__()
        assert hid_dim % n_heads == 0
        
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        
        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)
        
        self.fc_o = nn.Linear(hid_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(DEVICE)
        
    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]
        
        # query:[batch_size, query_len, hid_dim]
        # key:[batch_size, query_len, hid_dim]
        # value:[batch_size, query_len, hid_dim]
        
        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)
        
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) # [batch_size, query_len, n_heads, head_dim]
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale # [batch_size, n_heads, query_len, key_len]
        
        if mask is not None:
            energy = energy.mask_fill(mask == 0, -1e10)
        
        attention = torch.softmax(energy, dim=-1) # [batch_size, n_heads, query_len, key_len]
        
        x = torch.matmul(self.dropout(attention), V) # [batch_size, n_heads, query_len, head_dim]
        
        x = x.permute(0, 2, 1, 3).contiguous() # [batch_size, query_len, n_heads, head_dim]
        
        x = x.view(batch_size, -1, self.hid_dim) # [batch_size, query_len, hid_dim]
        
        x = self.fc_o(x) # [batch_size, query_len, hid_dim]
        
        return x
        
class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        super(PositionwiseFeedforwardLayer, self).__init__()
        self.fc_1 = nn.Linear(hid_dim, pf_dim)
        self.fc_2 = nn.Linear(pf_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)
        self.gelu = nn.GELU()
        
    def forward(self, x):
        # x:[batch_size, seq_len, hid_dim]
        
        x = self.dropout(self.gelu(self.fc_1(x))) # [batch_size, seq_len, pf_dim]
        x = self.fc_2(x) # [batch_size, seq_len, hid_dim]
        
        return x

class BERT(nn.Module):
    def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, dropout, slot_size, intent_size, src_pad_idx):
        super(BERT, self).__init__()
        self.src_pad_idx = src_pad_idx
        self.encoder = Encoder(input_dim, hid_dim, n_layers, n_heads, pf_dim, dropout)
        self.gelu = nn.GELU()
        
        self.fc = nn.Sequential(nn.Linear(hid_dim, hid_dim), nn.Dropout(dropout), nn.Tanh())
        self.intent_out = nn.Linear(hid_dim, intent_size)
        self.linear = nn.Linear(hid_dim, hid_dim)
       
        embed_weight = self.encoder.tok_embedding.weight
        self.slot_out = nn.Linear(hid_dim, slot_size, bias=False)
        self.slot_out.weight = embed_weight
    
    def make_src_mask(self, src):
        # src: [batch_size, src_len]
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2) # [batch_size, 1, 1, src_len]
        
    def forward(self, src):
        src_mask = self.make_src_mask(src)
        encoder_out = self.encoder(src, src_mask) #[batch_size, src_len, hid_dim]
        
        # 拿到[cls] token进行意图分类
        cls_hidden = self.fc(encoder_out[:, 0]) # [batch_size, hid_dim]
        intent_output = self.intent_out(cls_hidden) # [batch_size, intent_size]
        
        # 排除cls进行slot预测
        other_hidden = self.gelu(self.linear(encoder_out[:,1:])) # [batch_sze, src_len-1, hid_dim]
        slot_output = self.slot_out(other_hidden) # [batch_size, src_len-1, slot_size]
        return intent_output, slot_output

In [9]:
n_layers = 6 # transformer-encoder层数
n_heads = 12 # 多头self-attention
hid_dim =768 
dropout = 0.5
pf_dim = 768 * 4 

input_dim = len(SOURCE.vocab)
slot_size = len(TARGET.vocab) # slot size
intent_size = len(LABEL.vocab) # intent size
src_pad_idx = SOURCE.vocab.stoi[SOURCE.pad_token]

model = BERT(input_dim, hid_dim, n_layers, n_heads, pf_dim, dropout, slot_size, intent_size, src_pad_idx).to(DEVICE)

# 优化函数
optimizer = optim.Adam(model.parameters(), lr=0.01)

# 损失函数(slot)
loss_slot = nn.CrossEntropyLoss(ignore_index=src_pad_idx)

# 定义损失函数(意图识别)
loss_intent = nn.CrossEntropyLoss()

In [10]:
# 训练
def train(model, iterator, optimizer, loss_slot, loss_intent, clip):
    
    model.train()
    epoch_loss = 0
    
    for i, batch in enumerate(iterator):
        
        src, _ = batch.source  # src=[batch_size, seq_len]，这里batch.src返回src和src的长度，因为在使用torchtext.Field时设置include_lengths=True
        trg, _ = batch.target  # trg=[batch_size, seq_len]
        label = batch.intent # [batch_size]
        src = src.to(DEVICE)
        trg = trg.to(DEVICE)
        label = label.to(DEVICE)
        
        optimizer.zero_grad()
        
        intent_output, slot_output = model(src) # [batch_size, intent_size]; [batch_size, trg_len-1, slot_size]
        
        # 1.计算slot loss
        slot_output_dim = slot_output.shape[-1]
        
        slot_output = slot_output.reshape(-1, slot_output_dim) # [batch_size * (trg_len-1), slot_output_dim]
        
        trg = trg[:,1:].contiguous().view(-1) # [batch_size * (trg_len-1)]
        
        # 1.计算slot loss
        loss1 = loss_slot(slot_output, trg)
        
        # 2.计算intent loss
        loss2 = loss_intent(intent_output, label)
        
        # 3.联合slot loss + intent loss
        loss = loss1 + loss2
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
    
    return epoch_loss / len(iterator)
        

In [11]:
# val loss
def evaluate(model, iterator, loss_slot, loss_intent):
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            src, _ = batch.source  # src=[batch_size, seq_len]
            trg, _ = batch.target  # trg=[batch_size, seq_len]
            label = batch.intent
            src = src.to(DEVICE)
            trg = trg.to(DEVICE)
            label = label.to(DEVICE)
            
            intent_output, slot_output = model(src) # [batch_size, intent_dim]; [batch_size, trg_len-1, slot_size]
            
            # 1.计算slot loss
            slot_output_dim = slot_output.shape[-1]

            slot_output = slot_output.reshape(-1, slot_output_dim) # [batch_size * (trg_len-1), slot_output_dim]

            trg = trg[:,1:].contiguous().view(-1) # [batch_size * (trg_len-1)]

            loss1 = loss_slot(slot_output, trg)

            # 2.计算intent loss
            loss2 = loss_intent(intent_output, label)

            loss = loss1 + loss2
        
            
            epoch_loss += loss.item()
            
    return epoch_loss / len(iterator)

In [12]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
n_epochs = 100 # 迭代次数
clip = 0.1 # 梯度裁剪

model_path = os.path.join(os.getcwd(), "model.h5")

best_valid_loss = float('inf')

for epoch in range(n_epochs):
    
    start_time = time.time()
    
    train_loss = train(model, train_iter, optimizer, loss_slot, loss_intent, clip)
    valid_loss = evaluate(model, val_iter, loss_slot, loss_intent)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time) # 每个epoch花费的时间
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), model_path)
        
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

Epoch: 01 | Time: 0m 22s
	Train Loss: 37.131 | Train PPL: 13362582122889738.000
	 Val. Loss: 8.320 |  Val. PPL: 4104.540
Epoch: 02 | Time: 0m 21s
	Train Loss: 3.868 | Train PPL:  47.841
	 Val. Loss: 5.999 |  Val. PPL: 402.855
