In [1]:
%run GlobalConfig.ipynb

import numpy as np
from PIL import Image
import torch.nn.functional as F

def mirror_pad(image, pad_size):
    return F.pad(image, (pad_size, pad_size, pad_size, pad_size), mode='reflect')
    
def apply_lut(lr_image, model, scale_factor=global_scale_factor, device = None):
    lr_image = lr_image.to(device)
    lr_height, lr_width = lr_image.shape[:2]
    hr_height, hr_width = int(lr_height * scale_factor), int(lr_width * scale_factor)

    # 镜像填充LR图像
    pad_size = 2  # 为了计算4x4补丁，填充2个像素
    padded_lr_image = mirror_pad(lr_image.permute(2, 0, 1), pad_size).permute(1, 2, 0)  # 转置到(N, C, H, W)后填充，再转回(H, W, C)

    # 创建高分辨率图像张量
    hr_image = torch.zeros((hr_height, hr_width, 3), device=device)

    # 生成 (i, j) 坐标网格
    i_coords = torch.arange(hr_height, device=device)
    j_coords = torch.arange(hr_width, device=device)
    ii, jj = torch.meshgrid(i_coords, j_coords, indexing='ij')

    # 计算 (x, y) 偏移
    x = ((ii + 0.5) / scale_factor + 0.5) % 1
    y = ((jj + 0.5) / scale_factor + 0.5) % 1

    # 将 (x, y) 坐标输入模型，计算权重
    weights = model(torch.stack([x, y], dim=-1).view(-1, 2)).view(hr_height, hr_width, 4, 4)

    # 计算 LR 图像中的采样坐标
    x_lr = torch.floor((ii + 0.5) / scale_factor - 1.5 + pad_size).long()  # 加上填充的补偿
    y_lr = torch.floor((jj + 0.5) / scale_factor - 1.5 + pad_size).long()  # 加上填充的补偿

    # 应用权重到 LR 图像的 4x4 补丁上
    # for dx in range(4):
    #     for dy in range(4):
    #         xi = torch.clamp(x_lr + dx, 0, lr_height - 1)
    #         yi = torch.clamp(y_lr + dy, 0, lr_width - 1)
    #         hr_image += weights[:, :, dx, dy].unsqueeze(-1) * lr_image[xi, yi].unsqueeze(2)

    
    # 对每个通道独立进行计算
    for c in range(3):
        hr_image_channel = hr_image[:, :, c]
        lr_image_channel = padded_lr_image[:, :, c]
        # print(lr_image_channel.shape)
        # print(weights.shape)
        for dx in range(4):
            for dy in range(4):
                # xi = torch.clamp(x_lr + dx, 0, lr_height - 1)
                # yi = torch.clamp(y_lr + dy, 0, lr_width - 1)
                xi = x_lr + dx
                yi = y_lr + dy
                hr_image_channel += (weights[:, :, dx, dy] * lr_image_channel[xi, yi])
    
    hr_image = hr_image.clamp(0, 1)  # 确保值在 [0, 1] 范围内
    return hr_image

def apply_lut_batchs(lr_images, model, scale_factor=global_scale_factor, device = None):
    hr_images = None
    for i in range(lr_images.shape[0]):
        hr_image = apply_lut(lr_images[i, :, :, :], model, scale_factor, device=device)
        if hr_images is None:
            # 根据第一个返回的 hr_image 初始化 hr_images
            hr_images = torch.zeros((lr_images.shape[0], *hr_image.shape)).to(device)
            
        hr_images[i] = hr_image
    return hr_images    
    