In [None]:
# Laurent LEQUIEVRE
# Research Engineer, CNRS (France)
# Institut Pascal UMR6602
# laurent.lequievre@uca.fr

# Solution based on :
# https://www.kaggle.com/wuhao1542/pytorch-rl-0-frozenlake-q-network-learning

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch import optim
import numpy as np 

In [2]:
import gym
from gym.envs.registration import register

In [3]:
register(
   id="FrozenLakeNotSlippery-v0",
   entry_point='gym.envs.toy_text:FrozenLakeEnv',
   kwargs={'map_name': '4x4', 'is_slippery': False},
)

env = gym.make("FrozenLakeNotSlippery-v0")

In [4]:
observation_space = env.observation_space.n
action_space = env.action_space.n

print("observation space = {}, action space = {}".format(observation_space, action_space))

observation space = 16, action space = 4


In [12]:
def uniform_linear_layer(linear_layer):
    linear_layer.weight.data.uniform_()
    linear_layer.bias.data.fill_(-0.02)

class Agent(nn.Module):
    def __init__(self, observation_space_size, action_space_size):
        super(Agent, self).__init__()
        self.observation_space_size = observation_space_size
        self.hidden_size = observation_space_size
        self.l1 = nn.Linear(in_features=observation_space_size, out_features=self.hidden_size)
        self.l2 = nn.Linear(in_features=self.hidden_size, out_features=action_space_size)
        uniform_linear_layer(self.l1)
        uniform_linear_layer(self.l2)
    
    def forward(self, state):
        obs_emb = F.one_hot(torch.LongTensor([int(state)]), num_classes=self.observation_space_size)
        out1 = torch.sigmoid(self.l1(obs_emb.float()))
        return self.l2(out1).view((-1)) # 1 x ACTION_SPACE_SIZE == 1 x 4  =>  4

In [13]:
def take_action(action, env):
    new_state, reward, done, info = env.step(action)
    # Reward function
    # if new_state is a Hole
    if new_state in [5, 7, 11, 12]:
        reward = -1
    # else if new_state is the Goal (Final State)
    elif new_state == 15:
        reward = 1
    # else penalize research
    else:
        reward = -0.01
    return new_state, reward, done, info

class Trainer:
    def __init__(self, env):
        self.agent = Agent(env.observation_space.n, env.action_space.n)
        self.optimizer = optim.Adam(params=self.agent.parameters())
        self.env = env
    
    def train(self, epoch):
        for i in range(epoch):
            print("Perform epoch i = {}".format(i))
            s = self.env.reset()
            j = 0
            while j < 200:
                # perform chosen action
                a = self.choose_action(s)
                s1, r, d, _ = take_action(a,self.env)
                
                # calculate target and loss
                target_q = r + 0.99 * torch.max(self.agent(s1).detach()) # detach from the computing flow
                loss = F.smooth_l1_loss(self.agent(s)[a], target_q)
                
                # update model to optimize Q
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                # update state
                s = s1
                j += 1
                if d == True: break
            
           
        print("Train is done !")

    def choose_action(self, s):
        if (np.random.rand(1) < 0.1): 
            #print("sample action !")
            return self.env.action_space.sample()
        else:
            agent_out = self.agent(s).detach()
            #print(agent_out)
            index_max = np.argmax(agent_out)
            #print("index max = {}".format(index_max))
            #print("torch max action !")
            return index_max.item() 

In [23]:
# Use Trainer class to train Agent network
t = Trainer(env)
t.train(2000)

Perform epoch i = 0
Perform epoch i = 1
Perform epoch i = 2
Perform epoch i = 3
Perform epoch i = 4
Perform epoch i = 5
Perform epoch i = 6
Perform epoch i = 7
Perform epoch i = 8
Perform epoch i = 9
Perform epoch i = 10
Perform epoch i = 11
Perform epoch i = 12
Perform epoch i = 13
Perform epoch i = 14
Perform epoch i = 15
Perform epoch i = 16
Perform epoch i = 17
Perform epoch i = 18
Perform epoch i = 19
Perform epoch i = 20
Perform epoch i = 21
Perform epoch i = 22
Perform epoch i = 23
Perform epoch i = 24
Perform epoch i = 25
Perform epoch i = 26
Perform epoch i = 27
Perform epoch i = 28
Perform epoch i = 29
Perform epoch i = 30
Perform epoch i = 31
Perform epoch i = 32
Perform epoch i = 33
Perform epoch i = 34
Perform epoch i = 35
Perform epoch i = 36
Perform epoch i = 37
Perform epoch i = 38
Perform epoch i = 39
Perform epoch i = 40
Perform epoch i = 41
Perform epoch i = 42
Perform epoch i = 43
Perform epoch i = 44
Perform epoch i = 45
Perform epoch i = 46
Perform epoch i = 47
Pe

