# 模型构建与损失函数


**目录：**
1. FewShotMultiLabel模型： SchemaFewShotTextClassifier
    - ContextEmbedder；
    - emission scorer；
    - similarity scorer;
    - decoder；


2. 损失函数计算

---


In [None]:
import torch

### ContextEmbedder



In [None]:
# 基类
class ContextEmbedderBase(torch.nn.Module):
    def __init__(self):
        super(ContextEmbedderBase, self).__init__()

    def forward(self, *args, **kwargs) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
        """
        :param args:
        :param kwargs:
        :return: test_token_reps, support_token_reps, test_sent_reps, support_sent_reps
        """
        raise NotImplementedError()

In [None]:
# 基于 BERT 的上下文编码器

class BertContextEmbedder(ContextEmbedderBase):
    def __init__(self, opt):
        super(BertContextEmbedder, self).__init__()
        self.opt = opt
        self.embedder = self.build_embedder()  # 加载BERT预训练模型

    def forward(
            self,
            test_token_ids: torch.Tensor,
            test_segment_ids: torch.Tensor,
            test_nwp_index: torch.Tensor,
            test_input_mask: torch.Tensor,
            support_token_ids: torch.Tensor = None,
            support_segment_ids: torch.Tensor = None,
            support_nwp_index: torch.Tensor = None,
            support_input_mask: torch.Tensor = None,
    ) -> (torch.Tensor, torch.Tensor):
        """
        get context representation
        :param test_token_ids: (batch_size, test_len)
        :param test_segment_ids: (batch_size, test_len)
        :param test_nwp_index: (batch_size, test_len, 1)
        :param test_input_mask: (batch_size, test_len)

        ======= Support features ======
        We allow to only embed query to enable single sentence embedding, but such feature is NOT used now.
        (Separate embedding is achieved through special sub classes)

        :param support_token_ids: (batch_size, support_size, support_len)
        :param support_segment_ids: (batch_size, support_size, support_len)
        :param support_nwp_index: (batch_size, support_size, support_len, 1)
        :param support_input_mask: (batch_size, support_size, support_len)
        :return:
            
            # 如果没有support set，则只进行query的编码
            if do concatenating representation:
                return (test_reps, support_reps, test_sent_reps, support_sent_reps):
                    test_reps, support_reps:  token reps (batch_size, support_size, nwp_sent_len, emb_len)
                    test_sent_reps, support_sent_reps: sent reps (batch_size, support_size, 1, emb_len)
            else do representation for a single sent (No support staff):
                return test_reps, shape is (batch_size, nwp_sent_len, emb_len)
        """
        if support_token_ids is not None:
            return self.concatenate_reps(
                test_token_ids, test_segment_ids, test_nwp_index, test_input_mask,
                support_token_ids, support_segment_ids, support_nwp_index, support_input_mask,
            )
        else:
            return self.single_reps(test_token_ids, test_segment_ids, test_nwp_index, test_input_mask,)

    def build_embedder(self):
        """ load bert here """
        return BertModel.from_pretrained(self.opt.bert_path)

    def concatenate_reps(
            self,
            test_token_ids: torch.Tensor,
            test_segment_ids: torch.Tensor,
            test_nwp_index: torch.Tensor,
            test_input_mask: torch.Tensor,
            support_token_ids: torch.Tensor,
            support_segment_ids: torch.Tensor,
            support_nwp_index: torch.Tensor,
            support_input_mask: torch.Tensor,
    ) -> (torch.Tensor, torch.Tensor):
        """ get token reps of a sent pair. """
        # support set 的 最大数量
        support_size = support_token_ids.shape[1]
        
        # 序列长度
        test_len = test_token_ids.shape[-1] - 2  # max len, exclude [CLS] and [SEP] token
        support_len = support_token_ids.shape[-1] - 1  # max len, exclude [SEP] token
        
        # 
        batch_size = support_token_ids.shape[0]
        
        # 将query数据的形状更改
        ''' expand test input to shape: (batch_size, support_size, test_len)'''
        test_token_ids, test_segment_ids, test_input_mask, test_nwp_index = self.expand_test_item(
            test_token_ids, test_segment_ids, test_input_mask, test_nwp_index, support_size)
        
        # 将query和每个support set样本拼接： concat操作，沿最后一个维度        
        ''' concat test and support '''
        input_ids = self.cat_test_and_support(test_token_ids, support_token_ids)
        segment_ids = self.cat_test_and_support(test_segment_ids, support_segment_ids)
        input_mask = self.cat_test_and_support(test_input_mask, support_input_mask)
        
        # 将 (batch_size, support_size, cat_len) 头两个维度合并，则为 (batch_size * support_size, sent_len)
        ''' flatten input '''
        input_ids, segment_ids, input_mask = self.flatten_input(input_ids, segment_ids, input_mask)
        # 将 (batch_size, support_size, index_len, 1) 头两个维度合并，则为 (batch_size * support_size, index_len, 1)
        test_nwp_index, support_nwp_index = self.flatten_index(test_nwp_index), self.flatten_index(support_nwp_index)
        
        # bert编码输入: (batch_size * support_size, sent_len, hidden_dim)
        ''' get concat reps '''
        sequence_output = self.embedder(input_ids, input_mask, segment_ids)[0]
        
        '''
        torch.narrow(input, dim, start, length):
            The dimension dim is input from start to start + length.
        
        '''
        # 在序列长度这个维度，取出tensor的一部分
        ''' extract reps '''
        # select pure sent part, remove [SEP] and [CLS], notice: seq_len1 == seq_len2 == max_len.
        test_reps = sequence_output.narrow(-2, 1, test_len)  # shape:(batch * support_size, test_len, rep_size)
        support_reps = sequence_output.narrow(-2, 2 + test_len, support_len)  # shape:(batch * support_size, support_len, rep_size)
        
        # 取出wordpiece模型没有拆分的单词的表征
        # test_reps： shape (batch_size * support_size, test_len, hidden_dim)
        # test_nwp_index： shape (batch_size * support_size, index_len, 1)
        # select non-word-piece tokens' representation
        nwp_test_reps = self.extract_non_word_piece_reps(test_reps, test_nwp_index)
        nwp_support_reps = self.extract_non_word_piece_reps(support_reps, support_nwp_index)
        
        # 改变形状
        # resize to shape (batch_size, support_size, sent_len, emb_len)
        reps_size = nwp_test_reps.shape[-1]
        nwp_test_reps = nwp_test_reps.view(batch_size, support_size, -1, reps_size)
        nwp_support_reps = nwp_support_reps.view(batch_size, support_size, -1, reps_size)
        test_reps = test_reps.view(batch_size, support_size, -1, reps_size)
        support_reps = support_reps.view(batch_size, support_size, -1, reps_size)
        
        # average pooling
        # get whole sent reps
        test_sent_reps = self.get_sent_reps(test_reps, test_input_mask)
        support_sent_reps = self.get_sent_reps(support_reps, support_input_mask)
        return nwp_test_reps, nwp_support_reps, test_sent_reps, support_sent_reps

    def single_reps(
            self,
            test_token_ids: torch.Tensor,
            test_segment_ids: torch.Tensor,
            test_nwp_index: torch.Tensor,
            test_input_mask: torch.Tensor,
    ) -> (torch.Tensor, torch.Tensor):
        """ get token reps of a single sent. """
        test_len = test_token_ids.shape[-1] - 2  # max len, exclude [CLS] and [SEP] token
        batch_size = test_token_ids.shape[0]
        ''' get bert reps '''
        test_sequence_output = self.embedder(test_token_ids, test_input_mask, test_segment_ids)[0]
        ''' extract reps '''
        # select pure sent part, remove [SEP] and [CLS], notice: seq_len1 == seq_len2 == max_len.
        test_reps = test_sequence_output.narrow(-2, 1, test_len)  # shape:(batch, test_len, rep_size)
        # select non-word-piece tokens' representation
        nwp_test_reps = self.extract_non_word_piece_reps(test_reps, test_nwp_index)
        # get whole word reps, unsuqeeze to fit interface
        test_sent_reps = self.get_sent_reps(test_reps.unsqueeze(1), test_input_mask.unsqueeze(1)).squeeze(1)
        return nwp_test_reps, test_sent_reps

    def get_sent_reps(self, reps, input_mask):
        """
         Average token reps to get a whole sent reps
        :param reps:   (batch_size, support_size, sent_len, emb_len)
        :param input_mask:  (batch_size, support_size, sent_len)
        :return:  averaged reps (batch_size, support_size, sent_len, emb_len)
        """
        batch_size, support_size, sent_len, reps_size = reps.shape
        mask_len = input_mask.shape[-1]
        
        # 因为是要取平均，所以要数出来具体的实际句长
        # count each sent's tokens, to avoid over div with pad,  shape: (batch_size * support_size, 1)
        token_counts = torch.sum(input_mask.contiguous().view(-1, mask_len), dim=1).unsqueeze(-1)
        sp_token_num = input_mask.shape[-1] - reps.shape[-2]  # num of [CLS], [SEP] tokens
        token_counts = token_counts - sp_token_num + 0.00001  # calculate pure token num and remove zero
        
        # mask pad-token's reps to 0 vectors [Notice that by default pad token's reps are not 0-vector]
        # 去除BERT特殊符号对应的mask部分
        if sp_token_num == 2:
            trimed_mask = input_mask.narrow(-1, 1, reps.shape[-2]).float()  # remove mask of [CLS], [SEP]
        elif sp_token_num == 1:
            trimed_mask = input_mask.narrow(-1, 0, reps.shape[-2]).float()  # remove mask of [SEP]
        else:
            raise RuntimeError("Unexpected sp_token_num.")
        
        # 用mask把padding部分改为0，这样就不会影响最后的avg pooling
        reps = reps * trimed_mask.unsqueeze(-1)
        # sum reps, shape (batch_size * support_size, emb_len)
        sum_reps = torch.sum(reps.contiguous().view(-1, sent_len, reps_size), dim=1)
        # averaged reps (batch_size, support_size, emb_len)
        ave_reps = torch.div(sum_reps, token_counts.float()).contiguous().view(batch_size, support_size, reps_size)
        return ave_reps.unsqueeze(-2)

    def expand_test_item(
            self,
            test_token_ids: torch.Tensor,
            test_segment_ids: torch.Tensor,
            test_input_mask: torch.Tensor,
            test_nwp_index: torch.Tensor,
            support_size: int,
    ) -> (torch.Tensor, torch.Tensor, torch.Tensor):
        # 在第一个维度位置增加一个维度
        return self.expand_it(test_token_ids, support_size), self.expand_it(test_segment_ids, support_size), \
               self.expand_it(test_input_mask, support_size), self.expand_it(test_nwp_index, support_size)

    def expand_it(self, item: torch.Tensor, support_size):
        expand_shape = list(item.unsqueeze_(1).shape)
        expand_shape[1] = support_size
        
        # tensor_.expand(shape): 扩大维度，但不增加内存消耗
        return item.expand(expand_shape)

    def cat_test_and_support(self, test_item, support_item):
        return torch.cat([test_item, support_item], dim=-1)

    def flatten_input(self, input_ids, segment_ids, input_mask):
        """ resize shape (batch_size, support_size, cat_len) to shape (batch_size * support_size, sent_len) """
        sent_len = input_ids.shape[-1]
        input_ids = input_ids.view(-1, sent_len)
        segment_ids = segment_ids.view(-1, sent_len)
        input_mask = input_mask.view(-1, sent_len)
        return input_ids, segment_ids, input_mask

    def flatten_index(self, nwp_index):
        """ resize shape (batch_size, support_size, index_len, 1) to shape (batch_size * support_size, index_len, 1) """
        nwp_sent_len = nwp_index.shape[-2]
        return nwp_index.contiguous().view(-1, nwp_sent_len, 1)

    def extract_non_word_piece_reps(self, reps, index):
        """
        Use the first word piece as entire word representation
        As we have only one index for each token, we need to expand to the size of reps dim.
        """
        expand_shape = list(index.shape)
        expand_shape[-1] = reps.shape[-1]  # expend index over embedding dim, like 768
        index = index.expand(expand_shape)
        
        nwp_reps = torch.gather(input=reps, index=index, dim=-2)  
        # extract over token level
        # 沿着句子序列维度，取向量
        
        return nwp_reps

