In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

proj_root = os.path.dirname(os.path.abspath("."))
# print(proj_root)
sys.path.append(proj_root)

import time
import datetime
import torch
import random
import numpy as np
import os
from pathlib import Path
import argparse

import torch.optim as optim
import torch.nn.functional as F

import gym
import yaml

from minatar_dqn.my_dqn import AgentDQN
from minatar_dqn.replay_buffer import ReplayBuffer
from minatar_dqn.utils.my_logging import setup_logger
from minatar_dqn.models import Conv_QNET, Conv_QNET_one

from experiments.experiment_utils import (
    seed_everything,
    search_files_containing_string,
    split_path_at_substring,
    collect_training_output_files,
)

from minatar_dqn.my_dqn import Conv_QNET, build_environment
from minatar_dqn.redo import (
    apply_redo_parametrization,
    reset_optimizer_states,
    map_layers_to_optimizer_indices,
)
from experiments.experiment_utils import (
    collect_training_output_files,
    collect_pruning_output_files,
)

from experiments.training.training import read_config_files, get_config_paths

from flatten_dict import flatten
import pandas as pd
import seaborn as sns
import scipy

sns.set()

import plotly

plotly.io.kaleido.scope.mathjax = None

import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio

import matplotlib.pyplot as plt

import warnings

warnings.filterwarnings("ignore", category=FutureWarning)

In [3]:
# Setup config
root_dir = os.path.dirname(os.path.abspath("."))

path_experiments_configs = os.path.join(
    root_dir, "experiments", "training", "training_configs"
)
path_experiments_outputs = os.path.join(root_dir, "experiments", "training", "outputs")

default_config_path, experiment_config_paths = get_config_paths(
    path_experiments_configs
)

experiment_configs = read_config_files(default_config_path, experiment_config_paths)

experiment_configs[0]

{'epochs_to_train': 20,
 'seeds': [0],
 'environments': ['breakout'],
 'agent_params': {'agent': 'AgentDQN',
  'args_': {'train_step_cnt': 200000,
   'validation_enabled': True,
   'validation_step_cnt': 125000,
   'validation_epsilon': 0.001,
   'replay_start_size': 5000,
   'batch_size': 32,
   'training_freq': 4,
   'target_model_update_freq': 100,
   'loss_fcn': 'mse_loss',
   'gamma': 0.99,
   'epsilon': {'start': 1.0, 'end': 0.01, 'decay': 250000}}},
 'estimator': {'model': 'Conv_QNET',
  'args_': {'conv_hidden_out_size': 32, 'lin_hidden_out_size': 128}},
 'optim': {'name': 'Adam', 'args_': {'lr': 6.25e-05, 'eps': 0.00015}},
 'replay_buffer': {'max_size': 100000, 'action_dim': 1, 'n_step': 0},
 'redo': {'attach': True,
  'enabled': False,
  'tau': 0.025,
  'beta': 0.1,
  'selection_option': None},
 'reward_perception': None,
 'experiment_name': 'conv32_lin128'}

In [4]:
config = experiment_configs[0]
config["environment"] = "breakout"
config["seed"] = 0

output_path = os.path.join(
    root_dir,
    "experiments",
    "redo",
)

In [5]:
config

{'epochs_to_train': 20,
 'seeds': [0],
 'environments': ['breakout'],
 'agent_params': {'agent': 'AgentDQN',
  'args_': {'train_step_cnt': 200000,
   'validation_enabled': True,
   'validation_step_cnt': 125000,
   'validation_epsilon': 0.001,
   'replay_start_size': 5000,
   'batch_size': 32,
   'training_freq': 4,
   'target_model_update_freq': 100,
   'loss_fcn': 'mse_loss',
   'gamma': 0.99,
   'epsilon': {'start': 1.0, 'end': 0.01, 'decay': 250000}}},
 'estimator': {'model': 'Conv_QNET',
  'args_': {'conv_hidden_out_size': 32, 'lin_hidden_out_size': 128}},
 'optim': {'name': 'Adam', 'args_': {'lr': 6.25e-05, 'eps': 0.00015}},
 'replay_buffer': {'max_size': 100000, 'action_dim': 1, 'n_step': 0},
 'redo': {'attach': True,
  'enabled': False,
  'tau': 0.025,
  'beta': 0.1,
  'selection_option': None},
 'reward_perception': None,
 'experiment_name': 'conv32_lin128',
 'environment': 'breakout',
 'seed': 0}

