In [None]:
import time
import math

import matplotlib.pyplot as plt
import skimage
from skimage.feature import hog, greycomatrix, greycoprops
from skimage.color import rgb2gray
from skimage.morphology import erosion, dilation, opening, closing, white_tophat
import cv2
from osgeo import gdal
from osgeo import osr
import numpy as np

MAX_SCALE = 150

def write_geotiff(out_file, in_arr, geotran, srs_wkt):
    """
    in_arr must be in channels, rows, cols
    """
    driver = gdal.GetDriverByName('GTiff')
    if len(in_arr.shape) == 3:
        out = driver.Create(out_file, in_arr.shape[2], in_arr.shape[1], in_arr.shape[0], gdal.GDT_Float64)
        out.SetGeoTransform(geotran) # the origin is the upper left of the input shapefile
        for b in range(in_arr.shape[0]):
            outband = out.GetRasterBand(b+1)
            outband.WriteArray(in_arr[b])
            outband.FlushCache()
    else:
        out = driver.Create(out_file, in_arr.shape[1], in_arr.shape[0], 1, gdal.GDT_Float64)
        out.SetGeoTransform(geotran) # the origin is the upper left of the input shapefile
        outband = out.GetRasterBand(1)
        outband.WriteArray(in_arr)
        outband.FlushCache()
    out.SetProjection(srs_wkt)


def hog_feature(image_name, block, scale, output=None, stat=None):
    """
    Parameters:
    ----------
    image_name: str
    block: int
    scale: int
    
    Returns:
    --------
    out_image: 3D ndarray
    """
    ds = gdal.Open(image_name)
    image = ds.ReadAsArray()
    geotran = ds.GetGeoTransform()
    ulx = geotran[0]
    uly = geotran[3]
    cell_width = geotran[1]
    cell_height = geotran[5]
    
    out_srs = osr.SpatialReference()
    out_srs.ImportFromEPSG(4326)
    out_srs_wkt = out_srs.ExportToWkt()
    out_cell_width = block * cell_width
    out_cell_height = block * cell_height
    
    ds = None
    
    image = np.moveaxis(image, 0, -1) # expects an image in rows, columns, channels
    out_image = []
    for i in range(0, image.shape[0], block):
        outrow = []
        if i >= MAX_SCALE and i <= image.shape[0] - MAX_SCALE:
            for j in range(0, image.shape[1], block):
                if j >= MAX_SCALE and j <= image.shape[1] - MAX_SCALE:
                    block_arr = image[i:i+block,j:j+block]
                    center_i = int(i+block/2)
                    center_j = int(j+block/2)
                    if len(out_image) == 0 and len(outrow) == 0:
                        out_uly = uly + cell_height * (center_i - block)
                        out_ulx = ulx + cell_width * (center_j - block)
                    if block%2 != 0 and scale%2 == 0: # make sure the scale window is the correct size for the block
                        scale_arr = image[center_i-int(scale/2):center_i+int(scale/2),center_j-int(scale/2):center_j+int(scale/2)]
                    else:
                        scale_arr = image[center_i-int(scale/2):center_i+int(scale/2)+1,center_j-int(scale/2):center_j+int(scale/2)+1]      
                    fd = hog(scale_arr, orientations=8, pixels_per_cell=(scale_arr.shape[0], scale_arr.shape[1]), cells_per_block=(1, 1), multichannel=True, feature_vector=False)
                    outrow.append(fd.flatten())
            out_image.append(outrow)
    out_arr = np.moveaxis(out_image, -1, 0)
    
    """for i in range(int(scale/2), image.shape[0] - int(scale/2), block):
        outrow = []
        for j in range(int(scale/2), image.shape[1] - int(scale/2), block):
            block_arr = image[i:i+block,j:j+block]
            center_i = int(i+block/2)
            center_j = int(j+block/2)
            if len(out_image) == 0 and len(outrow) == 0:
                out_uly = uly + cell_height * (center_i - block)
                out_ulx = ulx + cell_width * (center_j - block)
            if block%2 != 0 and scale%2 == 0: # make sure the scale window is the correct size for the block
                scale_arr = image[center_i-int(scale/2):center_i+int(scale/2),center_j-int(scale/2):center_j+int(scale/2)]
            else:
                scale_arr = image[center_i-int(scale/2):center_i+int(scale/2)+1,center_j-int(scale/2):center_j+int(scale/2)+1]      
            fd = hog(scale_arr, orientations=8, pixels_per_cell=(scale_arr.shape[0], scale_arr.shape[1]), cells_per_block=(1, 1), multichannel=True, feature_vector=False)
            outrow.append(fd.flatten())
        out_image.append(outrow)
    out_arr = np.moveaxis(out_image, -1, 0)"""
    
    if output:
        if stat:
            out_arr = calc_stat(out_arr, "all", 0)
        out_geotran = (out_ulx, out_cell_width, 0, out_uly, 0, out_cell_height)
        # this should be a standardized write geotiff function
        write_geotiff(output, out_arr, out_geotran, out_srs_wkt)
    else:
        return np.array(out_arr)
    
