In [1]:
import time, random
from collections import deque
from pathlib import Path
from types import SimpleNamespace as sn

import torch
import numpy as np
from tqdm import trange
from rich import print

import wandb
from wandb.integration.sb3 import WandbCallback

In [2]:
torch.backends.cudnn.deterministic = True

In [3]:
from rainbow import Rainbow, reset_noise
from env_wrappers import create_env
from utils import get_mean_ep_length
from sb3_logger import configure_logger, WandbOutputFormat

In [4]:
#from old_rainbow.common import argp
from old_rainbow.common.rainbow import Rainbow as old_Rainbow
from old_rainbow.common.env_wrappers import create_env as old_create_env #, BASE_FPS_ATARI, BASE_FPS_PROCGEN
from old_rainbow.common.utils import LinearSchedule
from old_rainbow.common.utils import get_mean_ep_length as old_get_mean_ep_length

In [5]:
def same_named(net, old_named, new_named):
    same = True
    for old_p, new_p in zip(old_named, new_named):
        if old_p[1].grad is not None and not torch.allclose(old_p[1].grad, new_p[1].grad):
            same = False
            print(f'{net}: {old_p[0]}.grad != {new_p[0]}.grad: {(old_p[1].grad - new_p[1].grad).abs().max().item()}')
        if not torch.allclose(old_p[1], new_p[1]):
            same = False
            print(f'{net}: {old_p[0]} != {new_p[0]}: {(old_p[1] - new_p[1]).abs().max().item()}')
    return same

def same_net(net, old_net, new_net):
    same_params = same_named(net, old_net.named_parameters(), new_net.named_parameters())
    same_buffers = same_named(net, old_net.named_buffers(), new_net.named_buffers())
    return same_params and same_buffers

def same_model(old_rainbow, new_rainbow):
    same_current = same_net('current', old_rainbow.q_policy, new_rainbow.q_net)
    same_target = same_net('target', old_rainbow.q_target, new_rainbow.q_net_target)
    return same_current and same_target

def same_replays(old_rainbow, new_rainbow):
    if old_rainbow.buffer.size != new_rainbow.replay_buffer.size():
        print(f'Old replay buffer size {old_rainbow.buffer.size} != new size {new_rainbow.replay_buffer.size()}')
        return False
    if hasattr(old_rainbow.buffer, 'max_priority') != hasattr(new_rainbow.replay_buffer, 'max_priority'):
        print('Using different replay buffers (Uniform/Priority)')
        return False
    priority = hasattr(old_rainbow.buffer, 'max_priority')
    
    o_obs, o_nobs, o_a, o_r, o_d = list(zip(*old_rainbow.buffer.data[:old_rainbow.buffer.size]))
    o_obs = np.array([np.array(o) for o in o_obs]) #convert lazy frames to numpy arrays
    o_nobs = np.array([np.array(o) for o in o_nobs])
    o_a = torch.stack(o_a).detach().cpu().numpy()
    o_r = torch.concat(o_r).detach().cpu().numpy()
    o_d = torch.concat(o_d).detach().cpu().numpy()
    
    data = new_rainbow.replay_buffer.observations, new_rainbow.replay_buffer.next_observations, new_rainbow.replay_buffer.actions, new_rainbow.replay_buffer.rewards, new_rainbow.replay_buffer.dones
    n_obs, n_nobs, n_a, n_r, n_d = map(lambda d: d[:new_rainbow.replay_buffer.size()], data)
    n_obs = np.array([np.array(o) for o in n_obs]) #convert lazy frames to numpy arrays
    n_nobs = np.array([np.array(o) for o in n_nobs])
    
    for name, old_v, new_v in zip(('observations', 'next_observations', 'actions', 'rewards', 'dones'), (o_obs, o_nobs, o_a, o_r, o_d), (n_obs, n_nobs, n_a, n_r, n_d)):
        if not (old_v == new_v).all():
            print(f'{name} are not the same')
            return False
    
    if priority:
        if not np.allclose(old_rainbow.buffer.max_priority, new_rainbow.replay_buffer.max_priority):
            print('Max priorities are not close')
            return False
        if not np.allclose(old_rainbow.buffer.priority_sum, new_rainbow.replay_buffer.priority_sum):
            print('Priority sums are not close')
            return False
        if not np.allclose(old_rainbow.buffer.priority_min, new_rainbow.replay_buffer.priority_min):
            print('Priority mins are not close')
            return False
    return True

