In [1]:
#!/usr/bin/env python3
import gymnasium as gym

from collections import namedtuple
import numpy as np

from tensorboardX import SummaryWriter

import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
device = torch.device("cpu")
device

device(type='cpu')

In [3]:
HIDDEN_SIZE = 128
BATCH_SIZE = 100
PERCENTILE = 30
GAMMA = 0.9

In [4]:
class DiscreteOneHotWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super(DiscreteOneHotWrapper, self).__init__(env)
        assert isinstance(env.observation_space,
                          gym.spaces.Discrete)
        shape = (env.observation_space.n, )
        self.observation_space = gym.spaces.Box(
            0.0, 1.0, shape, dtype=np.float32)

    def observation(self, observation):
        res = np.copy(self.observation_space.low)
        res[observation] = 1.0
        return res

In [5]:
class MLP(nn.Module):
    def __init__(self, obs_size, hidden_size, n_actions):
        super(MLP, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_actions)
        )

    def forward(self, x):
        return self.net(x)

In [6]:
Episode = namedtuple('Episode', field_names=['reward', 'steps'])
EpisodeStep = namedtuple('EpisodeStep', field_names=['observation', 'action'])

In [7]:
def iterate_batches(env, net, batch_size):
    batch = []
    episode_reward = 0.0
    episode_steps = []
    obs, _ = env.reset()
    sm = nn.Softmax(dim=1).to(device)
    while True:
        obs_v = torch.FloatTensor([obs]).to(device)
        act_probs_v = sm(net(obs_v))
        act_probs = act_probs_v.detach().cpu().numpy()[0]
        action = np.random.choice(len(act_probs), p=act_probs)
        next_obs, reward, is_done, truncated, _ = env.step(action)
        episode_reward += reward
        episode_steps.append(EpisodeStep(observation=obs, action=action))
        if is_done or truncated:
            batch.append(Episode(reward=episode_reward, steps=episode_steps))
            episode_reward = 0.0
            episode_steps = []
            next_obs, _ = env.reset()
            if len(batch) == batch_size:
                yield batch
                batch = []
        obs = next_obs

In [8]:
def filter_batch(batch, percentile):
    filter_fun = lambda s: s.reward * (GAMMA ** len(s.steps))
    disc_rewards = list(map(filter_fun, batch))
    reward_bound = np.percentile(disc_rewards, percentile)

    train_obs = []
    train_act = []
    elite_batch = []
    for example, discounted_reward in zip(batch, disc_rewards):
        if discounted_reward > reward_bound:
            train_obs.extend(map(lambda step: step.observation,
                                 example.steps))
            train_act.extend(map(lambda step: step.action,
                                 example.steps))
            elite_batch.append(example)

    return elite_batch, train_obs, train_act, reward_bound

In [9]:
from gymnasium.envs.toy_text.frozen_lake import generate_random_map

env = DiscreteOneHotWrapper(gym.make("FrozenLake-v1", desc=generate_random_map(size=4), is_slippery=True))

# env = gym.wrappers.Monitor(env, directory="mon", force=True)
obs_size = env.observation_space.shape[0]
n_actions = env.action_space.n

net = MLP(obs_size, HIDDEN_SIZE, n_actions).to(device)

objective = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=net.parameters(), lr=0.01)
writer = SummaryWriter(comment="-frozenlake-mlp")

In [10]:
'number of parameters', sum(p.numel() for p in net.parameters())

('number of parameters', 2692)

In [11]:
full_batch = []

with torch.device(device):
    for iter_no, batch in enumerate(iterate_batches(env, net, BATCH_SIZE)):
        
        reward_mean = float(np.mean(list(map(lambda s: s.reward, batch))))

        full_batch, obs, acts, reward_bound = filter_batch(full_batch + batch, PERCENTILE)

        if not full_batch:
            continue

        obs_v = torch.FloatTensor(obs).to(device)
        acts_v = torch.LongTensor(acts).to(device)
        full_batch = full_batch[-500:]

        optimizer.zero_grad()

        action_scores_v = net(obs_v)

        loss_v = objective(action_scores_v, acts_v)

        loss_v.backward()
        optimizer.step()

        print("%d: loss=%.3f, rw_mean=%.3f, rw_bound=%.3f, batch=%d" % (
            iter_no, loss_v.item(), reward_mean, reward_bound, len(full_batch))
        )

        writer.add_scalar("loss", loss_v.item(), iter_no)
        writer.add_scalar("reward_mean", reward_mean, iter_no)
        writer.add_scalar("reward_bound", reward_bound, iter_no)

        if reward_mean > 0.8:
            print("Solved!")
            break

    writer.close()

  obs_v = torch.FloatTensor([obs]).to(device)


