# 模型构建与损失函数


**目录：**
1. JointBERT模型
    - 分类层;
    - CRF层;


2. 损失函数计算

---


In [9]:
import torch
import torch.nn as nn

from torch.utils.data import TensorDataset, RandomSampler, DataLoader

# 以BERT为预训练模型进行讲解
from transformers import BertPreTrainedModel, BertModel, BertConfig
from torchcrf import CRF  # pip install pytorch-crf

### 两个分类任务各自的MLP层

In [10]:
# intent分类的MLP全连接层
class IntentClassifier(nn.Module):
    def __init__(self, input_dim, num_intent_labels, dropout_rate=0.):
        super(IntentClassifier, self).__init__()
        self.dropout = nn.Dropout(dropout_rate)
        self.linear = nn.Linear(input_dim, num_intent_labels)

    def forward(self, x):
        # x: [batch_size, input_dim]
        x = self.dropout(x)
        return self.linear(x)

    
# slot分类的MLP全连接层
class SlotClassifier(nn.Module):
    def __init__(self, input_dim, num_slot_labels, dropout_rate=0.):
        super(SlotClassifier, self).__init__()
        self.dropout = nn.Dropout(dropout_rate)
        self.linear = nn.Linear(input_dim, num_slot_labels)

    def forward(self, x):
        # x: [batch_size, max_seq_len, input_dim]
        x = self.dropout(x)
        return self.linear(x)

### 主要的模型框架

In [11]:
# class JointBERT(BertPreTrainedModel):
class JointBERT(nn.Module):
    def __init__(self, config, args, intent_label_lst, slot_label_lst):
        super(JointBERT, self).__init__(config)
        self.args = args
        self.num_intent_labels = len(intent_label_lst)
        self.num_slot_labels = len(slot_label_lst)
        
        self.bert = BertModel(config=config)  # Load pretrained bert

        self.intent_classifier = IntentClassifier(config.hidden_size, self.num_intent_labels, args.dropout_rate)
        self.slot_classifier = SlotClassifier(config.hidden_size, self.num_slot_labels, args.dropout_rate)

        if args.use_crf:
            self.crf = CRF(num_tags=self.num_slot_labels, batch_first=True)

    def forward(self, input_ids, attention_mask, token_type_ids, intent_label_ids, slot_labels_ids):
        outputs = self.bert(input_ids, attention_mask=attention_mask,
                            token_type_ids=token_type_ids)  # sequence_output, pooled_output, (hidden_states), (attentions)
        sequence_output = outputs[0]  # [bsz, seq_len, hidden_dim]
        pooled_output = outputs[1]  # [CLS]上的输出, BertPooler module, MLP, tanh, 

        intent_logits = self.intent_classifier(pooled_output)
        slot_logits = self.slot_classifier(sequence_output)

        total_loss = 0
        # 1. 计算intent分类任务的loss
        if intent_label_ids is not None:
            if self.num_intent_labels == 1:   # STS-B： 回归任务
                intent_loss_fct = nn.MSELoss()
                intent_loss = intent_loss_fct(intent_logits.view(-1), intent_label_ids.view(-1))
            else:
                intent_loss_fct = nn.CrossEntropyLoss()
                intent_loss = intent_loss_fct(
                    intent_logits.view(-1, self.num_intent_labels), 
                    intent_label_ids.view(-1)
                )
            total_loss += intent_loss

        # 2. Slot Softmax
        if slot_labels_ids is not None:
            if self.args.use_crf:
                slot_loss = self.crf(
                    slot_logits, 
                    slot_labels_ids, 
                    mask=attention_mask.byte(), 
                    reduction='mean',
                )
                slot_loss = -1 * slot_loss  # negative log-likelihood
            else:
                # 指定ignore_index
                slot_loss_fct = nn.CrossEntropyLoss(ignore_index=self.args.ignore_index)
                # Only keep active parts of the loss
                # 只计算非padding部分的loss
                if attention_mask is not None:
                    active_loss = attention_mask.view(-1) == 1   # [B * L, 1]
                    print("active_loss: ", active_loss)
                    
                    active_logits = slot_logits.view(
                        -1, self.num_slot_labels
                    )[active_loss]  # [B * L , num_slot_labels]
                    print("active_logits: ", active_logits)
                    
                    active_labels = slot_labels_ids.view(-1)[active_loss]   # [-1, 1]
                    print("active_labels: ", active_labels)
                    
                    slot_loss = slot_loss_fct(active_logits, active_labels)
                    
                else:
                    slot_loss = slot_loss_fct(
                        slot_logits.view(-1, self.num_slot_labels), 
                        slot_labels_ids.view(-1)
                    )
            # total loss = intent_loss + coef*slot_loss
            total_loss += self.args.slot_loss_coef * slot_loss

        outputs = ((intent_logits, slot_logits),) + outputs[2:]  # add hidden states and attention if they are here

        outputs = (total_loss,) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions) # Logits is a tuple of intent and slot logits


