In [1]:
# coding = utf-8
# author = xy


import torch
from torch.nn.modules import loss
import torch.nn.functional as f
import utils


class MyNLLLoss(loss._Loss):
    """ MLE 最大似然估计 """
    def __init__(self):
        super(MyNLLLoss, self).__init__()

    def forward(self, outputs, batch):
        """
        :param outputs: tensor (2, batch_size, content_seq_len)
        :param batch: tensor
        :return:loss
        """
        y_start = batch[2]
        y_end = batch[3]
        outputs = torch.log(outputs)
        start_loss = f.nll_loss(outputs[0], y_start)
        end_loss = f.nll_loss(outputs[1], y_end)
        return start_loss + end_loss


class RougeLoss(loss._Loss):
    """ MRT 最小风险 """
    def __init__(self, lam):
        super(RougeLoss, self).__init__()

        self.lam = lam

        self.mle = MyNLLLoss()

    def forward(self, outputs, batch):
        """
        :param outputs: tensor (2, batch_size, content_seq_len)
        :param batch: tensor (content, question, start, end)
        :return: loss
        """
        # mrt
        start_y = batch[2]
        end_y = batch[3]
        y = [torch.range(s, e).cuda() for s, e in zip(start_y, end_y)]

        start_pred, end_pred = torch.max(outputs, dim=2)[1]
        pred = [torch.range(s, e).cuda() if s <= e else torch.Tensor([-1]).cuda() for s, e in zip(start_pred, end_pred)]

        loss_mrt = [utils.rouge_score(pred_i, y_i) for pred_i, y_i in zip(pred, y)]
        print(loss_mrt)
        loss_mrt_tmp = 0
        for i in loss_mrt:
            loss_mrt_tmp += i
        loss_mrt = loss_mrt_tmp / len(loss_mrt)

        # mle
        loss_mle = self.mle(outputs, batch)

        loss_value = loss_mle + self.lam * loss_mrt
        return loss_value


In [2]:
outputs = torch.rand(2, 3, 4).cuda()
outputs

tensor([[[ 0.3370,  0.2132,  0.1478,  0.3742],
         [ 0.5618,  0.0211,  0.7462,  0.4303],
         [ 0.4481,  0.9956,  0.9781,  0.1160]],

        [[ 0.8187,  0.3722,  0.5184,  0.7440],
         [ 0.5579,  0.0137,  0.8926,  0.2001],
         [ 0.0819,  0.1607,  0.7611,  0.6231]]], device='cuda:0')

In [3]:
batch = [0, 0, torch.LongTensor([1, 0, 0]).cuda(), torch.LongTensor([1, 3, 0]).cuda()]

In [4]:
loss = RougeLoss(4)

In [5]:
loss(outputs, batch)

[tensor(0., device='cuda:0'), tensor(0.3609, device='cuda:0'), tensor(0., device='cuda:0')]


tensor(3.1562, device='cuda:0')

In [7]:
a = [torch.tensor(3), torch.tensor(3)]
a

[tensor(3), tensor(3)]

In [9]:
torch.mean(torch.Tensor(a))

tensor(3.)

In [10]:
torch.Tensor(a)

tensor([ 3.,  3.])

In [56]:
torch.max(torch.LongTensor([5]), torch.LongTensor([3]))

tensor([ 5])

In [49]:
torch.max(torch.LongTensor([1]), torch.LongTensor([-1]))

tensor([ 1])

In [12]:
a= torch.LongTensor([1,3,4])
b = torch.LongTensor([2, 5, 8])

In [8]:
string = '你妹的..wokao'
sub = '.woo '

In [13]:
y = [torch.range(s, e) for s, e in zip(a, b)]

In [14]:
y

[tensor([ 1.,  2.]),
 tensor([ 3.,  4.,  5.]),
 tensor([ 4.,  5.,  6.,  7.,  8.])]

In [27]:
outputs = torch.rand(2, 3, 4)
outputs

tensor([[[ 0.5194,  0.8005,  0.4448,  0.8597],
         [ 0.8940,  0.6153,  0.0597,  0.1176],
         [ 0.3903,  0.0271,  0.1076,  0.3063]],

        [[ 0.9697,  0.9926,  0.4235,  0.3457],
         [ 0.6630,  0.4112,  0.7384,  0.9590],
         [ 0.8537,  0.1090,  0.2849,  0.4107]]])

In [28]:
torch.max(outputs, dim=2)[1]

tensor([[ 3,  0,  0],
        [ 1,  3,  0]])

In [35]:
a, b = torch.max(outputs, dim=2)[1]

In [40]:
len(a)

3

In [30]:
b

tensor([ 1,  3,  0])

In [39]:
[torch.range(s, e) if s<=e else torch.Tensor([-1]) for s, e in zip(a, b)]

[tensor([-1.]), tensor([ 0.,  1.,  2.,  3.]), tensor([ 0.])]