In [2]:
import torch
def total_variation(images):
    ndims = images.dim()
    if ndims == 3:
        # The input is a single image with shape [height, width, channels].
        # Calculate the difference of neighboring pixel-values.
        pixel_dif1 = images[1:, :, :] - images[:-1, :, :]
        pixel_dif2 = images[:, 1:, :] - images[:, :-1, :]
        # Sum for all axis.
        tot_var = torch.sum(torch.abs(pixel_dif1)) + torch.sum(torch.abs(pixel_dif2))

    elif ndims == 4:
        # The input is a batch of images with shape: [batch, height, width, channels].
        # Calculate the difference of neighboring pixel-values.
        pixel_dif1 = images[:, 1:, :, :] - images[:, :-1, :, :]
        pixel_dif2 = images[:, :, 1:, :] - images[:, :, :-1, :]
        # Sum for the last 3 axes, resulting in a 1-D tensor with the total variation for each image.
        tot_var = torch.sum(torch.abs(pixel_dif1), dim=(1, 2, 3)) + torch.sum(torch.abs(pixel_dif2), dim=(1, 2, 3))
    else:
        raise ValueError("'images' must be either 3 or 4-dimensional.")

    return tot_var

decoded = torch.load('decoded2.pt').cpu()
gt = torch.load('gt2.pt').cpu()
detv = total_variation(decoded)
gttv = total_variation(gt)

print(detv,gttv)

tensor(185126.6562, grad_fn=<AddBackward0>) tensor(82546.8359)


In [5]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F

def total_variation(images):
    ndims = images.dim()
    if ndims == 3:
        # The input is a single image with shape [height, width, channels].
        # Calculate the difference of neighboring pixel-values.
        pixel_dif1 = images[1:, :, :] - images[:-1, :, :]
        pixel_dif2 = images[:, 1:, :] - images[:, :-1, :]
        # Sum for all axis.
        tot_var = torch.sum(torch.abs(pixel_dif1)) + torch.sum(torch.abs(pixel_dif2))
    elif ndims == 4:
        # The input is a batch of images with shape: [batch, height, width, channels].
        # Calculate the difference of neighboring pixel-values.
        pixel_dif1 = images[:, 1:, :, :] - images[:, :-1, :, :]
        pixel_dif2 = images[:, :, 1:, :] - images[:, :, :-1, :]
        # Sum for the last 3 axes, resulting in a 1-D tensor with the total variation for each image.
        tot_var = torch.sum(torch.abs(pixel_dif1), dim=(1, 2, 3)) + torch.sum(torch.abs(pixel_dif2), dim=(1, 2, 3))
    else:
        raise ValueError("'images' must be either 3 or 4-dimensional.")
    return tot_var

class TVLoss(nn.Module):
    def __init__(self, beta=1.0):
        super(TVLoss, self).__init__()
        # self.alpha = alpha
        self.beta = beta

    def forward(self, decoded, GT):
        # Calculate the MSE loss
        mse_loss = nn.MSELoss(reduction='sum')
        loss_mse = mse_loss(decoded, GT)
        tvloss= total_variation(decoded)
        print(f"tvloss:{tvloss*self.beta:.4f},mseloss:{loss_mse:.4f}")

        total_loss = loss_mse + tvloss * self.beta
        return total_loss

