In [1]:
import matplotlib as mpl
import os
if 'DISPLAY' not in os.environ:
    mpl.use('Pdf')
from kmn.util import split_test_train_data
from kmn.tf_util import init_cfg
from kmn.dataset_generator import DatasetGenerator
from kmn.kmn_model import KmnModel, KmnMultiModel
from kmn.kmn_model_conf import KmnModelConf, KmnMultiModelConf
from kmn.model_evaluation import evaluate_model
import numpy as np
mpl.rcParams['figure.facecolor'] = 'w'
mpl.rcParams['figure.figsize'] = [12.0, 6.0]
%load_ext autoreload
%autoreload 2

In [1]:
# define cfg file that specifies all run parameters
cfg_name = "kmn/scenes/box/cfgs/basic"
cfg = init_cfg(cfg_name)

# generate + load dataset 
dg = DatasetGenerator(cfg)
dg.create_dataset()
[x_train, y_train, x_test, y_test] = dg.load_dataset(small=cfg['DEBUG'])

# load KMN model type
if cfg['N_CONF'] > 0:
    if cfg['MULTI_MODEL']:
        m = KmnMultiModelConf(cfg)
    else:
        m = KmnModelConf(cfg)
else:
    if cfg['MULTI_MODEL']:
        m = KmnMultiModel(cfg)
    else:
        m = KmnModel(cfg)

In [1]:
x_train_full = np.copy(x_train)
y_train_full = np.copy(y_train)
x_test_full = np.copy(x_test)
y_test_full = np.copy(y_test)

# iterate over number of predictions
for i_pred in range(cfg['N_PRED']):
    print("### i_pred: ", i_pred)
    cfg['LOG_PREFIX'] = str(i_pred) + "/"
    if cfg['MULTI_MODEL']:                        
        # train network
        m.train(x_train=x_train_full, y_train=y_train_full, x_test=x_test_full, y_test=y_test_full,
                i_pred=i_pred)
        # load best weights
        m.load_model(i_pred=i_pred)
        # eval current network
        evaluate_model(m, cfg, x_train, y_train, x_test, y_test, i_pred) 
        # generate new training data with model predictions
        x_train_full, y_train_full = m.augment_dataset(x=x_train, y=y_train, n_pred=i_pred + 1)
        x_train_full, y_train_full, x_test_full, y_test_full = \
            split_test_train_data(x_train_full, y_train_full)
    else:
        # train network
        m.train(x_train=x_train_full, y_train=y_train_full, x_test=x_test_full, y_test=y_test_full)
        # load best weights
        m.load_model()
        # eval current network
        evaluate_model(m, cfg, x_train, y_train, x_test, y_test)
        # generate new training data with model predictions
        x_train_aug, y_train_aug = m.augment_dataset(x=x_train, y=y_train, n_pred=i_pred + 1)
        x_train_full = np.vstack([x_train_full, x_train_aug])
        y_train_full = np.vstack([y_train_full, y_train_aug])