In [0]:
cd ../code

In [0]:
import numpy as np
import sklearn.model_selection as sk

In [0]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torchvision import transforms
import torch_tools as tt

In [0]:
import train_data
import train_tools

In [0]:
import warnings
warnings.filterwarnings('ignore')

In [0]:
plt = plotter()
import matplotlib as mpl
%config InlineBackend.figure_format = 'retina'
%matplotlib inline

In [0]:
data_path = '../data/tiles_fast'

In [0]:
source = 'asie'
year = 2003
channel = ['density', 'landsat']
landsat = 'mincloud2002'
size = 1024
ivar = 'id'
yvar = 'log_tfp'
split = 'geo' # geo, rand

In [0]:
pix = 256
val_frac = 0.2
batch_size = 128
buffer = 10000

In [0]:
if source == 'asie':
    firms = train_data.load_asie_firms(year, landsat, drop=False)
elif source == 'census':
    firms = train_data.load_census_firms(year, landsat)

In [0]:
print(firms['sic2'].nunique())
firms['sic2'].value_counts().sort_index().plot.bar();

In [0]:
# random geographic split
if split == 'geo':
    state = np.random.RandomState(21921351)
    df_train, df_test = train_tools.categ_split(firms, 'city', val_frac, state=state)
else:
    df_train, df_test = sk.train_test_split(firms, test_size=val_frac)
print(len(df_test)/(len(firms)))

In [0]:
model = tt.make_resnet(len(channel)).to('cuda')

In [0]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [0]:
img_proc = transforms.Compose([
    transforms.ToTensor(),
])

In [0]:
sources = [f'{data_path}/{source}{year}/{ch}/{size}px' for ch in channel]
load_args = dict(batch_size=batch_size, pin_memory=True, num_workers=1)

In [0]:
train_dataset = tt.ImageDataset(sources, df_train[yvar].dropna(), transform=img_proc)
train_loader = tt.data.DataLoader(train_dataset, shuffle=True, **load_args)

In [0]:
test_dataset = tt.ImageDataset(sources, df_test[yvar].dropna(), transform=img_proc)
test_loader = tt.data.DataLoader(test_dataset, **load_args)

In [0]:
for epoch in range(10):
    tt.train(model, train_loader, optimizer, epoch)
    tt.test(model, test_loader, epoch)

In [0]:
test_stat, test_pred = tt.evaluate(model, test_loader)

In [0]:
fig, axs = plt.subplots(ncols=2, figsize=(10, 4))
train_tools.eval_model(test_stat.cpu(), test_pred.cpu(), axs=axs, qmin=0.01, qmax=0.99)

In [0]:
torch.save(model.state_dict(), 'models/resnet_log_tfp.pt')