In [43]:
import json
import os
import numpy as np
import datetime
import skimage
from osgeo import gdal, osr

In [86]:
class Dataset(object):
    def __init__(self):
        now = datetime.datetime.now()
        self.info_info = [{'year':now.year,'version':'','description':'','contributor':'','url':'','date_created':str(now)}]
        # Background is always the first class
        #list of dictionaries
        self.license_info = [{'id': 0, 'name': 'GEOJsonToCOCO', 'url': ''}]
        #self.categories_info = [{'id': 0, 'name': 'BG', 'supercategory': ''}]
        self.categories_info = []
        #list of dictionaries
        self.images_info = []
        self.annotations_info = []
        
    def load_geojson_file(self, filePath, imagePath):
        json_file = open(filePath)
        geo_json = json.load(json_file)
        json_file.close()
        
        
        imageFilename = os.path.split(imagePath)[1]
        #split out the digits at the end of the filename
        imageID = int(imageFilename.split('.')[0].split('_')[-1][3:])
        
        src_raster = gdal.Open(imagePath)
        geom_transform = src_raster.GetGeoTransform()
        longCorner = geom_transform[0]
        latCorner = geom_transform[3]
        pixelWidth = geom_transform[1]
        pixelHeight = geom_transform[5]
        
        for feature in geo_json['features']:
            paved_id = feature['properties']['paved']
            bridge_id = feature['properties']['bridge_type']
            cat_id, cat_name = self.convert_cat(paved_id, bridge_id)
            self.add_category(cat_id, cat_name)
        
        #Check to ensure there are annotations in the file
        if geo_json['features'] != []:
            #Add lines
            for feature in geo_json['features']:
                if feature['geometry']['type'] == 'LineString':

                    coordinatesSet = feature['geometry']['coordinates']
                    
                    pixelCoordinates = self.convert_coordinatesSet(coordinatesSet, longCorner, latCorner, pixelWidth, pixelHeight)                                    
                        
                    objectID = feature['properties']['road_id']
                        
                    paved_id = feature['properties']['paved']
                    bridge_id = feature['properties']['bridge_type']
                    cat_id, cat_name = self.convert_cat(paved_id, bridge_id)
                        
                    self.add_annotation(objectID, imageID, cat_id, pixelCoordinates)
                    
                elif feature['geometry']['type'] == 'MultiLineString':
                    for coordinatesSet in feature['geometry']['coordinates']:
                    

                        pixelCoordinates = self.convert_coordinatesSet(coordinatesSet, longCorner, latCorner, pixelWidth, pixelHeight)                                    
                        
                        objectID = feature['properties']['road_id']
                        
                        paved_id = feature['properties']['paved']
                        bridge_id = feature['properties']['bridge_type']
                        cat_id, cat_name = self.convert_cat(paved_id, bridge_id)
                        
                        self.add_annotation(objectID, imageID, cat_id, pixelCoordinates)
                else:
                    print(feature['geometry']['type'])
            
        #Check to make an annotation exists for this image (met the criteria above) before adding the image to the list
        if next((item for item in self.annotations_info if item["image_id"] == imageID), False):
                image = skimage.io.imread(imagePath)
                imageWidth = image.shape[0]
                imageHeight = image.shape[1]
                #Add image to the image_info list
                self.add_image(imageFilename, imageID, imageWidth, imageHeight) 
        
        
    def convert_cat(self, paved_id, bridge_id):
        cat_id = 0
        cat_name = 'BG'
        if paved_id == '1':
                cat_id = 1
                cat_name = "Paved"
        elif paved_id == '2':
                cat_id = 2
                cat_name = "Unpaved"
        elif paved_id == '3':
                cat_id = 3
                cat_name = "Unknown"
        elif bridge_id == '2':
                cat_id = 4
                cat_name = "Bridge"
        return cat_id, cat_name
    
    def convert_coordinatesSet(self, coordinatesSet, longCorner, latCorner, pixelWidth, pixelHeight):
        pixelCoordinates = []
        for coordinates in coordinatesSet:
            xcoord = coordinates[0]
            ycoord = coordinates[1]
            xPixel = (xcoord - longCorner) / pixelWidth
            yPixel = (ycoord - latCorner) / pixelHeight
            pixelCoordinates.append(xPixel)
            pixelCoordinates.append(yPixel)
        return pixelCoordinates
    
    def convert_coordinates(self, xcoord, ycoord, longCorner, latCorner, pixelWidth, pixelHeight):
        xPixel = (xcoord - longCorner) / pixelWidth
        yPixel = (latCorner - ycoord) / pixelHeight
        return xPixel, yPixel
    
    def add_annotation(self, objectID, imageID, category_id, roadCoordsList):
        self.annotations_info.append({'id' : objectID, 
                                       'image_id' : imageID, 
                                       'category_id' : category_id, 
                                       'segmentation' : [roadCoordsList], 
                                       'area' : 0, 
                                       'bbox' : [], 
                                       'iscrowd' : 0})
                    
    def add_image(self, imageFilename, imageID, imageWidth, imageHeight):
        self.images_info.append({'id':imageID,
                                 'width':imageWidth,
                                 'height':imageHeight,
                                 'file_name': imageFilename,
                                 'license':'',
                                 'flicker_url':'',
                                 'coco_url':'',
                                 'date_captured':''})
        
    def load_data(self, annotation_dir, images_dir):
        
        for jsonFile in os.listdir(annotation_dir):
            if jsonFile.split('.')[-1] == 'geojson':
                jsonFilenameRoot = jsonFile.split('geojson_roads')
                #Pull the folder name for the image_dir (PAN, MIL, etc)
                imageType = image_dir.split('/')[-2]
                #Assemble image name that corresponds with the annotation file
                image = jsonFilenameRoot[0] + imageType + jsonFilenameRoot[1].split('.')[0]+'.tif'
                
                imagePath = os.path.join(image_dir, image)
                    
                filePath = os.path.join(annotation_dir, jsonFile)
                    
                #Execute for every geojson file in the folder
                self.load_geojson_file(filePath, imagePath)
                
    def add_category(self, cat_id, cat_name):
        
        if next((item for item in self.categories_info if item["id"] == cat_id), False):

            # cat_id combination already available, skip
            return
        self.categories_info.append({
            "id": cat_id,
            "name": cat_name,
            "supercategory": ''
        }) 

    
    def write_json(self, outputName):
        json_dict = {'license': self.license_info, 'info': self.info_info, 'categories': self.categories_info,
                     'images': self.images_info, 'annotations': self.annotations_info}
        with open(outputName, 'w', encoding='utf-8') as f:
            json.dump(json_dict, f, ensure_ascii=False, indent=4)
        

In [87]:
dataset = Dataset()
annotation_dir = '/mnt/shared/bryan/spacenet-data/AOI_5_Khartoum/geojson_roads/'
image_dir = '/mnt/shared/bryan/spacenet-data/AOI_5_Khartoum/PAN/'
dataset.load_data(annotation_dir, image_dir)
#dataset.write_json('/mnt/shared/bryan/spacenet-data/AOI_5_Khartoum/test.json')