In [None]:
class BertSchemaContextEmbedder(BertContextEmbedder):
    def __init__(self, opt):
        super(BertSchemaContextEmbedder, self).__init__(opt)

    def forward(
            self,
            test_token_ids: torch.Tensor,
            test_segment_ids: torch.Tensor,
            test_nwp_index: torch.Tensor,
            test_input_mask: torch.Tensor,
            support_token_ids: torch.Tensor = None,
            support_segment_ids: torch.Tensor = None,
            support_nwp_index: torch.Tensor = None,
            support_input_mask: torch.Tensor = None,
            reps_type: str = 'test_support',
    ) -> (torch.Tensor, torch.Tensor):
        """
        get context representation
        :param test_token_ids: (batch_size, test_len)
        :param test_segment_ids: (batch_size, test_len)
        :param test_nwp_index: (batch_size, test_len, 1)
        :param test_input_mask: (batch_size, test_len)
        :param support_token_ids: (batch_size, support_size, support_len)
        :param support_segment_ids: (batch_size, support_size, support_len)
        :param support_nwp_index: (batch_size, support_size, support_len, 1)
        :param support_input_mask: (batch_size, support_size, support_len)
        :param reps_type: select the reps type, default: reps for test and support tokens. Special choice is for label
        :return:
            if do concatenating representation:
                return (test_reps, support_reps, test_sent_reps, support_sent_reps):
                    test_reps, support_reps:  token reps (batch_size, support_size, nwp_sent_len, emb_len)
                    test_sent_reps, support_sent_reps: sent reps (batch_size, support_size, 1, emb_len)
            else do representation for a single sent (No support staff):
                return test_reps, shape is (batch_size, nwp_sent_len, emb_len)
        """
        if reps_type == 'test_support':
            if support_token_ids is not None:
                return self.concatenate_reps(
                    test_token_ids, test_segment_ids, test_nwp_index, test_input_mask,
                    support_token_ids, support_segment_ids, support_nwp_index, support_input_mask,
                )
            else:
                return self.single_reps(test_token_ids, test_segment_ids, test_nwp_index, test_input_mask,)
        elif reps_type == 'label':
            return self.get_label_reps(test_token_ids, test_segment_ids, test_nwp_index, test_input_mask)
    
    
    # 这里主要关注 label_reps = “sep”的情形：
    #     -- 就是我们常见的用BERT表征一个句子的流程：采用[CLS]的向量表征
    def get_label_reps(self, test_token_ids, test_segment_ids, test_nwp_index, test_input_mask):
        batch_size = test_token_ids.shape[0]
        if self.opt.label_reps == 'cat':
            # todo: use label mask to represent a label with only in domain info
            reps = self.single_reps(test_token_ids, test_segment_ids, test_nwp_index, test_input_mask, )
        elif self.opt.label_reps in ['sep', 'sep_sum']:
            input_ids, segment_ids, input_mask = self.flatten_input(test_token_ids, test_segment_ids,
                                                                    test_input_mask)
            # get flatten reps: shape (batch_size * label_num, label_des_len)
            sequence_output = self.embedder(input_ids, input_mask, segment_ids)[0]
            reps_size = sequence_output.shape[-1]
            if self.opt.label_reps == 'sep':  # use cls as each label's reps
                # re-shape to  (batch_size, label_num, label_des_len)
                reps = sequence_output.narrow(-2, 0, 1)  # fetch all [CLS] shape:(batch, 1, rep_size)
                reps = reps.contiguous().view(batch_size, -1, reps_size)
            elif self.opt.label_reps == 'sep_sum':  # average all label reps as reps
                reps = sequence_output
                emb_mask = self.expand_mask(test_input_mask, 2, reps_size)
                # todo: use mask to get sum of single embedding
                raise NotImplementedError
            else:
                raise ValueError("Wrong label_reps choice ")
        else:
            raise ValueError("Wrong reps_type choice")
        return reps

    def expand_mask(self, item: torch.Tensor, expand_size, dim):
        new_item = item.unsqueeze(dim)
        expand_shape = list(new_item.shape)
        expand_shape[dim] = expand_size
        return new_item.expand(expand_shape)

