In [1]:
import os
import sys
import torch
import argparse
import torchvision
import numpy as np
import torch.nn as nn
import scipy.io as sio
import cv2
import pywt
import pywt.data
import matplotlib.image as image

from PIL import Image
from tqdm import tqdm
from datetime import datetime
from collections import defaultdict
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from os.path import join, isdir, abspath, dirname
import matplotlib.pyplot as plt
# Customized import.
from networks import HED
from datasets import BsdsDataset
from utils import Logger, AverageMeter, \
    load_checkpoint, save_checkpoint, load_vgg16_caffe, load_pretrained_caffe


# Set device.
device = torch.device('cpu')


def wav_trans_grayscale(images):
    images = torchvision.utils.make_grid(images)
    npimg = images.numpy()
    img_np_right_dim = np.transpose(npimg, (1, 2, 0))
    img_np_right_dim = cv2.cvtColor(img_np_right_dim, cv2.COLOR_BGR2GRAY)
    coeffs2 = pywt.dwt2(img_np_right_dim, 'bior4.4')
    LL, (LH, HL, HH) = coeffs2
    return([LL, LH, HL, HH])



def wav_trans(images):
    # convert to numpy
    images = images.numpy() # making images an nchw numpy array
    images = images[0,0,:,:] # making images an hw numpy array
    coeffs2 = pywt.dwt2(images, 'bior4.4')
    LL, (LH, HL, HH) = coeffs2

    # LL is an hw numpy array

    # convert LL into a torch tensor of NCHW
    
    LL[LL>255.0] = 255.0
    height = LL.shape[0]
    width = LL.shape[1]
    LL = torch.from_numpy(LL)
    LL = torch.reshape(LL,(1, 1, height, width))
    
    return(LL)


def WT_grayscale_path_input(path):  
    """Takes as input path string"""
    orig_image = plt.imread(path)
    if(len(orig_image.shape)==3):
        gray_image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2GRAY)
    else:
        gray_image = orig_image
    coeffs2 = pywt.dwt2(gray_image, 'bior4.4')
    LL, (LH, HL, HH) = coeffs2
    return [LL, LH, HL, HH]

def im_save_path_modifier(input_path, mode):
    """
    parameter: input_path is the path of the original image
    parameter: mode is one of LL, LH, HL, HH
    """
    dir_path = input_path[:16] + f'{mode}/' + input_path[16:input_path.rfind('/')]
    image_path = input_path[:16] + f'{mode}/' + input_path[16:]
    return(dir_path, image_path)

def imshow(img):
#     img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()
    
    
    

################################################
# II. Datasets.
################################################
# Datasets and dataloaders.
train_dataset = BsdsDataset(dataset_dir='./data/HED-BSDS', split='train')
test_dataset  = BsdsDataset(dataset_dir='./data/HED-BSDS', split='test')
train_loader  = DataLoader(train_dataset, batch_size=1,
                           num_workers=4, drop_last=True, shuffle=True)
test_loader   = DataLoader(test_dataset,  batch_size=1,
                           num_workers=4, drop_last=False, shuffle=False)

dataiter = iter(train_loader)



for j, data in tqdm(enumerate(train_loader, 0)):

    images, labels, path, label_path = data
    wav_trans_imlist = WT_grayscale_path_input(path[0])
    wav_trans_imlist_names  = ['LL_bior', 'LH_bior', 'HL_bior', 'HH_bior']

    for i in range(len(wav_trans_imlist)):

        # saving image

        dir_path, image_path = im_save_path_modifier(path[0], str(wav_trans_imlist_names[i]))
        os.makedirs(dir_path, exist_ok=True)
        plt.imsave(image_path, wav_trans_imlist[i], cmap = 'gray')

        # saving label
        # retrieve label image
        im_label = plt.imread(label_path[0])
        label_dir_path_new, label_path_new = im_save_path_modifier(label_path[0], str(wav_trans_imlist_names[i]))
        os.makedirs(label_dir_path_new, exist_ok=True)
        plt.imsave(label_path_new, im_label, cmap = 'gray')

    if(j%1000==0):
        print(j)

In [9]:
for j, data in tqdm(enumerate(train_loader, 0)):

    images, labels, path, label_path = data
    wav_trans_imlist = WT_grayscale_path_input(path[0])
    wav_trans_imlist_names  = ['LL_bior', 'LH_bior', 'HL_bior', 'HH_bior']

    for i in range(len(wav_trans_imlist)):

        # saving image

        dir_path, image_path = im_save_path_modifier(path[0], str(wav_trans_imlist_names[i]))
        os.makedirs(dir_path, exist_ok=True)
        plt.imsave(image_path, wav_trans_imlist[i], cmap = 'gray')

        # saving label
        # retrieve label image
        im_label = plt.imread(label_path[0])
        label_dir_path_new, label_path_new = im_save_path_modifier(label_path[0], str(wav_trans_imlist_names[i]))
        os.makedirs(label_dir_path_new, exist_ok=True)
        plt.imsave(label_path_new, im_label, cmap = 'gray')

    if(j%1000==0):
        print(j)

1it [00:00,  8.13it/s]

0


1003it [02:21,  7.54it/s]

1000


2002it [04:48,  6.89it/s]

2000


3001it [07:10,  8.61it/s]

3000


4002it [09:33,  6.57it/s]

4000


5002it [11:55,  9.07it/s]

5000


6002it [14:17,  8.30it/s]

6000


7002it [16:39,  7.13it/s]

7000


8003it [19:02,  9.16it/s]

8000


9001it [21:24,  5.98it/s]

9000


10001it [23:44,  6.48it/s]

10000


11001it [26:06,  6.96it/s]

11000


12001it [28:27,  8.25it/s]

12000


13002it [30:50,  4.90it/s]

13000


14003it [33:13,  7.16it/s]

14000


15002it [35:31,  8.61it/s]

15000


16001it [37:54,  7.24it/s]

16000


17001it [40:14,  6.30it/s]

17000


18003it [42:34,  7.59it/s]

18000


19002it [44:55,  6.64it/s]

19000


20002it [47:15,  6.99it/s]

20000


21002it [49:43,  7.74it/s]

21000


22001it [52:08,  6.76it/s]

22000


23001it [54:30,  6.83it/s]

23000


24002it [56:47,  8.01it/s]

24000


25003it [59:10,  6.97it/s]

25000


26002it [1:01:33,  6.90it/s]

26000


27002it [1:03:52,  5.99it/s]

27000


28001it [1:06:19,  7.75it/s]

28000


28800it [1:08:11,  7.04it/s]
