In [64]:
import numpy as np
import random

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

%run lipm_env.ipynb #imports LIPM Env

In [65]:
class Buffer:
    def __init__(self, buffer_size):
        
        self.buffer_size = buffer_size
        self.buffer = np.zeros((1, 2 + 1 + 1 + 2)) ## should be made variables
        
    def size(self):
        
        return len(self.buffer)
    
    def store(self, state, action, reward, next_state, done):
        ## stores new data in the buffer
        new_data = 9999*np.ones((1, 2 + 1 + 1 + 2))
        new_data[0, 0:2] = state
        new_data[0, 2:4] = [action, reward]
        if not done:
            new_data[0, 4:] = next_state 
        else:
            new_data[0, 4:] = [np.nan, np.nan]
        self.buffer = np.concatenate((self.buffer, np.around(new_data, 2)), axis = 0)
        
        #removes the first element if it is zero.
        if np.all(self.buffer[0] == 0):
            self.buffer = self.buffer[1:]
            
        # removes first element if the desired buffer size is obtained
        elif self.size() > self.buffer_size:
            self.buffer = self.buffer[1:]
            
    def sample(self, batch_size):
        # returns a random mini batch of desired batch size
        return np.asarray(random.sample(list(self.buffer), batch_size))

In [66]:
class NN(nn.Module):
    def __init__(self, inp_size, out_size):
        
        super(NN, self).__init__()
        self.l1 = nn.Linear(inp_size, 128)
        self.l2 = nn.Linear(128, 128)
        self.l3 = nn.Linear(128, out_size)
    
    def forward(self, x):
        
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = self.l3(x)
        return x

In [67]:
class DQStepper:
    def __init__(self, no_actions = 9, lr = 1e-4):
        
        self.device = torch.device("cpu")
        self.dq_stepper = NN(3, 1).to(self.device) #state+ action -> q_value
        self.optimizer = torch.optim.SGD(self.dq_stepper.parameters(), lr)
        
        self.no_actions = no_actions
        
    def predict_action_value(self, x):
        # this function predicts the q_value for different actions and returns action and min q value
        x_in = np.tile([x[0], x[1], 0], (self.no_actions, 1))
        x_in[:,2] = np.arange(self.no_actions)
        torch_x_in = torch.FloatTensor(x_in, device = self.device)
        with torch.no_grad():
            q_values = self.dq_stepper(torch_x_in).detach().numpy()
        return np.argmin(q_values), np.min(q_values)
    
    def predict_eps_greedy(self, x, eps = 0.1):
        # This function returns prediction based on epsillon greedy algorithm
        if np.random.random() > eps:
            return self.predict_action_value(x)[0]
        else:
            return np.random.randint(self.no_actions)
        
    

In [69]:
## This block is for training the dq stepper
buffer_size = 100
buffer = Buffer(buffer_size)

dqs = DQStepper()
batch_size = 32
epsillon = 0.1

no_epi = 100
no_steps = 20 ## number of steps simulated per episode (pendulum steps)
step_time = 0.1 ## time after which step is taken
env = LipmEnv(0.2, 0.22)

for e in range(no_epi):
    state = env.reset_env([0.2, 0], no_steps*step_time)
    for n in range(no_steps):
        action = dqs.predict_eps_greedy(state, epsillon)
        next_state, cost, done = env.step_env(action, step_time)
        buffer.store(state, action, cost, next_state, done)
        state = next_state
        
        ## optimizing DQN
        
        if done:
            break
            

[[0.2  0.   3.   1.42  nan  nan]
 [0.2  0.   6.   1.31 0.14 1.06]
 [0.2  0.   3.   1.42  nan  nan]
 [0.2  0.   3.   1.42  nan  nan]
 [0.2  0.   3.   1.42  nan  nan]
 [0.2  0.   3.   1.42  nan  nan]
 [0.2  0.   3.   1.42  nan  nan]
 [0.2  0.   3.   1.42  nan  nan]
 [0.2  0.   3.   1.42  nan  nan]
 [0.2  0.   3.   1.42  nan  nan]
 [0.2  0.   3.   1.42  nan  nan]
 [0.2  0.   3.   1.42  nan  nan]
 [0.2  0.   3.   1.42  nan  nan]
 [0.2  0.   3.   1.42  nan  nan]
 [0.2  0.   3.   1.42  nan  nan]
 [0.14 1.06 2.   2.58  nan  nan]
 [0.2  0.   3.   1.42  nan  nan]
 [0.2  0.   3.   1.42  nan  nan]
 [0.2  0.   4.   1.31  nan  nan]
 [0.14 1.06 2.   2.58  nan  nan]
 [0.2  0.   3.   1.42  nan  nan]
 [0.2  0.   3.   1.42  nan  nan]
 [0.2  0.   3.   1.42  nan  nan]
 [0.2  0.   3.   1.42  nan  nan]
 [0.2  0.   3.   1.42  nan  nan]
 [0.2  0.   3.   1.42  nan  nan]
 [0.2  0.   3.   1.42  nan  nan]
 [0.2  0.   3.   1.42  nan  nan]
 [0.2  0.   3.   1.42  nan  nan]
 [0.2  0.   3.   1.42  nan  nan]
 [0.2  0. 