### re-scale input tensor's value

对输入向量大小进行缩放

我们主要关注： emission_normalizer=None, emission_scaler=learn

In [None]:
class ScaleControllerBase(torch.nn.Module):
    """
    The base class for ScaleController.
    ScaleController is a callable class that re-scale input tensor's value.
    Traditional scale method may include:
        soft-max, L2 normalize, relu and so on.
    Advanced method:
        Learnable scale parameter
    """
    def __init__(self):
        super(ScaleControllerBase, self).__init__()

    def forward(self, x:  torch.Tensor, dim: int = 0, p: int = 1):
        """
        Re-scale the input x into proper value scale.
        :param x: the input tensor
        :param dim: axis to scale(mostly used in traditional method)
        :param p: p parameter used in traditional methods
        :return: rescaled x
        """
        raise NotImplementedError

In [None]:
class LearnableScaleController(ScaleControllerBase):
    """
    Scale parameter mentioned in [Tadam: Task dependent adaptive metric for improved few-shot learning. (NIPS2018)]
    """
    def __init__(self, normalizer: ScaleControllerBase = None):
        super(LearnableScaleController, self).__init__()
        self.scale_rate = torch.nn.Parameter(torch.rand(1), requires_grad=True)
        self.normalizer = normalizer

    def forward(self, x:  torch.Tensor, dim: int = 0, p: int = 1):
        x = x * self.scale_rate
        if self.normalizer:
            x = self.normalizer(x, dim=dim, p=p)
        return x

In [None]:
def build_scale_controller(name: str, kwargs=None) -> Union[ScaleControllerBase, None]:
    """
    A tool function that help to select scale controller easily.
    :param name: name of scale controller, choice now: 'learn', 'fix', 'relu', 'exp', 'softmax', 'norm'
    :param kwargs: necessary controller parameter in dictionary style
    :return:
    """
    if not name or name == 'none':
        return None
    controller_choices = {
        'learn': LearnableScaleController,
        'fix': FixedScaleController,
        'relu': ReluScaleController,
        'exp': ExpScaleController,
        'softmax': SoftmaxScaleController,
        'norm': NormalizeScaleController,
    }
    if name not in controller_choices:
        raise KeyError('Wrong scale controller name.')
    controller_type = controller_choices[name]
    return controller_type(**kwargs) if kwargs else controller_type()

In [None]:
def make_scaler_args(name : str, normalizer: ScaleControllerBase, scale_r: float = None):
    ret = None
    if name == 'learn':
        ret = {'normalizer': normalizer}
    elif name == 'fix':
        ret = {'normalizer': normalizer, 'scale_rate': scale_r}
    return ret

In [None]:
# 举例：

class Config():
    do_debug = False

opt = Config()
opt.emission_normalizer = None
opt.emission_scaler = "learn"

ems_normalizer = build_scale_controller(
        name=opt.emission_normalizer
    )
ems_scaler = build_scale_controller(
    name=opt.emission_scaler,
    kwargs=make_scaler_args(opt.emission_scaler, ems_normalizer, opt.ems_scale_r)
)

### similarity function

相似度计算函数

可选的有 'cosine', 'dot', 'l2'

In [None]:
def reps_dot(sent1_reps: torch.Tensor, sent2_reps: torch.Tensor) -> torch.Tensor:
    """
    calculate representation dot production
    :param sent1_reps: (N, sent1_len, reps_dim)
    :param sent2_reps: (N, sent2_len, reps_dim)
    :return: (N, sent1_len, sent2_len)
    """
    
    # 注意这里的 torch.bmm 的用法 (N, sent1_len, reps_dim ) * (N, reps_dim, sent2_len, )
    return torch.bmm(sent1_reps, torch.transpose(sent2_reps, -1, -2))  # shape: (N, seq_len1, seq_len2)


def reps_l2_sim(sent1_reps: torch.Tensor, sent2_reps: torch.Tensor) -> torch.Tensor:
    """
    calculate representation L2 similarity
    实际就是负的L2距离
    
    :param sent1_reps: (N, sent1_len, reps_dim)
    :param sent2_reps: (N, sent2_len, reps_dim)
    :return: (N, sent1_len, sent2_len)
    """
    sent1_len = sent1_reps.shape[-2]
    sent2_len = sent2_reps.shape[-2]
    expand_shape1 = list(sent2_reps.shape)
    expand_shape1.insert(2, sent2_len)
    expand_shape2 = list(sent2_reps.shape)
    expand_shape2.insert(1, sent1_len)

    # shape: (N, seq_len1, seq_len2, emb_dim)
    expand_reps1 = sent1_reps.unsqueeze(2).expand(expand_shape1)
    expand_reps2 = sent2_reps.unsqueeze(1).expand(expand_shape2)

    # shape: (N, seq_len1, seq_len2)
    sim = torch.norm(expand_reps1 - expand_reps2, dim=-1, p=2)
    return -sim  # we calculate similarity not distance here

def reps_cosine_sim(sent1_reps: torch.Tensor, sent2_reps: torch.Tensor) -> torch.Tensor:
    """
    calculate representation cosine similarity, note that this is different from torch version(that compute parwisely)
    先求dot product，然后除以模长
    
    :param sent1_reps: (N, sent1_len, reps_dim)
    :param sent2_reps: (N, sent2_len, reps_dim)
    :return: (N, sent1_len, sent2_len)
    """
    dot_sim = torch.bmm(sent1_reps, torch.transpose(sent2_reps, -1, -2))  # shape: (batch, seq_len1, seq_len2)
    sent1_reps_norm = torch.norm(sent1_reps, dim=-1, keepdim=True)  # shape: (batch, seq_len1, 1)
    sent2_reps_norm = torch.norm(sent2_reps, dim=-1, keepdim=True)  # shape: (batch, seq_len2, 1)
    norm_product = torch.bmm(sent1_reps_norm,
                             torch.transpose(sent2_reps_norm, -1, -2))  # shape: (batch, seq_len1, seq_len2)
    sim_predicts = dot_sim / norm_product  # shape: (batch, seq_len1, seq_len2)
    return sim_predicts

In [None]:
opt.similarity = "dot"

if opt.similarity == 'dot':
    sim_func = reps_dot
elif opt.similarity == 'cosine':
    sim_func = reps_cosine_sim
elif opt.similarity == 'l2':
    sim_func = reps_l2_sim
else:
    raise TypeError('wrong component type')

