## Практическое задание №3
### Метод кросс-энтропии с глубоким обучением (со средой gym)

In [1]:
import gym
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import time



In [2]:
env = gym.make('CartPole-v1')
state_dim = 4
action_n = 2

In [15]:
# случайный агент
class CEMDL(nn.Module):
    def __init__(self, state_dim, action_n):
        super().__init__()
        self.action_n = action_n
        self.state_dim = state_dim
        self.network = nn.Sequential(nn.Linear(self.state_dim, 100),
                                     nn.ReLU(),
                                     nn.Linear(100, self.action_n))
        self.softmax = nn.Softmax()
        self.optimazer = torch.optim.Adam(self.parameters(), lr=0.01)
        self.loss = nn.CrossEntropyLoss()

    def forward(self, input_):
        return self.network(input_)

    def fit(self, elite_trajectories):
        elite_states = []
        elite_actions = []
        for elite_trajectoty in elite_trajectories:
            for state, action in zip(elite_trajectoty['states'], elite_trajectoty['actions']):
                elite_states.append(state)
                elite_actions.append(action)

        elite_states = torch.FloatTensor(elite_states)
        elite_actions = torch.LongTensor(elite_actions)
        pred_actions = self.forward(elite_states)        
        loss = self.loss(pred_actions, elite_actions)
        loss.backward()
        self.optimazer.step()
        self.optimazer.zero_grad()

    def get_action(self, state):
        logits = self.forward(torch.FloatTensor(state))
        probs = self.softmax(logits).detach().numpy()
        action = np.random.choice(self.action_n, p=probs)
        return action
        
# получить траекторию
def get_trajectory(env, agent, max_len=500, vizualize=False):
    trajectory = {'states': [], 'actions': [], 'reward': 0}

    # инициализация начального состояния
    state = env.reset()
    trajectory['states'].append(state)

    # стратегия действия агента
    for _ in range(max_len):
        action = agent.get_action(state)
        trajectory['actions'].append(action)

        next_state, reward, done, _ = env.step(action)
        trajectory['reward'] += reward

        state = next_state

        if vizualize:
            time.sleep(0.01)
            env.render()

        if done:
            break

        trajectory['states'].append(state)

    return trajectory

# инициализация агента
agent = CEMDL(state_dim, action_n)
q_param = 0.8
iteration_n = 100
trajectory_n = 20

for iteration in range(iteration_n):
    # оценка политики
    trajectories = [get_trajectory(env, agent, max_len=500, vizualize=False) for _ in range(trajectory_n)]
    mean_total_reward = np.mean([trajectory['reward'] for trajectory in trajectories])
    print('iteration:', iteration, 'mean total reward:', mean_total_reward)

    # улучшение политики
    quantile = np.quantile(total_reward, q_param)
    elite_trajectories = []
    for trajectory in trajectories:
        total_reward = trajectory['reward']
        if total_reward > quantile:
            elite_trajectories.append(trajectory)
    if len(elite_trajectories) > 0:
        agent.fit(elite_trajectories)

#get_trajectory(env, agent, 1, vizualize=True)

iteration: 0 mean total reward: 19.75
iteration: 1 mean total reward: 19.9
iteration: 2 mean total reward: 25.65
iteration: 3 mean total reward: 26.1
iteration: 4 mean total reward: 33.8
iteration: 5 mean total reward: 34.7
iteration: 6 mean total reward: 42.55
iteration: 7 mean total reward: 41.35
iteration: 8 mean total reward: 41.95
iteration: 9 mean total reward: 54.25
iteration: 10 mean total reward: 44.0
iteration: 11 mean total reward: 52.8
iteration: 12 mean total reward: 65.2
iteration: 13 mean total reward: 47.0
iteration: 14 mean total reward: 58.25
iteration: 15 mean total reward: 66.35
iteration: 16 mean total reward: 61.35
iteration: 17 mean total reward: 74.75
iteration: 18 mean total reward: 90.8
iteration: 19 mean total reward: 77.1
iteration: 20 mean total reward: 69.45
iteration: 21 mean total reward: 75.0
iteration: 22 mean total reward: 71.55
iteration: 23 mean total reward: 80.8
iteration: 24 mean total reward: 71.55
iteration: 25 mean total reward: 76.2
iteration