In [None]:
# Third-party packages and modules:
import pickle, os, gzip
import numpy as np
import matplotlib.pyplot as plt
import utils
import my_neural_network as mnn
#import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

def read_data():
    """Reads the states and actions recorded by drive_manually.py"""
    print("Reading data")
    with gzip.open('.warning /data/data_02new.pkl.gzip','rb') as f:
    #with gzip.open('./data/data_new.pkl.gzip','rb') as f:
        data = pickle.load(f)
    X = utils.vstack(data["state"])
    y = utils.vstack(data["action"])
    print("finish...")
    return X, y

def preprocess_data(X, y, hist_len, shuffle):
    """ Preprocess states and actions from expert dataset before feeding them to the agent """
    print('Preprocessing states. Shape:', X.shape)
    utils.check_invalid_actions(y)
    y_pp = utils.transl_action_env2agent(y)
    print("env2agent finished")
    X_pp = utils.preprocess_state(X)
    print("preprocess finished")
    X_pp, y_pp = utils.stack_history(X_pp, y_pp, hist_len, shuffle=shuffle)
    return X_pp, y_pp

def split_data(X, y, frac = 0.1):
    """ Splits data into training and validation set """
    split = int((1-frac) * len(y))
    X_train, y_train = X[:split], y[:split]
    X_valid, y_valid = X[split:], y[split:]
    return X_train, y_train, X_valid, y_valid

def plot_states(x_pp, X_tr=None, n=3):
    """ Plot some random states before and after preprocessing """
    pick = np.random.randint(0, len(x_pp), n)
    fig, axes = plt.subplots(n, 2, sharex=True, sharey=True, figsize=(20,20))
    for i, p in enumerate(pick):
        if X_tr is not None:
            axes[i,0].imshow(X_tr[p]/255)
        axes[i,1].imshow(np.squeeze(x_pp[p]), cmap='gray')
    fig.tight_layout()
    plt.show()

def plot_action_histogram(actions, title):
    """ Plot the histogram of actions from the expert dataset """
    acts_id = utils.unhot(actions)
    fig, ax = plt.subplots()
    bins = np.arange(-.5, utils.n_actions + .5)
    ax.hist(acts_id, range=(0,6), bins=bins, rwidth=.9)
    ax.set(title=title, xlim=(-.5, utils.n_actions -.5))
    plt.show()

