In [1]:
import copy
import sys

from torch import nn
from torch.nn import functional as F

sys.path.append('..')

from src.agents import RainbowAgent, EzExplorerAgent, SurprisalExplorerAgent
from src.agents import SFPredictor
from src.agents.Rainbow import DEFAULT_RAINBOW_ARGS
from src.envs import *
from src.training import *
from src.models import *

In [2]:
env = create_simple_gridworld_env(False, 100)
# env = create_crazy_climber_env()

In [3]:
custom_encoder = None
if env.observation_space.shape[1] <= 42:
  custom_encoder = create_gridworld_convs(env.observation_space.shape[0])

In [4]:
if env.observation_space.shape[1] <= 32:
    CREATE_CONV_FUNC = create_gridworld_convs
else:
    CREATE_CONV_FUNC = create_atari_convs

class PolicyNetwork(nn.Module):
    def __init__(self, obs_dim, n_acts):
        super().__init__()
        convs = CREATE_CONV_FUNC(obs_dim[0])

        test_input = torch.zeros(1, *obs_dim)
        with torch.no_grad():
            self.encoder_output_size = convs(test_input).view(-1).shape[0]

        self.layers = nn.Sequential(
            convs,
            nn.Flatten(),
            nn.Linear(self.encoder_output_size, 128),
            nn.ReLU(),
            nn.Linear(128, n_acts))

    def forward(self, x):
        return self.layers(x)

class CriticNetwork(nn.Module):
    def __init__(self, obs_dim):
        super().__init__()
        convs = CREATE_CONV_FUNC(obs_dim[0])

        test_input = torch.zeros(1, *obs_dim)
        with torch.no_grad():
            self.encoder_output_size = convs(test_input).view(-1).shape[0]

        self.layers = nn.Sequential(
            convs,
            nn.Flatten(),
            nn.Linear(self.encoder_output_size, 128),
            nn.ReLU(),
            nn.Linear(128, 1))

    def forward(self, x):
        return self.layers(x)

class SFNetwork(nn.Module):
    def __init__(self, obs_dim, embed_dim=64):
        super().__init__()
        convs = CREATE_CONV_FUNC(obs_dim[0])

        test_input = torch.zeros(1, *obs_dim)
        with torch.no_grad():
            self.encoder_output_size = convs(test_input).view(-1).shape[0]
        
        self.encoder = nn.Sequential(
            convs,
            nn.Flatten(),
            nn.Linear(self.encoder_output_size, embed_dim),
            nn.LayerNorm(embed_dim))

        self.sf_predictor = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim))

    def forward(self, x):
        embeds = self.encoder(x)
        sfs = self.sf_predictor(embeds)
        return embeds, sfs

# sf_model = SFNetwork(list(env.observation_space.shape), 64)
# lstate, sfs = sf_model(torch.zeros([2] + list(env.observation_space.shape)))
# print(lstate.shape, sfs.shape)

In [5]:
embed_dim = 256
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

rainbow_args = copy.copy(DEFAULT_RAINBOW_ARGS)
rainbow_args.device = device
# rainbow_args.replay_frequency = 8

sf_model = SFNetwork(list(env.observation_space.shape), embed_dim)
sf_model = sf_model.to(device)
repr_learner = SFPredictor(
    sf_model,
    batch_size = 32,
    update_freq = 16,
    log_freq = 200,
    target_net_update_freq = 64,
    discount_factor = 0.99,
    lr = 1e-4)

In [None]:
policy_net = PolicyNetwork(list(env.observation_space.shape), env.action_space.n)
policy_net = policy_net.to(device)
critic_net = CriticNetwork(list(env.observation_space.shape))
critic_net = critic_net.to(device)

explore_agent = SurprisalExplorerAgent(
    env, policy_net, critic_net, repr_learner, log_freq=50,
    update_freq=2000, batch_size=2000)
# train_exploration_model(explore_agent, env, int(1e6))
train_task_model(explore_agent, env, int(1e6), print_rewards=True)

In [6]:
encoder_copy = copy.deepcopy(sf_model.encoder)
encoder_copy = encoder_copy.to('cpu')

agent = RainbowAgent(rainbow_args, env, encoder_copy, repr_learner=None)
sf_model = sf_model.to(device)

In [7]:
train_task_model(agent, env, int(1e5))

Step: 5000	# Episodes: 50	Avg ep reward: 0.06
Step: 10000	# Episodes: 877	Avg ep reward: 0.99
Step: 15000	# Episodes: 1000	Avg ep reward: 1.00
Step: 20000	# Episodes: 1000	Avg ep reward: 1.00
Step: 25000	# Episodes: 1000	Avg ep reward: 1.00
Step: 30000	# Episodes: 1000	Avg ep reward: 1.00
Step: 35000	# Episodes: 1000	Avg ep reward: 1.00
Step: 40000	# Episodes: 1000	Avg ep reward: 1.00
Step: 45000	# Episodes: 1000	Avg ep reward: 1.00
Step: 50000	# Episodes: 1000	Avg ep reward: 1.00
Step: 55000	# Episodes: 1000	Avg ep reward: 1.00
Step: 60000	# Episodes: 1000	Avg ep reward: 1.00
Step: 65000	# Episodes: 1000	Avg ep reward: 1.00
Step: 70000	# Episodes: 1000	Avg ep reward: 1.00
Step: 75000	# Episodes: 1000	Avg ep reward: 1.00
Step: 80000	# Episodes: 1000	Avg ep reward: 1.00
Step: 85000	# Episodes: 1000	Avg ep reward: 1.00
Step: 90000	# Episodes: 1000	Avg ep reward: 1.00
Step: 95000	# Episodes: 1000	Avg ep reward: 1.00
Step: 100000	# Episodes: 1000	Avg ep reward: 1.00