In [None]:
class SimilarityScorerBase(torch.nn.Module):
    def __init__(self, sim_func, emb_log=None):
        super(SimilarityScorerBase, self).__init__()
        self.sim_func = sim_func
        self.emb_log = emb_log
        self.log_content = ''
        self.mlc_support_tags_mask = None

    def update_mlc_support_tags_mask(self, support_targets, support_output_mask):
        """
        update the mlc_support_tags_mask
        :param support_targets: (batch_size, support_size, max_label_num, num_tags)
        :param support_output_mask: (batch_size, support_size, max_label_num)
        :return:
        """
        batch_size = support_targets.size(0)
        num_tags = support_targets.size(-1)
        
        support_output_mask = support_output_mask.unsqueeze(-1).expand_as(support_targets)
        
        tags_mask = support_output_mask * support_targets  # (batch_size, support_size, max_label_num, num_tags)
        tags_mask = torch.sum(tags_mask.contiguous().view(batch_size, -1, num_tags), dim=1)
        tags_mask = (tags_mask >= 1).float()
        self.mlc_support_tags_mask = tags_mask

    def forward(
            self,
            test_reps: torch.Tensor,
            support_reps: torch.Tensor,
            test_output_mask: torch.Tensor,
            support_output_mask: torch.Tensor,
            support_targets: torch.Tensor = None,
            label_reps: torch.Tensor = None, ) -> torch.Tensor:
        """
            :param test_reps: (batch_size, support_size, test_seq_len, dim)
            :param support_reps: (batch_size, support_size, support_seq_len)
            :param test_output_mask: (batch_size, test_seq_len)
            :param support_output_mask: (batch_size, support_size, support_seq_len)
            :param support_targets: one-hot label targets: (batch_size, support_size, support_seq_len, num_tags)
            :param label_reps: (batch_size, num_tags, dim)
            :return: similarity
        """
        raise NotImplementedError()

    def mask_sim(self, sim: torch.Tensor, mask1: torch.Tensor, mask2: torch.Tensor):
        """
        去除掉无意义的similarity计算，比如对padding的： support  set 的多标签labels，是有padding的；
        
        mask invalid similarity to 0, i.e. sim to pad token is 0 here.
        :param sim: similarity matrix (num sim, test_len, support_len)
        :param mask1: (num sim, test_len, support_len)
        :param mask2: (num sim, test_len, support_len)
        :param min_value: the minimum value for similarity score
        :return:
        """
        mask1 = mask1.unsqueeze(-1).float()  # (b * s, test_label_num, 1)
        mask2 = mask2.unsqueeze(-1).float()  # (b * s, support_label_num, 1)
        mask = reps_dot(mask1, mask2)  # (b * s, test_label_num, support_label_num)
        sim = sim * mask
        return sim

    def expand_it(self, item: torch.Tensor, support_size):
        item = item.unsqueeze(1)
        expand_shape = list(item.shape)
        expand_shape[1] = support_size
        return item.expand(expand_shape)

In [None]:
class ProtoWithLabelSimilarityScorer(SimilarityScorerBase):
    def __init__(self, sim_func, scaler=None, emb_log=None):
        super(ProtoWithLabelSimilarityScorer, self).__init__(sim_func=sim_func, emb_log=emb_log)
        self.scaler = scaler  # 浮点数，缩放
        self.emb_log = emb_log
        self.idx = 0

    def forward(
            self,
            test_reps: torch.Tensor,
            support_reps: torch.Tensor,
            test_output_mask: torch.Tensor,
            support_output_mask: torch.Tensor,
            support_targets: torch.Tensor = None,
            label_reps: torch.Tensor = None,) -> torch.Tensor:
        """
            calculate similarity between token and each label's prototype.
            :param test_reps: (batch_size, support_size, test_seq_len, dim)
            :param support_reps: (batch_size, support_size, support_seq_len)
            :param test_output_mask: (batch_size, test_seq_len)
            :param support_output_mask: (batch_size, support_size, support_seq_len)   # 帮助区分哪个部分是padding
            :param support_targets: one-hot label targets: (batch_size, support_size, support_seq_len, num_tags) 进行了one-hot编码的多标签
            :param label_reps: (batch_size, num_tags, dim)
            :return: similarity: (batch_size, test_seq_len, num_tags)
        """
        '''get data attribute'''
        support_size = support_reps.shape[1]
        test_len = test_reps.shape[-2]  # non-word-piece max test token num, Notice that it's different to input t len
        support_len = support_reps.shape[-2]  # non-word-piece max test token num
        emb_dim = test_reps.shape[-1]
        batch_size = test_reps.shape[0]
        num_tags = support_targets.shape[-1]
        
        # 计算prototype向量： onehot向量support_targets 与 support_reps相乘，然后取平均
        '''get prototype reps'''
        # flatten dim mention of support size and batch size.
        # shape (batch_size * support_size, sent_len, emb_dim)
        support_reps = support_reps.view(-1, support_len, emb_dim)

        # shape (batch_size * support_size, sent_len, num_tags)
        # print("Debug", test_reps.shape, support_targets.shape)
        support_targets = support_targets.view(batch_size * support_size, support_len, num_tags).float()

        # shape (batch_size, support_size, num_tags, emd_dim)
        sum_reps = torch.bmm(torch.transpose(support_targets, -1, -2), support_reps)
        # sum up tag emb over support set, shape (batch_size, num_tags, emd_dim)
        sum_reps = torch.sum(sum_reps.view(batch_size, support_size, num_tags, emb_dim), dim=1)

        # get num of each tag in support set, shape: (batch_size, num_tags, 1)
        tag_count = torch.sum(support_targets.view(batch_size, -1, num_tags), dim=1).unsqueeze(-1)
        # divide by 0 occurs when the tags, such as "I-x", are not existing in support.
        tag_count = self.remove_0(tag_count)
        prototype_reps = torch.div(sum_reps, tag_count)  # shape (batch_size, num_tags, emd_dim)
        
        ################
        # 提问
        ################
        ''' 为什么这里的 prototype_reps 还与 batch_size 相关？每个batch里面的support set都是相同的，为什么还要再维持这个更大的tensor？  '''
        # query 和 support set； 
        
        # add PAD label
        if label_reps is not None:
            # 添加PAD的向量表征
            label_reps = torch.cat((torch.zeros_like(label_reps).narrow(dim=-2, start=0, length=1), label_reps), dim=-2)
            
            # 这里就是论文中的anchored label representation： label name的表征 与 prototype的加权平均
            prototype_reps = (1 - self.scaler) * prototype_reps + self.scaler * label_reps

        '''get final test data reps'''
        # average test representation over support set (reps for each support sent can be different)
        # 之前，query与每个support 样本拼接，得到了不同的表征
        test_reps = torch.mean(test_reps, dim=1)  # shape (batch_size, sent_len, emb_dim)
        
        # 计算相似度
        '''calculate dot product'''
        sim_score = self.sim_func(test_reps, prototype_reps)  # shape (batch_size, sent_len, num_tags)

        '''store visualization embedding'''
        if not self.training and self.emb_log:
            log_context = '\n'.join(
                ['test_' + str(self.idx) + '-' + str(idx) + '-' + str(idx2) + '\t' + '\t'.join(map(str, item2))
                 for idx, item in enumerate(test_reps.tolist()) for idx2, item2 in enumerate(item)]) + '\n'
            log_context += '\n'.join(
                ['proto_' + str(self.idx) + '-' + str(idx) + '-' + str(idx2) + '\t' + '\t'.join(map(str, item2))
                 for idx, item in enumerate(prototype_reps.tolist()) for idx2, item2 in enumerate(item)]) + '\n'
            self.idx += batch_size
            self.emb_log.write(log_context)

        return sim_score

    def remove_nan(self, my_tensor):
        """
        Using 'torch.where' here because:
        modifying tensors in-place can cause issues with backprop.
        """
        return torch.where(torch.isnan(my_tensor), torch.zeros_like(my_tensor), my_tensor)

    def remove_0(self, my_tensor):
        """

        """
        return my_tensor + 0.0001

### emission

相似度计算 + 相似度分数调整的方法；

支持：'mnet', 'rank', 'proto', 'proto_with_label', 'tapnet'

