# Distributing QMIX with a visual feature extractor over multiple CPU cores
## Table of Contents:
    - Load a config file
    - run environment from config file
    - spawn multiple environments
    - spawn replay server
    - spawn parameter server
    - spawn learner process
     

## Loading a Config

In [None]:
from utils.read_config import merge_yaml_files, merge_dicts

file1 = "./config/default.yaml"
file2 = "./config/visual_qmix.yaml"




## Running environment given a config file

### Imports

In [None]:
import ray
import mlagents
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.side_channel.engine_configuration_channel import (EngineConfigurationChannel,)
from wrappers.UnityParallelEnvWrapper_Torch import UnityWrapper
from mlagents_envs.base_env import ActionTuple
import pdb
from collections import deque
import numpy as np
import gc
from functools import partial
from utils.unity_utils import get_worker_id
import torch
from controllers.custom_controller import CustomMAC
from components.replay_buffer import EpisodeBatch
from utils.utils import OneHot, RunningMeanStdTorch
import torch.nn.functional as F
import time
from components.replay_buffer import Remote_ReplayBuffer
from components.parameter_server import ParameterServer

# @ray.remote(num_cpus = 1, num_gpus = 0.01)
class Executor:
    def __init__(self,config, worker_id):
        super().__init__()
        # Set config items
        self.time_scale = config["time_scale"]
        self.env_path = config["executable_path"]
        self.episode_limit = config["episode_limit"]
        self.config = config
        self.batch_size = config["batch_size_run"]

        self.device = torch.device("cuda")
        self.worker_id = worker_id

        # Set class variables
        config_channel = EngineConfigurationChannel()
        config_channel.set_configuration_parameters(time_scale=self.time_scale)

        try:
            self.env.close()
            self.unity_env.close()
        except:
            print("No envs open")

        unity_env = UnityEnvironment(file_name=self.env_path, worker_id=get_worker_id(), seed=np.int32(0), side_channels=[config_channel])
        # unity_env = UnityEnvironment(file_name='./unity/envs/Discrete_NoCur/Discrete_NoCur.x86_64', worker_id=get_worker_id())
        unity_env.reset()

        self.env = UnityWrapper(unity_env, config_channel, episode_limit=self.episode_limit)
        self.env.reset()

        self.get_env_info()
        self.setup()
        self.setup_logger()

        if self.config["curiosity"]:
            self.reward_rms = RunningMeanStdTorch()


    def collect_experience(self):
        self.reset()
        episode_start = time.time()
        # global_steps = ray.get(self.parameter_server.return_environment_steps.remote())
        global_steps = 123
        terminated = False
        episode_return = 0
        

        self.mac.init_hidden(batch_size=self.batch_size)
        raw_observations = 0

        reward_episode = []
        icm_reward = None

        while not terminated:

            raw_observations = self.env._get_observations()       

            # state is determined from raw obs after feature extraction
            # normalise the obs before you save them to the replay buffer
            # if self.config["contains_state"]:
            state = self.env._get_global_state_variables()

            pre_transition_data = {
                "state": state,
                "avail_actions": self.env.get_avail_actions(),
                "obs": raw_observations
            }
            # else:
            #     pre_transition_data = {
            #         "avail_actions": self.env.get_avail_actions(),
            #         "obs": raw_observations
            #     }

            self.batch.update(pre_transition_data, ts=self.t)

            # Pass the entire batch of experiences up till now to the 
            # Receive the actions for each agent at this timestep in a batch of size 1
            # This will change depending on whether I'm using a feature extraction network or not
            actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=global_steps, test_mode=False)


            reward, terminated, env_info = self.env.step(actions[0])
            reward_episode.append(reward)


            episode_return += reward


            post_transition_data = {
                "actions": actions,
                "reward": [(reward,)],
                "terminated": [(terminated != env_info.get("episode_limit", False),)],
                # terminated above says whether the agent terminated because they reached the end
                # of the episode
            }

            self.batch.update(post_transition_data, ts=self.t)

            self.t += 1

        raw_observations = self.env._get_observations()

        last_data = {
                "state": self.env._get_global_state_variables(),
                "avail_actions": self.env.get_avail_actions(),
                "obs": raw_observations
            }
        
        self.batch.update(last_data, ts=self.t)

        actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=global_steps, test_mode=False)

        self.batch.update({"actions": actions}, ts=self.t)

        if self.config["curiosity"]:
            icm_reward = self.curiosity()
            icm_reward = np.sum(icm_reward, axis=-1)
        
        try:
            # Parameter server keeps track of global steps and episodes
            # Add the number of steps in this executor's episode to the global count
            self.parameter_server.add_environment_steps.remote(self.t)
            
            # Increment global episode count
            self.parameter_server.increment_total_episode_count.remote()

            # Accumulate mean episode reward and episode duration
            self.parameter_server.accumulate_stats.remote(sum(reward_episode), time.time() - episode_start, icm_reward)
        except:
            pass

        return self.batch
    
    def run(self, remote_buffer, parameter_server):
        self.parameter_server = parameter_server
        self.remote_buffer = remote_buffer

        while True:
            # TODO : Sample from parameter server every few episodes to obtain up to date parameters
            # Sample every 10 parameter updates:
            if ray.get(self.parameter_server.get_parameter_update_steps.remote()) % 10 == 0:
                self.sync_with_parameter_server()


            episode_batch = self.collect_experience()

            # Trying to use shared memory
            episode_batch_reference = ray.put(episode_batch)

            self.remote_buffer.insert_episode_batch.remote(episode_batch_reference)



    def reset(self):
        self.batch = self.new_batch()
        self.env.reset()
        self.t = 0
        self.mac.agent.fc2.eval()
        self.mac.agent.feature_extractor.eval()

    def curiosity(self):
        T = self.batch.max_t_filled()
        observation = self.batch["obs"][:, :T]
        B = observation.shape[0]
        N = observation.shape[2]

        shaped_obs = observation.reshape(-1, self.mac.agent.feature_extractor.height,
                self.mac.agent.feature_extractor.width,
                self.mac.agent.feature_extractor.initial_channels).cuda()

        reduced_obs = self.mac.agent.feature_extractor(shaped_obs)
        reduced_obs = reduced_obs.reshape(B, T, N, -1).cuda()
        next_obs, obs = self.mac.agent._build_batch_inputs(reduced_obs, self.batch)

        real_next_obs, pred_next_obs, pred_action = self.mac.agent.icm([obs[:, :T], next_obs[:, :T], self.batch["actions_onehot"][:, :T-1]])

        intrinsic_reward = self.config["eta"] * F.mse_loss(
        real_next_obs, pred_next_obs, reduction='none').mean(-1).data

        intrinsic_reward = torch.mean(intrinsic_reward, dim=-1).unsqueeze(-1)

        self.reward_rms.update(intrinsic_reward)
        # self.reward_rms.update_from_moments(mean, std**2, count)
        intrinsic_reward_scaled = (intrinsic_reward-self.reward_rms.mean)/torch.sqrt(self.reward_rms.var)
        # print(f"Intrinsic reward before norm: {intrinsic_reward}")
        # print(f"Intrinsic reward after norm: {intrinsic_reward_scaled}")
        self.batch.update({"icm_reward": intrinsic_reward_scaled}, ts = slice(0,self.t))

        return np.sum(intrinsic_reward_scaled.cpu().numpy(), axis = -1)

    
    def setup(self):
        scheme, groups, preprocess = self.generate_scheme()
        self.mac = CustomMAC(self.config)

        self.new_batch = partial(EpisodeBatch, scheme, groups, self.config["batch_size_run"], self.config["episode_limit"]+1, preprocess = preprocess, device = "cpu")


    def get_env_info(self):
        self.config["obs_shape"] = self.env.obs_shape
        self.env_info = self.env.get_init_env_info()
        self.config["n_actions"] = self.env_info["n_actions"]

    def setup_logger(self):
        #TODO
        pass

    def close_env(self):
        self.env.close()

    def generate_scheme(self):
        self.config["state_shape"] = self.env_info["state_shape"]

        scheme = {
            "state": {"vshape": self.env_info["state_shape"]},
            "obs": {"vshape": self.env_info["obs_shape"], "group": "agents", "dtype": torch.float32},
            "actions": {"vshape": (1,), "group": "agents", "dtype": torch.long},
            "avail_actions": {"vshape": (self.env_info["n_actions"],), "group": "agents", "dtype": torch.int},
            "reward": {"vshape": (1,)},
            "terminated": {"vshape": (1,), "dtype": torch.uint8},
            }
        
        if self.config["curiosity"]:
            icm_reward = {"icm_reward": {"vshape": (1,)},}
            scheme.update(icm_reward)

        if self.config["useNoisy"]:
            raise NotImplementedError
        
        groups = {
        "agents": self.config["num_agents"]
        }

        preprocess = {
        "actions": ("actions_onehot", [OneHot(out_dim=self.config["n_actions"])])
        }

        return scheme, groups, preprocess
    
    def retrieve_updated_config(self):
        return self.config
    
    def sync_with_parameter_server(self):
        # receive the stored parameters from the server using ray.get()
        new_params = ray.get(self.parameter_server.return_params.remote())
        
        # copy the received neural network weights to its own
        # since the weights are saved as numpy arrays in the server,
        # we have to convert them into pytorch tensors. 
        for param, new_param in zip(self.mac.parameters(), new_params):
            new_param = torch.Tensor(new_param).to(self.device)
            param.data.copy_(new_param)




