In [1]:
import torchvision.transforms as transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchjpeg import dct
from PIL import Image
import seaborn as sns
import numpy as np
import torch
import math
import cv2
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
block_size = 4
total_frequency_component = block_size * block_size

overall_img_path_list = []
path_prefix = "/home/jianming/work/multiface/dataset/m--20180227--0000--6795937--GHS/unwrapped_uv_1024/"
all_dir = os.listdir(path_prefix)
# print(all_dir)
for sgl_dir in all_dir:
    path_average = os.path.join(path_prefix + sgl_dir, "average")
        # print(os.path.join(path_average, image))
    overall_img_path_list.append(os.path.join(path_average, os.listdir(path_average)[0]))

def img_reorder(x, bs, ch, h, w):
    x = (x + 1) / 2 * 255
    assert(x.shape[1] == 3, "Wrong input, Channel should equals to 3")
    x = dct.to_ycbcr(x)  # comvert RGB to YCBCR
    x -= 128
    x = x.view(bs * ch, 1, h, w)
    x = F.unfold(x, kernel_size=(block_size, block_size), dilation=1, padding=0, stride=(block_size, block_size))
    x = x.transpose(1, 2)
    x = x.view(bs, ch, -1, block_size, block_size)
    return x

## Image reordering and testing
def img_inverse_reroder(coverted_img, bs, ch, h, w):
    x = coverted_img.view(bs* ch, -1, total_frequency_component)
    x = x.transpose(1, 2)
    x = F.fold(x, output_size=(h, w), kernel_size=(block_size, block_size), stride=(block_size, block_size))
    x += 128
    x = x.view(bs, ch, h, w)
    x = dct.to_rgb(x)#.squeeze(0)
    x = (x / 255.0) * 2 - 1
    return x

def calculate_block_mse(downsample_in, freq_block, num_freq_component=block_size):
    downsample_img = transforms.Resize(size=int(downsample_in.shape[-1]/num_freq_component))(downsample_in)
    assert(downsample_img.shape == freq_block[:,:,0,:,:].shape, "downsample input shape does not match the shape of post-BDCT component")
    loss_vector = torch.zeros(freq_block.shape[2])
    for i in range(freq_block.shape[2]):
        # calculate the MSE between each frequency components and given input downsampled images
        loss_vector[i] = F.mse_loss(downsample_img, freq_block[:,:,i,:,:])
    return loss_vector

def bdct_4x4(img_path):
    # The original input image comes with it and I disable it to reduce the computation overhead.
    # x = F.interpolate(x, scale_factor=8, mode='bilinear', align_corners=True)
    image = Image.open(img_path).convert('RGB')
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    x = transform(image).unsqueeze(0)

    back_input = x
    bs, ch, h, w = x.shape
    block_num = h // block_size
    x = img_reorder(x, bs, ch, h, w)
    dct_block = dct.block_dct(x) # BDCT
    dct_block_reorder = dct_block.view(bs, ch, block_num, block_num, total_frequency_component).permute(0, 1, 4, 2, 3) # into (bs, ch, 64, block_num, block_num)

    return  dct_block_reorder

def private_freq_component_thres_based_selection(img_path, mse_threshold):
    # The original input image comes with it and I disable it to reduce the computation overhead.
    # x = F.interpolate(x, scale_factor=8, mode='bilinear', align_corners=True)
    image = Image.open(img_path).convert('RGB')
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    x = transform(image).unsqueeze(0)

    back_input = x
    bs, ch, h, w = x.shape
    block_num = h // block_size
    x = img_reorder(x, bs, ch, h, w)
    dct_block = dct.block_dct(x) # BDCT
    dct_block_reorder = dct_block.view(bs, ch, block_num, block_num, total_frequency_component).permute(0, 1, 4, 2, 3) # into (bs, ch, 64, block_num, block_num)
    loss_vector = calculate_block_mse(back_input, dct_block_reorder)
    # Split all component based on the frequency
    private_idx = torch.where(loss_vector > mse_threshold)[0]
    public_idx = []
    all_possible_idx = [i for i in range(total_frequency_component)]
    for element in all_possible_idx:
        if element not in private_idx:
            public_idx.append(element)

    return private_idx,  torch.Tensor(public_idx).to(torch.int64), dct_block_reorder

  assert(x.shape[1] == 3, "Wrong input, Channel should equals to 3")
  assert(downsample_img.shape == freq_block[:,:,0,:,:].shape, "downsample input shape does not match the shape of post-BDCT component")


# Intuitive Illustration

In [35]:
# User 1
overall_img_path_list = []
path_prefix = "/home/jianming/work/multiface/dataset/m--20180227--0000--6795937--GHS/unwrapped_uv_1024/"
all_dir = os.listdir(path_prefix)
for sgl_dir in all_dir:
    path_average = os.path.join(path_prefix + sgl_dir, "average")
    overall_img_path_list.append(os.path.join(path_average, os.listdir(path_average)[0]))

highest_frequency_components_list = []
for img_path in overall_img_path_list:
    image = Image.open(img_path).convert('RGB')
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    x = transform(image).unsqueeze(0)
    highest_frequency_components_list.append(x)

# User 2
overall_img_path_list2 = []
path_prefix2 = "/scratch1/jianming/multiface/dataset/m--20180226--0000--6674443--GHS/unwrapped_uv_1024/"
all_dir = os.listdir(path_prefix2)
for sgl_dir in all_dir:
    path_average2 = os.path.join(path_prefix2 + sgl_dir, "average")
    overall_img_path_list2.append(os.path.join(path_average2, os.listdir(path_average2)[0]))

highest_frequency_components_list2 = []
for img_path in overall_img_path_list2:
    image = Image.open(img_path).convert('RGB')
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    x = transform(image).unsqueeze(0)
    highest_frequency_components_list2.append(x)

# Two users
highest_frequency_components_overall = highest_frequency_components_list + highest_frequency_components_list2
num_images = len(highest_frequency_components_overall)

In [36]:
highest_frequency_components_overall_test = highest_frequency_components_overall[:10]

isotropic_noise_covariance = [0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]

for noise_value in isotropic_noise_covariance:
    l2_norm_list = []
    for i, img  in  enumerate(highest_frequency_components_overall_test):
        img = img.squeeze(0)
        noise_val = torch.Tensor(np.random.normal(0, noise_value, img.shape))
        noisy_img = img + noise_val
        # transforms.functional.to_pil_image(noisy_img).save(f'downsample_original_img_{i}_noise_{noise_value}.png')
        l2_norm_list.append(np.linalg.norm(noisy_img - img))
    avg_l2_norm = np.mean(l2_norm_list)
    print(f"average L2 norm is {avg_l2_norm} for noise {noise_value}")

average L2 norm is 17.73265838623047 for noise 0.01
average L2 norm is 88.67213439941406 for noise 0.05
average L2 norm is 177.37962341308594 for noise 0.1
average L2 norm is 354.73406982421875 for noise 0.2
average L2 norm is 532.0896606445312 for noise 0.3
average L2 norm is 709.4993286132812 for noise 0.4
average L2 norm is 886.6399536132812 for noise 0.5
average L2 norm is 1064.2862548828125 for noise 0.6
average L2 norm is 1241.5096435546875 for noise 0.7
average L2 norm is 1418.840576171875 for noise 0.8
average L2 norm is 1596.6778564453125 for noise 0.9
average L2 norm is 1774.0244140625 for noise 1
