In [3]:
import rasterio
from rasterio.mask import mask

from shapely.geometry import box
import geopandas as gpd
from fiona.crs import from_epsg

import os
import numpy as np
import cv2
import glob
import collections

%matplotlib inline
import matplotlib.pyplot as plt

import sys
#sys.path.append('../../GranularData/')

#from utils import match
import pycrs

from functools import partial
import pyproj
from shapely.ops import transform

import pandas as pd


In [4]:
cities = os.listdir('../datasets/onera/images/')
#safes = os.listdir('/media/Drive1/onera_safes/')

cities.sort()

In [4]:
def stretch_8bit(bands, lower_percent=5, higher_percent=95):
    """stretch_8bit takes a 3 band image (as an array) and returns an 8bit array with clipped values (5-95%) stretched to 0-255

    Parameters
    ----------
    bands : numpy.array
        Numpy array of shape  (*,*,3)
    lower_percent : type
        Lower threshold below which array values will be discarded (the default is 5).
    higher_percent : type
        Upper threshold above which array values will be discarded (the default is 95).

    Returns
    -------
    numpy.array
        Numpy array containing np.uint8 values of shape bands.shape

    """
    out = np.zeros_like(bands)
    for i in range(3):
        a = 0
        b = 255
        real_values = bands[:,:,i].flatten()
        real_values = real_values[real_values > 0]
        c = np.percentile(real_values, lower_percent)
        d = np.percentile(real_values, higher_percent)
        t = a + (bands[:,:,i] - c) * (b - a) / (d - c)
        t[t<a] = a
        t[t>b] = b
        out[:,:,i] =t
    return out.astype(np.uint8)

In [226]:
def crop_region(source, target):
    bbox = box(source.bounds[0], source.bounds[3], source.bounds[2], source.bounds[1])
    geo = gpd.GeoDataFrame({'geometry': bbox}, index=[0], crs=from_epsg(4326))
    geo = geo.to_crs(crs=target.crs.data)

    coords = getFeatures(geo)
    
    out_img, out_transform = mask(target, shapes=coords, crop=True)
    
    out_meta = target.meta.copy()
    
    out_meta.update({"driver": "GTiff", "height": out_img.shape[1], 
                     "width": out_img.shape[2], "transform": out_transform})
    
    return out_meta, out_img

In [5]:
bands = ['B01','B02','B03','B04','B05','B06','B07','B08','B8A','B09','B10','B11','B12']

In [238]:
tile_ids = []
for city in cities:
    if 'txt' not in city:
        
        if os.listdir('../datasets/onera/images/' + city + '/imgs_1')[0].split('_')[-2][0] == 'T':
            tile_id = os.listdir('../datasets/onera/images/' + city + '/imgs_1')[0].split('_')[-2]
            date = os.listdir('../datasets/onera/images/' + city + '/imgs_1')[0].split('_')[2]
        else:
            tile_id = os.listdir('../datasets/onera/images/' + city + '/imgs_1')[0].split('_')[0]
            date = os.listdir('../datasets/onera/images/' + city + '/imgs_1')[0].split('_')[-1]
        base_path = glob.glob('../datasets/onera/images/' + city + '/imgs_1/**.tif')[0][:-7]
        
        sources = {}
        
        for band in bands:
            b = rasterio.open(base_path + band + '.tif')
            sources[band] = b
        
        pngs = os.listdir('../datasets/onera/images/' + city + '/pngs/')
        
        if not os.path.exists('../datasets/onera/images/' + city + '/cropped_safes/'):
            os.mkdir('../datasets/onera/images/' + city + '/cropped_safes/')
        
        for png in pngs:
            if 'png' in png:
                if png.split('_')[-2][0] == 'T':
                    date = png.split('_')[2]
                else:
                    date = png.split('_')[-1]
                
                if not os.path.exists('../datasets/onera/images/' + city + '/cropped_safes/' + date):
                    os.mkdir('../datasets/onera/images/' + city + '/cropped_safes/' + date)
                    base_path_safe = glob.glob("/media/Drive1/onera_safes/" + png[:-3] + "SAFE/GRANULE/**/IMG_DATA/**.jp2")[0][:-7]

                    for band in bands:
                        b_s = rasterio.open(base_path_safe + band + '.jp2')
                        b_out, b_img = crop_region(sources[band], b_s)

                        fout = rasterio.open('../datasets/onera/images/' + city + '/cropped_safes/' + date + '/' + band + '.tif', "w", **b_out)
                        fout.write(b_img)
                        fout.close()
                
        

