In [10]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from model import BallBot2D
from thop import profile
import numpy as np
import pickle
import optuna
import logging
import sys
import os
import plotly.graph_objects as go
from plotly.subplots import make_subplots

#torch.manual_seed(11)
#torch.autograd.set_detect_anomaly(True) 

In [36]:
h_dim_x = 16
n_layers_x = 8
h_dim_u = 16
n_layers_u = 8
n_snapshot = 300
# Controling Horizon
T = np.linspace(0, 11, n_snapshot)

# LOAD MODEL
model = BallBot2D(h_dim_x = h_dim_x, 
    n_layers_x = n_layers_x, 
    h_dim_u = h_dim_u, 
    n_layers_u = n_layers_u)

fig = make_subplots(rows=2, cols=2)

for _ in range(50):
    # Initial State
    x_init = torch.zeros(4, 1, requires_grad = True).double()
    x_init[0] = np.pi/4
    x_init[1] = np.pi/4
    sigma_theta = np.random.normal(0, 0.3)
    sigma_phi = np.random.normal(0, 0.3)
    x_init[2] = sigma_theta
    x_init[3] = sigma_phi
    # Target variable
    y_tar = torch.zeros(4, 1).double()
    # Optimization setting
    criterion = nn.MSELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    flops, params = profile(model, inputs=(x_init,))
    print(f"FLOPs: {flops}, Params: {params}")
    # TRAINING LOOP
    STATE = []
    STATE.append(x_init.T)
    CONTROL = []
    CONTROL.append(torch.zeros(4, 1, requires_grad = True).double())

    model.train()
    for i in range(len(T)-1):
        dotx, state_vector = model.forward(x_init)
        # Global loss
        loss = criterion(state_vector.flatten(), y_tar.flatten()) + criterion(dotx.flatten(), y_tar.flatten())
        # Backward pass
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
        # Print loss
        print(f'Epoch [{i+1}/{len(T)}], Loss: {loss.item()}')
        STATE.append(state_vector)
        CONTROL.append(model.FNN_U(state_vector))
        # Update
        x_init = state_vector.clone().detach()

    STATE = torch.vstack(STATE)
    CONTROL = torch.vstack(CONTROL)
    
    # Subplot for THETA vs. T
    fig.add_trace(go.Scatter(x=T, y=STATE[:, 0].detach().numpy(),
                            mode='lines+markers', line=dict(color='red', width=2),
                            marker=dict(color='red', size=5, opacity=0.5), opacity=0.5,
                            name='Body Angle', showlegend=False), row=1, col=1)
    fig.add_trace(go.Scatter(x=T, y=[0] * len(T), mode='lines', line=dict(color='black', dash='dash'),
                            showlegend=False), row=1, col=1)
    # Subplot for PHI vs. T
    fig.add_trace(go.Scatter(x=T, y=STATE[:, 1].detach().numpy(),
                            mode='lines+markers', line=dict(color='blue', width=2),
                            marker=dict(color='blue', size=5, opacity = 0.5), opacity=0.5,
                            name='Ball Angle', showlegend=False), row=2, col=1)
    fig.add_trace(go.Scatter(x=T, y=[0] * len(T), mode='lines', line=dict(color='black', dash='dash'),
                            showlegend=False), row=2, col=1)
    fig.add_trace(go.Scatter(x=[sigma_theta], y=[sigma_phi],
                            mode='markers',
                            marker=dict(color='green', size=5), opacity=0.7,
                            name='Noise Distribution', showlegend=False), row=1, col=2)

    fig.add_trace(go.Scatter(x=T, y=CONTROL.flatten().detach().numpy(),
                            mode='lines+markers', line=dict(color='purple', width=2),
                            marker=dict(color='purple', size=5, opacity=0.5), opacity=0.5,
                            name='Control Signal', showlegend=False), row=2, col=2)

fig.update_yaxes(range=[-0.5, 1], row=1, col=1)
fig.update_yaxes(range=[-0.5, 1], row=2, col=1)
fig.update_yaxes(range=[-0.07, 0.07], row=2, col=2)
fig.update_xaxes(range=[0, 3], row=1, col=1)
fig.update_xaxes(range=[0, 3], row=2, col=1)
fig.update_xaxes(range=[0, 3], row=2, col=2)
fig.update_xaxes(title_text='Time [s]',title_font=dict(size=20), tickfont=dict(size=18), row=1, col=1)
fig.update_yaxes(title_text='Body Angle [rad]',title_font=dict(size=20), tickfont=dict(size=18), row=1, col=1)
fig.update_xaxes(title_text='Time [s]', title_font=dict(size=20), tickfont=dict(size=18), row=2, col=1)
fig.update_yaxes(title_text='Ball Angle [rad]', title_font=dict(size=20), tickfont=dict(size=18), row=2, col=1)
fig.update_xaxes(title_text=r'$\sigma_{\theta}$', title_font=dict(size=20), tickfont=dict(size=18), row=1, col=2)
fig.update_yaxes(title_text=r'$\sigma_{\phi}$', title_font=dict(size=20), tickfont=dict(size=18), row=1, col=2)
fig.update_xaxes(title_text='Time [s]',title_font=dict(size=20), tickfont=dict(size=18), row=2, col=2)
fig.update_yaxes(title_text=r'$u^*(t)$',title_font=dict(size=20), tickfont=dict(size=18), row=2, col=2)

fig.update_layout(height=600, width=1200, margin=dict(l=0, r=10, t=20, b=35))
fig.write_image('noise__{}_{}_{}_{}_{}.pdf'.format(n_snapshot, h_dim_x, n_layers_x, h_dim_u, n_layers_u))


fig.show()



[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
FLOPs: 4304.0, Params: 4597.0
Epoch [1/300], Loss: 24.164319614119766
Epoch [2/300], Loss: 4.006728656736984
Epoch [3/300], Loss: 0.9696472006994281
Epoch [4/300], Loss: 1.0561908485923688
Epoch [5/300], Loss: 1.8833601943964522
Epoch [6/300], Loss: 2.4328709668749062
Epoch [7/300], Loss: 2.4693350659187407
Epoch [8/300], Loss: 2.1620252085261713
Epoch [9/300], Loss: 1.7571751666523192
Epoch [10/300], Loss: 1.3153533177300614
Epoch [11/300], Loss: 0.8825471749198803
Epoch [12/300], Loss: 0.5942666307320046
Epoch [13/300], Loss: 0.4638346083636217
Epoch [14/300], Loss: 0.41466715497304274
Epoch [15/300], Loss: 0.41232374022114193
Epoch [16/300], Loss: 0.4361609173889153
Epoch [17/300], Loss: 0.4696416477889056
Epoch [18/300], Loss: 0.5003257811663127
Epoch [19/300], Loss: 0.5190012106129398
Epoch [20/300], Loss: 0.5169426719978465
Epoch [21/300], Loss: 0.4990267367733625
Epoch [22/300], Loss: 0.4631434734311959