In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

proj_root = os.path.dirname(os.path.abspath("."))
# print(proj_root)
sys.path.append(proj_root)

import time
import datetime
import torch
import random
import numpy as np
import os
from pathlib import Path
import argparse

import torch.optim as optim
import torch.nn.functional as F

import gym
import yaml

from minatar_dqn.my_dqn import AgentDQN
from minatar_dqn.replay_buffer import ReplayBuffer
from minatar_dqn.utils.my_logging import setup_logger
from minatar_dqn.models import Conv_QNET, Conv_QNET_one

from experiments.experiment_utils import (
    seed_everything,
    search_files_containing_string,
    split_path_at_substring,
    collect_training_output_files,
)

from minatar_dqn.my_dqn import Conv_QNET, build_environment
from minatar_dqn.redo import (
    apply_redo_parametrization,
    reset_optimizer_states,
    map_layers_to_optimizer_indices,
)
from experiments.experiment_utils import (
    collect_training_output_files,
    collect_pruning_output_files,
)

from experiments.training.training import read_config_files, get_config_paths

from flatten_dict import flatten
import pandas as pd
import seaborn as sns
import scipy

sns.set()

import plotly

plotly.io.kaleido.scope.mathjax = None

import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio

import matplotlib.pyplot as plt

import warnings

warnings.filterwarnings("ignore", category=FutureWarning)

In [3]:
# Setup config
root_dir = os.path.dirname(os.path.abspath("."))

path_experiments_configs = os.path.join(
    root_dir, "experiments", "training", "training_configs"
)
path_experiments_outputs = os.path.join(root_dir, "experiments", "training", "outputs")

default_config_path, experiment_config_paths = get_config_paths(
    path_experiments_configs
)

experiment_configs = read_config_files(default_config_path, experiment_config_paths)

experiment_configs[0]

{'epochs_to_train': 20,
 'seeds': [0],
 'environments': ['breakout'],
 'agent_params': {'agent': 'AgentDQN',
  'args_': {'train_step_cnt': 200000,
   'validation_enabled': True,
   'validation_step_cnt': 125000,
   'validation_epsilon': 0.001,
   'replay_start_size': 5000,
   'batch_size': 32,
   'training_freq': 4,
   'target_model_update_freq': 100,
   'loss_fcn': 'mse_loss',
   'gamma': 0.99,
   'epsilon': {'start': 1.0, 'end': 0.01, 'decay': 250000}}},
 'estimator': {'model': 'Conv_QNET',
  'args_': {'conv_hidden_out_size': 32, 'lin_hidden_out_size': 128}},
 'optim': {'name': 'Adam', 'args_': {'lr': 6.25e-05, 'eps': 0.00015}},
 'replay_buffer': {'max_size': 100000, 'action_dim': 1, 'n_step': 0},
 'redo': {'attach': True,
  'enabled': False,
  'tau': 0.025,
  'beta': 0.1,
  'selection_option': None},
 'reward_perception': None,
 'experiment_name': 'conv32_lin128'}

In [4]:
config = experiment_configs[0]
config["environment"] = "breakout"
config["seed"] = 0

output_path = os.path.join(
    root_dir,
    "experiments",
    "redo",
)

In [5]:
config

{'epochs_to_train': 20,
 'seeds': [0],
 'environments': ['breakout'],
 'agent_params': {'agent': 'AgentDQN',
  'args_': {'train_step_cnt': 200000,
   'validation_enabled': True,
   'validation_step_cnt': 125000,
   'validation_epsilon': 0.001,
   'replay_start_size': 5000,
   'batch_size': 32,
   'training_freq': 4,
   'target_model_update_freq': 100,
   'loss_fcn': 'mse_loss',
   'gamma': 0.99,
   'epsilon': {'start': 1.0, 'end': 0.01, 'decay': 250000}}},
 'estimator': {'model': 'Conv_QNET',
  'args_': {'conv_hidden_out_size': 32, 'lin_hidden_out_size': 128}},
 'optim': {'name': 'Adam', 'args_': {'lr': 6.25e-05, 'eps': 0.00015}},
 'replay_buffer': {'max_size': 100000, 'action_dim': 1, 'n_step': 0},
 'redo': {'attach': True,
  'enabled': False,
  'tau': 0.025,
  'beta': 0.1,
  'selection_option': None},
 'reward_perception': None,
 'experiment_name': 'conv32_lin128',
 'environment': 'breakout',
 'seed': 0}

In [6]:
config["agent_params"]["args_"]["train_step_cnt"] = 20000
config["agent_params"]["args_"]["validation_step_cnt"] = 12500
config["redo"]["beta"] = 1

In [7]:
env_name = config["environment"]

logger = setup_logger(
    env_name=env_name,
    identifier_string="redo_test_experiment",
)
logger.info(
    f'Starting up experiment: {config["experiment_name"]}, environment: {config["environment"]}, seed: {config["seed"]}'
)

