## DQN

In [63]:
import torch.nn as nn
import torch
import numpy as np
import random

class Qfunction(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.linear_1 = nn.Linear(state_dim, 64)
        self.linear_2 = nn.Linear(64, 64)
        self.linear_3 = nn.Linear(64, action_dim)
        self.activation = nn.ReLU()
        
    def forward(self, states):
        hidden = self.linear_1(states)
        hidden = self.activation(hidden)
        hidden = self.linear_2(hidden)
        hidden = self.activation(hidden)
        actions = self.linear_3(hidden)
        
        return actions
        

In [78]:
class DQN:
    def __init__(self, state_dim, action_dim, gamma = .99, batch_size=64, lr=.001, eps_decrease = .01, eps_min = .01):
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.q_function = Qfunction(state_dim, action_dim)
        self.epsilon = 1
        self.eps_decrease = eps_decrease
        self.eps_min = eps_min
        self.memory = []
        self.batch_size = batch_size
        self.gamma = gamma
        self.optimizer = torch.optim.Adam(self.q_function.parameters(), lr = lr)
        
    def get_action(self, state):
        q_values = self.q_function(torch.FloatTensor(state))
        
        argmax_action = torch.argmax(q_values)
        probs = self.epsilon * np.ones(self.action_dim) / self.action_dim
        probs[argmax_action] += 1 - self.epsilon
        
        action = np.random.choice(np.arange(self.action_dim), p = probs)
        
        return action
    
    def get_batch(self):        
        batch = random.sample(self.memory, self.batch_size)
        
        states, actions, rewards, dones, next_states = [], [], [], [], []
        #print(batch)
        
        for five in batch:
            states.append(five[0])
            actions.append(five[1])
            rewards.append(five[2])
            dones.append(five[3])
            next_states.append(five[4])
            
        #states, actions, rewards, dones, next_states = list(zip(*batch)) # вместо цикла на 7 строк
            
        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        dones = torch.FloatTensor(dones)
        next_states = torch.FloatTensor(next_states)
        
        #states, actions, rewards, dones, next_states = map(torch.FloatTensor, list(zip(*batch))) #вместо всего выше написанного
        
        return states, actions, rewards, dones, next_states
    
    def fit(self, state, action, reward, done, next_state):
        self.memory.append([state, action, reward, done, next_state])
        
        if len(self.memory) > self.batch_size:  
            states, actions, rewards, dones, next_states = self.get_batch()

            #q_values = []
            #for i in range(batch_size):
            #    q_values.append(self.q_function(states[i][actions[i]]))
            #q_values = torch.FloatTensor(q_values)

            #q_values = self.q_function(states) # матрица строки- эл-т батча, столбцы - д-е
            #targets = q_values.clone()
            #for i in range(self.batch_size):
            #    target[i][actions[i]] = rewards[i] + self.gamma * (1 - dones[i]) * max(self.q_function(next_states[i]))

            targets = rewards + self.gamma * (1 - dones) * torch.max(self.q_function(next_states), dim =1).values #????
            q_values = self.q_function(states)[torch.arange(self.batch_size), actions] #check

            loss = torch.mean((q_values - targets.detach()) ** 2)
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

            if self.epsilon > self.eps_min:
                self.epsilon -= self.eps_decrease
        
        

In [80]:
import gym

env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

agent = DQN(state_dim, action_dim)

episode_n = 100
t_max = 500

for i in range(episode_n):
    state = env.reset()
    total_reward = 0
    
    for t in range(t_max):
        action = agent.get_action(state)
        next_state, reward, done, _ = env.step(action)
        
        total_reward += reward
        
        agent.fit(state, action, reward, done, next_state)
        
        state = next_state
        
        if done:
            break
            
    print(total_reward)

15.0
11.0
34.0
13.0
24.0
10.0
13.0
12.0
10.0
10.0
13.0
10.0
10.0
10.0
9.0
10.0
10.0
10.0
10.0
11.0
11.0
14.0
18.0
9.0
11.0
24.0
19.0
18.0
18.0
17.0
15.0
22.0
23.0
22.0
14.0
21.0
53.0
47.0
30.0
43.0
132.0
58.0
121.0
72.0
89.0
40.0
44.0
78.0
44.0
54.0
106.0
76.0
181.0
142.0
156.0
172.0
199.0
226.0
227.0
295.0
184.0
172.0
168.0
192.0
184.0
174.0
159.0
163.0
170.0
179.0
172.0
178.0
178.0
161.0
172.0
156.0
158.0
156.0
170.0
153.0
177.0
159.0
165.0
166.0
137.0
173.0
151.0
163.0
186.0
174.0
157.0
200.0
169.0
175.0
167.0
163.0
192.0
239.0
213.0
203.0
