In [None]:
import os
from datetime import datetime

import numpy as np
import torch
import wandb
from tqdm import trange

import scipy.io as sio

#rllib with ray is the primary framework that will be used for imitation learning 
from ray import air, tune, rllib

#RL option after imitation learning is done
#from agilerl.algorithms.ppo import PPO
#from agilerl.training.train_on_policy import train_on_policy
#from agilerl.utils.utils import create_population, make_skill_vect_envs, make_vect_envs
#from agilerl.wrappers.learning import Skill

from ray.rllib.algorithms.bc import BCConfig

In [None]:
#loading data 

x = sio.loadmat()
y = sio.loadmat()

In [None]:

import gym
from gym import spaces
from ray.rllib.env.multi_agent_env import MultiAgentEnv

class CichyEnv(MultiAgentEnv):
    def __init__(self, images, expert_rdms):
        self.images = images
        self.expert_rdms = expert_rdms
        self.num_agents = 2
        self.agent_ids = ["IT", "EVC"]


        self.observation_space = spaces.Dict({
            "image": spaces.Box(low=0, high=255, shape=(224, 224, 3), dtype=np.uint8),#this shape should actually probably be 175, 175,3
            "other_action": spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32)
        })

        # Continuous: 1. expending activity units 2. sending signal
        self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)

        self.action_space = spaces.Dict({ 
            "IAU": spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32),
            "other_action": spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32)
            }
        )


    def reset(self):
        self.current_step = 0
        self.state = {agent: {"image": self.images[self.current_step], "other_action": np.array([0.0])} for agent in self.agent_ids}
        return self.state

    def step(self, action_dict):
        obs, rewards, dones, infos = {}, {}, {}, {}
        self.current_step += 1

        for agent_id in self.agent_ids:
            other_agent_id = "EVC" if agent_id == "IT" else "IT"
            obs[agent_id] = {"image": self.images[self.current_step], "other_action": action_dict[other_agent_id][1]}
            rewards[agent_id] = self._calculate_reward(agent_id, action_dict[agent_id])
            dones[agent_id] = self.current_step >= len(self.images) - 1
            infos[agent_id] = {}

        dones["__all__"] = all(dones.values())
        return obs, rewards, dones, infos
    def _calculate_rewards(self):
        rewards = {}
        """
        this function iterates over agents and withen each agent iterates over images to create a cosine similarity. it then calculates the reward
        for each agent 
        
        this is configured to do one pass where it generates simulated RDM for one subject
        
        """
        for agent_id in self.agent_ids:
            actions = np.array(self.actions[agent_id])
            num_images = len(self.images)
            simulated_rdm = np.zeros((num_images, num_images))#

            for i in range(num_images):
                for j in range(num_images):
                    if i != j:
                        sim = cosine_similarity(actions[i].reshape(1, -1), actions[j].reshape(1, -1))[0][0]
                        simulated_rdm[i, j] = 1 - sim

            expert_rdm = self.expert_rdms[agent_id]
            rewards[agent_id] = -np.sum((simulated_rdm - expert_rdm) ** 2)
        return rewards





MVP(Minimum viable product) checklist:
1. load rdm and image data: in progress

2. create imitation learner: in progress
 - define observation space: in progress
 - define action space: in progress
 - definitively settle on imitation learner architecture: DONE 

 RL architecture: MARIWEL

 it support multiagent RL and support continuous action and observation spaces. 

3. create enviroment: in progress

4. load the model after imitation learning training: 
note: instead of rllib we can possibly use agilerl instead for everything after step 4
5. load the model and proceed with standard RL 

STREtTCH
1. create more agents to carry out simulations on more granular level. 

MISC

observation space: images, actions from other agents

action space: expending activity units, sending signal observable by other agent that also expends an activity unit

expert data: input: images, output: rdm

training stage: use rllib for imitation learning

save model in .pt after training

prediction stage: use agilerl for prediction

why is more RL needed beyond just the imitation learning? because the agents will also be able to interact with each other. This is not captured in the dataset that is being fed to the neural network 

