In [2]:
import os
import sys
import pathlib
project_root = os.path.abspath("../")
sys.path.append(project_root)
import random
import time
from game.infantry import Infantry,ACTION
from game.tank import *
from game.panal import *
pygame.init()


def get_nearby(infantry_units, unit):
    min_distance = 100
    target_infantry = None
    for infantry in infantry_units:
        distance = abs(infantry.x - unit.x) + abs(infantry.y - unit.y)
        if distance < min_distance and infantry.hp > 0:
            min_distance = distance
            target_infantry = infantry
    return (target_infantry.x, target_infantry.y)


def transform_state(unit, tank):
    dx = tank.x - unit.x
    dy = tank.y - unit.y
    return (dx,dy,unit.hp)


pygame 2.6.1 (SDL 2.28.4, Python 3.10.16)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [186]:
class QLearningAgent:
    def __init__(self, env):
        self.env = env
        self.q_table = self.create_q_table()
        self.alpha = 0.1
        self.gamma = 0.99
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995

    def get_q_value(self, state, action):
        if abs(state[0]) > 20 and abs(state[1]) > 20 :
            state = (int(state[0]/abs(state[0])*20),int(state[1]/abs(state[1])*20),state[2])
        elif abs(state[0]) > 20:
            state = (int(state[0]/abs(state[0])*20),state[1],state[2])
        elif abs(state[1]) > 20:
            state = (state[0],int(state[1]/abs(state[1])*20),state[2])
        return self.q_table.get(state).get(action)

    def choose_action(self, state):
        if random.uniform(0, 1) < self.epsilon:
            return random.choice(ACTION)
        else:
            q_values = [self.get_q_value(state, action) for action in ACTION]
            max_q = max(q_values)
            best_actions = [a for a, q in zip(ACTION, q_values) if q == max_q]
            return random.choice(best_actions)

    def update_q_table(self, state, action, reward, new_state):
        current_q = self.get_q_value(state, action)
        future_q = max([self.get_q_value(new_state, a) for a in ACTION])
        new_q = current_q + self.alpha * (reward + self.gamma * future_q - current_q)
        self.q_table[state][action] = new_q

    def decay_epsilon(self):
        self.epsilon = max(self.epsilon * self.epsilon_decay, self.epsilon_min)
        
    def create_q_table(self):
        #如果q_table.pkl存在，则加载,使用pathlib
        import pickle
        if pathlib.Path('q_table.pkl').exists():
            print('load q_table.pkl')
            with open('q_table.pkl', 'rb') as f:
                q_table = pickle.load(f)
            return q_table
        state = []
        for i in range(-20, 20 + 1):
            for j in range(-20, 20  + 1):
                state.append((i, j))
        result = {}
        for i in state:
            for j in [0,2,4,6,8]:
                result[(i[0],i[1],j)] = {}
                for k in ACTION:
                    result[(i[0], i[1], j)][k] = 0
        return result




