In [1]:
%matplotlib inline

In [2]:
import gym
import math
import random
import tqdm
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from collections import deque
from itertools import count
from copy import deepcopy
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
from torch.distributions import Categorical
from torch.autograd import Variable

env = gym.make('MountainCarContinuous-v0').unwrapped

is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display    
print("Is python : {}".format(is_ipython))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device : {}".format(device))

ACTIONS_NUM = env.action_space
print("Number of actions : {}".format(ACTIONS_NUM))

Is python : True
Device : cpu
Number of actions : Box(1,)


In [3]:
STATE_SIZE = 4
STATE_W = 84
STATE_H = 84
MEMSIZE = 70000
  
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

class ReplayMemory:

    def __init__(self, capacity = MEMSIZE):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [4]:
class DQN(nn.Module):

    def __init__(self):
        super(DQN, self).__init__()
        self.l1 = nn.Linear(2, 16)
        self.l2 = nn.Linear(16, 1)
        self.init_weights()
        
    def init_weights(layer):
        if type(layer) == nn.Linear:
            nn.init.xavier_normal(layer.weight)

    def forward(self, x):
        x = F.relu(self.l1(x))
        x = F.tanh(self.l2(x))
        return x

In [5]:
policy_net = DQN().to(device)
target_net = DQN().to(device)

target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.Adam(policy_net.parameters(), lr=3e-4, amsgrad=True)

memory = ReplayMemory(70000)


EPS_START = 1
EPS_END = 0.1
EPS_DECAY = 1000000

steps_done = 0


train_rewards = []

mean_size = 100
mean_step = 1

def plot_rewards(rewards = train_rewards, name = "Train"):
    plt.figure(2)
    plt.clf()
    plt.title(name)
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(rewards)
    if len(rewards) > mean_size:
        means = np.array([rewards[i:i+mean_size:] for i in range(0, len(rewards) - mean_size, mean_step)]).mean(1)
        means = np.concatenate((np.zeros(mean_size - 1), means))
        plt.plot(means)

In [6]:
BATCH_SIZE = 32
GAMMA = 0.99

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.uint8)

    non_final_next_states = torch.tensor(next_state).to(device).float()

    state_batch =  torch.tensor(batch.state).to(device).float()
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    next_state_values = torch.zeros(BATCH_SIZE, device=device).detach()
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(0)[0].detach()
    
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
    
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()
    
    del non_final_mask
    del non_final_next_states
    del state_batch
    del action_batch
    del reward_batch
    del state_action_values
    del next_state_values
    del expected_state_action_values
    del loss

In [7]:
TEST_EPS = 0.05

def show_state(env, step=0, info=""):
    plt.figure(3)
    plt.clf()
    plt.imshow(env.render(mode='rgb_array'))
    plt.title("%s | Step: %d %s" % (env.spec.id, step, info))
    plt.axis('off')

    display.clear_output(wait=True)
    display.display(plt.gcf())


In [8]:
def select_action(state, env):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1

    if sample < eps_threshold:
        action = env.action_space.sample()
    else:
        action = policy_net(torch.tensor(state).to(device).float()).cpu().data.numpy()
    return action

In [9]:
NUM_EPISODES = 100000

OPTIMIZE_MODEL_STEP = 20
TARGET_UPDATE = 20000

STEPS_BEFORE_TRAIN = 1500000

policy_net.train()
target_net.eval()
i = 0
test_rewards = []



for e in tqdm.tqdm_notebook(range(NUM_EPISODES)):
    ep_rewards = []
    state = env.reset()
    for t in range(5000):
        action = select_action(state,env)
        new_state, reward, done, info = env.step(action)
        action = torch.tensor([action], device=device, dtype=torch.long)
        ep_rewards.append(reward)
        reward = torch.tensor([reward], device=device)

        if not done:
            memory.push(state, action, (new_state), reward)
            state = new_state
        else:
            next_state = env.reset()
            memory.push(state, action, (new_state), reward)
            
        if (steps_done > STEPS_BEFORE_TRAIN) and steps_done % OPTIMIZE_MODEL_STEP == 0:
            if (not i):
                print("start train")
                i = 1
            optimize_model()
        
        if steps_done % TARGET_UPDATE == 0:

            target_net.load_state_dict(policy_net.state_dict())
        
        
        if done:
            train_rewards.append(np.sum(ep_rewards))
            print("episode:", e, "  score:", train_rewards[-1], "  mean:",np.mean(train_rewards[-100:]))
            break 



