In [2]:
import torch

from modules.utils import GlobalConfig, TrainConfig, Logger, paint, get_env, simulate_episode, evaluate
from modules.DQN import DQN

from world.envs import VersusBotEnv
from world.realm import Realm
from world.map_loaders.two_teams import TwoTeamLabyrinthMapLoader, TwoTeamMapLoader, TwoTeamRocksMapLoader
from world.scripted_agents import ClosestTargetAgent
from world.utils import RenderedEnvWrapper

from modules.create_gif import create_gif, get_text_info, create_video_from_gif
from modules.preprocess import preprocess
from modules.reward import Reward

import os
import random
import numpy as np

from IPython.display import clear_output
from dataclasses import dataclass
from matplotlib import pyplot as plt
from collections import defaultdict

global_config = GlobalConfig(
    device='cuda' if torch.cuda.is_available() else 'cpu',
    n_actions=5,
    n_predators=5,
    n_masks=5,
    map_size=40
)

train_config = TrainConfig(
    description='after adding extra inputs',
    max_steps_for_episode=300,
    gamma=0.9,
    initial_steps=300,  # 3000
    steps=100_000,
    steps_per_update=3,
    steps_per_paint=250,  # 500
    steps_per_eval=1000,  # 5000
    buffer_size=10_000,
    batch_size=64,
    learning_rate=0.001,
    eps_start=0.9,
    eps_end=0.05,
    eps_decay=1000,
    tau=0.005,  # the update rate of the target network, was 0.005
    reward_params=dict(
        w_dist_change=-0.5,
        w_kill_prey=1.,
        w_kill_enemy=3.,
        w_kill_bonus=1.3,
        standing_still_penalty=-0.7,
        gamma_for_bonus_count=0.5,
        n_nearest_targets=2,
    ),
    seed=1234
)

  from .autonotebook import tqdm as notebook_tqdm


# Model impovement

In [4]:
# model = DQN(global_config, train_config).to(global_config.device)

# env = get_env(global_config, train_config, 1.0, render_gif=True)
# state, info = env.reset()
# processed_state = preprocess(state, info)
# done = False
# r = Reward(global_config, train_config)
# actions = model.get_actions(processed_state, random=True)
# text_info = [get_text_info(r, info, env, model)]

# while not done:    
#     next_state, done, next_info = env.step(actions)
#     next_processed_state = preprocess(next_state, next_info)
#     _ = r(processed_state, info, next_processed_state, next_info)
#     info, processed_state = next_info, next_processed_state
#     actions = model.get_actions(processed_state, random=True)
#     text_info.append(get_text_info(r, next_info, env, model))  # for display

# create_gif(env, '123.gif', duration=1., text_info=text_info)
create_video_from_gif('123.gif')

In [15]:
model = DQN(global_config, train_config).to(global_config.device)
model.load('logs/39/weights/35k_steps_0.39_score.pt')

torch.save(model, 'model.pt')

In [18]:
model_loaded = DQN(global_config, train_config).to(global_config.device)
model_loaded.load_state_dict(torch.load('logs/41/weights/26k_steps_0.53_score.pt'))
evaluate(model_loaded, 20)

Evaluation:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluation: 100%|██████████| 20/20 [02:36<00:00,  7.85s/it]


0.5102318855018495

In [19]:
model_loaded = DQN(global_config, train_config).to(global_config.device)
model_loaded.load_state_dict(torch.load('logs/41/weights/16k_steps_0.55_score.pt'))
evaluate(model_loaded, 20)

  distance_mask = distance_mask / distance_mask.max()
  distance_mask = distance_mask / distance_mask.max()
Evaluation: 100%|██████████| 20/20 [02:43<00:00,  8.18s/it]


0.48153686894242176

In [20]:
model_loaded = DQN(global_config, train_config).to(global_config.device)
model_loaded.load_state_dict(torch.load('logs/41/weights/28k_steps_0.55_score.pt'))
evaluate(model_loaded, 20)

Evaluation: 100%|██████████| 20/20 [02:49<00:00,  8.47s/it]


0.5380277806732942

In [25]:
model_loaded = DQN(global_config, train_config).to(global_config.device)
model_loaded.load('logs/41/weights/30k_steps_0.51_score.pt')
evaluate(model_loaded, 20)

Evaluation:   0%|          | 0/20 [00:00<?, ?it/s]

Evaluation: 100%|██████████| 20/20 [02:50<00:00,  8.51s/it]


0.5200848616549356

In [28]:
model_loaded = DQN(global_config, train_config).to(global_config.device)
model_loaded.load('logs/41/weights/32k_steps_0.49_score.pt')
evaluate(model_loaded, 10)

Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Evaluation: 100%|██████████| 10/10 [01:19<00:00,  7.93s/it]


0.5294536852435661

In [35]:
model_loaded = DQN(global_config, train_config).to(global_config.device)
model_loaded.load('logs/41/weights/34k_steps_0.53_score.pt')
evaluate(model_loaded, 20)

  distance_mask = distance_mask / distance_mask.max()
  distance_mask = distance_mask / distance_mask.max()
Evaluation: 100%|██████████| 20/20 [02:32<00:00,  7.63s/it]


0.5383105047981548

In [37]:
model_loaded = DQN(global_config, train_config).to(global_config.device)
model_loaded.load('logs/41/weights/36k_steps_0.57_score.pt')
evaluate(model_loaded, 30)

Evaluation: 100%|██████████| 30/30 [04:24<00:00,  8.82s/it]


0.5271020697716284

In [4]:
model_loaded = DQN(global_config, train_config).to(global_config.device)
model_loaded.load('logs/42/weights/32k_steps_0.54_score.pt')
evaluate(model_loaded, 30) # prev was 54.18. next 53.68

Evaluation:   0%|          | 0/30 [00:00<?, ?it/s]

Evaluation: 100%|██████████| 30/30 [03:48<00:00,  7.63s/it]


0.5368669507405168

In [5]:
model_loaded = DQN(global_config, train_config).to(global_config.device)
model_loaded.load('logs/42/weights/38k_steps_0.52_score.pt')
evaluate(model_loaded, 30) 

Evaluation: 100%|██████████| 30/30 [04:19<00:00,  8.65s/it]


0.5126152804838585

In [6]:
model_loaded = DQN(global_config, train_config).to(global_config.device)
model_loaded.load('logs/42/weights/40k_steps_0.5_score.pt')
evaluate(model_loaded, 30) 

Evaluation: 100%|██████████| 30/30 [03:51<00:00,  7.72s/it]


0.5136014512674422