In [1]:
from modules.utils import GlobalConfig, TrainConfig, Logger, paint, get_env, simulate_episode, evaluate
from modules.DQN import DQN
from modules.reward import Reward
from modules.preprocess import preprocess

import numpy as np
import pickle
from tqdm.auto import trange

import torch
import torch.nn as nn
import math

import random

import warnings
warnings.filterwarnings('ignore')

import os

# 1 right
# 2 left
# 3 up
# 4 down

global_config = GlobalConfig(
    device='cuda' if torch.cuda.is_available() else 'cpu',
    n_actions=5,
    n_predators=5,
    n_masks=5,
    map_size=40
)

  from .autonotebook import tqdm as notebook_tqdm


# TRAIN

In [3]:
train_config = TrainConfig(
    description='some description',    
    max_steps_for_episode=300, 
    gamma = 0.9,
    initial_steps=300, # 3000
    steps=100_000,
    steps_per_update=3,
    steps_per_paint=500, # 500
    steps_per_eval=3000, # 5000
    buffer_size=10_000,
    batch_size=64,
    learning_rate=1e-3,
    eps_start=0.9, 
    eps_end=0.05,
    eps_decay=1000,
    tau=0.01, # the update rate of the target network, was 0.005
    reward_params=dict(
        w_dist_change=-0.5,
        w_kill_prey=1.,
        w_kill_enemy=2.1,
        w_kill_bonus=1.3, 
        gamma_for_bonus_count=0.5,
        n_nearest_targets=2,
    ),
    seed=1234 
)

def train(global_config, train_config):
    model = DQN(global_config, train_config).to(global_config.device).train()
    def get_difficulty(): return -1 + (model.steps_done / train_config.steps) * 2
    logger = Logger(train_config, model)

    # INITIAL STEPS
    env = get_env(global_config, train_config, difficulty=get_difficulty())
    state, info = env.reset()
    processed_state = preprocess(state, info)
    r = Reward(global_config, train_config)
    for _ in trange(train_config.initial_steps):
        actions = model.get_actions(processed_state, random=True)
        next_state, done, next_info = env.step(actions)
        next_processed_state = preprocess(next_state, next_info)
        reward = r(processed_state, info, next_processed_state, next_info)
        model.consume_transition(processed_state, actions, next_processed_state, reward, done)
        state, info = (next_state, next_info) if not done else env.reset()
        processed_state = preprocess(state, info)

    # with open(f'pre_calc_buffer_simple_10000.pkl', 'wb') as handle:
    #     pickle.dump(model.buffer, handle, protocol=pickle.HIGHEST_PROTOCOL)

    # TRAINING
    # with open('pre_calc_buffer_simple_10000.pkl', 'rb') as handle:
    #     model.buffer = pickle.load(handle)

    env = get_env(global_config, train_config, difficulty=get_difficulty())
    state, info = env.reset()
    processed_state = preprocess(state, info)
    r = Reward(global_config, train_config)
    score_difference = None

    try:
        for _ in trange(train_config.steps):
            # ========== step ==========================================================
            eps_threshold = train_config.eps_end + (train_config.eps_start - train_config.eps_end) * math.exp(-1. * model.steps_done / train_config.eps_decay)
            actions = model.get_actions(processed_state, random=(random.random() < eps_threshold))            
            next_state, done, next_info = env.step(actions)
            next_processed_state = preprocess(next_state, next_info)
            reward = r(processed_state, info, next_processed_state, next_info)
            model.consume_transition(processed_state, actions, next_processed_state, reward, done)

            if done:                
                env = get_env(global_config, train_config, difficulty=get_difficulty())
                state, info = env.reset()
            else:
                state, info = next_state, next_info

            processed_state = preprocess(state, info)

            # ========== updates =======================================================
            if model.steps_done % train_config.steps_per_update == 0:
                loss = model.update_policy_network()

            model.soft_update_target_network()  # each step

            if model.steps_done % train_config.steps_per_paint == 0 and model.steps_done > 0:
                paint(logger)

            if model.steps_done % train_config.steps_per_eval == 0 and model.steps_done > 0:
                os.makedirs(logger.curr_subfolder_path + '/gifs', exist_ok=True)
                os.makedirs(logger.curr_subfolder_path + '/weights', exist_ok=True)
                gif_path = f'{logger.curr_subfolder_path}/gifs/{model.steps_done}_steps.gif'
                simulate_episode(model, get_difficulty(), gif_path)
                score_difference = evaluate(model)   
                model.save(logger.curr_subfolder_path + f'/weights/{model.steps_done//1000}k_steps_{round(score_difference, 2)}_score.pt')     

            model.steps_done += 1

            # ========== logs ==========================================================
            logger.add('eps', eps_threshold)
            logger.add('reward', reward.mean())            
            logger.add('loss', loss)
            logger.add('score_difference', score_difference)

    except KeyboardInterrupt:
        print('Training interrupted')

    except Exception as e:
        raise       

    finally:
        logger.save()
        model.save(logger.curr_subfolder_path + f'/weights/{model.steps_done//1000}k_steps_{round(score_difference, 2)}_score.pt')
        paint(logger, save_plots=True)
        return model


model = train(global_config, train_config)

AssertionError: 

# TODO
- monospaced...
- пофиксить ку вельюсь
- добавить дополнительные параметры на вход модели

# Идеи

1. зафорсить оптимальные действия в инишал буффер ??
2. добавить шедулер ??
3. если заработает бейзлайн, подумать как добавить возм-ть выучить "бфс"
4. добавить маску с 1 в точке (20, 20)
5. double or dueling DQN

In [None]:
# model = DQN(
#     n_masks=N_MASKS,
#     n_actions=N_ACTIONS,
#     n_predators=N_PREDATORS,
#     map_size=MAP_SIZE,
#     device=DEVICE,
#     config=cfg
# ).to(DEVICE).train()

# with open('pre_calculated_buffer_10000.pkl', 'rb') as handle:
#     buffer = pickle.load(handle)