In [27]:
total_pixels = 0
total_dates = 0
total_pixels_city = {}
total_dates_city = {}
for city in cities:
    if 'txt' not in city:
        dates = os.listdir('../datasets/onera/images/' + city + '/cropped_safes/')
        
        tot_pix = 0
        tot_dates = 0
        for date in dates:
            d = rasterio.open('../datasets/onera/images/' + city + '/cropped_safes/' + date + '/B02.tif')
            tot_dates += 1
            tot_pix += 13 * d.shape[0] * d.shape[1]
        
        total_pixels_city[city] = tot_pix * 2 / 1024 / 1024
        total_dates_city[city] = tot_dates
        
        total_pixels += tot_pix
        total_dates += tot_dates
    

In [28]:
total_pixels, total_dates

(5532510763, 967)

In [29]:
total_dates_city

{'abudhabi': 96,
 'aguasclaras': 48,
 'beihai': 12,
 'beirut': 32,
 'bercy': 31,
 'bordeaux': 21,
 'brasilia': 85,
 'chongqing': 22,
 'cupertino': 38,
 'dubai': 85,
 'hongkong': 38,
 'lasvegas': 49,
 'milano': 26,
 'montpellier': 27,
 'mumbai': 37,
 'nantes': 34,
 'norcia': 20,
 'paris': 31,
 'pisa': 54,
 'rennes': 26,
 'rio': 27,
 'saclay_e': 36,
 'saclay_w': 36,
 'valencia': 56}

In [30]:
total_pixels_city

{'abudhabi': 1565.6748046875,
 'aguasclaras': 303.7210693359375,
 'beihai': 211.67330932617188,
 'beirut': 1059.128662109375,
 'bercy': 123.26870727539062,
 'bordeaux': 137.65128135681152,
 'brasilia': 441.8478298187256,
 'chongqing': 223.0949249267578,
 'cupertino': 791.0493850708008,
 'dubai': 1068.9301872253418,
 'hongkong': 369.6859130859375,
 'lasvegas': 751.4050483703613,
 'milano': 208.62613677978516,
 'montpellier': 138.44833374023438,
 'mumbai': 452.6576900482178,
 'nantes': 282.1007537841797,
 'norcia': 51.28013610839844,
 'paris': 137.95782852172852,
 'pisa': 801.2918243408203,
 'rennes': 137.5458755493164,
 'rio': 103.88177490234375,
 'saclay_e': 433.4387969970703,
 'saclay_w': 434.1734390258789,
 'valencia': 323.89312744140625}

