In [None]:
import torch
import torch.nn as nn

def total_variation(images):
    ndims = images.dim()
    if ndims == 3:
        # The input is a single image with shape [height, width, channels].
        # Calculate the difference of neighboring pixel-values.
        pixel_dif1 = images[:, :, 1:] - images[:, :, :-1]
        # pixel_dif1 = images[1:, :, :] - images[:-1, :, :] #改正了对batchsize做差分的错误。。。
        pixel_dif2 = images[:, 1:, :] - images[:, :-1, :]
        # Sum for all axis.
        tot_var = torch.mean(torch.abs(pixel_dif1)) + torch.mean(torch.abs(pixel_dif2))
        # tot_var = torch.sum(torch.abs(pixel_dif1)) + torch.sum(torch.abs(pixel_dif2))
    elif ndims == 4:
        # The input is a batch of images with shape: [batch, height, width, channels].
        # Calculate the difference of neighboring pixel-values.
        pixel_dif1 = images[:, :, :, 1:] - images[:, :, :, :-1]
        pixel_dif2 = images[:, :, 1:, :] - images[:, :, :-1, :]
        # Sum for the last 3 axes, resulting in a 1-D tensor with the total variation for each image.
        tot_var = torch.mean(torch.abs(pixel_dif1), dim=(1, 2, 3)) + torch.mean(torch.abs(pixel_dif2), dim=(1, 2, 3))
        # tot_var = torch.sum(torch.abs(pixel_dif1), dim=(1, 2, 3)) + torch.sum(torch.abs(pixel_dif2), dim=(1, 2, 3))
    else:
        raise ValueError("'images' must be either 3 or 4-dimensional.")
    return tot_var

class MaxLoss(nn.Module):
    def __init__(self, beta=1.0, gama=1.0):
        super(MaxLoss, self).__init__()
        # self.alpha = alpha
        self.beta = beta
        self.gama = gama

    def forward(self, decoded, GT):
        # Calculate the MSE loss
        L1_loss = nn.L1Loss(reduction='mean')
        # L1_loss = nn.MSELoss(reduction='mean')
        # L1_loss = nn.L1Loss(reduction='sum')
        loss_L1 = L1_loss(decoded, GT)
        tvloss = total_variation(decoded)
        maxloss = torch.mean(torch.abs(torch.amax(decoded, dim=(1, 2)) - torch.amax(GT, dim=(1, 2))))
        # logger.info(f" tvloss:{tvloss*self.beta:.4f}, L1loss:{loss_L1:.4f}")
        # print(f"L1={loss_L1:.4f},TV={tvloss:.4f},max={maxloss:.4f}")
        total_loss = loss_L1 + tvloss * self.beta + maxloss * self.gama
        
        # print(f'l1loss:{loss_L1},tvloss:{tvloss},totalloss:{total_loss}')
        return total_loss
        # return loss_L1

gt = torch.rand(10,36,72)
decoded = torch.rand(10,36,72)
loss = MaxLoss()
print(loss(gt,decoded))


L1=0.3324,TV=0.6674,max=0.0003
tensor(1.0001)
