Permalink
Switch branches/tags
Nothing to show
Find file Copy path
2a80c9b Aug 17, 2016
2 contributors

Users who have contributed to this file

@rocfig @espes
executable file 78 lines (64 sloc) 2.58 KB
#!/usr/bin/env python
"""
Steering angle prediction model
"""
import os
import argparse
import json
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten, Lambda, ELU
from keras.layers.convolutional import Convolution2D
from server import client_generator
def gen(hwm, host, port):
for tup in client_generator(hwm=hwm, host=host, port=port):
X, Y, _ = tup
Y = Y[:, -1]
if X.shape[1] == 1: # no temporal context
X = X[:, -1]
yield X, Y
def get_model(time_len=1):
ch, row, col = 3, 160, 320 # camera format
model = Sequential()
model.add(Lambda(lambda x: x/127.5 - 1.,
input_shape=(ch, row, col),
output_shape=(ch, row, col)))
model.add(Convolution2D(16, 8, 8, subsample=(4, 4), border_mode="same"))
model.add(ELU())
model.add(Convolution2D(32, 5, 5, subsample=(2, 2), border_mode="same"))
model.add(ELU())
model.add(Convolution2D(64, 5, 5, subsample=(2, 2), border_mode="same"))
model.add(Flatten())
model.add(Dropout(.2))
model.add(ELU())
model.add(Dense(512))
model.add(Dropout(.5))
model.add(ELU())
model.add(Dense(1))
model.compile(optimizer="adam", loss="mse")
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Steering angle model trainer')
parser.add_argument('--host', type=str, default="localhost", help='Data server ip address.')
parser.add_argument('--port', type=int, default=5557, help='Port of server.')
parser.add_argument('--val_port', type=int, default=5556, help='Port of server for validation dataset.')
parser.add_argument('--batch', type=int, default=64, help='Batch size.')
parser.add_argument('--epoch', type=int, default=200, help='Number of epochs.')
parser.add_argument('--epochsize', type=int, default=10000, help='How many frames per epoch.')
parser.add_argument('--skipvalidate', dest='skipvalidate', action='store_true', help='Multiple path output.')
parser.set_defaults(skipvalidate=False)
parser.set_defaults(loadweights=False)
args = parser.parse_args()
model = get_model()
model.fit_generator(
gen(20, args.host, port=args.port),
samples_per_epoch=10000,
nb_epoch=args.epoch,
validation_data=gen(20, args.host, port=args.val_port),
nb_val_samples=1000
)
print("Saving model weights and configuration file.")
if not os.path.exists("./outputs/steering_model"):
os.makedirs("./outputs/steering_model")
model.save_weights("./outputs/steering_model/steering_angle.keras", True)
with open('./outputs/steering_model/steering_angle.json', 'w') as outfile:
json.dump(model.to_json(), outfile)