In [6]:
config["agent_params"]["args_"]["train_step_cnt"] = 20000
config["agent_params"]["args_"]["validation_step_cnt"] = 12500
config["redo"]["beta"] = 1
config["redo"]["tau"] = 0.1

In [7]:
env_name = config["environment"]

logger = setup_logger(
    env_name=env_name,
    identifier_string="redo_test_experiment",
)
logger.info(
    f'Starting up experiment: {config["experiment_name"]}, environment: {config["environment"]}, seed: {config["seed"]}'
)

### Setup environments ###
train_env = build_environment(
    game_name=config["environment"], random_seed=config["seed"]
)
validation_env = build_environment(
    game_name=config["environment"], random_seed=config["seed"]
)

experiment_agent = AgentDQN(
    train_env=train_env,
    validation_env=validation_env,
    experiment_output_folder=output_path,
    experiment_name="redo_test_experiment",
    resume_training_path=None,
    save_checkpoints=True,
    logger=logger,
    config=config,
)

2023-11-01 23:53:54,236 - root - INFO - redo_test_experiment - Starting up experiment: conv32_lin128, environment: breakout, seed: 0
2023-11-01 23:53:54,241 - root - INFO - redo_test_experiment - Loaded configuration settings.
2023-11-01 23:53:54,248 - root - INFO - redo_test_experiment - Initialized newtworks and optimizer.
2023-11-01 23:53:54,248 - root - INFO - redo_test_experiment - Applied redo parametrization to policy model.
2023-11-01 23:53:54,249 - root - INFO - redo_test_experiment - Applied redo parametrization to target model.