Perform epoch i = 385
Perform epoch i = 386
Perform epoch i = 387
Perform epoch i = 388
Perform epoch i = 389
Perform epoch i = 390
Perform epoch i = 391
Perform epoch i = 392
Perform epoch i = 393
Perform epoch i = 394
Perform epoch i = 395
Perform epoch i = 396
Perform epoch i = 397
Perform epoch i = 398
Perform epoch i = 399
Perform epoch i = 400
Perform epoch i = 401
Perform epoch i = 402
Perform epoch i = 403
Perform epoch i = 404
Perform epoch i = 405
Perform epoch i = 406
Perform epoch i = 407
Perform epoch i = 408
Perform epoch i = 409
Perform epoch i = 410
Perform epoch i = 411
Perform epoch i = 412
Perform epoch i = 413
Perform epoch i = 414
Perform epoch i = 415
Perform epoch i = 416
Perform epoch i = 417
Perform epoch i = 418
Perform epoch i = 419
Perform epoch i = 420
Perform epoch i = 421
Perform epoch i = 422
Perform epoch i = 423
Perform epoch i = 424
Perform epoch i = 425
Perform epoch i = 426
Perform epoch i = 427
Perform epoch i = 428
Perform epoch i = 429
Perform ep

Perform epoch i = 776
Perform epoch i = 777
Perform epoch i = 778
Perform epoch i = 779
Perform epoch i = 780
Perform epoch i = 781
Perform epoch i = 782
Perform epoch i = 783
Perform epoch i = 784
Perform epoch i = 785
Perform epoch i = 786
Perform epoch i = 787
Perform epoch i = 788
Perform epoch i = 789
Perform epoch i = 790
Perform epoch i = 791
Perform epoch i = 792
Perform epoch i = 793
Perform epoch i = 794
Perform epoch i = 795
Perform epoch i = 796
Perform epoch i = 797
Perform epoch i = 798
Perform epoch i = 799
Perform epoch i = 800
Perform epoch i = 801
Perform epoch i = 802
Perform epoch i = 803
Perform epoch i = 804
Perform epoch i = 805
Perform epoch i = 806
Perform epoch i = 807
Perform epoch i = 808
Perform epoch i = 809
Perform epoch i = 810
Perform epoch i = 811
Perform epoch i = 812
Perform epoch i = 813
Perform epoch i = 814
Perform epoch i = 815
Perform epoch i = 816
Perform epoch i = 817
Perform epoch i = 818
Perform epoch i = 819
Perform epoch i = 820
Perform ep

Perform epoch i = 1157
Perform epoch i = 1158
Perform epoch i = 1159
Perform epoch i = 1160
Perform epoch i = 1161
Perform epoch i = 1162
Perform epoch i = 1163
Perform epoch i = 1164
Perform epoch i = 1165
Perform epoch i = 1166
Perform epoch i = 1167
Perform epoch i = 1168
Perform epoch i = 1169
Perform epoch i = 1170
Perform epoch i = 1171
Perform epoch i = 1172
Perform epoch i = 1173
Perform epoch i = 1174
Perform epoch i = 1175
Perform epoch i = 1176
Perform epoch i = 1177
Perform epoch i = 1178
Perform epoch i = 1179
Perform epoch i = 1180
Perform epoch i = 1181
Perform epoch i = 1182
Perform epoch i = 1183
Perform epoch i = 1184
Perform epoch i = 1185
Perform epoch i = 1186
Perform epoch i = 1187
Perform epoch i = 1188
Perform epoch i = 1189
Perform epoch i = 1190
Perform epoch i = 1191
Perform epoch i = 1192
Perform epoch i = 1193
Perform epoch i = 1194
Perform epoch i = 1195
Perform epoch i = 1196
Perform epoch i = 1197
Perform epoch i = 1198
Perform epoch i = 1199
Perform epo