主要关注： proto_with_label

In [None]:
class EmissionScorerBase(torch.nn.Module):
    def __init__(self, similarity_scorer: SimilarityScorerBase, scaler: ScaleControllerBase = None):
        """
        :param similarity_scorer: Module for calculating token similarity
        """
        super(EmissionScorerBase, self).__init__()
        self.similarity_scorer = similarity_scorer
        self.scaler = scaler

    def forward(
            self,
            test_reps: torch.Tensor,
            support_reps: torch.Tensor,
            test_output_mask: torch.Tensor,
            support_output_mask: torch.Tensor,
            support_targets: torch.Tensor,
            label_reps: torch.Tensor = None, ) -> torch.Tensor:
        """
        :param test_reps: (batch_size, support_size, test_seq_len, dim), notice: reps has been expand to support size
        :param support_reps: (batch_size, support_size, support_seq_len)
        :param test_output_mask: (batch_size, test_seq_len)
        :param support_output_mask: (batch_size, support_size, support_seq_len)
        :param support_targets: one-hot label targets: (batch_size, support_size, support_seq_len, num_tags)
        :param label_reps: (batch_size, num_tags, dim)
        :return: emission, shape: (batch_size, test_len, no_pad_num_tags)
        """
        raise NotImplementedError()

In [None]:
# 其实就是前面的SimilarityScorer与scaler的结合

class ProtoWithLabelEmissionScorer(EmissionScorerBase):
    def __init__(self, similarity_scorer: ProtoWithLabelSimilarityScorer, scaler: ScaleControllerBase = None):
        """
        :param similarity_scorer: Module for calculating token similarity
        """
        super(ProtoWithLabelEmissionScorer, self).__init__(similarity_scorer, scaler)

    def forward(
            self,
            test_reps: torch.Tensor,
            support_reps: torch.Tensor,
            test_output_mask: torch.Tensor,
            support_output_mask: torch.Tensor,
            support_targets: torch.Tensor,
            label_reps: torch.Tensor = None, ) -> torch.Tensor:
        """
        :param test_reps: (batch_size, support_size, test_seq_len, dim), notice: reps has been expand to support size
        :param support_reps: (batch_size, support_size, support_seq_len)
        :param test_output_mask: (batch_size, test_seq_len)
        :param support_output_mask: (batch_size, support_size, support_seq_len)
        :param support_targets: one-hot label targets: (batch_size, support_size, support_seq_len, num_tags)
        :param label_reps: (batch_size, num_tags, dim)
        :return: emission, shape: (batch_size, test_len, no_pad_num_tags)
        """
        similarity = self.similarity_scorer(
            test_reps, support_reps, test_output_mask, support_output_mask, support_targets, label_reps)
        emission = self.get_emission(similarity, support_targets)  # shape(batch_size, test_len, no_pad_num_tag)
        return emission

    def get_emission(self, similarities: torch.Tensor, support_targets: torch.Tensor):
        """
        :param similarities: (batch_size, support_size, test_seq_len, support_seq_len)
        :param support_targets: one-hot label targets: (batch_size, support_size, support_seq_len, num_tags)
        :return: emission: shape: (batch_size, test_len, no_pad_num_tags)
        """
        batch_size, test_len, num_tags = similarities.shape
        no_pad_num_tags = num_tags - 1  # block emission on pad
        
        #先去除PAD的分数
        ''' cut emission to block predictions on [PAD] label (we use 0 as [PAD] label id) '''
        emission = similarities.narrow(-1, 1, no_pad_num_tags)
        
        if self.scaler:
            emission = self.scaler(emission, p=3, dim=-1)
        return emission

## Decoder

多标签的解码器

支持： 'mlc', 'eamlc', 'msmlc', 'krnmsmlc'；

主要关注： krnmsmlc

In [None]:
class MultiLabelTextClassifier(torch.nn.Module):
    def __init__(self, threshold=0.6, grad_threshold=True):
        super(MultiLabelTextClassifier, self).__init__()
        
        # threshold作为模型的参数，可以进行参数更新(如果requires_grad=True)
        self.threshold = nn.Parameter(torch.FloatTensor([threshold]), requires_grad=grad_threshold)
        
        # 多分类损失函数计算
        self.criterion = nn.MultiLabelSoftMarginLoss()
        self.right_estimate = None

    def forward(self,
                logits: torch.Tensor,
                mask: torch.Tensor,
                tags: torch.Tensor) -> torch.Tensor:
        """
        :param logits: (batch_size, 1, n_tags)
        :param mask: (batch_size, 1)
        :param tags: (batch_size, max_label_num), eg [[2, 15], [2, 0]]
        :return:
        """
        return self._compute_loss(logits, mask, tags)

    def _compute_loss(self,
                      logits: torch.Tensor,
                      mask: torch.Tensor,
                      targets: torch.Tensor) -> torch.Tensor:
        """
        :param logits: (batch_size, 1, n_tags)
        :param mask: (batch_size, 1)
        :param targets: (batch_size, max_label_num), eg [[2, 15], [2, 0]]
        :return:
        """

        batch_size, seq_len = mask.shape
        # todo: test different normalization effectiveness.
        # todo: check pad label

        threshold = self.get_threshold(logits)
        # Normalization has been done in emission's scaler
        filtered_logits = logits - threshold  # For each pos, > 0 for positive tag, < 0 for negative tag
        
        # 形状 (N, C)
        multi_hot_target = self.create_multi_hot(targets, label_num=logits.shape[-1])

        loss = self.criterion(filtered_logits, multi_hot_target)
        return loss

    def create_multi_hot(self, y: torch.Tensor, label_num: int):
        """
        :param y:  (batch_size, max_label_num), eg [[2, 15], [2, 0]]
        :param label_num: num of
        :return:
        """
        batch_size = y.shape[0]
        y_one_hot = torch.zeros(batch_size, label_num).to(y.device)
        return y_one_hot.scatter(1, y, 1)

    def decode(self, logits: torch.Tensor) -> List[List[int]]:
        """ collect the values greater than threshold. """
        # shape: (batch_size, 1, no_pad_num_tag) -> (batch_size, 1, no_pad_num_tag)
        threshold = self.get_threshold(logits)
        preds = (logits - threshold).squeeze()
        ret = []
        for pred in preds:
            temp = []
            for l_id, score in enumerate(pred):
                if bool(score > 0):
                    temp.append(int(l_id))
            # predict the label with most probability
            
            # 如果没有预测是概率大于阈值的，则取最大概率的那一个 
            if not temp:
                temp = [int(torch.argmax(pred))]
            ret.append(temp)

        return ret

    def get_threshold(self, logits):
        return self.threshold

In [None]:
# 加入论文中的 logits adaptive机制

class EAMultiLabelTextClassifier(MultiLabelTextClassifier):
    """ Emission adaptive MultiLabelTextClassifier
        (1) Adaptive threshold λ is calculated by observing emission scores as:
        λ =（E_max−E_min ）× r + E_min
        where E is emission, r is emission rate.

        (2) Then different sample has different threshold for different logits.
    """
    def __init__(self, threshold=0.6, grad_threshold=True):
        """
        Here self.threshold is used as emission rate.
        """
        super(EAMultiLabelTextClassifier, self).__init__(threshold, grad_threshold)

    def get_threshold(self, logits):
        """
        :param logits: (batch_size, 1, n_tags)
        :return: (batch_size, 1, 1)
        """
        # 根据公式计算动态自适应的阈值
        # 思考： 本文的场景是：我知道每个样本都至少有标签；如果不是这样的，那我应该如何做出改变？
        
        # 这里的 self.threshold 参数含义就变了：权重
        max_logits = torch.max(logits, dim=-1)[0]  # fetch value, give up indexes.
        min_logits = torch.min(logits, dim=-1)[0]
        threshold = (max_logits - min_logits) * self.threshold + min_logits  # (batch_size, 1)
        return threshold.unsqueeze(-1)

