diff --git a/policy/openbot/associate_frames.py b/policy/openbot/associate_frames.py index b62fb3bf0..35461b546 100644 --- a/policy/openbot/associate_frames.py +++ b/policy/openbot/associate_frames.py @@ -107,15 +107,20 @@ def associate(first_list, second_list, max_offset): return matches -def match_frame_ctrl_cmd( - data_dir, datasets, max_offset, redo_matching=False, remove_zeros=True +def match_frame_ctrl_input( + data_dir, + datasets, + max_offset, + redo_matching=False, + remove_zeros=True, + policy="autopilot", ): frames = [] for dataset in datasets: for folder in utils.list_dirs(os.path.join(data_dir, dataset)): session_dir = os.path.join(data_dir, dataset, folder) frame_list = match_frame_session( - session_dir, max_offset, redo_matching, remove_zeros + session_dir, max_offset, redo_matching, remove_zeros, policy ) for timestamp in list(frame_list): frames.append(frame_list[timestamp][0]) @@ -123,8 +128,24 @@ def match_frame_ctrl_cmd( def match_frame_session( - session_dir, max_offset, redo_matching=False, remove_zeros=True + session_dir, max_offset, redo_matching=False, remove_zeros=True, policy="autopilot" ): + + if policy == "autopilot": + matched_frames_file_name = "matched_frame_ctrl_cmd.txt" + processed_frames_file_name = "matched_frame_ctrl_cmd_processed.txt" + log_file = "indicatorLog.txt" + csv_label_string = "timestamp (frame),time_offset (cmd-frame),time_offset (ctrl-frame),frame,left,right,cmd\n" + csv_label_string_processed = "timestamp,frame,left,right,cmd\n" + elif policy == "point_goal_nav": + matched_frames_file_name = "matched_frame_ctrl_goal.txt" + processed_frames_file_name = "matched_frame_ctrl_goal_processed.txt" + log_file = "goalLog.txt" + csv_label_string = "timestamp (frame),time_offset (goal-frame),time_offset (ctrl-frame),frame,left,right,dist,sinYaw,cosYaw\n" + csv_label_string_processed = "timestamp,frame,left,right,dist,sinYaw,cosYaw\n" + else: + raise Exception("Unknown policy") + sensor_path = os.path.join(session_dir, "sensor_data") img_path = os.path.join(session_dir, "images") print("Processing folder %s" % (session_dir)) @@ -156,7 +177,7 @@ def match_frame_session( print(" Frames and controls matched.") if not redo_matching and os.path.isfile( - os.path.join(sensor_path, "matched_frame_ctrl_cmd.txt") + os.path.join(sensor_path, matched_frames_file_name) ): print(" Frames and commands already matched.") else: @@ -164,58 +185,79 @@ def match_frame_session( frame_list = read_file_list(os.path.join(sensor_path, "matched_frame_ctrl.txt")) if len(frame_list) == 0: raise Exception("Empty matched_frame_ctrl.txt") - cmd_list = read_file_list(os.path.join(sensor_path, "indicatorLog.txt")) - # Set indicator signal to 0 for initial frames - if len(cmd_list) == 0 or sorted(frame_list)[0] < sorted(cmd_list)[0]: - cmd_list[sorted(frame_list)[0]] = ["0"] + cmd_list = read_file_list(os.path.join(sensor_path, log_file)) + + if policy == "autopilot": + # Set indicator signal to 0 for initial frames + if len(cmd_list) == 0 or sorted(frame_list)[0] < sorted(cmd_list)[0]: + cmd_list[sorted(frame_list)[0]] = ["0"] + + elif policy == "point_goal_nav": + if len(cmd_list) == 0: + raise Exception("Empty goalLog.txt") + matches = associate(frame_list, cmd_list, max_offset) - with open(os.path.join(sensor_path, "matched_frame_ctrl_cmd.txt"), "w") as f: - f.write( - "timestamp (frame),time_offset (cmd-frame),time_offset (ctrl-frame),frame,left,right,cmd\n" - ) + with open(os.path.join(sensor_path, matched_frames_file_name), "w") as f: + f.write(csv_label_string) for a, b in matches: f.write( "%d,%d,%s,%s\n" % (a, b - a, ",".join(frame_list[a]), ",".join(cmd_list[b])) ) - print(" Frames and commands matched.") + print(" Frames and high-level commands matched.") if not redo_matching and os.path.isfile( - os.path.join(sensor_path, "matched_frame_ctrl_cmd_processed.txt") + os.path.join(sensor_path, processed_frames_file_name) ): print(" Preprocessing already completed.") else: # Cleanup: Add path and remove frames where vehicle was stationary - frame_list = read_file_list( - os.path.join(sensor_path, "matched_frame_ctrl_cmd.txt") - ) - with open( - os.path.join(sensor_path, "matched_frame_ctrl_cmd_processed.txt"), "w" - ) as f: - f.write("timestamp,frame,left,right,cmd\n") + frame_list = read_file_list(os.path.join(sensor_path, matched_frames_file_name)) + with open(os.path.join(sensor_path, processed_frames_file_name), "w") as f: + f.write(csv_label_string_processed) # max_ctrl = get_max_ctrl(frame_list) for timestamp in list(frame_list): frame = frame_list[timestamp] if len(frame) < 6: continue - left = int(frame[3]) - right = int(frame[4]) - # left = normalize(max_ctrl, frame[3]) - # right = normalize(max_ctrl, frame[4]) - if remove_zeros and left == 0 and right == 0: - print(f" Removed timestamp: {timestamp}") - del frame - else: - frame_name = os.path.join(img_path, frame[2] + "_crop.jpeg") - cmd = int(frame[5]) - f.write( - "%s,%s,%d,%d,%d\n" % (timestamp, frame_name, left, right, cmd) - ) + + if policy == "autopilot": + left = int(frame[3]) + right = int(frame[4]) + # left = normalize(max_ctrl, frame[3]) + # right = normalize(max_ctrl, frame[4]) + if remove_zeros and left == 0 and right == 0: + print(f" Removed timestamp: {timestamp}") + del frame + else: + frame_name = os.path.join(img_path, frame[2] + "_crop.jpeg") + cmd = int(frame[5]) + f.write( + "%s,%s,%d,%d,%d\n" + % (timestamp, frame_name, left, right, cmd) + ) + + elif policy == "point_goal_nav": + left = float(frame_list[timestamp][3]) + right = float(frame_list[timestamp][4]) + if remove_zeros and left == 0.0 and right == 0.0: + print(" Removed timestamp:%s" % (timestamp)) + del frame_list[timestamp] + else: + frame_name = os.path.join( + img_path, frame_list[timestamp][2] + ".jpeg" + ) + dist = float(frame_list[timestamp][5]) + sinYaw = float(frame_list[timestamp][6]) + cosYaw = float(frame_list[timestamp][7]) + f.write( + "%s,%s,%f,%f,%f,%f,%f\n" + % (timestamp, frame_name, left, right, dist, sinYaw, cosYaw) + ) + print(" Preprocessing completed.") - return read_file_list( - os.path.join(sensor_path, "matched_frame_ctrl_cmd_processed.txt") - ) + return read_file_list(os.path.join(sensor_path, processed_frames_file_name)) def normalize(max_ctrl, val): diff --git a/policy/openbot/callbacks.py b/policy/openbot/callbacks.py index b0e21640f..660490b84 100644 --- a/policy/openbot/callbacks.py +++ b/policy/openbot/callbacks.py @@ -18,6 +18,45 @@ def checkpoint_cb(checkpoint_path, steps_per_epoch=-1, num_epochs=10): return checkpoint_callback +def checkpoint_last_cb(checkpoint_path, steps_per_epoch=-1, num_epochs=10): + checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( + filepath=os.path.join(checkpoint_path, "cp-last.ckpt"), + monitor="val_loss", + verbose=0, + save_best_only=False, + save_weights_only=False, + mode="auto", + save_freq="epoch" if steps_per_epoch < 0 else int(num_epochs * steps_per_epoch), + ) + return checkpoint_callback + + +def checkpoint_best_train_cb(checkpoint_path, steps_per_epoch=-1, num_epochs=10): + checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( + filepath=os.path.join(checkpoint_path, "cp-best-train.ckpt"), + monitor="loss", + verbose=0, + save_best_only=True, + save_weights_only=False, + mode="auto", + save_freq="epoch" if steps_per_epoch < 0 else int(num_epochs * steps_per_epoch), + ) + return checkpoint_callback + + +def checkpoint_best_val_cb(checkpoint_path, steps_per_epoch=-1, num_epochs=10): + checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( + filepath=os.path.join(checkpoint_path, "cp-best-val.ckpt"), + monitor="val_loss", + verbose=0, + save_best_only=True, + save_weights_only=False, + mode="auto", + save_freq="epoch" if steps_per_epoch < 0 else int(num_epochs * steps_per_epoch), + ) + return checkpoint_callback + + def tensorboard_cb(log_path): tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=log_path, diff --git a/policy/openbot/data_augmentation.py b/policy/openbot/data_augmentation.py index ce3adadc3..8dea50f27 100644 --- a/policy/openbot/data_augmentation.py +++ b/policy/openbot/data_augmentation.py @@ -3,6 +3,7 @@ This script implements several routines for data augmentation. """ import tensorflow as tf +import numpy as np def augment_img(img): @@ -18,6 +19,7 @@ def augment_img(img): img = tf.image.random_saturation(img, 0.6, 1.6) img = tf.image.random_brightness(img, 0.05) img = tf.image.random_contrast(img, 0.7, 1.3) + img = tf.clip_by_value(img, clip_value_min=0.0, clip_value_max=1.0) return img @@ -32,7 +34,7 @@ def augment_cmd(cmd): cmd: augmented command """ if not (cmd > 0 or cmd < 0): - coin = tf.random.uniform(shape=[1], minval=0, maxval=1, dtype=tf.dtypes.float32) + coin = np.random.default_rng().uniform(low=0.0, high=1.0, size=None) if coin < 0.25: cmd = -1.0 elif coin < 0.5: @@ -41,7 +43,7 @@ def augment_cmd(cmd): def flip_sample(img, cmd, label): - coin = tf.random.uniform(shape=[1], minval=0, maxval=1, dtype=tf.dtypes.float32) + coin = np.random.default_rng().uniform(low=0.0, high=1.0, size=None) if coin < 0.5: img = tf.image.flip_left_right(img) cmd = -cmd diff --git a/policy/openbot/dataloader.py b/policy/openbot/dataloader.py index 53844e6b6..d4b51c58d 100644 --- a/policy/openbot/dataloader.py +++ b/policy/openbot/dataloader.py @@ -6,17 +6,30 @@ class dataloader: - def __init__(self, data_dir: str, datasets: List[str]): + def __init__(self, data_dir: str, datasets: List[str], policy: str): self.data_dir = data_dir + self.policy = policy # "autopilot" or "point_goal_nav" self.datasets = datasets self.labels = self.load_labels() self.index_table = self.lookup_table() self.label_values = tf.constant( [(float(label[0]), float(label[1])) for label in self.labels.values()] ) - self.cmd_values = tf.constant( - [(float(label[2])) for label in self.labels.values()] - ) + + if self.policy == "autopilot": + self.label_divider = 255.0 + self.processed_frames_file_name = "matched_frame_ctrl_cmd_processed.txt" + self.cmd_values = tf.constant( + [(float(label[2])) for label in self.labels.values()] + ) + elif self.policy == "point_goal_nav": + self.label_divider = 1.0 + self.processed_frames_file_name = "matched_frame_ctrl_goal_processed.txt" + self.cmd_values = tf.constant( + [(float(l[2]), float(l[3]), float(l[4])) for l in self.labels.values()] + ) + else: + raise Exception("Unknown policy") # Load labels def load_labels(self): @@ -27,34 +40,39 @@ def load_labels(self): for f in os.listdir(os.path.join(self.data_dir, dataset)) if not f.startswith(".") ]: - with open( - os.path.join( - self.data_dir, - dataset, - folder, - "sensor_data", - "matched_frame_ctrl_cmd_processed.txt", - ) - ) as f_input: - # discard header - header = f_input.readline() - data = f_input.read() - lines = ( - data.replace(",", " ") - .replace("\\", "/") - .replace("\r", "") - .replace("\t", " ") - .split("\n") - ) - data = [ - [v.strip() for v in line.split(" ") if v.strip() != ""] - for line in lines - if len(line) > 0 and line[0] != "#" - ] - # Tuples containing id: framepath and label: left,right,cmd - data = [(line[1], line[2:]) for line in data if len(line) > 1] - corpus.extend(data) - return dict(corpus) + labels_file = os.path.join( + self.data_dir, + dataset, + folder, + "sensor_data", + self.processed_frames_file_name, + ) + + if os.path.isfile(labels_file): + with open(labels_file) as f_input: + + # discard header + header = f_input.readline() + data = f_input.read() + lines = ( + data.replace(",", " ") + .replace("\\", "/") + .replace("\r", "") + .replace("\t", " ") + .split("\n") + ) + data = [ + [v.strip() for v in line.split(" ") if v.strip() != ""] + for line in lines + if len(line) > 0 and line[0] != "#" + ] + # Tuples containing id: framepath and respectively labels "left,right,cmd" for autopilot policy + # and labels "left,right,dist,sinYaw,cosYaw" point_goal_nav policy + data = [(line[1], line[2:]) for line in data if len(line) > 1] + corpus.extend(data) + else: + print(f"Skipping {folder}") + return dict(corpus) / self.label_divider # build a lookup table to get the frame index for the label def lookup_table(self): @@ -70,4 +88,4 @@ def lookup_table(self): def get_label(self, file_path): index = self.index_table.lookup(file_path) - return self.cmd_values[index], self.label_values[index] / 255 + return self.cmd_values[index], self.label_values[index] diff --git a/policy/openbot/losses.py b/policy/openbot/losses.py index 06550e8ce..e52916b72 100644 --- a/policy/openbot/losses.py +++ b/policy/openbot/losses.py @@ -3,21 +3,56 @@ import tensorflow as tf +def angle(y): + return y[:, 0] - y[:, 1] + + +def angle_weight(y_gt, eps=0.05): + return tf.math.square(angle(y_gt)) + eps + + +def mse_raw(y_gt, y_pred): + return tf.keras.losses.mean_squared_error(y_gt, y_pred) + + +def mae_raw(y_gt, y_pred): + return tf.keras.losses.mean_absolute_error(y_gt, y_pred) + + +def huber_raw(y_gt, y_pred): + huber = tf.keras.losses.Huber() + return tf.keras.losses.huber(y_gt, y_pred) + + +def mse_angle(y_gt, y_pred): + return tf.keras.losses.mean_squared_error(angle(y_gt), angle(y_pred)) + + +def weighted_mse_raw(y_gt, y_pred): + return angle_weight(y_gt) * mse_raw(y_gt, y_pred) + + +def weighted_mse_angle(y_gt, y_pred): + return angle_weight(y_gt) * mse_angle(y_gt, y_pred) + + +def sq_weighted_mse_angle(y_gt, y_pred): + return angle_weight(y_gt) * mse_angle(y_gt, y_pred) + + +def weighted_mse_raw_angle(y_gt, y_pred): + return angle_weight(y_gt) * (mse_raw(y_gt, y_pred) + mse_angle(y_gt, y_pred)) + + +def mae_raw_weighted_mse_angle(y_gt, y_pred): + return mae_raw(y_gt, y_pred) + weighted_mse_angle(y_gt, y_pred) + + def weighted_mse_raw(y_true, y_pred): weight = tf.math.abs(y_true[:, 0] - y_true[:, 1] + 0.05) return weight * tf.keras.losses.mean_squared_error(y_true, y_pred) -def weighted_mse_angle(y_true, y_pred): - angle_true = y_true[:, 1] - y_true[:, 0] - angle_pred = y_pred[:, 1] - y_pred[:, 0] - weight = tf.math.abs(angle_true + 0.05) - return tf.math.square(weight) * ( - tf.keras.losses.mean_squared_error(angle_true, angle_pred) - + tf.keras.losses.mean_squared_error(y_true, y_pred) - ) - - def sq_weighted_mse_angle(y_true, y_pred): angle_true = y_true[:, 1] - y_true[:, 0] angle_pred = y_pred[:, 1] - y_pred[:, 0] diff --git a/policy/openbot/models.py b/policy/openbot/models.py index 75ed0ecbc..677da7862 100644 --- a/policy/openbot/models.py +++ b/policy/openbot/models.py @@ -63,11 +63,11 @@ def create_cnn( return model -def create_mlp(in_dim, hidden_dim, out_dim, activation="relu", dropout=0.2): +def create_mlp(in_dim, hidden_dim, out_dim, activation="relu", dropout=0.2, name="cmd"): model = tf.keras.Sequential(name="MLP") model.add( tf.keras.layers.Dense( - hidden_dim, input_dim=in_dim, activation=activation, name="cmd" + hidden_dim, input_dim=in_dim, activation=activation, name=name ) ) if dropout > 0: @@ -76,8 +76,15 @@ def create_mlp(in_dim, hidden_dim, out_dim, activation="relu", dropout=0.2): return model -def pilot_net(img_width, img_height, bn=False): - mlp = create_mlp(1, 1, 1, dropout=0) +def pilot_net(img_width, img_height, bn=False, policy="autopilot"): + + if policy == "autopilot": + mlp = create_mlp(1, 1, 1, dropout=0, name="cmd") + elif policy == "point_goal_nav": + mlp = create_mlp(3, 16, 16, dropout=0, name="goal") + else: + raise Exception("Unknown policy") + cnn = create_cnn( img_width, img_height, @@ -108,8 +115,15 @@ def pilot_net(img_width, img_height, bn=False): return model -def cil_mobile(img_width, img_height, bn=True): - mlp = create_mlp(1, 16, 16, dropout=0.5) +def cil_mobile(img_width, img_height, bn=True, policy="autopilot"): + + if policy == "autopilot": + mlp = create_mlp(1, 16, 16, dropout=0.5, name="cmd") + elif policy == "point_goal_nav": + mlp = create_mlp(3, 16, 16, dropout=0.5, name="goal") + else: + raise Exception("Unknown policy") + cnn = create_cnn( img_width, img_height, @@ -144,8 +158,15 @@ def cil_mobile(img_width, img_height, bn=True): return model -def cil_mobile_fast(img_width, img_height, bn=True): - mlp = create_mlp(1, 16, 16) +def cil_mobile_fast(img_width, img_height, bn=True, policy="autopilot"): + + if policy == "autopilot": + mlp = create_mlp(1, 16, 16, name="cmd") + elif policy == "point_goal_nav": + mlp = create_mlp(3, 16, 16, name="goal") + else: + raise Exception("Unknown policy") + cnn = create_cnn( img_width, img_height, @@ -179,8 +200,15 @@ def cil_mobile_fast(img_width, img_height, bn=True): return model -def cil(img_width, img_height, bn=True): - mlp = create_mlp(1, 64, 64, dropout=0.5) +def cil(img_width, img_height, bn=True, policy="autopilot"): + + if policy == "autopilot": + mlp = create_mlp(1, 64, 64, dropout=0.5, name="cmd") + elif policy == "point_goal_nav": + mlp = create_mlp(3, 64, 64, dropout=0.5, name="goal") + else: + raise Exception("Unknown policy") + cnn = create_cnn( img_width, img_height, diff --git a/policy/openbot/server/api.py b/policy/openbot/server/api.py index 36fbaa7e6..9693dc002 100644 --- a/policy/openbot/server/api.py +++ b/policy/openbot/server/api.py @@ -191,6 +191,7 @@ def broadcast(event, payload=None): hyper_params = Hyperparameters() for p in params: setattr(hyper_params, p, params[p]) + setattr(hyper_params, "POLICY", "autopilot") print(hyper_params.__dict__) loop.run_in_executor(None, train, hyper_params, broadcast, event_cancelled) return True @@ -200,7 +201,7 @@ def train(params, broadcast, cancelled): try: broadcast("started", params.__dict__) my_callback = MyCallback(broadcast, cancelled, True) - create_tfrecord(my_callback) + create_tfrecord(my_callback, params.POLICY) tr = start_train(params, my_callback) broadcast("done", {"model": tr.model_name}) except CancelledException: diff --git a/policy/openbot/server/prediction.py b/policy/openbot/server/prediction.py index fb384daa6..3cd5cf57e 100644 --- a/policy/openbot/server/prediction.py +++ b/policy/openbot/server/prediction.py @@ -47,7 +47,7 @@ def getPrediction(params): if params["indicator"] is None: cmd_input = np.array([[ind]], dtype=np.float32) - img = utils.load_img(path) + img = utils.load_img(path, is_crop=False) img_input = np.expand_dims(img, axis=0) interpreter.set_tensor(input_details[0]["index"], cmd_input) diff --git a/policy/openbot/tfrecord.py b/policy/openbot/tfrecord.py index 18a7fc2c9..da0e8c2fe 100644 --- a/policy/openbot/tfrecord.py +++ b/policy/openbot/tfrecord.py @@ -34,8 +34,16 @@ def get_parser(): return parser -def load_labels(data_dir, datasets): - """Returns a dictionary of matched images path[string] and actions tuple (left[int], right[int], cmd[int]).""" +def load_labels(data_dir, datasets, policy="autopilot"): + """Returns a dictionary of matched images path[string] and actions tuple, namely (left[int], right[int], cmd[int]) for autopilot policy and (left[int], right[int], dist[float],sinYaw[float],cosYaw[float]) for point_goal_nav policy.""" + + if policy == "autopilot": + processed_frames_file_name = "matched_frame_ctrl_cmd_processed.txt" + elif policy == "point_goal_nav": + processed_frames_file_name = "matched_frame_ctrl_goal_processed.txt" + else: + raise Exception("Unknown policy") + corpus = [] for dataset in datasets: dataset_folders = [ @@ -46,7 +54,7 @@ def load_labels(data_dir, datasets): for folder in dataset_folders: sensor_data_dir = os.path.join(data_dir, dataset, folder, "sensor_data") with open( - os.path.join(sensor_data_dir, "matched_frame_ctrl_cmd_processed.txt") + os.path.join(sensor_data_dir, processed_frames_file_name) ) as f_input: header = f_input.readline() # discard header data = f_input.read() @@ -63,18 +71,31 @@ def load_labels(data_dir, datasets): for line in lines if len(line) > 0 and line[0] != "#" ] - # Tuples containing id: framepath and label: left,right,cmd + # Tuples containing id: framepath and label: "left,right,cmd" for autopiot policy + # and for "left,right,dist,sinYaw,cosYaw" point_goal_nav policy. data = [(l[1], l[2:]) for l in data if len(l) > 1] corpus.extend(data) return dict(corpus) def convert_dataset( - data_dir, tfrecords_dir, tfrecords_name, redo_matching=True, remove_zeros=True + data_dir, + tfrecords_dir, + tfrecords_name, + redo_matching=True, + remove_zeros=True, + policy="autopilot", ): print(f"Reading dataset from {data_dir}") print(f"TFRecord will be saved at {tfrecords_dir}/{tfrecords_name}") + if policy == "autopilot": + processed_frames_file_name = "matched_frame_ctrl_cmd_processed.txt" + elif policy == "point_goal_nav": + processed_frames_file_name = "matched_frame_ctrl_goal_processed.txt" + else: + raise Exception("Unknown policy") + # load the datasets avaible. datasets = [ d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d)) @@ -83,12 +104,13 @@ def convert_dataset( # match frames. max_offset = 1e3 # 1ms - frames = associate_frames.match_frame_ctrl_cmd( + frames = associate_frames.match_frame_ctrl_input( data_dir, datasets, max_offset, redo_matching=redo_matching, remove_zeros=remove_zeros, + policy=policy, ) # creating TFRecords output folder. @@ -96,12 +118,20 @@ def convert_dataset( os.makedirs(tfrecords_dir) # generate data in the TFRecord format. - samples = load_labels(data_dir, datasets) + samples = load_labels(data_dir, datasets, policy) with tf.io.TFRecordWriter(tfrecords_dir + "/" + tfrecords_name) as writer: - for image_path, ctrl_cmd in samples.items(): + for image_path, ctrl_input in samples.items(): try: image = tf.io.decode_jpeg(tf.io.read_file(image_path)) - example = tfrecord_utils.create_example(image, image_path, ctrl_cmd) + if policy == "autopilot": + example = tfrecord_utils.create_example_autopilot( + image, image_path, ctrl_input + ) + elif policy == "point_goal_nav": + example = tfrecord_utils.create_example_point_goal_nav( + image, image_path, ctrl_input + ) + writer.write(example.SerializeToString()) except: print(f"Oops! Image {image_path} cannot be found.") diff --git a/policy/openbot/tfrecord_utils.py b/policy/openbot/tfrecord_utils.py index bfe446ffd..2eed07452 100644 --- a/policy/openbot/tfrecord_utils.py +++ b/policy/openbot/tfrecord_utils.py @@ -32,7 +32,7 @@ def float_feature_list(value): return tf.train.Feature(float_list=tf.train.FloatList(value=value)) -def parse_tfrecord_fn(example): +def parse_tfrecord_fn_autopilot(example): """Parse the input `tf.train.Example` proto.""" # Create a description of the features. @@ -51,7 +51,28 @@ def parse_tfrecord_fn(example): return example -def create_example(image, path, ctrl_cmd): +def parse_tfrecord_fn_point_goal_nav(example): + """Parse the input `tf.train.Example` proto.""" + + # Create a description of the features. + feature_description = { + "image": tf.io.FixedLenFeature([], tf.string), + "path": tf.io.FixedLenFeature([], tf.string), + "left": tf.io.FixedLenFeature([], tf.float32), + "right": tf.io.FixedLenFeature([], tf.float32), + "dist": tf.io.FixedLenFeature([], tf.float32), + "sinYaw": tf.io.FixedLenFeature([], tf.float32), + "cosYaw": tf.io.FixedLenFeature([], tf.float32), + } + + example = tf.io.parse_single_example(example, feature_description) + img = tf.io.decode_jpeg(example["image"], channels=3) + img = tf.image.convert_image_dtype(img, tf.float32) + example["image"] = img + return example + + +def create_example_autopilot(image, path, ctrl_cmd): """Converts the train features into a `tf.train.Example` eady to be written to a tfrecord file.""" # Create a dictionary mapping the feature name to the tf.train.Example-compatible data type. @@ -64,3 +85,20 @@ def create_example(image, path, ctrl_cmd): } return tf.train.Example(features=tf.train.Features(feature=feature)) + + +def create_example_point_goal_nav(image, path, ctrl_goal): + """Converts the train features into a `tf.train.Example` eady to be written to a tfrecord file.""" + + # Create a dictionary mapping the feature name to the tf.train.Example-compatible data type. + feature = { + "image": image_feature(image), + "path": bytes_feature(path), + "left": float_feature(float(ctrl_goal[0])), + "right": float_feature(float(ctrl_goal[1])), + "dist": float_feature(float(ctrl_goal[2])), + "sinYaw": float_feature(float(ctrl_goal[3])), + "cosYaw": float_feature(float(ctrl_goal[4])), + } + + return tf.train.Example(features=tf.train.Features(feature=feature)) diff --git a/policy/openbot/train.py b/policy/openbot/train.py index 09d1547db..3b8bb63c4 100644 --- a/policy/openbot/train.py +++ b/policy/openbot/train.py @@ -34,19 +34,19 @@ AUTOTUNE = tf.data.experimental.AUTOTUNE -dataset_name = "my_openbot" - @dataclass class Hyperparameters: - MODEL: str = "cil_mobile" + MODEL: str = "pilot_net" + POLICY: str = "autopilot" - TRAIN_BATCH_SIZE: int = 16 + TRAIN_BATCH_SIZE: int = 128 TEST_BATCH_SIZE: int = 16 - LEARNING_RATE: float = 0.0001 - NUM_EPOCHS: int = 10 + LEARNING_RATE: float = 0.0003 + NUM_EPOCHS: int = 100 BATCH_NORM: bool = True + IS_CROP: bool = False FLIP_AUG: bool = False CMD_AUG: bool = False @@ -86,6 +86,7 @@ def __init__(self, params: Hyperparameters): self.hyperparameters = params self.NETWORK_IMG_WIDTH = 0 self.NETWORK_IMG_HEIGHT = 0 + self.INITIAL_EPOCH = 0 self.train_data_dir = "" self.test_data_dir = "" self.train_datasets = [] @@ -98,6 +99,7 @@ def __init__(self, params: Hyperparameters): self.test_ds = None self.history = None self.model_name = "" + self.dataset_name = "openbot" self.checkpoint_path = "" self.log_path = "" self.loss_fn = None @@ -158,19 +160,22 @@ def process_data(tr: Training): # 1ms max_offset = 1e3 - train_frames = associate_frames.match_frame_ctrl_cmd( + train_frames = associate_frames.match_frame_ctrl_input( tr.train_data_dir, tr.train_datasets, max_offset, redo_matching=tr.redo_matching, remove_zeros=tr.remove_zeros, + policy=tr.hyperparameters.POLICY, ) - test_frames = associate_frames.match_frame_ctrl_cmd( + + test_frames = associate_frames.match_frame_ctrl_input( tr.test_data_dir, tr.test_datasets, max_offset, redo_matching=tr.redo_matching, remove_zeros=tr.remove_zeros, + policy=tr.hyperparameters.POLICY, ) tr.image_count_train = len(train_frames) @@ -185,42 +190,94 @@ def load_tfrecord(tr: Training, verbose=0): def process_train_sample(features): # image = tf.image.resize(features["image"], size=(224, 224)) image = features["image"] - cmd = features["cmd"] - label = [features["left"], features["right"]] - image = data_augmentation.augment_img(image) - if tr.hyperparameters.FLIP_AUG: - img, cmd, label = data_augmentation.flip_sample(img, cmd, label) - if tr.hyperparameters.CMD_AUG: - cmd = data_augmentation.augment_cmd(cmd) - return (image, cmd), label + if tr.hyperparameters.POLICY == "autopilot": + cmd_input = features["cmd"] + label = [features["left"], features["right"]] + image = data_augmentation.augment_img(image) + if tr.hyperparameters.FLIP_AUG: + image, cmd_input, label = data_augmentation.flip_sample( + image, cmd_input, label + ) + if tr.hyperparameters.CMD_AUG: + cmd_input = data_augmentation.augment_cmd(cmd_input) + elif tr.hyperparameters.POLICY == "point_goal_nav": + image = tf.image.crop_to_bounding_box( + image, tf.shape(image)[0] - 90, tf.shape(image)[1] - 160, 90, 160 + ) + cmd_input = [features["dist"], features["sinYaw"], features["cosYaw"]] + label = [features["left"], features["right"]] + image = data_augmentation.augment_img(image) + if tr.hyperparameters.FLIP_AUG: + print( + "Image flip augmentation is not implemented for Point Goal Navigation." + ) + if tr.hyperparameters.CMD_AUG: + print( + "Command augmentation is not implemented for Point Goal Navigation." + ) + + return (image, cmd_input), label def process_test_sample(features): image = features["image"] - cmd = features["cmd"] + + if tr.hyperparameters.POLICY == "autopilot": + cmd_input = features["cmd"] + + elif tr.hyperparameters.POLICY == "point_goal_nav": + image = tf.image.crop_to_bounding_box( + image, tf.shape(image)[0] - 90, tf.shape(image)[1] - 160, 90, 160 + ) + cmd_input = [features["dist"], features["sinYaw"], features["cosYaw"]] + label = [features["left"], features["right"]] - return (image, cmd), label + return (image, cmd_input), label - train_dataset = ( - tf.data.TFRecordDataset(tr.train_data_dir, num_parallel_reads=AUTOTUNE) - .map(tfrecord_utils.parse_tfrecord_fn, num_parallel_calls=AUTOTUNE) - .map(process_train_sample, num_parallel_calls=AUTOTUNE) - ) + if tr.hyperparameters.POLICY == "autopilot": + train_dataset = ( + tf.data.TFRecordDataset(tr.train_data_dir, num_parallel_reads=AUTOTUNE) + .map( + tfrecord_utils.parse_tfrecord_fn_autopilot, num_parallel_calls=AUTOTUNE + ) + .map(process_train_sample, num_parallel_calls=AUTOTUNE) + ) + elif tr.hyperparameters.POLICY == "point_goal_nav": + train_dataset = ( + tf.data.TFRecordDataset(tr.train_data_dir, num_parallel_reads=AUTOTUNE) + .map( + tfrecord_utils.parse_tfrecord_fn_point_goal_nav, + num_parallel_calls=AUTOTUNE, + ) + .map(process_train_sample, num_parallel_calls=AUTOTUNE) + ) # Obtains the images shapes of records from .tfrecords. - for (image, cmd), label in train_dataset.take(1): + for (image, cmd_input), label in train_dataset.take(1): shape = image.numpy().shape tr.NETWORK_IMG_HEIGHT = shape[0] tr.NETWORK_IMG_WIDTH = shape[1] print("Image shape: ", shape) - print("Command: ", cmd.numpy()) + print("Command: ", cmd_input.numpy()) print("Label: ", label.numpy()) - test_dataset = ( - tf.data.TFRecordDataset(tr.test_data_dir, num_parallel_reads=AUTOTUNE) - .map(tfrecord_utils.parse_tfrecord_fn, num_parallel_calls=AUTOTUNE) - .map(process_test_sample, num_parallel_calls=AUTOTUNE) - ) + if tr.hyperparameters.POLICY == "autopilot": + test_dataset = ( + tf.data.TFRecordDataset(tr.test_data_dir, num_parallel_reads=AUTOTUNE) + .map( + tfrecord_utils.parse_tfrecord_fn_autopilot, num_parallel_calls=AUTOTUNE + ) + .map(process_test_sample, num_parallel_calls=AUTOTUNE) + ) + elif tr.hyperparameters.POLICY == "point_goal_nav": + test_dataset = ( + tf.data.TFRecordDataset(tr.test_data_dir, num_parallel_reads=AUTOTUNE) + .map( + tfrecord_utils.parse_tfrecord_fn_point_goal_nav, + num_parallel_calls=AUTOTUNE, + ) + .map(process_test_sample, num_parallel_calls=AUTOTUNE) + ) # Obtains the total number of records from .tfrecords file # https://stackoverflow.com/questions/40472139/obtaining-total-number-of-records-from-tfrecords-file-in-tensorflow @@ -252,8 +309,12 @@ def load_data(tr: Training, verbose=0): list_test_ds = tf.data.Dataset.list_files( [str(tr.test_data_dir + "/" + ds + "/*/images/*") for ds in tr.test_datasets] ) - train_data = dataloader.dataloader(tr.train_data_dir, tr.train_datasets) - test_data = dataloader.dataloader(tr.test_data_dir, tr.test_datasets) + train_data = dataloader.dataloader( + tr.train_data_dir, tr.train_datasets, tr.hyperparameters.POLICY + ) + test_data = dataloader.dataloader( + tr.test_data_dir, tr.test_datasets, tr.hyperparameters.POLICY + ) if verbose: for f in list_train_ds.take(5): @@ -261,36 +322,37 @@ def load_data(tr: Training, verbose=0): print() for f in list_test_ds.take(5): print(f.numpy()) - print("Number of train samples: %d" % len(train_data.labels)) - print("Number of test samples: %d" % len(test_data.labels)) + print() + print("Number of train samples: %d" % len(train_data.labels)) + print("Number of test samples: %d" % len(test_data.labels)) def process_train_path(file_path): - cmd, label = train_data.get_label( + cmd_input, label = train_data.get_label( tf.strings.regex_replace(file_path, "[/\\\\]", "/") ) - img = utils.load_img(file_path) + img = utils.load_img(file_path, tr.hyperparameters.IS_CROP) img = data_augmentation.augment_img(img) if tr.hyperparameters.FLIP_AUG: - img, cmd, label = data_augmentation.flip_sample(img, cmd, label) + img, cmd_input, label = data_augmentation.flip_sample(img, cmd_input, label) if tr.hyperparameters.CMD_AUG: - cmd = data_augmentation.augment_cmd(cmd) - return (img, cmd), label + cmd_input = data_augmentation.augment_cmd(cmd_input) + return (img, cmd_input), label def process_test_path(file_path): - cmd, label = test_data.get_label( + cmd_input, label = test_data.get_label( tf.strings.regex_replace(file_path, "[/\\\\]", "/") ) - img = utils.load_img(file_path) - return (img, cmd), label + img = utils.load_img(file_path, tr.hyperparameters.IS_CROP) + return (img, cmd_input), label # Set `num_parallel_calls` so multiple images are loaded/processed in parallel. labeled_ds = list_train_ds.map(process_train_path, num_parallel_calls=4) - for (image, cmd), label in labeled_ds.take(1): + for (image, cmd_input), label in labeled_ds.take(1): shape = image.numpy().shape tr.NETWORK_IMG_HEIGHT = shape[0] tr.NETWORK_IMG_WIDTH = shape[1] print("Image shape: ", shape) - print("Command: ", cmd.numpy()) + print("Command: ", cmd_input.numpy()) print("Label: ", label.numpy()) tr.train_ds = utils.prepare_for_training( ds=labeled_ds, @@ -304,19 +366,19 @@ def process_test_path(file_path): def visualize_train_data(tr: Training): - (image_batch, cmd_batch), label_batch = next(iter(tr.train_ds)) - utils.show_train_batch(image_batch.numpy(), cmd_batch.numpy(), label_batch.numpy()) - savefig(os.path.join(models_dir, "train_preview.png")) + utils.show_batch(dataset=tr.train_ds, policy=tr.hyperparameters.POLICY, model=None) + utils.savefig(os.path.join(models_dir, "train_preview.png")) def do_training(tr: Training, callback: tf.keras.callbacks.Callback, verbose=0): - tr.model_name = dataset_name + "_" + str(tr.hyperparameters) + tr.model_name = tr.dataset_name + "_" + str(tr.hyperparameters) tr.checkpoint_path = os.path.join(models_dir, tr.model_name, "checkpoints") + tr.log_path = os.path.join(models_dir, tr.model_name, "logs") + tr.custom_objects = { "direction_metric": metrics.direction_metric, "angle_metric": metrics.angle_metric, } - model_path = os.path.join(models_dir, tr.model_name, "model") if tr.hyperparameters.WANDB: import wandb @@ -330,27 +392,49 @@ def do_training(tr: Training, callback: tf.keras.callbacks.Callback, verbose=0): config.batch_size = tr.hyperparameters.TRAIN_BATCH_SIZE config["model_name"] = tr.model_name - append_logs = False + resume_training = False model: tf.keras.Model + if tr.hyperparameters.USE_LAST: - append_logs = True - model = tf.keras.models.load_model( - model_path, - custom_objects=tr.custom_objects, - compile=False, - ) - else: + try: + dirs = utils.list_dirs(tr.checkpoint_path) + last_checkpoint = os.path.join(tr.checkpoint_path, "cp-last.ckpt") + os.path.join(tr.checkpoint_path, last_checkpoint) + model = tf.keras.models.load_model( + last_checkpoint, + custom_objects=tr.custom_objects, + compile=False, + ) + log_file = open(os.path.join(tr.log_path, "log.csv"), "r") + tr.INITIAL_EPOCH = int(log_file.readlines()[-1].split(",")[0]) + 1 + log_file.close() + resume_training = True + print(f"Resuming from checkpoint: {last_checkpoint}") + # print(f"Resuming from saved model.") + except FileNotFoundError as err: + print("No checkpoint or log file found, training new model!") + print(err) + except Exception as err: + print(err) + raise + + if not resume_training: model = getattr(models, tr.hyperparameters.MODEL)( tr.NETWORK_IMG_WIDTH, tr.NETWORK_IMG_HEIGHT, tr.hyperparameters.BATCH_NORM, + tr.hyperparameters.POLICY, ) dot_img_file = os.path.join(models_dir, tr.model_name, "model.png") tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True) callback.broadcast("model", tr.model_name) - tr.loss_fn = losses.sq_weighted_mse_angle + if tr.hyperparameters.POLICY == "autopilot": + tr.loss_fn = losses.sq_weighted_mse_angle + elif tr.hyperparameters.POLICY == "point_goal_nav": + tr.loss_fn = losses.mae_raw_weighted_mse_angle + tr.metric_list = [ "mean_absolute_error", tr.custom_objects["direction_metric"], @@ -362,7 +446,6 @@ def do_training(tr: Training, callback: tf.keras.callbacks.Callback, verbose=0): if verbose: print(model.summary()) - tr.log_path = os.path.join(models_dir, tr.model_name, "logs") if verbose: print(tr.model_name) @@ -371,126 +454,144 @@ def do_training(tr: Training, callback: tf.keras.callbacks.Callback, verbose=0): ) callback.broadcast("message", "Fit model...") callback_list = [ - callbacks.checkpoint_cb(tr.checkpoint_path), + callbacks.checkpoint_last_cb(tr.checkpoint_path), + callbacks.checkpoint_best_train_cb(tr.checkpoint_path), + callbacks.checkpoint_best_val_cb(tr.checkpoint_path), callbacks.tensorboard_cb(tr.log_path), - callbacks.logger_cb(tr.log_path, append_logs), + callbacks.logger_cb(tr.log_path, resume_training), callback, ] if tr.hyperparameters.WANDB: callback_list += [WandbCallback()] - tr.history = model.fit( tr.train_ds, epochs=tr.hyperparameters.NUM_EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, + initial_epoch=tr.INITIAL_EPOCH, validation_data=tr.test_ds, verbose=verbose, callbacks=callback_list, ) - model.save(model_path) if tr.hyperparameters.WANDB: - wandb.save(model_path) + wandb.save(tr.log_path) wandb.finish() + callback.broadcast("message", "...Done") + def do_evaluation(tr: Training, callback: tf.keras.callbacks.Callback, verbose=0): callback.broadcast("message", "Generate plots...") - plt.plot(tr.history.history["mean_absolute_error"], label="mean_absolute_error") + + x = np.arange(tr.INITIAL_EPOCH + 1, tr.history.params["epochs"] + 1, 1) + + plt.figure().gca().xaxis.get_major_locator().set_params(integer=True) + plt.plot(x, tr.history.history["loss"], label="loss") + plt.plot(x, tr.history.history["val_loss"], label="val_loss") + plt.xlabel("Epoch") + plt.ylabel("Loss") + plt.legend(loc="upper right") + utils.savefig(os.path.join(tr.log_path, "loss.png")) + + plt.figure().gca().xaxis.get_major_locator().set_params(integer=True) + plt.plot(x, tr.history.history["mean_absolute_error"], label="mean_absolute_error") plt.plot( - tr.history.history["val_mean_absolute_error"], label="val_mean_absolute_error" + x, + tr.history.history["val_mean_absolute_error"], + label="val_mean_absolute_error", ) plt.xlabel("Epoch") plt.ylabel("Mean Absolute Error") - plt.legend(loc="lower right") - savefig(os.path.join(tr.log_path, "error.png")) + plt.legend(loc="upper right") + utils.savefig(os.path.join(tr.log_path, "error.png")) - plt.plot(tr.history.history["direction_metric"], label="direction_metric") - plt.plot(tr.history.history["val_direction_metric"], label="val_direction_metric") + plt.figure().gca().xaxis.get_major_locator().set_params(integer=True) + plt.plot(x, tr.history.history["direction_metric"], label="direction_metric") + plt.plot( + x, tr.history.history["val_direction_metric"], label="val_direction_metric" + ) plt.xlabel("Epoch") plt.ylabel("Direction Metric") plt.legend(loc="lower right") - savefig(os.path.join(tr.log_path, "direction.png")) + utils.savefig(os.path.join(tr.log_path, "direction.png")) - plt.plot(tr.history.history["angle_metric"], label="angle_metric") - plt.plot(tr.history.history["val_angle_metric"], label="val_angle_metric") + plt.figure().gca().xaxis.get_major_locator().set_params(integer=True) + plt.plot(x, tr.history.history["angle_metric"], label="angle_metric") + plt.plot(x, tr.history.history["val_angle_metric"], label="val_angle_metric") plt.xlabel("Epoch") plt.ylabel("Angle Metric") plt.legend(loc="lower right") - savefig(os.path.join(tr.log_path, "angle.png")) - - plt.plot(tr.history.history["loss"], label="loss") - plt.plot(tr.history.history["val_loss"], label="val_loss") - plt.xlabel("Epoch") - plt.ylabel("Loss") - plt.legend(loc="lower right") - savefig(os.path.join(tr.log_path, "loss.png")) + utils.savefig(os.path.join(tr.log_path, "angle.png")) callback.broadcast("message", "Generate tflite models...") checkpoint_path = tr.checkpoint_path print("checkpoint_path", checkpoint_path) - best_index = np.argmax( - np.array(tr.history.history["val_angle_metric"]) - + np.array(tr.history.history["val_direction_metric"]) + + best_train_checkpoint = "cp-best-train.ckpt" + best_train_tflite = utils.generate_tflite(tr.checkpoint_path, best_train_checkpoint) + utils.save_tflite(best_train_tflite, tr.checkpoint_path, "best-train") + best_train_index = np.argmin(np.array(tr.history.history["loss"])) + print( + "Best Train Checkpoint (epoch %s) - angle: %.4f, val_angle: %.4f, direction: %.4f, val_direction: %.4f" + % ( + best_train_index, + tr.history.history["angle_metric"][best_train_index], + tr.history.history["val_angle_metric"][best_train_index], + tr.history.history["direction_metric"][best_train_index], + tr.history.history["val_direction_metric"][best_train_index], + ) ) - best_checkpoint = str("cp-%04d.ckpt" % (best_index + 1)) - best_tflite = utils.generate_tflite(checkpoint_path, best_checkpoint) - utils.save_tflite(best_tflite, checkpoint_path, "best") + + best_val_checkpoint = "cp-best-val.ckpt" + best_val_tflite = utils.generate_tflite(tr.checkpoint_path, best_val_checkpoint) + utils.save_tflite(best_val_tflite, tr.checkpoint_path, "best") + utils.save_tflite(best_val_tflite, tr.checkpoint_path, "best-val") + best_val_index = np.argmin(np.array(tr.history.history["val_loss"])) print( - "Best Checkpoint (val_angle: %s, val_direction: %s): %s" + "Best Val Checkpoint (epoch %s) - angle: %.4f, val_angle: %.4f, direction: %.4f, val_direction: %.4f" % ( - tr.history.history["val_angle_metric"][best_index], - tr.history.history["val_direction_metric"][best_index], - best_checkpoint, + best_val_index, + tr.history.history["angle_metric"][best_val_index], + tr.history.history["val_angle_metric"][best_val_index], + tr.history.history["direction_metric"][best_val_index], + tr.history.history["val_direction_metric"][best_val_index], ) ) - last_checkpoint = sorted(utils.list_dirs(checkpoint_path))[-1] - last_tflite = utils.generate_tflite(checkpoint_path, last_checkpoint) - utils.save_tflite(last_tflite, checkpoint_path, "last") + last_checkpoint = "cp-last.ckpt" + last_tflite = utils.generate_tflite(tr.checkpoint_path, last_checkpoint) + utils.save_tflite(last_tflite, tr.checkpoint_path, "last") print( - "Last Checkpoint (val_angle: %s, val_direction: %s): %s" + "Last Checkpoint - angle: %.4f, val_angle: %.4f, direction: %.4f, val_direction: %.4f" % ( + tr.history.history["angle_metric"][-1], tr.history.history["val_angle_metric"][-1], + tr.history.history["direction_metric"][-1], tr.history.history["val_direction_metric"][-1], - last_checkpoint, ) ) callback.broadcast("message", "Evaluate model...") - best_model = utils.load_model( - os.path.join(checkpoint_path, best_checkpoint), + last_model = utils.load_model( + os.path.join(tr.checkpoint_path, last_checkpoint), tr.loss_fn, tr.metric_list, tr.custom_objects, ) - # test_loss, test_acc, test_dir, test_ang = best_model.evaluate(tr.test_ds, - res = best_model.evaluate( + # test_loss, test_acc, test_dir, test_ang + res = last_model.evaluate( tr.test_ds, steps=tr.image_count_test / tr.hyperparameters.TEST_BATCH_SIZE, verbose=2, ) print(res) - NUM_SAMPLES = 15 - (image_batch, cmd_batch), label_batch = next(iter(tr.test_ds)) - pred_batch = best_model.predict( - ( - tf.slice(image_batch, [0, 0, 0, 0], [NUM_SAMPLES, -1, -1, -1]), - tf.slice(cmd_batch, [0], [NUM_SAMPLES]), - ) - ) - utils.show_test_batch( - image_batch.numpy(), cmd_batch.numpy(), label_batch.numpy(), pred_batch + utils.show_batch( + dataset=tr.test_ds, policy=tr.hyperparameters.POLICY, model=last_model ) - savefig(os.path.join(tr.log_path, "test_preview.png")) - utils.compare_tf_tflite(best_model, best_tflite) - - -def savefig(path): - plt.savefig(path, bbox_inches="tight") - plt.clf() + utils.savefig(os.path.join(tr.log_path, "test_preview.png")) + utils.compare_tf_tflite(last_model, last_tflite, policy=tr.hyperparameters.POLICY) def start_train( @@ -518,7 +619,7 @@ def start_train( return tr -def create_tfrecord(callback: MyCallback): +def create_tfrecord(callback: MyCallback, policy="autopilot"): callback.broadcast( "message", "Converting data to tfrecord (this may take some time)..." ) @@ -526,11 +627,13 @@ def create_tfrecord(callback: MyCallback): os.path.join(dataset_dir, "train_data"), os.path.join(dataset_dir, "tfrecords"), "train.tfrec", + policy=policy, ) tfrecord.convert_dataset( os.path.join(dataset_dir, "test_data"), os.path.join(dataset_dir, "tfrecords"), "test.tfrec", + policy=policy, ) @@ -549,22 +652,22 @@ def create_tfrecord(callback: MyCallback): type=str, default="pilot_net", choices=["cil_mobile", "cil_mobile_fast", "cil", "pilot_net"], - help="network architecture (default: cil_mobile)", + help="network architecture (default: pilot_net)", ) parser.add_argument( "--batch_size", type=int, - default=16, - help="number of training epochs (default: 16)", + default=128, + help="number of training epochs (default: 128)", ) parser.add_argument( "--learning_rate", type=float, - default=0.0001, - help="learning rate (default: 0.0001)", + default=0.0003, + help="learning rate (default: 0.0003)", ) parser.add_argument( - "--num_epochs", type=int, default=10, help="number of epochs (default: 10)" + "--num_epochs", type=int, default=100, help="number of epochs (default: 100)" ) parser.add_argument("--batch_norm", action="store_true", help="use batch norm") parser.add_argument( @@ -583,11 +686,19 @@ def create_tfrecord(callback: MyCallback): parser.add_argument( "--wandb", action="store_true", help="training logs with weights & biases" ) + parser.add_argument( + "--policy", + type=str, + default="autopilot", + choices=["autopilot", "point_goal_nav"], + help="the type of policy to be trained (default: autopilot)", + ) args = parser.parse_args() params = Hyperparameters() params.MODEL = args.model + params.POLICY = args.policy params.TRAIN_BATCH_SIZE = args.batch_size params.TEST_BATCH_SIZE = args.batch_size params.LEARNING_RATE = args.learning_rate @@ -597,6 +708,7 @@ def create_tfrecord(callback: MyCallback): params.CMD_AUG = args.cmd_aug params.USE_LAST = args.resume params.WANDB = args.wandb + params.IS_CROP = args.policy == "point_goal_nav" def broadcast(event, payload=None): print() @@ -606,6 +718,6 @@ def broadcast(event, payload=None): my_callback = MyCallback(broadcast, event) if args.create_tf_record: - create_tfrecord(my_callback) + create_tfrecord(my_callback, args.policy) start_train(params, my_callback, verbose=1, no_tf_record=args.no_tf_record) diff --git a/policy/openbot/utils.py b/policy/openbot/utils.py index 18b551baa..439260b35 100644 --- a/policy/openbot/utils.py +++ b/policy/openbot/utils.py @@ -49,36 +49,69 @@ def prepare_for_training( return ds -def show_train_batch(image_batch, cmd_batch, label_batch, fig_num=1): - plt.figure(num=fig_num, figsize=(15, 10)) - for n in range(15): - ax = plt.subplot(5, 3, n + 1) - plt.imshow(image_batch[n]) - plt.title( - "Cmd: %s, Label: [%.2f %.2f]" - % (cmd_batch[n], float(label_batch[n][0]), float(label_batch[n][1])) - ) - plt.axis("off") +def show_batch(dataset, policy="autopilot", model=None, fig_num=1): + + (image_batch, cmd_batch), label_batch = next(iter(dataset)) + NUM_SAMPLES = min(image_batch.numpy().shape[0], 15) + + if policy == "autopilot": + command_input_name = "Cmd" + size = (15, 10) + if model is not None: + pred_batch = model.predict( + ( + tf.slice(image_batch, [0, 0, 0, 0], [NUM_SAMPLES, -1, -1, -1]), + tf.slice(cmd_batch, [0], [NUM_SAMPLES]), + ) + ) + elif policy == "point_goal_nav": + command_input_name = "Goal" + size = (15, 15) + if model is not None: + pred_batch = model.predict( + ( + tf.slice(image_batch, [0, 0, 0, 0], [NUM_SAMPLES, -1, -1, -1]), + tf.slice(cmd_batch, [0, 0], [NUM_SAMPLES, -1]), + ) + ) + else: + raise Exception("Unknown policy") + plt.figure(num=fig_num, figsize=size) -def show_test_batch(image_batch, cmd_batch, label_batch, pred_batch, fig_num=1): - plt.figure(num=fig_num, figsize=(15, 10)) - for n in range(15): + for n in range(NUM_SAMPLES): ax = plt.subplot(5, 3, n + 1) plt.imshow(image_batch[n]) - plt.title( - "Cmd: %s, Label: [%.2f %.2f], Pred: [%.2f %.2f]" - % ( - cmd_batch[n], - float(label_batch[n][0]), - float(label_batch[n][1]), - float(pred_batch[n][0]), - float(pred_batch[n][1]), + if model is None: + plt.title( + "%s: %s, Label: [%.2f %.2f]" + % ( + command_input_name, + cmd_batch.numpy()[n], + float(label_batch[n][0]), + float(label_batch[n][1]), + ) + ) + else: + plt.title( + "%s: %s, Label: [%.2f %.2f], Pred: [%.2f %.2f]" + % ( + command_input_name, + cmd_batch.numpy()[n], + float(label_batch[n][0]), + float(label_batch[n][1]), + float(pred_batch[n][0]), + float(pred_batch[n][1]), + ) ) - ) plt.axis("off") +def savefig(path): + plt.savefig(path, bbox_inches="tight") + plt.clf() + + def generate_tflite(path, filename): converter = tf.lite.TFLiteConverter.from_saved_model(os.path.join(path, filename)) converter.optimizations = [tf.lite.Optimize.DEFAULT] @@ -103,7 +136,9 @@ def load_model(model_path, loss_fn, metric_list, custom_objects): return model -def compare_tf_tflite(model, tflite_model, img=None, cmd=None): +def compare_tf_tflite( + model, tflite_model, img=None, cmd=None, policy="autopilot", debug=False +): # Load TFLite model and allocate tensors. interpreter = tf.lite.Interpreter(model_content=tflite_model) interpreter.allocate_tensors() @@ -112,6 +147,17 @@ def compare_tf_tflite(model, tflite_model, img=None, cmd=None): input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() + if debug: + print("input_details:", input_details) + print("output_details:", output_details) + + if policy == "autopilot": + command_input_name = "cmd_input" + elif policy == "point_goal_nav": + command_input_name = "goal_input" + else: + raise Exception("Unknown policy") + # Test the TensorFlow Lite model on input data. If no data provided, generate random data. input_data = {} for input_detail in input_details: @@ -124,15 +170,17 @@ def compare_tf_tflite(model, tflite_model, img=None, cmd=None): print(img) input_data["img_input"] = img interpreter.set_tensor(input_detail["index"], input_data["img_input"]) - elif "cmd_input" in input_detail["name"]: + elif command_input_name in input_detail["name"]: if cmd is None: - input_data["cmd_input"] = np.array( + input_data[command_input_name] = np.array( np.random.random_sample(input_detail["shape"]), dtype=np.float32 ) else: print(cmd) input_data[input_detail["name"]] = cmd[0] - interpreter.set_tensor(input_detail["index"], input_data["cmd_input"]) + interpreter.set_tensor( + input_detail["index"], input_data[command_input_name] + ) else: ValueError("Unknown input") @@ -141,19 +189,20 @@ def compare_tf_tflite(model, tflite_model, img=None, cmd=None): # The function `get_tensor()` returns a copy of the tensor data. # Use `tensor()` in order to get a pointer to the tensor. tflite_results = interpreter.get_tensor(output_details[0]["index"]) - print("tflite:", tflite_results) - - # Test the TensorFlow model on random input data. tf_results = model.predict( - (tf.constant(input_data["img_input"]), tf.constant(input_data["cmd_input"])) + ( + tf.constant(input_data["img_input"]), + tf.constant(input_data[command_input_name]), + ) ) + print("tflite:", tflite_results) print("tf:", tf_results) # Compare the result. for tf_result, tflite_result in zip(tf_results, tflite_results): print( - "Almost equal (5% tolerance):", - np.allclose(tf_result, tflite_result, rtol=5e-02), + "Almost equal (10% tolerance):", + np.allclose(tf_result, tflite_result, rtol=0.1), ) # np.testing.assert_almost_equal(tf_result, tflite_result, decimal=2) @@ -162,10 +211,13 @@ def list_dirs(path): return [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))] -def load_img(file_path): +def load_img(file_path, is_crop=False): img = tf.io.read_file(file_path) img = tf.image.decode_image(img, channels=3, dtype=tf.float32) - + if is_crop: + img = tf.image.crop_to_bounding_box( + img, tf.shape(img)[0] - 90, tf.shape(img)[1] - 160, 90, 160 + ) return img diff --git a/policy/policy_learning.ipynb b/policy/policy_learning.ipynb index 65f2e2b70..a3bc61183 100644 --- a/policy/policy_learning.ipynb +++ b/policy/policy_learning.ipynb @@ -39,6 +39,19 @@ "tf.__version__" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "43ca9a91", + "metadata": {}, + "outputs": [], + "source": [ + "if tf.test.gpu_device_name():\n", + " print(\"Default GPU Device:{}\".format(tf.test.gpu_device_name()))\n", + "else:\n", + " print(\"Please install GPU version of TF if you have one.\")" + ] + }, { "cell_type": "code", "execution_count": null, @@ -73,7 +86,7 @@ "outputs": [], "source": [ "dataset_dir = \"dataset\"\n", - "dataset_name = \"my_openbot\"\n", + "dataset_name = \"openbot\"\n", "train_data_dir = os.path.join(dataset_dir, \"train_data\")\n", "test_data_dir = os.path.join(dataset_dir, \"test_data\")" ] @@ -83,7 +96,8 @@ "id": "4388bbaa", "metadata": {}, "source": [ - "## Hyperparameters" + "## Hyperparameters\n", + "" ] }, { @@ -103,16 +117,20 @@ "source": [ "params = train.Hyperparameters()\n", "\n", - "params.MODEL = \"pilot_net\"\n", - "params.TRAIN_BATCH_SIZE = 16\n", + "params.MODEL = \"pilot_net\" # choices: \"pilot_net\",\"cil_mobile\",\"cil_mobile_fast\",\"cil\"\n", + "params.POLICY = \"autopilot\" # choices: \"autopilot\",\"point_goal_nav\"\n", + "params.TRAIN_BATCH_SIZE = 128\n", "params.TEST_BATCH_SIZE = 16\n", - "params.LEARNING_RATE = 0.0001\n", - "params.NUM_EPOCHS = 10\n", - "params.BATCH_NORM = True\n", - "params.FLIP_AUG = False\n", - "params.CMD_AUG = False\n", - "params.USE_LAST = False\n", - "params.WANDB = False" + "params.LEARNING_RATE = 0.0003\n", + "params.NUM_EPOCHS = 100\n", + "params.BATCH_NORM = True # use batch norm (recommended)\n", + "params.FLIP_AUG = False # flip image and controls as augmentation (only autopilot)\n", + "params.CMD_AUG = False # randomize high-level command as augmentation (only autopilot)\n", + "params.USE_LAST = False # resume training from last checkpoint\n", + "params.WANDB = False\n", + "# policy = \"autopilot\": images are expected to be 256x96 - no cropping required\n", + "# policy = \"point_goal_nav\": images are expected to be 160x120 - cropping to 160x90\n", + "params.IS_CROP = params.POLICY == \"point_goal_nav\"" ] }, { @@ -131,6 +149,7 @@ "outputs": [], "source": [ "tr = train.Training(params)\n", + "tr.dataset_name = dataset_name\n", "tr.train_data_dir = train_data_dir\n", "tr.test_data_dir = test_data_dir" ] @@ -188,7 +207,7 @@ "metadata": {}, "outputs": [], "source": [ - "train.create_tfrecord(my_callback)" + "train.create_tfrecord(my_callback, policy=tr.hyperparameters.POLICY)" ] }, { @@ -241,8 +260,7 @@ "metadata": {}, "outputs": [], "source": [ - "(image_batch, cmd_batch), label_batch = next(iter(tr.train_ds))\n", - "utils.show_train_batch(image_batch.numpy(), cmd_batch.numpy(), label_batch.numpy())" + "utils.show_batch(dataset=tr.train_ds, policy=tr.hyperparameters.POLICY, model=None)" ] }, { @@ -253,6 +271,14 @@ "## Training" ] }, + { + "cell_type": "markdown", + "id": "71034420", + "metadata": {}, + "source": [ + "The number of epochs is proportional to the training time. One epoch means going through the complete dataset once. Increasing `NUM_EPOCHS` will mean longer training time, but generally leads to better performance. To get familiar with the code it can be set to small values like `5` or `10`. To train a model that performs well, it should be set to values between `50` and `200`. Setting `USE_LAST` to `true` will resume the training from the last checkpoint. The default values are `NUM_EPOCHS = 100` and `USE_LAST = False`. They are set in [Hyperparameters](#hyperparameters)." + ] + }, { "cell_type": "code", "execution_count": null, @@ -260,6 +286,8 @@ "metadata": {}, "outputs": [], "source": [ + "# params.NUM_EPOCHS = 200\n", + "# params.USE_LAST = True\n", "train.do_training(tr, my_callback, verbose=1)" ] }, @@ -279,6 +307,24 @@ "The loss and mean absolute error should decrease. This indicates that the model is fitting the data well. The custom metrics (direction and angle) should go towards 1. These provide some additional insight to the training progress. The direction metric measures weather or not predictions are in the same direction as the labels. Similarly the angle metric measures if the prediction is within a small angle of the labels. The intuition is that driving in the right direction with the correct steering angle is most critical part for good final performance." ] }, + { + "cell_type": "markdown", + "id": "4867aaa7", + "metadata": {}, + "source": [ + "### Plot metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8ca0b291", + "metadata": {}, + "outputs": [], + "source": [ + "x = np.arange(tr.INITIAL_EPOCH + 1, tr.history.params[\"epochs\"] + 1, 1)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -286,12 +332,13 @@ "metadata": {}, "outputs": [], "source": [ - "plt.plot(tr.history.history[\"loss\"], label=\"loss\")\n", - "plt.plot(tr.history.history[\"val_loss\"], label=\"val_loss\")\n", + "plt.figure().gca().xaxis.get_major_locator().set_params(integer=True)\n", + "plt.plot(x, tr.history.history[\"loss\"], label=\"loss\")\n", + "plt.plot(x, tr.history.history[\"val_loss\"], label=\"val_loss\")\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Loss\")\n", - "plt.legend(loc=\"lower right\")\n", - "plt.savefig(os.path.join(tr.log_path, \"loss.png\"))" + "plt.legend(loc=\"upper right\")\n", + "plt.savefig(os.path.join(tr.log_path, \"loss.png\"), bbox_inches=\"tight\")" ] }, { @@ -301,12 +348,15 @@ "metadata": {}, "outputs": [], "source": [ - "plt.plot(tr.history.history[\"mean_absolute_error\"], label=\"mean_absolute_error\")\n", - "plt.plot(tr.history.history[\"val_mean_absolute_error\"], label=\"val_mean_absolute_error\")\n", + "plt.figure().gca().xaxis.get_major_locator().set_params(integer=True)\n", + "plt.plot(x, tr.history.history[\"mean_absolute_error\"], label=\"mean_absolute_error\")\n", + "plt.plot(\n", + " x, tr.history.history[\"val_mean_absolute_error\"], label=\"val_mean_absolute_error\"\n", + ")\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Mean Absolute Error\")\n", - "plt.legend(loc=\"lower right\")\n", - "plt.savefig(os.path.join(tr.log_path, \"error.png\"))" + "plt.legend(loc=\"upper right\")\n", + "plt.savefig(os.path.join(tr.log_path, \"error.png\"), bbox_inches=\"tight\")" ] }, { @@ -316,12 +366,13 @@ "metadata": {}, "outputs": [], "source": [ - "plt.plot(tr.history.history[\"direction_metric\"], label=\"direction_metric\")\n", - "plt.plot(tr.history.history[\"val_direction_metric\"], label=\"val_direction_metric\")\n", + "plt.figure().gca().xaxis.get_major_locator().set_params(integer=True)\n", + "plt.plot(x, tr.history.history[\"direction_metric\"], label=\"direction_metric\")\n", + "plt.plot(x, tr.history.history[\"val_direction_metric\"], label=\"val_direction_metric\")\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Direction Metric\")\n", "plt.legend(loc=\"lower right\")\n", - "plt.savefig(os.path.join(tr.log_path, \"direction.png\"))" + "plt.savefig(os.path.join(tr.log_path, \"direction.png\"), bbox_inches=\"tight\")" ] }, { @@ -331,12 +382,13 @@ "metadata": {}, "outputs": [], "source": [ - "plt.plot(tr.history.history[\"angle_metric\"], label=\"angle_metric\")\n", - "plt.plot(tr.history.history[\"val_angle_metric\"], label=\"val_angle_metric\")\n", + "plt.figure().gca().xaxis.get_major_locator().set_params(integer=True)\n", + "plt.plot(x, tr.history.history[\"angle_metric\"], label=\"angle_metric\")\n", + "plt.plot(x, tr.history.history[\"val_angle_metric\"], label=\"val_angle_metric\")\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Angle Metric\")\n", "plt.legend(loc=\"lower right\")\n", - "plt.savefig(os.path.join(tr.log_path, \"angle.png\"))" + "plt.savefig(os.path.join(tr.log_path, \"angle.png\"), bbox_inches=\"tight\")" ] }, { @@ -344,7 +396,7 @@ "id": "e72c1de3", "metadata": {}, "source": [ - "Save tf lite models for best and last checkpoint" + "### Save tf lite models for best train, best val and last checkpoint" ] }, { @@ -354,19 +406,18 @@ "metadata": {}, "outputs": [], "source": [ - "best_index = np.argmax(\n", - " np.array(tr.history.history[\"val_angle_metric\"])\n", - " + np.array(tr.history.history[\"val_direction_metric\"])\n", - ")\n", - "best_checkpoint = str(\"cp-%04d.ckpt\" % (best_index + 1))\n", - "best_tflite = utils.generate_tflite(tr.checkpoint_path, best_checkpoint)\n", - "utils.save_tflite(best_tflite, tr.checkpoint_path, \"best\")\n", + "best_train_checkpoint = \"cp-best-train.ckpt\"\n", + "best_train_tflite = utils.generate_tflite(tr.checkpoint_path, best_train_checkpoint)\n", + "utils.save_tflite(best_train_tflite, tr.checkpoint_path, \"best-train\")\n", + "best_train_index = np.argmin(np.array(tr.history.history[\"loss\"]))\n", "print(\n", - " \"Best Checkpoint (val_angle: %s, val_direction: %s): %s\"\n", + " \"Best Train Checkpoint (epoch %s) - angle: %.4f, val_angle: %.4f, direction: %.4f, val_direction: %.4f\"\n", " % (\n", - " tr.history.history[\"val_angle_metric\"][best_index],\n", - " tr.history.history[\"val_direction_metric\"][best_index],\n", - " best_checkpoint,\n", + " best_train_index,\n", + " tr.history.history[\"angle_metric\"][best_train_index],\n", + " tr.history.history[\"val_angle_metric\"][best_train_index],\n", + " tr.history.history[\"direction_metric\"][best_train_index],\n", + " tr.history.history[\"val_direction_metric\"][best_train_index],\n", " )\n", ")" ] @@ -378,21 +429,40 @@ "metadata": {}, "outputs": [], "source": [ - "last_checkpoint = sorted(\n", - " [\n", - " d\n", - " for d in os.listdir(tr.checkpoint_path)\n", - " if os.path.isdir(os.path.join(tr.checkpoint_path, d))\n", - " ]\n", - ")[-1]\n", + "best_val_checkpoint = \"cp-best-val.ckpt\"\n", + "best_val_tflite = utils.generate_tflite(tr.checkpoint_path, best_val_checkpoint)\n", + "utils.save_tflite(best_val_tflite, tr.checkpoint_path, \"best\")\n", + "utils.save_tflite(best_val_tflite, tr.checkpoint_path, \"best-val\")\n", + "best_val_index = np.argmin(np.array(tr.history.history[\"val_loss\"]))\n", + "print(\n", + " \"Best Val Checkpoint (epoch %s) - angle: %.4f, val_angle: %.4f, direction: %.4f, val_direction: %.4f\"\n", + " % (\n", + " best_val_index,\n", + " tr.history.history[\"angle_metric\"][best_val_index],\n", + " tr.history.history[\"val_angle_metric\"][best_val_index],\n", + " tr.history.history[\"direction_metric\"][best_val_index],\n", + " tr.history.history[\"val_direction_metric\"][best_val_index],\n", + " )\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e2b1644", + "metadata": {}, + "outputs": [], + "source": [ + "last_checkpoint = \"cp-last.ckpt\"\n", "last_tflite = utils.generate_tflite(tr.checkpoint_path, last_checkpoint)\n", "utils.save_tflite(last_tflite, tr.checkpoint_path, \"last\")\n", "print(\n", - " \"Last Checkpoint (val_angle: %s, val_direction: %s): %s\"\n", + " \"Last Checkpoint - angle: %.4f, val_angle: %.4f, direction: %.4f, val_direction: %.4f\"\n", " % (\n", + " tr.history.history[\"angle_metric\"][-1],\n", " tr.history.history[\"val_angle_metric\"][-1],\n", + " tr.history.history[\"direction_metric\"][-1],\n", " tr.history.history[\"val_direction_metric\"][-1],\n", - " last_checkpoint,\n", " )\n", ")" ] @@ -402,7 +472,7 @@ "id": "d0c57018", "metadata": {}, "source": [ - "Evaluate the best model" + "### Evaluate the best model (train loss) on the training set" ] }, { @@ -412,13 +482,63 @@ "metadata": {}, "outputs": [], "source": [ - "best_model = utils.load_model(\n", - " os.path.join(tr.checkpoint_path, best_checkpoint),\n", + "best_train_model = utils.load_model(\n", + " os.path.join(tr.checkpoint_path, best_train_checkpoint),\n", + " tr.loss_fn,\n", + " tr.metric_list,\n", + " tr.custom_objects,\n", + ")\n", + "loss, mae, direction, angle = best_train_model.evaluate(\n", + " tr.train_ds,\n", + " steps=tr.image_count_train / tr.hyperparameters.TRAIN_BATCH_SIZE,\n", + " verbose=1,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92462709", + "metadata": {}, + "outputs": [], + "source": [ + "utils.show_batch(\n", + " dataset=tr.train_ds, policy=tr.hyperparameters.POLICY, model=best_train_model\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10a078b1", + "metadata": {}, + "outputs": [], + "source": [ + "utils.compare_tf_tflite(best_train_model, best_train_tflite)" + ] + }, + { + "cell_type": "markdown", + "id": "b2ad2605", + "metadata": {}, + "source": [ + "### Evaluate the best model (val loss) on the validation set" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e2ac5b20", + "metadata": {}, + "outputs": [], + "source": [ + "best_val_model = utils.load_model(\n", + " os.path.join(tr.checkpoint_path, best_val_checkpoint),\n", " tr.loss_fn,\n", " tr.metric_list,\n", " tr.custom_objects,\n", ")\n", - "test_loss, test_acc, test_dir, test_ang = best_model.evaluate(\n", + "loss, mae, direction, angle = best_val_model.evaluate(\n", " tr.test_ds,\n", " steps=tr.image_count_test / tr.hyperparameters.TEST_BATCH_SIZE,\n", " verbose=1,\n", @@ -432,16 +552,8 @@ "metadata": {}, "outputs": [], "source": [ - "NUM_SAMPLES = 15\n", - "(image_batch, cmd_batch), label_batch = next(iter(tr.test_ds))\n", - "pred_batch = best_model.predict(\n", - " (\n", - " tf.slice(image_batch, [0, 0, 0, 0], [NUM_SAMPLES, -1, -1, -1]),\n", - " tf.slice(cmd_batch, [0], [NUM_SAMPLES]),\n", - " )\n", - ")\n", - "utils.show_test_batch(\n", - " image_batch.numpy(), cmd_batch.numpy(), label_batch.numpy(), pred_batch\n", + "utils.show_batch(\n", + " dataset=tr.test_ds, policy=tr.hyperparameters.POLICY, model=best_val_model\n", ")" ] }, @@ -452,7 +564,7 @@ "metadata": {}, "outputs": [], "source": [ - "utils.compare_tf_tflite(best_model, best_tflite)" + "utils.compare_tf_tflite(best_val_model, best_val_tflite)" ] }, { @@ -501,7 +613,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.11" + "version": "3.9.13" + }, + "vscode": { + "interpreter": { + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + } } }, "nbformat": 4,