Perform epoch i = 1541
Perform epoch i = 1542
Perform epoch i = 1543
Perform epoch i = 1544
Perform epoch i = 1545
Perform epoch i = 1546
Perform epoch i = 1547
Perform epoch i = 1548
Perform epoch i = 1549
Perform epoch i = 1550
Perform epoch i = 1551
Perform epoch i = 1552
Perform epoch i = 1553
Perform epoch i = 1554
Perform epoch i = 1555
Perform epoch i = 1556
Perform epoch i = 1557
Perform epoch i = 1558
Perform epoch i = 1559
Perform epoch i = 1560
Perform epoch i = 1561
Perform epoch i = 1562
Perform epoch i = 1563
Perform epoch i = 1564
Perform epoch i = 1565
Perform epoch i = 1566
Perform epoch i = 1567
Perform epoch i = 1568
Perform epoch i = 1569
Perform epoch i = 1570
Perform epoch i = 1571
Perform epoch i = 1572
Perform epoch i = 1573
Perform epoch i = 1574
Perform epoch i = 1575
Perform epoch i = 1576
Perform epoch i = 1577
Perform epoch i = 1578
Perform epoch i = 1579
Perform epoch i = 1580
Perform epoch i = 1581
Perform epoch i = 1582
Perform epoch i = 1583
Perform epo

Perform epoch i = 1920
Perform epoch i = 1921
Perform epoch i = 1922
Perform epoch i = 1923
Perform epoch i = 1924
Perform epoch i = 1925
Perform epoch i = 1926
Perform epoch i = 1927
Perform epoch i = 1928
Perform epoch i = 1929
Perform epoch i = 1930
Perform epoch i = 1931
Perform epoch i = 1932
Perform epoch i = 1933
Perform epoch i = 1934
Perform epoch i = 1935
Perform epoch i = 1936
Perform epoch i = 1937
Perform epoch i = 1938
Perform epoch i = 1939
Perform epoch i = 1940
Perform epoch i = 1941
Perform epoch i = 1942
Perform epoch i = 1943
Perform epoch i = 1944
Perform epoch i = 1945
Perform epoch i = 1946
Perform epoch i = 1947
Perform epoch i = 1948
Perform epoch i = 1949
Perform epoch i = 1950
Perform epoch i = 1951
Perform epoch i = 1952
Perform epoch i = 1953
Perform epoch i = 1954
Perform epoch i = 1955
Perform epoch i = 1956
Perform epoch i = 1957
Perform epoch i = 1958
Perform epoch i = 1959
Perform epoch i = 1960
Perform epoch i = 1961
Perform epoch i = 1962
Perform epo

In [22]:
# Test the best solution from an Agent (with a neural network)
# Initial state = 0, Final state = 15
S = env.reset()  # S is the initial state = 0
env.render()

while (S != 15):
  agent_out = t.agent(S).detach()   
  A = np.argmax(agent_out).item()
  print("Take action {}".format(A))
  S, _, _, _ = env.step(A)
  env.render()