HBox(children=(IntProgress(value=0, max=100000), HTML(value='')))



episode: 0   score: -35.633106827345784   mean: -35.633106827345784
episode: 1   score: -31.590393035937257   mean: -33.61174993164152
episode: 2   score: -32.46155846119776   mean: -33.22835277482693
episode: 3   score: -33.244723123241   mean: -33.232445361930445
episode: 4   score: -33.74310302418428   mean: -33.334576894381215
episode: 5   score: -32.2129242767625   mean: -33.14763479144476
episode: 6   score: -31.366857271687227   mean: -32.89323800290797
episode: 7   score: -32.30607985052251   mean: -32.819843233859785
episode: 8   score: -34.104053523235045   mean: -32.96253326601259
episode: 9   score: -32.7454449460149   mean: -32.940824434012825
episode: 10   score: -33.15256429719284   mean: -32.96007351248374
episode: 11   score: -33.401754206076816   mean: -32.99688023694983
episode: 12   score: -32.63658015579543   mean: -32.96916484609179
episode: 13   score: -32.91497677391513   mean: -32.965294269507744
episode: 14   score: -34.220024034150185   mean: -33.048942920483

episode: 122   score: -30.34279695939368   mean: -30.531935129017338
episode: 123   score: -28.26902598006652   mean: -30.49217251842804
episode: 124   score: -28.72490924755182   mean: -30.446100933315435
episode: 125   score: -31.28764808961553   mean: -30.428296731497777
episode: 126   score: -30.37413592181767   mean: -30.41371011105806
episode: 127   score: -29.426991337080697   mean: -30.37109913789464
episode: 128   score: -31.409055846974134   mean: -30.371755497672485
episode: 129   score: -28.46042258209644   mean: -30.326343726977353
episode: 130   score: -29.79458802211708   mean: -30.29719521356501
episode: 131   score: -28.98177645072908   mean: -30.272451428269438
episode: 132   score: -29.406255181946126   mean: -30.25167060817736
episode: 133   score: -30.064070056897155   mean: -30.22402945937654
episode: 134   score: -30.51242371541625   mean: -30.212022367492388
episode: 135   score: -30.59562032195593   mean: -30.188542060179348
episode: 136   score: -29.9655622489

episode: 242   score: -27.308081483010753   mean: -28.765162271475706
episode: 243   score: -25.928921210037057   mean: -28.727968900654663
episode: 244   score: -28.740429075874502   mean: -28.714766003003728
episode: 245   score: -27.43293017226836   mean: -28.682801296297256
episode: 246   score: -28.06267793910558   mean: -28.691544615496618
episode: 247   score: -26.309592601908676   mean: -28.668329984876113
episode: 248   score: -26.741359473111466   mean: -28.635214370861327
episode: 249   score: -27.73306225659939   mean: -28.611725007810186
episode: 250   score: -26.846684217813827   mean: -28.59545112957247
episode: 251   score: -27.23390475024599   mean: -28.56465471468166
episode: 252   score: -27.631344345788172   mean: -28.541231906738894
episode: 253   score: -27.467870532869185   mean: -28.517839555357078
episode: 254   score: -27.499230674877563   mean: -28.50182356345316
episode: 255   score: -27.534120036288297   mean: -28.483895990988252
episode: 256   score: -26.9

episode: 362   score: -25.201212898826036   mean: -26.443375988640497
episode: 363   score: -26.42998443524821   mean: -26.432522674797134
episode: 364   score: -24.126899966976552   mean: -26.389874297689975
episode: 365   score: -25.4589420898914   mean: -26.366301774712543
episode: 366   score: -24.613696220357944   mean: -26.33778953130861
episode: 367   score: -24.559308781659677   mean: -26.304227220250727
episode: 368   score: -25.810099080254503   mean: -26.27867912511221
episode: 369   score: -25.562996928294393   mean: -26.268679539825154
episode: 370   score: -25.664160847064217   mean: -26.24681611509007
episode: 371   score: -24.51889434976436   mean: -26.22915856060661
episode: 372   score: -24.103128327527156   mean: -26.20454385816795
episode: 373   score: -24.439638831680018   mean: -26.180865628149064
episode: 374   score: -24.542831166320017   mean: -26.162028157996723
episode: 375   score: -25.21897999330116   mean: -26.135019007508753
episode: 376   score: -25.8833