In [40]:
for city in cities:
    if 'txt' not in city:
        dates = os.listdir('../datasets/onera/images/' + city + '/cropped_safes/')
        dates.sort()
        
        stacked = []
        for date in dates:
            b01 = rasterio.open('../datasets/onera/images/' + city + '/cropped_safes/' + date + '/B01.tif').read()[0]
            b02 = rasterio.open('../datasets/onera/images/' + city + '/cropped_safes/' + date + '/B02.tif').read()[0]
            b03 = rasterio.open('../datasets/onera/images/' + city + '/cropped_safes/' + date + '/B03.tif').read()[0]
            b04 = rasterio.open('../datasets/onera/images/' + city + '/cropped_safes/' + date + '/B04.tif').read()[0]
            b05 = rasterio.open('../datasets/onera/images/' + city + '/cropped_safes/' + date + '/B05.tif').read()[0]
            b06 = rasterio.open('../datasets/onera/images/' + city + '/cropped_safes/' + date + '/B06.tif').read()[0]
            b07 = rasterio.open('../datasets/onera/images/' + city + '/cropped_safes/' + date + '/B07.tif').read()[0]
            b08 = rasterio.open('../datasets/onera/images/' + city + '/cropped_safes/' + date + '/B08.tif').read()[0]
            b8a = rasterio.open('../datasets/onera/images/' + city + '/cropped_safes/' + date + '/B8A.tif').read()[0]
            b09 = rasterio.open('../datasets/onera/images/' + city + '/cropped_safes/' + date + '/B09.tif').read()[0]
            b10 = rasterio.open('../datasets/onera/images/' + city + '/cropped_safes/' + date + '/B10.tif').read()[0]
            b11 = rasterio.open('../datasets/onera/images/' + city + '/cropped_safes/' + date + '/B11.tif').read()[0]
            b12 = rasterio.open('../datasets/onera/images/' + city + '/cropped_safes/' + date + '/B12.tif').read()[0]
            
            
            b01 = cv2.resize(b01, (b02.shape[1], b02.shape[0]))
            b05 = cv2.resize(b05, (b02.shape[1], b02.shape[0]))
            b06 = cv2.resize(b06, (b02.shape[1], b02.shape[0]))
            b07 = cv2.resize(b07, (b02.shape[1], b02.shape[0]))
            b08 = cv2.resize(b08, (b02.shape[1], b02.shape[0]))
            b8a = cv2.resize(b8a, (b02.shape[1], b02.shape[0]))
            b09 = cv2.resize(b09, (b02.shape[1], b02.shape[0]))
            b10 = cv2.resize(b10, (b02.shape[1], b02.shape[0]))
            b11 = cv2.resize(b11, (b02.shape[1], b02.shape[0]))
            b12 = cv2.resize(b12, (b02.shape[1], b02.shape[0]))
            
            this_date = np.stack([b01, b02, b03, b04, b05, b06, b07, b08, b8a, b09, b10, b11, b12], axis=2)
            
            stacked.append(this_date)
        
        stacked = np.asarray(stacked)
        
        np.save('../datasets/onera/npys/' + city + '.npy', stacked)
            
            

In [45]:
len(cities)

27

In [6]:
def match_bands(source, template):
        """
        Adjust the pixel values of a grayscale image such that its histogram
        matches that of a target image

        Arguments:
        -----------
            source: np.ndarray
                Image to transform; the histogram is computed over the flattened
                array
            template: np.ndarray
                Template image; can have different dimensions to source
        Returns:
        -----------
            matched: np.ndarray
                The transformed output image
        """
        oldshape = source.shape
        source = source.ravel()
        template = template.ravel()

        # get the set of unique pixel values and their corresponding indices and
        # counts
#         s_values, bin_idx, s_counts = np.unique(source, return_inverse=True,
#                                                 return_counts=True)
        perm = source.argsort(kind='heapsort')
        aux = source[perm]
        flag = np.concatenate(([True], aux[1:] != aux[:-1]))
        s_values = aux[flag]
        iflag = np.cumsum(flag) - 1
        inv_idx = np.empty(source.shape, dtype=np.intp)
        inv_idx[perm] = iflag
        bin_idx = inv_idx
        idx = np.concatenate(np.nonzero(flag) + ([source.size],))
        s_counts = np.diff(idx)
#         t_values, t_counts = np.unique(template, return_counts=True)

        a = pd.value_counts(template).sort_index()
        t_values = np.asarray(a.index)
        t_counts = np.asarray(a.values)


