In [None]:
# Mount google drive storage to the notebook.
from google.colab import drive
drive.mount('/content/gdrive/')
initFlag = 0

# Choose the name of the Github folder you want to create in MyDrive.
gitDir = 'Github'

In [None]:
# Initial setup. Execute this cell only when you want to frehsly setup the project.

# Setup cloned repository in google drive for the first time.

#initFlag = 1
#%cd 'gdrive/MyDrive'
#%mkdir $gitDir
#%cd $gitDir
#!git clone 'https://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@github.com/gohmmagn/HJDQN.git'

# Install hjdqn and gym_lqr for the first time.

#%cd 'HJDQN'
#!pip install -e .

In [None]:
# Import dolfinx for Google Colab.
try:
    import dolfinx
except ImportError:
    !wget "https://fem-on-colab.github.io/releases/fenicsx-install-real.sh" -O "/tmp/fenicsx-install.sh" && bash "/tmp/fenicsx-install.sh"
    import dolfinx

In [None]:
# Import PyTorch and check version.
!pip install torch
import torch
print('torch: '+torch.__version__)

# Import the OpenAI Gym and check version.
!pip install gymnasium
import gymnasium as gym
print('gym: '+gym.__version__)

# Install control package to calulate the solution of the riccati equation.
!pip install control
import control
print('control: '+control.__version__)

In [None]:
# Setup gdrive file handling and goto main project directory.
if initFlag == 0:
  subDir = 'gdrive/MyDrive/' + gitDir + '/HJDQN/fileHandling'
else:
  subDir = 'fileHandling'

%cd $subDir
from gdrive_File_Handler import gdriveFileHandler
%cd ..

In [None]:
# To train new gym environments we first need to register them.
%cd 'gym_lqr'
!pip install -e .
%cd ..

In [None]:
# Set the focus of the file handler to the environment specified with the (envId)
envId = 'NonLinearPDEEnv-v0'
#envId = 'Linear1dPDEEnv-v0'
#envId = 'Linear2dPDEEnv-v0'
gdriveFH = gdriveFileHandler(envId)

In [None]:
# Show models of given environment.
df_modelNames = gdriveFH.getModelsOfEnvironment()
k = 0
df_modelNames

In [None]:
# Merge training and evaluation logs into single files.
gdriveFH.mergeAllLogsOfEnvironment()
#parFileId = 0
#gdriveFH.mergeLogs(df_modelNames['Model name'][k], 'eval_log', parFileId)
#gdriveFH.mergeLogs(df_modelNames['Model name'][k], 'train_log', parFileId)

In [None]:
# Checkout the checkpoints which were saved during model training.
df_checkpoints = gdriveFH.getCheckpointFiles(df_modelNames['Model name'][k])
df_checkpoints

In [None]:
# Checkout the parameters of the given model with the specified parameter file id (parFileId).
parFileId = 0
parameter_list = gdriveFH.getModelParameters(df_modelNames['Model name'][k],parFileId)
parameter_list

In [None]:
# Get an overview of the svaed ricatti solutions.
df_ricatti_solutions = gdriveFH.getRicattiSolutionFiles()
df_ricatti_solutions

In [None]:
# Train hjdqn model.

!python main.py --env=NonLinearPDEEnv-v0 --model='Critic_NN1' --algo='hjdqn' --T=4.0 --L=10 --tau=1e-3 --lr=1e-4 --sigma=0.1 --noise='gaussian' --max_iter=2e4 --eval_interval=50 --fill_buffer=0 --start_train=400 --batch_size=512 --gamma=0.99999

In [None]:
# Calculate state solution.

!python calculateStateSolution.py --envId='NonLinearPDEEnv-v0' --modelName='HJDQN_2023-09-20T131710' --savedModel='HJDQN_2023-09-20T131710_0_19998.pth.tar'

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

modelNames = [df_modelNames['Model name'][0]]#df_modelNames['Model name']
modelId = ["0"]#["0" for i in range(0,len(modelName))]

def plotEvaluation(envId, modelName, modelId):
  evallog_path = "{}_{}.csv".format(modelName, str(modelId))
  outputDirectory = '/content/gdrive/MyDrive/'+ gitDir +'/HJDQN/outputs'
  avgCostEvalAndExact = pd.read_csv('{}/{}/{}/eval_log/{}'.format(outputDirectory, envId, modelName, evallog_path)).values
  print("Minimal difference: {}".format(min(np.abs(avgCostEvalAndExact[:,1]-avgCostEvalAndExact[:,2]))))
  return ax1.plot(avgCostEvalAndExact[:,0], np.abs(avgCostEvalAndExact[:,1]-avgCostEvalAndExact[:,2]))

fig = plt.figure()
ax1 = fig.add_subplot()
ax1.set_ylabel('abs(Average Return - Average Exact Return)')
ax1.set_xlabel('steps')
for i in range(0,len(modelNames)):
  plotEvaluation(envId, modelNames[i], modelId[i])
plt.show()