In [None]:
class MetaStatsMultiLabelTextClassifier(EAMultiLabelTextClassifier):
    """ Meta statistic MultiLabelTextClassifier.
    """
    def __init__(self, threshold=0.6, grad_threshold=True, meta_rate=0.5, ab_ea=False):
        super(MetaStatsMultiLabelTextClassifier, self).__init__(threshold, grad_threshold)
        self.num_stats = None
        self.meta_rate = meta_rate   # 
        self.ab_ea = ab_ea

    def update_statistics(self, support_targets):
        """
        Update stats for each sample in batch.
        :param support_targets: one-hot targets (batch_size, support_size, max_label_num, num_tags)
        :return: None
        """
        # count label num
        batch_size, support_size, max_label_num, num_tags = support_targets.shape
        c_sup_targets = support_targets.narrow(dim=-1, start=1, length=num_tags-1) # 去除PAD
        
        # one-hot转为multi-hot
        multi_hot_tgt = torch.sum(c_sup_targets, dim=-2)  # shape (batch_size, support_size, num_tags)
        
        label_num = torch.sum(multi_hot_tgt, dim=-1)  # shape (batch_size, support_size)
        
        # get stats for each sample in batch
        label_num = [[it for it in item.tolist() if it] for item in label_num]  # del [PAD] label data
        
        # 得到每个样本有多少个label
        self.num_stats = [Counter(n) for n in label_num]

    def get_threshold(self, logits):
        """
        Step1: Estimate threshold λ′ by observing support set:
                λ′= ∑_N^i[p(k=i) E_(i+1)]
        ,where N is label num, E_j is j_th largest ranked emission scores， p(k=i) 是经验count

        Step2: Calibrate λ′ with meta parameter

        :param logits: (batch_size, 1, n_tags)
        :return: (batch_size, 1, 1)
        """
        ''' Estimate threshold '''
        # get estimate thresholds: shape (batch_size)
        estimate_thresholds = self.estimate_threshold(logits)
        ''' Calibrate threshold '''
        thresholds = self.calibrate_threshold(logits, estimate_thresholds)
        return thresholds

    def estimate_threshold(self, logits) -> torch.FloatTensor:
        """
        :param logits: (batch_size, 1, n_tags)
        :return: shape (batch_size)
        """
        # todo: check support set pad influence of
        
        # 根据经验分布来预估阈值
        ret = []
        for ind, logit in enumerate(logits):
            # logits 排序
            sorted_logits = sorted(logit[0], reverse=True)
            
            stats = self.num_stats[ind]
            stats: Counter
            
            l_sum = 0
            for num, count in stats.items():
                l_sum += sorted_logits[int(num)] * count  # num is already rank + 1
            ret.append(l_sum / len(list(stats.elements())))
        ret = torch.stack(ret).to(logits.device)
        return ret

    def calibrate_threshold(self, logits, thresholds) -> torch.FloatTensor:
        """

        :param logits:
        :param thresholds:
        :return: (batch_size, 1, 1)
        """
        
        if self.ab_ea: # 针对 emission adapting的
            meta_threshold = self.threshold  # ablation EA threshold here
        else:
            meta_threshold = super().get_threshold(logits)  # use EA threshold here
        
        est_threshold = thresholds.unsqueeze(-1).unsqueeze(-1)
        
        return est_threshold * (1 - self.meta_rate) + self.meta_rate * meta_threshold

    def get_ea_threshold(self, logits):
        return super().get_threshold(logits)

In [None]:

def gaussian_kernel(input1, input2, bandwidth):
    """

    :param input1: (batch_size, support_size, feature_len)
    :param input2: (batch_size, support_size, feature_len)
    :param bandwidth:
    :return:
    """
    # 高斯核： 预测样本与其他已知样本进行对比； 
    input_x = (input1 - input2) / bandwidth
    input_x = input_x.unsqueeze(-2)
    
    # 基于gauss核的权重
    k_value = 1.0 / np.sqrt(np.pi) * torch.exp(-torch.matmul(input_x, input_x.permute(0, 1, 3, 2)) / 2)
    k_value = k_value.squeeze(-1).squeeze(-1)
    
    return k_value


class GaussianKernelSimilarityScorer(torch.nn.Module):

    def __init__(self, bandwidth=0.5, learnable=False, map_dict=None):
        super(GaussianKernelSimilarityScorer, self).__init__()
        self.bandwidth = torch.nn.Parameter(torch.Tensor([bandwidth]), requires_grad=learnable)
        self.map_dict = map_dict
        self.feature_map = map_dict['feature_map']
        self.feature_num = map_dict['feature_num']
        self.feature_map_dim = map_dict['feature_map_dim']
        self.feature_map_act = map_dict['feature_map_act']
        self.feature_map_layer_num = map_dict['feature_map_layer_num']
        
        if self.feature_map:
            modules = []
            if self.feature_map_act == 'none':
                modules.append(torch.nn.Linear(self.feature_num, self.feature_map_dim))
                for _ in range(self.feature_map_layer_num - 1):
                    modules.append(torch.nn.Linear(self.feature_map_dim, self.feature_map_dim))
            elif self.feature_map_act == 'relu':
                modules.append(torch.nn.Linear(self.feature_num, self.feature_map_dim))
                modules.append(torch.nn.ReLU())
                for _ in range(self.feature_map_layer_num - 1):
                    modules.append(torch.nn.Linear(self.feature_map_dim, self.feature_map_dim))
                    modules.append(torch.nn.ReLU())
            elif self.feature_map_act == 'sigmoid':
                modules.append(torch.nn.Linear(self.feature_num, self.feature_map_dim))
                modules.append(torch.nn.Sigmoid())
                for _ in range(self.feature_map_layer_num - 1):
                    modules.append(torch.nn.Linear(self.feature_map_dim, self.feature_map_dim))
                    modules.append(torch.nn.Sigmoid())
            elif self.feature_map_act == 'tanh':
                modules.append(torch.nn.Linear(self.feature_num, self.feature_map_dim))
                modules.append(torch.nn.Tanh())
                for _ in range(self.feature_map_layer_num - 1):
                    modules.append(torch.nn.Linear(self.feature_map_dim, self.feature_map_dim))
                    modules.append(torch.nn.Tanh())
            else:
                raise NotImplementedError

            self.map_linear = torch.nn.Sequential(*modules)
            self.map_linear.apply(self.init_weights)

    def init_weights(self, w):
        if type(w) == torch.nn.Linear:
            torch.nn.init.xavier_normal_(w.weight)

    def feature_norm(self, support_features, test_feature):
        """
        :param support_features:  (batch_size, support_size, feature_len)
        :param test_feature: (batch_size, feature_len)
        :return:
        """
        if self.feature_map:
            support_features = self.map_linear(support_features)
            test_feature = self.map_linear(test_feature)

        s_max = torch.max(support_features, dim=-2)[0]  # (batch_size, feature_len)
        s_min = torch.min(support_features, dim=-2)[0]  # (batch_size, feature_len)

        s_max = torch.max(s_max, test_feature)  # (batch_size, feature_len)
        s_min = torch.min(s_min, test_feature)  # (batch_size, feature_len)

        scale = s_max - s_min  # (batch_size, feature_len)
        # if scale equal to 0, then add 1
        scale = scale + (scale == 0).float()

        s_features = (support_features - s_min.unsqueeze(-2).expand_as(support_features))
        s_features = s_features / scale.unsqueeze(-2).expand_as(support_features)

        t_feature = (test_feature - s_min) / scale

        return s_features, t_feature

    def forward(self,
                support_sentence_feature: torch.Tensor,
                test_sentence_feature: torch.Tensor,
                support_sentence_target: torch.Tensor,
                test_sentence_target: torch.Tensor) -> torch.Tensor:
        """

        :param support_sentence_feature:  (batch_size, support_size, feature_len)
        :param test_sentence_feature:  (batch_size, feature_len)
        :param support_sentence_target:  (batch_size, support_size)
        :param test_sentence_target:  (batch_size, 1)
        :return:
        """
        
        # 将特征数值进行 最大值-最小值 标准化
        support_sentence_feature, test_sentence_feature = \
            self.feature_norm(support_sentence_feature, test_sentence_feature)
        
        # 基于高斯核计算出权重
        test_sentence_feature = test_sentence_feature.unsqueeze(1).expand_as(support_sentence_feature)
        self.bandwidth.to(support_sentence_feature.device)
        k_values = gaussian_kernel(test_sentence_feature, support_sentence_feature, self.bandwidth)

        k_values_sum = torch.sum(k_values, dim=-1, keepdim=True)
        k_values_sum = k_values_sum.expand_as(k_values)
        k_weights = k_values / k_values_sum  # (batch_size, support_size)

        return k_weights