In [190]:
class Environment:
    def __init__(self):
        self.game_end = False
        self.player_units = [Infantry(random.randint(0, BOARD_WIDTH // 2), random.randint(0, BOARD_HEIGHT)) for _ in
                        range(10)]
        self.tank = Tank(BOARD_WIDTH - 5, BOARD_HEIGHT // 2)
        self.tank_ai = TankAI(self.tank, None)
        self.states = []
        self.screen = ui_init()
        for i in self.player_units:
            self.states.append(transform_state(i,self.tank))
        
    def show_state(self):
        self.screen.fill((0, 0, 0))
        # 绘制网格
        for x in range(BOARD_WIDTH):
            for y in range(BOARD_HEIGHT):
                rect = pygame.Rect(x * CELL_WIDTH, y * CELL_HEIGHT, CELL_WIDTH, CELL_HEIGHT)
                pygame.draw.rect(self.screen, "#2F4F4F", rect, 1)
                
        # 绘制单位
        for infantry in self.player_units:
            rect = pygame.Rect(infantry.x * CELL_WIDTH, infantry.y * CELL_HEIGHT, CELL_WIDTH, CELL_HEIGHT)
            if infantry.hp <= 0:
                continue
            if infantry.selected:
                pygame.draw.rect(self.screen, GREEN, rect)
            elif not infantry.moved:
                pygame.draw.rect(self.screen, BLUE, rect)
            else:
                pygame.draw.rect(self.screen, "#4444AA", rect)
            pygame.draw.rect(self.screen, BLACK, rect, 1)

        # 绘制坦克
        tank_rect = pygame.Rect(
            self.tank.x * CELL_WIDTH - CELL_WIDTH,
            self.tank.y * CELL_HEIGHT - CELL_HEIGHT,
            CELL_WIDTH * self.tank.size[0],
            CELL_HEIGHT * self.tank.size[1]
        )
        pygame.draw.rect(self.screen, RED, tank_rect)
        pygame.draw.rect(self.screen, BLACK, tank_rect, 1)

        draw_hp_panel(self.tank,self.player_units,self.screen)
        pygame.display.flip()

    def reset(self):
        self.game_end = False
        self.player_units = [Infantry(random.randint(0, BOARD_WIDTH // 2), random.randint(0, BOARD_HEIGHT)) for _ in
                        range(10)]
        self.tank = Tank(BOARD_WIDTH - 5, BOARD_HEIGHT // 2)
        self.tank_ai = TankAI(self.tank, None)
        self.show_state()
        

    def step(self, actions,use_screen = False):
        rewards = [0 for i in range(10)]
        for index,unit in enumerate(self.player_units):
            unit.move(actions[index][0],actions[index][1])

        under_crush_units = []
        self.tank_ai.decide_movement(self.player_units)
        for x, y in self.tank.route:
            for i in self.tank.crush(x, y, self.player_units):
                #被碾压了但是未被记录
                if i not in under_crush_units:
                    under_crush_units.append(i)
        for index in under_crush_units:
            rewards[index] -= 35
        self.tank.route = []
        self.tank_ai.decide_attack(self.player_units)
        for index,infantry in enumerate(self.player_units):
            if infantry.hp <= 0:
                continue
            if abs(infantry.x - self.tank.x) + abs(infantry.y - self.tank.y) <= infantry.range:
                self.tank.hp -= infantry.attack_power
                # 给予奖励
                rewards[index] += 6

            else:
                rewards[index] -= 1
                # 给予副奖励
                pass
        self.show_state()
        pygame.display.flip()
        if use_screen:
            game_end = check_victory_conditions(self.tank, self.player_units, self.screen)
        else:
            game_end = check_victory_conditions(self.tank, self.player_units)
        pygame.display.flip()
        self.states = []
        for i in self.player_units:
            self.states.append(transform_state(i,self.tank))
        if game_end:
            print(game_end)

        return self.states,rewards,game_end


# 正式训练代码

In [192]:
env = Environment()
agent = QLearningAgent(env)
num_episodes = 1000

for episode in range(num_episodes):
    # 重置环境
    env.reset()
    total_reward = 0

    while True:
        # 获取当前状态
        current_states = [transform_state(unit, env.tank) for unit in env.player_units]
        actions = []
        
        # 选择动作
        for idx, state in enumerate(current_states):
            if state is not None:  # 确保状态合法
                # action = agent.q_learning_units[idx].choose_action(state)
                action = agent.choose_action(state)
                actions.append(action)
        
        # 执行动作并获取新状态和奖励
        next_states, rewards, done = env.step(actions,episode>600)
        
        # 更新Q表
        for idx, (state, action, reward, next_state) in enumerate(zip(current_states, actions, rewards, next_states)):
            if abs(state[0]) > 20 and abs(state[1]) > 20 :
                state = (int(state[0]/abs(state[0])*20),int(state[1]/abs(state[1])*20),state[2])
            elif abs(state[0]) > 20:
                state = (int(state[0]/abs(state[0])*20),state[1],state[2])
            elif abs(state[1]) > 20:
                state = (state[0],int(state[1]/abs(state[1])*20),state[2])
            agent.update_q_table(state, action, reward, next_state)
        
        total_reward += sum(rewards)

        # 检查游戏是否结束
        if done:
            # 打印训练信息
            print(f"Episode {episode + 1}: Total Reward = {total_reward}! ")
            break
        if episode > 600:
            time.sleep(0.5)

    # 衰减探索率
    agent.decay_epsilon()





# # 保存 Q 表
# with open('q_table.pkl', 'wb') as f:
#     pickle.dump(agent.q_table, f)


load q_table.pkl
Player Loses!
Episode 1: Total Reward = -169! 
Player Loses!
Episode 2: Total Reward = -280! 
Player Loses!
Episode 3: Total Reward = -453! 
Player Loses!
Episode 4: Total Reward = -351! 
Player Loses!
Episode 5: Total Reward = -198! 
Player Loses!
Episode 6: Total Reward = -112! 
Player Loses!
Episode 7: Total Reward = -293! 
Player Loses!
Episode 8: Total Reward = -240! 
Player Loses!
Episode 9: Total Reward = -360! 
Player Loses!
Episode 10: Total Reward = -352! 
Player Loses!
Episode 11: Total Reward = -415! 
Player Loses!
Episode 12: Total Reward = -391! 
Player Loses!
Episode 13: Total Reward = -311! 
Player Loses!
Episode 14: Total Reward = -249! 
Player Loses!
Episode 15: Total Reward = -349! 
Player Loses!
Episode 16: Total Reward = -377! 
Player Loses!
Episode 17: Total Reward = -392! 
Player Loses!
Episode 18: Total Reward = -216! 
Player Loses!
Episode 19: Total Reward = -186! 
Player Loses!
Episode 20: Total Reward = -248! 
Player Loses!
Episode 21: Total 

KeyboardInterrupt: 