2: loss=1.415, rw_mean=0.010, rw_bound=0.000, batch=1
3: loss=1.339, rw_mean=0.010, rw_bound=0.000, batch=2
4: loss=1.282, rw_mean=0.000, rw_bound=0.000, batch=2
5: loss=1.226, rw_mean=0.000, rw_bound=0.000, batch=2
6: loss=1.198, rw_mean=0.010, rw_bound=0.000, batch=3
7: loss=1.154, rw_mean=0.000, rw_bound=0.000, batch=3
8: loss=1.115, rw_mean=0.000, rw_bound=0.000, batch=3
9: loss=1.083, rw_mean=0.000, rw_bound=0.000, batch=3
10: loss=1.130, rw_mean=0.020, rw_bound=0.000, batch=5
11: loss=1.121, rw_mean=0.010, rw_bound=0.000, batch=6
12: loss=1.076, rw_mean=0.010, rw_bound=0.000, batch=7
13: loss=1.066, rw_mean=0.010, rw_bound=0.000, batch=8
14: loss=1.103, rw_mean=0.020, rw_bound=0.000, batch=10
15: loss=1.070, rw_mean=0.020, rw_bound=0.000, batch=12
16: loss=1.040, rw_mean=0.010, rw_bound=0.000, batch=13
17: loss=1.034, rw_mean=0.000, rw_bound=0.000, batch=13
18: loss=1.035, rw_mean=0.020, rw_bound=0.000, batch=15
19: loss=1.000, rw_mean=0.010, rw_bound=0.000, batch=16
20: loss=0.9

KeyboardInterrupt: 

In [12]:
# %load_ext autoreload
# %autoreload 2

from collections import OrderedDict
from lib.architecture import Search
from lib.sample import SampleNormal, SampleUniform

def create_model(input_size, hidden_size, output_size):
    encoder = nn.Sequential(
        nn.Linear(input_size, hidden_size),
        nn.LayerNorm(hidden_size)
    )

    search = Search(
        transition=nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 2*hidden_size),
        ),
        fitness=nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        ),
        sample=nn.Sequential(
            SampleNormal(hidden_size, num_samples=4),
            nn.LayerNorm(hidden_size)
        ),
        max_depth=2,
        beam_width=4
    )

    decoder = nn.Sequential(
        nn.Linear(hidden_size, output_size)
    )

    model = nn.Sequential(OrderedDict([
        ('encoder', encoder),
        ('search', search),
        ('decoder', decoder)
    ]))

    return model

In [13]:
from gymnasium.envs.toy_text.frozen_lake import generate_random_map

env = DiscreteOneHotWrapper(gym.make("FrozenLake-v1", desc=generate_random_map(size=4), is_slippery=True))

# env = gym.wrappers.Monitor(env, directory="mon", force=True)
obs_size = env.observation_space.shape[0]
n_actions = env.action_space.n

net = create_model(obs_size, 16, n_actions).to(device)

objective = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=net.parameters(), lr=0.01)
writer = SummaryWriter(comment="-frozenlake-search")

In [14]:
'number of parameters', sum(p.numel() for p in net.parameters())

('number of parameters', 1509)

In [15]:
full_batch = []

temperature = 3.0
gamma = 0.99

with torch.device(device):
    for iter_no, batch in enumerate(iterate_batches(env, net, BATCH_SIZE)):
        
        reward_mean = float(np.mean(list(map(lambda s: s.reward, batch))))

        full_batch, obs, acts, reward_bound = filter_batch(full_batch + batch, PERCENTILE)

        if not full_batch:
            continue

        obs_v = torch.FloatTensor(obs).to(device)
        acts_v = torch.LongTensor(acts).to(device)
        full_batch = full_batch[-500:]

        optimizer.zero_grad()

        temperature = max(temperature * gamma, 1.0)

        net.search.set_temperature(temperature)

        action_scores_v = net(obs_v)

        loss_v = objective(action_scores_v, acts_v)

        loss_v.backward()
        optimizer.step()

        print("%d: loss=%.3f, rw_mean=%.3f, rw_bound=%.6f, batch=%d" % (
            iter_no, loss_v.item(), reward_mean, reward_bound, len(full_batch))
        )

        writer.add_scalar("loss", loss_v.item(), iter_no)
        writer.add_scalar("reward_mean", reward_mean, iter_no)
        writer.add_scalar("reward_bound", reward_bound, iter_no)

        if reward_mean > 0.8:
            print("Solved!")
            break

    writer.close()

  return func(*args, **kwargs)


0: loss=1.374, rw_mean=0.070, rw_bound=0.000000, batch=7
1: loss=1.411, rw_mean=0.090, rw_bound=0.000000, batch=16
2: loss=1.375, rw_mean=0.040, rw_bound=0.000000, batch=20
3: loss=1.380, rw_mean=0.050, rw_bound=0.000000, batch=25
4: loss=1.384, rw_mean=0.050, rw_bound=0.000000, batch=30
5: loss=1.369, rw_mean=0.070, rw_bound=0.000000, batch=37
6: loss=1.357, rw_mean=0.070, rw_bound=0.000000, batch=44
7: loss=1.370, rw_mean=0.040, rw_bound=0.000000, batch=48
8: loss=1.364, rw_mean=0.060, rw_bound=0.000000, batch=54
9: loss=1.362, rw_mean=0.040, rw_bound=0.000000, batch=58
10: loss=1.363, rw_mean=0.080, rw_bound=0.000000, batch=66
11: loss=1.360, rw_mean=0.050, rw_bound=0.000000, batch=71
12: loss=1.349, rw_mean=0.070, rw_bound=0.000000, batch=78
13: loss=1.360, rw_mean=0.090, rw_bound=0.000000, batch=87
14: loss=1.362, rw_mean=0.050, rw_bound=0.000000, batch=92
15: loss=1.361, rw_mean=0.070, rw_bound=0.000000, batch=99
16: loss=1.353, rw_mean=0.100, rw_bound=0.000000, batch=109
17: los

KeyboardInterrupt: 