In [1]:
from modules.utils import TrainConfig, Logger, paint, get_env, simulate_episode, evaluate
from modules.DQN import DQN
from modules.reward import Reward
from modules.preprocess import preprocess

import numpy as np
import pickle
from tqdm.auto import trange

import torch
import torch.nn as nn
import math

import random

import warnings
warnings.filterwarnings('ignore')

import os

# 1 right
# 2 left
# 3 up
# 4 down

# general settings
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
N_ACTIONS = 5
N_PREDATORS = 5
N_MASKS = 5 
MAP_SIZE = 40

  from .autonotebook import tqdm as notebook_tqdm


# TRAIN

In [2]:
# train settings
cfg = TrainConfig(
    description='some description',    
    max_steps_for_episode=300, 
    gamma = 0.9,
    initial_steps=300, # 3000
    steps=100_000,
    steps_per_update=3,
    steps_per_paint=250, # 500
    steps_per_eval=1000, # 5000
    buffer_size=10_000,
    batch_size=64,
    learning_rate=1e-3,
    eps_start=0.9, 
    eps_end=0.05,
    eps_decay=1000,
    tau=0.01, # the update rate of the target network, was 0.005
    reward_params=dict(
        w_dist_change=-0.8,
        w_kill_prey=1.,
        w_kill_enemy=2.1,
        w_kill_bonus=1.3, 
        gamma_for_bonus_count=0.5,
        n_nearest_targets=2,
    ),
    seed=1234 
)

def train(kindergarten: bool):
    model = DQN(
        n_masks=N_MASKS,
        n_actions=N_ACTIONS,
        n_predators=N_PREDATORS,
        map_size=MAP_SIZE,
        device=DEVICE,
        config=cfg
    ).to(DEVICE).train()

    if kindergarten:
        def get_difficulty(): return -1 + model.steps_done / cfg.steps
    else:
        def get_difficulty(): return model.steps_done / cfg.steps

    logger = Logger(cfg)

    # INITIAL STEPS
    env = get_env(n_predators=N_PREDATORS, difficulty=get_difficulty(), step_limit=cfg.max_steps_for_episode)
    state, info = env.reset()
    processed_state = preprocess(state, info)
    r = Reward(N_PREDATORS, cfg.reward_params)
    for _ in trange(cfg.initial_steps):
        actions = model.get_actions(processed_state, random=True)
        next_state, done, next_info = env.step(actions)
        next_processed_state = preprocess(next_state, next_info)
        reward = r(processed_state, info, next_processed_state, next_info)
        model.consume_transition(processed_state, actions, next_processed_state, reward, done)
        state, info = (next_state, next_info) if not done else env.reset()
        processed_state = preprocess(state, info)

    # with open(f'pre_calc_buffer_simple_10000.pkl', 'wb') as handle:
    #     pickle.dump(model.buffer, handle, protocol=pickle.HIGHEST_PROTOCOL)

    # TRAINING
    # with open('pre_calc_buffer_simple_10000.pkl', 'rb') as handle:
    #     model.buffer = pickle.load(handle)

    env = get_env(n_predators=N_PREDATORS, difficulty=get_difficulty(), step_limit=cfg.max_steps_for_episode)
    state, info = env.reset()
    processed_state = preprocess(state, info)
    r = Reward(N_PREDATORS, cfg.reward_params)
    score_difference = None

    try:
        for _ in trange(cfg.steps):
            # ========== step ==========================================================
            eps_threshold = cfg.eps_end + (cfg.eps_start - cfg.eps_end) * \
                math.exp(-1. * model.steps_done / cfg.eps_decay)
            actions = model.get_actions(processed_state, random=(random.random() < eps_threshold))            
            next_state, done, next_info = env.step(actions)
            next_processed_state = preprocess(next_state, next_info)
            reward = r(processed_state, info, next_processed_state, next_info)
            model.consume_transition(processed_state, actions, next_processed_state, reward, done)

            if done:
                # not just reset in oreder to implement changes of map
                env = get_env(n_predators=N_PREDATORS, difficulty=get_difficulty(),
                              step_limit=cfg.max_steps_for_episode)
                state, info = env.reset()
            else:
                state, info = next_state, next_info

            processed_state = preprocess(state, info)

            # ========== updates =======================================================
            if model.steps_done % cfg.steps_per_update == 0:
                reward_batch, loss_batch = model.update_policy_network()

            model.soft_update_target_network()  # each step

            if model.steps_done % cfg.steps_per_paint == 0 and model.steps_done > 0:
                paint(logger)

            if model.steps_done % cfg.steps_per_eval == 0 and model.steps_done > 0:
                os.makedirs(logger.curr_subfolder_path + '/gifs', exist_ok=True)
                path = f'{logger.curr_subfolder_path}/gifs/{model.steps_done}_steps.gif'
                simulate_episode(model, get_difficulty(), N_PREDATORS, cfg, path, render_gif=True)
                score_difference = evaluate(model, N_PREDATORS, cfg)   
                model.save(logger.curr_subfolder_path + f'/model_steps_{model.steps_done}_score_{round(score_difference, 2)}.pt')             

            model.steps_done += 1

            # ========== logs ==========================================================
            logger.add('eps', eps_threshold)
            logger.add('reward', reward.mean())
            logger.add('reward_batch', reward_batch)
            logger.add('loss_batch', loss_batch)
            logger.add('score_difference', score_difference)

    except KeyboardInterrupt:
        print('Training interrupted')

    except Exception as e:
        print(f'Exception: {e}')

    finally:
        logger.save()
        model.save(logger.curr_subfolder_path + f'/model_steps_{model.steps_done}.pt')
        return model


model = train(kindergarten=True)

100%|██████████| 300/300 [00:47<00:00,  6.36it/s]
  0%|          | 113/100000 [00:22<6:37:09,  4.19it/s]

# TODO
- monospaced...

# Идеи

1. зафорсить оптимальные действия в инишал буффер ??
2. добавить шедулер ??
3. если заработает бейзлайн, подумать как добавить возм-ть выучить "бфс"
4. добавить маску с 1 в точке (20, 20)
5. double or dueling DQN

In [None]:
# model = DQN(
#     n_masks=N_MASKS,
#     n_actions=N_ACTIONS,
#     n_predators=N_PREDATORS,
#     map_size=MAP_SIZE,
#     device=DEVICE,
#     config=cfg
# ).to(DEVICE).train()

# with open('pre_calculated_buffer_10000.pkl', 'rb') as handle:
#     buffer = pickle.load(handle)