In [1]:
import numpy as np
from stable_baselines3 import PPO
# from sb3_contrib import TRPO
from gymnasium import spaces, Env
from boat_simulation import Boat, wrap_phase

2024-02-29 16:26:28.711912: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-02-29 16:26:28.730842: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-29 16:26:28.730859: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-29 16:26:28.731391: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-02-29 16:26:28.734997: I tensorflow/core/platform/cpu_feature_guar

In [2]:
import torch
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor


def linear(device, x, in_units, out_units, activation):
    x = nn.Linear(in_units, out_units).to(device)(x)
    return activation().to(device)(x)

def mlp(device, x, units):
    for i in range(0, len(units)):
        if not isinstance(units[i], tuple):
            units[i] = (units[i], nn.ReLU)

    for i in range(1, len(units)):
        prev_layer_size, _ = units[i - 1]
        layer_size, activation = units[i]
        x = linear(device, x, prev_layer_size, layer_size, activation)
    return x

def self_attention(device, x, num_heads, num_layers):
    _batch, seq_len, embed_size = x.shape

    for i in range(1, num_layers):
        lin = linear(device, x, embed_size, embed_size, nn.ReLU)
        attn,  _weights = nn.MultiheadAttention(embed_size, num_heads, dropout=0.5).to(device)(lin, lin, lin, need_weights=False)
        x = attn + x
    return x


class CustomMLP(BaseFeaturesExtractor):
    def __init__(self, observation_space, features_dim=1):
        super(CustomMLP, self).__init__(observation_space, features_dim)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def forward(self, observations):
        x = observations.to(self.device)
        return mlp(self.device, x, [7, 400, 300, (self.features_dim, nn.Tanh)])


# Define the LSTM model
class NonLSTM(BaseFeaturesExtractor):
    def __init__(self, observation_space, features_dim):
        super(NonLSTM, self).__init__(observation_space, features_dim)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.hidden_layer_size = 64
        x1, x2 = observation_space.shape
        input_size = x1 * x2
        self.linear = lambda x: mlp(self.device, x, [input_size, self.hidden_layer_size, self.hidden_layer_size, features_dim])

    def forward(self, input_seq):
        batch_size = input_seq.shape[0]
        input = input_seq.view(batch_size, -1)
        predictions = self.linear(input)
        return predictions

class SelfAttentionExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space, features_dim):
        super(SelfAttentionExtractor, self).__init__(observation_space, features_dim)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        (seq_size, input_size,) = observation_space.shape
        self.attention = lambda x: self_attention(self.device, x, 8, 1)
        self.input_layer = lambda x: mlp(self.device, x, [input_size, features_dim])
        self.output_layer = lambda x: mlp(self.device, x, [seq_size * features_dim, (seq_size // 2 + 1) * features_dim, features_dim])

    def forward(self, input_seq):
        input = self.input_layer(input_seq.to(self.device))
        hidden = self.attention(input)
        hidden_flat = nn.Flatten().to(self.device)(hidden)
        output = self.output_layer(hidden_flat)
        return output

In [3]:
from environments import MultiMarkEnv

In [4]:
from IPython.core.display import display, HTML, Javascript

disable_scroll_script = """
var idx = Jupyter.notebook.get_selected_index();
var cell = Jupyter.notebook.get_cell(idx);
cell.output_area._should_scroll = function(lines) {
    return false;
}
"""

display(Javascript(disable_scroll_script))

  from IPython.core.display import display, HTML, Javascript


<IPython.core.display.Javascript object>

In [5]:
import plotly.graph_objects as go
import time
import ipywidgets as widgets
from IPython.display import display

outer_radius = 2 * 250 + 2.5 * 0.1 * 250

dt = 15
num_marks = 2
max_seconds = 500 * num_marks / dt

# Initialize the figure and scatter plot
fig = go.FigureWidget()
scatter = fig.add_scatter(mode='markers+lines', name='trajectory')
marks_scatter = fig.add_scatter(mode='markers+text', name='marks')
fig.update_xaxes(range=[-outer_radius, outer_radius],dtick=25)
fig.update_yaxes(range=[-outer_radius, outer_radius],dtick=25)
fig.layout.width=800
fig.layout.height=800

heading_fig = go.FigureWidget()
heading_scatter = heading_fig.add_scatter(mode='markers+lines', name='heading')
rudder_angle_scatter = heading_fig.add_scatter(mode='markers+lines', name='rudder angle')
heading_fig.update_xaxes(range=[0, max_seconds],dtick=25)
heading_fig.update_yaxes(range=[-210, 210],dtick=30)
heading_fig.layout.width=800
heading_fig.layout.height=600
heading_fig.layout.title = "Heading over time"

vmg_fig = go.FigureWidget()
vmg_scatter = vmg_fig.add_scatter(mode='markers+lines')
vmg_fig.update_xaxes(range=[0, max_seconds],dtick=25)
vmg_fig.update_yaxes(range=[-4, 4],dtick=2)
vmg_fig.layout.width=800
vmg_fig.layout.height=400
vmg_fig.layout.title = "VMG over time"

# Initialize output for the text
out = widgets.Output()

# Display text and figure
display(out)
display(fig)
display(heading_fig)
display(vmg_fig)

episode = 0

colormap = ['blue', 'red', 'green', 'orange', 'black']

def plot(data, marks, bounds):
    global episode, out, fig
    episode += 1
    if episode % 50 != 0:
        return
    x_values = [point['x'] for point in data]
    y_values = [point['y'] for point in data]
    meta_values = [{k: '%.3f' % v for (k, v) in point['meta'].items()} for point in data]

    min_x, max_x, min_y, max_y = bounds


    color_meta = [point['meta']['current_mark'] for point in data]
    colors = [colormap[color % len(colormap)] for color in color_meta]

    headings = [point['meta']['heading'] for point in data]
    rudder_angles = [point['meta']['rudder_angle'] for point in data]
    vmg = [point['meta']['vmg'] for point in data]

    global_meta = {'final_position': (x_values[-1], y_values[-1]), 'reward': sum([point['meta']['reward'] for point in data]), 'iters': len(data), 'episode': episode}

    # Update data
    scatter = fig.data[0]
    scatter.x = x_values
    scatter.y = y_values
    scatter.hovertext = meta_values  # Add this line
    scatter.marker.color = colors
    scatter.line.color = 'lightgrey'

    marks_scatter = fig.data[1]
    marks_scatter.x = [x for x, _ in marks]
    marks_scatter.y = [y for _, y in marks]
    marks_scatter.text = list(range(len(marks)))
    marks_scatter.textposition = 'bottom right'
    marks_scatter.marker.color = 'black'
    marks_scatter.marker.symbol = 'x'
    marks_scatter.marker.size = 10

    heading_scatter = heading_fig.data[0]
    heading_scatter.x = list(range(len(headings)))
    heading_scatter.y = headings
    heading_scatter.marker.color = colors
    heading_scatter.line.color = 'lightgrey'

    rudder_angle_scatter = heading_fig.data[1]
    rudder_angle_scatter.x = list(range(len(rudder_angles)))
    rudder_angle_scatter.y = rudder_angles
    rudder_angle_scatter.marker.color = 'lightgreen'
    rudder_angle_scatter.line.color = 'lightgreen'

    vmg_scatter = vmg_fig.data[0]
    vmg_scatter.x = list(range(len(vmg)))
    vmg_scatter.y = vmg
    vmg_scatter.marker.color = colors
    vmg_scatter.line.color = 'lightgrey'

    fig.update_xaxes(range=[min_x * 1.1, max_x * 1.1],dtick=25)
    fig.update_yaxes(range=[min_y * 1.1, max_y * 1.1],dtick=25)
    fig.layout.width=800
    fig.layout.height=800

    heading_fig.update_yaxes(range=[-210, 210],dtick=30)
    heading_fig.layout.width=800
    heading_fig.layout.height=600

    vmg_fig.layout.width=800
    vmg_fig.layout.height=400

    # Update the text output
    with out:
        out.clear_output(wait=True)
        print(f"Global Meta: {str(global_meta)}")

None

Output()

FigureWidget({
    'data': [{'mode': 'markers+lines',
              'name': 'trajectory',
              'type': 'scatter',
              'uid': '2a31f430-7a86-42f9-82cf-882b37e79cff'},
             {'mode': 'markers+text', 'name': 'marks', 'type': 'scatter', 'uid': '732814e4-aa01-4014-9e19-b4d777e6da1f'}],
    'layout': {'height': 800,
               'template': '...',
               'width': 800,
               'xaxis': {'dtick': 25, 'range': [-562.5, 562.5]},
               'yaxis': {'dtick': 25, 'range': [-562.5, 562.5]}}
})

FigureWidget({
    'data': [{'mode': 'markers+lines',
              'name': 'heading',
              'type': 'scatter',
              'uid': '9647bf5f-0196-48fb-ab22-5cf9fc7c4b58'},
             {'mode': 'markers+lines',
              'name': 'rudder angle',
              'type': 'scatter',
              'uid': '4cf6d123-7c2c-4060-bc49-023eff9b79ad'}],
    'layout': {'height': 600,
               'template': '...',
               'title': {'text': 'Heading over time'},
               'width': 800,
               'xaxis': {'dtick': 25, 'range': [0, 66.66666666666667]},
               'yaxis': {'dtick': 30, 'range': [-210, 210]}}
})

FigureWidget({
    'data': [{'mode': 'markers+lines', 'type': 'scatter', 'uid': 'bc712978-2839-4b86-8655-47542e6c336a'}],
    'layout': {'height': 400,
               'template': '...',
               'title': {'text': 'VMG over time'},
               'width': 800,
               'xaxis': {'dtick': 25, 'range': [0, 66.66666666666667]},
               'yaxis': {'dtick': 2, 'range': [-4, 4]}}
})

In [6]:
# Configuration

r = 250
actions = np.array([-90, -45, 0, 45, 90]) / 180.0

config = {
    'max_marks': 1,
    'max_seconds_per_leg': 500,
    'plot_fn': plot,
    'leg_radius': r,
    'actions': actions,
    'target_tolerance_multiplier': 5
}

outer_radius = config['max_marks'] * 2 * r + 0.2 * r
bounds = [-outer_radius, outer_radius, -outer_radius, outer_radius]
# Initialize Environment
env = MultiMarkEnv(config, dt=5, bounds=bounds, seq_size=3, target_phase_steps=2, heading_phase_steps=4, radius_multipliers=[1])

# Initialize PPO model
model = PPO("MlpPolicy", env, verbose=1, device='cuda', policy_kwargs={
    "net_arch": [64, 400, 300],
    # "features_extractor_class": SelfAttentionExtractor,
    # "features_extractor_kwargs": {"features_dim": 64}
})

# Train the model
# model.learn(total_timesteps=1_000_000)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [7]:
model.learn(total_timesteps=4_000_000)

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 81.7     |
|    ep_rew_mean     | -0.754   |
| time/              |          |
|    fps             | 1770     |
|    iterations      | 1        |
|    time_elapsed    | 1        |
|    total_timesteps | 2048     |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 80           |
|    ep_rew_mean          | -0.667       |
| time/                   |              |
|    fps                  | 1409         |
|    iterations           | 2            |
|    time_elapsed         | 2            |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0042808233 |
|    clip_fraction        | 0.0566       |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.42        |
|    explained_variance   | -1.08        |
|    learning_r

<stable_baselines3.ppo.ppo.PPO at 0x7f44522e8220>

In [8]:
# Initialize Environment
env = MultiMarkEnv(config, dt=5, bounds=bounds, seq_size=1, target_phase_steps=4, heading_phase_steps=4)

# Initialize PPO model
model = PPO("MlpPolicy", env, verbose=1, device='cuda', policy_kwargs={
    "net_arch": [64, 400, 300],
    # "features_extractor_class": SelfAttentionExtractor,
    # "features_extractor_kwargs": {"features_dim": 64}
})

model.learn(total_timesteps=1_000_000)


Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 85.3     |
|    ep_rew_mean     | -1.22    |
| time/              |          |
|    fps             | 2000     |
|    iterations      | 1        |
|    time_elapsed    | 1        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 82.9        |
|    ep_rew_mean          | -1.1        |
| time/                   |             |
|    fps                  | 1595        |
|    iterations           | 2           |
|    time_elapsed         | 2           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.004321207 |
|    clip_fraction        | 0.027       |
|    clip_range           | 0.2         |
|    entropy_loss  

<stable_baselines3.ppo.ppo.PPO at 0x7f44400dcbb0>

In [9]:
model.save("BoatControl_boat_model_coarse.dat")

In [10]:
import plotly.graph_objects as go
import time
import ipywidgets as widgets
from IPython.display import display

outer_radius = 2 * 250 + 2.5 * 0.1 * 250

dt = 5
num_marks = 2
max_seconds = 500 * num_marks / dt

# Initialize the figure and scatter plot
fig = go.FigureWidget()
scatter = fig.add_scatter(mode='markers+lines', name='trajectory')
marks_scatter = fig.add_scatter(mode='markers+text', name='marks')
fig.update_xaxes(range=[-outer_radius, outer_radius],dtick=25)
fig.update_yaxes(range=[-outer_radius, outer_radius],dtick=25)
fig.layout.width=800
fig.layout.height=800

heading_fig = go.FigureWidget()
heading_scatter = heading_fig.add_scatter(mode='markers+lines', name='heading')
rudder_angle_scatter = heading_fig.add_scatter(mode='markers+lines', name='rudder angle')
heading_fig.update_xaxes(range=[0, max_seconds],dtick=25)
heading_fig.update_yaxes(range=[-210, 210],dtick=30)
heading_fig.layout.width=800
heading_fig.layout.height=600
heading_fig.layout.title = "Heading over time"

vmg_fig = go.FigureWidget()
vmg_scatter = vmg_fig.add_scatter(mode='markers+lines')
vmg_fig.update_xaxes(range=[0, max_seconds],dtick=25)
vmg_fig.update_yaxes(range=[-4, 4],dtick=2)
vmg_fig.layout.width=800
vmg_fig.layout.height=400
vmg_fig.layout.title = "VMG over time"

# Initialize output for the text
out = widgets.Output()

# Display text and figure
display(out)
display(fig)
display(heading_fig)
display(vmg_fig)

episode = 0

colormap = ['blue', 'red', 'green', 'orange', 'black']

def plot_new(data, marks, bounds):
    global episode, out, fig
    episode += 1
    if episode % 50 != 0:
        return
    x_values = [point['x'] for point in data]
    y_values = [point['y'] for point in data]
    meta_values = [{k: '%.3f' % v for (k, v) in point['meta'].items()} for point in data]

    min_x, max_x, min_y, max_y = bounds


    color_meta = [point['meta']['current_mark'] for point in data]
    colors = [colormap[color % len(colormap)] for color in color_meta]

    headings = [point['meta']['heading'] for point in data]
    rudder_angles = [point['meta']['rudder_angle'] for point in data]
    vmg = [point['meta']['vmg'] for point in data]

    global_meta = {'final_position': (x_values[-1], y_values[-1]), 'reward': sum([point['meta']['reward'] for point in data]), 'iters': len(data), 'episode': episode}

    # Update data
    scatter = fig.data[0]
    scatter.x = x_values
    scatter.y = y_values
    scatter.hovertext = meta_values  # Add this line
    scatter.marker.color = colors
    scatter.line.color = 'lightgrey'

    marks_scatter = fig.data[1]
    marks_scatter.x = [x for x, _ in marks]
    marks_scatter.y = [y for _, y in marks]
    marks_scatter.text = list(range(len(marks)))
    marks_scatter.textposition = 'bottom right'
    marks_scatter.marker.color = 'black'
    marks_scatter.marker.symbol = 'x'
    marks_scatter.marker.size = 10

    heading_scatter = heading_fig.data[0]
    heading_scatter.x = list(range(len(headings)))
    heading_scatter.y = headings
    heading_scatter.marker.color = colors
    heading_scatter.line.color = 'lightgrey'

    rudder_angle_scatter = heading_fig.data[1]
    rudder_angle_scatter.x = list(range(len(rudder_angles)))
    rudder_angle_scatter.y = rudder_angles
    rudder_angle_scatter.marker.color = 'lightgreen'
    rudder_angle_scatter.line.color = 'lightgreen'

    vmg_scatter = vmg_fig.data[0]
    vmg_scatter.x = list(range(len(vmg)))
    vmg_scatter.y = vmg
    vmg_scatter.marker.color = colors
    vmg_scatter.line.color = 'lightgrey'

    fig.update_xaxes(range=[min_x * 1.1, max_x * 1.1],dtick=25)
    fig.update_yaxes(range=[min_y * 1.1, max_y * 1.1],dtick=25)
    fig.layout.width=800
    fig.layout.height=800

    heading_fig.update_yaxes(range=[-210, 210],dtick=30)
    heading_fig.layout.width=800
    heading_fig.layout.height=600

    vmg_fig.layout.width=800
    vmg_fig.layout.height=400

    # Update the text output
    with out:
        out.clear_output(wait=True)
        print(f"Global Meta: {str(global_meta)}")

None

Output()

FigureWidget({
    'data': [{'mode': 'markers+lines',
              'name': 'trajectory',
              'type': 'scatter',
              'uid': '15832426-2f60-4b0e-9e4c-20cafe0d250e'},
             {'mode': 'markers+text', 'name': 'marks', 'type': 'scatter', 'uid': '7f1109cf-5c4f-4f1d-85d2-d7b875515239'}],
    'layout': {'height': 800,
               'template': '...',
               'width': 800,
               'xaxis': {'dtick': 25, 'range': [-562.5, 562.5]},
               'yaxis': {'dtick': 25, 'range': [-562.5, 562.5]}}
})

FigureWidget({
    'data': [{'mode': 'markers+lines',
              'name': 'heading',
              'type': 'scatter',
              'uid': '1079a7a1-7bba-4fbc-9682-ab48acb25668'},
             {'mode': 'markers+lines',
              'name': 'rudder angle',
              'type': 'scatter',
              'uid': '5e4fd900-599e-41bd-87ee-e82ea1883eee'}],
    'layout': {'height': 600,
               'template': '...',
               'title': {'text': 'Heading over time'},
               'width': 800,
               'xaxis': {'dtick': 25, 'range': [0, 200.0]},
               'yaxis': {'dtick': 30, 'range': [-210, 210]}}
})

FigureWidget({
    'data': [{'mode': 'markers+lines', 'type': 'scatter', 'uid': 'd269d9e4-352c-422e-beec-10b63ca048ff'}],
    'layout': {'height': 400,
               'template': '...',
               'title': {'text': 'VMG over time'},
               'width': 800,
               'xaxis': {'dtick': 25, 'range': [0, 200.0]},
               'yaxis': {'dtick': 2, 'range': [-4, 4]}}
})

In [11]:
config_new = {**config}
config_new['plot_fn'] = plot_new
config_new['target_tolerance_multiplier'] = 3

model.set_env(MultiMarkEnv(config, dt=5, seq_size=1, bounds=bounds, target_phase_steps=8, heading_phase_steps=8))

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [12]:
model.learn(total_timesteps=1500_000)

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 49.8     |
|    ep_rew_mean     | 1.68     |
| time/              |          |
|    fps             | 2022     |
|    iterations      | 1        |
|    time_elapsed    | 1        |
|    total_timesteps | 2048     |
---------------------------------
---------------------------------------
| rollout/                |           |
|    ep_len_mean          | 46.9      |
|    ep_rew_mean          | 1.45      |
| time/                   |           |
|    fps                  | 1582      |
|    iterations           | 2         |
|    time_elapsed         | 2         |
|    total_timesteps      | 4096      |
| train/                  |           |
|    approx_kl            | 5.4508963 |
|    clip_fraction        | 0.459     |
|    clip_range           | 0.2       |
|    entropy_loss         | 1.51      |
|    explained_variance   | 0.821     |
|    learning_rate        | 0.0003    |
|    loss           

<stable_baselines3.ppo.ppo.PPO at 0x7f44400dcbb0>

In [13]:
model.save("BoatControl_boat_model_level_2.dat")

In [14]:
config_new = {**config}
config_new['plot_fn'] = plot_new
config_new['target_tolerance_multiplier'] = 2

model.set_env(MultiMarkEnv(config, dt=5, seq_size=1, bounds=bounds, target_phase_steps=8, heading_phase_steps=8))

model.learn(total_timesteps=1_500_000)

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


---------------------------------
| rollout/           |          |
|    ep_len_mean     | 41.5     |
|    ep_rew_mean     | 2.15     |
| time/              |          |
|    fps             | 2133     |
|    iterations      | 1        |
|    time_elapsed    | 0        |
|    total_timesteps | 2048     |
---------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 45.6       |
|    ep_rew_mean          | 2.27       |
| time/                   |            |
|    fps                  | 1721       |
|    iterations           | 2          |
|    time_elapsed         | 2          |
|    total_timesteps      | 4096       |
| train/                  |            |
|    approx_kl            | 0.45054722 |
|    clip_fraction        | 0.375      |
|    clip_range           | 0.2        |
|    entropy_loss         | 1.8        |
|    explained_variance   | 0.971      |
|    learning_rate        | 0.0003     |
|   

<stable_baselines3.ppo.ppo.PPO at 0x7f44400dcbb0>

In [15]:
model.save("BoatControl_boat_model_level_3.dat")

In [16]:
config_new = {**config}
config_new['plot_fn'] = plot_new
config_new['target_tolerance_multiplier'] = 1.5

model.set_env(MultiMarkEnv(config, dt=5, seq_size=1, bounds=bounds, target_phase_steps=8, heading_phase_steps=8))

model.learn(total_timesteps=1_500_000)

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


---------------------------------
| rollout/           |          |
|    ep_len_mean     | 41.6     |
|    ep_rew_mean     | 2.06     |
| time/              |          |
|    fps             | 2110     |
|    iterations      | 1        |
|    time_elapsed    | 0        |
|    total_timesteps | 2048     |
---------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 42.2       |
|    ep_rew_mean          | 2.09       |
| time/                   |            |
|    fps                  | 1676       |
|    iterations           | 2          |
|    time_elapsed         | 2          |
|    total_timesteps      | 4096       |
| train/                  |            |
|    approx_kl            | 0.20003253 |
|    clip_fraction        | 0.348      |
|    clip_range           | 0.2        |
|    entropy_loss         | 1.64       |
|    explained_variance   | 0.986      |
|    learning_rate        | 0.0003     |
|   

<stable_baselines3.ppo.ppo.PPO at 0x7f44400dcbb0>

In [17]:
model.save("BoatControl_boat_model_fine.dat")


In [18]:
env = MultiMarkEnv(config, dt=5, bounds=bounds, seq_size=1)


def render(data, marks, bounds):
    x_values = [point['x'] for point in data]
    y_values = [point['y'] for point in data]
    meta_values = [{k: '%.3f' % v for (k, v) in point['meta'].items()} for point in data]

    min_x, max_x, min_y, max_y = bounds


    color_meta = [point['meta']['current_mark'] for point in data]
    colors = [colormap[color % len(colormap)] for color in color_meta]

    headings = [point['meta']['heading'] for point in data]
    rudder_angles = [point['meta']['rudder_angle'] for point in data]
    vmg = [point['meta']['vmg'] for point in data]

    global_meta = {'final_position': (x_values[-1], y_values[-1]), 'reward': sum([point['meta']['reward'] for point in data]), 'iters': len(data), 'episode': episode}

    # Update data
    scatter = fig.data[0]
    scatter.x = x_values
    scatter.y = y_values
    scatter.hovertext = meta_values  # Add this line
    scatter.marker.color = colors
    scatter.line.color = 'lightgrey'

    marks_scatter = fig.data[1]
    marks_scatter.x = [x for x, _ in marks]
    marks_scatter.y = [y for _, y in marks]
    marks_scatter.text = list(range(len(marks)))
    marks_scatter.textposition = 'bottom right'
    marks_scatter.marker.color = 'black'
    marks_scatter.marker.symbol = 'x'
    marks_scatter.marker.size = 10

    heading_scatter = heading_fig.data[0]
    heading_scatter.x = list(range(len(headings)))
    heading_scatter.y = headings
    heading_scatter.marker.color = colors
    heading_scatter.line.color = 'lightgrey'

    rudder_angle_scatter = heading_fig.data[1]
    rudder_angle_scatter.x = list(range(len(rudder_angles)))
    rudder_angle_scatter.y = rudder_angles
    rudder_angle_scatter.marker.color = 'lightgreen'
    rudder_angle_scatter.line.color = 'lightgreen'

    vmg_scatter = vmg_fig.data[0]
    vmg_scatter.x = list(range(len(vmg)))
    vmg_scatter.y = vmg
    vmg_scatter.marker.color = colors
    vmg_scatter.line.color = 'lightgrey'

    fig.update_xaxes(range=[min_x * 1.1, max_x * 1.1],dtick=25)
    fig.update_yaxes(range=[min_y * 1.1, max_y * 1.1],dtick=25)
    fig.layout.width=800
    fig.layout.height=800

    heading_fig.update_yaxes(range=[-210, 210],dtick=30)
    heading_fig.layout.width=800
    heading_fig.layout.height=600

    vmg_fig.layout.width=800
    vmg_fig.layout.height=400

    # Update the text output
    with out:
        out.clear_output(wait=True)
        print(f"Global Meta: {str(global_meta)}")


# Initialize output for the text
out = widgets.Output()

# Display text and figure
display(out)
display(fig)
display(heading_fig)
display(vmg_fig)

Output()

FigureWidget({
    'data': [{'hovertext': [{'current_mark': '0.000', 'vmg': '0.000', 'heading':
                            '135.000', 'rudder_angle': '-33.773', 'reward':
                            '0.000', 'speed': '0.000'}, {'current_mark': '0.000',
                            'vmg': '2.996', 'heading': '74.092', 'rudder_angle':
                            '90.000', 'reward': '0.108', 'speed': '30.130'},
                            {'current_mark': '0.000', 'vmg': '2.941', 'heading':
                            '68.719', 'rudder_angle': '34.740', 'reward': '0.102',
                            'speed': '32.872'}, {'current_mark': '0.000', 'vmg':
                            '3.221', 'heading': '103.459', 'rudder_angle':
                            '53.688', 'reward': '0.134', 'speed': '32.561'},
                            {'current_mark': '0.000', 'vmg': '3.159', 'heading':
                            '106.658', 'rudder_angle': '-39.876', 'reward':
                            '0.126

FigureWidget({
    'data': [{'line': {'color': 'lightgrey'},
              'marker': {'color': [blue, blue, blue, blue, blue, blue, blue, blue,
                                   blue, blue, blue, blue, blue, blue, red]},
              'mode': 'markers+lines',
              'name': 'heading',
              'type': 'scatter',
              'uid': '1079a7a1-7bba-4fbc-9682-ab48acb25668',
              'x': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
              'y': [135.0, 74.09188375514128, 68.71947826225363,
                    103.4586950241547, 106.65795710166974, 108.25969366695571,
                    83.72957089386499, 79.01299931448028, 69.67214845590752,
                    85.02164375502265, 100.48684919760012, 105.72235451131648,
                    99.8772977065597, 99.33460591073566, 86.90682382708678]},
             {'line': {'color': 'lightgreen'},
              'marker': {'color': 'lightgreen'},
              'mode': 'markers+lines',
              'name': 'rudde

FigureWidget({
    'data': [{'line': {'color': 'lightgrey'},
              'marker': {'color': [blue, blue, blue, blue, blue, blue, blue, blue,
                                   blue, blue, blue, blue, blue, blue, red]},
              'mode': 'markers+lines',
              'type': 'scatter',
              'uid': 'd269d9e4-352c-422e-beec-10b63ca048ff',
              'x': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
              'y': [0, 2.996167977003396, 2.940631725537705, 3.2209688148045132,
                    3.159118154631177, 3.101486278885751, 3.249165085813891,
                    3.187346622222959, 2.9567615128788645, 3.223936651914875,
                    3.2621681376175937, 3.17633584692732, 3.2193715438327537,
                    3.1719728399430847, 3.2618623596767518]}],
    'layout': {'height': 400,
               'template': '...',
               'title': {'text': 'VMG over time'},
               'width': 800,
               'xaxis': {'dtick': 25, 'range': [0, 20

In [None]:
target_x = 0
target_y = 250

best_trajectory = None
best_reward = None
for i in range(20):
    obs, _ = env.reset()
    env.target_x[0] = target_x
    env.target_y[0] = target_y
    env.boat.heading = np.pi / 4
    env.heading = np.pi/4
    done = False
    total_reward = 0
    while not done:
        action, _states = model.predict(obs)
        obs, reward, is_terminal, is_truncated, info = env.step(action)
        done = is_terminal or is_truncated
        total_reward += reward

    if best_reward is None or total_reward > best_reward:
        best_reward = total_reward
        best_trajectory = env.trajectory

render(best_trajectory, [(target_x, target_y)], bounds)