# Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import math
import tensorflow as tf
import keras
from keras.models import Sequential, load_model
from keras.layers import Dense, Dropout, Activation
import h5py
from keras.optimizers import Adamax, Nadam
import sys
from writeNNet import saveNNet

from interval import interval, inf

from safe_train import propagate_interval, check_intervals, plot_policy

import matplotlib.pyplot as plt

# Options

In [None]:
ver = 4  # Neural network version
hu = 45  # Number of hidden units in each hidden layer in network
saveEvery = 3  # Epoch frequency of saving
totalEpochs = 20  # Total number of training epochs
BATCH_SIZE = 2**8
EPOCH_TO_PROJECT = 5
trainingDataFiles = (
    "../TrainingData/VertCAS_TrainingData_v2_%02d.h5"  # File format for training data
)
nnetFiles = (
    "../networks/SafeVertCAS_pra%02d_v%d_45HU_%03d.nnet"  # File format for .nnet files
)
advisories = {
    "COC": 0,
    "DNC": 1,
    "DND": 2,
    "DES1500": 3,
    "CL1500": 4,
    "SDES1500": 5,
    "SCL1500": 6,
    "SDES2500": 7,
    "SCL2500": 8,
}

pra = 1

In [None]:
print("Loading Data for VertCAS, pra %02d, Network Version %d" % (pra, ver))
f = h5py.File(trainingDataFiles % pra, "r")
X_train = np.array(f["X"])
Q = np.array(f["y"])
means = np.array(f["means"])
ranges = np.array(f["ranges"])
min_inputs = np.array(f["min_inputs"])
max_inputs = np.array(f["max_inputs"])
print(f"min inputs: {min_inputs}")
print(f"max inputs: {max_inputs}")

N, numOut = Q.shape
print(f"Setting up model with {numOut} outputs and {N} training examples")
num_batches = N / BATCH_SIZE

# Asymmetric loss function
lossFactor = 40.0

# NOTE(nskh): from HorizontalCAS which was updated to use TF
def asymMSE(y_true, y_pred):
    d = y_true - y_pred
    maxes = tf.argmax(y_true, axis=1)
    maxes_onehot = tf.one_hot(maxes, numOut)
    others_onehot = maxes_onehot - 1
    d_opt = d * maxes_onehot
    d_sub = d * others_onehot
    a = lossFactor * (numOut - 1) * (tf.square(d_opt) + tf.abs(d_opt))
    b = tf.square(d_opt)
    c = lossFactor * (tf.square(d_sub) + tf.abs(d_sub))
    d = tf.square(d_sub)
    loss = tf.where(d_sub > 0, c, d) + tf.where(d_opt > 0, a, b)
    return tf.reduce_mean(loss)

# Training: Standard

In [None]:
totalEpochs = 20
saveEvery = 1

In [None]:
# Define model architecture
model = Sequential()
# model.add(Dense(hu, init='uniform', activation='relu', input_dim=4))
# model.add(Dense(hu, init='uniform', activation='relu'))
# model.add(Dense(hu, init='uniform', activation='relu'))
# model.add(Dense(hu, init='uniform', activation='relu'))
# model.add(Dense(hu, init='uniform', activation='relu'))
# model.add(Dense(hu, init='uniform', activation='relu'))
model.add(Dense(hu, activation="relu", input_dim=4))
model.add(Dense(hu, activation="relu"))
model.add(Dense(hu, activation="relu"))
model.add(Dense(hu, activation="relu"))
model.add(Dense(hu, activation="relu"))
model.add(Dense(hu, activation="relu"))

# model.add(Dense(numOut, init="uniform"))
model.add(Dense(numOut))
opt = Nadam(learning_rate=0.0003)
model.compile(loss=asymMSE, optimizer=opt, metrics=["accuracy"])