[33mWARN: The environment MinAtar/Breakout-v0 is out of date. You should consider upgrading to version `v1`.[0m


[33mWARN: It seems a Box observation space is an image but the `dtype` is not `np.uint8`, actual type: bool. If the Box observation space is not an image, we recommend flattening the observation to have only a 1D vector.[0m


[33mWARN: It seems a Box observation space is an image but the upper and lower bounds are not in [0, 255]. Generally, CNN policies assume observations are within that range, so you may encounter an issue if the observation values are not.[0m



In [8]:
experiment_agent.train(train_epochs=10)

2023-11-01 23:53:54,308 - root - INFO - redo_test_experiment - Starting training session at: 0
2023-11-01 23:53:54,308 - root - INFO - redo_test_experiment - Starting training epoch at t = 0



`np.bool8` is a deprecated alias for `np.bool_`.  (Deprecated NumPy 1.24)



2023-11-01 23:54:19,208 - root - INFO - redo_test_experiment - TRAINING STATS | Frames seen: 20000 | Episode: 1750 | Max reward: 5.0 | Avg reward: 0.532 | Avg frames (episode): 11.425142857142857 | Avg max Q: -69.70157273895877 | Epsilon: 0.9406 | Train epoch time: 0:00:24.895504
2023-11-01 23:54:19,209 - root - INFO - redo_test_experiment - Starting validation epoch at t = 20000
2023-11-01 23:54:24,917 - root - INFO - redo_test_experiment - VALIDATION STATS | Max reward: 1.0 | Avg reward: 0.003864734299516908 | Avg frames (episode): 6.038647342995169 | Avg max Q: -87.59372736653678 | Validation epoch time: 0:00:05.706717
2023-11-01 23:54:24,918 - root - INFO - redo_test_experiment - Saving checkpoint at t = 20000 ...
2023-11-01 23:54:24,922 - root - DEBUG - redo_test_experiment - Models saved at t = 20000
2023-11-01 23:54:24,927 - root - DEBUG - redo_test_experiment - Training status saved at t = 20000
2023-11-01 23:54:26,210 - root - INFO - redo_test_experiment - Checkpoint saved at 

True

In [9]:
# collect a set of states to check Q on before and after redo

eval_states = []

samples_nr = 100
skip_nr = 100
for i in range(samples_nr):
    for j in range(skip_nr):
        # get a random action from the environment
        action = experiment_agent.train_env.action_space.sample()
        s_prime, reward, is_terminated, truncated, info = experiment_agent.train_env.step(
            action
        )
        s_prime = torch.tensor(s_prime, device="cpu").float()

        if is_terminated:
            experiment_agent.train_env.reset()
            continue

    eval_states.append(s_prime)

In [10]:
import copy

eval_states_tensor = torch.stack(eval_states)

predictions_init = experiment_agent.policy_model(eval_states_tensor)
predictions_init

max_q_vals_init = []
for state in eval_states_tensor:
    max_q_val = experiment_agent.get_max_q_val_for_state(state.unsqueeze(0))
    max_q_vals_init.append(max_q_val)

# get the weights before redo
state_dict_init = copy.deepcopy(experiment_agent.policy_model.state_dict())

In [11]:
# apply redo
reset_details = experiment_agent.policy_model.apply_redo()
reset_details

[{'indexes': tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
          18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]),
  'inbound': 'features.conv1',
  'outbound': 'features.conv2'},
 {'indexes': tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
          18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]),
  'inbound': 'features.conv2',
  'outbound': 'fc.lin1'},
 {'indexes': tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
           14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
           28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
           42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
           56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
           70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
           84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  

In [12]:
layer_to_optim_idx = map_layers_to_optimizer_indices(
    experiment_agent.policy_model, experiment_agent.optimizer
)
layer_to_optim_idx

reset_optimizer_states(reset_details, experiment_agent.optimizer, layer_to_optim_idx)

In [13]:
# get the difference between the values before and after
predictions_after = experiment_agent.policy_model(eval_states_tensor)
predictions_after

max_q_vals_after = []
for state in eval_states_tensor:
    max_q_val = experiment_agent.get_max_q_val_for_state(state.unsqueeze(0))
    max_q_vals_after.append(max_q_val)

# get the weights before redo
state_dict_after = experiment_agent.policy_model.state_dict()

In [14]:
diff = np.array(max_q_vals_init) - np.array(max_q_vals_after)
sum(diff)

-6218.0907286554575

In [18]:
max_q_vals_init

[-64.28768157958984,
 -64.709228515625,
 -68.58647918701172,
 -67.16693878173828,
 -83.45682525634766,
 -63.717437744140625,
 -63.442543029785156,
 -66.2746353149414,
 -61.1007080078125,
 -58.52145004272461,
 -58.478302001953125,
 -59.55796813964844,
 -60.09797286987305,
 -61.80244064331055,
 -62.32272720336914,
 -63.70777130126953,
 -64.34134674072266,
 -65.33708953857422,
 -64.15252685546875,
 -72.15023040771484,
 -63.717437744140625,
 -63.442543029785156,
 -62.82538604736328,
 -61.1007080078125,
 -58.52145004272461,
 -58.478302001953125,
 -59.55796813964844,
 -60.09797286987305,
 -61.80244064331055,
 -62.32272720336914,
 -66.83450317382812,
 -64.34134674072266,
 -66.03160095214844,
 -76.37629699707031,
 -75.23518371582031,
 -64.28768157958984,
 -68.33306884765625,
 -68.58647918701172,
 -67.16693878173828,
 -83.45682525634766,
 -63.662872314453125,
 -64.76863861083984,
 -65.44690704345703,
 -63.429786682128906,
 -59.09043884277344,
 -58.70704650878906,
 -59.836082458496094,
 -60.7317

In [19]:
max_q_vals_after

[0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.06966846436262131,
 0.0696684

In [17]:
for layer in state_dict_after:
    if not torch.equal(state_dict_init[layer], state_dict_after[layer]):
        print(f"Weights changed in layer: {layer}")

Weights changed in layer: features.conv1.weight
Weights changed in layer: features.conv1.bias
Weights changed in layer: features.relu1.running_avg
Weights changed in layer: features.relu1.running_avg_cnt
Weights changed in layer: features.conv2.weight
Weights changed in layer: features.conv2.bias
Weights changed in layer: features.relu2.running_avg
Weights changed in layer: features.relu2.running_avg_cnt
Weights changed in layer: fc.lin1.weight
Weights changed in layer: fc.lin1.bias
Weights changed in layer: fc.relu3.running_avg
Weights changed in layer: fc.relu3.running_avg_cnt
Weights changed in layer: fc.lin2.weight