#         return(a)
        # take the cumsum of the counts and normalize by the number of pixels to
        # get the empirical cumulative distribution functions for the source and
        # template images (maps pixel value --> quantile)
        # s_quantiles = np.cumsum(s_counts).astype(np.float64)
        s_quantiles = np.cumsum(s_counts).astype(np.float64)

        s_quantiles /= s_quantiles[-1]

        # t_quantiles = np.cumsum(t_counts).astype(np.float64)
        t_quantiles = np.cumsum(t_counts).astype(np.float64)

        t_quantiles /= t_quantiles[-1]

        # interpolate linearly to find the pixel values in the template image
        # that correspond most closely to the quantiles in the source image
        interp_t_values = np.interp(s_quantiles, t_quantiles, t_values)

        return interp_t_values[bin_idx].reshape(oldshape)

In [7]:
hist_source_maps = {"39QZG": "20180924T065619",
"22LHH": "20150916T133636",
"49QCD": "20181005T030549",
"36SYC": "20150820T082006",
"31UDQ": "20181006T110029",
"30TXQ": "20161130T110422",
"22LHH": "20150916T133636",
"48RXT": "20161228T034142",
"10SEG": "20150918T190346",
"40RCN": "20151211T070232",
"49QHE": "20181004T024541",
"11SPV": "20150909T183316",
"32TNR": "20150806T102012",
"31TEJ": "20150802T104026",
"43QBB": "20151120T054122",
"30TXT": "20150821T111616",
"33TUH": "20151218T101215",
"31UDQ": "20181006T110029",
"32TPP": "20150704T101337",
"30UWU": "20150821T111616",
"23KPQ": "20150808T130816",
"31UDP": "20160125T111611",
"31UDP": "20151126T110402",
"30SYJ": "20150706T105351"}

In [8]:
for city in cities[1:]:
    if 'txt' not in city:
        date_path = '../datasets/onera/images/' + city + '/cropped_safes/'
        dates = os.listdir(date_path)
        
        if os.listdir('../datasets/onera/images/' + city + '/imgs_1')[0].split('_')[-2][0] == 'T':
            tile_id = os.listdir('../datasets/onera/images/' + city + '/imgs_1')[0].split('_')[-2]
        else:
            tile_id = os.listdir('../datasets/onera/images/' + city + '/imgs_1')[0].split('_')[0]
       
        source_safe = hist_source_maps[tile_id[1:]]
        
        dates.sort()
        
        dates_stacked = []
        for date in dates:
            bands_stacked = []
            
            shape = rasterio.open(date_path + source_safe + '/B02.tif').shape
            
            for band in bands:
                b_source = rasterio.open(date_path + source_safe + '/' + band + '.tif')
                b_target = rasterio.open(date_path + date + '/' + band + '.tif')
                
                b_s = b_source.read()[0]
                b_t = b_target.read()[0]
                
                mod = match_bands(b_s, b_t)
                mod = cv2.resize(mod, (shape[1], shape[0]))
                
                bands_stacked.append(mod)
                
                
            dates_stacked.append(bands_stacked)
        
        dates_stacked = np.asarray(dates_stacked)
        print (dates_stacked.shape)
        
        np.save('../datasets/onera/hist_matched_npys/' + city + '.npy', dates_stacked)

(96, 13, 852, 772)
(48, 13, 487, 524)
(12, 13, 941, 756)
(32, 13, 1319, 1012)
(31, 13, 514, 312)
(21, 13, 637, 415)
(85, 13, 447, 469)
(22, 13, 788, 519)
(38, 13, 1158, 725)
(85, 13, 822, 617)
(38, 13, 732, 536)
(49, 13, 930, 665)
(26, 13, 670, 483)
(27, 13, 517, 400)
(37, 13, 889, 555)
(34, 13, 660, 507)
(20, 13, 298, 347)
(31, 13, 531, 338)
(54, 13, 938, 638)
(26, 13, 439, 486)
(27, 13, 373, 416)
(36, 13, 823, 590)
(36, 13, 823, 591)
(56, 13, 535, 436)
