In [10]:
import torch
import numpy as np

In [11]:
def make_coord(shape, ranges=None, flatten=True):
    """ Make coordinates at grid centers.
    """
    coord_seqs = []
    for i, n in enumerate(shape):
        if ranges is None:
            v0, v1 = -1, 1
        else:
            v0, v1 = ranges[i]
        r = (v1 - v0) / (2 * n)
        seq = v0 + r + (2 * r) * torch.arange(n).float()
        coord_seqs.append(seq)
    ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
    if flatten:
        ret = ret.view(-1, ret.shape[-1])
    return ret


def to_pixel_samples(img):
    """ Convert the image to coord-RGB pairs.
        img: Tensor, (3, H, W)
    """
    coord = make_coord(img.shape[-2:])
    coord = coord.view(img.shape[-2],img.shape[-1],-1)
    rgb = img.view(3, -1).permute(1, 0)
    return coord, rgb

In [12]:
img = torch.randn([3,16,16])
coord,rgb = to_pixel_samples(img)

print(coord.shape)

torch.Size([16, 16, 2])


In [17]:
def snake_scan(tensor):
    B, C, H, W = tensor.shape
    tensor = tensor.view(B, C, H*W)

    for i in range(H):
        if i % 2 == 1:  # If it's an odd row, reverse the order
            tensor[:, :, i*W:(i+1)*W] = tensor[:, :, i*W:(i+1)*W].flip(-1)

    tensor = tensor.view(B, C, -1)
    return tensor

In [18]:
def hilbert_distance(x, y, N):
    rx, ry = 0, 0
    s = (1 << (N - 1))
    d = 0
    while s > 0:
        ry = ((y & s) > 0)
        rx = ((x & s) > 0)
        d += s * s * ((3 * rx) ^ ry)
        x, y = rot(s, x, y, rx, ry)
        s //= 2
    return d

def rot(n, x, y, rx, ry):
    if ry == 0:
        if rx == 1:
            x = n - 1 - x
            y = n - 1 - y
        x, y = y, x
    return x, y

def hilbert_sort(B, C, H, W):
    N = int(np.ceil(np.log2(max(H, W))))
    max_d = 1 << (2 * N)

    hilbert_map = torch.zeros((H, W), dtype=torch.long)
    for y in range(H):
        for x in range(W):
            hilbert_map[y, x] = hilbert_distance(x, y, N)

    hilbert_map = hilbert_map.view(-1).sort()[1].view(H, W)

    return hilbert_map

def hilbert_scan(tensor):
    B, C, H, W = tensor.shape
    hilbert_map = hilbert_sort(B, C, H, W)

    tensor = tensor.view(B, C, -1)
    tensor = tensor[:, :, hilbert_map.view(-1)]
    tensor = tensor.view(B, C, -1)

    return tensor

In [19]:
# 创建一个4x4的RGB图像作为示例
# image = torch.tensor([
#     [[1, 2, 3, 4],
#      [8, 7, 6, 5],
#      [9, 10, 11, 12],
#      [16, 15, 14, 13]],
#      [[1, 2, 3, 4],
#      [8, 7, 6, 5],
#      [9, 10, 11, 12],
#      [16, 15, 14, 13]],
#      [[1, 2, 3, 4],
#      [8, 7, 6, 5],
#      [9, 10, 11, 12],
#      [16, 15, 14, 13]]
# ], dtype=torch.float32)
image = torch.tensor([
    [[1, 2, 3, 4],
     [9, 10, 11, 12],
     [16, 15, 14, 13]],
     [[1, 2, 3, 4],
     [9, 10, 11, 12],
     [16, 15, 14, 13]],
     [[1, 2, 3, 4],
     [9, 10, 11, 12],
     [16, 15, 14, 13]]
], dtype=torch.float32)
image = image.unsqueeze(0)  # 添加batch维度
b,c,h,w = image.shape
print(image)


tensor([[[[ 1.,  2.,  3.,  4.],
          [ 9., 10., 11., 12.],
          [16., 15., 14., 13.]],

         [[ 1.,  2.,  3.,  4.],
          [ 9., 10., 11., 12.],
          [16., 15., 14., 13.]],

         [[ 1.,  2.,  3.,  4.],
          [ 9., 10., 11., 12.],
          [16., 15., 14., 13.]]]])


In [20]:
zigzag_image = snake_scan(image)
print(zigzag_image)


tensor([[[ 1.,  2.,  3.,  4., 12., 11., 10.,  9., 16., 15., 14., 13.],
         [ 1.,  2.,  3.,  4., 12., 11., 10.,  9., 16., 15., 14., 13.],
         [ 1.,  2.,  3.,  4., 12., 11., 10.,  9., 16., 15., 14., 13.]]])


In [21]:
hilbert_image = hilbert_scan(image)
print(hilbert_image)


tensor([[[ 1.,  2., 11., 12., 16., 15., 14., 13.,  9., 10.,  3.,  4.],
         [ 1.,  2., 11., 12., 16., 15., 14., 13.,  9., 10.,  3.,  4.],
         [ 1.,  2., 11., 12., 16., 15., 14., 13.,  9., 10.,  3.,  4.]]])


In [2]:
import os
from PIL import Image

# Specify the folder path
folder_path = "E:/Research/test_results/groundtruth"
save_folder_path = "E:/Research/test_results/LR-x2"

# Get the list of image files in the folder
image_files = [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]

# Iterate through each image file
for file_name in image_files:
    # Open the image
    image_path = os.path.join(folder_path, file_name)
    image = Image.open(image_path)

    # Downsample the image to 1/2
    width, height = image.size
    new_width = width // 2
    new_height = height // 2
    downscaled_image = image.resize((new_width, new_height))

    # Save the downscaled image
    save_path = os.path.join(save_folder_path, file_name)
    downscaled_image.save(save_path)

    # Close the image
    image.close()


: 