the RL and Imitation learning can be done in one fell swoop with the MARIWEL architecture since it can become a hybrid RL and imitation learning architecture by varying the beta paraemter 

enviroment

class

step()

reset()



In [None]:
#BC here stands for behavior cloning 

config = BCConfig().training(lr=0.00001, gamma=0.99)
config = config.offline_data(
    input_="./rllib/tests/data/cartpole/large.json")



In [None]:
def central_critic_observer(agent_obs, **kw):
    """Rewrites the agent obs to include opponent data for training."""

    new_obs = {
        0: {
            "own_obs": agent_obs[0],
            "opponent_obs": agent_obs[1],
            "opponent_action": 0,  # filled in by FillInActions
        },
        1: {
            "own_obs": agent_obs[1],
            "opponent_obs": agent_obs[0],
            "opponent_action": 0,  # filled in by FillInActions
        },
    }
    return new_obs

In [None]:
    
from ray.rllib.algorithms.marwil import MARWILConfig
from ray import tune
config = MARWILConfig()
# Print out some default values.
print(config.beta)  
# Update the config object.
config.training(lr=tune.grid_search(  
    [0.001, 0.0001]), beta=0.75,gamma=0.99)
# Set the config object's data path.
# Run this from the ray directory root.
config.offline_data( 
    input_=["./rllib/tests/data/cartpole/large.json"])
config = config.multi_agent(
    policies={
            "pol1": (None, observer_space, action_space, {}),
            "pol2": (None, observer_space, action_space, {}),
        },
        policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "pol1"
        if agent_id == 0
        else "pol2",
        observation_fn=central_critic_observer,#lambda agent_id, episode, worker, **kwargs: "pol2"
    )
# Set the config object's env, used for evaluation.
config.environment(env=CichyEnv())  
# Use to_dict() to get the old-style python config dict
# when running with tune.

"""
pbt_scheduler = PopulationBasedTraining(
    time_attr='training_iteration',
    metric='episode_reward_mean',#'loss',
    mode='min',
    perturbation_interval=1,
    hyperparam_mutations={
        #"lr": [1e-3, 5e-4, 1e-4, 5e-5, 1e-5],
        "alpha": tune.uniform(0.0, 1.0),
    }
)
pb2_scheduler = PB2(
    time_attr='training_iteration',#'time_total_s',
    metric='episode_reward_mean',#'mean_accuracy',
    mode='min',
    perturbation_interval=600.0,
    hyperparam_bounds={
        #"lr": [1e-3, 1e-5],
        "alpha": [0.0, 1.0],
    }
)


"""

tune.Tuner(  
    "MARWIL",
    run_config=air.RunConfig(stop=stop, verbose=2),
    param_space=config.to_dict(),
    """
    tune_config=tune.TuneConfig(
        num_samples=2,#4,
        scheduler=pb2_scheduler,
        #reuse_actors=True,
        ),
    """
).fit()

    ray.shutdown()

In [None]:
 from ray.rllib.algorithms.marwil import MARWILConfig
from ray import tune
config = MARWILConfig()
# Print out some default values.
print(config.beta)  
# Update the config object.
config = config.training(beta=1.0, lr=0.00001, gamma=0.99) 
# Set the config object's data path.
# Run this from the ray directory root.
config.offline_data( 
    input_=["./rllib/tests/data/cartpole/large.json"])
config = config.multi_agent(
    policies={
            "pol1": (None, observer_space, action_space, {}),
            "pol2": (None, observer_space, action_space, {}),
        },
        policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "pol1"
        if agent_id == 0
        else "pol2",
        observation_fn=central_critic_observer,#lambda agent_id, episode, worker, **kwargs: "pol2"
    )
# Set the config object's env, used for evaluation.
config.environment(env=CichyEnv())  
"""


"""
# Build an Algorithm object from the config and run 1 training iteration.
algo = config.build()  
algo.train() 

In [None]:
#prediciton

action = agent.compute_single_action(obs)