In [1]:
import gym
import numpy as np
import torch
import torch.nn as nn

import matplotlib.pyplot as plt
import time
from tqdm import trange

In [2]:
from neural_nets import q_network, p_network, d_network
from sample_env import EnvSampler
from replay_memory import MemoryElement, ReplayMemory
from models import DynamicEnsemble, SAC, UniformPolicy
from utils import get_memories_torch



In [3]:
env_name = 'Ant-v4'


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

seed = 0
torch.manual_seed(seed)
np.random.seed(seed)
#env.seed(seed)

replay_size = 1000000
rollout_batch_size = 100
agent_batch_size = 256

num_models = 10



In [4]:
#Initialize real environment and real replay buffer
real_env = EnvSampler(env_name)
real_memory = ReplayMemory(replay_size)

#Initialize model buffer
model_memory = ReplayMemory(replay_size)

#Initialize evaluation environment
eval_env = EnvSampler(env_name, env_number = num_models)


E_step = 0
min_reward = np.inf
max_reward = -np.inf


agent = SAC(real_env)
dynamic_model = DynamicEnsemble(num_models, real_env)

In [5]:
#Initial data collection
for _ in trange(5000):
    real_memory_element = real_env.sample()[0]
    min_reward = min(real_memory_element.reward, min_reward)
    max_reward = max(real_memory_element.reward, max_reward)

    real_memory.push(real_memory_element)


100%|█████████████████████████████████████| 5000/5000 [00:02<00:00, 2176.58it/s]


In [6]:
#Initial model update
for _ in trange(10000 // 10):
    dynamic_model.update_params(real_memory)

100%|███████████████████████████████████████| 1000/1000 [02:35<00:00,  6.42it/s]


In [26]:
#Initial rollout

for _ in trange(5000 // 10):
    roll_out(dynamic_model, real_memory, model_memory, agent, real_env = real_env, model_batch_size = rollout_batch_size, init= True)

  0%|                                                   | 0/500 [00:00<?, ?it/s]

[ 0.35188087  0.44180884  0.16860129  0.84373505 -0.14772321 -0.52349419
 -0.13804593  0.58856111]
[-0.33199133 -0.70224095  0.24600144 -0.35681459 -0.17709548  0.7815187
 -0.61266428  0.2092077 ]
[-0.85565905 -0.4953875   0.65724867  0.11788838 -0.5965073  -0.71551345
  0.02094051 -0.82571126]
[ 0.37403582  0.00586862  0.71705578  0.51879547 -0.39550944 -0.14071161
  0.86937446  0.81199347]
[ 0.97318171  0.44699094  0.66787194 -0.0288882  -0.23858167 -0.45562606
 -0.37939599  0.12127929]
[-0.73186465  0.09761692  0.41328362 -0.98338912  0.41595992 -0.30724588
  0.36189112  0.3689116 ]
[-0.09675446  0.24010803  0.13921998  0.4983193   0.30003401 -0.83195159
 -0.97189955 -0.18567861]
[-0.80136387  0.75542255  0.55253756 -0.52667714 -0.02691568 -0.06694545
  0.2541106  -0.33853309]
[ 0.93794941 -0.30329569  0.02219471  0.83050177  0.92278293 -0.01641021
 -0.52480373 -0.76271896]
[-0.64927958  0.93394541 -0.92083097  0.68328207  0.82642547 -0.82171105
  0.46391835 -0.75536759]
[ 0.1622821




NameError: name 'violations' is not defined

In [25]:
def roll_out(dynamic_model, real_memory, model_memory,agent, real_env , model_batch_size = 100, init = False):
        
    states, actions, rewards, next_states, terminals, truncateds = real_memory.sample_numpy(model_batch_size)

    horizon = agent.horizon
    actions = []
    for t in range(horizon):
        with torch.no_grad():
            if init:
                for i in range(model_batch_size):
                    actions.append(real_env.sample_uniform_action()[0])
                    #print(actions[-1])
                actions = np.asarray(actions)
                #print(actions.shape)
            else:
                actions, log_prob = agent.act(states, to_cpu = True)
            
            next_states, rewards = dynamic_model.forward_all(states, actions, to_cpu = True)


        dones = real_env.check_done(next_states)

        for state, action, reward, next_state, terminal in zip(states, actions, rewards, next_states, dones):
            mem_el = MemoryElement(state, action, reward, next_state, terminal, 0)
            model_memory.push(mem_el)

        ixes = np.where(~(dones))[0]
        if len(ixes) == 0:
            break
        states = next_states[ixes]
        