### Setup environments ###
train_env = build_environment(
    game_name=config["environment"], random_seed=config["seed"]
)
validation_env = build_environment(
    game_name=config["environment"], random_seed=config["seed"]
)

experiment_agent = AgentDQN(
    train_env=train_env,
    validation_env=validation_env,
    experiment_output_folder=output_path,
    experiment_name="redo_test_experiment",
    resume_training_path=None,
    save_checkpoints=True,
    logger=logger,
    config=config,
)

2023-11-01 21:56:22,918 - root - INFO - redo_test_experiment - Starting up experiment: conv32_lin128, environment: breakout, seed: 0
2023-11-01 21:56:22,923 - root - INFO - redo_test_experiment - Loaded configuration settings.
2023-11-01 21:56:22,924 - root - INFO - redo_test_experiment - Initialized newtworks and optimizer.
2023-11-01 21:56:22,924 - root - INFO - redo_test_experiment - Applied redo parametrization to policy model.
2023-11-01 21:56:22,924 - root - INFO - redo_test_experiment - Applied redo parametrization to target model.



[33mWARN: The environment MinAtar/Breakout-v0 is out of date. You should consider upgrading to version `v1`.[0m


[33mWARN: It seems a Box observation space is an image but the `dtype` is not `np.uint8`, actual type: bool. If the Box observation space is not an image, we recommend flattening the observation to have only a 1D vector.[0m


[33mWARN: It seems a Box observation space is an image but the upper and lower bounds are not in [0, 255]. Generally, CNN policies assume observations are within that range, so you may encounter an issue if the observation values are not.[0m



In [8]:
experiment_agent.train(train_epochs=20)

2023-11-01 21:56:22,999 - root - INFO - redo_test_experiment - Starting training session at: 0
2023-11-01 21:56:22,999 - root - INFO - redo_test_experiment - Starting training epoch at t = 0



`np.bool8` is a deprecated alias for `np.bool_`.  (Deprecated NumPy 1.24)



2023-11-01 21:56:46,164 - root - INFO - redo_test_experiment - TRAINING STATS | Frames seen: 20000 | Episode: 1801 | Max reward: 6.0 | Avg reward: 0.4997223764575236 | Avg frames (episode): 11.10494169905608 | Avg max Q: -72.96229731319742 | Epsilon: 0.9406 | Train epoch time: 0:00:23.160869
2023-11-01 21:56:46,164 - root - INFO - redo_test_experiment - Starting validation epoch at t = 20000
2023-11-01 21:56:51,578 - root - INFO - redo_test_experiment - VALIDATION STATS | Max reward: 2.0 | Avg reward: 0.5438756855575868 | Avg frames (episode): 11.438756855575868 | Avg max Q: -87.82619462405682 | Validation epoch time: 0:00:05.410856
2023-11-01 21:56:51,579 - root - INFO - redo_test_experiment - Saving checkpoint at t = 20000 ...
2023-11-01 21:56:51,582 - root - DEBUG - redo_test_experiment - Models saved at t = 20000
2023-11-01 21:56:51,588 - root - DEBUG - redo_test_experiment - Training status saved at t = 20000
2023-11-01 21:56:52,842 - root - INFO - redo_test_experiment - Checkpoin

True

In [10]:
# collect a set of states to check Q on before and after redo

eval_states = []
for i in range(100):
    # get a random action from the environment
    action = experiment_agent.train_env.action_space.sample()
    s_prime, reward, is_terminated, truncated, info = experiment_agent.train_env.step(
        action
    )
    s_prime = torch.tensor(s_prime, device="cpu").float()

    if is_terminated:
        experiment_agent.train_env.reset()
        continue

    else:
        eval_states.append(s_prime)

In [11]:
import copy

eval_states_tensor = torch.stack(eval_states)

predictions_init = experiment_agent.policy_model(eval_states_tensor)
predictions_init

max_q_vals_init = []
for state in eval_states_tensor:
    max_q_val = experiment_agent.get_max_q_val_for_state(state.unsqueeze(0))
    max_q_vals_init.append(max_q_val)

# get the weights before redo
state_dict_init = copy.deepcopy(experiment_agent.policy_model.state_dict())

In [12]:
# apply redo
reset_details = experiment_agent.policy_model.apply_redo()
reset_details

[{'indexes': tensor([ 0,  4,  7,  8, 10, 11, 13, 16, 17, 19, 20, 23, 24, 25, 26, 28, 29, 30]),
  'inbound': 'features.conv1',
  'outbound': 'features.conv2'},
 {'indexes': tensor([ 5,  7, 11, 20, 23, 29]),
  'inbound': 'features.conv2',
  'outbound': 'fc.lin1'},
 {'indexes': tensor([  0,   1,   2,   4,   5,   6,   7,   8,   9,  10,  11,  12,  15,  16,
           18,  19,  20,  21,  22,  23,  24,  26,  28,  29,  30,  31,  32,  33,
           34,  35,  36,  37,  38,  39,  41,  42,  43,  45,  46,  47,  48,  49,
           50,  51,  52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,
           64,  65,  66,  67,  68,  70,  71,  72,  73,  74,  75,  76,  77,  79,
           81,  83,  84,  85,  86,  87,  88,  90,  91,  92,  93,  96,  97,  98,
           99, 100, 101, 102, 103, 104, 106, 107, 108, 109, 110, 111, 112, 113,
          114, 116, 117, 118, 121, 122, 123, 124, 125, 126, 127]),
  'inbound': 'fc.lin1',
  'outbound': 'fc.lin2'}]

In [13]:
layer_to_optim_idx = map_layers_to_optimizer_indices(
    experiment_agent.policy_model, experiment_agent.optimizer
)
layer_to_optim_idx

reset_optimizer_states(reset_details, experiment_agent.optimizer, layer_to_optim_idx)

In [15]:
# get the difference between the values before and after
predictions_after = experiment_agent.policy_model(eval_states_tensor)
predictions_after

max_q_vals_after = []
for state in eval_states_tensor:
    max_q_val = experiment_agent.get_max_q_val_for_state(state.unsqueeze(0))
    max_q_vals_after.append(max_q_val)

# get the weights before redo
state_dict_after = experiment_agent.policy_model.state_dict()

In [16]:
diff = np.array(max_q_vals_init) - np.array(max_q_vals_after)
sum(diff)

-2929.767496109009

In [18]:
max_q_vals_init

[-60.39997863769531,
 -60.462135314941406,
 -62.19587707519531,
 -66.5285873413086,
 -59.822235107421875,
 -61.21587371826172,
 -61.4658203125,
 -62.421051025390625,
 -68.46141815185547,
 -63.34773254394531,
 -58.512962341308594,
 -58.057411193847656,
 -56.5676383972168,
 -58.349308013916016,
 -59.57539367675781,
 -59.264286041259766,
 -57.844608306884766,
 -54.27384948730469,
 -67.55104064941406,
 -61.71174621582031,
 -49.753257751464844,
 -56.757232666015625,
 -63.935829162597656,
 -72.09099578857422,
 -70.3806381225586,
 -72.32054138183594,
 -87.32705688476562,
 -55.98558044433594,
 -53.44734573364258,
 -53.56566619873047,
 -54.94139099121094,
 -55.80167007446289,
 -60.07108688354492,
 -63.54494094848633,
 -66.42874145507812,
 -67.56927490234375,
 -73.65245056152344,
 -78.41280364990234,
 -83.2714614868164,
 -86.79576873779297,
 -80.65792846679688,
 -85.84222412109375,
 -52.30204391479492,
 -56.55108642578125,
 -57.007503509521484,
 -56.17133331298828,
 -73.1443862915039,
 -53.16310

In [19]:
max_q_vals_after

[-29.0242919921875,
 -29.39809799194336,
 -32.019535064697266,
 -33.661067962646484,
 -30.81812286376953,
 -31.3681640625,
 -33.19950485229492,
 -32.55437469482422,
 -31.29922866821289,
 -29.91390609741211,
 -29.478397369384766,
 -28.429950714111328,
 -28.31380844116211,
 -30.242435455322266,
 -33.16775131225586,
 -32.690006256103516,
 -30.675861358642578,
 -31.174137115478516,
 -33.79548645019531,
 -31.235971450805664,
 -29.314197540283203,
 -29.421100616455078,
 -29.99227523803711,
 -32.75470733642578,
 -34.542198181152344,
 -35.31593704223633,
 -34.324790954589844,
 -28.495248794555664,
 -31.971786499023438,
 -32.03752517700195,
 -33.806541442871094,
 -29.913204193115234,
 -30.655254364013672,
 -33.32368087768555,
 -32.15947341918945,
 -32.25957107543945,
 -33.42523956298828,
 -35.668067932128906,
 -35.38266372680664,
 -36.02852249145508,
 -35.106266021728516,
 -33.81298065185547,
 -27.989044189453125,
 -32.409278869628906,
 -32.46217346191406,
 -33.55458068847656,
 -31.390003204345

In [17]:
for layer in state_dict_after:
    if not torch.equal(state_dict_init[layer], state_dict_after[layer]):
        print(f"Weights changed in layer: {layer}")

Weights changed in layer: features.conv1.weight
Weights changed in layer: features.conv1.bias
Weights changed in layer: features.relu1.running_avg
Weights changed in layer: features.relu1.running_avg_cnt
Weights changed in layer: features.conv2.weight
Weights changed in layer: features.conv2.bias
Weights changed in layer: features.relu2.running_avg
Weights changed in layer: features.relu2.running_avg_cnt
Weights changed in layer: fc.lin1.weight
Weights changed in layer: fc.lin1.bias
Weights changed in layer: fc.relu3.running_avg
Weights changed in layer: fc.relu3.running_avg_cnt
Weights changed in layer: fc.lin2.weight
