## Training an RL agent to predict country capitals

This notebook presents a worked example of how to build an RL agent with Kandula. In this example we build an RL agent that learns what are the capitals of different countrys.

### Setup

First things first, let's import the requirements. 

In [1]:
from kandula import logging
from kandula.steps import RLStep
from kandula.qtable import QTable
from kandula.q_learning import QL
from typing import List

from functools import reduce
import torch
import json
import random
from nltk import word_tokenize

For this example, in order to train the RL agent, we have downloaded a json file that contains the countris of the world and their capitals from [this repository](https://github.com/icyrockcom/country-capitals/blob/master/data/country-list-with-ids.json). For convenience, we downloaded the file into the [presources folder](https://github.com/meghdadFar/kandula/tree/main/resources) in this repository. Let's read this file and prune it so that it suits our needs.

In [2]:
with open("../resources/country_capital.json", "r") as fc:
    capitals: List = json.load(fc)
capitals_dict = {}
country_index = {}
index_country = {}
i=1
for jl in capitals:
    capitals_dict[jl["country"]] = jl["capital"]
    country_index[jl["country"]] = i
    index_country[i] = jl["country"]
    i+=1

In the above code, we first read the json lines into `capitals`, we then create three dictionaries from it: `capitals_dict` that maps the country names to their capitals, `country_index` that maps the country names to incremental indexes, and `index_country` that maps back indexes to country names.

### Define the State Space

The first thing that we should consider is how do we want to map our problem to a state-action space. In this example, I consider that each contry represents a stae and hence, for 248 countries we will have a 1-dimensional state space of 248. If the space was 2-dimensional and say the first dimension had a size N and the second dimension has a size M, we could have defined the `state_space` variable as: `[N, M]`

In [3]:
state_space = [248]

### Define the Actions

The next thing is to define the actions. In this example, I consider guessing (the right) capital as an action that my RL agent is suppose to learn and hence, I define my actions to be a list of 248 capitals. If for instance, my RL agent was supposed to take one of the two actions of e.g. shifting grear up or shifting gear down, my actions variable would have been `[shift_gear_up, shift_gear_down]`.

In [4]:
actions = [v for _, v in capitals_dict.items()]

### Define the RL Step

It's now time to define our RL step. An RL step should be a child of `kandula.steps.RLStep` class and implement its two abstract methods, namely `get_state()` and `get_reward()`. In addition to the definition of states and actions, this is where you make the RL agent really specific to your problem.

In [10]:
class CapitalsRLStep(RLStep):

    def get_state(self):
        country = gen_rand_country()
        state = [country_index[country]]
        return state
    
    def get_reward(self, state, action):
        s = reduce((lambda x: x), state)
        reward = 1 if capitals_dict[index_country[s]] == action else 0
        return reward

In the above code, the state situation is simplified which matches this particular problem. To get the current state we simply choose a random country (ignoring the previous action and the possible changes of the enviroment) via `gen_rand_country()`. The reward is calculated by comparing the RL agent's prediction to the actual capital.

In [11]:
def gen_rand_country():
    country, _ = random.choice(list(capitals_dict.items()))
    return country

### Initiate and Train the RL Agent

In [13]:
%load_ext tensorboard
mrls = CapitalsRLStep()
qt = QTable(state_space=state_space, actions=actions)
ql = QL(qtable=qt, rl_step=mrls)
ql.train(3000000, get_correct_action_for_capitals)

Q_learning      - 138 - INFO - Training the RL agent...
Q_learning      - 145 - INFO - Epoch: 1000 - Error: 97.18%


The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Q_learning      - 145 - INFO - Epoch: 2000 - Error: 95.97%
Q_learning      - 145 - INFO - Epoch: 3000 - Error: 93.15%
Q_learning      - 145 - INFO - Epoch: 4000 - Error: 92.34%
Q_learning      - 145 - INFO - Epoch: 5000 - Error: 91.13%
Q_learning      - 145 - INFO - Epoch: 6000 - Error: 90.73%
Q_learning      - 145 - INFO - Epoch: 7000 - Error: 89.52%
Q_learning      - 145 - INFO - Epoch: 8000 - Error: 89.11%
Q_learning      - 145 - INFO - Epoch: 9000 - Error: 88.71%
Q_learning      - 145 - INFO - Epoch: 10000 - Error: 87.90%
Q_learning      - 145 - INFO - Epoch: 11000 - Error: 87.90%
Q_learning      - 145 - INFO - Epoch: 12000 - Error: 87.50%
Q_learning      - 145 - INFO - Epoch: 13000 - Error: 87.10%
Q_learning      - 145 - INFO - Epoch: 14000 - Error: 87.10%
Q_learning      - 145 - INFO - Epoch: 15000 - Error: 87.10%
Q_learning      - 145 - INFO - Epoch: 16000 - Error: 87.10%
Q_learning      - 145 - INFO - Epoch: 17000 - Error: 87.10%
Q_learning      - 145 - INFO - Epoch: 18000 - Er

Q_learning      - 145 - INFO - Epoch: 139000 - Error: 79.03%
Q_learning      - 145 - INFO - Epoch: 140000 - Error: 79.03%
Q_learning      - 145 - INFO - Epoch: 141000 - Error: 79.03%
Q_learning      - 145 - INFO - Epoch: 142000 - Error: 79.03%
Q_learning      - 145 - INFO - Epoch: 143000 - Error: 79.03%
Q_learning      - 145 - INFO - Epoch: 144000 - Error: 79.03%
Q_learning      - 145 - INFO - Epoch: 145000 - Error: 78.63%
Q_learning      - 145 - INFO - Epoch: 146000 - Error: 78.63%
Q_learning      - 145 - INFO - Epoch: 147000 - Error: 78.63%
Q_learning      - 145 - INFO - Epoch: 148000 - Error: 78.63%
Q_learning      - 145 - INFO - Epoch: 149000 - Error: 78.63%
Q_learning      - 145 - INFO - Epoch: 150000 - Error: 78.63%
Q_learning      - 145 - INFO - Epoch: 151000 - Error: 78.63%
Q_learning      - 145 - INFO - Epoch: 152000 - Error: 78.63%
Q_learning      - 145 - INFO - Epoch: 153000 - Error: 78.63%
Q_learning      - 145 - INFO - Epoch: 154000 - Error: 78.63%
Q_learning      - 145 - 

Q_learning      - 145 - INFO - Epoch: 274000 - Error: 70.97%
Q_learning      - 145 - INFO - Epoch: 275000 - Error: 70.97%
Q_learning      - 145 - INFO - Epoch: 276000 - Error: 70.56%
Q_learning      - 145 - INFO - Epoch: 277000 - Error: 70.56%
Q_learning      - 145 - INFO - Epoch: 278000 - Error: 70.56%
Q_learning      - 145 - INFO - Epoch: 279000 - Error: 70.56%
Q_learning      - 145 - INFO - Epoch: 280000 - Error: 70.56%
Q_learning      - 145 - INFO - Epoch: 281000 - Error: 70.56%
Q_learning      - 145 - INFO - Epoch: 282000 - Error: 70.56%
Q_learning      - 145 - INFO - Epoch: 283000 - Error: 70.56%
Q_learning      - 145 - INFO - Epoch: 284000 - Error: 70.56%
Q_learning      - 145 - INFO - Epoch: 285000 - Error: 70.16%
Q_learning      - 145 - INFO - Epoch: 286000 - Error: 70.16%
Q_learning      - 145 - INFO - Epoch: 287000 - Error: 70.16%
Q_learning      - 145 - INFO - Epoch: 288000 - Error: 69.76%
Q_learning      - 145 - INFO - Epoch: 289000 - Error: 69.76%
Q_learning      - 145 - 

Q_learning      - 145 - INFO - Epoch: 409000 - Error: 63.71%
Q_learning      - 145 - INFO - Epoch: 410000 - Error: 63.71%
Q_learning      - 145 - INFO - Epoch: 411000 - Error: 63.71%
Q_learning      - 145 - INFO - Epoch: 412000 - Error: 63.71%
Q_learning      - 145 - INFO - Epoch: 413000 - Error: 63.71%
Q_learning      - 145 - INFO - Epoch: 414000 - Error: 63.31%
Q_learning      - 145 - INFO - Epoch: 415000 - Error: 62.90%
Q_learning      - 145 - INFO - Epoch: 416000 - Error: 62.90%
Q_learning      - 145 - INFO - Epoch: 417000 - Error: 62.90%
Q_learning      - 145 - INFO - Epoch: 418000 - Error: 62.90%
Q_learning      - 145 - INFO - Epoch: 419000 - Error: 62.90%
Q_learning      - 145 - INFO - Epoch: 420000 - Error: 62.90%
Q_learning      - 145 - INFO - Epoch: 421000 - Error: 62.90%
Q_learning      - 145 - INFO - Epoch: 422000 - Error: 62.90%
Q_learning      - 145 - INFO - Epoch: 423000 - Error: 62.90%
Q_learning      - 145 - INFO - Epoch: 424000 - Error: 62.50%
Q_learning      - 145 - 

Q_learning      - 145 - INFO - Epoch: 544000 - Error: 55.24%
Q_learning      - 145 - INFO - Epoch: 545000 - Error: 55.24%
Q_learning      - 145 - INFO - Epoch: 546000 - Error: 55.24%
Q_learning      - 145 - INFO - Epoch: 547000 - Error: 55.24%
Q_learning      - 145 - INFO - Epoch: 548000 - Error: 55.24%
Q_learning      - 145 - INFO - Epoch: 549000 - Error: 55.24%
Q_learning      - 145 - INFO - Epoch: 550000 - Error: 55.24%
Q_learning      - 145 - INFO - Epoch: 551000 - Error: 55.24%
Q_learning      - 145 - INFO - Epoch: 552000 - Error: 55.24%
Q_learning      - 145 - INFO - Epoch: 553000 - Error: 55.24%
Q_learning      - 145 - INFO - Epoch: 554000 - Error: 54.84%
Q_learning      - 145 - INFO - Epoch: 555000 - Error: 54.84%
Q_learning      - 145 - INFO - Epoch: 556000 - Error: 54.84%
Q_learning      - 145 - INFO - Epoch: 557000 - Error: 54.84%
Q_learning      - 145 - INFO - Epoch: 558000 - Error: 54.84%
Q_learning      - 145 - INFO - Epoch: 559000 - Error: 54.84%
Q_learning      - 145 - 

Q_learning      - 145 - INFO - Epoch: 679000 - Error: 50.81%
Q_learning      - 145 - INFO - Epoch: 680000 - Error: 50.81%
Q_learning      - 145 - INFO - Epoch: 681000 - Error: 50.81%
Q_learning      - 145 - INFO - Epoch: 682000 - Error: 50.81%
Q_learning      - 145 - INFO - Epoch: 683000 - Error: 50.81%
Q_learning      - 145 - INFO - Epoch: 684000 - Error: 50.81%
Q_learning      - 145 - INFO - Epoch: 685000 - Error: 50.81%
Q_learning      - 145 - INFO - Epoch: 686000 - Error: 50.81%
Q_learning      - 145 - INFO - Epoch: 687000 - Error: 50.81%
Q_learning      - 145 - INFO - Epoch: 688000 - Error: 50.81%
Q_learning      - 145 - INFO - Epoch: 689000 - Error: 50.81%
Q_learning      - 145 - INFO - Epoch: 690000 - Error: 50.81%
Q_learning      - 145 - INFO - Epoch: 691000 - Error: 50.81%
Q_learning      - 145 - INFO - Epoch: 692000 - Error: 50.81%
Q_learning      - 145 - INFO - Epoch: 693000 - Error: 50.81%
Q_learning      - 145 - INFO - Epoch: 694000 - Error: 50.81%
Q_learning      - 145 - 

Q_learning      - 145 - INFO - Epoch: 814000 - Error: 47.18%
Q_learning      - 145 - INFO - Epoch: 815000 - Error: 47.18%
Q_learning      - 145 - INFO - Epoch: 816000 - Error: 47.18%
Q_learning      - 145 - INFO - Epoch: 817000 - Error: 47.18%
Q_learning      - 145 - INFO - Epoch: 818000 - Error: 47.18%
Q_learning      - 145 - INFO - Epoch: 819000 - Error: 47.18%
Q_learning      - 145 - INFO - Epoch: 820000 - Error: 47.18%
Q_learning      - 145 - INFO - Epoch: 821000 - Error: 47.18%
Q_learning      - 145 - INFO - Epoch: 822000 - Error: 47.18%
Q_learning      - 145 - INFO - Epoch: 823000 - Error: 47.18%
Q_learning      - 145 - INFO - Epoch: 824000 - Error: 47.18%
Q_learning      - 145 - INFO - Epoch: 825000 - Error: 46.77%
Q_learning      - 145 - INFO - Epoch: 826000 - Error: 46.77%
Q_learning      - 145 - INFO - Epoch: 827000 - Error: 46.77%
Q_learning      - 145 - INFO - Epoch: 828000 - Error: 46.77%
Q_learning      - 145 - INFO - Epoch: 829000 - Error: 46.77%
Q_learning      - 145 - 

Q_learning      - 145 - INFO - Epoch: 949000 - Error: 43.15%
Q_learning      - 145 - INFO - Epoch: 950000 - Error: 43.15%
Q_learning      - 145 - INFO - Epoch: 951000 - Error: 43.15%
Q_learning      - 145 - INFO - Epoch: 952000 - Error: 43.15%
Q_learning      - 145 - INFO - Epoch: 953000 - Error: 42.74%
Q_learning      - 145 - INFO - Epoch: 954000 - Error: 42.74%
Q_learning      - 145 - INFO - Epoch: 955000 - Error: 42.74%
Q_learning      - 145 - INFO - Epoch: 956000 - Error: 42.74%
Q_learning      - 145 - INFO - Epoch: 957000 - Error: 42.74%
Q_learning      - 145 - INFO - Epoch: 958000 - Error: 42.74%
Q_learning      - 145 - INFO - Epoch: 959000 - Error: 42.74%
Q_learning      - 145 - INFO - Epoch: 960000 - Error: 42.74%
Q_learning      - 145 - INFO - Epoch: 961000 - Error: 42.74%
Q_learning      - 145 - INFO - Epoch: 962000 - Error: 42.74%
Q_learning      - 145 - INFO - Epoch: 963000 - Error: 42.74%
Q_learning      - 145 - INFO - Epoch: 964000 - Error: 42.74%
Q_learning      - 145 - 

Q_learning      - 145 - INFO - Epoch: 1082000 - Error: 38.71%
Q_learning      - 145 - INFO - Epoch: 1083000 - Error: 38.71%
Q_learning      - 145 - INFO - Epoch: 1084000 - Error: 38.71%
Q_learning      - 145 - INFO - Epoch: 1085000 - Error: 38.71%
Q_learning      - 145 - INFO - Epoch: 1086000 - Error: 38.71%
Q_learning      - 145 - INFO - Epoch: 1087000 - Error: 38.71%
Q_learning      - 145 - INFO - Epoch: 1088000 - Error: 38.71%
Q_learning      - 145 - INFO - Epoch: 1089000 - Error: 38.71%
Q_learning      - 145 - INFO - Epoch: 1090000 - Error: 38.71%
Q_learning      - 145 - INFO - Epoch: 1091000 - Error: 38.71%
Q_learning      - 145 - INFO - Epoch: 1092000 - Error: 38.71%
Q_learning      - 145 - INFO - Epoch: 1093000 - Error: 38.71%
Q_learning      - 145 - INFO - Epoch: 1094000 - Error: 38.71%
Q_learning      - 145 - INFO - Epoch: 1095000 - Error: 38.71%
Q_learning      - 145 - INFO - Epoch: 1096000 - Error: 38.71%
Q_learning      - 145 - INFO - Epoch: 1097000 - Error: 38.71%
Q_learni

Q_learning      - 145 - INFO - Epoch: 1215000 - Error: 35.48%
Q_learning      - 145 - INFO - Epoch: 1216000 - Error: 35.48%
Q_learning      - 145 - INFO - Epoch: 1217000 - Error: 35.48%
Q_learning      - 145 - INFO - Epoch: 1218000 - Error: 35.48%
Q_learning      - 145 - INFO - Epoch: 1219000 - Error: 35.48%
Q_learning      - 145 - INFO - Epoch: 1220000 - Error: 35.48%
Q_learning      - 145 - INFO - Epoch: 1221000 - Error: 35.48%
Q_learning      - 145 - INFO - Epoch: 1222000 - Error: 35.48%
Q_learning      - 145 - INFO - Epoch: 1223000 - Error: 35.48%
Q_learning      - 145 - INFO - Epoch: 1224000 - Error: 35.48%
Q_learning      - 145 - INFO - Epoch: 1225000 - Error: 35.48%
Q_learning      - 145 - INFO - Epoch: 1226000 - Error: 35.48%
Q_learning      - 145 - INFO - Epoch: 1227000 - Error: 35.48%
Q_learning      - 145 - INFO - Epoch: 1228000 - Error: 35.48%
Q_learning      - 145 - INFO - Epoch: 1229000 - Error: 35.08%
Q_learning      - 145 - INFO - Epoch: 1230000 - Error: 35.08%
Q_learni

Q_learning      - 145 - INFO - Epoch: 1348000 - Error: 32.26%
Q_learning      - 145 - INFO - Epoch: 1349000 - Error: 32.26%
Q_learning      - 145 - INFO - Epoch: 1350000 - Error: 32.26%
Q_learning      - 145 - INFO - Epoch: 1351000 - Error: 32.26%
Q_learning      - 145 - INFO - Epoch: 1352000 - Error: 32.26%
Q_learning      - 145 - INFO - Epoch: 1353000 - Error: 32.26%
Q_learning      - 145 - INFO - Epoch: 1354000 - Error: 32.26%
Q_learning      - 145 - INFO - Epoch: 1355000 - Error: 32.26%
Q_learning      - 145 - INFO - Epoch: 1356000 - Error: 32.26%
Q_learning      - 145 - INFO - Epoch: 1357000 - Error: 32.26%
Q_learning      - 145 - INFO - Epoch: 1358000 - Error: 32.26%
Q_learning      - 145 - INFO - Epoch: 1359000 - Error: 32.26%
Q_learning      - 145 - INFO - Epoch: 1360000 - Error: 32.26%
Q_learning      - 145 - INFO - Epoch: 1361000 - Error: 32.26%
Q_learning      - 145 - INFO - Epoch: 1362000 - Error: 32.26%
Q_learning      - 145 - INFO - Epoch: 1363000 - Error: 32.26%
Q_learni

Q_learning      - 145 - INFO - Epoch: 1481000 - Error: 30.65%
Q_learning      - 145 - INFO - Epoch: 1482000 - Error: 30.65%
Q_learning      - 145 - INFO - Epoch: 1483000 - Error: 30.65%
Q_learning      - 145 - INFO - Epoch: 1484000 - Error: 30.65%
Q_learning      - 145 - INFO - Epoch: 1485000 - Error: 30.65%
Q_learning      - 145 - INFO - Epoch: 1486000 - Error: 30.65%
Q_learning      - 145 - INFO - Epoch: 1487000 - Error: 30.65%
Q_learning      - 145 - INFO - Epoch: 1488000 - Error: 30.65%
Q_learning      - 145 - INFO - Epoch: 1489000 - Error: 30.65%
Q_learning      - 145 - INFO - Epoch: 1490000 - Error: 30.65%
Q_learning      - 145 - INFO - Epoch: 1491000 - Error: 30.65%
Q_learning      - 145 - INFO - Epoch: 1492000 - Error: 30.65%
Q_learning      - 145 - INFO - Epoch: 1493000 - Error: 30.65%
Q_learning      - 145 - INFO - Epoch: 1494000 - Error: 30.65%
Q_learning      - 145 - INFO - Epoch: 1495000 - Error: 30.65%
Q_learning      - 145 - INFO - Epoch: 1496000 - Error: 30.65%
Q_learni

Q_learning      - 145 - INFO - Epoch: 1614000 - Error: 28.63%
Q_learning      - 145 - INFO - Epoch: 1615000 - Error: 28.63%
Q_learning      - 145 - INFO - Epoch: 1616000 - Error: 28.63%
Q_learning      - 145 - INFO - Epoch: 1617000 - Error: 28.23%
Q_learning      - 145 - INFO - Epoch: 1618000 - Error: 28.23%
Q_learning      - 145 - INFO - Epoch: 1619000 - Error: 28.23%
Q_learning      - 145 - INFO - Epoch: 1620000 - Error: 28.23%
Q_learning      - 145 - INFO - Epoch: 1621000 - Error: 28.23%
Q_learning      - 145 - INFO - Epoch: 1622000 - Error: 28.23%
Q_learning      - 145 - INFO - Epoch: 1623000 - Error: 28.23%
Q_learning      - 145 - INFO - Epoch: 1624000 - Error: 28.23%
Q_learning      - 145 - INFO - Epoch: 1625000 - Error: 28.23%
Q_learning      - 145 - INFO - Epoch: 1626000 - Error: 28.23%
Q_learning      - 145 - INFO - Epoch: 1627000 - Error: 28.23%
Q_learning      - 145 - INFO - Epoch: 1628000 - Error: 28.23%
Q_learning      - 145 - INFO - Epoch: 1629000 - Error: 28.23%
Q_learni

Q_learning      - 145 - INFO - Epoch: 1747000 - Error: 25.81%
Q_learning      - 145 - INFO - Epoch: 1748000 - Error: 25.81%
Q_learning      - 145 - INFO - Epoch: 1749000 - Error: 25.81%
Q_learning      - 145 - INFO - Epoch: 1750000 - Error: 25.81%
Q_learning      - 145 - INFO - Epoch: 1751000 - Error: 25.81%
Q_learning      - 145 - INFO - Epoch: 1752000 - Error: 25.81%
Q_learning      - 145 - INFO - Epoch: 1753000 - Error: 25.81%
Q_learning      - 145 - INFO - Epoch: 1754000 - Error: 25.81%
Q_learning      - 145 - INFO - Epoch: 1755000 - Error: 25.81%
Q_learning      - 145 - INFO - Epoch: 1756000 - Error: 25.81%
Q_learning      - 145 - INFO - Epoch: 1757000 - Error: 25.40%
Q_learning      - 145 - INFO - Epoch: 1758000 - Error: 25.40%
Q_learning      - 145 - INFO - Epoch: 1759000 - Error: 25.40%
Q_learning      - 145 - INFO - Epoch: 1760000 - Error: 25.40%
Q_learning      - 145 - INFO - Epoch: 1761000 - Error: 25.40%
Q_learning      - 145 - INFO - Epoch: 1762000 - Error: 25.40%
Q_learni

Q_learning      - 145 - INFO - Epoch: 1880000 - Error: 22.58%
Q_learning      - 145 - INFO - Epoch: 1881000 - Error: 22.58%
Q_learning      - 145 - INFO - Epoch: 1882000 - Error: 22.58%
Q_learning      - 145 - INFO - Epoch: 1883000 - Error: 22.58%
Q_learning      - 145 - INFO - Epoch: 1884000 - Error: 22.58%
Q_learning      - 145 - INFO - Epoch: 1885000 - Error: 22.18%
Q_learning      - 145 - INFO - Epoch: 1886000 - Error: 22.18%
Q_learning      - 145 - INFO - Epoch: 1887000 - Error: 22.18%
Q_learning      - 145 - INFO - Epoch: 1888000 - Error: 22.18%
Q_learning      - 145 - INFO - Epoch: 1889000 - Error: 22.18%
Q_learning      - 145 - INFO - Epoch: 1890000 - Error: 22.18%
Q_learning      - 145 - INFO - Epoch: 1891000 - Error: 22.18%
Q_learning      - 145 - INFO - Epoch: 1892000 - Error: 21.77%
Q_learning      - 145 - INFO - Epoch: 1893000 - Error: 21.77%
Q_learning      - 145 - INFO - Epoch: 1894000 - Error: 21.37%
Q_learning      - 145 - INFO - Epoch: 1895000 - Error: 20.97%
Q_learni

Q_learning      - 145 - INFO - Epoch: 2013000 - Error: 19.35%
Q_learning      - 145 - INFO - Epoch: 2014000 - Error: 19.35%
Q_learning      - 145 - INFO - Epoch: 2015000 - Error: 19.35%
Q_learning      - 145 - INFO - Epoch: 2016000 - Error: 19.35%
Q_learning      - 145 - INFO - Epoch: 2017000 - Error: 19.35%
Q_learning      - 145 - INFO - Epoch: 2018000 - Error: 19.35%
Q_learning      - 145 - INFO - Epoch: 2019000 - Error: 19.35%
Q_learning      - 145 - INFO - Epoch: 2020000 - Error: 19.35%
Q_learning      - 145 - INFO - Epoch: 2021000 - Error: 19.35%
Q_learning      - 145 - INFO - Epoch: 2022000 - Error: 19.35%
Q_learning      - 145 - INFO - Epoch: 2023000 - Error: 19.35%
Q_learning      - 145 - INFO - Epoch: 2024000 - Error: 19.35%
Q_learning      - 145 - INFO - Epoch: 2025000 - Error: 19.35%
Q_learning      - 145 - INFO - Epoch: 2026000 - Error: 19.35%
Q_learning      - 145 - INFO - Epoch: 2027000 - Error: 19.35%
Q_learning      - 145 - INFO - Epoch: 2028000 - Error: 19.35%
Q_learni

Q_learning      - 145 - INFO - Epoch: 2146000 - Error: 18.95%
Q_learning      - 145 - INFO - Epoch: 2147000 - Error: 18.95%
Q_learning      - 145 - INFO - Epoch: 2148000 - Error: 18.95%
Q_learning      - 145 - INFO - Epoch: 2149000 - Error: 18.95%
Q_learning      - 145 - INFO - Epoch: 2150000 - Error: 18.95%
Q_learning      - 145 - INFO - Epoch: 2151000 - Error: 18.95%
Q_learning      - 145 - INFO - Epoch: 2152000 - Error: 18.95%
Q_learning      - 145 - INFO - Epoch: 2153000 - Error: 18.95%
Q_learning      - 145 - INFO - Epoch: 2154000 - Error: 18.95%
Q_learning      - 145 - INFO - Epoch: 2155000 - Error: 18.95%
Q_learning      - 145 - INFO - Epoch: 2156000 - Error: 18.95%
Q_learning      - 145 - INFO - Epoch: 2157000 - Error: 18.95%
Q_learning      - 145 - INFO - Epoch: 2158000 - Error: 18.95%
Q_learning      - 145 - INFO - Epoch: 2159000 - Error: 18.95%
Q_learning      - 145 - INFO - Epoch: 2160000 - Error: 18.95%
Q_learning      - 145 - INFO - Epoch: 2161000 - Error: 18.95%
Q_learni

Q_learning      - 145 - INFO - Epoch: 2279000 - Error: 17.34%
Q_learning      - 145 - INFO - Epoch: 2280000 - Error: 17.34%
Q_learning      - 145 - INFO - Epoch: 2281000 - Error: 17.34%
Q_learning      - 145 - INFO - Epoch: 2282000 - Error: 17.34%
Q_learning      - 145 - INFO - Epoch: 2283000 - Error: 17.34%
Q_learning      - 145 - INFO - Epoch: 2284000 - Error: 17.34%
Q_learning      - 145 - INFO - Epoch: 2285000 - Error: 17.34%
Q_learning      - 145 - INFO - Epoch: 2286000 - Error: 17.34%
Q_learning      - 145 - INFO - Epoch: 2287000 - Error: 17.34%
Q_learning      - 145 - INFO - Epoch: 2288000 - Error: 17.34%
Q_learning      - 145 - INFO - Epoch: 2289000 - Error: 17.34%
Q_learning      - 145 - INFO - Epoch: 2290000 - Error: 17.34%
Q_learning      - 145 - INFO - Epoch: 2291000 - Error: 17.34%
Q_learning      - 145 - INFO - Epoch: 2292000 - Error: 17.34%
Q_learning      - 145 - INFO - Epoch: 2293000 - Error: 17.34%
Q_learning      - 145 - INFO - Epoch: 2294000 - Error: 17.34%
Q_learni

Q_learning      - 145 - INFO - Epoch: 2412000 - Error: 14.92%
Q_learning      - 145 - INFO - Epoch: 2413000 - Error: 14.92%
Q_learning      - 145 - INFO - Epoch: 2414000 - Error: 14.92%
Q_learning      - 145 - INFO - Epoch: 2415000 - Error: 14.92%
Q_learning      - 145 - INFO - Epoch: 2416000 - Error: 14.92%
Q_learning      - 145 - INFO - Epoch: 2417000 - Error: 14.92%
Q_learning      - 145 - INFO - Epoch: 2418000 - Error: 14.92%
Q_learning      - 145 - INFO - Epoch: 2419000 - Error: 14.92%
Q_learning      - 145 - INFO - Epoch: 2420000 - Error: 14.92%
Q_learning      - 145 - INFO - Epoch: 2421000 - Error: 14.92%
Q_learning      - 145 - INFO - Epoch: 2422000 - Error: 14.92%
Q_learning      - 145 - INFO - Epoch: 2423000 - Error: 14.92%
Q_learning      - 145 - INFO - Epoch: 2424000 - Error: 14.92%
Q_learning      - 145 - INFO - Epoch: 2425000 - Error: 14.92%
Q_learning      - 145 - INFO - Epoch: 2426000 - Error: 14.92%
Q_learning      - 145 - INFO - Epoch: 2427000 - Error: 14.92%
Q_learni

Q_learning      - 145 - INFO - Epoch: 2545000 - Error: 14.11%
Q_learning      - 145 - INFO - Epoch: 2546000 - Error: 13.71%
Q_learning      - 145 - INFO - Epoch: 2547000 - Error: 13.71%
Q_learning      - 145 - INFO - Epoch: 2548000 - Error: 13.71%
Q_learning      - 145 - INFO - Epoch: 2549000 - Error: 13.71%
Q_learning      - 145 - INFO - Epoch: 2550000 - Error: 13.71%
Q_learning      - 145 - INFO - Epoch: 2551000 - Error: 13.71%
Q_learning      - 145 - INFO - Epoch: 2552000 - Error: 13.71%
Q_learning      - 145 - INFO - Epoch: 2553000 - Error: 13.71%
Q_learning      - 145 - INFO - Epoch: 2554000 - Error: 13.71%
Q_learning      - 145 - INFO - Epoch: 2555000 - Error: 13.71%
Q_learning      - 145 - INFO - Epoch: 2556000 - Error: 13.71%
Q_learning      - 145 - INFO - Epoch: 2557000 - Error: 13.71%
Q_learning      - 145 - INFO - Epoch: 2558000 - Error: 13.71%
Q_learning      - 145 - INFO - Epoch: 2559000 - Error: 13.71%
Q_learning      - 145 - INFO - Epoch: 2560000 - Error: 13.71%
Q_learni

Q_learning      - 145 - INFO - Epoch: 2678000 - Error: 11.69%
Q_learning      - 145 - INFO - Epoch: 2679000 - Error: 11.69%
Q_learning      - 145 - INFO - Epoch: 2680000 - Error: 11.69%
Q_learning      - 145 - INFO - Epoch: 2681000 - Error: 11.69%
Q_learning      - 145 - INFO - Epoch: 2682000 - Error: 11.69%
Q_learning      - 145 - INFO - Epoch: 2683000 - Error: 11.69%
Q_learning      - 145 - INFO - Epoch: 2684000 - Error: 11.69%
Q_learning      - 145 - INFO - Epoch: 2685000 - Error: 11.69%
Q_learning      - 145 - INFO - Epoch: 2686000 - Error: 11.69%
Q_learning      - 145 - INFO - Epoch: 2687000 - Error: 11.69%
Q_learning      - 145 - INFO - Epoch: 2688000 - Error: 11.69%
Q_learning      - 145 - INFO - Epoch: 2689000 - Error: 11.69%
Q_learning      - 145 - INFO - Epoch: 2690000 - Error: 11.69%
Q_learning      - 145 - INFO - Epoch: 2691000 - Error: 11.69%
Q_learning      - 145 - INFO - Epoch: 2692000 - Error: 11.69%
Q_learning      - 145 - INFO - Epoch: 2693000 - Error: 11.69%
Q_learni

Q_learning      - 145 - INFO - Epoch: 2811000 - Error: 10.08%
Q_learning      - 145 - INFO - Epoch: 2812000 - Error: 10.08%
Q_learning      - 145 - INFO - Epoch: 2813000 - Error: 10.08%
Q_learning      - 145 - INFO - Epoch: 2814000 - Error: 10.08%
Q_learning      - 145 - INFO - Epoch: 2815000 - Error: 10.08%
Q_learning      - 145 - INFO - Epoch: 2816000 - Error: 10.08%
Q_learning      - 145 - INFO - Epoch: 2817000 - Error: 10.08%
Q_learning      - 145 - INFO - Epoch: 2818000 - Error: 10.08%
Q_learning      - 145 - INFO - Epoch: 2819000 - Error: 10.08%
Q_learning      - 145 - INFO - Epoch: 2820000 - Error: 10.08%
Q_learning      - 145 - INFO - Epoch: 2821000 - Error: 10.08%
Q_learning      - 145 - INFO - Epoch: 2822000 - Error: 10.08%
Q_learning      - 145 - INFO - Epoch: 2823000 - Error: 10.08%
Q_learning      - 145 - INFO - Epoch: 2824000 - Error: 10.08%
Q_learning      - 145 - INFO - Epoch: 2825000 - Error: 10.08%
Q_learning      - 145 - INFO - Epoch: 2826000 - Error: 10.08%
Q_learni

Q_learning      - 145 - INFO - Epoch: 2944000 - Error: 9.68%
Q_learning      - 145 - INFO - Epoch: 2945000 - Error: 9.68%
Q_learning      - 145 - INFO - Epoch: 2946000 - Error: 9.68%
Q_learning      - 145 - INFO - Epoch: 2947000 - Error: 9.68%
Q_learning      - 145 - INFO - Epoch: 2948000 - Error: 9.68%
Q_learning      - 145 - INFO - Epoch: 2949000 - Error: 9.68%
Q_learning      - 145 - INFO - Epoch: 2950000 - Error: 9.68%
Q_learning      - 145 - INFO - Epoch: 2951000 - Error: 9.68%
Q_learning      - 145 - INFO - Epoch: 2952000 - Error: 9.68%
Q_learning      - 145 - INFO - Epoch: 2953000 - Error: 9.68%
Q_learning      - 145 - INFO - Epoch: 2954000 - Error: 9.68%
Q_learning      - 145 - INFO - Epoch: 2955000 - Error: 9.68%
Q_learning      - 145 - INFO - Epoch: 2956000 - Error: 9.68%
Q_learning      - 145 - INFO - Epoch: 2957000 - Error: 9.68%
Q_learning      - 145 - INFO - Epoch: 2958000 - Error: 9.68%
Q_learning      - 145 - INFO - Epoch: 2959000 - Error: 9.68%
Q_learning      - 145 - 

Before being able to train our RL agent, we need to define one more function that is used in the training loop and its main purpose is to decide what is the best action to take in a given state. You should be able to define this function with respect to the problem that you are solving. For instance, in this example, since we defined the state to be simply a country, the best action is to return the right capital for that country. Hence, we can define the following function:

In [7]:
def get_correct_action_for_capitals(state: List):
    return capitals_dict[index_country[state[0]]]

By setting `verbose=True` allow the train function to log the errors at each epoch, however, this is not necessary, as the error results are always stored in a tensorboard plot. Outside Notebooks, you can access the plots by simply running `tensorboard --logdir=runs`. You can then access the tensorboard under http://localhost:6006/. To observe the plots in the Notebooks directly, you can run the following command:

In [16]:
%tensorboard --logdir runs

Reusing TensorBoard on port 6006 (pid 65777), started 0:01:19 ago. (Use '!kill 65777' to kill it.)

The RL agent seems to be trained and the error has dropped well, though it took quite a bit of iterations. Let's now write a script that uses the trained RL agent and answers queries about country capitals 

In [20]:
while True:
    country = ""
    query = input ("Enter your quey: ")
    if query.lower() == "stop":
        print("Have a nice day! Bye.")
        break
    try:
        tokens = word_tokenize(query)
        for t in tokens:
            if t in capitals_dict:
                country = t
    except Exception as E:
        logging.error(E)
        continue
    try:
        state_index = ql.q_table.get_state_index([country_index[country]])
        action_index = torch.argmax(ql.q_table.q_table[state_index]).item()
        res = ql.q_table.actions[action_index]
        print(f'Capital of {country} is {res}')
    except:
        logging.error('Make sure the country name is written correctly, and is capitalized.')

Enter your quey: What is the capital of Enter your Argentina?
Capital of Argentina is Buenos Aires
Enter your quey: What is the capital of Enter your Bulgaria?
Capital of Bulgaria is Sofia
Enter your quey: What is the capital of Enter your Germany?
Capital of Germany is Berlin
Enter your quey: What is the capital of Enter your Australia?
Capital of Australia is Canberra
Enter your quey: What is the capital of Enter your Canada?
Capital of Canada is Ottawa
Enter your quey: Stop
Have a nice day! Bye.