[41mS[0mFFF
FHFH
FFFH
HFFG
Take action 2
  (Right)
S[41mF[0mFF
FHFH
FFFH
HFFG
Take action 2
  (Right)
SF[41mF[0mF
FHFH
FFFH
HFFG
Take action 1
  (Down)
SFFF
FH[41mF[0mH
FFFH
HFFG
Take action 1
  (Down)
SFFF
FHFH
FF[41mF[0mH
HFFG
Take action 1
  (Down)
SFFF
FHFH
FFFH
HF[41mF[0mG
Take action 2
  (Right)
SFFF
FHFH
FFFH
HFF[41mG[0m


In [24]:
register(
   id="FrozenLakeNotSlippery-v1",
   entry_point='gym.envs.toy_text:FrozenLakeEnv',
   kwargs={'map_name': '8x8', 'is_slippery': False},
)

env2 = gym.make("FrozenLakeNotSlippery-v1")

In [34]:
env2.reset()
env2.render()


[41mS[0mFFFFFFF
FFFFFFFF
FFFHFFFF
FFFFFHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFFG


In [30]:
# Define new take_action function for a frozen lake 8x8
def take_action(action, env):
    new_state, reward, done, info = env.step(action)
    # Reward function
    # if new_state is a Hole
    if new_state in [19, 29, 35, 41, 42, 46, 49, 52, 54, 59]:
        reward = -1
     # else if new_state is the Goal (Final State)
    elif new_state == 63:
        reward = 1
    # else penalize research
    else:
        reward = -0.01
    return new_state, reward, done, info

In [31]:
# Use Trainer class to train Agent network
t = Trainer(env2)
t.train(2000)

Perform epoch i = 0
Perform epoch i = 1
Perform epoch i = 2
Perform epoch i = 3
Perform epoch i = 4
Perform epoch i = 5
Perform epoch i = 6
Perform epoch i = 7
Perform epoch i = 8
Perform epoch i = 9
Perform epoch i = 10
Perform epoch i = 11
Perform epoch i = 12
Perform epoch i = 13
Perform epoch i = 14
Perform epoch i = 15
Perform epoch i = 16
Perform epoch i = 17
Perform epoch i = 18
Perform epoch i = 19
Perform epoch i = 20
Perform epoch i = 21
Perform epoch i = 22
Perform epoch i = 23
Perform epoch i = 24
Perform epoch i = 25
Perform epoch i = 26
Perform epoch i = 27
Perform epoch i = 28
Perform epoch i = 29
Perform epoch i = 30
Perform epoch i = 31
Perform epoch i = 32
Perform epoch i = 33
Perform epoch i = 34
Perform epoch i = 35
Perform epoch i = 36
Perform epoch i = 37
Perform epoch i = 38
Perform epoch i = 39
Perform epoch i = 40
Perform epoch i = 41
Perform epoch i = 42
Perform epoch i = 43
Perform epoch i = 44
Perform epoch i = 45
Perform epoch i = 46
Perform epoch i = 47
Pe

Perform epoch i = 384
Perform epoch i = 385
Perform epoch i = 386
Perform epoch i = 387
Perform epoch i = 388
Perform epoch i = 389
Perform epoch i = 390
Perform epoch i = 391
Perform epoch i = 392
Perform epoch i = 393
Perform epoch i = 394
Perform epoch i = 395
Perform epoch i = 396
Perform epoch i = 397
Perform epoch i = 398
Perform epoch i = 399
Perform epoch i = 400
Perform epoch i = 401
Perform epoch i = 402
Perform epoch i = 403
Perform epoch i = 404
Perform epoch i = 405
Perform epoch i = 406
Perform epoch i = 407
Perform epoch i = 408
Perform epoch i = 409
Perform epoch i = 410
Perform epoch i = 411
Perform epoch i = 412
Perform epoch i = 413
Perform epoch i = 414
Perform epoch i = 415
Perform epoch i = 416
Perform epoch i = 417
Perform epoch i = 418
Perform epoch i = 419
Perform epoch i = 420
Perform epoch i = 421
Perform epoch i = 422
Perform epoch i = 423
Perform epoch i = 424
Perform epoch i = 425
Perform epoch i = 426
Perform epoch i = 427
Perform epoch i = 428
Perform ep

Perform epoch i = 776
Perform epoch i = 777
Perform epoch i = 778
Perform epoch i = 779
Perform epoch i = 780
Perform epoch i = 781
Perform epoch i = 782
Perform epoch i = 783
Perform epoch i = 784
Perform epoch i = 785
Perform epoch i = 786
Perform epoch i = 787
Perform epoch i = 788
Perform epoch i = 789
Perform epoch i = 790
Perform epoch i = 791
Perform epoch i = 792
Perform epoch i = 793
Perform epoch i = 794
Perform epoch i = 795
Perform epoch i = 796
Perform epoch i = 797
Perform epoch i = 798
Perform epoch i = 799
Perform epoch i = 800
Perform epoch i = 801
Perform epoch i = 802
Perform epoch i = 803
Perform epoch i = 804
Perform epoch i = 805
Perform epoch i = 806
Perform epoch i = 807
Perform epoch i = 808
Perform epoch i = 809
Perform epoch i = 810
Perform epoch i = 811
Perform epoch i = 812
Perform epoch i = 813
Perform epoch i = 814
Perform epoch i = 815
Perform epoch i = 816
Perform epoch i = 817
Perform epoch i = 818
Perform epoch i = 819
Perform epoch i = 820
Perform ep

Perform epoch i = 1144
Perform epoch i = 1145
Perform epoch i = 1146
Perform epoch i = 1147
Perform epoch i = 1148
Perform epoch i = 1149
Perform epoch i = 1150
Perform epoch i = 1151
Perform epoch i = 1152
Perform epoch i = 1153
Perform epoch i = 1154
Perform epoch i = 1155
Perform epoch i = 1156
Perform epoch i = 1157
Perform epoch i = 1158
Perform epoch i = 1159
Perform epoch i = 1160
Perform epoch i = 1161
Perform epoch i = 1162
Perform epoch i = 1163
Perform epoch i = 1164
Perform epoch i = 1165
Perform epoch i = 1166
Perform epoch i = 1167
Perform epoch i = 1168
Perform epoch i = 1169
Perform epoch i = 1170
Perform epoch i = 1171
Perform epoch i = 1172
Perform epoch i = 1173
Perform epoch i = 1174
Perform epoch i = 1175
Perform epoch i = 1176
Perform epoch i = 1177
Perform epoch i = 1178
Perform epoch i = 1179
Perform epoch i = 1180
Perform epoch i = 1181
Perform epoch i = 1182
Perform epoch i = 1183
Perform epoch i = 1184
Perform epoch i = 1185
Perform epoch i = 1186
Perform epo

Perform epoch i = 1511
Perform epoch i = 1512
Perform epoch i = 1513
Perform epoch i = 1514
Perform epoch i = 1515
Perform epoch i = 1516
Perform epoch i = 1517
Perform epoch i = 1518
Perform epoch i = 1519
Perform epoch i = 1520
Perform epoch i = 1521
Perform epoch i = 1522
Perform epoch i = 1523
Perform epoch i = 1524
Perform epoch i = 1525
Perform epoch i = 1526
Perform epoch i = 1527
Perform epoch i = 1528
Perform epoch i = 1529
Perform epoch i = 1530
Perform epoch i = 1531
Perform epoch i = 1532
Perform epoch i = 1533
Perform epoch i = 1534
Perform epoch i = 1535
Perform epoch i = 1536
Perform epoch i = 1537
Perform epoch i = 1538
Perform epoch i = 1539
Perform epoch i = 1540
Perform epoch i = 1541
Perform epoch i = 1542
Perform epoch i = 1543
Perform epoch i = 1544
Perform epoch i = 1545
Perform epoch i = 1546
Perform epoch i = 1547
Perform epoch i = 1548
Perform epoch i = 1549
Perform epoch i = 1550
Perform epoch i = 1551
Perform epoch i = 1552
Perform epoch i = 1553
Perform epo

Perform epoch i = 1883
Perform epoch i = 1884
Perform epoch i = 1885
Perform epoch i = 1886
Perform epoch i = 1887
Perform epoch i = 1888
Perform epoch i = 1889
Perform epoch i = 1890
Perform epoch i = 1891
Perform epoch i = 1892
Perform epoch i = 1893
Perform epoch i = 1894
Perform epoch i = 1895
Perform epoch i = 1896
Perform epoch i = 1897
Perform epoch i = 1898
Perform epoch i = 1899
Perform epoch i = 1900
Perform epoch i = 1901
Perform epoch i = 1902
Perform epoch i = 1903
Perform epoch i = 1904
Perform epoch i = 1905
Perform epoch i = 1906
Perform epoch i = 1907
Perform epoch i = 1908
Perform epoch i = 1909
Perform epoch i = 1910
Perform epoch i = 1911
Perform epoch i = 1912
Perform epoch i = 1913
Perform epoch i = 1914
Perform epoch i = 1915
Perform epoch i = 1916
Perform epoch i = 1917
Perform epoch i = 1918
Perform epoch i = 1919
Perform epoch i = 1920
Perform epoch i = 1921
Perform epoch i = 1922
Perform epoch i = 1923
Perform epoch i = 1924
Perform epoch i = 1925
Perform epo

In [32]:
# Test the best solution from an Agent (with a neural network)
# Initial state = 0, Final state = 63
S = env2.reset()  # S is the initial state = 0
env2.render()

while (S != 63):
  agent_out = t.agent(S).detach()   
  A = np.argmax(agent_out).item()
  print("Take action {}".format(A))
  S, _, _, _ = env2.step(A)
  env2.render()


[41mS[0mFFFFFFF
FFFFFFFF
FFFHFFFF
FFFFFHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFFG
Take action 2
  (Right)
S[41mF[0mFFFFFF
FFFFFFFF
FFFHFFFF
FFFFFHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFFG
Take action 2
  (Right)
SF[41mF[0mFFFFF
FFFFFFFF
FFFHFFFF
FFFFFHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFFG
Take action 1
  (Down)
SFFFFFFF
FF[41mF[0mFFFFF
FFFHFFFF
FFFFFHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFFG
Take action 2
  (Right)
SFFFFFFF
FFF[41mF[0mFFFF
FFFHFFFF
FFFFFHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFFG
Take action 2
  (Right)
SFFFFFFF
FFFF[41mF[0mFFF
FFFHFFFF
FFFFFHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFFG
Take action 1
  (Down)
SFFFFFFF
FFFFFFFF
FFFH[41mF[0mFFF
FFFFFHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFFG
Take action 1
  (Down)
SFFFFFFF
FFFFFFFF
FFFHFFFF
FFFF[41mF[0mHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFFG
Take action 1
  (Down)
SFFFFFFF
FFFFFFFF
FFFHFFFF
FFFFFHFF
FFFH[41mF[0mFFF
FHHFFFHF
FHFFHFHF
FFFHFFFG
Take action 2
  (Right)
SFFFFFFF
FFFFFFFF
FFFHFFFF
FFFFFHFF
FFFHF[41mF[0mFF
FHHF