<a href="https://colab.research.google.com/github/debarghaBhattacharjee/drl_for_tbg/blob/main/drl_for_tbg.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Deep Reinforcement Learning for Text-based Games

## Install necessary libraries

In [None]:
!sudo apt update && sudo apt install build-essential libffi-dev python3-dev curl git

Get:1 https://cloud.r-project.org/bin/linux/ubuntu bionic-cran40/ InRelease [3,626 B]
Ign:2 https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64  InRelease
Get:3 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64  InRelease [1,581 B]
Hit:4 http://archive.ubuntu.com/ubuntu bionic InRelease
Get:5 http://security.ubuntu.com/ubuntu bionic-security InRelease [88.7 kB]
Hit:6 https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64  Release
Get:7 http://archive.ubuntu.com/ubuntu bionic-updates InRelease [88.7 kB]
Get:8 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64  Packages [910 kB]
Get:9 http://ppa.launchpad.net/c2d4u.team/c2d4u4.0+/ubuntu bionic InRelease [15.9 kB]
Get:10 http://archive.ubuntu.com/ubuntu bionic-backports InRelease [74.6 kB]
Get:12 http://security.ubuntu.com/ubuntu bionic-security/main amd64 Packages [2,937 kB]
Get:13 http://archive.ubuntu.com/ubuntu bionic-up

In [None]:
!pip install textworld

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting textworld
  Downloading textworld-1.5.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (6.7 MB)