# # Train and write nnet files
epoch = saveEvery
while epoch <= totalEpochs:
    model.fit(X_train, Q, epochs=saveEvery, batch_size=2**8, shuffle=True)
    saveFile = nnetFiles % (pra, ver, epoch)
    saveNNet(model, saveFile, means, ranges, min_inputs, max_inputs)
    epoch += saveEvery
    output_interval, penultimate_interval = propagate_interval(
        [
            interval[400, 500],
            interval[50, 51],
            interval[-51, -50],
            interval[20, 21],
        ],
        model,
        graph=False,
    )
    print(output_interval)
    plot_policy(model, f"images/standard_vcas_policy_viz_vo50_vi-50_epoch{epoch}.pdf", zoom=True)

In [None]:
model.save("models/july6-standard-acas-8epochs")

# Querying standard model for estimates of loss values

In [None]:
COC_INTERVAL = [
    interval[400, 500],
    interval[50, 51],
    interval[-51, -50],
    interval[20, 21],
]
action_names = [
    "COC",
    "DNC",
    "DND",
    "DES1500",
    "CL1500",
    "SDES1500",
    "SCL1500",
    "SDES2500",
    "SCL2500",
]

In [None]:
np.meshgrid(*[np.arange(400, 510, 25), np.arange(50, 51.1, 0.25), np.arange(-51, -49.9, 0.25)[::-1], np.arange(20, 21.1, 0.25)]).T

In [None]:
x_pred = np.vstack([np.arange(400, 510, 25), np.arange(50, 51.1, 0.25), np.arange(-51, -49.9, 0.25)[::-1], np.arange(20, 21.1, 0.25)]).T
x_pred

In [None]:
y_pred = model.predict(x_pred)
y_pred

In [None]:
advisory_idxs = np.argmax(y_pred, axis=1)

In [None]:
[action_names[idx] for idx in advisory_idxs]

In [None]:
y_pred = model.predict(np.array([[400, 50, -50, 20]]))
action_names[np.argmax(y_pred)]

# Plotting Policy

In [None]:
plot_policy(model)

In [None]:
action_names

In [None]:
hs = np.hstack([np.linspace(-5000, -2000, 20), np.linspace(-2000, 2000, 40), np.linspace(2000, 5000, 20)])
hs.shape

vo = 50
vi = -50
x_grid = None
taus = np.linspace(0, 40, 81)
for tau in taus:
    grid_component = np.vstack([hs, np.ones(hs.shape)*vo, np.ones(hs.shape) * vi, np.ones(hs.shape)*tau]).T
    if x_grid is not None:
        x_grid = np.vstack([x_grid, grid_component])
    else:
        x_grid = grid_component
    
x_grid.shape

In [None]:
plt.scatter(x_grid[:, 3], x_grid[:, 0], s=10)

In [None]:
y_pred = model.predict(x_grid)
advisory_idxs = np.argmax(y_pred, axis=1)
commands = [action_names[idx] for idx in advisory_idxs]

ra1 = (0.9,0.9,0.9) # white
ra2 = (.0,1.0,1.0) # cyan
ra3 = (144.0/255.0,238.0/255.0,144.0/255.0) # lightgreen
ra4 = (30.0/255.0,144.0/255.0,1.0) # dodgerblue
ra5 = (0.0,1.0,.0) # lime
ra6 = (0.0,0.0,1.0) # blue
ra7 = (34.0/255.0,139.0/255.0,34.0/255.0) # forestgreen
ra8 = (0.0,0.0,128.0/255.0) # navy
ra9 = (0.0,100.0/255.0,0.0) # darkgreen
colors = [ra1,ra2,ra3,ra4,ra5,ra6,ra7,ra8,ra9]
bg_colors = [(1.0,1.0,1.0)]

In [None]:
# dict indexed by color/advisory of all points
xs = {0: [], 1: [], 2: [], 3: [], 4: [], 5: [], 6: [], 7: [], 8: [], 9: []}
ys = {0: [], 1: [], 2: [], 3: [], 4: [], 5: [], 6: [], 7: [], 8: [], 9: []}

