# Make a population weight file using RF model

In [1]:
import os
import sys
import argparse 
import pickle
import rasterio
import pandas as pd
import numpy as np
import tqdm
import gdal
import glob
import itertools
import multiprocess as mp

In [2]:
print('-----------------------------------------------------------')
print('Population data process 3/7: random forest model prediction')
print('-----------------------------------------------------------')

-----------------------------------------------------------
Population data process 3/7: random forest model prediction
-----------------------------------------------------------


In [3]:
#setting up paths for files
#top_folder="/home/dohyungkim/population"
#ISO="SGP"
#year="2015"

In [None]:
# Initialize parser 
my_parser = argparse.ArgumentParser(description='initial input')
my_parser.add_argument('top_folder',metavar='top_folder',type=str,help='working folder')
my_parser.add_argument('ISO',metavar='ISO',type=str,help='3 character country iso code')
my_parser.add_argument('year',metavar='year',type=str,help='population year')
args = my_parser.parse_args()
top_folder = args.top_folder
ISO = args.ISO
year = args.year
if not os.path.isdir(top_folder):
    print('The path specified does not exist')
    sys.exit()

In [4]:
#import the trained rf model at cluster level to make predictions at 100m resolution
wp_path=os.path.join(top_folder,ISO,"wp_data")
tif_path=os.path.join(wp_path,'tif')
vrt_path=os.path.join(tif_path,'vrt_tiles')
prd_path=os.path.join(wp_path,'prd_files')

rf_file = 'rf_model'
with open(os.path.join(os.path.join(top_folder,ISO), rf_file), 'rb') as f:
    rf = pickle.load(f)

In [5]:
#max_number_processes=10
p = mp.Pool()

In [6]:
tif_tiles=os.listdir(vrt_path)

In [7]:
print('start random forest model prediction')

start random forest model prediction


In [8]:
def rf_predict(file_name):    
    try:
        raster_tif = rasterio.open(os.path.join(vrt_path,file_name))
        vals=raster_tif.read()
        arr_shp=list(vals.shape)
        vals=vals.reshape(arr_shp[0],arr_shp[1]*arr_shp[2])
        result=rf.predict(vals.T).reshape(arr_shp[1],arr_shp[2])
        result=result.astype(np.float32)
        with rasterio.open(
            os.path.join(prd_path, file_name),
            'w',
            driver='GTiff',
            height=arr_shp[1],
            width=arr_shp[2],
            count=1,
            dtype=result.dtype,
            crs=raster_tif.crs,
            transform=raster_tif.transform
        ) as dst:
            dst.write(result, 1)
        #time.sleep(.1)
        return(1)
    except:
        return(0)

In [9]:
#results = p.map(rf_predict, rf_where[0:100])
results = list(tqdm.tqdm(p.imap_unordered(rf_predict, tif_tiles), total=len(tif_tiles)))


100%|██████████| 2730/2730 [04:15<00:00, 10.70it/s]


In [29]:
p.close()
p.join()

In [30]:
prd_files=glob.glob(prd_path + '/*.tif')

In [31]:
os.chdir(prd_path)
vrt=gdal.BuildVRT(os.path.join(prd_path,"rf_predict.vrt"), prd_files)
vrt.FlushCache()