In [1]:
import gym
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
import math

In [2]:
## preprocesser une image du jeu (scale down, monochrome)
def prepro(I):
    """ preprocessing 210x160x3 uint8 frame into 6400 (80x80) 1D float vector """
    I = I[35:195] # crop
    I = I[::2,::2,0] # downsample by factor of 2
    I[I == 144] = 0 # erase background (background type 1)
    I[I == 109] = 0 # erase background (background type 2)
    I[I != 0] = 1 # everything else (paddles, ball) just set to 1
    return I.astype(np.float).ravel()

def discount_rewards(r):
    """ take 1D float array of rewards and compute discounted reward """
    discounted_r = np.zeros_like(r)
    running_add = 0
    for t in reversed(range(0,len(r))):
        if r[t] != 0: running_add = 0 # reset the sum, since this was a game boundary (pong specific!)
        running_add = running_add * gamma + r[t]
        discounted_r[t] = running_add
    return discounted_r

In [3]:
class Rl(nn.Module):
        def __init__(self,input_size,h_size, output_size):
            super(Rl,self).__init__()
            self.linear1 = nn.Linear(input_size,h_size)
            self.linear2 = nn.Linear(h_size,output_size)
            self.sigmoid = nn.Sigmoid()
            
        def forward(self,x):
            h = F.relu(self.linear1(x))
            logp = self.sigmoid(self.linear2(h))
            return logp

In [8]:
#TODO : bug quand reward = 0 et du training voir karpathy : http://karpathy.github.io/2016/05/31/rl/
env = gym.make('Pong-v0') #charge l'environement
env.reset()  # reinitialiser le jeu

env.action_space # liste des actions possibles
env.unwrapped.get_action_meanings() # signification des actions
observation, reward, done, info = env.step(1) # joue l'action 1
# print(observation,done,info)

r = 0
r_prec = 0
D = 80*80
R_model = Rl(D,200,2)
gamma = 0.99
l = 0
render = False
optimizer = optim.SGD(R_model.parameters(),lr=0.001)

for num_episode,episode in enumerate(range(1000)):
    log_P_action = []
    rewards = []
    reward_sum = 0
    prec = None
    observation = env.reset()
    for t in range(1000):
        if render : env.render() # afficher l'etat du jeu
        
        #prepocessing et copie de l'etat precdu jeu
        current = prepro(observation)
        x =  current - prec if prec is not None else np.zeros(D)
        prec = current
        
        x = Variable(torch.FloatTensor(x))
        #calcul des probas d'action
        logP = R_model(x)
        #tirage d'une action
        action = torch.multinomial(logP.exp(),1)
        log_P_action.append(logP[action])
        
        #0  aller en bas 
        g_action = 2 if action.data[0] == 0 else 3 
#         print(action)
        observation, reward, done, info = env.step(g_action)
        
        rewards.append(reward)
        reward_sum += reward
        
    if reward_sum != 0:
        discount_reward = discount_rewards(rewards)
        discount_reward -= np.mean(discount_reward)
        discount_reward /= np.std(discount_reward)

        optimizer.zero_grad()
        #calcul du loss
        loss = []
        for i in range(len(log_P_action)):
            loss.append(-log_P_action[i] * discount_reward[i])

        loss = torch.cat(loss).sum()
        loss.backward()
        print(loss.data[0])
        optimizer.step()
    

#en entree du reseau du neurone 


-0.0698164701461792
0.482412189245224
-0.39479994773864746
-0.03588195890188217
-0.11499029397964478
0.1387137472629547
0.006992551032453775
-0.005157307721674442
0.05214235931634903
0.1623355597257614
-0.24651741981506348
-0.20082837343215942
-0.21208174526691437
-0.36587104201316833
-0.4240066111087799
-0.1455814093351364
-0.44285130500793457
-0.092318095266819
-0.5013523697853088
-0.6368620991706848
0.19984303414821625
-0.30826517939567566
-0.5078850984573364
-0.41932961344718933
-0.6721529960632324
-0.40533778071403503
-0.0518607534468174
0.31200721859931946
-0.6708685159683228
-0.8724162578582764
-0.19425415992736816
-0.9713054299354553
-0.420680433511734
-0.34701216220855713
-0.20485687255859375
-0.6603385210037231
-0.45080676674842834
-1.0268237590789795
-0.9250720739364624
-2.0618178844451904
-0.02684801258146763
-1.1003836393356323
-0.8997959494590759
-1.0962839126586914
0.21672065556049347
-1.360263466835022
-2.3127012252807617
-0.9847291707992554
-0.8821638822555542
0.628937

-92.85964965820312
-32.51166534423828
-57.27985763549805
-116.50524139404297
-32.57722473144531
-41.97479248046875
-41.434600830078125
-68.22177124023438
-70.0263900756836
-110.4729995727539
-40.03157043457031
-39.585601806640625
-51.02585983276367
-93.02742004394531
-85.64688873291016
-33.874359130859375
-44.08659744262695
-44.93728256225586
-27.160083770751953
-35.767669677734375
-115.09878540039062
-73.63119506835938
-51.23134994506836
-37.492008209228516
-29.282730102539062
-34.15887451171875
-36.857845306396484
-78.01175689697266
-86.20846557617188
-78.48875427246094
-75.63316345214844
-23.96537208557129
-69.46405792236328
-32.03873825073242
-65.91513061523438
-98.79501342773438
-113.65746307373047
-123.3941879272461
-47.379573822021484
-71.42252349853516
-41.954345703125
-89.60720825195312
-37.421974182128906
-53.60773468017578
-60.274967193603516
-69.67583465576172
-102.91673278808594
-37.068328857421875
-66.96978759765625
-49.6028938293457
-68.48284912109375
-86.24683380126953


-154.56797790527344
-122.67424011230469
-78.34388732910156
-74.06121063232422
-183.59268188476562
-77.2278060913086
-81.3206787109375
-80.8995132446289
-142.18377685546875
-182.62696838378906
-108.29805755615234
-68.23980712890625
-170.44052124023438
-168.84210205078125
-181.81358337402344
-149.45762634277344
-134.1556854248047
-122.8560791015625
-64.0740966796875
-60.27806091308594
-51.658409118652344
-139.85511779785156
-186.76016235351562
-77.35364532470703
-135.15476989746094
-159.6597442626953
-84.28873443603516
-60.20656204223633
-207.0114288330078
-84.90101623535156
-147.73658752441406
-120.05948638916016
-57.969242095947266
-64.7723617553711
-80.46463012695312
-209.6849365234375
-122.10799407958984
-63.7196044921875
-192.67860412597656
-138.447021484375
-82.5154037475586
-152.88388061523438
-50.5294189453125
-52.749855041503906
-189.23265075683594
-112.1329345703125
-56.754981994628906
-60.206459045410156
-157.83840942382812
-145.47726440429688
-177.576171875
-79.70951080322266

In [6]:
env.action_space # liste des actions possibles
env.unwrapped.get_action_meanings() # signification des actions

['NOOP', 'FIRE', 'RIGHT', 'LEFT', 'RIGHTFIRE', 'LEFTFIRE']