In [1]:
import os
import sys
import torch
import argparse
import torchvision
import numpy as np
import torch.nn as nn
import scipy.io as sio
import cv2
import pywt
import pywt.data
import matplotlib.image as image

from PIL import Image
from tqdm import tqdm
from datetime import datetime
from collections import defaultdict
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from os.path import join, isdir, abspath, dirname
import matplotlib.pyplot as plt
# Customized import.
from networks import HED
from datasets import BsdsDataset
from utils import Logger, AverageMeter, \
    load_checkpoint, save_checkpoint, load_vgg16_caffe, load_pretrained_caffe


# Set device.
device = torch.device('cpu')


def wav_trans_grayscale(images):
    images = torchvision.utils.make_grid(images)
    npimg = images.numpy()
    img_np_right_dim = np.transpose(npimg, (1, 2, 0))
    img_np_right_dim = cv2.cvtColor(img_np_right_dim, cv2.COLOR_BGR2GRAY)
    coeffs2 = pywt.dwt2(img_np_right_dim, 'bior4.4')
    LL, (LH, HL, HH) = coeffs2
    return([LL, LH, HL, HH])



def wav_trans(images):
    # convert to numpy
    images = images.numpy() # making images an nchw numpy array
    images = images[0,0,:,:] # making images an hw numpy array
    coeffs2 = pywt.dwt2(images, 'bior4.4')
    LL, (LH, HL, HH) = coeffs2

    # LL is an hw numpy array

    # convert LL into a torch tensor of NCHW
    
    LL[LL>255.0] = 255.0
    height = LL.shape[0]
    width = LL.shape[1]
    LL = torch.from_numpy(LL)
    LL = torch.reshape(LL,(1, 1, height, width))
    
    return(LL)


def WT_grayscale_path_input(path):  
    """Takes as input path string"""
    orig_image = plt.imread(path)
    if(len(orig_image.shape)==3):
        gray_image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2GRAY)
    else:
        gray_image = orig_image
    coeffs2 = pywt.dwt2(gray_image, 'bior4.4')
    LL, (LH, HL, HH) = coeffs2
    return [LL, LH, HL, HH]

def im_save_path_modifier(input_path, mode):
    """
    parameter: input_path is the path of the original image
    parameter: mode is one of LL, LH, HL, HH
    """
    dir_path = input_path[:16] + f'{mode}/' + input_path[16:input_path.rfind('/')]
    image_path = input_path[:16] + f'{mode}/' + input_path[16:]
    return(dir_path, image_path)

def imshow(img):
#     img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()
    
    
    

################################################
# II. Datasets.
################################################
# Datasets and dataloaders.
train_dataset = BsdsDataset(dataset_dir='./data/HED-BSDS', split='train')
test_dataset  = BsdsDataset(dataset_dir='./data/HED-BSDS', split='test')
train_loader  = DataLoader(train_dataset, batch_size=1,
                           num_workers=4, drop_last=True, shuffle=True)
test_loader   = DataLoader(test_dataset,  batch_size=1,
                           num_workers=4, drop_last=False, shuffle=False)

dataiter = iter(test_loader)


for j, data in tqdm(enumerate(test_loader, 0)):
    
    images, path = data
    
    wav_trans_imlist = WT_grayscale_path_input(path[0])
    wav_trans_imlist_names  = ['LL_bior', 'LH_bior', 'HL_bior', 'HH_bior']


    for i in range(len(wav_trans_imlist)):

        # saving image

        dir_path, image_path = im_save_path_modifier(path[0], str(wav_trans_imlist_names[i]))
        os.makedirs(dir_path, exist_ok=True)
        plt.imsave(image_path, wav_trans_imlist[i], cmap = 'gray')

        
    if(j%100==0):
        print(j)

4it [00:00,  9.27it/s]

0


104it [00:04, 24.54it/s]

100


200it [00:09, 21.92it/s]
