<a href="https://colab.research.google.com/github/chungsp2003/nv-blockchain-bluemix/blob/master/agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
pip install wandb



In [0]:
pip install Q

Collecting Q
  Downloading https://files.pythonhosted.org/packages/53/bc/51619d89e0bd855567e7652fa16d06f1ed36a85f108a7fe71f6629bf719d/q-2.6-py2.py3-none-any.whl
Installing collected packages: Q
Successfully installed Q-2.6


In [0]:
"""
Deep Q Learning (2020.4.12)
"""
import torch.optim as optim
import torch
import wandb
import numpy as np
import random

from model import Q

class Agent:
    def __init__(self,
        env,
        hidden_dim,
        lr,
        log_every=200,
        buffer_size = 100000,
        epsilon = 0.3,
        batch_size=512,
        explore_to=200,
        train_from=100,
        update_target_every=50,
        ):
        """
        Lovely Class
        """
        self.q = Q(hidden_dim)
        self.q_target = Q(hidden_dim)
        self.q_target.load_state_dict(self.q.state_dict())
        self.optimizer = optim.Adam(self.q.parameters(), lr=lr)
        self.discount_factor = 0.99
        self.env = env
        self.log_every = log_every
        self.buffer_size = buffer_size
        self.epsilon = epsilon
        self.batch_size = batch_size
        self.explore_to = explore_to
        self.train_from = train_from
        self.update_target_every = update_target_every

        self.counter = 0
        self.counter_data = 0 

        self.initialize_data()

    def put_data(self,
        state,
        action,
        reward,
        next_state,
        done):

        self.data['states'][self.counter_data] = state
        self.data['actions'][self.counter_data] = action
        self.data['rewards'][self.counter_data] = reward
        self.data['next_states'][self.counter_data] = next_state
        self.data['dones'][self.counter_data] = done

        self.counter += 1
        self.counter = min(self.counter, self.buffer_size)

        self.counter_data +=1
        self.counter_data = self.counter_data % self.buffer_size

    def initialize_data(self):
        self.data = dict()
        self.data['states'] = np.zeros((self.buffer_size, 4))
        self.data['actions'] = np.zeros((self.buffer_size, 1))
        self.data['rewards'] = np.zeros((self.buffer_size, 1))
        self.data['next_states'] = np.zeros((self.buffer_size, 4))
        self.data['dones'] = np.zeros((self.buffer_size, 1))

    def interact(self, start=False):
        state = self.env.reset()
        done = False
        while not done:
            if start or random.random() < self.epsilon:
                action = random.choice(range(2))
            else:
                action_values= self.q(torch.tensor(state,
                                                 dtype=torch.float32).unsqueeze(0))
                action = torch.argmax(action_values).item()

            next_state, reward, done, _ = self.env.step(action)

            self.put_data(state,
                          action,
                          reward,
                          next_state,
                          done * 1.)

            state = next_state

    def get_data_from_buffer(self, batch_size=32):
        indices = np.random.choice(range(self.counter), batch_size)
        
        states = self.data['states'][indices]
        actions = self.data['actions'][indices]
        rewards = self.data['rewards'][indices]
        next_states = self.data['next_states'][indices]
        dones = self.data['dones'][indices]

        states = torch.tensor(states, dtype=torch.float32)
        actions = torch.tensor(actions, dtype=torch.long)
        rewards = torch.tensor(rewards, dtype=torch.float32)
        next_states = torch.tensor(next_states, dtype=torch.float32)
        dones = torch.tensor(dones, dtype=torch.float32)

        return states, actions, rewards, next_states, dones

    def step(self):
        states, actions, rewards, next_states, dones = self.get_data_from_buffer(self.batch_size)

        values = self.q(states).gather(1, actions)

        with torch.no_grad():
            targets = rewards +\
                      self.discount_factor*self.q_target(next_states).max(dim=1, keepdim=True)[0]*\
                      (1-dones)

        loss = (values-targets).pow(2).mean()

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def train(self, iterations):
        for i in range(1, iterations+1):
            if i < self.explore_to:
                self.interact(True)
            else:
                self.interact(False)
            
            if i < self.train_from:
                pass
            else:
                self.step()

            if i % self.log_every == 0:
                wandb.log(self.evaluation())

            if i % self.update_target_every == 0:
                self.update_target()

    def evaluation(self):
        total_counter = []
        with torch.no_grad():
            counter = 0
            for i in range(30):
                state = self.env.reset()
                done = False
                
                while not done:
                    action_values= self.q(torch.tensor(state,
                                                         dtype=torch.float32).unsqueeze(0))
                    action = torch.argmax(action_values).item()
                    next_state, reward, done, _ = self.env.step(action)
                    counter += reward
                    state = next_state

                total_counter.append(counter)
                counter = 0

        return {
                "Reward (Meidian)": np.median(total_counter),
                "Reward (Mean)": np.mean(total_counter),
                "Reward (Max)": np.max(total_counter),
                "Reward (Standard Deviation":  np.std(total_counter),
        }

    def update_target(self):
        self.q_target.load_state_dict(self.q.state_dict())

    def save(self):
        torch.save(self.q.state_dict(), 'q.pth')
        torch.save(self.q_target.state_dict(), 'q_target.pth')
        torch.save(self.optimizer.state_dict(), 'optimizer.pth')

    def load(self):
        self.q.load_state_dict(torch.load('q.pth'))
        self.q_target.load_state_dict(torch.load('q_target.pth'))
        self.optimizer.load_state_dict(torch.load('optimizer.pth'))

ModuleNotFoundError: ignored