class SmoothLoss(nn.Module):
    def __init__(self, alpha=1.0, beta=1.0):
        super(SmoothLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta

    def forward(self, decoded, GT):
        # Calculate the MSE loss
        mse_loss = nn.MSELoss(reduction='sum')
        loss_mse = mse_loss(decoded, GT)

        diff1 = torch.abs(decoded[:,  :-1, :] - decoded[:,  1:, :])
        diff2 = torch.abs(decoded[:,  :, :-1] - decoded[:,  :, 1:])
        smoothness_loss = torch.sum(diff1) + torch.sum(diff2) * self.alpha
        # print(f"smoothloss:{smoothness_loss*self.beta:.4f},mseloss:{loss_mse:.4f}")
        total_loss = loss_mse + smoothness_loss * self.beta
        return total_loss

def median_filter2d(img, kernel_size=5):
    pad_size = kernel_size // 2    # 计算 padding 大小
    img_padded = F.pad(img, (pad_size, pad_size, pad_size, pad_size), mode='reflect')    # 对图像进行 padding
    batch_size, channels, height, width = img_padded.shape    # 获取图像的尺寸
    unfolded = F.unfold(img_padded, kernel_size=kernel_size)    # 展开图像矩阵
    unfolded = unfolded.view(batch_size, channels, kernel_size * kernel_size, -1)    # 计算中值
    median = unfolded.median(dim=2)[0]
    median = median.view(batch_size, channels, height - 2 * pad_size, width - 2 * pad_size)    # 恢复图像尺寸
    return median

def gaussian_kernel(size, sigma):
    """生成一个高斯核"""
    kernel = torch.tensor([[(1/(2.0*np.pi*sigma**2)) * np.exp(-((x - size//2)**2 + (y - size//2)**2)/(2*sigma**2))
                            for x in range(size)] for y in range(size)]).float()
    kernel /= kernel.sum()
    return kernel.unsqueeze(0).unsqueeze(0)

def gaussian_filter2d(img, kernel_size=5, sigma=1.0):
    """应用高斯滤波"""
    kernel = gaussian_kernel(kernel_size, sigma)
    channels = img.shape[1]
    kernel = kernel.repeat(channels, 1, 1, 1)
    padding = kernel_size // 2
    filtered_img = F.conv2d(img, kernel, padding=padding, groups=channels)
    return filtered_img

# 加载图像
decoded = torch.load('decoded2.pt').cpu()
gt = torch.load('gt2.pt').cpu()

sigma = 4.
# 假设图像是灰度图，转换形状为 [batch_size, channels, height, width]
# 这里的 decoded 是 [2, 361, 720]，我们假设 batch_size = 2，channels = 1
img = torch.Tensor(decoded).unsqueeze(1)  # 添加 channel 维度
filtered_img_median = median_filter2d(img, kernel_size=5)# 应用中值滤波
filtered_img_gaussian = gaussian_filter2d(img, kernel_size=5, sigma=sigma)# 应用高斯滤波
bothMG_img = gaussian_filter2d(filtered_img_median, kernel_size=5, sigma=sigma)#两个都用 这个效果好
bothGM_img = median_filter2d(filtered_img_median, kernel_size=5)#两个都用

# 移除 batch 和 channel 维度
img = img.squeeze(1)
filtered_img_median = filtered_img_median.squeeze(1)
filtered_img_gaussian = filtered_img_gaussian.squeeze(1)
bothMG_img = bothMG_img.squeeze(1)
bothGM_img = bothGM_img.squeeze(1)
gt = torch.Tensor(gt).squeeze(1)  # 确保 gt 形状正确

mseloss = torch.nn.MSELoss(reduction='sum')
loss0 = mseloss(img, gt)
loss1 = mseloss(filtered_img_median, gt)
loss2 = mseloss(filtered_img_gaussian, gt)
loss3 = mseloss(bothMG_img,gt)
loss4 = mseloss(bothGM_img,gt)
print(f'Mseloss原来loss={loss0:.4f}，中值滤波后loss={loss1:.4f}，高斯滤波后loss={loss2:.4f}，先中值后高斯loss={loss3:.4f}，两次中值后loss={loss4:.4f}')

smloss = SmoothLoss()
smloss0 = smloss(img, gt)
smloss1 = smloss(filtered_img_median, gt)
smloss2 = smloss(filtered_img_gaussian, gt)
smloss3 = smloss(bothMG_img,gt)
smloss4 = smloss(bothGM_img,gt)
print(f'Smoothloss原来loss={smloss0:.4f}，中值滤波后loss={smloss1:.4f}，高斯滤波后loss={smloss2:.4f}，先中值后高斯loss={smloss3:.4f}，两次中值后loss={smloss4:.4f}')

print(img.shape)
tvloss = TVLoss()
vt0 = tvloss(img, gt)
vt1 = tvloss(filtered_img_median, gt)
vt2 = tvloss(filtered_img_gaussian, gt)
vt3 = tvloss(bothMG_img,gt)
vt4 = tvloss(bothGM_img,gt)
print(f'TVloss原来loss={vt0:.4f}，中值滤波后loss={vt1:.4f}，高斯滤波后loss={vt2:.4f}，先中值后高斯loss={vt3:.4f}，两次中值后loss={vt4:.4f}')


Mseloss原来loss=142607.2031，中值滤波后loss=141088.7031，高斯滤波后loss=140381.4688，先中值后高斯loss=139440.8281，两次中值后loss=140711.9688
Smoothloss原来loss=177706.5469，中值滤波后loss=154292.3281，高斯滤波后loss=151825.8125，先中值后高斯loss=149060.7344，两次中值后loss=151170.3750
torch.Size([2, 361, 720])
tvloss:185126.6562,mseloss:142607.2031
tvloss:172791.6406,mseloss:141088.7031
tvloss:171196.7188,mseloss:140381.4688
tvloss:169329.6719,mseloss:139440.8281
tvloss:170797.7812,mseloss:140711.9688
TVloss原来loss=327733.8750，中值滤波后loss=313880.3438，高斯滤波后loss=311578.1875，先中值后高斯loss=308770.5000，两次中值后loss=311509.7500