In [None]:
class KRNMetaStatsMultiLabelTextClassifier(MetaStatsMultiLabelTextClassifier):

    def __init__(self, threshold=0.6, grad_threshold=True,
                 meta_rate=0.5, ab_ea=False,
                 kernel='gaussian', bandwidth=0.5,
                 use_gold=False, learnable=False,
                 map_dict=None):
        super(KRNMetaStatsMultiLabelTextClassifier, self).__init__(threshold, grad_threshold, meta_rate, ab_ea)
        self.kernel = kernel
        self.bandwidth = bandwidth
        self.learnable = learnable
        self.use_gold = use_gold
        self.map_dict = map_dict
        self.similarity_scorer = self.choose_kernel_similar()
        if self.learnable:
            self.label_num_criterion = nn.MSELoss()
            self.label_num_loss = None

    def choose_kernel_similar(self):
        if self.kernel == 'gaussian':
            similarity_scorer = GaussianKernelSimilarityScorer(bandwidth=self.bandwidth, learnable=self.learnable,
                                                               map_dict=self.map_dict)
        else:
            raise NotImplementedError
        return similarity_scorer

    def update_statistics(self,
                          support_targets,
                          support_sentence_feature=None,
                          test_sentence_feature=None,
                          support_sentence_label_num=None,
                          test_sentence_label_num=None):
        """
        Update stats for each sample in batch.
        :param support_targets: one-hot targets (batch_size, support_size, max_label_num, num_tags)
        :param support_sentence_feature: (batch_size, support_size, feature_len)
        :param test_sentence_feature: (batch_size, feature_len)
        :param support_sentence_label_num: (batch_size, support_size)
        :param test_sentence_label_num: (batch_size)
        :return: None
        """

        ''' count label num '''
        # 设置为 use_gold = False
        if self.use_gold:
            t_target = test_sentence_label_num.squeeze(-1)
            self.num_stats = [{item: [1]} for item in t_target.long().tolist()]
            self.right_estimate = (t_target == t_target).long()
        else:
            
            # 根据高斯核得到权重
            label_num_weights = self.similarity_scorer(support_sentence_feature, test_sentence_feature,
                                                       support_sentence_label_num, test_sentence_label_num)

            # get the distributed num stats
            batch_size = label_num_weights.size(0)
            self.num_stats = []
            for b_idx in range(batch_size):
                tmp_stat = {}
                for s_label_num, weight in zip(support_sentence_label_num[b_idx].long().tolist(),
                                               label_num_weights[b_idx].tolist()):
                    if s_label_num not in tmp_stat:
                        tmp_stat[s_label_num] = [weight]
                    else:
                        tmp_stat[s_label_num].append(weight)
                self.num_stats.append(tmp_stat)
            
            # 看kernel regression是不是正确的
            # calculate the label num accuracy
            pred_label_num = torch.sum(label_num_weights * support_sentence_label_num, dim=-1)
            pred_label_num_int = torch.round(pred_label_num)
            
            self.right_estimate = (pred_label_num_int == test_sentence_label_num.squeeze(-1)).long()
            
            # 实际使用的模型 设置  learnable = True
            if self.learnable:
                self.label_num_loss = self.label_num_criterion(pred_label_num, test_sentence_label_num.squeeze(-1))

    def estimate_threshold(self, logits) -> torch.FloatTensor:
        """
        :param logits: (batch_size, 1, n_tags)
        :return: shape (batch_size)
        """
        # todo: check support set pad influence of
        ret = []
        for ind, logit in enumerate(logits):
            sorted_logits = sorted(logit[0], reverse=True)
            stats = self.num_stats[ind]
            stats: Dict
            l_sum = 0
            for num, count_lst in stats.items():
                l_sum += sorted_logits[int(num)] * sum(count_lst)
            ret.append(l_sum)
        ret = torch.stack(ret).to(logits.device)
        return ret

    def _compute_loss(self,
                      logits: torch.Tensor,
                      mask: torch.Tensor,
                      targets: torch.Tensor) -> torch.Tensor:
        """
        :param logits: (batch_size, 1, n_tags)
        :param mask: (batch_size, 1)
        :param targets: (batch_size, max_label_num), eg [[2, 15], [2, 0]]
        :return:
        """
        # 标签预测的损失 + label num 预测的损失
        loss = super()._compute_loss(logits, mask, targets)
        if self.learnable:
            loss += self.label_num_loss
        return loss

    def decode(self, logits: torch.Tensor, test_label_num=None) -> List[List[int]]:
        """ collect the values greater than threshold. """
        # shape: (batch_size, 1, no_pad_num_tag) -> (batch_size, 1, no_pad_num_tag)
        
        # 这部分的写法与之前是一致的
        if self.use_gold:
            test_label_num = test_label_num.squeeze(-1).long().tolist()  # (batch_size, )
            ret = []
            for label_num, logit in zip(test_label_num, logits):
                sorted_logits = sorted(logit[0], reverse=True)
                threshold_logit = sorted_logits[label_num]
                threshold_logit = threshold_logit.unsqueeze(-1).expand_as(logit)
                pred = (logit - threshold_logit).squeeze()  # (batch_size, )
                temp = []
                for l_id, score in enumerate(pred):
                    if bool(score >= 0):
                        temp.append(int(l_id))
                # predict the label with most probability
                if not temp:
                    temp = [int(torch.argmax(pred))]
                ret.append(temp)
        else:
            ret = super().decode(logits)

        return ret

In [None]:
## 模型汇总

将模型的各个模块串起来

