In [10]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random

In [11]:
from QNetwork import QNetwork
from ReplayBuffer import ReplayBuffer

In [12]:
BUFFER_SIZE = int(1e5)  # replay buffer size
BATCH_SIZE = 64         # minibatch size
GAMMA = 0.99            # discount factor
TAU = 1e-3              # for soft update of target parameters
LR = 5e-4               # learning rate 
UPDATE_EVERY = 4        # how often to update the network

In [13]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
class Agent(object):
    
    def __init__(self, state_size, action_size, seed):
        # create the local and target Qnetworks
        # create the memory buffer
        # counter for steps to learn
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)
        
        self.qn_local = QNetwork(state_size, action_size, seed=seed).to(device)
        self.qn_target = QNetwork(state_size, action_size, seed=seed).to(device)
        self.optimizer = optim.Adam(params=self.qn_local.parameters(), lr=LR)
        
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
        self.t_step = 0
    
    def step(self, state, action, reward, next_state, done):
        self.memory.add(state, action, reward, next_state, done)
        self.t_step = (self.t_step + 1) % UPDATE_EVERY
        if self.t_step == 0:
            experiences = self.memory.sample()
            self.learn(experiences, GAMMA)
    
    def act(self, state, eps=0.):
        # implement an epsilon greedy policy
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        self.qn_local.eval()
        with torch.no_grad():
            action_values = self.qn_local(state)
        self.qn_local.train()
        if random.random() > eps:
            return np.argmax(action_values.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.action_size))
     
    def learn(self, experiences, gamma):
        # collect experience tuples in memory and learn every S steps
        states, actions, rewards, next_states, dones = experiences
        
        Q_targets_next = self.qn_target(next_states).detach().max(1)[0].unsqueeze(1)
        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
        
        Q_expected = self.qn_local(states).gather(1, actions)
        
        loss = F.mse_loss(Q_expected, Q_targets)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        self.soft_update(local_model=self.qn_local, target_model=self.qn_target, tau=TAU)  
        
    def soft_update(self, local_model, target_model, tau):
        # Update the local target network with the local network params
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)