# Convert from GeoTiff to Numpy
This script takes all of the GeoTIFF files, reads them using GDAL, converts them to numpy arrays, and then pre-processes the images and saves them to a more easily used format (compressed numpy arrays) so that we don't have to use GDAL all the time to load the data.

In [1]:
import gdal, common, uuid
from osgeo import gdal_array
import numpy as np
from io import BytesIO
from scipy.ndimage import rotate



ROTATION_ANGLE = 14 # determined by brute force

class Rotator:
    '''
    A simple wrapping class that can rotate images based on the rotation information derived from the label image.
    '''
    def __init__(self, labelImage, rotationAngle=ROTATION_ANGLE):
        self.labelImage = labelImage
        self.rotAngle = rotationAngle
        
        nl = np.array(labelImage[:,:,0])
        # 255 is the padding value for the rotated image, so set that to zero.
        nl[nl==255] = 0
        # use zero-order here because this is a nearest neighbor 
        nl = rotate(nl,self.rotAngle,order=0)

        # about half-way through the image find the first and last sample that are not zero, these define
        # the cropping extents for the rotated image, so that we can reduce our image to only the parts that
        # matter
        horz = np.where(abs(np.diff(nl[3000])) > 0)[0]
        vert = np.where(abs(np.diff(nl[:,3000])) > 0)[0]

        self.sx1,self.ex1 = horz[0],horz[-1]
        self.sy1,self.ey1 = vert[0],vert[-1]

        self.rotatedLabelImage = nl[self.sy1:self.ey1, self.sx1:self.ex1]
        
    def __call__(self, image):
        rotated = rotate(image,self.rotAngle, order=1) # this is linear
        return rotated[self.sy1:self.ey1, self.sx1:self.ex1]
    
    
def cleanLabelImage(rotator,nclasses=10):
    '''
    Extract the most significant classes from the label image, otherwise you wind up with a bunch of classes that
    only have very small counts, and subsequently you'll waste a lot of time trying to train a classifier using almost
    no training data for those classes.
    '''
    # there are 256 classes in the raw image
    classCounts = np.zeros((256,2))
    for k in range(0,255):
        rotClass = (rot.rotatedLabelImage == k).sum()
        origClass = (rot.labelImage == k).sum()
        classCounts[k,:] = [rotClass, origClass]
        print('Class count : {} -> {} {}'.format(k,rotClass,origClass))

    print('there are {} classes with non-zero counts'.format((classCounts[:,0] > 0).sum()))
#     np.histogram(classCounts[:,0],bins=100)
    idx = np.argsort(classCounts[:,0])[::-1] # descending order argsort
    keepClasses = classCounts[idx,0][:nclasses]
    remappedLabelImage = np.zeros_like(rot.rotatedLabelImage)
    keepCount = 1
    for clazz, count in zip(idx,keepClasses):
        swap = rot.rotatedLabelImage == clazz
#         print('swap.sum', swap.sum())
        remappedLabelImage[swap] = keepCount
#         print(clazz, count, keepCount)
        keepCount += 1
    print('the min,max range of the remapped label image is: {}, {}'.format(remappedLabelImage.min(), remappedLabelImage.max()))
    return remappedLabelImage

def cleanSpectralImage(image,thresh=20000):
    '''
    There are anomalously large values in some of the bands, we need to clamp those values to prevent them from causing problems.
    '''
    cleaned = np.array(image)
#     pclip = np.percentile(cleaned,95)
    cleaned[image > thresh] = thresh
    return cleaned

def geoToArray(gdalImage):
    '''
    Convert a gdal image to a corresponding numpy array.
    '''
    #>>>>>>>> This particular segment of code was refactored from: https://github.com/bhavesh907/Crop-Classification/blob/master/data-preprocessing.ipynb
    shape = (gdalImage.RasterYSize, gdalImage.RasterXSize, gdalImage.RasterCount)
    image = np.zeros(shape,
               gdal_array.GDALTypeCodeToNumericTypeCode(gdalImage.GetRasterBand(1).DataType))

    for b in range(1,shape[2]+1):
        image[:, :, b-1] = gdalImage.GetRasterBand(b).ReadAsArray()
        
    #<<<<<<<<<'
    return image

def info(dataset):
    '''
    This was an attempt to backout the rotation directly inside of GDAL, but it was taking too long to figure out so I abandoned it
    and went ahead and just brute forced the rotation angle.
    '''
    print("Projection is {}".format(dataset.GetProjection()))
    geotransform = dataset.GetGeoTransform()
    if geotransform:
        print('geo transform', geotransform)
        print("Origin = ({}, {})".format(geotransform[0], geotransform[3]))
        print("Pixel Size = ({}, {})".format(geotransform[1], geotransform[5]))
        

                            
def readGeoTiff(key):
    '''
    Read geo tiff data from raw bytes, and convert to numpy arrays.
    '''
    # >>>> this particular code snippet was modified from: https://gist.github.com/jleinonen/5781308
    byteData = common.readTIFF(key)
    gdal_dataset = None
    try:
        mmap_name = "/vsimem/"+str(uuid.uuid4())
        gdal.FileFromMemBuffer(mmap_name, byteData)
        gdal_dataset = gdal.Open(mmap_name)
#         info(gdal_dataset)
        return geoToArray(gdal_dataset)
    finally:
        if gdal_dataset is not None:
            gdal_dataset = None
            gdal.Unlink(mmap_name)
    # <<<<<


images = {}

labelImage = readGeoTiff(common.LABEL_KEY)

rot = Rotator(labelImage)

labelImage = cleanLabelImage(rot)
common.saveNumpy(labelImage,'labels')

Class count : 0 -> 51584 6048
Class count : 1 -> 51136 51120
Class count : 2 -> 70224 70704
Class count : 3 -> 0 0
Class count : 4 -> 3087 3492
Class count : 5 -> 0 0
Class count : 6 -> 0 0
Class count : 7 -> 0 0
Class count : 8 -> 0 0
Class count : 9 -> 0 0
Class count : 10 -> 0 0
Class count : 11 -> 0 0
Class count : 12 -> 0 0
Class count : 13 -> 0 0
Class count : 14 -> 0 0
Class count : 15 -> 0 0
Class count : 16 -> 0 0
Class count : 17 -> 0 0
Class count : 18 -> 0 0
Class count : 19 -> 0 0
Class count : 20 -> 0 0
Class count : 21 -> 46673 47700
Class count : 22 -> 0 0
Class count : 23 -> 870 864
Class count : 24 -> 490124 494748
Class count : 25 -> 0 0
Class count : 26 -> 0 0
Class count : 27 -> 0 0
Class count : 28 -> 73417 73512
Class count : 29 -> 0 0
Class count : 30 -> 0 0
Class count : 31 -> 0 0
Class count : 32 -> 0 0
Class count : 33 -> 5988 6120
Class count : 34 -> 0 0
Class count : 35 -> 0 0
Class count : 36 -> 1540735 1543068
Class count : 37 -> 23706 23724
Class count :

In [2]:
import tqdm

for key in tqdm.tqdm(common.IMAGE_KEYS):
    im = readGeoTiff(key)
    rotated = rot(im)
    finalIm = cleanSpectralImage(rotated)
    common.saveNumpy(finalIm, key.split('.')[0]) # strip the extension
    

100%|██████████| 10/10 [05:13<00:00, 31.30s/it]
