In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import numpy as np
import cv2
import torchvision.transforms as transforms
import torchvision.models as models
import torchvision
import pytorch_ssim

class ConvLutModel(nn.Module):
    def __init__(self, input_size=2, hidden_size=64, output_size=16):
        super(ConvLutModel, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size * 2)
        self.fc2 = nn.Linear(hidden_size * 2, hidden_size)
        self.dropout = nn.Dropout(p=0.2)  # Dropout 概率为 50%
        self.fc3 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.dropout(x)  # 应用 Dropout
        x = self.fc3(x)
        return x

def perceptual_loss(sr, hr, vgg):
    sr_vgg_features = vgg(sr)
    hr_vgg_features = vgg(hr)
    loss = nn.functional.mse_loss(sr_vgg_features, hr_vgg_features)
    return loss

def ssim_loss(sr, hr):
    return 1 - ssim(sr, hr)

def ssim(img1, img2, window_size=11, size_average=True):
    (_, channel, _, _) = img1.size()
    window = pytorch_ssim.create_window(window_size, channel).to(img1.device)
    padding = window_size // 2  # 将padding设置为整数
    return _ssim(img1, img2, window, window_size, channel, size_average, padding)

def _ssim(img1, img2, window, window_size, channel, size_average, padding):
    mu1 = nn.functional.conv2d(img1, window, padding=padding, groups=channel)
    mu2 = nn.functional.conv2d(img2, window, padding=padding, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = nn.functional.conv2d(img1 * img1, window, padding=padding, groups=channel) - mu1_sq
    sigma2_sq = nn.functional.conv2d(img2 * img2, window, padding=padding, groups=channel) - mu2_sq
    sigma12 = nn.functional.conv2d(img1 * img2, window, padding=padding, groups=channel) - mu1_mu2

    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

class VGGFeatureExtractor(nn.Module):
    def __init__(self):
        super(VGGFeatureExtractor, self).__init__()
        vgg19 = models.vgg19()
        self.features = nn.Sequential(*list(vgg19.features)[:35]).eval()
        for param in self.features.parameters():
            param.requires_grad = False

    def forward(self, x):
        return self.features(x)