episode: 482   score: -24.090093683757537   mean: -24.234343210298356
episode: 483   score: -21.728828709709557   mean: -24.199067292851556
episode: 484   score: -23.388742160649272   mean: -24.18871050850868
episode: 485   score: -22.732871653909292   mean: -24.154762823371893
episode: 486   score: -23.87103572116351   mean: -24.149157913653664
episode: 487   score: -24.014966620613997   mean: -24.124945348281596
episode: 488   score: -23.8490230319952   mean: -24.102444218010067
episode: 489   score: -23.5168715897576   mean: -24.069496982408634
episode: 490   score: -22.989256098116236   mean: -24.070505663472574
episode: 491   score: -22.5530876986871   mean: -24.04515064253512
episode: 492   score: -21.732771162586708   mean: -24.011927981324412
episode: 493   score: -21.968183380153608   mean: -23.961255627248658
episode: 494   score: -23.527448018426842   mean: -23.961395712289562
episode: 495   score: -24.554841083263817   mean: -23.969909712234845
episode: 496   score: -22.963

episode: 601   score: -21.566040582938854   mean: -22.33709541237072
episode: 602   score: -21.5920852713852   mean: -22.313808888305164
episode: 603   score: -21.54908911604636   mean: -22.28787395714822
episode: 604   score: -21.888252338344195   mean: -22.286059121243767
episode: 605   score: -22.35963961997244   mean: -22.27726253743016
episode: 606   score: -20.60399185789627   mean: -22.252013845393513
episode: 607   score: -23.12982965178732   mean: -22.24169202114094
episode: 608   score: -21.8707151995642   mean: -22.235399673037602
episode: 609   score: -21.806355261158245   mean: -22.21816969771992
episode: 610   score: -22.12961555797424   mean: -22.216904740660194
episode: 611   score: -21.4917232631487   mean: -22.192638899902295
episode: 612   score: -22.763048781289427   mean: -22.1891045695572
episode: 613   score: -21.731725730606613   mean: -22.178148006445575
episode: 614   score: -22.066949772858884   mean: -22.162064572020228
episode: 615   score: -22.617382026411

episode: 720   score: -20.160002231938563   mean: -20.78991828424701
episode: 721   score: -20.555840076838845   mean: -20.773861295792916
episode: 722   score: -21.586325689013844   mean: -20.765223022805284
episode: 723   score: -20.6856783951933   mean: -20.77100450350875
episode: 724   score: -20.401480068428775   mean: -20.77335026019218
episode: 725   score: -19.376156809665808   mean: -20.751712587895046
episode: 726   score: -19.977248122010625   mean: -20.753647084569597
episode: 727   score: -18.465562896396037   mean: -20.737168517455075
episode: 728   score: -20.80045288505176   mean: -20.733734423662472
episode: 729   score: -20.137571638329543   mean: -20.711436845734013
episode: 730   score: -18.273836590209662   mean: -20.690476769256602
episode: 731   score: -20.204925179036426   mean: -20.690653689826746
episode: 732   score: -19.039483902231346   mean: -20.655608364881477
episode: 733   score: -19.681202124322073   mean: -20.6415476927796
episode: 734   score: -20.59

episode: 839   score: -19.715090053712633   mean: -19.165931391208115
episode: 840   score: -19.38287618682152   mean: -19.144350192573462
episode: 841   score: -18.410509450036365   mean: -19.138625623037917
episode: 842   score: -18.640311928074492   mean: -19.127413818374922
episode: 843   score: -18.98237086100609   mean: -19.117750316844308
episode: 844   score: -18.1671642936331   mean: -19.11126494440567
episode: 845   score: -18.438648938595655   mean: -19.09022492264533
episode: 846   score: -19.12342700867569   mean: -19.08908907777014
episode: 847   score: -17.2891760220759   mean: -19.059046931738735
episode: 848   score: -18.755687661733223   mean: -19.040849183593387
episode: 849   score: -18.762810975672892   mean: -19.03901310598745
episode: 850   score: -19.167422050183035   mean: -19.038158754985897
episode: 851   score: -18.776336859016347   mean: -19.03031992047314
episode: 852   score: -18.9978358093216   mean: -19.026125938479105
episode: 853   score: -18.45250477

