In [None]:
from src.agents.sac_agent import SACAgent
from src.sequentiell.pipeline import TrainingPipeline
from src.compostion_buffer import CompositionReplayBuffer
from src.environment_orchestrator import env_specs

In [None]:
def create_compositions_dict(n=10):
    return [{'mujoco': round(1-x*(1/n), 2), 'brax': round(x*(1/n), 2)}for x in range(n+1)]


print(create_compositions_dict(n=8))


def halfchetah_experiment():
    for _ in range(10):
        for composition in create_compositions_dict():
            env_name = 'HalfCheetah',

            action_dim = env_specs[env_name]['action_dim']

            composition_buffer = CompositionReplayBuffer(
                capacity=500000,
                strategy='stratified',
                sampling_composition=composition,
                buffer_composition={'mujoco': 0.5, 'brax': 0.5},
                engine_counts={'mujoco': 1, 'brax': 1},
                recency_bias=3.0
            )

            sac_agent = SACAgent(
                state_dim=env_specs[env_name]['state_dim'],
                action_dim=action_dim,
                replay_buffer=composition_buffer,
                hidden_dim=512,
                lr=3e-4,
                gamma=0.99,
                tau=0.01,
                target_entropy=-0.5*action_dim,
                grad_clip=5.0,
                warmup_steps=20000
            )

            pipeline = TrainingPipeline(
                env_name=env_name,
                batch_size=256,
                episodes=500,
                steps_per_episode=1000,
                agent=sac_agent,
                engine_dropout=True,
                drop_out_limit=0.5
            )

            pipeline.run()

[{'mujoco': 1.0, 'brax': 0.0}, {'mujoco': 0.88, 'brax': 0.12}, {'mujoco': 0.75, 'brax': 0.25}, {'mujoco': 0.62, 'brax': 0.38}, {'mujoco': 0.5, 'brax': 0.5}, {'mujoco': 0.38, 'brax': 0.62}, {'mujoco': 0.25, 'brax': 0.75}, {'mujoco': 0.12, 'brax': 0.88}, {'mujoco': 0.0, 'brax': 1.0}]