In [None]:
def make_model(opt, config):
    """ Customize and build the few-shot learning model from components """

    ''' Build context_embedder '''
    if opt.context_emb == 'bert':
        context_embedder = BertSchemaContextEmbedder(opt=opt) if opt.use_schema else BertContextEmbedder(opt=opt)
    elif opt.context_emb == 'sep_bert':
        context_embedder = BertSchemaSeparateContextEmbedder(opt=opt) if opt.use_schema else \
            BertSeparateContextEmbedder(opt=opt)
    elif opt.context_emb == 'electra':
        context_embedder = ElectraSchemaContextEmbedder(opt=opt) if opt.use_schema else ElectraContextEmbedder(opt=opt)
    elif opt.context_emb == 'elmo':
        raise NotImplementedError
    elif opt.context_emb == 'raw':
        context_embedder = NormalContextEmbedder(opt=opt, num_token=len(opt.word2id))
    else:
        raise TypeError('wrong component type')

    ''' Create log file to record testing data '''
    if opt.emb_log:
        emb_log = open(os.path.join(opt.output_dir, 'emb.log'), 'w')
        if 'id2label' in config:
            emb_log.write('id2label\t' + '\t'.join([str(k) + ':' + str(v) for k, v in config['id2label'].items()]) + '\n')
    else:
        emb_log = None

    '''Build emission scorer and similarity scorer '''
    # build scaler
    ems_normalizer = build_scale_controller(
        name=opt.emission_normalizer
    )
    ems_scaler = build_scale_controller(
        name=opt.emission_scaler,
        kwargs=make_scaler_args(opt.emission_scaler, ems_normalizer, opt.ems_scale_r)
    )
    if opt.similarity == 'dot':
        sim_func = reps_dot
    elif opt.similarity == 'cosine':
        sim_func = reps_cosine_sim
    elif opt.similarity == 'l2':
        sim_func = reps_l2_sim
    else:
        raise TypeError('wrong component type')

    if opt.emission == 'mnet':
        similarity_scorer = MatchingSimilarityScorer(sim_func=sim_func, emb_log=emb_log)
        emission_scorer = MNetEmissionScorer(similarity_scorer, ems_scaler, opt.div_by_tag_num)
    elif opt.emission == 'proto':
        similarity_scorer = PrototypeSimilarityScorer(sim_func=sim_func, emb_log=emb_log)
        emission_scorer = PrototypeEmissionScorer(similarity_scorer, ems_scaler)
    elif opt.emission == 'proto_with_label':
        similarity_scorer = ProtoWithLabelSimilarityScorer(sim_func=sim_func, scaler=opt.ple_scale_r, emb_log=emb_log)
        emission_scorer = ProtoWithLabelEmissionScorer(similarity_scorer, ems_scaler)
    elif opt.emission == 'tapnet':
        # set num of anchors:
        # (1) if provided in config, use it (usually in load model case.)
        # (2) *3 is used to ensure enough anchors ( > num_tags of unseen domains )
        num_anchors = config['num_anchors'] if 'num_anchors' in config else config['num_tags'] * 3
        config['num_anchors'] = num_anchors
        anchor_dim = 256 if opt.context_emb == 'electra' else 768
        similarity_scorer = TapNetSimilarityScorer(
            sim_func=sim_func, num_anchors=num_anchors, mlp_out_dim=opt.tap_mlp_out_dim,
            random_init=opt.tap_random_init, random_init_r=opt.tap_random_init_r,
            mlp=opt.tap_mlp, emb_log=emb_log, tap_proto=opt.tap_proto, tap_proto_r=opt.tap_proto_r,
            anchor_dim=anchor_dim)
        emission_scorer = TapNetEmissionScorer(similarity_scorer, ems_scaler)
    else:
        raise TypeError('wrong component type')

    ''' Build decoder '''
    if opt.task == 'sl': # for sequence labeling
        if opt.decoder == 'sms':
            transition_scorer = None
            decoder = SequenceLabeler()
        elif opt.decoder == 'rule':
            transition_scorer = None
            decoder = RuleSequenceLabeler(config['id2label'])
        elif opt.decoder == 'crf':
            # Notice: only train back-off now
            trans_normalizer = build_scale_controller(name=opt.trans_normalizer)
            trans_scaler = build_scale_controller(
                name=opt.trans_scaler, kwargs=make_scaler_args(opt.trans_scaler, trans_normalizer, opt.trans_scale_r))
            if opt.transition == 'learn':
                transition_scorer = FewShotTransitionScorer(
                    num_tags=config['num_tags'], normalizer=trans_normalizer, scaler=trans_scaler,
                    r=opt.trans_r, backoff_init=opt.backoff_init)
            elif opt.transition == 'learn_with_label':
                label_trans_normalizer = build_scale_controller(name=opt.label_trans_normalizer)
                label_trans_scaler = build_scale_controller(name=opt.label_trans_scaler, kwargs=make_scaler_args(
                        opt.label_trans_scaler, label_trans_normalizer, opt.label_trans_scale_r))
                transition_scorer = FewShotTransitionScorerFromLabel(
                    num_tags=config['num_tags'], normalizer=trans_normalizer, scaler=trans_scaler,
                    r=opt.trans_r, backoff_init=opt.backoff_init, label_scaler=label_trans_scaler)
            else:
                raise ValueError('Wrong choice of transition.')
            if opt.add_transition_rules and 'id2label' in config:  # 0 is [PAD] label id, here remove it.
                non_pad_id2label = copy.deepcopy(config['id2label']).__delitem__(0)
                for k, v in non_pad_id2label.items():
                    non_pad_id2label[k] = v - 1  # we 0 as [PAD] label id, here remove it.
                constraints = allowed_transitions(constraint_type='BIO', labels=non_pad_id2label)
            else:
                constraints = None
            decoder = ConditionalRandomField(
                num_tags=transition_scorer.num_tags, constraints=constraints)  # accurate tags
        else:
            raise TypeError('wrong component type')
    elif opt.task == 'mlc':  # for multi-label text classification task
        grad_threshold = True if opt.threshold_type == 'learn' else False
        if opt.decoder == 'mlc':
            decoder = MultiLabelTextClassifier(opt.threshold, grad_threshold)
        elif opt.decoder == 'eamlc':
            decoder = EAMultiLabelTextClassifier(opt.threshold, grad_threshold)
        elif opt.decoder == 'msmlc':
            decoder = MetaStatsMultiLabelTextClassifier(opt.threshold, grad_threshold, meta_rate=opt.meta_rate,
                                                        ab_ea=opt.ab_ea)
        elif opt.decoder == 'krnmsmlc':
            map_dict = {
                "feature_map": opt.feature_map,
                "feature_num": opt.feature_num,
                "feature_map_dim": opt.feature_map_dim,
                "feature_map_act": opt.feature_map_act,
                "feature_map_layer_num": opt.feature_map_layer_num,
            }
            decoder = KRNMetaStatsMultiLabelTextClassifier(opt.threshold, grad_threshold, meta_rate=opt.meta_rate,
                                                           ab_ea=opt.ab_ea, kernel=opt.kernel, bandwidth=opt.bandwidth,
                                                           use_gold=opt.use_gold, learnable=opt.kernel_learnable,
                                                           map_dict=map_dict)
        else:
            raise TypeError('wrong component type')
    elif opt.task == 'sc':  # for single-label text classification task
        decoder = SingleLabelTextClassifier()
    else:
        raise TypeError('wrong task type')

    ''' Build the whole model '''
    if opt.task == 'sl':
        seq_labeler = SchemaFewShotSeqLabeler if opt.use_schema else FewShotSeqLabeler
        model = seq_labeler(
            opt=opt,
            context_embedder=context_embedder,
            emission_scorer=emission_scorer,
            decoder=decoder,
            transition_scorer=transition_scorer,
            config=config,
            emb_log=emb_log
        )
    elif opt.task in ['sc', 'mlc']:
        text_classifier = SchemaFewShotTextClassifier if opt.use_schema else FewShotTextClassifier
        model = text_classifier(
            opt=opt,
            context_embedder=context_embedder,
            emission_scorer=emission_scorer,
            decoder=decoder,
            config=config,
            emb_log=emb_log
        )
    else:
        raise TypeError('wrong task type')
    return model


def load_model(path):
    try:
        with open(path, 'rb') as reader:
            cpt = torch.load(reader, map_location='cpu')
            model = make_model(opt=cpt['opt'], config=cpt['config'])
            model = prepare_model(args=cpt['opt'], model=model, device=cpt['opt'].device, n_gpu=cpt['opt'].n_gpu)
            model.load_state_dict(cpt['state_dict'])
            return model
    except IOError as e:
        logger.info("Failed to load model from {} \n {}".format(path, e))
        return None

In [None]:
# 处理并行&device信息
def prepare_model(args, model, device, n_gpu):
    """ init my part parameter """

    """ Set device to use """
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)
    return model