In [None]:
%matplotlib inline
import os
import sys
import gc
from tqdm.notebook import tqdm

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from commons.configuration_manager import ConfigurationManager
from src.learning.training.label_collector import LabelCollector
from src.learning.training.training_file_reader import TrainingFileReader
from src.learning.training.training_transformer import TrainingTransformer
from src.learning.training.generator import Generator
from src.learning.models import create_cnn, create_mlp, create_multi_model

import pygame
import cv2

window_width = 500
window_height = 200

black = (0, 0, 0)
red = (255, 0, 0)
green = (0, 255, 0)
blue = (0, 0, 255)

pygame.init()
font = pygame.font.SysFont('Roboto', 20)

In [None]:
config_manager = ConfigurationManager()
config = config_manager.config

reader = TrainingFileReader(path_to_training='../../training/laps/')
transformer = TrainingTransformer(config)
collector = LabelCollector()

path_to_training = '../../training/laps/'

def plot_stuff(title, plot_elems, figsize=(18, 10)):
    fig=plt.figure(figsize=figsize)
    plt.title(title)
    #plt.ylabel('dunno')
    plt.xlabel('Count')
    
    for plot_elem in plot_elems:
        plt.plot(plot_elem['data'], label=plot_elem['label'], alpha=plot_elem['alpha'], fmt='-o')

    plt.grid(axis='y')
    plt.legend(loc='best')
    plt.show()
    
# axis=2 for frames, axis=0 for telem and diffs
def memory_creator(instance, memory, length=4, interval=2, axis=2):
    memory.append(instance)
    
    near_memory = memory[::-interval]
    if len(near_memory) < length:
        return None
    
    if len(memory) >= length * interval:
        memory.pop(0)
        
    return np.concatenate(near_memory, axis=axis)


def read_stored_data(filename, numeric_columns, diff_columns):
    telemetry = reader.read_specific_telemetry_columns(filename + '.csv', numeric_columns)
    diffs = reader.read_specific_telemetry_columns(filename + '.csv', diff_columns)
    frames = reader.read_video(filename + '.avi')
    resized_frames = transformer.resize_and_normalize_video(frames)
    
    return resized_frames, telemetry.to_numpy(), diffs.to_numpy()


def read_stored_video(filename):
    return reader.read_video(filename + '.avi')
    
    
def create_memorized_dataset(frames, telemetry, diffs, length, interval): 
    # final length diff is (length - 1) * interval
    mem_slice_frames = []
    mem_slice_telemetry = []
    
    len_diff = (length - 1) * interval
    mem_frames = np.zeros((frames.shape[0] - len_diff, *frames.shape[1:-1], frames.shape[-1] * length))
    mem_telems = np.zeros((telemetry.shape[0] - len_diff, telemetry.shape[1] * length))
    
    for i in range(0, frames.shape[0]):
        mem_frame = memory_creator(frames[i], mem_slice_frames, length=length, interval=interval, axis=2)
        mem_telem = memory_creator(telemetry[i], mem_slice_telemetry, length=length, interval=interval, axis=0)
        
        if mem_frame is not None:
            mem_frames[i - len_diff] = mem_frame
            mem_telems[i - len_diff] = mem_telem
            
    mem_diffs = diffs[len_diff:]
    
    assert mem_frames.shape[0] == mem_telems.shape[0] == mem_diffs.shape[0], "Lengths differ!"
    return mem_frames, mem_telems, mem_diffs
            
    
def get_mem_dataset_with_full_video(filename, length, interval):
    mem_frames_np, mem_telemetry_np, mem_diffs_np = create_memorized_dataset(*read_stored_data(filename, collector.steering_columns(), collector.diff_steering_columns()), length, interval)
    view_frames_np = read_stored_video(filename)[((length - 1) * interval):]
    
    assert view_frames_np.shape[0] == mem_frames_np.shape[0], "Frames length mismatch!"
    assert mem_frames_np.shape[0] == mem_telemetry_np.shape[0] == mem_diffs_np.shape[0], "Mem lengths differ!"
    return mem_frames_np, mem_telemetry_np, mem_diffs_np, view_frames_np

In [None]:
filenames = ['lap_5_2020_01_24', 'lap_6_2020_01_24', 'lap_7_2020_01_24']
experiments = [(1, 1), (4, 1), (4, 3), (16, 1)]
epochs = 8
batch_size = 32
verbose = 1
memory = experiments[1]
generator = Generator(memory=memory, batch_size=batch_size)

mem_frames, mem_telems, mem_diffs, display_frames = get_mem_dataset_with_full_video(filenames[1], *memory)

mlp = create_mlp(input_shape=mem_telems[0].shape)
cnn = create_cnn(input_shape=mem_frames[0].shape)
multi = create_multi_model(mlp, cnn, output_shape=mem_diffs[0].shape[0])

multi.fit(generator.generate(data='train'),
          steps_per_epoch=generator.train_batch_count,
          validation_data=generator.generate(data='test'),
          validation_steps=generator.test_batch_count,
          epochs=epochs, verbose=verbose)

In [None]:
pygame.display.init()
pygame.display.set_caption("Prediction viewer")
screen = pygame.display.set_mode((window_width, window_height))

try:
    for i in range(0, display_frames.shape[0]):
        frame = display_frames[i]
        screen.fill(black)
        
        converted_telem = mem_telems[i][np.newaxis, :]
        converted_frame = mem_frames[i].astype(np.float32)[np.newaxis, :]
        
        # gear, steering, throttle, braking
        predictions = multi.predict([converted_frame, converted_telem])[0]
        #print(mem_diffs[i])
        #print(predictions)
        
        steering = predictions[0]
        prediction_text = "{}".format(steering)
        if steering >= 0.0:
            prediction_texture = font.render(prediction_text, True, green)
        else:
            prediction_texture = font.render(prediction_text, True, blue)
        
        m_frame = mem_frames[i][:,:,:3]
        frame = np.rot90(frame)
        surface = pygame.surfarray.make_surface(frame)
        
        x = (window_width - frame.shape[0]) // 2
        y = 0
        
        screen.blit(prediction_texture, (200, window_height - 25))
        screen.blit(surface, (x, y))
        pygame.display.update()

        for event in pygame.event.get():
            if event.type == pygame.KEYDOWN:
                if event.key == pygame.K_ESCAPE:
                    sys.exit(0)
    pygame.display.quit()
    
except (KeyboardInterrupt, SystemExit) as ex:
    pygame.display.quit()

In [None]:
from tensorflow.keras.utils import plot_model
plot_model(multi, to_file='model.png')