for i, advisory_idx in enumerate(advisory_idxs):
    color = colors[advisory_idx]
    scatter_x = x_grid[i, 3] # tau
    scatter_y = x_grid[i, 0] # h 
    xs[advisory_idx].append(scatter_x)
    ys[advisory_idx].append(scatter_y)
print("done constructing dicts")

In [None]:
plt.figure()
plt.tight_layout()
for i in range(len(colors)):
    plt.scatter(xs[i], ys[i], s = 10, c = [colors[i]])
plt.legend(action_names)
plt.xlabel("Tau (sec)")
plt.ylabel("h (ft)")
plt.title(f"Policy for vo:{vo} and vi:{vi}")
plt.savefig(f"viz_policy_vo{vo}_vi{vi}.pdf")
plt.show()

# Training: "safe" with projection

In [None]:
np.mean([1, 2, 3.4]).round(3)

In [None]:
# Redefine model architecture
model = Sequential()
model.add(Dense(hu, activation="relu", input_dim=4))
model.add(Dense(hu, activation="relu"))
model.add(Dense(hu, activation="relu"))
model.add(Dense(hu, activation="relu"))
model.add(Dense(hu, activation="relu"))
model.add(Dense(hu, activation="relu"))

# model.add(Dense(numOut, init="uniform"))
model.add(Dense(numOut))
opt = Nadam(learning_rate=0.0003)

epoch_losses = []
epoch_accuracies = []
weights_before_projection = []
weights_after_projection = []
for epoch in range(totalEpochs):
    print(f"on epoch {epoch}")

    rng = np.random.default_rng()

    train_indices = np.arange(X_train.shape[0])

    rng.shuffle(train_indices)  # in-place

    x_shuffled = X_train[train_indices, :]
    y_shuffled = Q[train_indices, :]

    x_batched = np.split(
        x_shuffled, np.arange(BATCH_SIZE, len(x_shuffled), BATCH_SIZE)
    )
    y_batched = np.split(
        y_shuffled, np.arange(BATCH_SIZE, len(y_shuffled), BATCH_SIZE)
    )

    dataset_batched = list(zip(x_batched, y_batched))
    batch_losses = []
    epoch_accuracy = keras.metrics.CategoricalAccuracy()
    for step, (x_batch_train, y_batch_train) in enumerate(dataset_batched):
        with tf.GradientTape() as tape:
            y_pred = model(x_batch_train, training=True)  # Forward pass
            loss = asymMSE(y_batch_train, y_pred)
            batch_losses.append(loss.numpy())
            epoch_accuracy.update_state(y_batch_train, y_pred)

        if step % int(num_batches / 500) == 0:
            print(
                f"{np.round(step / num_batches * 100, 1)}% through this epoch with loss",
                f"{loss.numpy()} and accuracy {epoch_accuracy.result()}\r",
                end="",
            )

        # Compute gradients
        trainable_vars = model.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        opt.apply_gradients(zip(gradients, trainable_vars))
        
    epoch_accuracies.append(epoch_accuracy.result())
    epoch_losses.append(np.mean(batch_losses))

    print("")
    print(f"Mean loss over this epoch: {np.mean(batch_losses)}")
    print(f"Mean accuracy over this epoch: {epoch_accuracy.result()}")
    
    weights_before_projection.append(model.layers[-1].weights)
    
    # Parameters:
    # - h (ft): Altitude of intruder relative to ownship, [-8000, 8000]
    # - vO (ft/s): ownship vertical climb rate, [-100, 100]
    # - vI (ft/s): intruder vertical climb rate, [-100, 100]
    # - τ (sec): time to loss of horizontal separation
    output_interval, penultimate_interval = propagate_interval(
        [
            interval[7880, 7900],
            interval[95, 96],
            interval[5, 6],
            interval[38, 40],
        ],
        model,
        graph=False,
    )
    
    if not check_intervals(output_interval, desired_interval):
        print(f"safe region test FAILED, interval was {output_interval}")
        if epoch % EPOCH_TO_PROJECT == 0:
            print(f"\nProjecting weights at epoch {epoch}.")
            intervals_to_project = []
            assert type(output_interval) == type(desired_interval)
            if type(output_interval) is list:
                assert len(output_interval) == len(desired_interval)
                for i in range(len(output_interval)):
                    if (
                        desired_interval[i] is not None
                        and output_interval[i] not in desired_interval[i]
                    ):
                        intervals_to_project.append(i)
            else:
                intervals_to_project.append(0)

            weights_tf = model.layers[-1].weights
            weights_np = weights_tf[0].numpy()
            biases_np = weights_tf[1].numpy()

            for idx in intervals_to_project:
                weights_to_project = np.hstack([weights_np[:, idx], biases_np[idx]])
                proj = project_weights(
                    desired_interval[idx], penultimate_interval, weights_to_project
                )
                weights_np[:, idx] = proj[:-1]
                biases_np[idx] = proj[-1]

            model.layers[-1].set_weights([weights_np, biases_np])
            output_interval, _ = propagate_interval(
                COC_INTERVAL,
                model,
                graph=False,
            )
            weights_after_projection.append(model.layers[-1].weights)
            print(f"After projecting, output interval is {output_interval}")
    else:
        print(f"safe region test passed, interval was {output_interval}")
        
    with open("projection_acas.pickle", "wb") as f:
        data = {
            "accuracies": epoch_accuracies,
            "losses": epoch_losses,
            "weights_before_projection": weights_before_projection,
            "weights_after_projection" : weights_after_projection
        }
        pickle.dump()