In [6]:
def set_random(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

In [7]:
def set_up_env(args, old=False):
    set_random(args)
    
    create = old_create_env if old else create_env
    ep_length = old_get_mean_ep_length if old else get_mean_ep_length
    
    decorr_steps = None
    if args.env_name == 'gym:Breakout':
        decorr_steps = 160 // args.parallel_envs
    if args.decorr and not args.env_name.startswith('procgen:') and decorr_steps is None:
        decorr_steps = ep_length(args) // args.parallel_envs
    env = create(args, decorr_steps=decorr_steps)
    #states = env.reset()
    
    return env

In [8]:
def initialize_sb3(args):
    env = set_up_env(args)
    
    model = Rainbow('CnnPolicy',
                    env,
                    buffer_size=args.buffer_size,
                    batch_size=args.batch_size,
                    learning_starts=args.burnin,
                    target_update_interval=args.sync_dqn_target_every,
                    gradient_steps=args.gradient_steps,
                    exploration_fraction=args.exploration_fraction,
                    policy_kwargs={
                        'noisy_linear': args.noisy_linear, 
                        'linear_kwargs': args.linear_kwargs, 
                        'optimizer_class': args.optimizer_class,
                        'optimizer_kwargs': args.optimizer_kwargs,
                        'features_extractor_kwargs': {'model_size': args.model_size}
                    })
    
    return model, env

In [9]:
def add_default_args(args):
    args.prioritized_er_beta0 = 0.45
    args.prioritized_er_time = args.training_frames
    args.use_amp = False
    args.network_arch = f'impala_large:{args.model_size}'
    args.spectral_norm = 'all'
    args.noisy_dqn = args.noisy_linear
    if args.noisy_dqn:
        args.noisy_sigma0 = args.linear_kwargs['sigma_0']
        args.init_eps = 0.002
        args.final_eps = 0.0
        args.eps_decay_frames = max(int(0.002 * args.training_frames), 1)
    else:
        args.init_eps = 1.
        args.final_eps = 0.01
        #args.eps_decay_frames = max(int(0.05 * args.training_frames), 1)
    args.double_dqn = True
    args.prioritized_er = True
    args.n_step = 3
    args.max_grad_norm = 10
    args.lr = 0.00025
    args.adam_eps = 0.005/args.batch_size if args.adam_eps is None else args.adam_eps
    args.lr_decay_steps = None
    args.loss_fn = 'huber'
    args.train_count = args.gradient_steps
    return args

In [10]:
def initialize_old(args):
    env = set_up_env(args, old=True)
    
    args = add_default_args(args)
    
    rainbow = old_Rainbow(env, args)
    
    return rainbow, env

In [11]:
def compare_initialization(args):
    new_model, _ = initialize_sb3(args)
    old_model, _ = initialize_old(args)
    
    try:
        assert same_model(old_model, new_model)
    except Exception as e:
        print(e)
    finally:
        return new_model, old_model

In [12]:
def train_sb3(args):
    model, env = initialize_sb3(args)
    
    model.learn(args.training_frames, progress_bar=True)
    
    env.close()
    
    return model

In [13]:
def train_old(args):
    rainbow, old_env = initialize_old(args)
    
    states = old_env.reset()
    
    eps_schedule = LinearSchedule(0, initial_value=args.init_eps, final_value=args.final_eps, decay_time=args.eps_decay_frames)
    per_beta_schedule = LinearSchedule(0, initial_value=args.prioritized_er_beta0, final_value=1.0, decay_time=args.prioritized_er_time)
    
    episode_count = 0
    returns = deque(maxlen=100)
    discounted_returns = deque(maxlen=10)
    losses = deque(maxlen=10)
    q_values = deque(maxlen=10)
    grad_norms = deque(maxlen=10)
    iter_times = deque(maxlen=10)
    reward_density = 0
    
    returns_all = []
    q_values_all = []
    
    t = trange(0, args.training_frames + 1, args.parallel_envs)
    for game_frame in t:
        iter_start = time.time()
        eps = eps_schedule(game_frame)
        per_beta = per_beta_schedule(game_frame)
    
        # reset the noisy-nets noise in the policy
        if args.noisy_dqn:
            rainbow.reset_noise(rainbow.q_policy)
    
        # compute actions to take in all parallel envs, asynchronously start environment step
        actions = rainbow.act(states, eps)
        old_env.step_async(actions)
    
        # if training has started, perform args.train_count training steps, each on a batch of size args.batch_size
        if rainbow.buffer.burnedin:
            print('Trained')
            for train_iter in range(args.train_count):
                if args.noisy_dqn and train_iter > 0: rainbow.reset_noise(rainbow.q_policy)
                q, loss, grad_norm = rainbow.train(args.batch_size, beta=per_beta)
                losses.append(loss)
                grad_norms.append(grad_norm)
                q_values.append(q)
                q_values_all.append((game_frame, q))
    
        # copy the Q-policy weights over to the Q-target net
        # (see also https://github.com/spragunr/deep_q_rl/blob/master/deep_q_rl/launcher.py#L155)
        if game_frame % args.sync_dqn_target_every == 0 and rainbow.buffer.burnedin:
            rainbow.sync_Q_target()
    
        # block until environments are ready, then collect transitions and add them to the replay buffer
        next_states, rewards, dones, infos = old_env.step_wait()
        
        #transitions.append((states, actions, rewards, dones))
        
        for state, action, reward, done, j in zip(states, actions, rewards, dones, range(args.parallel_envs)):
            reward_density = 0.999 * reward_density + (1 - 0.999) * (reward != 0)
            rainbow.buffer.put(state, action, reward, done, j=j)
        states = next_states
    
        # if any of the envs finished an episode, log stats to wandb
        for info, j in zip(infos, range(args.parallel_envs)):
            if 'episode_metrics' in info.keys():
                episode_metrics = info['episode_metrics']
                returns.append(episode_metrics['return'])
                returns_all.append((game_frame, episode_metrics['return']))
                discounted_returns.append(episode_metrics['discounted_return'])
    
                episode_count += 1
                
        #if game_frame % (10_000-(10_000 % args.parallel_envs)) == 0:
        #    print(f' [{game_frame:>8} frames, {episode_count:>5} episodes] running average return = {np.mean(returns)}')
        #    torch.cuda.empty_cache()
    
        iter_times.append(time.time() - iter_start)
        t.set_description(f' [{game_frame:>8} frames, {episode_count:>5} episodes]', refresh=False)
        
    old_env.close()
    
    return rainbow

In [14]:
def compare_pre_optimization(args):
    args.training_frames = args.burnin - 5 # stop procedure before training starts
    
    new_model = train_sb3(args)
    old_model = train_old(args)
    
    try:
        assert same_model(old_model, new_model)
        assert same_replays(old_model, new_model)
    except Exception as e:
        print(e)
    finally:
        return new_model, old_model

In [15]:
def compare_first_update(args):
    args.training_frames = ((args.burnin // args.parallel_envs) + 5) * args.parallel_envs # stop procedure after first update
    args.gradient_steps = 1 # only do one update, otherwise non-determinism gets amplified and prevents proper comparison
    
    new_model = train_sb3(args)
    old_model = train_old(args)
    
    try:
        assert same_model(old_model, new_model)
        assert same_replays(old_model, new_model)
    except Exception as e:
        print(e)
    finally:
        return new_model, old_model

In [16]:
from argparse import Namespace
args = Namespace(env_name='gym:Breakout',
                 parallel_envs=64,
                 subproc_vecenv=False,
                 time_limit=108_000,
                 frame_stack=4,
                 frame_skip=4,
                 grayscale=True,
                 gamma=0.99,
                 resolution=(84, 84),
                 save_dir='tmp',
                 record_every=60*50,
                 decorr=True,
                 seed=3605)

args.burnin = 500 #100_000
args.buffer_size = 2**12 #2**19
args.batch_size = 256
args.sync_dqn_target_every = 320 #32_000
args.training_frames = 64*11 #1000 #2_000_000

args.eps_decay_frames = 700
args.exploration_fraction = args.eps_decay_frames / (args.training_frames + args.parallel_envs)

args.noisy_linear = False
args.linear_kwargs = {'sigma_0': 0.5} if args.noisy_linear else {}
    
args.adam_eps = None
args.optimizer_kwargs = {'eps': args.adam_eps}
args.model_size = 2
args.gradient_steps = 2

from cbp import CBP
args.optimizer_class = CBP
args.optimizer_kwargs |= {'m':500, 'rho':10**-5}

In [17]:
#new_model, old_model = compare_initialization(args)

In [18]:
#new_model, old_model = compare_pre_optimization(args)

In [19]:
new_model, old_model = compare_first_update(args)

Output()

 [     704 frames,     1 episodes]:  92%|██████████████████████████████████████████▍   | 12/13 [00:01<00:00,  7.83it/s]

 [     768 frames,     1 episodes]: 100%|██████████████████████████████████████████████| 13/13 [00:02<00:00,  5.67it/s]


In [20]:
sum(map(torch.numel, new_model.q_net.features_extractor.parameters()))

389024

In [21]:
sum(map(torch.numel, new_model.q_net.dueling.parameters()))

2098949