def glcm_feature(image_name, block, scale, output=None, prop=None, stat=None):
    """
    Parameters:
    -----------
    image_name: str
    block: int
    scale: int
    prop: str
    stat: str
    
    Returns:
    --------
    out_image: 2D or 3D ndarray (depends on the input)
    """
    ds = gdal.Open(image_name)
    image = ds.ReadAsArray()
    geotran = ds.GetGeoTransform()
    ulx = geotran[0]
    uly = geotran[3]
    cell_width = geotran[1]
    cell_height = geotran[5]
    
    out_srs = osr.SpatialReference()
    out_srs.ImportFromEPSG(4326)
    out_srs_wkt = out_srs.ExportToWkt()
    out_cell_width = block * cell_width
    out_cell_height = block * cell_height
    
    ds = None
    image = np.moveaxis(image, 0, -1)
    image = skimage.img_as_ubyte(rgb2gray(image))
    
    pi = 3.14159265
    angles = [0., pi/6., pi/4., pi/3., pi/2., (2.*pi)/3., (3.*pi)/4., (5.*pi)/6.]
    # dist = [1, 2, 4, 8, 16, 32, 64, 128]
    dist = [10, 20]
    distances = [n for n in dist if n < scale]
    
    out_image = []
    for i in range(0, image.shape[0], block):
        outrow = []
        if i >= MAX_SCALE and i <= image.shape[0] - MAX_SCALE:
            for j in range(0, image.shape[1], block):
                if j >= MAX_SCALE and j <= image.shape[1] - MAX_SCALE:
                    block_arr = image[i:i+block,j:j+block]
                    center_i = int(i+block/2)
                    center_j = int(j+block/2)
                    if len(out_image) == 0 and len(outrow) == 0:
                        out_uly = uly + cell_height * (center_i - block)
                        out_ulx = ulx + cell_width * (center_j - block)
                    if block%2 != 0 and scale%2 == 0: # make sure the scale window is the correct size for the block
                        scale_arr = image[center_i-int(scale/2):center_i+int(scale/2),center_j-int(scale/2):center_j+int(scale/2)]
                    else:
                        scale_arr = image[center_i-int(scale/2):center_i+int(scale/2)+1,center_j-int(scale/2):center_j+int(scale/2)+1]      

                    out = greycomatrix(scale_arr, distances, angles)
                    if prop:
                        if prop == "variance": # variance is not included in greycoprops, so use a custom implementation
                            print('hi')
                        else:
                            out = greycoprops(out, prop) # results 2d array [d, a] is the property for th d'th distance and a'th angle
                            if stat:
                                out = calc_stat(out, stat, None)
                    else:
                        if stat:
                            out = calc_stat(out, stat, None)
                        else:
                            return out
                    outrow.append(out)
            out_image.append(outrow)
    out_image = np.array(out_image)
    """for i in range(int(scale/2), image.shape[0] - int(scale/2), block):
        outrow = []
        for j in range(int(scale/2), image.shape[1] - int(scale/2), block):
            block_arr = image[i:i+block,j:j+block]
            center_i = int(i+block/2)
            center_j = int(j+block/2)
            if block%2 != 0 and scale%2 == 0: # make sure the scale window is the correct size for the block
                scale_arr = image[center_i-int(scale/2):center_i+int(scale/2),center_j-int(scale/2):center_j+int(scale/2)]
            else:
                scale_arr = image[center_i-int(scale/2):center_i+int(scale/2)+1,center_j-int(scale/2):center_j+int(scale/2)+1]      
            
            out = greycomatrix(scale_arr, distances, angles)
            
            if prop:
                if prop == "variance": # variance is not included in greycoprops
                    print('hi')
                else:
                    out = greycoprops(out, prop) # results 2d array [d, a] is the property for th d'th distance and a'th angle
                    if stat:
                        out = calc_stat(out, stat, None)
            else:
                if stat:
                    out = calc_stat(out, stat, None)
                else:
                    return out
            outrow.append(out)
        out_image.append(outrow)"""
   
    if output:
        out_geotran = (out_ulx, out_cell_width, 0, out_uly, 0, out_cell_height)
        # this should be a standardized write geotiff function
        write_geotiff(output, out_image, out_geotran, out_srs_wkt)
    else:
        return np.array(out_image)

    
