In [None]:
import pandas as pd
import numpy as np
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from market_places_env import MarketPlacesEnv

# Инициализируем среду

env_ = MarketPlacesEnv()
env_.reset()

# Зафиксируем размерность пространства действий и наблюдений
# Для каждого вида агентов
# А также их количество
# И создадим пока пустой ключ для нейросетей

dims_dict = {
    'household': {
        'agents_num':5, 
        'actions_dim':3,
        'state_dim':33,
        'networks':[],
        'opt':[],
        'rewards':[0] * 5,
        'cum_rewards':[0] * 5,
        'prev_state':[[]] * 5,
        'actions':[[]] * 5
    },
    'marketplace': {
        'agents_num':2,
        'actions_dim':3,
        'state_dim':19, 
        'networks':[],
        'opt':[],
        'rewards':[0] * 2,
        'cum_rewards':[0] * 2,
        'prev_state':[[]] * 2,
        'actions':[[]] * 2
    },
    'distributor': {
        'agents_num':1,
        'actions_dim':3, 
        'state_dim':26,
        'networks':[],
        'opt':[],
        'rewards':[0] * 1,
        'cum_rewards':[0] * 1,
        'prev_state':[[]] * 1,
        'actions':[[]] * 1
    },
    'firm': {
        'agents_num':2,
        'actions_dim':3,
        'state_dim':52,
        'networks':[],
        'opt':[],
        'rewards':[0] * 2,
        'cum_rewards':[0] * 2,
        'prev_state':[[]] * 2,
        'actions':[[]] * 2
    },
}

# Инициализируем архитектуры сетей для каждого агента

for kkey_ in dims_dict.keys():
    instances_num = dims_dict[kkey_]['agents_num']
    actions_dim = dims_dict[kkey_]['actions_dim']
    state_dim = dims_dict[kkey_]['state_dim']
    
    for j in range(instances_num):
        
        network = nn.Sequential()
        network.add_module('layer1', nn.Linear(state_dim, 32))
        network.add_module('relu1', nn.ReLU())
        network.add_module('layer2', nn.Linear(32, 32))
        network.add_module('relu1', nn.ReLU())
        network.add_module('layer2', nn.Linear(32, 32))
        network.add_module('relu1', nn.ReLU())
        network.add_module('layer2', nn.Linear(32, 32))
        network.add_module('relu1', nn.ReLU())
        network.add_module('layer2', nn.Linear(32, 32))
        network.add_module('relu1', nn.ReLU())
        network.add_module('layer3', nn.Linear(32, actions_dim))

        dims_dict[kkey_]['networks'].append(network)

        opt_ = torch.optim.Adam(network.parameters(), lr=0.001)
        
        dims_dict[kkey_]['opt'].append(opt_)

def get_agent_action(network_, state, agent_type='household', agent_idx=0, epsilon=0.1):
    state_ = torch.tensor(state, dtype=torch.float32)
    if agent_idx != None:
        network_ = dims_dict[agent_type]['networks'][agent_idx]
    else: 
        network_ = dims_dict[agent_type]['networks'][0]

    q_values_ = network_(state_).detach().numpy()

    q_values_dict = {q_values_[i]: i for i in range(q_values_.shape[0])}
    pos = np.argmax(q_values_)
    val = q_values_[pos]

    epsilon = 0.1
    greedy_action = q_values_dict[val]
    should_explore = np.random.binomial(n=1, p=epsilon)
    
    if should_explore:
        chosen_action = np.random.choice(a = [0, 1, 2])
    else:
        chosen_action = greedy_action
    
    return chosen_action + 1

sample_state = env_.get_marketplace_state()
sample_network = dims_dict['marketplace']['networks'][0]
sample_action = get_agent_action(sample_network, sample_state, agent_type='marketplace')
sample_state, sample_action

