In [0]:
import os
from glob import glob
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from PIL import Image
import mectools.data as dt
import sklearn.model_selection as sk
import statsmodels.formula.api as smf

In [0]:
import data
import tools
import models

In [0]:
import warnings
warnings.filterwarnings('ignore')

In [0]:
plt = plotter(backend='agg')
import matplotlib as mpl
import seaborn as sns
%matplotlib inline

In [0]:
for gpu in tf.config.experimental.list_physical_devices('GPU'):
    print(gpu)
    tf.config.experimental.set_memory_growth(gpu, True)

In [0]:
save = False

In [0]:
source = 'asie'
year = 2003
channel = ['density']
landsat = 'mincloud2002'
size = 1024
ivar = 'id'
yvar = 'log_tfp'

In [0]:
pix = 256
val_frac = 0.2
batch_size = 128
buffer = 10000

In [0]:
if source == 'asie':
    firms = data.load_asie_firms(year, landsat)
elif source == 'census':
    firms = data.load_census_firms(year, landsat)

In [0]:
# random geographic split
# state = np.random.RandomState(21921351)
# df_train, df_valid = tools.categ_split(firms, 'city', val_frac, state=state)
# print(len(df_valid)/(len(firms)))

In [0]:
# pure random split
df_train, df_valid = sk.train_test_split(firms, test_size=val_frac)

In [0]:
tile_path = f'../data/tiles_fast/{source}{year}'
def parse_function(fid, out):
    image = tf.concat([data.load_tile(fid, f'{tile_path}/{ch}/{size}px') for ch in channel], -1)
    return (fid, image), out

In [0]:
def make_dataset(df):
    fids = tf.constant(df[ivar])
    labels = tf.reshape(tf.cast(tf.constant(df[yvar]), tf.float32), (-1, 1))
    data = tf.data.Dataset.from_tensor_slices((fids, labels))
    data = data.map(parse_function)
    data = data.shuffle(buffer_size=buffer)
    data = data.batch(batch_size)
    data = data.repeat()
    return data

In [0]:
train, valid = make_dataset(df_train), make_dataset(df_valid)

In [0]:
model = models.gen_dual_medium(pix, len(channel))
model.summary()

In [0]:
# train keras model
history = model.fit(train, validation_data=valid, epochs=5, steps_per_epoch=2000, validation_steps=100)

In [0]:
x00, y00 = next(iter(valid))

In [0]:
model.predict(x00)

In [0]:
x_test, y_test, yh_test = tools.predict_data(model, valid, 100)
fig, (ax0, ax1) = plt.subplots(ncols=2, figsize=(11, 5))
tools.eval_model(y_test, yh_test, N=10, axs=(ax0, ax1), qmin=0.02, qmax=0.98)
if save: fig.savefig('../docs/images/asie_tfp_medium_valid.svg')

In [0]:
x_test, y_test, yh_test = tools.predict_data(model, train, 100)
fig, (ax0, ax1) = plt.subplots(ncols=2, figsize=(11, 5))
tools.eval_model(y_test, yh_test, qmin=0.02, qmax=0.98, N=10, axs=(ax0, ax1))
if save: fig.savefig('../docs/images/asie_tfp_medium_train.svg')

## City Oracle

In [0]:
ret_city = smf.ols(f'log_tfp ~ C(city)', data=firms).fit()
fig, (ax0, ax1) = plt.subplots(ncols=2, figsize=(11, 5))
tools.eval_model(firms['log_tfp'], ret_city.predict(), qmin=0.02, qmax=0.98, N=10, axs=(ax0, ax1))

## Radial Pooling

In [0]:
model_radial = models.gen_radial_pool(pix, len(channel), 3)
model_radial.summary()

In [0]:
history_radial = model_radial.fit(
    train, validation_data=valid, epochs=5, steps_per_epoch=1000, validation_steps=100
)

In [0]:
model_radial.get_weights()

In [0]:
x_test, y_test, yh_test = tools.predict_data(model_radial, valid, 100)
fig, (ax0, ax1) = plt.subplots(ncols=2, figsize=(11, 5))
tools.eval_model(y_test, yh_test, N=10, axs=(ax0, ax1), qmin=0.02, qmax=0.98)
fig.savefig('../docs/images/asie_tfp_radial_valid.svg')

In [0]:
x_test, y_test, yh_test = tools.predict_data(model_radial, train, 100)
fig, (ax0, ax1) = plt.subplots(ncols=2, figsize=(11, 5))
tools.eval_model(y_test, yh_test, qmin=0.02, qmax=0.98, N=10, axs=(ax0, ax1))
fig.savefig('../docs/images/asie_tfp_radial_train.svg')