In [1]:
import matplotlib as mpl
import os
if 'DISPLAY' not in os.environ:
    mpl.use('Pdf')
from kmn.util import split_test_train_data
from kmn.tf_util import init_cfg
from kmn.dataset_generator import DatasetGenerator
from kmn.kmn_model import KmnModel, KmnMultiModel
from kmn.kmn_model_conf import KmnModelConf, KmnMultiModelConf
from kmn.model_evaluation import evaluate_model
import numpy as np
mpl.rcParams['figure.facecolor'] = 'w'
mpl.rcParams['figure.figsize'] = [12.0, 6.0]
%load_ext autoreload
%autoreload 2

Using TensorFlow backend.


In [2]:
# define cfg file that specifies all run parameters
cfg_name = "kmn/scenes/box/cfgs/basic"
cfg = init_cfg(cfg_name)

# generate + load dataset 
dg = DatasetGenerator(cfg)
dg.create_dataset()
[x_train, y_train, x_test, y_test] = dg.load_dataset(small=cfg['DEBUG'])

# load KMN model type
if cfg['N_CONF'] > 0:
    if cfg['MULTI_MODEL']:
        m = KmnMultiModelConf(cfg)
    else:
        m = KmnModelConf(cfg)
else:
    if cfg['MULTI_MODEL']:
        m = KmnMultiModel(cfg)
    else:
        m = KmnModel(cfg)


descr  : 
log dir: /Users/englerpr/git/kinematic_morphing_networks/kmn/scenes/box/logs/basic_s-8d3_0725_1136/
scene  : box
cfg    : basic_s-8d3_0725_1136
data   : Transformation
model  : conv_basic
param  : ['xy scaling' '-']
80  data points for testing
20  data points for training
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
I (InputLayer)               (None, 12288)             0         
_________________________________________________________________
Im (Reshape)                 (None, 96, 128, 1)        0         
_________________________________________________________________
Scaling (Lambda)             (None, 96, 128, 1)        0         
_________________________________________________________________
conv1 (Conv2D)               (None, 96, 128, 2)        20        
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 48, 64, 2)         0 

In [3]:
x_train_full = np.copy(x_train)
y_train_full = np.copy(y_train)
x_test_full = np.copy(x_test)
y_test_full = np.copy(y_test)

# iterate over number of predictions
for i_pred in range(cfg['N_PRED']):
    print("### i_pred: ", i_pred)
    cfg['LOG_PREFIX'] = str(i_pred) + "/"
    if cfg['MULTI_MODEL']:                        
        # train network
        m.train(x_train=x_train_full, y_train=y_train_full, x_test=x_test_full, y_test=y_test_full,
                i_pred=i_pred)
        # load best weights
        m.load_model(i_pred=i_pred)
        # eval current network
        evaluate_model(m, cfg, x_train, y_train, x_test, y_test, i_pred) 
        # generate new training data with model predictions
        x_train_full, y_train_full = m.augment_dataset(x=x_train, y=y_train, n_pred=i_pred + 1)
        x_train_full, y_train_full, x_test_full, y_test_full = \
            split_test_train_data(x_train_full, y_train_full)
    else:
        # train network
        m.train(x_train=x_train_full, y_train=y_train_full, x_test=x_test_full, y_test=y_test_full)
        # load best weights
        m.load_model()
        # eval current network
        evaluate_model(m, cfg, x_train, y_train, x_test, y_test)
        # generate new training data with model predictions
        x_train_aug, y_train_aug = m.augment_dataset(x=x_train, y=y_train, n_pred=i_pred + 1)
        x_train_full = np.vstack([x_train_full, x_train_aug])
        y_train_full = np.vstack([y_train_full, y_train_aug])

### i_pred:  0


Train on 80 samples, validate on 20 samples


Epoch 1/12





Epoch 2/12


Epoch 3/12





Epoch 4/12





Epoch 5/12


Epoch 6/12





Epoch 7/12





Epoch 8/12





Epoch 9/12





Epoch 10/12





Epoch 11/12





Epoch 12/12








Evaluation on Train set
[                                                                                                    ] 0%

[...........                                                                                         ] 11%

[......................                                                                              ] 22%

[.................................                                                                   ] 33%

[............................................                                                        ] 44%

[.......................................................                                             ] 55%

KeyboardInterrupt: 