In [22]:
import sys
sys.path.append("..")
sys.path.append("../..")
sys.path.append("../../..")
import numpy as np
import gym
import collections
import torch
import torch.nn.functional as F
import random
from gym import spaces
from gym.utils import seeding
from aerobench.visualize import anim3d, plot
from aerobench.examples.anim3d.run_fall import fall_simulate
from aerobench.examples.anim3d.run_rise import rise_simulate
from aerobench.examples.anim3d.run_straight import straight_simulate
from aerobench.examples.anim3d.run_right_turn import right_turn_simulate
from aerobench.examples.anim3d.run_left_turn import left_turn_simulate

In [23]:
if len(sys.argv) > 1 and (sys.argv[1].endswith('.mp4') or sys.argv[1].endswith('.gif')):
    filename = sys.argv[1]
    print(f"saving result to '{filename}'")
else:
    filename = ''
    print("Plotting to the screen. To save a video, pass a command-line argument ending with '.mp4' or '.gif'.")
simulation_functions = [
        fall_simulate,
        left_turn_simulate,
        right_turn_simulate,
        rise_simulate,
        straight_simulate
    ]

Plotting to the screen. To save a video, pass a command-line argument ending with '.mp4' or '.gif'.


In [24]:
class F_16env(gym.Env):
    def __init__(self,init_x,init_y,init_z):
        self.res={}
        self.missile=[5000,5000,5000]
        self.psi=[]
        self.distance=0
        self.count=0
        self.x=init_x
        self.y=init_y
        self.z=init_z
        self.low=np.array([0,0,0],dtype=np.float32)
        self.high=np.array([np.inf,np.inf,np.inf],dtype=np.float32)
        self.action_space=spaces.Discrete(5)
        self.observation_space=spaces.Box(self.low,self.high,dtype=np.float32)

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]
    
    def calculate_reward(self):
        missile=np.array(self.missile)
        plane=np.array([self.x,self.y,self.z])
        self.distance= np.linalg.norm(plane-missile)
        if self.distance<1000:
            reward=-100
        elif self.distance<=5000:
            reward=0
        else:
            reward=100
        return reward
    
    def calculate_done(self):
        self.count+=1
        done=False
        if self.distance>10000 or self.count>=20:
            done=True
        return done
    
    def step(self,action):
        assert self.action_space.contains(action), "%r (%s) invalid" % (action,type(action),)
        select_simulation=simulation_functions[action]
        self.res, init_extra, skip_override, _=select_simulation(filename,self.x,self.y,self.z,0 if self.count==0 else self.res['states'][-1][5],2000)
        self.x,self.y,self.z=self.res['states'][-1][10],self.res['states'][-1][9],self.res['states'][-1][11]
        next_state=(self.x,self.y,self.z)
        reward=self.calculate_reward()
        done=self.calculate_done()
        return next_state,reward,done,self.res, init_extra, skip_override
    
    
    def reset(self):
        self.x,self.y,self.z=5100,5100,5100
        return (self.x,self.y,self.z)


In [25]:
class ReplayBuffer:
    """经验回放池"""
    def __init__(self,capacity):
        self.buffer=collections.deque(maxlen=capacity)
    def add(self,state,action,reward,next_state,done):
        self.buffer.append((state,action,reward,next_state,done))
    def sample(self,batch_size):
        transitions=random.sample(self.buffer,batch_size)
        state,action,reward,next_state,done=zip(*transitions)
        return np.array(state),action,reward,np.array(next_state),done
    
    def size(self):
        return len(self.buffer)

In [26]:
class Qnet(torch.nn.Module):
    def __init__(self,state_dim,hidden_dim,action_dim):
        super(Qnet,self).__init__()
        self.fc1=torch.nn.Linear(state_dim,hidden_dim)
        self.fc2=torch.nn.Linear(hidden_dim,action_dim)

    def forward(self,x):
        x=F.relu(self.fc1(x))
        return self.fc2(x)