class Agent:
    # Constructor is "overloaded" by the functions bellow.
    def __init__(self, model):
        # The neural network:
        self.model = model
        # Just a constant:
        self.accelerate = np.array([0.0, 1.0, 0.0], dtype=np.float32)


    @classmethod  # Constructor for a brand new model
    def from_scratch(cls, n_channels):
        layers = [
            mnn.layers.Input(input_shape=[96, 96, n_channels]), 
            mnn.layers.Conv2d(filters=16, kernel_size=5, stride=4), 
            mnn.layers.ReLU(), 
            mnn.layers.Dropout(drop_probability=0.5),
            mnn.layers.Conv2d(filters=32, kernel_size=3, stride=2), 
            mnn.layers.ReLU(), 
            mnn.layers.Dropout(drop_probability=0.5),
            mnn.layers.Flatten(), 
            mnn.layers.Linear(n_units=128), 
            mnn.layers.Linear(n_units=utils.n_actions), 
        ]
        model = mnn.models.Classifier_From_Layers(layers)
        return Agent(model)
    
    @classmethod  # Constructor to load a model from a file
    def from_file(cls, file_name):
        model = mnn.models.Classifier_From_File('saved_models/')
        return Agent(model)

    def train(self, X_train, y_train, X_valid, y_valid, n_batches, batch_size, lr, display_step):
        print("Training model")
        self.model.train(X_train, y_train, X_valid, y_valid, n_batches, batch_size, lr, display_step)

    def begin_new_episode(self, state0):
        # A history of the last n agent's actions
        self.action_history = deque(maxlen=100)
        # Buffer for actions that may eventually overwrite the model
        self.overwrite_actions = []
        # Keep track of how many state transitions were made
        self.action_counter = 0
        # This data structure (kind of a deque) will always store the
        # last 'history_lenght' states and will be fed to the model:
        self.state_hist = np.empty((1, state0.shape[0], state0.shape[1], utils.history_length))
        for _ in range(utils.history_length):
            self.__push_state(state0)

    def __push_state(self, state):
        # Push the current state to the history. 
        # Oldest state in history is discarded.
        sg = state.astype(np.float32)
        sg = np.expand_dims(sg, 0)
        sg = utils.preprocess_state(sg)
        self.state_hist[0,:,:,1:] = self.state_hist[0,:,:,:-1]
        self.state_hist[0,:,:,0] = sg[0]

    def get_action(self, env_state):
        # Add the current state to the state history:
        self.__push_state(env_state)

        # First actions will always be to accelerate:
        if self.action_counter < utils.dead_start:
            self.action_history.append(self.accelerate)
            self.action_counter += 1
            return self.accelerate

        # If the car is stuck for too long, the neural network is overwritten:
        if len(self.overwrite_actions) > 0:
            print('Neural network overwritten')
            action = self.overwrite_actions.pop()
            self.action_history.append(action)
            return action

        # Check if the car is frozen:
        if self.check_freeze():
            print('Freeze detected. Overwritting neural network from next state onwards')

        # Uses the NN to choose the next action:
        agent_action = self.model.predict(self.state_hist)
        agent_action = utils.transl_action_agent2env(agent_action)
        self.action_history.append(agent_action)
        return agent_action
    
    def check_freeze(self):
        # If all the last actions are all the same and they 
        # are not accelerate, then the car is stuck somewhere.
        fa = self.action_history[0]
        for a in self.action_history:
            if not np.all(a==fa):
                return False
            if np.all(a == self.accelerate):
                return False
        
        # If the code reaches this point, the car is stuck
        fa[2] = 0.0  # release break
        overwrite_cycles = 2
        one_cicle = 10 * [fa] + 10 * [self.accelerate]
        self.overwrite_actions = overwrite_cycles * one_cicle
        return True

    def save(self, file_name):
        # Save model to a file
        self.model.save(file_name, close_session=True)

 


In [None]:
    # Read data:
    X, y = read_data()
    # Preprocess it:
    X_pp, y_pp = preprocess_data(X, y, hist_len=utils.history_length, shuffle=False)
    # Plot action histogram. JUST FOR DEBUGGING.
    if True: plot_action_histogram(y_pp, 'Action distribution BEFORE balancing')   
    # Balance samples. Gets hide of 50% of the most common action (accelerate)
    X_pp, y_pp = utils.balance_actions(X_pp, y_pp, 0.5)
    # Plot action histogram. JUST FOR DEBUGGING.
    if True: plot_action_histogram(y_pp, 'Action distribution AFTER balancing')   
    # Plot some random states before and after preprocessing. JUST FOR DEBUGGING. 
    # Requires to run the above fucntion with hist_len=1, shuffle=False.
    if False: plot_states(X_pp, X)
    # Split data into training and validation:
    X_train, y_train, X_valid, y_valid = split_data(X_pp, y_pp, frac=.1)
    # Create a new agent from scratch:
    agent = Agent.from_scratch(n_channels=utils.history_length)
    #agent = Agent.from_file('saved_models/')
    # Train it:
    agent.train(X_train, y_train, X_valid, y_valid, n_batches=10000, batch_size=100, lr=5e-4, display_step=100)
    # Save it to file:
    agent.save('saved_models/')

In [2]:
!nvidia-smi

Thu Feb 16 19:42:59 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.05    Driver Version: 525.85.05    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0  On |                  N/A |
| 30%   28C    P5    23W / 320W |    574MiB / 10240MiB |     52%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
!kill -9 128296