In [1]:
import torchvision.transforms as transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchjpeg import dct
# from scipy.fftpack import dct, idct
# import torch_dct as dct_2d, idct_2d
from PIL import Image
import os 
import numpy as np
import torch
import torchvision.transforms as T

block_size = 4
total_frequency_components = block_size * block_size
check_reconstruct_img = True
save_block_img_to_drive = False

def load_image(image_path):
    image = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    return transform(image).unsqueeze(0)  # Add batch dimension 

def img_reorder(x, bs, ch, h, w):
    x = (x + 1) / 2 * 255
    assert(x.shape[1] == 3, "Wrong input, Channel should equals to 3")
    x = dct.to_ycbcr(x)  # comvert RGB to YCBCR
    x -= 128
    x = x.view(bs * ch, 1, h, w)
    x = F.unfold(x, kernel_size=(block_size, block_size), dilation=1, padding=0, stride=(block_size, block_size))
    x = x.transpose(1, 2)
    x = x.view(bs, ch, -1, block_size, block_size)
    return x

## Image reordering and testing
def img_inverse_reroder(coverted_img, bs, ch, h, w):
    x = coverted_img.view(bs* ch, -1, total_frequency_components)
    x = x.transpose(1, 2)
    x = F.fold(x, output_size=(h, w), kernel_size=(block_size, block_size), stride=(block_size, block_size))
    x += 128
    x = x.view(bs, ch, h, w)
    x = dct.to_rgb(x)#.squeeze(0)
    x = (x / 255.0) * 2 - 1
    return x

## Image frequency cosine transform
def test_img_dct_transform_reorder_noise(x, bs, ch, h, w, freq_comp_lb):
    back_input = x
    rerodered_img = img_reorder(x, bs, ch, h, w)
    block_num = h // 4
    dct_block = dct.block_dct(rerodered_img) #BDCT
    dct_block_reorder = dct_block.view(bs, ch, block_num, block_num, total_frequency_components).permute(0, 1, 4, 2, 3) # into (bs, ch, 64, block_num, block_num)
    
    for i in range(freq_comp_lb):
        dct_block_reorder[:, :, i, :, :] = dct_block_reorder[:, :, freq_comp_lb, :, :]
 
    idct_dct_block_reorder = dct_block_reorder.view(bs, ch, total_frequency_components, block_num*block_num).permute(0, 1, 3, 2).view(bs, ch, block_num*block_num, block_size, block_size)
    inverse_dct_block = dct.block_idct(idct_dct_block_reorder) #inverse BDCT
    inverse_transformed_img = img_inverse_reroder(inverse_dct_block, bs, ch, h, w)
    print(torch.allclose(inverse_transformed_img, back_input, atol=1e-4))
    return inverse_transformed_img


## Image frequency cosine transform
def test_img_dct_transform_reorder_noise_outsource(x, bs, ch, h, w, freq_comp_lb, path_variance_matrix_tensor, add_noise):
    rerodered_img = img_reorder(x, bs, ch, h, w)
    block_num = h // 4
    dct_block = dct.block_dct(rerodered_img) #BDCT
    dct_block_reorder = dct_block.view(bs, ch, block_num, block_num, total_frequency_components).permute(0, 1, 4, 2, 3) # into (bs, ch, 64, block_num, block_num)
    
    for i in range(freq_comp_lb):
        dct_block_reorder[:, :, i, :, :] = dct_block_reorder[:, :, freq_comp_lb, :, :]
    if add_noise:
        mean = np.zeros(256)
        variance_matrix_tensor = torch.load(path_variance_matrix_tensor).cpu()
        total_sample_noises = dct_block_reorder.shape[0] * dct_block_reorder.shape[1] * dct_block_reorder.shape[2] * dct_block_reorder.shape[3]
        noise_sample = torch.from_numpy(np.random.multivariate_normal(mean, variance_matrix_tensor.detach().numpy(), total_sample_noises)).to(torch.float)
        noise_sample = noise_sample.reshape(dct_block_reorder.shape)
        dct_block_reorder = dct_block_reorder + noise_sample
    
    idct_dct_block_reorder = dct_block_reorder.view(bs, ch, total_frequency_components, block_num*block_num).permute(0, 1, 3, 2).view(bs, ch, block_num*block_num, block_size, block_size)
    inverse_dct_block = dct.block_idct(idct_dct_block_reorder) #inverse BDCT
    inverse_transformed_img = img_inverse_reroder(inverse_dct_block, bs, ch, h, w)
    return inverse_transformed_img

  from .autonotebook import tqdm as notebook_tqdm
  assert(x.shape[1] == 3, "Wrong input, Channel should equals to 3")


In [2]:
test_img_path = "/home/jianming/work/multiface/dataset/m--20180227--0000--6795937--GHS/unwrapped_uv_1024/E001_Neutral_Eyes_Open/average/000102.png"
test_image_tensor = load_image(test_img_path)
bs, ch, h, w = test_image_tensor.shape
test_inverse_transformed_img = test_img_dct_transform_reorder_noise(test_image_tensor, bs, ch, h, w, 0)

True


In [3]:
Test_list = []
for i in range(total_frequency_components-1):
    print(total_frequency_components-1-i)
    Test_list.append(total_frequency_components-1-i)
print(len(Test_list))

15
14
13
12
11
10
9
8
7
6
5
4
3
2
1
15