In [27]:
class DQN:
    def __init__(self,state_dim,hidden_dim,action_dim,learning_rate,gamma,
                 epsilon,target_update,device):
        self.action_dim=action_dim
        self.q_net=Qnet(state_dim,hidden_dim,self.action_dim).to(device)
        self.target_q_net=Qnet(state_dim,hidden_dim,self.action_dim).to(device)
        self.optimizer=torch.optim.Adam(self.q_net.parameters(),lr=learning_rate)
        self.gamma=gamma
        self.epsilon=epsilon
        self.target_update=target_update
        self.count=0
        self.device=device

    def take_action(self,state):
        if np.random.random()<self.epsilon:
            action=np.random.randint(self.action_dim)
        else:
            state=torch.tensor([state],dtype=torch.float).to(self.device)
            action=self.q_net(state).argmax().item()
        return action
    
    def update(self,transition_dict):
        states=torch.tensor(transition_dict['states'],dtype=torch.float).to(
            self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(
            self.device)
        rewards = torch.tensor(transition_dict['rewards'],
                               dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'],
                                   dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'],
                             dtype=torch.float).view(-1, 1).to(self.device)
        q_values=self.q_net(states).gather(1,actions)#Q(s,a)
        max_next_q_values=self.target_q_net(next_states).max(1)[0].view(-1,1)#max(1)按行找最大，返回(值，索引)，故取第一个元素
        q_targets=rewards+self.gamma*max_next_q_values*(1-dones)
        dqn_loss=torch.mean(F.mse_loss(q_values,q_targets))
        self.optimizer.zero_grad()
        dqn_loss.backward()
        self.optimizer.step()

        if self.count%self.target_update==0:
            self.target_q_net.load_state_dict(self.q_net.state_dict())
        self.count+=1
        




In [28]:
def main():
    lr=2e-3
    hidden_dim=128#隐藏层神经元个数
    gamma=0.98
    epsilon=0.01
    target_update=10
    buffer_size=10000
    minimal_size=500
    batch_size=64
    init_x,init_y,init_z=5100,5100,5100
    device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    env=F_16env(init_x,init_y,init_z)
    random.seed(0)
    np.random.seed(0)
    env.seed(0)
    torch.manual_seed(0)
    replay_buffer=ReplayBuffer(buffer_size)
    state_dim=3
    action_dim=5
    accumulated_res = {
        'status': [], 'times': [], 'states': [], 'modes': [],
        'xd_list': [], 'ps_list': [], 'Nz_list': [], 'Ny_r_list': [], 'u_list': [], 'runtime': []
    }
    agent=DQN(state_dim,hidden_dim,action_dim,lr,gamma,epsilon,target_update,device)
    state = env.reset()
    done = False
    while not done:
        action=agent.take_action(state)
        next_state,reward,done,res, init_extra, skip_override=env.step(action)
        replay_buffer.add(state,action,reward,next_state,done)
        state=next_state
        if replay_buffer.size()>minimal_size:
            n_s,n_a,n_r,n_ns,n_done=replay_buffer.sample(batch_size)
            transition_dict={
                'states':n_s,
                'actions':n_a,
                'rewards':n_r,
                'next_states':n_ns,
                'dones':n_done
            }
            agent.update(transition_dict)
            accumulated_res['status'] = res['status']
            accumulated_res['times'].extend(res['times'])
            accumulated_res['states'].append(res['states'])
            accumulated_res['modes'].extend(res['modes'])
            if 'xd_list' in res:
                accumulated_res['xd_list'].extend(res['xd_list'])
                accumulated_res['ps_list'].extend(res['ps_list'])
                accumulated_res['Nz_list'].extend(res['Nz_list'])
                accumulated_res['Ny_r_list'].extend(res['Ny_r_list'])
                accumulated_res['u_list'].extend(res['u_list'])
            accumulated_res['runtime'].append(res['runtime'])
            accumulated_res['states'] = np.vstack(accumulated_res['states'])

    anim3d.make_anim(accumulated_res, filename, f16_scale=70, viewsize=5000, viewsize_z=4000, trail_pts=np.inf,
                     elev=27, azim=-107, skip_frames=skip_override,
                     chase=True, fixed_floor=True, init_extra=init_extra)

main()

Waypoint transition Waypoint 1 -> Waypoint 2 at time 0
Waypoint transition Waypoint 2 -> Waypoint 3 at time 11.699999999999969
Waypoint transition Waypoint 3 -> Done at time 27.93333333333388
Waypoint simulation completed in 0.35 seconds (extended_states=True)



If you passed *frames* as a generator it may be exhausted due to a previous display or save.