In [None]:
# Redefine model architecture
model = Sequential()
model.add(Dense(hu, activation="relu", input_dim=4))
model.add(Dense(hu, activation="relu"))
model.add(Dense(hu, activation="relu"))
model.add(Dense(hu, activation="relu"))
model.add(Dense(hu, activation="relu"))
model.add(Dense(hu, activation="relu"))

# model.add(Dense(numOut, init="uniform"))
model.add(Dense(numOut))
opt = Nadam(learning_rate=0.0003)

epoch_losses = []
epoch_accuracies = []

rng = np.random.default_rng()

train_indices = np.arange(X_train.shape[0])

rng.shuffle(train_indices)  # in-place

x_shuffled = X_train[train_indices, :]
y_shuffled = Q[train_indices, :]

x_batched = np.split(
    x_shuffled, np.arange(BATCH_SIZE, len(x_shuffled), BATCH_SIZE)
)
y_batched = np.split(
    y_shuffled, np.arange(BATCH_SIZE, len(y_shuffled), BATCH_SIZE)
)

dataset_batched = list(zip(x_batched, y_batched))
batch_losses = []
epoch_accuracy = keras.metrics.CategoricalAccuracy()

In [None]:
x_batch_train, y_batch_train = dataset_batched[0]

In [None]:
x_batch_train

In [None]:
x_batch_train.shape

In [None]:
y_batch_train

In [None]:
y_batch_train.shape

In [None]:
with tf.GradientTape() as tape:
    y_pred = model(x_batch_train, training=True)  # Forward pass
    loss = asymMSE(y_batch_train, y_pred)
    batch_losses.append(loss.numpy())
    epoch_accuracy.update_state(y_batch_train, y_pred)
    print(epoch_accuracy.result())

In [None]:
y_pred.numpy()

In [None]:
y_pred.numpy().shape

In [None]:
np.argmin(y_pred, axis=1)

In [None]:
np.argmin(y_batch_train, axis=1)

In [None]:
np.mean(np.argmin(y_pred, axis=1) == np.argmin(y_batch_train, axis=1))