def pantex_feature(image_name, block, scale, output=None):
    if output:
        glcm_feature(image_name, block, scale, output=output, prop="contrast", stat="min")
    else:
        return glcm_feature(image_name, block, scale, prop="contrast", stat="min")

def get_se_set(sizes):
    """
    Parameters:
    -----------
    sizes: list
    
    Returns:
    --------
    se_set: ndarray (4D)
        se_set[0] gives the linear directional kernels for size
            at index zero
        se_set[1] gives the linear direction kernels for size at
            index 1
    """
    se_set = []
    for se_size in sizes:
        assert(se_size%2!=0)
        # create a structural element for the direction and size
        # directions are hardcoded to 4 for now. it generates 4
        # kernels with directions of 0, 45, 90, and 135
        se0 = np.zeros(shape=(se_size,se_size))
        se0[se_size//2,:] = 1
        se45 = np.diagflat(np.ones(shape=(se_size)))[::-1]
        se90 = np.zeros(shape=(se_size,se_size))
        se90[:,se_size//2] = 1
        se135 = np.diagflat(np.ones(shape=(se_size)))
        se_set.append([se0, se45, se90, se135])
    return se_set

def MBI_feature(image_name, postprocess=True):
    MBI_THRESHOLD = 5.5
    
    ds = gdal.Open(image_name)
    image = ds.ReadAsArray()
    ds = None
    image = np.moveaxis(image, 0, -1) # rows, columns, channels
    # calculate brightness as a local max
    brightness = calc_stat(image, "max", 2)
    # a set of linear structural elements
    # for the white tophat transformation
    # dirs = [45, 90, 135, 180]
    se_sizes = [5, 9, 13, 19, 23, 27]
    se_set = get_se_set(se_sizes)
    # 'white' top-hat transformation
    # in this case, white top-hat is the brightness image minus morphological opening
    mean_w_tophats = []
    for s in se_set: # for each size in the structural element set
        w_tophats = []
        for k in s: # for each direction kernel in the structural element set for this size
            # directional top hat transformation using linear SE
            w_tophats.append(white_tophat(brightness, k))
        mean_w_tophat = calc_stat(w_tophats, 'mean', 0)
        mean_w_tophats.append(mean_w_tophat)
    
    th_dmp = []
    th_idx = 0
    while th_idx + 1 < len(mean_w_tophats):
        th_dmp.append(np.absolute(mean_w_tophats[th_idx + 1] - mean_w_tophats[th_idx]))
        th_idx+=1
    mbi = calc_stat(np.array(th_dmp), 'mean', 0)
    if postprocess:
        mbi = np.where(mbi >= MBI_THRESHOLD, 1, 0)
    return mbi

def LSR_feature(dat_directory):
    files = os.listdir(dat_directory)
    for f in files:
        infile = open(os.path.join(dat_directory, f))
        line = infile.readline()
        while line != "":
            line = line.strip().split(",")
            llen = line[0] # line length
            lmx = line[1] # middle x coordinate
            lmy = line[2] # middle y coordinate
            lorn = line[3] # line orientation
            

def textons_feature(image_name, block, scale):
    return NotImplemented
        
    
def calc_stat(arr, stat_name, axis=None):
    """
    Parameters:
    -----------
    arr: ndarray
        the input array
    stat_name: str
        the name of the statistics.
        "max", "min", "mean", "var", "std"
    axis: int, optional
        the axis over which the statistics is calculated
        
    Returns:
    --------
    out: ndarray
    """
    if stat_name == "all":
        out = np.array([np.amin(arr, axis), np.amax(arr, axis), np.mean(arr, axis), np.var(arr, axis), np.sum(arr, axis)])
    elif stat_name == "min":
        out = np.amin(arr, axis)
    elif stat_name == "max":
        out = np.amax(arr, axis)
    elif stat_name == "var":
        out = np.var(arr, axis)
    elif stat_name == "mean":
        out = np.mean(arr, axis)
    elif stat_name == "std":
        out = np.std(arr, axis)
    else: # stat_name == "sum":
        out = np.sum(arr, axis)
    return out