episode: 958   score: -17.079496417405203   mean: -17.91826820464387
episode: 959   score: -18.216595176392225   mean: -17.91730661525375
episode: 960   score: -16.821150611771884   mean: -17.90637342376167
episode: 961   score: -17.918225856934725   mean: -17.885673647633276
episode: 962   score: -18.11044024808482   mean: -17.87676767980785
episode: 963   score: -16.743858980996023   mean: -17.85724327537042
episode: 964   score: -17.42927277192912   mean: -17.848741694412848
episode: 965   score: -17.07391053440444   mean: -17.828856247567135
episode: 966   score: -16.879189336839485   mean: -17.809289913939594
episode: 967   score: -18.706107005559385   mean: -17.806492587592178
episode: 968   score: -17.546875763820424   mean: -17.791056918665774
episode: 969   score: -17.598952480611114   mean: -17.78396786017942
episode: 970   score: -16.963174593473575   mean: -17.77721849212625
episode: 971   score: -17.52931831564227   mean: -17.76169994345137
episode: 972   score: -18.035602

episode: 1077   score: -17.84100007871337   mean: -16.635901370582825
episode: 1078   score: -18.129691339386902   mean: -16.639443425242337
episode: 1079   score: -15.990620037933578   mean: -16.62981366235356
episode: 1080   score: -15.414126305324608   mean: -16.621807829376195
episode: 1081   score: -17.63981842735501   mean: -16.619939480502293
episode: 1082   score: -15.470268397282902   mean: -16.59538132676554
episode: 1083   score: -15.944604488602305   mean: -16.59848309591335
episode: 1084   score: -15.603270110457357   mean: -16.588355789482847
episode: 1085   score: -17.04998010845571   mean: -16.586003794963368
episode: 1086   score: -18.036192517644636   mean: -16.592269327207113
episode: 1087   score: -15.14411896763077   mean: -16.58844100364353
episode: 1088   score: -15.692828500828629   mean: -16.572399261852308
episode: 1089   score: -16.569635746966053   mean: -16.56471668529419
episode: 1090   score: -16.80519732206051   mean: -16.568083267390985
episode: 1091   

episode: 1194   score: -16.21827920363103   mean: -15.771206652107816
episode: 1195   score: -15.261353277249993   mean: -15.775098870905124
episode: 1196   score: -17.00447453461954   mean: -15.785321949173527
episode: 1197   score: -14.955402966055251   mean: -15.776525683407684
episode: 1198   score: -15.126741556538152   mean: -15.77228659659376
episode: 1199   score: -14.93694111227989   mean: -15.755973999118503
episode: 1200   score: -14.158076723081932   mean: -15.732625268751711
episode: 1201   score: -16.679602922671826   mean: -15.736300192342094
episode: 1202   score: -15.394601957389552   mean: -15.717073154645366
episode: 1203   score: -15.994741727405893   mean: -15.714667086851966
episode: 1204   score: -16.026543120106812   mean: -15.71289567517967
episode: 1205   score: -15.185276728291846   mean: -15.71657412428663
episode: 1206   score: -15.75287011788975   mean: -15.704274269640045
episode: 1207   score: -13.887816990595558   mean: -15.685486284063472
episode: 1208

KeyboardInterrupt: 

In [None]:
#доучить не успел

In [None]:
print("episode:", e, "  score:", train_rewards[-1], "  mean:",np.mean(train_rewards[-100:]))

In [None]:
TEST_EPS = 0.05

def show_state(env, step=0, info=""):
    plt.figure(3)
    plt.clf()
    plt.imshow(env.render(mode='rgb_array'))
    plt.title("%s | Step: %d %s" % (env.spec.id, step, info))
    plt.axis('off')

    display.clear_output(wait=True)
    display.display(plt.gcf())
    

def act(state,env ,eps_threshold,add_noise=True):
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
    policy_net.eval()
    with torch.no_grad():
        action = policy_net(torch.tensor(state).to(device).float()).cpu().data.numpy()

    policy_net.train()
    print(action)

    sample = random.random()

    return np.clip(action, -1, 1)

policy_net.eval()

state = env.reset()
total_reward = 0

for i in count():
    action = act(state,env ,TEST_EPS)
    next_t, reward, done, _ = env.step(action)
    show_state(env, i)
    if not done:
        state = next_t
    else:
        break
    
    
print("Total game reward : {}".format(total_reward))