In [None]:
import rasterio
import affine
from rasterio import plot
import matplotlib.pyplot as plt
import shapefile
from pyproj import CRS, transform, Transformer
import os
%matplotlib inline

In [None]:
IMAGE_PIXEL = 15
IMAGE_WIDTH = 732
LIST_SHAPEFILES = ["heigvd/central_highlands_2_other/central_highlands_2_other.shp",
                  "heigvd/central_highlands_2_test/central_highlands_2_test.shp",
                  "heigvd/central_higlands_1_other/central_higlands_1_other.shp"]
CORRESPONDANCES = {"1":"cacao", "2":"coffee","3":"complex_oil", "4":"nativevege", "5":"oil_palm", "6":"rubber", "7":"unknown", 
                 "8":"seasonal", "9":"urban", "10":"water", "11":"other_tree", "12":"other_no_tree", "13":"native_no_tree",
                 "14":"water_other", "15":"pepper", "16":"cassava", "17":"tea", "18":"rice", "19":"banana", "20":"baby_palm", 
                 "21":"cur_off-regrow", "22":"natural_wetland", "23":"intercrop", "24":"deciduous_forest", "25":"stick_pepper", 
                 "26":"flooded_plantation", "27":"pine_trees", "28":"coconut", "29":"banboo", "30":"savana", "31":"mango", 
                 "32":"other_fruit_tree_crop", "33":"water_mine", "0":"not_labeled", "-1":"ambiguous"}
SHAPEFILE_ESPG = 4326

In [None]:
points = []
for path in LIST_SHAPEFILES:
    sf = shapefile.Reader(path)
    shapes = sf.shapes()
    for point in sf.records():
        points.append((shapes[point.oid].points[0], point[0]))
        
def get_labels(begin, end):
    """Donne tous les labels contenu entre 2 positions géographiques"""
    labels = set()
    long1, lat1 = convert_to_format(begin)
    long2, lat2 = convert_to_format(end)
    for point in points:
        if long1 <= point[0][0] <= long2 and lat2 <= point[0][1] <= lat1:
            labels.add(point[1])
    return labels

In [None]:
bands = ['gauche/fall/T48PYU_20191025T030811_B08_10m.jp2',
        'gauche/fall/T48PYU_20191025T030811_B04_10m.jp2',
        'gauche/fall/T48PYU_20191025T030811_B03_10m.jp2',
        'gauche/fall/T48PYU_20191025T030811_B02_10m.jp2',
        'gauche/winter/T48PYU_20191224T031131_B08_10m.jp2',
        'gauche/winter/T48PYU_20191224T031131_B04_10m.jp2',
        'gauche/winter/T48PYU_20191224T031131_B03_10m.jp2',
        'gauche/winter/T48PYU_20191224T031131_B02_10m.jp2',
        'gauche/spring/T48PYU_20200318T030539_B08_10m.jp2',
        'gauche/spring/T48PYU_20200318T030539_B04_10m.jp2',
        'gauche/spring/T48PYU_20200318T030539_B03_10m.jp2',
        'gauche/spring/T48PYU_20200318T030539_B02_10m.jp2'
        ]

In [None]:
transformer = Transformer.from_crs(rasterio.open(bands[0], driver='JP2OpenJPEG').gcps[1], SHAPEFILE_ESPG)
print(rasterio.open(bands[0], driver='JP2OpenJPEG').gcps[1])
def convert_to_format(coords):
    """Sert à convertir le format de géolocalisation de l'image pour matcher avec celui des points des shapefiles"""
    conversion = transformer.transform(coords[0], coords[1])
    return (conversion[1], conversion[0])

In [None]:
def split_band(band, width, image_length):
    """Split une bande de Sentinel-2 en bande plus petite
    IMAGE_WIDTH^2 images de IMAGE_PIXELxIMAGE_PIXEL"""
    result = []
    my_band = band.read(1)
    for x in range(width):
        columns = []
        for y in range(width):
            lines = []
            for i in range(image_length):
                cells = []
                for j in range(image_length):
                    cells.append(my_band[i + IMAGE_PIXEL * x][j + IMAGE_PIXEL * y])
                lines.append(cells)
            columns.append(lines)
        result.append(columns)
    return result

In [None]:
def create_smaller_images(bands, path_name, file_name):
    """Créée des images plus petites pouvant être mieux traitées par le réseau de neurones et les classes dans le bon dossier
    selon s'il existe des positions déjà labelée dans le dossier (ambigious si plusieurs, not_labeled si pas de coordonnées)"""
    images = []
    width = None
    height = None
    crs = None
    transformFix = None
    dtype = None
    for band in bands:
        image = rasterio.open(band, driver='JP2OpenJPEG')
        images.append(split_band(image, IMAGE_WIDTH, IMAGE_PIXEL))
        if width == None:
            width = image.width
            height = image.height
            crs = image.crs
            transformFix = image.transform
            dtype = image.dtypes[0]
        image.close()
        print("Fin", band)
    if not os.path.isdir(path_name):
        os.mkdir(path_name)
    for _, directory in CORRESPONDANCES.items():
        if not os.path.isdir(path_name + "/" + directory):
            os.mkdir(path_name + "/" + directory)
    for x in range(IMAGE_WIDTH):
        for y in range(IMAGE_WIDTH):
            begin = transformFix * (y * IMAGE_PIXEL,x * IMAGE_PIXEL)
            end = transformFix * ((y + 1) * IMAGE_PIXEL,(x + 1) * IMAGE_PIXEL)  
            transform = affine.Affine(10.0, 0, begin[0], 0.0, -10.0, begin[1])
            #Trie selon s'il y a des points déjà répertorié dans le fichier.
            labels = get_labels(begin, end)
            image_name = path_name + "/"
            if len(labels) == 0:
                if not os.path.exists(path_name + "/" + CORRESPONDANCES['0'] + "/" + str(x)):
                    os.mkdir(path_name + "/" + CORRESPONDANCES['0'] + "/" + str(x))
                image_name += CORRESPONDANCES['0'] + "/" + str(x) + "/"
            elif len(labels) > 1:
                image_name += CORRESPONDANCES['-1'] + "/"
            else:
                image_name += CORRESPONDANCES[str(next(iter(labels)))] + "/"
            image_name += file_name
            smaller_image = rasterio.open(image_name + '_' + str(x) + '_' + str(y) \
                                              + '.tiff', 'w', driver='Gtiff', 
                          width=IMAGE_PIXEL, height=IMAGE_PIXEL, count=12, crs=crs, 
                          transform=transform, 
                          dtype=dtype
                         )
            cpt = 1
            for image in images:
                smaller_image.write(image[x][y], cpt)
                cpt += 1
            smaller_image.close()
        print(str(x) + "/" + str(IMAGE_WIDTH))
    print("Fichiers tiff créés.")

In [None]:
create_smaller_images(bands, "output/final_test", "final_test")