# Setting Up the replay buffer given config


In [None]:
# from components.replay_buffer import Remote_ReplayBuffer, generate_replay_scheme
# from components.parameter_server import ParameterServer


# config = merge_yaml_files(file1, file2)

# workers = [Executor.remote(config, i) for i in range (config["num_executors"])]
# config_ref = workers[0].retrieve_updated_config.remote()
# config = ray.get(config_ref)

# scheme, groups, preprocess = generate_replay_scheme(config)

# remote_buffer = Remote_ReplayBuffer.remote(scheme, groups, config["buffer_size"], config["episode_limit"]+1, preprocess=preprocess, device="cpu")
# parameter_server = ParameterServer.remote()

# ray.wait([worker.run.remote(parameter_server, remote_buffer) for worker in workers])






# Setting up the Learner for Distributed Training

In [None]:
import torch
import torch.nn as nn 
from models.qmix import QMixer
from utils.utils import RunningMeanStdTorch
from torch.optim import Adam
from copy import deepcopy
from torch.utils.tensorboard import SummaryWriter
import datetime
import sys

# @ray.remote(num_gpus = 0.96, num_cpus=3)
class Learner(object):
    def __init__(self, config):
        self.config = config
        self.device = "cuda:0"
        self.global_training_steps = config["t_max"]
        self.mac = CustomMAC(self.config)

        self.trainable_parameters = nn.ParameterList(self.mac.parameters())

        self.beta = self.config["beta"]


        # Parameter server stuff
        self.parameter_server_list = list (self.mac.agent.state_dict())


        # Setup Mixer
        self.mixer = QMixer(config)
        self.trainable_parameters+= nn.ParameterList(self.mixer.parameters())
        

        # Reward Standardisation
        if self.config["standardise_rewards"]:
            self.reward_rms = RunningMeanStdTorch(shape=(1,), device="cuda:0")

        # Optimiser
        self.optimiser = Adam(params=self.trainable_parameters, lr=self.config["lr"], eps = self.config["optim_eps"])

        # Loss functions
        self.mse = nn.MSELoss()
        self.ce = nn.CrossEntropyLoss()

        # Target Networks:
        # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC (this is what the pymarl dudes said, bless their innocent hearts)
        self.target_mac = deepcopy(self.mac)
        self.target_mixer = deepcopy(self.mixer)

        self.previous_target_update_episode = 0

        self.trainer_steps = 0

        self.cuda()

        self.debug=0

        # Logger Stuff
        self.training_start_time = time.time()
        self.setup_writer()
        self.time_info = {}
        self.log_stats_dict = {}

    def train(self, log_this_step = False):
        print("Getting Batch")
        if self.config["use_per"]:
            raise NotImplementedError
        else:
            # episode_sample_reference = self.remote_buffer.sample.remote(self.config["batch_size"])
            episode_sample_reference = self.remote_buffer.sample(self.config["batch_size"])
            # episode_sample = ray.get(episode_sample_reference)
        print("Got episode sample reference")

        self.debug+=1
        print(f"Sizeof ep ref: {sys.getsizeof(episode_sample_reference)}")
        # max_ep_t_reference = ray.get(episode_sample_reference.max_t_filled.remote())
        max_ep_t_reference = episode_sample_reference.max_t_filled()
        print(f"Sizeof max_ep_t_reference: {sys.getsizeof(max_ep_t_reference)}")
        print("Got max_ep_t_reference")
        # max_ep_t = episode_sample.max_t_filled()

        # max_ep_t = ray.get(max_ep_t_reference)
        max_ep_t = max_ep_t_reference
        print("Got max_ep_t")
        print("Slicing batch")
        if not self.config["random_update"]:
            # episode_sample = episode_sample[:, :max_ep_t]
            episode_sample_reference = episode_sample_reference[:, :max_ep_t]
        else:
            if max_ep_t>self.config["recurrent_sequence_length"]:
                start_idx = np.random.randint(0, max_ep_t-self.config["recurrent_sequence_length"]+1)
                # episode_sample = episode_sample[:, start_idx[0]:start_idx[0]+self.config["recurrent_sequence_length"]]
                episode_sample_reference = episode_sample_reference[:, start_idx[0]:start_idx[0]+self.config["recurrent_sequence_length"]]
            else:
                # episode_sample = episode_sample[:, :max_ep_t]
                episode_sample_reference = episode_sample_reference[:, :max_ep_t]
        print("Sliced batch")
        if self.config["use_per"]:
            # masked_td_error, mask = self.subtrain(batch, t_env, episode_num)
            # res = th.sum(masked_td_error, dim=(1,2)) / th.sum(mask, dim = (1,2))
            # res = res.cpu().detach().numpy()
            raise NotImplementedError
        else:
            print("Run Training LOOP")
            # log_dict = self.training_loop(episode_sample, log_this_step)
            # log_dict = self.training_loop(ray.put(episode_sample_reference), log_this_step)
            log_dict = self.training_loop(episode_sample_reference, num_global_episodes=123, log_this_step = log_this_step)
            print("Training LOOP Done")
        return log_dict

    def run(self, remote_buffer, parameter_server):
        self.remote_buffer = remote_buffer
        self.parameter_server = parameter_server
        debug_steps = 0
        while not ray.get(self.remote_buffer.can_sample.remote(self.config["batch_size"])):
            debug_steps+=1
            if debug_steps%10000 == 0:
                print(f"Waited {debug_steps} steps")
                print(ray.get(self.remote_buffer.__repr__.remote()))
            continue

        print("READY TO START TRAINING!!")

        while True:
            # Log time taken per training step:
            current_training_step_start_time = time.time()
            
            if self.trainer_steps%self.config["log_every"] == 0:
                log_this_step = True
            else:
                log_this_step = False

            print("Starting training step")
            log_dict = self.train(log_this_step)
            print("Training step done")
            self.trainer_steps += 1
            self.update_parameter_server()
            
            training_took = time.time() - current_training_step_start_time
            self.store_time_stats("Mean_training_loop_time", training_took)

            if log_this_step:
                self.log_stats(log_dict)
            

        
    def setup_writer(self):
        self.writer = SummaryWriter(log_dir="results/" + datetime.datetime.now().strftime("%d_%m_%H_%M") + "/tb_logs")

    def log_stats(self, stats_to_log:dict):
        global_environment_steps = ray.get(self.parameter_server.return_environment_steps.remote())
        # Stats obtained from the trainer:
        self.writer.add_scalars("Training_Stats", stats_to_log, global_environment_steps)
        
        # Total elapsed training time
        self.writer.add_scalar("Time_Stats", time.time() - self.training_start_time, global_environment_steps)

        # Trainer time info
        self.writer.add_scalar("Time_Stats", self.time_info["Mean_training_loop_time"]/self.time_info["number_log_steps"], global_environment_steps)

        # Log Executors' rewards
        mean_extrinsic_reward, mean_icm_reward, mean_ep_duration = ray.get(self.parameter_server.get_accumulated_stats.remote())
        self.writer.add_scalar("Reward_Stats", mean_extrinsic_reward, global_environment_steps)
        if mean_icm_reward is not None:
            self.writer.add_scalar("Reward_Stats", mean_icm_reward, global_environment_steps)
        self.writer.add_scalar("Time_Stats", mean_ep_duration, global_environment_steps)

    def periodically_print(self):
        if time.time() - self.training_start_time > 20:
            ray.get(self.remote_buffer.__repr__.remote())

    def store_time_stats(self, key, value):
        self.time_info[key] += value
        self.time_info["number_log_steps"] +=1

    def reset_stats():
        pass
           

    def training_loop(self, batch: EpisodeBatch, num_global_episodes, log_this_step):
        print("Setting encoder to train")
        self.mac.agent.feature_extractor.train()

        print("Cudaing all parts")
        # Get the relevant quantities
        rewards = batch["reward"][:, :-1].cuda()
        actions = batch["actions"][:, :-1].cuda()
        actions_onehot = batch["actions_onehot"][:,:-1].cuda()
        terminated = batch["terminated"][:, :-1].float().cuda()
        mask = batch["filled"][:, :-1].float().cuda()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]).cuda()
        avail_actions = batch["avail_actions"].cuda()
        print("Cudad all parts")
        if self.config["curiosity"]:
            icm_reward = batch["icm_reward"][:, :-1].cuda()

        if self.config["standardise_rewards"]:
            self.reward_rms.update(rewards)
            rewards_normed = (rewards-self.reward_rms.mean)/torch.sqrt(self.reward_rms.var)


        # Feature Extraction
        # Gotta generate the state from the flattened observations still
        # We get flat obs in shape: (B, T, N, 128)
        # We want tp convert this into a tensor of shape (B,T,256)
        B = batch["obs"].shape[0]
        T = batch["obs"].shape[1]
        N = batch["obs"].shape[2]

        observation = batch["obs"]
        print("Reshape and cuda obs")
        shaped_obs = observation.reshape(-1, self.mac.agent.feature_extractor.height,
                self.mac.agent.feature_extractor.width,
                self.mac.agent.feature_extractor.initial_channels).cuda()

        reduced_obs = self.mac.agent.feature_extractor(shaped_obs)
        # shaped_obs = shaped_obs.cpu()
        reduced_obs = reduced_obs.reshape(B, T, N, -1).cuda()
        print("Extracted feats and cudad")
        # Also calculate target reduced obs, for use with the target networks
        target_reduced_obs = self.target_mac.agent.feature_extractor(shaped_obs)
        target_reduced_obs = target_reduced_obs.reshape(B,T,N,-1).cuda()
        print("Extracted target feats and cudad")
                   
        # self.mac.agent.fc2.train()
        # self.target_mac.agent.fc2.train()

        # list_of_noises = [batch["noise_0_weight"], batch["noise_0_bias"], batch["noise_1_weight"], batch["noise_1_bias"]]
        
        # We have to pass the observations through the feature extractor before using them to calculate
        # the agent outputs
        # Calculate estimated Q-Values
        mac_out = []
        self.mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            agent_outs = self.mac.forward(batch, t=t, training=True, batch_reduced_obs=reduced_obs)
            mac_out.append(agent_outs) 

        mac_out = torch.stack(mac_out, dim=1)  # Concat over time

        # Pick the Q-Values for the actions taken by each agent
        chosen_action_qvals = torch.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3)  # Remove the last dim

        # Calculate the Q-Values necessary for the target
        target_mac_out = []
        self.target_mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            target_agent_outs = self.target_mac.forward(batch, t=t, training=True, batch_reduced_obs=target_reduced_obs)
            target_mac_out.append(target_agent_outs)

        # if self.feature_extractor:
        #     # Create the state for mixing here, then delete the batch
        #     del batch

        # We don't need the first timesteps Q-Value estimate for calculating targets
        target_mac_out = torch.stack(target_mac_out[1:], dim=1)  # Concat across time
        # Mask out unavailable actions
        target_mac_out[avail_actions[:, 1:] == 0] = -9999999

        # Use target_meac_out as above for M-DQN. If using M-DQN, the modified target_max_q-values should be u

        # Max over target Q-Values
        if self.config["double_q"]:
            # Get actions that maximise live Q (for double q-learning)
            mac_out_detach = mac_out.clone().detach()
            mac_out_detach[avail_actions == 0] = -9999999
            cur_max_actions = mac_out_detach[:, 1:].max(dim=3, keepdim=True)[1]
            target_max_qvals = torch.gather(target_mac_out, 3, cur_max_actions).squeeze(3)
        else:
            target_max_qvals = target_mac_out.max(dim=3)[0]


        # Mix
        if self.mixer is not None:
            # CURIOSITY
            if self.config["curiosity"]:
                next_obs, obs = self.mac.agent._build_batch_inputs(reduced_obs, batch)

                # Curiosity Loss:
                real_next_obs, pred_next_obs, pred_action = self.mac.agent.icm([obs, next_obs, batch["actions_onehot"][:, :-1]])

                # Only use ICM while the agent is exploring
                icm_reward_modded = icm_reward*self.config["icm_weight"]

                rewards_total = rewards_normed + icm_reward_modded
                    
                    
                
                # inverse_loss = self.bce(pred_action, actions_onehot)
                inverse_loss = self.ce(pred_action.contiguous().view(-1, self.config["n_actions"]), actions_onehot.contiguous().view(-1,self.config["n_actions"]))
                
                forward_loss = self.mse(
                    pred_next_obs, real_next_obs.detach()
                )
            else:
                # rewards = self.sum_norm(rewards, 20)
                rewards_total = rewards_normed

            # rewards_total = rewards
            state = reduced_obs.view(B,T,-1)
            target_state = target_reduced_obs.view(B,T,-1)
            # Norms to 0 mean and 1 std
            # Test without this to see if it influences anything
            # state = z_score_norm(state)
            
            # state now contains the previous actions for all agents

            # Also include the available actions at each timestep in the state:
            state = torch.concat([state,avail_actions.reshape([B,T,-1])], dim=-1).cuda()
            target_state = torch.concat([target_state,avail_actions.reshape([B,T,-1])], dim=-1).cuda()

            # if self.config["contains_state"]:
                # If extra state information is available:
            state = torch.cat([state, batch["state"].cuda()], dim=-1)
            target_state = torch.cat([target_state, batch["state"].cuda()], dim=-1)

            chosen_action_qvals = self.mixer(chosen_action_qvals, state[:, :-1])
            target_max_qvals = self.target_mixer(target_max_qvals, target_state[:, 1:])
        else:
            raise NotImplementedError
  
        # Calculate 1-step Q-Learning targets
        targets = rewards_total + self.config["gamma"] * (1 - terminated) * target_max_qvals
        
        # Td-error
        td_error = (chosen_action_qvals - targets.detach())

        mask = mask.expand_as(td_error)

        # 0-out the targets that came from padded data
        masked_td_error = td_error * mask


        # Normal L2 loss, take mean over actual data
        if self.config["curiosity"]:
            qmix_loss = (masked_td_error ** 2).sum() / mask.sum()
            forward_loss = self.beta*forward_loss
            inverse_loss = (1-self.beta)*inverse_loss
            icm_loss = forward_loss + inverse_loss
            loss = qmix_loss + icm_loss
        else:
            loss = (masked_td_error ** 2).sum() / mask.sum()
        
        # Optimise
        self.optimiser.zero_grad()
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(self.trainable_parameters, self.config["grad_norm_clip"])
        self.optimiser.step()

        # Update targets:
        if (num_global_episodes - self.previous_target_update_episode) / self.config["target_update_interval"] >= 1.0:
            self._update_targets()
            self.previous_target_update_episode = num_global_episodes
        
        if log_this_step:
            losses_dict = {"total_loss" : loss.item()}
            if self.config["curiosity"]:
                curiosity_stats = {
                    "qmix_loss" : qmix_loss.item(),
                    "icm_loss": icm_loss.item(),
                    "forward_loss" : forward_loss.item(),
                    "inverse_loss": inverse_loss.item(),
                }
                losses_dict.update(curiosity_stats)
            
            mask_elems = mask.sum().item()
            other_log_stuff = {
                "grad_norm": grad_norm,
                "td_error_abs": (masked_td_error.abs().sum().item()/mask_elems),
                "q_taken_mean": (chosen_action_qvals * mask).sum().item()/(mask_elems * self.config["n_agents"]),
                "target_mean": (targets * mask).sum().item()/(mask_elems * self.config["n_agents"])
            }
            losses_dict.update(other_log_stuff)
            return losses_dict
        



    def update_parameter_server(self):
        params = []
        state_dicts_to_save = self.mac.agent.state_dict() # returns [agent_state, feature_state]

        for param in self.params_list:
            params.append(state_dicts_to_save[param].cpu().numpy())
        self.parameter_server.update_params.remote(params)

    def return_parameter_list(self):
        return self.parameter_server_list

    def cuda(self):
        self.mac.cuda()
        self.target_mac.cuda()
        if self.mixer is not None:
            self.mixer.cuda()
            self.target_mixer.cuda()

    def _update_targets(self):
        self.target_mac.load_state(self.mac)
        if self.mixer is not None:
            self.target_mixer.load_state_dict(self.mixer.state_dict())
        # self.logger.console_logger.info("Updated target network")