def compute_td_loss(network, states, actions, rewards, next_states, is_done, gamma=0.999, check_shapes=False):
    """ Compute td loss using torch operations only """
    states = torch.tensor(states, dtype=torch.float32) 
    actions = torch.tensor(actions, dtype=torch.long) - 1
    rewards = torch.tensor(rewards, dtype=torch.float32)
  
    next_states = torch.tensor(next_states, dtype=torch.float32)
    is_done = torch.tensor(is_done, dtype=torch.uint8) 

    # get q-values for all actions in current states
    predicted_qvalues = network(states)
    
    # select q-values for chosen actions
    predicted_qvalues_for_actions = predicted_qvalues[
        range(states.shape[0]), actions
    ]

    # compute q-values for all actions in next states
    predicted_next_qvalues = network(next_states)

    # compute V*(next_states) using predicted next q-values
    next_state_values = torch.max(predicted_next_qvalues, dim=-1)[0]
    assert next_state_values.dtype == torch.float32

    # compute "target q-values" for loss - it's what's inside square parentheses in the above formula.
    target_qvalues_for_actions = rewards + gamma * next_state_values

    # at the last state we shall use simplified formula: Q(s,a) = r(s,a) since s' doesn't exist
    target_qvalues_for_actions = torch.where(
        is_done, rewards, target_qvalues_for_actions)

    # mean squared error loss to minimize
    loss = torch.mean((predicted_qvalues_for_actions -
                       target_qvalues_for_actions.detach()) ** 2)

    return loss

sample_reward = env_.marketplace_reward(agent_idx=0)
sample_opt = dims_dict['marketplace']['opt'][0]

sample_opt.zero_grad()
compute_td_loss(sample_network, [sample_state], [sample_action], [10], [sample_state], [False]).backward()
sample_opt.step()