[K     |████████████████████████████████| 6.7 MB 3.9 MB/s 
Collecting jericho>=3.0.3
  Downloading jericho-3.1.0.tar.gz (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 54.5 MB/s 
Collecting mementos>=1.3.1
  Downloading mementos-1.3.1-py2.py3-none-any.whl (12 kB)
Collecting hashids>=1.2.0
  Downloading hashids-1.3.1-py2.py3-none-any.whl (6.6 kB)
Collecting tatsu>=4.3.0
  Downloading TatSu-4.4.0-py2.py3-none-any.whl (85 kB)
[K     |████████████████████████████████| 85 kB 4.5 MB/s 
Building wheels for collected packages: jericho
  Building wheel for jericho (setup.py) ... [?25l[?25hdone
  Created wheel for jericho: filename=jericho-3.1.0-py3-none-any.whl size=333859 sha256=9090969b227f91f29ced3f3b3861d8d4d7b9b6b77fbdad6289a9d06bfcda501c
 

## Create coin-collector game

In [None]:
# !tw-make tw-coin_collector --level 5 -f -v --seed 0 --output "/content/drive/MyDrive/text_world_games/tw_games/training/coin_collector-l5.ulx"

In [None]:
# !tw-extract vocab -f -v "/content/drive/MyDrive/text_world_games/tw_games/training/coin_collector-l5.ulx" --output "/content/drive/MyDrive/text_world_games/tw_games/training/coin_collector-l5-vocab.txt"

## Play text-based games using DRL agent

### Basic imports

In [None]:
import re
import os
import glob
import time, datetime
import pickle
from tqdm import tqdm

import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from torch.nn.utils.rnn import pad_sequence
from torch.utils.tensorboard import SummaryWriter

### Configuration/Setup file

In [None]:
class Config:
    def __init__(self, warm_start=False):
        # Experimental setup
        self.game_name = "coin_collector-easy-l5-seed_1"  # Default: "coin_collector-l5"
        self.root_dir = "/content/drive/MyDrive/text_world_games"
        self.game_file = f"{self.root_dir}/tw_games/training/{self.game_name}.ulx"
        # self.vocab_file = f"{self.root_dir}/tw_games/training/{self.game_name}-vocab.txt"  # Default
        self.vocab_file = "/content/drive/MyDrive/text_world_games/tw_games/training/coin_collector-l5-vocab.txt"  # Explicitly specified
        self.nb_epochs = 200
        self.nb_episodes = 50
        self.max_episode_steps = 50
        self.result_dir = f"{self.root_dir}/{self.game_name}/result_dir"
        self.log_dir = f"{self.root_dir}/{self.game_name}/log_dir"
        self.save_log_freq = 5
        self.ckpt_dir = f"{self.root_dir}/{self.game_name}/model_dir"
        self.save_ckpt_freq = 10
        self.plot_dir = f"{self.root_dir}/{self.game_name}/plot_dir"

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu'
        )
        print(f"Device available: {self.device}")

        # Actor network parameters
        self.max_seq_len = 128
        self.embedding_size = 64
        self.embeddings = None # Specify full path to embeddings file (if it exists)
        self.freeze_embeddings = False
        self.hidden_size = 64
        self.nb_layers = 1
        self.a_lr = 1e-3 
        if warm_start:
            # Reduce by a factor of 10 (i.e., 1/10th) when warm starting 
            self.a_lr = self.a_lr/10
        self.max_grad_norm = 5.0
        self.t_update_freq = 5

        # Replay buffer parameters
        self.alpha_beta_replay = True
        self.max_buffer_len = 1_000_000 # alpha buffer max len + beta buffer max len 
        self.alpha_storage_fraction = 0.25
        self.alpha_sampling_fraction = 0.50
        self.alpha_threshold = 0.00
        self.replay_batch_size = 128
        self.gamma = 0.99

        ## Action selection strategy
        self.epsilon_init = 1.00
        self.epsilon_end = 0.05
        self.e_greedy_stop = 0.40 * self.nb_epochs
        self.epsilon_decay_rate = (
            self.epsilon_init - self.epsilon_end
        ) / (self.e_greedy_stop - 1)

### Create text world game environment object

In [None]:
import gym
import textworld
import textworld.gym

  from collections import defaultdict, Mapping


In [None]:
class TextWorldGame:
    def __init__(self, config):
        self.name = config.game_name
        self.game_files = [config.game_file]
        if os.path.isfile(config.game_file):
            self.game_files = glob.glob(config.game_file)
        self.max_episode_steps = config.max_episode_steps
        self.env_id = textworld.gym.register_games(
            gamefiles=self.game_files,
            request_infos=self.request_infos(),
            max_episode_steps=self.max_episode_steps
        )
        self.env = gym.make(self.env_id, new_step_api=True, disable_env_checker=True)

        if os.path.isfile(config.vocab_file):
            self.vocab_files = glob.glob(config.vocab_file)
            assert(len(self.vocab_files) ==  1)
        self.vocab = self.get_vocab(vocab_file=self.vocab_files[0])


    def get_game_info(self):
        game_info = {
            "env": self.env,
            "name": self.name,
            "max_episode_steps": self.max_episode_steps,
            "vocab": self.vocab
        }
        return game_info


    def get_vocab(self, vocab_file):
        vocab = list()
        with open(vocab_file, "r") as f:
            tokens = f.readlines()
            vocab = [re.sub(r"\s", "", token.lower()) for token in tokens]
            vocab.remove("")
        return vocab


    def get_game_controls(self):
        _, info = self.env.reset() 
        game_controls = {
            "verbs": info.get("verbs", None),
            "entities": info.get("entities", None)
        }
        return game_controls
        

    def request_infos(self):
        request_infos = textworld.EnvInfos()
        request_infos.description = True
        request_infos.last_command = True
        request_infos.won = True
        request_infos.verbs = True
        request_infos.entities = True
        return request_infos

### Random agent

In [None]:
class RandomAgent:
    def __init__(self, config, vocab, verbs, entities):
        self.agent = "random_agent"
        self.name = config.game_name
        self.nb_epochs = config.nb_epochs
        self.nb_episodes = config.nb_episodes
        self.max_episode_steps = config.max_episode_steps

        # Results directory
        self.result_dir = f"{config.result_dir}/{self.agent}"
        if not os.path.isdir(self.result_dir):
            os.makedirs(self.result_dir)

        # Log directory
        self.log_dir = f"{config.log_dir}/{self.agent}"
        if not os.path.isdir(self.log_dir):
            os.makedirs(self.log_dir)
        self.summary_writer = SummaryWriter(log_dir=self.log_dir)
        self.save_log_freq = config.save_log_freq

        # Checkpoint directory
        self.ckpt_dir = f"{config.model_dir}/{self.agent}"
        if not os.path.isdir(self.ckpt_dir):
            os.makedirs(self.ckpt_dir)
        self.save_ckpt_freq = config.save_model_freq

        # Vocab: for dealing with the state space
        self.id2tok = {0:"<pad>", 1:"<unk>", 2:"<sos>", 3:"<eos>"}
        for (id, tok) in enumerate(vocab, start=4):
            self.id2tok[id] = tok
        self.tok2id = {tok:id for (id, tok) in enumerate(self.id2tok)}

        # Verbs and entities: for dealing with action space
        self.id2verb = {id:verb for (id, verb) in enumerate(verbs)}
        self.verb2id = {verb:id for (id, verb) in enumerate(verbs)}
        self.id2entity = {id:entity for (id, entity) in enumerate(entities)}
        self.entity2id = {entity:id for (id, entity) in enumerate(entities)}

        # # Annealing epsilon-greedy policy configuration
        # self.epsilon_init = config.epsilon_init
        # self.epsilon_end = config.epsilon_end
        # self.epsilon_decay_rate = config.epsilon_decay_rate
        # self.epsilon = config.epsilon

    
    def save_data(self, data, file_name, directory):
        # Create directory if it does not exist.
        if not os.path.isdir(directory):
            os.makedirs(directory)

        # Save data as pickled object.
        pickle_out = open(
            f"{directory}/{file_name}", "wb"
        )
        pickle.dump(data, pickle_out)
        pickle_out.close()
        return


    def get_command(self):
        """
        This method returns the command that should be issued
        by the random agent.
        """
        # Select verb
        v_id = np.random.randint(low=0, high=len(self.verb2id), size=1)
        verb = self.id2verb.get(v_id.item(), None)
        # Select entity
        e_id = np.random.randint(low=0, high=len(self.entity2id), size=1)
        entity = self.id2entity.get(e_id.item(), None)
        command = f"{verb} {entity}"
        return command

    def test(self, test_env):
        test_stats = {
            "avg_reward": [],
            "avg_steps" : [],
            "avg_wins"  : []
        }

        # Initialize empty list for storing episode info at 
        # the start of every epoch.
        e_rewards = []
        e_steps   = []
        e_wins    = []

        for episode in range(self.nb_episodes):
            # Reset at the beginning of every episode.
            obs, info = test_env.reset()
            prev_score = 0
            cum_sum_reward = 0
            nb_steps = 0
            done = False
            won = False

            while not done:
                command = self.get_command()
                obs, score, done, _, info = test_env.step(command)
                reward = score - prev_score
                cum_sum_reward += reward
                nb_steps += 1
                won = info["won"]
                if nb_steps >= self.max_episode_steps:
                    break
                prev_score = score
            
            e_rewards.append(cum_sum_reward) 
            e_steps.append(nb_steps)
            e_wins.append(won)

        test_stats["avg_reward"] = np.nanmean(e_rewards) 
        test_stats["avg_steps"] = np.nanmean(e_steps)
        test_stats["avg_wins"] = np.nanmean(e_wins)
        return test_stats


    def play(self, test_env):
        """
        Play the game using the agent.
        """
        ep_stats = {
            "avg_reward_test": [],
            "avg_steps_test" : [],
            "avg_wins_test"  : []
        }

        epoch_offset = 0
        for epoch in tqdm(range(epoch_offset+1, self.nb_epochs+1), unit=" epoch"):
            start = time.time()
            ep_test_stats = self.test(test_env)
            if ((epoch == 0) or ((epoch % self.save_log_freq) == 0)):
                msg = f"Epoch: {epoch}\n" + \
                    f"Test avg. reward: {np.around(ep_test_stats['avg_reward'], 2)} " + \
                    f"Test avg. steps: {np.around(ep_test_stats['avg_steps'], 2)} " + \
                    f"Test avg. wins: {np.around(ep_test_stats['avg_wins'], 2)}"
                print(msg)
                print("=" * 75)
            self.summary_writer.add_scalar("avg. reward", ep_test_stats["avg_reward"])
            self.summary_writer.add_scalar("avg. steps", ep_test_stats["avg_steps"])
            self.summary_writer.add_scalar("avg. wins", ep_test_stats["avg_wins"])
            
            ep_stats["avg_reward_test"].append(ep_test_stats["avg_reward"])
            ep_stats["avg_steps_test"].append(ep_test_stats["avg_steps"])
            ep_stats["avg_wins_test"].append(ep_test_stats["avg_wins"])
            self.save_data(
                data=ep_stats, file_name="ep_stats", 
                directory=self.result_dir
            )
        return ep_stats

#### Test random agent

Set random generator seed for reproducibility.

In [None]:
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f092f570f10>

Get test environment and random agent

In [None]:
play_random_agent = False  # Set to True if you want to play using random agent.
if play_random_agent:
    config = Config()
    text_world_game = TextWorldGame(config=config)
    random_agent = RandomAgent(
        config=config,
        vocab=text_world_game.get_game_info().get("vocab", None), 
        verbs=["go", "take"], # Pruned the verb space to simplify the problem
        # verbs=text_world_game.get_game_controls().get("verbs", None), 
        entities=text_world_game.get_game_controls().get("entities", None), 
    )
    test_env = text_world_game.get_game_info().get("env", None)

Play random agent in the test environment

In [None]:
if play_random_agent:
    random_agent_stats = random_agent.play(test_env=test_env)
    random_agent_stats_df = pd.DataFrame(random_agent_stats)
    random_agent_stats_df["epoch"] = np.arange(start=1, stop=random_agent_stats_df.shape[0]+1)
    random_agent_stats_df

In [None]:
if play_random_agent:
    plot_dir = f"{config.plot_dir}/random_agent"
    if not os.path.isdir(plot_dir):
        os.makedirs(plot_dir)

    sns.lineplot(x="epoch", y="avg_reward_test", data=random_agent_stats_df);
    plt.savefig(f"{plot_dir}/avg_reward_test.jpeg");

    sns.lineplot(x="epoch", y="avg_steps_test", data=random_agent_stats_df);
    plt.savefig(f"{plot_dir}/avg_steps_test.jpeg");

    sns.lineplot(x="epoch", y="avg_wins_test", data=random_agent_stats_df);
    plt.savefig(f"{plot_dir}/avg_wins_test.jpeg");

### RL Agent
The RL agent uses a neural network function approximator as the policy network. Specifically, we use the LSTM-DQN model. Unlike the random agent, the RL agent leverages intelligent decision making for playing text-based games.

#### Replay memory

In [None]:
from collections import namedtuple, deque

In [None]:
Experience = namedtuple("Experience", (
    "cur_state", "verb", "entity", "reward",
    "next_state", "done"
))

In [None]:
class AlphaBetaReplayBuffer(object):
    def __init__(self, max_len=1_000_000, alpha_storage_fraction=0.25, 
                 alpha_sampling_fraction=0.50):
        super(AlphaBetaReplayBuffer, self).__init__()
        self.max_len = max_len
        self.alpha_storage_fraction = alpha_storage_fraction
        self.alpha_sampling_fraction = alpha_sampling_fraction
        # Alpha buffer modules
        # -------------------------------------------------------------
        self.alpha_max_len = int(self.max_len * self.alpha_storage_fraction)
        self.alpha_buffer = deque(maxlen=self.alpha_max_len)
        # Beta buffer modules
        # --------------------------------------------------------------
        self.beta_max_len = self.max_len - self.alpha_max_len
        self.beta_buffer = deque(maxlen=self.beta_max_len)

    def push(self, is_alpha=False, *args):
        """Adds an experience to replay buffer."""
        if is_alpha:
            self.alpha_buffer.append(Experience._make(*args))
        else:
            self.beta_buffer.append(Experience._make(*args))

    def sample(self, batch_size):
        alpha_batch_size = int(self.alpha_sampling_fraction * batch_size)
        alpha_sample_size = min(len(self.alpha_buffer), alpha_batch_size)
        alpha_samples = random.sample(self.alpha_buffer, alpha_sample_size)
        beta_batch_size = batch_size - alpha_batch_size
        beta_sample_size = min(len(self.beta_buffer), beta_batch_size)
        beta_samples = random.sample(self.beta_buffer, beta_sample_size)
        samples = alpha_samples + beta_samples
        random.shuffle(samples)
        return samples

#### Actor loss function
This is the loss function for the DQN algorithm based on the least mean squared error. 

In [None]:
class ActorLoss(torch.nn.Module):
    def __init__(self):
        super(ActorLoss, self).__init__()
        
    def forward(self, vq_error, eq_error):
        """
        INPUT
        ------------------------------
        vq_error shape = (replay_batch_size, 1)
        eq_error shape = (replay_batch_size, 1)
        """
        vq_loss = torch.square(vq_error)
        vq_loss = (1/2) * vq_loss.mean()
        eq_loss = torch.square(eq_error)
        eq_loss = (1/2) * eq_loss.mean()
        actor_loss = vq_loss + eq_loss
        return actor_loss

#### Actor

In [None]:
class StateEncoder(torch.nn.Module):
    def __init__(self, vocab_size, embedding_size, embeddings, 
                 freeze_embeddings, hidden_size, nb_layers):
        super(StateEncoder, self).__init__()
        self.vocab_size = vocab_size
        # Embedding module
        if embeddings is not None:
            self.embedding = torch.nn.Embedding(
                num_embeddings=vocab_size, embedding_dim=embedding_size
            ).from_pretrained(
                embeddings=embeddings, freeze=freeze_embeddings
            )
        else:
            self.embedding = torch.nn.Embedding(
                num_embeddings=vocab_size, embedding_dim=embedding_size
            )

        # Input sequence encoder module
        self.lstm = torch.nn.LSTM(
            input_size=embedding_size, hidden_size=hidden_size,
            num_layers=nb_layers, batch_first=True,
            dropout=0.0, bidirectional=True
        )
        self.encoder = torch.nn.Linear(
            in_features=2*hidden_size, out_features=hidden_size
        )

    def forward(self, in_seq):
        batch_size, mts = in_seq.shape
        in_seq_mts = in_seq.shape[1]
        in_seq_embed = self.embedding(in_seq)
        in_seq_mask = torch.sign(in_seq)
        in_seq_len = torch.sum(in_seq_mask, dim=1)
        lstm_in = pack_padded_sequence(
            input=in_seq_embed,
            lengths=in_seq_len.cpu(),
            batch_first=True,
            enforce_sorted=False
        )
        enc_seq, _ = self.lstm(lstm_in)
        enc_seq, _ = pad_packed_sequence(
            sequence=enc_seq,
            batch_first=True,
            total_length=in_seq_mts
        ) # [b, t, d]
        enc_seq = torch.mean(enc_seq, dim=1)
        enc_seq = self.encoder(enc_seq)
        return enc_seq

In [None]:
class ActionValueEstimator(torch.nn.Module):
    def __init__(self, verb_space_size, entity_space_size, hidden_size):
        super(ActionValueEstimator, self).__init__()
        # Action-value estimator for verbs
        self.vq_estimator_init = torch.nn.Linear(
            in_features=hidden_size, out_features=hidden_size//2
        )
        self.vq_estimator_final = torch.nn.Linear(
            in_features=hidden_size//2, out_features=verb_space_size
        )
        # Action-value estimator for entities
        self.eq_estimator_init = torch.nn.Linear(
            in_features=hidden_size, out_features=hidden_size//2
        )
        self.eq_estimator_final = torch.nn.Linear(
            in_features=hidden_size//2, out_features=entity_space_size
        )

    def forward(self, in_state):
        # Obtain the action-values for output verbs.
        vq_vals = self.vq_estimator_init(in_state)
        vq_vals = self.vq_estimator_final(vq_vals)

        # Obtain the action-values for output entities.
        eq_vals = self.eq_estimator_init(in_state)
        eq_vals = self.eq_estimator_final(eq_vals)

        return vq_vals, eq_vals

In [None]:
class Actor(torch.nn.Module):
    def __init__(self, vocab_size, embedding_size, embeddings, 
                 freeze_embeddings, hidden_size, nb_layers,
                 verb_space_size, entity_space_size):
        super(Actor, self).__init__()
        self.state_encoder = StateEncoder(
            vocab_size=vocab_size, 
            embedding_size=embedding_size, 
            embeddings=embeddings, 
            freeze_embeddings=freeze_embeddings, 
            hidden_size=hidden_size, 
            nb_layers=nb_layers
        )
        self.action_value_estimator = ActionValueEstimator(
            verb_space_size=verb_space_size, 
            entity_space_size=entity_space_size, 
            hidden_size=hidden_size
        )

    def forward(self, in_seq):
        enc_state = self.state_encoder(in_seq=in_seq)
        vq_vals, eq_vals = self.action_value_estimator(in_state=enc_state)
        return vq_vals, eq_vals

#### RL agent

In [None]:
class RLAgent:
    def __init__(self, config, vocab, verbs, entities):
        self.agent = "rl_agent-complex"
        self.name = config.game_name
        self.device = config.device
        self.nb_epochs = config.nb_epochs
        self.nb_episodes = config.nb_episodes
        self.max_episode_steps = config.max_episode_steps

        # Results directory
        self.result_dir = f"{config.result_dir}/{self.agent}"
        if not os.path.isdir(self.result_dir):
            os.makedirs(self.result_dir)

        # Log directory
        self.log_dir = f"{config.log_dir}/{self.agent}"
        if not os.path.isdir(self.log_dir):
            os.makedirs(self.log_dir)
        self.summary_writer = SummaryWriter(log_dir=self.log_dir)
        self.save_log_freq = config.save_log_freq

        # Checkpoint directory
        self.ckpt_dir = f"{config.ckpt_dir}/{self.agent}"
        if not os.path.isdir(self.ckpt_dir):
            os.makedirs(self.ckpt_dir)
        self.save_ckpt_freq = config.save_ckpt_freq

        # Vocab: for dealing with the state space
        self.id2token = {0:"<pad>", 1:"<unk>", 2:"<sos>", 3:"<eos>"}
        for (id, token) in enumerate(vocab, start=4):
            self.id2token[id] = token
        self.token2id = {token:id for (id, token) in self.id2token.items()}

        # Verbs and entities: for dealing with action space
        self.id2verb = {id:verb for (id, verb) in enumerate(verbs)}
        self.verb2id = {verb:id for (id, verb) in enumerate(verbs)}
        self.id2entity = {id:entity for (id, entity) in enumerate(entities)}
        self.entity2id = {entity:id for (id, entity) in enumerate(entities)}

        # Actor network
        self.max_seq_len = config.max_seq_len
        self.actor_m = Actor(
            vocab_size=len(self.token2id), 
            embedding_size=config.embedding_size, 
            embeddings=config.embeddings, 
            freeze_embeddings=config.freeze_embeddings, 
            hidden_size=config.hidden_size, 
            nb_layers=config.nb_layers,
            verb_space_size=len(self.verb2id), 
            entity_space_size=len(self.entity2id)
        ).to(config.device)
        self.max_grad_norm = config.max_grad_norm
        self.gamma = config.gamma
        self.a_optimizer = torch.optim.Adam(
            params=self.actor_m.parameters(), 
            lr=config.a_lr
        )
        self.a_criterion = ActorLoss()
        self.actor_t = Actor(
            vocab_size=len(self.token2id), 
            embedding_size=config.embedding_size, 
            embeddings=config.embeddings, 
            freeze_embeddings=config.freeze_embeddings, 
            hidden_size=config.hidden_size, 
            nb_layers=config.nb_layers,
            verb_space_size=len(self.verb2id), 
            entity_space_size=len(self.entity2id)
        ).to(config.device)
        self.actor_t.load_state_dict(self.actor_m.state_dict())
        self.t_update_freq = config.t_update_freq

        # Replay buffer parameters
        self.alpha_threshold = config.alpha_threshold
        self.replay_buffer = AlphaBetaReplayBuffer(
            max_len=config.max_buffer_len,
            alpha_storage_fraction=config.alpha_storage_fraction,
            alpha_sampling_fraction=config.alpha_sampling_fraction
        )
        self.replay_batch_size = config.replay_batch_size

        # Annealing epsilon-greedy policy configuration
        self.epsilon_init = config.epsilon_init
        self.epsilon_end = config.epsilon_end
        self.epsilon_decay_rate = config.epsilon_decay_rate
        self.epsilon = self.epsilon_init

        # Load checkpoint (if required).
        self.epoch_offset = 0
        ckpt_available = False  # Set to True if you want to load checkpoint.
        if ckpt_available == True:
            self.load_ckpt()


    def load_ckpt(self):
        """
        This methop loads a previously saved checkpoint (if required 
        and avialble) and resumes training from that point. 
        """
        ckpt = torch.load(f"{self.ckpt_dir}/ckpt")
        self.epoch_offset = ckpt["epoch"]
        self.actor_m.load_state_dict(ckpt["actor_m_state_dict"])
        self.actor_t.load_state_dict(ckpt["actor_m_state_dict"])
        self.a_optimizer.load_state_dict(ckpt["a_optimizer_state_dict"])
        self.epsilon = ckpt["epsilon"]
        print(f"Loaded checkpoint from: {self.ckpt_dir}/ckpt")
        return


    def save_ckpt(self, epoch):
        ckpt = {
            "epoch": epoch,
            "actor_m_state_dict": self.actor_m.state_dict(),
            "a_optimizer_state_dict": self.a_optimizer.state_dict(),
            "epsilon": self.epsilon
        }
        torch.save(ckpt, f"{self.ckpt_dir}/ckpt")
        return


    def copy_actor_weights(self, src_path=None):
        """
        This method copy weights from a another rl agent's 
        actor network to current agent's actor network.
        """
        if src_path is None:
            print("No source provided. Couldn't copy weights.")
            return
        src_actor = torch.load(src_path)
        self.actor_m.load_state_dict(src_actor["actor_m_state_dict"])
        self.actor_t.load_state_dict(src_actor["actor_m_state_dict"])
        print(f"Sucessfully copied weights from: {src_path}")
        return


    def save_data(self, data, file_name, directory):
        # Create directory if it does not exist.
        if not os.path.isdir(directory):
            os.makedirs(directory)

        # Save data as pickled object.
        pickle_out = open(
            f"{directory}/{file_name}", "wb"
        )
        pickle.dump(data, pickle_out)
        pickle_out.close()
        return


    def clean_text(self, text):
        text = re.sub(r"[^a-zA-Z0-9\-'<> ]", r" ", text.lower())
        return text.strip()
        
    
    def get_token_id(self, token):
        return self.token2id.get(token, self.token2id["<unk>"])
    

    def tokenize_text(self, text):
        token_ids = list(map(self.get_token_id, text.split()))
        return token_ids


    def make_tensor(self, text):
        token_ids_list = self.tokenize_text(text)
        tensor = torch.tensor([token_ids_list]).to(self.device)
        return tensor


    def get_padded_tensors(self, tensors_list):
        # print(f"tensors_list: \n{tensors_list}")
        # print("-" * 75)
        padded_tensors = pad_sequence(
            sequences=tensors_list,
            batch_first=True,
            padding_value=self.token2id["<pad>"]
        ).to(self.device)
        return padded_tensors


    def get_train_command(self, in_seq):
        """
        This method returns the command that should be issued
        by the RL agent in the training phase.
        """
        self.actor_m.eval()
        batch_size = in_seq.shape[0]
        command = "" 
        action_ids = {"v_id": 0, "e_id": 0}
        with torch.no_grad():
            vq_vals, eq_vals = self.actor_m(in_seq=in_seq)
            prob = random.random()
            if prob > self.epsilon:
                _, v_id = vq_vals.topk(k=1, dim=1)
                _, e_id = eq_vals.topk(k=1, dim=1)
            else:
                v_id = torch.randint(low=0, high=vq_vals.shape[1], size=(batch_size, 1))
                e_id = torch.randint(low=0, high=eq_vals.shape[1], size=(batch_size, 1))
            # Select verb
            verb = self.id2verb.get(v_id.item(), None)
            # Select entity
            entity = self.id2entity.get(e_id.item(), None)
            command = f"{verb} {entity}"
            action_ids["v_id"] = v_id
            action_ids["e_id"] = e_id
        return command, action_ids

    
    def save_experience(self, experience, is_alpha):
        """
        This methods saves the experience in the in either the
        stream or beta stream of the replay buffer depending on
        is_alpha flag.
        """
        self.replay_buffer.push(is_alpha, experience)
        return


    def sample_experiences(self):
        # Sample replay_batch_size no. of stored experiences from Replay Memory.
        replay_batch = self.replay_buffer.sample(batch_size=self.replay_batch_size)
        if (len(replay_batch) < self.replay_batch_size):
            return None
        
        # Extract necessary information.
        cur_state  = [data.cur_state for data in replay_batch] # Current state
        verb       = [data.verb for data in replay_batch]  # Verb
        entity     = [data.entity for data in replay_batch]  # Entity
        reward     = [data.reward for data in replay_batch]  # Reward
        next_state = [data.next_state for data in replay_batch]  # Next state
        done       = [data.done for data in replay_batch]  # Done

        # Get tensor representaions.
        cur_state_tensor = self.get_padded_tensors(cur_state)
        verb_tensor = torch.tensor(verb).view(self.replay_batch_size, 1).to(self.device)
        entity_tensor = torch.tensor(entity).view(self.replay_batch_size, 1).to(self.device)
        reward_tensor = torch.tensor(reward).view(self.replay_batch_size, 1).to(self.device)
        next_state_tensor = self.get_padded_tensors(next_state)
        done_tensor = torch.tensor(done).view(self.replay_batch_size, 1).to(self.device)
        return cur_state_tensor, verb_tensor, entity_tensor, \
            reward_tensor, next_state_tensor, done_tensor

    
    def compute_actor_loss(self, replay_batch):
        cur_state_tensor, verb_tensor, entity_tensor, \
        reward_tensor, next_state_tensor, done_tensor = replay_batch
        pred_vq_vals, pred_eq_vals = self.actor_m(in_seq=cur_state_tensor)
        pred_vq_val = pred_vq_vals.gather(index=verb_tensor, dim=1)
        pred_eq_val = pred_eq_vals.gather(index=entity_tensor, dim=1)

        next_vq_vals, next_eq_vals = self.actor_m(in_seq=next_state_tensor)
        next_v_id = next_vq_vals.argmax(dim=1).unsqueeze(dim=1)
        next_e_id = next_eq_vals.argmax(dim=1).unsqueeze(dim=1)

        next_vq_vals, next_eq_vals = self.actor_t(in_seq=next_state_tensor)
        next_vq_val = next_vq_vals.gather(index=next_v_id, dim=1)
        trg_vq_val = reward_tensor.add(
            self.gamma * next_vq_val.mul(done_tensor.logical_not())
        )
        vq_error = trg_vq_val.sub(pred_vq_val)
        next_eq_val = next_eq_vals.gather(index=next_e_id, dim=1)
        trg_eq_val = reward_tensor.add(
            self.gamma * next_eq_val.mul(done_tensor.logical_not())
        )
        eq_error = trg_eq_val.sub(pred_eq_val)

        actor_loss = self.a_criterion(vq_error=vq_error, eq_error=eq_error)
        return actor_loss


    def optimize_actor(self, actor_loss):
        self.a_optimizer.zero_grad()
        actor_loss.backward()
        torch.nn.utils.clip_grad_norm(
            self.actor_m.parameters(), self.max_grad_norm
        )
        self.a_optimizer.step()
        return

    
    def update_main_actor(self):
        self.actor_m.train()
        replay_batch = self.sample_experiences()
        if replay_batch is None:
            return
        actor_loss = self.compute_actor_loss(replay_batch)
        self.optimize_actor(actor_loss=actor_loss)
        return


    def train(self, train_env):
        train_stats = {
            "avg_reward": [],
            "avg_steps" : [],
            "avg_wins"  : []
        }

        # Initialize empty list for storing episode info at 
        # the start of every epoch.
        e_rewards = []
        e_steps   = []
        e_wins    = []

        for episode in range(self.nb_episodes):
            # Reset at the beginning of every episode.
            obs, _ = train_env.reset()
            prev_score = 0
            cum_sum_reward = 0
            nb_steps = 0
            done = False
            won = False
            cur_state_text = "<sos> " + obs + " <eos>"
            cur_state_text = self.clean_text(text=cur_state_text)
            cur_state_tensor = self.make_tensor(text=cur_state_text)
            
            while not done:
                # Play one step using RL agent.
                command, action_ids = self.get_train_command(in_seq=cur_state_tensor)
                v_id = action_ids.get("v_id", None)
                e_id = action_ids.get("e_id", None)
                obs, score, done, _, info = train_env.step(command)
                next_state_text = ""
                obs = self.clean_text(obs)
                info["description"] = self.clean_text(info["description"])
                info["last_command"] = self.clean_text(info["last_command"])
                if obs == info["description"]:
                    next_state_text = "<sos> " + info["last_command"] + " <eos> " + \
                        "<sos> " + info["description"] + " <eos>"
                else:
                    next_state_text = "<sos> " + info["last_command"] + " <eos> " + \
                        "<sos> " + obs + " <eos> " + \
                        "<sos> " + info["description"] + " <eos>"
                next_state_text = self.clean_text(next_state_text)
                next_state_tensor = self.make_tensor(next_state_text)
                reward = score - prev_score
                cum_sum_reward += reward
                nb_steps += 1
                won = info["won"]

                # Save the experience.
                experience = (
                    cur_state_tensor.detach().cpu().view(-1),
                    v_id, e_id, reward,
                    next_state_tensor.detach().cpu().view(-1),
                    done
                )
                if reward > self.alpha_threshold:
                    self.save_experience(experience, is_alpha=True)
                else:
                    self.save_experience(experience, is_alpha=False)

                # Sample batch from replay biffer and update
                # main actor parameters.
                self.update_main_actor()

                # Update cur_state and prev_score if not reached
                # max_episode_steps.
                if nb_steps >= self.max_episode_steps:
                    break
                cur_state_tensor = next_state_tensor
                prev_score = score

            e_rewards.append(cum_sum_reward) 
            e_steps.append(nb_steps)
            e_wins.append(won)

        train_stats["avg_reward"] = np.nanmean(e_rewards) 
        train_stats["avg_steps"] = np.nanmean(e_steps)
        train_stats["avg_wins"] = np.nanmean(e_wins)
        return train_stats


    def get_test_command(self, in_seq):
        """
        This method returns the command that should be issued
        by the RL agent in the testing phase.
        """
        self.actor_m.eval()
        batch_size = in_seq.shape[0]
        command = "" 
        action_ids = {"v_id": 0, "e_id": 0}
        with torch.no_grad():
            vq_vals, eq_vals = self.actor_m(in_seq=in_seq)
            _, v_id = vq_vals.topk(k=1, dim=1)
            _, e_id = eq_vals.topk(k=1, dim=1)
            # Select verb
            verb = self.id2verb.get(v_id.item(), None)
            # Select entity
            entity = self.id2entity.get(e_id.item(), None)
            command = f"{verb} {entity}"
            action_ids["v_id"] = v_id
            action_ids["e_id"] = e_id
        return command, action_ids

    
    def test(self, test_env):
        test_stats = {
            "avg_reward": [],
            "avg_steps" : [],
            "avg_wins"  : []
        }

        # Initialize empty list for storing episode info at 
        # the start of every epoch.
        e_rewards = []
        e_steps   = []
        e_wins    = []

        for episode in range(self.nb_episodes):
            # Reset at the beginning of every episode.
            obs, _ = test_env.reset()
            prev_score = 0
            cum_sum_reward = 0
            nb_steps = 0
            done = False
            won = False
            cur_state_text = "<sos> " + obs + " <eos>"
            cur_state_text = self.clean_text(text=cur_state_text)
            cur_state_tensor = self.make_tensor(text=cur_state_text)

            while not done:
                # Play one step using RL agent.
                command, action_ids = self.get_test_command(in_seq=cur_state_tensor)
                v_id = action_ids.get("v_id", None)
                e_id = action_ids.get("e_id", None)
                obs, score, done, _, info = test_env.step(command)
                next_state_text = ""
                obs = self.clean_text(obs)
                info["description"] = self.clean_text(info["description"])
                info["last_command"] = self.clean_text(info["last_command"])
                if obs == info["description"]:
                    next_state_text = "<sos> " + info["last_command"] + " <eos> " + \
                        "<sos> " + info["description"] + " <eos>"
                else:
                    next_state_text = "<sos> " + info["last_command"] + " <eos> " + \
                        "<sos> " + obs + " <eos> " + \
                        "<sos> " + info["description"] + " <eos>"
                next_state_text = self.clean_text(next_state_text)
                next_state_tensor = self.make_tensor(next_state_text)
                reward = score - prev_score
                cum_sum_reward += reward
                nb_steps += 1
                won = info["won"]
                if nb_steps >= self.max_episode_steps:
                    break
                cur_state_tensor = next_state_tensor
                prev_score = score
            
            e_rewards.append(cum_sum_reward) 
            e_steps.append(nb_steps)
            e_wins.append(won)

        test_stats["avg_reward"] = np.nanmean(e_rewards) 
        test_stats["avg_steps"] = np.nanmean(e_steps)
        test_stats["avg_wins"] = np.nanmean(e_wins)
        return test_stats


    def play(self, train_env, test_env):
        """
        Play the game using the agent.
        """
        ep_stats = {
            "avg_reward_train": [],
            "avg_steps_train" : [],
            "avg_wins_train"  : [],
            "avg_reward_test" : [],
            "avg_steps_test"  : [],
            "avg_wins_test"   : []
        }

        for epoch in tqdm(range(self.epoch_offset+1, self.nb_epochs+1), unit="epoch"):
            start = time.time()
            if epoch > 1:
                self.epsilon = max(
                    self.epsilon_end, self.epsilon - self.epsilon_decay_rate
                )
            ep_train_stats = self.train(train_env)
            ep_test_stats = self.test(test_env)
            if ((epoch == 1) or ((epoch % self.save_log_freq) == 0)):
                msg = f"\nEpoch: {epoch}\n" + \
                    f"Train avg. reward: {np.around(ep_train_stats['avg_reward'], 2)} " + \
                    f"Train avg. steps: {np.around(ep_train_stats['avg_steps'], 2)} " + \
                    f"Train avg. wins: {np.around(ep_train_stats['avg_wins'], 2)}\n" + \
                    f"Test avg. reward: {np.around(ep_test_stats['avg_reward'], 2)} " + \
                    f"Test avg. steps: {np.around(ep_test_stats['avg_steps'], 2)} " + \
                    f"Test avg. wins: {np.around(ep_test_stats['avg_wins'], 2)}\n" + \
                    f"Epsilon: {np.around(self.epsilon, 2)}"
                print(msg)
                print("=" * 75)

            self.summary_writer.add_scalar("avg. reward", ep_test_stats["avg_reward"])
            self.summary_writer.add_scalar("avg. steps", ep_test_stats["avg_steps"])
            self.summary_writer.add_scalar("avg. wins", ep_test_stats["avg_wins"])

            # Save checkpoint periodically.
            if ((epoch == 1) or ((epoch % self.save_ckpt_freq) == 0)):
                self.save_ckpt(epoch=epoch)
            # Update target Actor network parameters periodically.
            if ((epoch == 1) or ((epoch % self.t_update_freq) == 0)):
                self.actor_t.load_state_dict(self.actor_m.state_dict())
                
            ep_stats["avg_reward_train"].append(ep_train_stats["avg_reward"])
            ep_stats["avg_steps_train"].append(ep_train_stats["avg_steps"])
            ep_stats["avg_wins_train"].append(ep_train_stats["avg_wins"])
            ep_stats["avg_reward_test"].append(ep_test_stats["avg_reward"])
            ep_stats["avg_steps_test"].append(ep_test_stats["avg_steps"])
            ep_stats["avg_wins_test"].append(ep_test_stats["avg_wins"])
            self.save_data(
                data=ep_stats, file_name="ep_stats", 
                directory=self.result_dir
            )
        return ep_stats

#### Test RL agent

Set random generator seed for reproducibility.

In [None]:
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f092f570f10>

Get train and test environments, and Rl agent

In [None]:
play_rl_agent = False  # Set to True if you want to play using Rl agent.
if play_rl_agent:
    config = Config()
    train_twg = TextWorldGame(config=config)
    rl_agent = RLAgent(
        config=config,
        vocab=train_twg.get_game_info().get("vocab", None), 
        # verbs=["go", "take"], # Pruned the verb space to simplify the problem
        verbs=train_twg.get_game_controls().get("verbs", None), 
        entities=train_twg.get_game_controls().get("entities", None), 
    )
    train_env = train_twg.get_game_info().get("env", None)
    test_twg = TextWorldGame(config=config)
    test_env = test_twg.get_game_info().get("env", None)

Train and test RL agent agent in training and test environment

In [None]:
if play_rl_agent:
    rl_agent_stats = rl_agent.play(train_env=train_env, test_env=test_env)
    rl_agent_stats_df = pd.DataFrame(rl_agent_stats)
    rl_agent_stats_df["epoch"] = np.arange(start=1, stop=rl_agent_stats_df.shape[0]+1)
    rl_agent_stats_df

In [None]:
if play_rl_agent:
    plot_dir = f"{config.plot_dir}/rl_agent-complex"
    if not os.path.isdir(plot_dir):
        os.makedirs(plot_dir)

    sns.lineplot(x="epoch", y="avg_reward_test", data=rl_agent_stats_df);
    plt.savefig(f"{plot_dir}/avg_reward_test.jpeg");

    sns.lineplot(x="epoch", y="avg_steps_test", data=rl_agent_stats_df);
    plt.savefig(f"{plot_dir}/avg_steps_test.jpeg");

    sns.lineplot(x="epoch", y="avg_wins_test", data=rl_agent_stats_df);
    plt.savefig(f"{plot_dir}/avg_wins_test.jpeg");