### 损失函数 CrossEntropyLoss
Pytorch中CrossEntropyLoss()函数的主要是将softmax -> log -> NLLLoss合并到一块得到的结果。
$$L=- \sum_{i=1}^{N}y_i* \log \hat{y_i}$$
$y_i$是真正类别的one-hot分布，只有真实类别的概率为1，其他都是0，$\hat{y_i}$是经由softmax后的分布

- softmax将输出数据规范化为一个概率分布。

- 然后将Softmax之后的结果取log

- 输入负对数损失函数

### 举例查看

In [12]:
from transformers import BertConfig, DistilBertConfig, AlbertConfig
from transformers import BertTokenizer, DistilBertTokenizer, AlbertTokenizer

from JointBERT.model import JointBERT, JointDistilBERT, JointAlbert
from JointBERT.utils import init_logger, load_tokenizer, get_intent_labels, get_slot_labels
from JointBERT.data_loader import load_and_cache_examples

MODEL_CLASSES = {
    'bert': (BertConfig, JointBERT, BertTokenizer),
    'distilbert': (DistilBertConfig, JointDistilBERT, DistilBertTokenizer),
    'albert': (AlbertConfig, JointAlbert, AlbertTokenizer)
}

MODEL_PATH_MAP = {
    'bert': 'resources/bert_base_uncased',
    'distilbert': 'distilbert-base-uncased',
    'albert': 'albert-xxlarge-v1',
}

In [13]:
# 先构建参数
class Args():
    task =  None
    data_dir =  None
    intent_label_file =  None
    slot_label_file =  None

args = Args()
args.task = "atis"
args.data_dir = "./data"
args.intent_label_file = "intent_label.txt"
args.slot_label_file = "slot_label.txt"
args.max_seq_len = 50
args.model_type = "bert"
args.model_dir = "experiments/jointbert_0"
args.model_name_or_path = MODEL_PATH_MAP[args.model_type]

args.ignore_index = -100

args.train_batch_size = 4

args.dropout_rate = 0.1
args.use_crf = False

args.slot_loss_coef = 1.0


In [14]:
tokenizer = load_tokenizer(args)


config = MODEL_CLASSES[args.model_type][0].from_pretrained(args.model_name_or_path)

intent_label_lst = get_intent_labels(args)
slot_label_lst = get_slot_labels(args)

num_intent_labels = len(intent_label_lst)
num_slot_labels = len(slot_label_lst)

model = JointBERT(config, args, intent_label_lst, slot_label_lst)

In [18]:
# load dataset 
train_dataset = load_and_cache_examples(args, tokenizer, mode="train")

# torch自带的sampler类，功能是每次返回一个随机的样本索引
train_sampler = RandomSampler(train_dataset)
# 使用dataloader输出batch
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