# Testing ALLES

In [None]:
from components.replay_buffer import Remote_ReplayBuffer, generate_replay_scheme
from components.parameter_server import ParameterServer

ray.init()

config = merge_yaml_files(file1, file2)


worker = Executor(config, 0)
# workers = [Executor.remote(config, i) for i in range (config["num_executors"])]
config = worker.retrieve_updated_config()
# config = ray.get(config_ref)

scheme, groups, preprocess = generate_replay_scheme(config)

remote_buffer = Remote_ReplayBuffer(scheme, groups, config["buffer_size"], config["episode_limit"]+1, preprocess=preprocess, device="cpu")
parameter_server = ParameterServer()

learner = Learner(config)

learner.remote_buffer = remote_buffer
learner.parameter_server = parameter_server
worker.remote_buffer = remote_buffer
worker.parameter_server=parameter_server

# all_actors = workers + [learner]

# # ray.wait([worker.run.remote(remote_buffer, parameter_server) for worker in workers])
# ray.wait([worker.run.remote(remote_buffer, parameter_server) for worker in all_actors])

episode_batch = worker.collect_experience()

remote_buffer.insert_episode_batch(episode_batch)

learner.train(log_this_step=False)


# More Test Code

In [None]:
from controllers.custom_controller import CustomMAC
from utils.read_config import merge_yaml_files, merge_dicts
import torch.nn as nn
import sys
import torch

file1 = "./config/default.yaml"
file2 = "./config/visual_qmix.yaml"


config = merge_yaml_files(file1, file2)

config["obs_shape"] = (84,84,3)
config["n_actions"] = 9
mac = CustomMAC(config)

sd = mac.agent.state_dict()



try:
    params = mac.agent.named_parameters().dict
except Exception as e:
    print(e)
# # these params are to be loaded into new_mac
# server_parameters = mac.agent.state_dict()

# # print(server_parameters.values())

# for parameter_name in server_parameters:
#     print(server_parameters[parameter_name])
#     # print(parameter_value)


In [None]:
dict = {}
