In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import osr
import gdal
from fastai.vision import *
from fastai.callbacks.hooks import *
from fastai.utils.mem import *
from fastai.utils.collect_env import *
import numpy as np
from PIL import Image
import geopandas as gp

raster_in_folder_name = 'raster_in'
data_folder_name = 'raster_to_shape'
shape_out_folder_name = 'shape_out'

config_files_path = Path('config_files')
segmenter_path = Path('model')
raster_in_path = Path(raster_in_folder_name)
data_tiles_path = Path(data_folder_name + '/tiles')
data_masks_path = Path(data_folder_name + '/masks')

tile_width = 256
tile_height = 256

type = 'roads'
model = 'roads2'

codes = np.loadtxt(config_files_path/'codes.txt', dtype=str)

name2id = {v:k for k,v in enumerate(codes)}
void_code = name2id['Void']

def acc_camvid(input, target):
    target = target.squeeze(1)
    mask = target != void_code
    return (input.argmax(dim=1)[mask]==target[mask]).float().mean()

metrics=acc_camvid

learn = load_learner(segmenter_path,model)

In [None]:
for rasterin in os.listdir(raster_in_folder_name):
    os.system('venv/bin/gdal_retile.py -ps ' + str(tile_width) + ' ' + str(tile_height) + ' -targetDir ' + data_folder_name + '/tiles ' + raster_in_folder_name + '/' + rasterin)

    for img_name in os.listdir(data_folder_name + '/tiles'):
        mask_name = 'mask_' + img_name
        
        source_raster = gdal.Open(data_folder_name + '/tiles/' + img_name)
        srs_geotrans = source_raster.GetGeoTransform()
        srs_proj = source_raster.GetProjection()

        raster_path = data_tiles_path/img_name
        img = open_image(raster_path)
        pred_class,pred_idx,outputs = learn.predict(img)
        res_img = ImageSegment(pred_idx)
        res_mask = ImageSegment(pred_idx * 255)
        res_img.save(data_masks_path/img_name)
        res_mask.save(data_masks_path/mask_name)

        rs_res_img = Image.open(data_folder_name + '/masks/' + img_name)
        rs_res_img = rs_res_img.resize((tile_width, tile_height))
        rs_res_img.save(data_folder_name + '/masks/' + img_name)

        rs_res_mask = Image.open(data_folder_name + '/masks/' + mask_name)
        rs_res_mask = rs_res_mask.resize((tile_width, tile_height))
        rs_res_mask.save(data_folder_name + '/masks/' + mask_name)

        dataset = gdal.Open(data_folder_name + '/masks/' + mask_name)
        
        width = dataset.RasterXSize
        height = dataset.RasterYSize
        datas = dataset.ReadAsArray(0,0,width,height)
        driver = gdal.GetDriverByName("GTiff")
        tods = driver.Create(data_folder_name + '/masks/srs_' + img_name,width,height,3,options=["INTERLEAVE=PIXEL"])
        tods.SetProjection(srs_proj)
        tods.SetGeoTransform(srs_geotrans)
        tods.WriteRaster(0,0,width,height,datas.tostring(),width,height,band_list=[1])
        tods = driver.Create(data_folder_name + '/masks/srs_' + img_name,width,height,3,options=["INTERLEAVE=PIXEL"])
        tods.SetProjection(srs_proj)
        tods.SetGeoTransform(srs_geotrans)
        tods.WriteRaster(0,0,width,height,datas.tostring(),width,height,band_list=[1])

        rasterfn = data_folder_name + '/masks/srs_' + img_name
        maskfn = data_folder_name + '/masks/' + mask_name
        outSHPfn = data_folder_name + '/shapes/' + img_name + '.shp'
        os.system('venv/bin/gdal_polygonize.py -8 -f "ESRI Shapefile" -mask ' + maskfn + ' ' + rasterfn + ' ' + outSHPfn)

    inSHPfn = data_folder_name + '/shapes/'
    outSHPfn = shape_out_folder_name + '/' + rasterin.split('.')[0]
    if os.path.exists(outSHPfn + '.shp'):
        os.remove(outSHPfn + '.shp')
    if os.path.exists(outSHPfn + '.dbf'):
        os.remove(outSHPfn + '.dbf')
    if os.path.exists(outSHPfn + '.prj'):
        os.remove(outSHPfn + '.prj')
    if os.path.exists(outSHPfn + '.cpg'):
        os.remove(outSHPfn + '.cpg')
    if os.path.exists(outSHPfn + '.shx'):
        os.remove(outSHPfn + '.shx')

    os.system('venv/bin/ogrmerge.py -single -o ' + outSHPfn + '.shp' + ' ' + inSHPfn + '*.shp')

    for file in os.listdir(data_folder_name + '/tiles'):
        os.remove(data_folder_name + '/tiles/' + file)

    for file in os.listdir(data_folder_name + '/shapes'):
        os.remove(data_folder_name + '/shapes/' + file)

    for file in os.listdir(data_folder_name + '/masks'):
        os.remove(data_folder_name + '/masks/' + file)

    gdf = gp.read_file(outSHPfn + '.shp')
    gdf['x'] = gdf.centroid.map(lambda p: p.x)
    gdf['y'] = gdf.centroid.map(lambda p: p.y)
    gdf['type']= type
    del gdf['DN']
    gdf.insert(0, 'ID', range(1, 1 + len(gdf)))
    gdf = gdf.to_crs({'init': 'epsg:3116'})
    gdf['area']= gdf['geometry'].area
    gdf = gdf.to_crs({'init': 'epsg:4326'})
    gdf.to_file(outSHPfn + '.shp')