device = "cpu"
for step, batch in enumerate(train_dataloader):
    batch = tuple(t.to(device) for t in batch) # 将batch上传到显卡
    inputs = {"input_ids": batch[0],
              "attention_mask": batch[1],
              "token_type_ids": batch[2],
              "intent_label_ids": batch[3],
              "slot_labels_ids": batch[4]}
    
    input_ids = inputs["input_ids"]  # [B, L]    
    
    attention_mask = inputs["attention_mask"]  # [B, L]
    token_type_ids = inputs["token_type_ids"]  # [B, L]
    intent_label_ids = inputs["intent_label_ids"]   # [B, ]
    
    slot_labels_ids = inputs["slot_labels_ids"]   # [B, L]
    
    
    if step > 1:
        break
        
    print("input_ids: ", input_ids.shape)
    print("slot_labels_ids: ", slot_labels_ids.shape)
    print("slot_labels_ids: ", slot_labels_ids)
    
    outputs = model.bert(input_ids, attention_mask=attention_mask,
                            token_type_ids=token_type_ids)  # sequence_output, pooled_output, (hidden_states), (attentions)
    
    
    sequence_output = outputs[0]   # [B, L, H]
    print("sequence_output: ", sequence_output.shape)
    
    pooled_output = outputs[1]   # [B, H]
    print("pooled_output: ", pooled_output.shape)
    
    # 计算intent分类的损失
    intent_logits = model.intent_classifier(pooled_output)   # [B, 22]
    print("intent_logits: ", intent_logits.shape)
    
    intent_loss_fct = nn.CrossEntropyLoss()
    intent_loss = intent_loss_fct(intent_logits.view(-1, num_intent_labels), intent_label_ids.view(-1))
        
    ####################################################################################
    # 采用JointBERT模型的写法，计算 active loss，也就是只计算句子中的非padding部分的损失
    ####################################################################################
    
    # [CLS], [SEP], 
    # word 非开始tokens， 
    
    slot_logits = model.slot_classifier(sequence_output)
    print("slot_logits: ", slot_logits.shape)
    
    
    active_loss = attention_mask.view(-1) == 1
    print("active_loss: ", active_loss.shape)
    
    active_logits = slot_logits.view(-1, num_slot_labels)[active_loss]
    print("slot_logits: ", slot_logits.shape)
    print("active_logits: ", active_logits.shape)

    active_labels = slot_labels_ids.view(-1)[active_loss]
    print("active_labels: ", active_labels.shape)
    
    slot_loss_fct = nn.CrossEntropyLoss()
    slot_loss = slot_loss_fct(active_logits, active_labels)
    print("slot_loss: ", slot_loss)
    
    ####################################################################################
    # 直接计算: 利用 ignore_index
    ####################################################################################
    slot_loss_fct = nn.CrossEntropyLoss(ignore_index=args.ignore_index)
    slot_loss = slot_loss_fct(
        slot_logits.view(-1, num_slot_labels), 
        slot_labels_ids.view(-1)
    )
    print("slot_loss: ", slot_loss)
    
    
    

input_ids:  torch.Size([4, 50])
slot_labels_ids:  torch.Size([4, 50])
slot_labels_ids:  tensor([[-100,    2,    2,    2,   63,    2,    2,   73,    2,  114, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100],
        [-100,   48,   48,    2,    2,   73,    2,  114, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100],
        [-100,    2,    2,    2,   81,    2,    2,   73,    2,  114,  115,    2,
           38,   41,   39, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -

In [19]:
#  改为使用crf : use_crf =True

args.use_crf = True
model = JointBERT(config, args, intent_label_lst, slot_label_lst)

# load dataset 
train_dataset = load_and_cache_examples(args, tokenizer, mode="train")

# torch自带的sampler类，功能是每次返回一个随机的样本索引
train_sampler = RandomSampler(train_dataset)
# 使用dataloader输出batch
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

device = "cpu"
for step, batch in enumerate(train_dataloader):
    batch = tuple(t.to(device) for t in batch) # 将batch上传到显卡
    inputs = {"input_ids": batch[0],
              "attention_mask": batch[1],
              "token_type_ids": batch[2],
              "intent_label_ids": batch[3],
              "slot_labels_ids": batch[4]}
    
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    token_type_ids = inputs["token_type_ids"]
    intent_label_ids = inputs["intent_label_ids"]
    slot_labels_ids = inputs["slot_labels_ids"]
    
    if step > 0:
        break
    
    outputs = model.bert(input_ids, attention_mask=attention_mask,
                            token_type_ids=token_type_ids)  # sequence_output, pooled_output, (hidden_states), (attentions)
    sequence_output = outputs[0]
    
    slot_logits = model.slot_classifier(sequence_output)
    
    slot_loss = model.crf(slot_logits, slot_labels_ids, mask=attention_mask.byte(), reduction='mean')
    slot_loss = -1 * slot_loss  # negative log-likelihood
    print("slot_loss: ", slot_loss)
    


slot_loss:  tensor(66.9170, grad_fn=<MulBackward0>)
