Skip to content

A deep learning framework for a more agile development process

License

Notifications You must be signed in to change notification settings

deng1fan/LazyProjects

Repository files navigation

Towards Faithful Dialogs via Focus Learning 论文代码正在整理中,先贴出FCE 的核心计算代码片段

论文核心代码:

    class CosineSimilarity(torch.nn.Module):
        def forward(self, tensor_1, tensor_2):
            normalized_tensor_1 = tensor_1 / tensor_1.norm(dim=-1, keepdim=True)
            normalized_tensor_2 = tensor_2 / tensor_2.norm(dim=-1, keepdim=True)
            return (normalized_tensor_1 * normalized_tensor_2).sum(dim=-1)
    cal_sim = CosineSimilarity()
    knowledge_emb = model.get_input_embeddings()(
        knowledges
    )
    sim_dist = -cal_sim(knowledge_emb, labels_emb)
    sim_score = -torch.log(sim_dist + 1 + self.config.get("fce_lamda", 0.01))+ 1
    
    weighted_lm_logits = torch.mul(sim_score.unsqueeze(-1).repeat(1, 1, logits.shape[-1]), logits)
    loss_fct = CrossEntropyLoss(ignore_index=-100)
    fce_loss = loss_fct(weighted_lm_logits.view(-1, weighted_lm_logits.size(-1)),
                                   torch.where(labels == self.tokenizer.pad_token_id, -100, labels).view(-1))

About

A deep learning framework for a more agile development process

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published