def generate_session(env, train=True):
    total_reward = 0
    rewards = {}
    env.reset()
    done = False

    while not done:
        # Действия маркетплейсов на Шаге 1
        for mm_ind in range(dims_dict['marketplace']['agents_num']):
            mm_state = env.get_marketplace_state(agent_idx=mm_ind)
            dims_dict['marketplace']['prev_state'][mm_ind] = dims_dict['marketplace']['prev_state'][mm_ind] + [mm_state]
            network_ind = dims_dict['marketplace']['networks'][mm_ind]
            mm_action = get_agent_action(network_ind, mm_state, agent_type='marketplace', agent_idx=mm_ind)
            dims_dict['marketplace']['actions'][mm_ind] = dims_dict['marketplace']['actions'][mm_ind] + [mm_action]
            env.marketplace_step(mm_action, agent_idx=mm_ind)
            

        # Действие дистрибьютора на Шаге 2
        d_state = env.get_distributor_state()
        dims_dict['distributor']['prev_state'][0] = dims_dict['distributor']['prev_state'][0] + [d_state]
        network_ind = dims_dict['distributor']['networks'][0]
        d_action = get_agent_action(network_ind, d_state, agent_type='distributor', agent_idx=None)
        dims_dict['distributor']['actions'][0] = dims_dict['distributor']['actions'][0] + [d_action]
        env.distributor_step(d_action)

        # Действие домохозяйства для выбора фирмы на Шаге 3
        for h_ind in range(dims_dict['household']['agents_num']):
            
            h_state = env.get_household_state(agent_idx=h_ind, regime='choose_firm')            
            dims_dict['household']['prev_state'][h_ind] = dims_dict['household']['prev_state'][h_ind] + [h_state]
            network_ind = dims_dict['household']['networks'][h_ind]
            h_action = get_agent_action(network_ind, h_state, agent_type='household', agent_idx=h_ind)
            dims_dict['household']['actions'][h_ind] = dims_dict['household']['actions'][h_ind] + [h_action]
            env.household_step(agent_idx=h_ind, regime='choose_firm', action_=h_action)

        # Действие по выплате зарплат и пересчет премий на Шаге 4
        for f_ind in range(dims_dict['firm']['agents_num']):
            f_state = env.get_firm_state(agent_idx=f_ind, regime='wage')
            dims_dict['firm']['prev_state'][f_ind] = dims_dict['firm']['prev_state'][f_ind] + [f_state]
            network_ind = dims_dict['firm']['networks'][f_ind]
            f_action = get_agent_action(network_ind, f_state, agent_type='firm', agent_idx=f_ind)
            dims_dict['firm']['actions'][f_ind] = dims_dict['firm']['actions'][f_ind] + [f_action]
            env.firm_step(agent_idx=f_ind, regime='wage', action_=f_action)
            
        # Действие производства товарова на Шаге 5
        for f_ind in range(dims_dict['firm']['agents_num']):
            f_state = env.get_firm_state(agent_idx=f_ind, regime='produce')
            dims_dict['firm']['prev_state'][f_ind] = dims_dict['firm']['prev_state'][f_ind] + [f_state]
            network_ind = dims_dict['firm']['networks'][f_ind]
            f_action = get_agent_action(network_ind, f_state, agent_type='firm', agent_idx=f_ind)
            dims_dict['firm']['actions'][f_ind] = dims_dict['firm']['actions'][f_ind] + [f_action]
            env.firm_step(agent_idx=f_ind, regime='produce', action_=f_action)

        # Действие изменения стратегии дистрибуции на Шаге 6
        for f_ind in range(dims_dict['firm']['agents_num']):
            f_state = env.get_firm_state(agent_idx=f_ind, regime='redistribute_inventories')
            dims_dict['firm']['prev_state'][f_ind] = dims_dict['firm']['prev_state'][f_ind] + [f_state]
            network_ind = dims_dict['firm']['networks'][f_ind]
            f_action = get_agent_action(network_ind, f_state, agent_type='firm', agent_idx=f_ind)
            dims_dict['firm']['actions'][f_ind] = dims_dict['firm']['actions'][f_ind] + [f_action]
            env.firm_step(agent_idx=f_ind, regime='redistribute_inventories', action_=f_action)

        # Действие обновления цен на Шаге 7
        for f_ind in range(dims_dict['firm']['agents_num']):
            for reg in ['prices_online_1', 'prices_online_2', 'price_offline']:
                f_state = env.get_firm_state(agent_idx=f_ind, regime=reg)
                dims_dict['firm']['prev_state'][f_ind] = dims_dict['firm']['prev_state'][f_ind] + [f_state]
                network_ind = dims_dict['firm']['networks'][f_ind]
                f_action = get_agent_action(network_ind, f_state, agent_type='firm', agent_idx=f_ind)
                dims_dict['firm']['actions'][f_ind] = dims_dict['firm']['actions'][f_ind] + [f_action]
                env.firm_step(agent_idx=f_ind, regime=reg, action_=f_action)

        # Действие адаптации предпочтений на Шаге 8
        for h_ind in range(dims_dict['household']['agents_num']):
            h_state = env.get_household_state(agent_idx=h_ind, regime='redistribute_demand')
            network_ind = dims_dict['household']['networks'][h_ind]
            h_action = get_agent_action(network_ind, h_state, agent_type='household', agent_idx=h_ind)
            dims_dict['household']['prev_state'][h_ind] = dims_dict['household']['prev_state'][h_ind] + [h_state]
            dims_dict['household']['actions'][h_ind] = dims_dict['household']['actions'][h_ind] + [h_action]
            env.household_step(agent_idx=h_ind, regime='redistribute_demand', action_=h_action)

        # Расчитываем  награды на Шаге 9
        for key_ in dims_dict.keys():
            for ind in range(dims_dict[key_]['agents_num']):
                if key_ == 'household':
                    rew = env.household_reward(agent_idx=ind)
                    dims_dict[key_]['cum_rewards'][ind] += rew
                    dims_dict[key_]['rewards'][ind] = rew
                    total_reward += rew
                if key_ == 'marketpalce':
                    rew = env.marketplace_reward(agent_idx=ind)
                    dims_dict[key_]['cum_rewards'][ind] += rew
                    dims_dict[key_]['rewards'][ind] = rew
                    total_reward += rew
                if key_ == 'firm':
                    rew = env.firm_reward(agent_idx=ind)
                    dims_dict[key_]['cum_rewards'][ind] += rew
                    dims_dict[key_]['rewards'][ind] = rew
                    total_reward += rew
                if key_ == 'distributor':
                    rew = env.distributor_reward()
                    dims_dict[key_]['cum_rewards'][ind] += rew
                    dims_dict[key_]['rewards'][ind] = rew
                    total_reward += rew

        if len(env.marketplace_stock) != 2:
            print("WTF")
        # Обновляем стратегию RL-агентов
        if train:
            for key_ in dims_dict.keys():
                for ind in range(dims_dict[key_]['agents_num']):
                    for j in range(len(dims_dict[key_]['prev_state'][ind])):
                        ind_opt = dims_dict[key_]['opt'][ind]
                        ind_network = dims_dict[key_]['networks'][ind]
                        ind_opt.zero_grad()
                        
                        if key_ == 'household':
                            action__ = dims_dict[key_]['actions'][ind][j]
                            regimes_dict = {
                                0:'choose_firm',
                                1:'redistribute_demand'
                            }
                            regime_ind = regimes_dict[j]
                            cur_state = env.get_household_state(agent_idx=ind, regime=regime_ind)
                            to_compare_state = dims_dict[key_]['prev_state'][ind][j]
                            reward__ = env.household_reward(agent_idx=ind)
                            compute_td_loss(ind_network, [to_compare_state], [action__], [reward__], [cur_state], [done]).backward()
                        
                        if key_ == 'marketpalce':
                            action__ = dims_dict[key_]['actions'][ind][j]
                            cur_state = env.get_marketplace_state(agent_idx=ind)
                            to_compare_state = dims_dict[key_]['prev_state'][ind][j]
                            reward__ = env.marketplace_reward(agent_idx=ind)
                            compute_td_loss(ind_network, [to_compare_state], [action__], [reward__], [cur_state], [done]).backward()

                        if key_ == 'firm':
                            regimes_dict = {
                                0:'wage',
                                1:'produce',
                                2:'redistribute_inventories',
                                3:'prices_online_1',
                                4:'prices_online_2',
                                5:'price_offline'
                            }
                            regime_ind = regimes_dict[j]
                            cur_state = env.get_firm_state(agent_idx=ind, regime=regime_ind)
                            to_compare_state = dims_dict[key_]['prev_state'][ind][j]
                            reward__ = env.firm_reward(agent_idx=ind)
                            compute_td_loss(ind_network, [to_compare_state], [action__], [reward__], [cur_state], [done]).backward()
                            
                        if key_ == 'distributor':
                            action__ = dims_dict[key_]['actions'][ind][j]
                            reward__ = env.distributor_reward()
                            cur_state = env.get_distributor_state()
                            to_compare_state = dims_dict[key_]['prev_state'][ind][j]
                            compute_td_loss(ind_network, [to_compare_state], [action__], [reward__], [cur_state], [done]).backward()
                        ind_opt.step() 
            
        # Сброс истории действий и промежуточных состояний
        dims_dict['household']['actions'] = [[]] * 5
        dims_dict['household']['prev_state'] = [[]] * 5

        dims_dict['marketplace']['actions'] = [[]] * 2
        dims_dict['marketplace']['prev_state'] = [[]] * 2

        dims_dict['distributor']['actions'] = [[]] * 1
        dims_dict['distributor']['prev_state'] = [[]] * 1

        dims_dict['firm']['actions'] = [[]] * 2
        dims_dict['firm']['prev_state'] = [[]] * 2
        
        # Запускаем технологические шоки на Шаге 10
        env.stochastic_step()

        # Обновляем шаг и проверяем условие остановки игры
        env.current_episode += 1
        
        if env.current_episode > env.episodes_horizon:
            done = True      

    return 

### Для обуения большого числа эпох
#for epoch in range(200):
#    generate_session(env_)

generate_session(env_)