In [None]:
import os
import cv2
import json
import numpy as np
import pandas
import skimage
import shapefile
import matplotlib.pyplot as plt

import modules

### load data

In [None]:
class DatasetGenerator:
    
    def __init__(self, config):
        self.config = config
        
        self.filenames = {}
        self.dataframes = {}
        self.shapefiles = {}
    
    def _setup(self, country):
        if country != "kenya" and country != "peru":
            raise ValueError("Country must be either \'kenya\' or \'peru\'.")

        geo = modules.data.load_geodata(country)
        osm, sf = modules.data.load_shapefile(country)
        
        self.shapefiles[country] = sf
        self.dataframes[country] = pandas.DataFrame.merge(geo, osm, on="index")
        self.filenames[country] = set(modules.data.util.load_image_filenames(country, D=self.config["image_size"]))
    
    def sample(self, country):
        
        def _sample(cls):
            df = self.dataframes[country]
            idx = np.logical_and(df["class"] == cls, df["valid"] == True)
            return df.iloc[np.random.choice(df[idx].index, size=self.config["sample"]["size"], replace=False)]
        
        major = _sample("major")
        minor = _sample("minor")
        two_track = _sample("two-track")
        
        df = pandas.concat([major, minor, two_track])
        
        filenames = self._extract_filenames(df)
                
        # major => 0, minor => 1, two-track => 2
        labels = np.arange(len(filenames)) // N
        labels = np.eye(np.max(labels) + 1)[labels]
        
        return filenames, labels

            
    def generate_kenya(self):
        self._setup("kenya")
        
        if self.config.__contains__("sample"):
            filenames, labels = self.sample("kenya")
        else:
            df = self.dataframes["kenya"]["valid"].iloc[self.dataframes["kenya"]["valid"] == True]
            
            filenames = self._extract_filenames(df)
            
            labels = np.zeros(len(filenames))
            labels[df["class"] == "major"] = 0
            labels[df["class"] == "minor"] = 1
            labels[df["class"] == "two-track"] = 2
            
            labels = np.eye(np.max(labels) + 1)[labels]

        dataset = tf.data.Dataset.from_generator(
            _generator(filenames, labels),
            output_types=(tf.int32, tf.int32),
            output_shapes=((self.config["image_size"], self.config["image_size"]), (1, ))
        )
        
        if self.config["shuffle_buffer"]:
            dataset = dataset.shuffle(self.config["shuffle_buffer"])
            
        dataset = dataset.batch(self.config["batch_size"], drop_remainder=True)
        
        return dataset
        
    
    def generate_peru(self):
        raise NotImplementedError()
    
    
    def _extract_filenames(self, df):
        filenames = []
        for idx in df.index:
            filenames.append("{}_{}.npy".format(idx, int(df.loc[idx]['id'])))
        return filenames

    
    def _generator(filenames, labels):
        for filename, label in zip(filenames, labels):
            yield np.load(filename), label