In [None]:
%matplotlib tk

import argparse
import gym
import datetime
import os
import random
import tempfile
import numpy as np
import pickle

import ray
from ray import tune
from ray.tune.logger import Logger, UnifiedLogger, pretty_print
from ray.rllib.env.multi_agent_env import make_multi_agent
from ray.rllib.examples.models.shared_weights_model import TF2SharedWeightsModel
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.test_utils import check_learning_achieved
from ray.rllib.agents.ppo import ppo
from ray.rllib.models import ModelCatalog
from environment_rllib_3d import MyEnv
from settings.initial_settings import *
from settings.reset_conditions import reset_conditions
#from modules.models import MyConv2DModel_v0B_Small_CBAM_1DConv_Share
from modules.models import DenseNetModelLarge
from tensorflow.keras.utils import plot_model
from modules.savers import save_conditions
from utility.result_env import render_env
from utility.terminate_uavsimproc import teminate_proc
import matplotlib.pyplot as plt
import matplotlib
import tensorflow as tf
import cv2
import ctypes
import warnings

#UCAV.exeが起動している場合、プロセスキルする。
teminate_proc.UAVsimprockill(proc_name="UCAV_100.exe")

warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) 
warnings.filterwarnings('ignore', category=matplotlib.MatplotlibDeprecationWarning)
np.set_printoptions(precision=3, suppress=True)
PROJECT = "UCAV"
TRIAL_ID = 2
TRIAL = 'test_' + str(TRIAL_ID)
EVAL_FREQ = 10
CONTINUAL = True

def custom_log_creator(custom_path, custom_str):
    timestr = datetime.datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
    logdir_prefix = "{}_{}".format(custom_str, timestr)

    def logger_creator(config):
        if not os.path.exists(custom_path):
            os.makedirs(custom_path)
        logdir = tempfile.mkdtemp(prefix=logdir_prefix, dir=custom_path)
        return UnifiedLogger(config, logdir, loggers=None)

    return logger_creator

ray.shutdown()
ray.init(ignore_reinit_error=True, log_to_driver=False)

ModelCatalog.register_custom_model('my_model', DenseNetModelLarge)

# config = {"env": MyEnv,
#           "num_workers": NUM_WORKERS,
#           "num_gpus": NUM_GPUS,
#           "num_cpus_per_worker": NUM_CPUS_PER_WORKER,
#           "num_sgd_iter": NUM_SGD_ITER,
#           "lr": LEARNING_RATE,
#           "gamma": GAMMA,  # default=0.99
#           "model": {"custom_model": "my_model"}
#           # "framework": framework
#           }  # use tensorflow 2
config = {"env": MyEnv,"num_gpus": 0,"num_workers": 0, "num_cpus_per_worker": 0,"num_gpus": 0}
conditions_dir = os.path.join('./' + PROJECT + '/conditions/')

if not os.path.exists(conditions_dir):
    os.makedirs(conditions_dir)
save_conditions(conditions_dir)

# PPOTrainer()は、try_import_tfを使うと、なぜかTensorflowのeager modeのエラーになる。

trainer = ppo.PPOTrainer(config=config,
                         logger_creator=custom_log_creator(
                             os.path.expanduser("./" + PROJECT + "/logs"), TRIAL))

if CONTINUAL:
    # Continual learning: Need to specify the checkpoint
    model_path = PROJECT + '/checkpoints/' + TRIAL + '/checkpoint_000001/checkpoint-1'
    trainer.restore(checkpoint_path=model_path)

models_dir = os.path.join('./' + PROJECT + '/models/')
if not os.path.exists(models_dir):
    os.makedirs(models_dir)
text_name = models_dir + TRIAL + '.txt'
with open(text_name, "w") as fp:
    trainer.get_policy().model.base_model.summary(print_fn=lambda x: fp.write(x + "\r\n"))
png_name = models_dir + TRIAL + '.png'
plot_model(trainer.get_policy().model.base_model, to_file=png_name, show_shapes=True)

# Instanciate the evaluation env
eval_env = MyEnv({})

# Define checkpoint dir
check_point_dir = os.path.join('./' + PROJECT + '/checkpoints/', TRIAL)
if not os.path.exists(check_point_dir):
    os.makedirs(check_point_dir)

In [None]:
#def getkey(key):
    # return 111
#    return(bool(ctypes.windll.user32.GetAsyncKeyState(key) & 0x8000))
# Training & evaluation

record_mode = 1
results_dir = os.path.join('./' + PROJECT + '/results/')

if not os.path.exists(results_dir):
    os.makedirs(results_dir)
results_file = results_dir + TRIAL + '.pkl'
for steps in range(10001):
    # Training
    # print(f'\n----------------- Training at steps:{steps} start! -----------------')
    # results = trainer.train()

    # Evaluation
    if steps % EVAL_FREQ == 0:
        print(f'\n----------------- Evaluation at steps:{steps} starting ! -----------------')
        print(pretty_print(results))
        check_point = trainer.save(checkpoint_dir=check_point_dir)
        win = 0
        for i in range(NUM_EVAL):
            # print(f'\nEvaluation {i}:')
            obs = eval_env.reset()
            done = False
            
            step_num = 0
            fig = plt.figure(1)
            ESC = 0x1B          # ESCキーの仮想キーコード
            trajectory_length = 100
            env_blue_pos = [0]
            env_red_pos = [0]
            env_mrm_pos = [0]
            if record_mode == 0:
                file_name = "test_num" + str(steps) +str(i)
                video = cv2.VideoWriter(file_name+'.mp4',0x00000020,20.0,(eval_env.WINNDOW_SIZE_lon,eval_env.WINDOW_SIZE_lat))

            while True:
                action_dict = {}
                for j in range(eval_env.blue_num):
                    #if not eval_env.blue[j].hitpoint == 0:
                    action_dict['blue_' + str(j)] = trainer.compute_action(obs['blue_' + str(j)])

                obs, rewards, dones, infos = eval_env.step(action_dict)
                env_blue_pos_temp, env_red_pos_temp, env_mrm_pos_temp= render_env.copy_from_env(eval_env)
                env_blue_pos.append(env_blue_pos_temp)
                env_red_pos.append(env_red_pos_temp)
                env_mrm_pos.append(env_mrm_pos_temp)
                if step_num == 0:
                    del env_blue_pos[0]
                    del env_red_pos[0]
                    del env_mrm_pos[0]

                hist_blue_pos = np.vstack(env_blue_pos)
                hist_red_pos = np.vstack(env_red_pos)
                hist_mrm_pos = np.vstack(env_mrm_pos)
                plt.clf()
                render_env.rend_3d(eval_env,hist_blue_pos,"b",1)
                render_env.rend_3d(eval_env,hist_red_pos,"r",1)
                render_env.rend_3d(eval_env,hist_mrm_pos,"k",1)
                fig.canvas.draw()
                plt.pause(.05)
                if record_mode == 0:
                    img = np.array(fig.canvas.renderer.buffer_rgba())
                    img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
                    # cv2.imshow('test', img)
                    # cv2.waitKey(1)
                    # cv2.destroyAllWindows()
                    video.write(img.astype('uint8'))

                
                step_num = step_num + 1
                
                done = dones["__all__"]
                #print(f'rewards:{rewards}')
                #if record_mode == 0:
                #    img = eval_env.render_movie(file_name,step_num)
                #    video.write(img.astype('unit8'))
                #elif record_mode == 1:
                #    eval_env.render()
                #elif record_mode == 2:
                #    eval_env.render()
                    
                #env_blue_pos_temp, env_red_pos_temp, env_mrm_pos_temp = render_env.copy_from_env(eval_env)
                
                #env_blue_pos.append(env_blue_pos_temp)
                #env_red_pos.append(env_red_pos_temp)
                #env_mrm_pos.append(env_mrm_pos_temp)
                #step_num = step_num + 1
                # エピソードの終了処理
                if dones['__all__']:
                    # print(f'all done at {env.steps}')
                    break
                
            #del env_blue_pos[0]
            #del env_red_pos[0]
            #del env_mrm_pos[0]
            
            #hist_blue_pos = np.vstack(env_blue_pos)
            #hist_red_pos = np.vstack(env_red_pos)
            #hist_mrm_pos = np.vstack(env_mrm_pos)
            
            #f = open(results_file,'wb')
            #pickle.dump(emv_blue_pos,f)
            #pickle.dump(emv_red_pos,f)
            #pickle.dump(emv_mrm_pos,f)
            #f.close()
            
            if record_mode == 0:
                video.release()

ray.shutdown()