In [2]:
%%capture
!pip install swig
!apt update
!apt install xvfb -y
!pip install 'pyglet==1.5.27'
!pip install 'gym[box2d]==0.20.0'
!pip install 'pyvirtualdisplay==3.0'

import math
import random

import gym
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal

from IPython.display import clear_output
import matplotlib.pyplot as plt
from pyvirtualdisplay import Display
from IPython import display as disp
%matplotlib inline

display = Display(visible=0,size=(600,600))
display.start()
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

from collections import namedtuple, deque

from google.colab import drive
drive.mount('/content/drive')
import os
import time

#This code is made up from code from following git repositories:
#https://github.com/honghaow/FORK/blob/master/BipedalWalkerHardcore/TD3_FORK_BipedalWalkerHardcore_Colab.ipynb

##Change mode here
mode = 'basic' #'hardcore'

##Seed environment or not
#if seed environment set to True
seed_bool = True #True or False
if seed_bool:
  if mode =='basic':
    SEED = 40
  if mode=='hardcore':
    SEED=42
else:
  SEED=88

In [3]:
# Actor Neural Network
class Actor(nn.Module):
    def __init__(self, state_size, action_size, seed, fc_units=400, fc1_units=300):
        super(Actor, self).__init__()
        if mode!='hardcore' or seed_bool==False:
            self.seed = torch.manual_seed(seed)
        self.fc1 = nn.Linear(state_size, fc_units)
        self.fc2 = nn.Linear(fc_units, fc1_units)
        self.fc3 = nn.Linear(fc1_units, action_size)

    def forward(self, state):
        """Build an actor (policy) network that maps states -> actions."""
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        return F.torch.tanh(self.fc3(x))

# Q1-Q2-Critic Neural Network  
  
class Critic(nn.Module):
    def __init__(self, state_size, action_size, seed, fc1_units=400, fc2_units=300):
        super(Critic, self).__init__()
        if mode!='hardcore'  or seed_bool==False:
            self.seed = torch.manual_seed(seed)

        # Q1 architecture
        self.l1 = nn.Linear(state_size + action_size, fc1_units)
        self.l2 = nn.Linear(fc1_units, fc2_units)
        self.l3 = nn.Linear(fc2_units, 1)

        # Q2 architecture
        self.l4 = nn.Linear(state_size + action_size, fc1_units)
        self.l5 = nn.Linear(fc1_units, fc2_units)
        self.l6 = nn.Linear(fc2_units, 1)

    def forward(self, state, action):
        """Build a critic (value) network that maps (state, action) pairs -> Q-values."""
        xa = torch.cat([state, action], 1)

        x1 = F.relu(self.l1(xa))
        x1 = F.relu(self.l2(x1))
        x1 = self.l3(x1)

        x2 = F.relu(self.l4(xa))
        x2 = F.relu(self.l5(x2))
        x2 = self.l6(x2)

        return x1, x2


class SysModel(nn.Module):
    def __init__(self, state_size, action_size, fc1_units=400, fc2_units=300):
        super(SysModel, self).__init__()
        self.l1 = nn.Linear(state_size + action_size, fc1_units)
        self.l2 = nn.Linear(fc1_units, fc2_units)
        self.l3 = nn.Linear(fc2_units, state_size)


    def forward(self, state, action):
        """Build a system model to predict the next state at a given state."""
        xa = torch.cat([state, action], 1)

        x1 = F.relu(self.l1(xa))
        x1 = F.relu(self.l2(x1))
        x1 = self.l3(x1)

        return x1



In [4]:
class TD3_FORK:
    def __init__(
        self,name,env,
        load = False,
        gamma = 0.99, #discount factor
        lr_actor = 3e-4,
        lr_critic = 3e-4,
        lr_sysmodel = 3e-4,
        batch_size = 100,
        buffer_capacity = 1000000,
        tau = 0.02,  #target network update factor
        random_seed = np.random.randint(1,10000),
        cuda = True,
        policy_noise=0.2, 
        std_noise = 0.1,
        noise_clip=0.5,
        policy_freq=2 #target network update period
    ):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.env = env
        self.create_actor()
        self.create_critic()
        self.create_sysmodel()
        self.act_opt = optim.Adam(self.actor.parameters(), lr=lr_actor)
        self.crt_opt = optim.Adam(self.critic.parameters(), lr=lr_critic)
        self.sys_opt = optim.Adam(self.sysmodel.parameters(), lr=lr_sysmodel)
        self.set_weights()
        self.replay_memory_buffer = deque(maxlen = buffer_capacity)
        self.replay_memory_bufferd_dis = deque(maxlen = buffer_capacity)
        self.batch_size = batch_size
        self.tau = tau
        self.policy_freq = policy_freq
        self.gamma = gamma
        self.name = name
        self.upper_bound = self.env.action_space.high[0] #action space upper bound
        self.lower_bound = self.env.action_space.low[0]  #action space lower bound
        self.obs_upper_bound = self.env.observation_space.high[0] #state space upper bound
        self.obs_lower_bound = self.env.observation_space.low[0]  #state space lower bound
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.std_noise = std_noise   
 

    

    def create_actor(self):
        params = {
            'state_size':      self.env.observation_space.shape[0],
            'action_size':     self.env.action_space.shape[0],
            'seed':            SEED
        }
        self.actor = Actor(**params).to(self.device)
        self.actor_target = Actor(**params).to(self.device)

    def create_critic(self):
        params = {
            'state_size':      self.env.observation_space.shape[0],
            'action_size':     self.env.action_space.shape[0],
            'seed':            SEED
        }
        self.critic = Critic(**params).to(self.device)
        self.critic_target = Critic(**params).to(self.device)

    def create_sysmodel(self):
        params = {
            'state_size':      self.env.observation_space.shape[0],
            'action_size':     self.env.action_space.shape[0]
        }
        self.sysmodel = SysModel(**params).to(self.device)

    def set_weights(self):
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.critic_target.load_state_dict(self.critic.state_dict())

    def load_weight(self):
        self.actor.load_state_dict(torch.load('/content/drive/My Drive/bipedal/weights/hardcore/actor.pth', map_location=self.device))
        self.critic.load_state_dict(torch.load('/content/drive/My Drive/bipedal/weights/hardcore/critic.pth', map_location=self.device))
        self.actor_target.load_state_dict(torch.load('/content/drive/My Drive/bipedal/weights/hardcore/actor_t.pth', map_location=self.device))
        self.critic_target.load_state_dict(torch.load('/content/drive/My Drive/bipedal/weights/hardcore/critic_t.pth', map_location=self.device))
        self.sysmodel.load_state_dict(torch.load('/content/drive/My Drive/bipedal/weights/hardcore/sysmodel.pth', map_location=self.device))

    def add_to_replay_memory(self, transition, buffername):
        #add samples to replay memory
        buffername.append(transition)

    def get_random_sample_from_replay_mem(self, buffername):
        #random samples from replay memory
        random_sample = random.sample(buffername, self.batch_size)
        return random_sample


    def learn_and_update_weights_by_replay(self,training_iterations, weight, totrain,avg_reward):
        """Update policy and value parameters using given batch of experience tuples.
        where:
            actor_target(state) -> action
            critic_target(state, action) -> Q-value
        """
        # print(len(self.replay_memory_buffer))
        if len(self.replay_memory_buffer) < 1e4:
            return 1
        for it in range(training_iterations):
            mini_batch = self.get_random_sample_from_replay_mem(self.replay_memory_buffer)
            state_batch = torch.from_numpy(np.vstack([i[0] for i in mini_batch])).float().to(self.device)
            action_batch = torch.from_numpy(np.vstack([i[1] for i in mini_batch])).float().to(self.device)
            reward_batch = torch.from_numpy(np.vstack([i[2] for i in mini_batch])).float().to(self.device)
            add_reward_batch = torch.from_numpy(np.vstack([i[3] for i in mini_batch])).float().to(self.device)
            next_state_batch = torch.from_numpy(np.vstack([i[4] for i in mini_batch])).float().to(self.device)
            done_list = torch.from_numpy(np.vstack([i[5] for i in mini_batch]).astype(np.uint8)).float().to(self.device)

            # Training and updating Actor & Critic networks.
            
            #Train Critic
            target_actions = self.actor_target(next_state_batch)
            if (mode=='basic') and (seed_bool==True):
                if(avg_reward>=300):
                  self.policy_noise = 0.1
                else:
                  self.policy_noise = np.random.uniform(0.19,0.21)##########################################
            offset_noises = torch.FloatTensor(action_batch.shape).data.normal_(0, self.policy_noise).to(self.device)

            #clip noise
            offset_noises = offset_noises.clamp(-self.noise_clip, self.noise_clip)
            target_actions = (target_actions + offset_noises).clamp(self.lower_bound, self.upper_bound)

            #Compute the target Q value
            Q_targets1, Q_targets2 = self.critic_target(next_state_batch, target_actions)
            Q_targets = torch.min(Q_targets1, Q_targets2)
            Q_targets = reward_batch + self.gamma * Q_targets * (1 - done_list)

            #Compute current Q estimates
            current_Q1, current_Q2 = self.critic(state_batch, action_batch)
            # Compute critic loss
            critic_loss = F.mse_loss(current_Q1, Q_targets.detach()) + F.mse_loss(current_Q2, Q_targets.detach())
            # Optimize the critic
            self.crt_opt.zero_grad()
            critic_loss.backward()
            self.crt_opt.step()

            self.soft_update_target(self.critic, self.critic_target)


            #Train_sysmodel
            predict_next_state = self.sysmodel(state_batch, action_batch) * (1-done_list)
            next_state_batch = next_state_batch * (1 -done_list)
            sysmodel_loss = F.mse_loss(predict_next_state, next_state_batch.detach())
            self.sys_opt.zero_grad()
            sysmodel_loss.backward()
            self.sys_opt.step()
        
            s_flag = 1 if sysmodel_loss.item() < 0.020  else 0

            #Train Actor
            # Delayed policy updates
            if it % self.policy_freq == 0 and totrain == 1:
                actions = self.actor(state_batch)
                actor_loss1,_ = self.critic_target(state_batch, actions)
                actor_loss1 =  actor_loss1.mean()
                actor_loss =  - actor_loss1 

                if s_flag == 1:
                    p_actions = self.actor(state_batch)
                    p_next_state = self.sysmodel(state_batch, p_actions).clamp(self.obs_lower_bound,self.obs_upper_bound)

                    p_actions2 = self.actor(p_next_state.detach()) * self.upper_bound
                    actor_loss2,_ = self.critic_target(p_next_state.detach(), p_actions2)
                    actor_loss2 = actor_loss2.mean() 

                    p_next_state2= self.sysmodel(p_next_state.detach(), p_actions2).clamp(self.obs_lower_bound,self.obs_upper_bound)
                    p_actions3 = self.actor(p_next_state2.detach()) * self.upper_bound
                    actor_loss3,_ = self.critic_target(p_next_state2.detach(), p_actions3)
                    actor_loss3 = actor_loss3.mean() 

                    actor_loss_final =  actor_loss - weight * (actor_loss2) - 0.5 *  weight * actor_loss3
                else:
                    actor_loss_final =  actor_loss

                self.act_opt.zero_grad()
                actor_loss_final.backward()
                self.act_opt.step()

                #Soft update target models
               
                self.soft_update_target(self.actor, self.actor_target)
                
        return sysmodel_loss.item()

    def soft_update_target(self,local_model,target_model):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target
        Params
        ======
            local_model: PyTorch model (weights will be copied from)
            target_model: PyTorch model (weights will be copied to)
            tau (float): interpolation parameter
        """
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(self.tau*local_param.data + (1.0-self.tau)*target_param.data)

    def policy(self,state,avg_reward):
        """select action based on ACTOR"""
        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
        self.actor.eval()
        with torch.no_grad():
            actions = self.actor(state).cpu().data.numpy()
        self.actor.train()
        # Adding noise to action
        if (mode=='basic') and (seed_bool==True):
            if(avg_reward>=300):
              self.std_noise = 0.05
            else:
              self.std_noise = np.random.uniform(0.09,0.11)#####################################
        shift_action = np.random.normal(0, self.std_noise, size=self.env.action_space.shape[0])
        sampled_actions = (actions + shift_action)
        # We make sure action is within bounds
        legal_action = np.clip(sampled_actions,self.lower_bound,self.upper_bound)
        return np.squeeze(legal_action)


    def select_action(self,state):
        """select action based on ACTOR"""
        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
        with torch.no_grad():
            actions = self.actor_target(state).cpu().data.numpy()
        return np.squeeze(actions)


    def eval_policy(self, env_name, seed, eval_episodes):
        eval_env = env_name
        eval_env.seed(seed)

        avg_reward = 0.
        for _ in range(eval_episodes):
            state, done = eval_env.reset(), False
            while not done:
                action = self.select_action(np.array(state))
                state, reward, done, _ = eval_env.step(action)
                avg_reward += reward
        avg_reward /= eval_episodes

        print("---------------------------------------")
        print(f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}")
        print("---------------------------------------")
        return avg_reward

In [5]:
"""Training the agent"""
gym.logger.set_level(40)
max_steps = 3000
falling_down = 0


if __name__ == '__main__':
    if mode=='hardcore':
        env = gym.make('BipedalWalkerHardcore-v3')
    else:
        env = gym.make('BipedalWalker-v3')
    plot_interval = 10 
    video_every = 25
    env = gym.wrappers.Monitor(env, "/content/drive/MyDrive/video-td3Fork", video_callable=lambda ep_id: ep_id%video_every == 0, force=True)
    log_f = open("/content/drive/MyDrive/agent-log-td3Fork.txt","a+")
    #seed env
    if seed_bool==True:
        seed = SEED
        torch.manual_seed(seed)
        env.seed(seed)
        random.seed(seed)
        np.random.seed(seed)
        env.action_space.seed(seed)

    agent = TD3_FORK('Bipedal-walker', env, batch_size = 100)#'Bipedalhardcore',
    total_episodes = 100000
    start_timestep=0            #time_step to select action based on Actor
    time_start = time.time()        # Init start time
    ep_reward_list = []
    avg_reward_list = []
    total_timesteps = 0
    sys_loss = 0
    numtrainedexp = 0
    save_time = 0
    expcount = 0
    totrain = 0
    reward_list = []
    plot_data = []
    ep =0

    torch_saved= False
    if torch_saved:
        chkpt = "/content/drive/MyDrive/checkpointtd3-fork.pt"
        print("load model: ",chkpt)
        params = torch.load(chkpt)
        agent.actor.load_state_dict(params["actor"])
        agent.critic.load_state_dict(params["critic"])
        agent.actor_target.load_state_dict(params["actor_target"])
        agent.critic_target.load_state_dict(params["critic_target"])
        agent.sysmodel.load_state_dict(params["sysmodel"])
        agent.act_opt.load_state_dict(params["act_opt"])
        agent.crt_opt.load_state_dict(params["crt_opt"])
        agent.sys_opt.load_state_dict(params["sys_opt"])
        ep = params["ep"]
        total_timesteps = params["total_timesteps"]
        sys_loss = params["sys_loss"]
        ep_reward_list = params["ep_reward_list"]
        avg_reward_list = params["avg_reward_list"]
        agent.replay_memory_buffer = params["replay_buffer"]

    #for ep in range(total_episodes):
    while ep<total_episodes:
        state = env.reset()
        episodic_reward = 0
        timestep = 0
        temp_replay_buffer = []

        for st in range(max_steps):

            # Select action randomly or according to policy
            if total_timesteps < start_timestep:
                action = env.action_space.sample()
            else:
                action = agent.policy(state,np.mean(ep_reward_list[-10:]))

            # Recieve state and reward from environment.
            next_state, reward, done, info = env.step(action)
            #change original reward from -100 to -5 and 5*reward for other values
            episodic_reward += reward
            if reward == -100:
                add_reward = -1
                reward = -5
                falling_down += 1
                expcount += 1
            else:
                add_reward = 0
                reward = 5 * reward

            temp_replay_buffer.append((state, action, reward, add_reward, next_state, done))
            
            # End this episode when `done` is True
            if done:
                if add_reward == -1 or episodic_reward < 250:            
                    totrain = 1
                    for temp in temp_replay_buffer: 
                        agent.add_to_replay_memory(temp, agent.replay_memory_buffer)
                elif expcount > 0 and np.random.rand() > 0.5:
                    totrain = 1
                    expcount -= 10
                    for temp in temp_replay_buffer: 
                        agent.add_to_replay_memory(temp, agent.replay_memory_buffer)
                break
            state = next_state
            timestep += 1     
            total_timesteps += 1

        ep_reward_list.append(episodic_reward)
        reward_list.append(episodic_reward)
        # Mean of last 100 episodes
        avg_reward = np.mean(ep_reward_list[-100:])
        avg_reward_list.append(avg_reward)
                    
        s = (int)(time.time() - time_start)

       
        #Training agent only when new experiences are added to the replay buffer
        weight =  1 - np.clip(np.mean(ep_reward_list[-100:])/300, 0, 1)
        if totrain == 1:
            sys_loss = agent.learn_and_update_weights_by_replay(timestep*10, weight, totrain,np.mean(ep_reward_list[-10:]))
        else: 
            sys_loss = agent.learn_and_update_weights_by_replay(100*10, weight, totrain,np.mean(ep_reward_list[-10:]))
        totrain = 0

        # do NOT change this logging code - it is used for automated marking!
        log_f.write('episode: {}, reward: {}\n'.format(ep, episodic_reward))
        log_f.flush()
        # print reward data every so often - add a graph like this in your report
        if ep % plot_interval == 0:
            print("\rTotal T: {:d} Episode Num: {:d} Reward: {:f} Avg Reward: {:f}".format(
            total_timesteps, ep, episodic_reward, avg_reward ), end="")
            plot_data.append([ep, np.array(reward_list).mean(), np.array(reward_list).std()])
            reward_list = []
            # plt.rcParams['figure.dpi'] = 100
            plt.plot([x[0] for x in plot_data], [x[1] for x in plot_data], '-', color='tab:grey')
            plt.fill_between([x[0] for x in plot_data], [x[1]-x[2] for x in plot_data], [x[1]+x[2] for x in plot_data], alpha=0.2, color='tab:grey')
            plt.xlabel('Episode number')
            plt.ylabel('Episode reward')
            plt.show()
            disp.clear_output(wait=True)

            #Put in memory
            torch.save(
                {
                    "actor": agent.actor.state_dict(),
                    "critic": agent.critic.state_dict(),
                    "actor_target": agent.actor_target.state_dict(),
                    "critic_target": agent.critic_target.state_dict(),
                    "sysmodel": agent.sysmodel.state_dict(),
                    "act_opt": agent.act_opt.state_dict(),
                    "crt_opt": agent.crt_opt.state_dict(),
                    "sys_opt": agent.sys_opt.state_dict(),
                    "ep":ep,
                    "total_timesteps": total_timesteps,
                    "sys_loss": sys_loss,
                    "ep_reward_list": ep_reward_list,
                    "avg_reward_list": avg_reward_list,
                    "replay_buffer": agent.replay_memory_buffer,
                },
                f"/content/drive/MyDrive/checkpointtd3-fork.pt",
            )
        ep+=1
      

# Plotting graph
# Episodes versus Avg. Rewards
plt.plot(avg_reward_list)
plt.xlabel("Episode")
plt.ylabel("Avg. Epsiodic Reward")
plt.show()
env.close()

KeyboardInterrupt: ignored

In [None]:
#######below attached code for td3 with lfiw, discor, per
#This code is made up from code from following git repositories:
#https://github.com/AIDefender/MyDiscor/tree/f040befbca4498388217ee634a933211d4566182/discor/algorithm/rlkit/torch
#https://github.com/vitchyr/rlkit/blob/master/rlkit/exploration_strategies/gaussian_strategy.py
#https://github.com/higgsfield/RL-Adventure-2/blob/master/6.td3.ipynb

"""%%capture
!pip install swig
!apt update
!apt install xvfb -y
!pip install 'pyglet==1.5.27'
!pip install 'gym[box2d]==0.20.0'
!pip install 'pyvirtualdisplay==3.0'

import math
import random

import gym
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal

from IPython.display import clear_output
import matplotlib.pyplot as plt
from pyvirtualdisplay import Display
from IPython import display as disp
%matplotlib inline

display = Display(visible=0,size=(600,600))
display.start()
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#from google.colab import drive
#drive.mount('/content/drive')
#from https://github.com/AIDefender/MyDiscor/tree/f040befbca4498388217ee634a933211d4566182/discor/algorithm/rlkit/torch
def identity(x):
    return x
def fanin_init(tensor):
    size = tensor.size()
    if len(size) == 2:
        fan_in = size[0]
    elif len(size) > 2:
        fan_in = np.prod(size[1:])
    else:
        raise Exception("Shape must be have dimension at least 2.")
    bound = 1. / np.sqrt(fan_in)
    return tensor.data.uniform_(-bound, bound)

class LayerNorm(nn.Module):
    
    #Simple 1D LayerNorm.
    

    def __init__(self, features, center=True, scale=False, eps=1e-6):
        super().__init__()
        self.center = center
        self.scale = scale
        self.eps = eps
        if self.scale:
            self.scale_param = nn.Parameter(torch.ones(features))
        else:
            self.scale_param = None
        if self.center:
            self.center_param = nn.Parameter(torch.zeros(features))
        else:
            self.center_param = None

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        output = (x - mean) / (std + self.eps)
        if self.scale:
            output = output * self.scale_param
        if self.center:
            output = output + self.center_param
        return output

class Mlp(nn.Module):
    def __init__(
            self,
            hidden_sizes,
            output_size,
            input_size,
            init_w=3e-3,
            hidden_activation=F.relu,
            output_activation=identity,
            hidden_init=fanin_init,
            b_init_value=0.1,
            layer_norm=False,
            layer_norm_kwargs=None,
    ):
        super().__init__()

        if layer_norm_kwargs is None:
            layer_norm_kwargs = dict()

        self.input_size = input_size
        self.output_size = output_size
        self.hidden_activation = hidden_activation
        self.output_activation = output_activation
        self.layer_norm = layer_norm
        self.fcs = []
        self.layer_norms = []
        in_size = input_size

        for i, next_size in enumerate(hidden_sizes):
            fc = nn.Linear(in_size, next_size)
            in_size = next_size
            hidden_init(fc.weight)
            fc.bias.data.fill_(b_init_value)
            self.__setattr__("fc{}".format(i), fc)
            self.fcs.append(fc)

            if self.layer_norm:
                ln = LayerNorm(next_size)
                self.__setattr__("layer_norm{}".format(i), ln)
                self.layer_norms.append(ln)

        self.last_fc = nn.Linear(in_size, output_size)
        self.last_fc.weight.data.uniform_(-init_w, init_w)
        self.last_fc.bias.data.uniform_(-init_w, init_w)

    def forward(self, input, return_preactivations=False):
        h = input
        for i, fc in enumerate(self.fcs):
            h = fc(h)
            if self.layer_norm and i < len(self.fcs) - 1:
                h = self.layer_norms[i](h)
            h = self.hidden_activation(h)
        preactivation = self.last_fc(h)
        output = self.output_activation(preactivation)
        if return_preactivations:
            return output, preactivation
        else:
            return output


class FlattenMlp(Mlp):
    
    #Flatten inputs along dimension 1 and then pass through MLP.
    

    def forward(self, *inputs, **kwargs):
        flat_inputs = torch.cat(inputs, dim=1)
        return super().forward(flat_inputs, **kwargs)

class ReplayBuffer:

    def __init__(self, memory_size, state_shape, action_shape, gamma=0.99,
                 nstep=1, arbi_reset=False, *args, **kwargs):
        assert isinstance(memory_size, int) and memory_size > 0
        assert isinstance(state_shape, tuple)
        assert isinstance(action_shape, tuple)
        assert isinstance(gamma, float) and 0 < gamma < 1.0
        assert isinstance(nstep, int) and nstep > 0

        self._memory_size = memory_size
        self._state_shape = state_shape
        self._action_shape = action_shape
        self._gamma = gamma
        self._nstep = nstep
        # If we need arbitrary reset, the output of env.sim.get_state() needs saving for further reset
        self._arbi_reset = arbi_reset

        self._reset()

    def _reset(self):
        self._n = 0
        self._p = 0

        self._states = np.empty(
            (self._memory_size, ) + self._state_shape, dtype=np.float32)
        self._next_states = np.empty(
            (self._memory_size, ) + self._state_shape, dtype=np.float32)
        self._actions = np.empty(
            (self._memory_size, ) + self._action_shape, dtype=np.float32)

        self._rewards = np.empty((self._memory_size, 1), dtype=np.float32)
        self._dones = np.empty((self._memory_size, 1), dtype=np.float32)
        if self._arbi_reset:
            self._sim_states = [None] * self._memory_size

 #       if self._nstep != 1:
 #           self._nstep_buffer = NStepBuffer(self._gamma, self._nstep)

    def append(self, state, action, reward, next_state, done, step=None, episode_done=None, sim_state=None):
        if self._nstep != 1:
            self._nstep_buffer.append(state, action, reward)

            if self._nstep_buffer.is_full():
                state, action, reward = self._nstep_buffer.get()
                self._append(state, action, reward, next_state, done, episode_done=episode_done)

            if done or episode_done:
                while not self._nstep_buffer.is_empty():
                    state, action, reward = self._nstep_buffer.get()
                    self._append(state, action, reward, next_state, done, episode_done=episode_done)

        else:
            self._append(state, action, reward, next_state, done, step, sim_state)

    def _append(self, state, action, reward, next_state, done, step=None, sim_state=None, episode_done=None):
        self._states[self._p, ...] = state
        self._actions[self._p, ...] = action
        self._rewards[self._p, ...] = reward
        self._next_states[self._p, ...] = next_state
        self._dones[self._p, ...] = done
        if self._arbi_reset:
            self._sim_states[self._p] = sim_state

        self._n = min(self._n + 1, self._memory_size)
        self._p = (self._p + 1) % self._memory_size

    def sample(self, batch_size, device=torch.device('cpu')):
        assert isinstance(batch_size, int) and batch_size > 0

        idxes = self._sample_idxes(batch_size)
        return self._sample_batch(idxes, batch_size, device)

    def _sample_idxes(self, batch_size):
        return np.random.randint(low=0, high=self._n, size=batch_size)

    def _sample_batch(self, idxes, batch_size, device):
        states = torch.tensor(
            self._states[idxes], dtype=torch.float, device=device)
        actions = torch.tensor(
            self._actions[idxes], dtype=torch.float, device=device)
        rewards = torch.tensor(
            self._rewards[idxes], dtype=torch.float, device=device)
        dones = torch.tensor(
            self._dones[idxes], dtype=torch.float, device=device)
        next_states = torch.tensor(
            self._next_states[idxes], dtype=torch.float, device=device)
        batch = {
            'states': states,
            'actions': actions,
            'rewards': rewards,
            'dones': dones,
            'next_states': next_states
        }
        if self._arbi_reset:
            sim_states = [self._sim_states[i] for i in idxes]
            batch.update({'sim_states': sim_states})

        return batch

    def __len__(self):
        return self._n

class TemporalPrioritizedReplayBuffer(ReplayBuffer):

    def __init__(self, memory_size, state_shape, action_shape, gamma=0.99, nstep=1,
                 arbi_reset=False):
        super().__init__(memory_size, state_shape, action_shape, gamma, nstep, arbi_reset=arbi_reset)

    def _reset(self):
        super()._reset()
        self._steps = np.empty((self._memory_size, 1), dtype=np.int64)
        self._done_cnts = np.empty((self._memory_size, 1), dtype=np.int64)
        self._cur_done_cnt = 0

    def _append(self, state, action, reward, next_state, done, step, sim_state=None, episode_done=None):
        super()._append(state, action, reward, next_state, done, step, sim_state, episode_done=episode_done)
        # We can compute mod on negative number
        self._p = (self._p - 1) % self._memory_size 
        self._steps[self._p, ...] = step
        self._done_cnts[self._p, ...] = self._cur_done_cnt
        self._p = (self._p + 1) % self._memory_size
        if done or episode_done:
            self._cur_done_cnt += 1

    def _sample_batch(self, idxes, batch_size, device):
        batch = super()._sample_batch(idxes, batch_size, device)
        steps = torch.tensor(
            self._steps[idxes], dtype=torch.int64, device=device)
        done_cnts = torch.tensor(
            self._done_cnts[idxes], dtype=torch.int64, device=device)
        batch.update({"steps": steps})
        batch.update({"done_cnts": done_cnts})
        return batch


def initialize_weights_xavier(m, gain=1.0):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight, gain=gain)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)


def create_linear_network(input_dim, output_dim, hidden_units=[],
                          hidden_activation=nn.ReLU(), output_activation=None,
                          initializer=initialize_weights_xavier):
    assert isinstance(input_dim, int) and isinstance(output_dim, int)
    assert isinstance(hidden_units, list) or isinstance(hidden_units, list)

    layers = []
    units = input_dim
    for next_units in hidden_units:
        layers.append(nn.Linear(units, next_units))
        layers.append(hidden_activation)
        units = next_units

    layers.append(nn.Linear(units, output_dim))
    if output_activation is not None:
        layers.append(output_activation)

    return nn.Sequential(*layers).apply(initialize_weights_xavier)

class BaseNetwork(nn.Module):

    def save(self, path):
        torch.save(self.state_dict(), path)

    def load(self, path):
        self.load_state_dict(torch.load(path))
class StateActionFunction(BaseNetwork):

    def __init__(self, state_dim, action_dim, hidden_units=[256, 256]):
        super().__init__()
        self.net = create_linear_network(
            input_dim=state_dim+action_dim,
            output_dim=1,
            hidden_units=hidden_units)

    def forward(self, x):
        return self.net(x)

class ValueNetwork(BaseNetwork):

    def __init__(self, state_dim, action_dim, hidden_units=[256, 256]):
        super().__init__()
        self.net1 = StateActionFunction(state_dim, action_dim, hidden_units)
        self.net2 = StateActionFunction(state_dim, action_dim, hidden_units)

    def forward(self, states, actions):
        assert states.dim() == 2 and actions.dim() == 2

        x = torch.cat([states, actions], dim=1)
        value1 = self.net1(x)
        value2 = self.net2(x)
        return value1, value2

class PolicyNetwork(BaseNetwork):

    def __init__(self, state_dim, action_dim, hidden_units=[256, 256]):
        super().__init__()
        self.net = create_linear_network(
            input_dim=state_dim,
            output_dim=action_dim,
            hidden_units=hidden_units)

    def forward(self, states):
        return torch.tanh(self.net(states)) #actions, entropies, torch.tanh(means)
    
    def get_action(self, state):
        state  = torch.FloatTensor(state).unsqueeze(0).to(device)
        action = self.forward(state)
        return action.detach().cpu().numpy()[0]

def update_online_networks(batch,noise_std,noise_clip):
    #learning_steps += 1
    
    update_q_functions_and_error_models(batch,noise_std,noise_clip) ##This function is almost like in td3

def update_policy_and_entropy( batch):
    states = batch["states"]

    # Update policy.
    policy_loss = calc_policy_loss(states)
    update_params(policy_optimizer, policy_loss)


def calc_policy_loss(states):
    # Resample actions to calculate expectations of Q.
    sampled_actions = policy_net(states)

    # Expectations of Q with clipped double Q technique.
    qs1,qs2 = value_net(states, sampled_actions)
    qs = torch.min(qs1, qs2)

    # Policy objective is maximization of (Q + alpha * entropy).
    policy_loss = qs
    policy_loss = -policy_loss.mean()

    return policy_loss

def update_params(optim, loss, retain_graph=False):
    optim.zero_grad()
    loss.backward(retain_graph=retain_graph)
    optim.step()
    
def update_target_networks():
    soft_update(target_value_net, value_net, target_update_coef)
    soft_update(target_policy_net,policy_net,target_update_coef)
    if discor:
        soft_update(
            target_error_net, error_net,
            target_update_coef)
        
def _soft_update(target, source, tau):
    target.data.copy_(target.data * (1.0 - tau) + source.data * tau)

def soft_update(target, source, tau=1e-2):
    assert isinstance(target, nn.Module) or isinstance(target, torch.Tensor)

    if isinstance(target, nn.Module):
        for t, s in zip(target.parameters(), source.parameters()):
            _soft_update(t, s, tau)

    elif isinstance(target, torch.Tensor):
        _soft_update(target, source, tau)

    else:
        raise NotImplementedError

def update_q_functions_and_error_models(batch,noise_std,noise_clip):
    uniform_batch = batch["uniform"]
    if lfiw:
        fast_batch = batch['fast']
        fast_states, fast_actions = fast_batch['states'], fast_batch['actions']
    else:
        fast_batch = None
    # train_batch = batch["prior"] if self.tper else batch["uniform"]
    train_batch = batch["uniform"]
    
    # transition to update Q net
    states, actions, next_states, dones = \
        train_batch["states"], train_batch["actions"], train_batch["next_states"], train_batch["dones"]
    # s,a to update the weight of lfiw network
    slow_states, slow_actions = uniform_batch["states"], uniform_batch["actions"]

    # Calculate importance weights.
    batch_size = states.shape[0]
    weights1 = torch.ones((batch_size, 1)).to(device)
    weights2 = torch.ones((batch_size, 1)).to(device)
    if discor:
        discor_weights = calc_importance_weights(next_states, dones)
        # print(weights[0].shape, discor_weights[0].shape)
        weights1 *= discor_weights[0]
        weights2 *= discor_weights[1]
    # Calculate and update prob_classifier
    if lfiw:
        lfiw_weights, prob_loss = calc_update_d_pi_iw(slow_states, slow_actions, fast_states, fast_actions, states, actions)
        weights1 *= lfiw_weights
        weights2 *= lfiw_weights
    # Calculate weights for temporal priority
#    if tper:
#        steps = train_batch["steps"]
#        done_cnts = train_batch["done_cnts"]
#        tper_weights = self.calc_tper_weights(steps, done_cnts)
#        weights1 *= tper_weights
#        weights2 *= tper_weights

    # Update Q functions.
    curr_errs1, curr_errs2 = None, None
    if discor:
        curr_errs1, curr_errs2 = calc_current_errors(states, actions)
    # pass in curr_errs1 for evaluating discor
    curr_qs1, curr_qs2, target_qs = \
        update_q_functions(train_batch,noise_std,noise_clip,weights1, weights2, fast_batch, curr_errs1)

    # Calculate current and target errors.
    if discor:
        target_errs1, target_errs2 = calc_target_errors(
            next_states, dones, curr_qs1, curr_qs2, target_qs)
        # Update error models.
        err_loss = calc_error_loss(
            curr_errs1, curr_errs2, target_errs1, target_errs2)
        update_params(error_optimizer, err_loss)
def calc_error_loss(curr_errs1, curr_errs2, target_errs1,
                    target_errs2):
    err1_loss = torch.mean((curr_errs1 - target_errs1).pow(2))
    err2_loss = torch.mean((curr_errs2 - target_errs2).pow(2))

    soft_update(
        tau1, curr_errs1.detach().mean(), target_update_coef)
    soft_update(
        tau2, curr_errs2.detach().mean(), target_update_coef)

    return err1_loss + err2_loss
def calc_target_errors(next_states, dones, curr_qs1, curr_qs2, target_qs):
    # Calculate targets of the cumulative sum of discounted Bellman errors,
    # which is 'Delta' in the paper.
    with torch.no_grad():
        next_actions = policy_net(next_states)
        next_errs1, next_errs2 = \
            target_error_net(next_states, next_actions)

        target_errs1 = (curr_qs1 - target_qs).abs() + \
            (1.0 - dones) * gamma * next_errs1
        target_errs2 = (curr_qs2 - target_qs).abs() + \
            (1.0 - dones) * gamma * next_errs2

    return target_errs1, target_errs2

def calc_current_errors(states, actions):
    curr_errs1, curr_errs2 = error_net(states, actions)
    return curr_errs1, curr_errs2
def calc_importance_weights(next_states, dones):
    with torch.no_grad():
        next_actions= policy_net(next_states)
        next_errs1, next_errs2 = \
            target_error_net(next_states, next_actions)

    # Terms inside the exponent of importance weights.
    if no_tau:
        x1 = -(1.0 - dones) * discount * next_errs1
        x2 = -(1.0 - dones) * discount * next_errs2
    else:
        x1 = -(1.0 - dones) * discount * next_errs1 / (tau1 * tau_scale)
        x2 = -(1.0 - dones) * discount * next_errs2 / (tau2 * tau_scale)


    # Calculate self-normalized importance weights.
    imp_ws1 = F.softmax(x1, dim=0)
    imp_ws2 = F.softmax(x2, dim=0)

    return imp_ws1, imp_ws2

def calc_update_d_pi_iw(slow_obs, slow_act, fast_obs, fast_act, target_obs=None, target_act=None):
    slow_samples = torch.cat((slow_obs, slow_act), dim=1)
    fast_samples = torch.cat((fast_obs, fast_act), dim=1)

    zeros = torch.zeros(slow_samples.size(0),1).to(device)
    ones = torch.ones(slow_samples.size(0),1).to(device)

    slow_preds = prob_classifier(slow_samples)
    fast_preds = prob_classifier(fast_samples)

    loss = F.binary_cross_entropy(torch.sigmoid(slow_preds), zeros) + \
            F.binary_cross_entropy(torch.sigmoid(fast_preds), ones)

    update_params(prob_optimizer, loss)

    # In case we want to compute ratio on data different from what we train the network
    if target_obs is None:
        target_obs = slow_obs
    if target_act is None:
        target_act = slow_act
    target_samples = torch.cat((target_obs, target_act), dim=1)
    slow_preds = prob_classifier(target_samples)

    importance_weights = torch.sigmoid(slow_preds/prob_temperature).detach()
    importance_weights = importance_weights / torch.sum(importance_weights)

    return importance_weights, loss

def calc_current_qs(states, actions):
    curr_qs1, curr_qs2 = value_net(states, actions)
    return curr_qs1, curr_qs2

def calc_target_qs(rewards, next_states, dones, noise_std,noise_clip):
    with torch.no_grad():
        next_actions = target_policy_net(next_states)
        noise = torch.normal(torch.zeros(next_actions.size()), noise_std).to(device)
        noise = torch.clamp(noise, -noise_clip, noise_clip)
        next_actions += noise
        next_qs1, next_qs2 = target_value_net(next_states, next_actions)
        next_qs = torch.min(next_qs1, next_qs2)
        
    assert rewards.shape == next_qs.shape
    target_qs = rewards + (1.0 - dones) * discount * next_qs

    return target_qs

def calc_q_loss(curr_qs1, curr_qs2, target_qs, imp_ws1=None, imp_ws2=None):
    assert imp_ws1 is None or imp_ws1.shape == curr_qs1.shape
    assert imp_ws2 is None or imp_ws2.shape == curr_qs2.shape
    assert not target_qs.requires_grad
    assert curr_qs1.shape == target_qs.shape

    # Q loss is mean squared TD errors with importance weights.
    if imp_ws1 is None:
        q1_loss = torch.mean((curr_qs1 - target_qs).pow(2))
        q2_loss = torch.mean((curr_qs2 - target_qs).pow(2))
    else:
        q1_loss = torch.mean((curr_qs1 - target_qs).pow(2) * imp_ws1)
        q2_loss = torch.mean((curr_qs2 - target_qs).pow(2) * imp_ws2)

    # Mean Q values for logging.
    mean_q1 = curr_qs1.detach().mean().item()
    mean_q2 = curr_qs2.detach().mean().item()

    # for a fair comparison
    unweighted_q_loss = torch.mean((curr_qs1 - target_qs).pow(2)) + torch.mean((curr_qs2 - target_qs).pow(2))

    return q1_loss + q2_loss, mean_q1, mean_q2, unweighted_q_loss

def update_q_functions(batch,noise_std,noise_clip,imp_ws1=None, imp_ws2=None, fast_batch=None, err_preds=None):
    states, actions, rewards, next_states, dones = \
        batch["states"], batch["actions"], batch["rewards"], batch["next_states"], batch["dones"]

    # Calculate current and target Q values.
    curr_qs1, curr_qs2 = calc_current_qs(states, actions)
    target_qs = calc_target_qs(rewards, next_states, dones,noise_std,noise_clip)

    # Update Q functions.
    q_loss, mean_q1, mean_q2, unweighted_q_loss = \
        calc_q_loss(curr_qs1, curr_qs2, target_qs, imp_ws1, imp_ws2)
    update_params(value_optimizer, q_loss)

#    if eval_tper and learning_steps % eval_tper_interval == 0:
#        steps = batch["steps"]
#        sim_states = batch["sim_states"]
#        done_cnts = batch["done_cnts"]
#        self.eval_Q(states[:128], actions[:128], steps[:128], sim_states[:128], curr_qs1[:128], 
#                    done_cnts[:128],
#                    err_preds[:128] if err_preds is not None else None
#        )

    # Return their values for DisCor algorithm.
    return curr_qs1.detach(), curr_qs2.detach(), target_qs

def disable_gradients(network):
    for param in network.parameters():
        param.requires_grad = False

def td3_update(step,
           batch_size,
           policy_update=2,
           noise_std=0.2,
           noise_clip=0.5
          ):
    batch = {}
    uniform_batch = replay_buffer.sample(batch_size,device)
    batch.update({"uniform": uniform_batch})
    if lfiw:
        fast_batch = fast_replay_buffer.sample(batch_size, device)
        batch.update({"fast": fast_batch})
    update_online_networks(batch,noise_std,noise_clip)

    if step % policy_update == 0:
        update_policy_and_entropy(batch['uniform'])
        update_target_networks()

#######CHANGED MINSIGMA
class GaussianExploration(object):
    def __init__(self, action_space, max_sigma=1.0, min_sigma=0.1, decay_period=500000):
        self.low  = action_space.low
        self.high = action_space.high
        self.max_sigma = max_sigma
        self.min_sigma = min_sigma
        self.decay_period = decay_period
    
    def get_action(self, action, t=0):
        sigma  = max(-0.9,-((t / self.decay_period)**3))+1 #self.max_sigma - (self.max_sigma - self.min_sigma) * min(1.0, t / self.decay_period)
        action = action + np.random.normal(size=len(action)) * sigma
        return np.clip(action, self.low, self.high)
    
#https://github.com/vitchyr/rlkit/blob/master/rlkit/exploration_strategies/gaussian_strategy.py

class NormalizedActions(gym.ActionWrapper):
    def action(self, action):
        low  = self.action_space.low
        high = self.action_space.high
        
        action = low + (action + 1.0) * 0.5 * (high - low)
        action = np.clip(action, low, high)
        
        return action

    def reverse_action(self, action):
        low  = self.action_space.low
        high = self.action_space.high
        
        action = 2 * (action - low) / (high - low) - 1
        action = np.clip(action, low, high)
        
        return action

env = NormalizedActions(gym.make('BipedalWalker-v3')) #gym.make('BipedalWalker-v3')#NormalizedActions(gym.make('BipedalWalker-v3'))
noise = GaussianExploration(env.action_space)
plot_interval = 10 
video_every = 50
env = gym.wrappers.Monitor(env, "./video", video_callable=lambda ep_id: ep_id%video_every == 0, force=True)

eval_env = NormalizedActions(gym.make('BipedalWalker-v3'))
eval_env = gym.wrappers.Monitor(env, "./video/eval", video_callable=lambda ep_id: ep_id%2 == 0, force=True)

state_dim  = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
state_shape = env.observation_space.shape
action_shape = env.action_space.shape
hidden_dim = 256
total_timesteps = 0
frame_idx   = 0
target_update_coef=0.005
value_net = ValueNetwork(state_dim, action_dim, hidden_units=[256, 256]).to(device)
policy_net = PolicyNetwork(state_dim, action_dim, hidden_units=[256, 256]).to(device)
target_value_net = ValueNetwork(state_dim, action_dim, hidden_units=[256, 256]).to(device)
target_policy_net = PolicyNetwork(state_dim, action_dim, hidden_units=[256, 256]).to(device)


policy_lr = 0.003
value_lr  = 0.003
entropy_lr = 0.003
error_lr = 0.003

value_optimizer = optim.Adam(value_net.parameters(), lr=value_lr)
policy_optimizer = optim.Adam(policy_net.parameters(), lr=policy_lr)


print("frame_idx:",frame_idx)
print("total_timesteps:",total_timesteps)
soft_update(value_net, target_value_net, tau=1.0)
soft_update(policy_net, target_policy_net, tau=1.0)

# Replay buffer with $\alpha = 0.6$. Capacity of the replay buffer must be a power of 2.
eval_tper=False #Change later?
discor = False
buffer = ReplayBuffer #TemporalPrioritizedReplayBuffer
replay_buffer = buffer(
            memory_size=1000000,
            state_shape=state_shape ,
            action_shape=action_shape ,
            gamma=0.99, nstep=1,
            arbi_reset=eval_tper)
fast_replay_buffer = buffer(
                memory_size=50000,
                state_shape=state_shape ,
                action_shape=action_shape ,
                gamma=0.99, nstep=1,
                )

# Copy parameters of the learning network to the target network.
target_value_net.load_state_dict(value_net.state_dict())

# Disable gradient calculations of the target network.
disable_gradients(target_value_net)

# Target entropy is -|A|
target_entropy = -float(action_dim)

# We optimize log(alpha), instead of alpha.
log_alpha = torch.zeros(
    1, device=device, requires_grad=True)
alpha = log_alpha.detach().exp()
alpha_optimizer = optim.Adam([log_alpha], lr=policy_lr)

lfiw = True
if lfiw:
    prob_hidden_units=[128, 128]
    prob_classifier = FlattenMlp(                
                input_size=state_dim+action_dim,
                output_size=1,
                hidden_sizes=prob_hidden_units,).to(device)
    prob_optimizer = optim.Adam(prob_classifier.parameters(), lr=entropy_lr)
    prob_temperature = 7.5

discor = True
if discor:
    tau_init = 10.0
    tau_scale = 1
    error_net = ValueNetwork(state_dim, action_dim, hidden_units=[256, 256, 256]).to(device)
    target_error_net = ValueNetwork(state_dim, action_dim, hidden_units=[256, 256, 256]).to(device)
    target_error_net.load_state_dict(error_net.state_dict())
    disable_gradients(target_error_net)
    error_optimizer = optim.Adam(error_net.parameters(),lr=error_lr)
    tau1 = torch.tensor(tau_init,device=device,requires_grad=False)
    tau2 = torch.tensor(tau_init,device=device,requires_grad=False)
    if tau_init<1e-6:
        no_tau = True
    else:
        no_tau = False
    
def evaluate_policy(policy,eval_env, eval_episodes=2):
    """run several episodes using the best agent policy
        
        Args:
            policy (agent): agent to evaluate
            env (env): gym environment
            eval_episodes (int): how many test episodes to run
        
        Returns:
            avg_reward (float): average reward over the number of evaluations
    
    """

    avg_reward = 0.
    for i in range(eval_episodes):
        state, done = eval_env.reset(), False
        step_count = 0 
        while not done and step_count<max_steps:
            action = policy_net.get_action(state)
            next_state, reward, done, _ = eval_env.step(action)
            #replay_buffer.push(state, action, reward, next_state, done)
            avg_reward += reward
            step_count +=1
            #state = next_state

    avg_reward /= eval_episodes


    return avg_reward
def assert_action(action):
    assert isinstance(action, np.ndarray)
    assert not np.isnan(np.sum(action)), 'Action has a Nan value.'
def explore(state,total_timesteps):
    #state = torch.tensor(
    #    state[None, ...].copy(), dtype=torch.float, device=device)
    #with torch.no_grad():
    action = policy_net.get_action(state)
    #action = action.cpu().numpy()[0]
    assert_action(action)
    action = noise.get_action(action,total_timesteps)
    return action


max_frames  = 1000
max_steps   = 3000
rewards     = []
batch_size  = 128
best_avg = -2000
REWARD_THRESH=300

# Set seeds
SEED = 42
random.seed(SEED)
env.seed(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)
env.action_space.seed(SEED)
# logging variables
plot_data = []
log_f = open("agent-log.txt","w+")
reward_list = []
evaluations = []
best_episode = None
best_episode_reward = -2000
discount = 0.99 
gamma = 0.99

noise_start = 1.0
noise_final = 0.1
noise_decay = 400
noise_by_frame_1 = lambda frame_idx: max(0.2,-((frame_idx /noise_decay)-1)**3) 
noise_by_frame_2 = lambda frame_idx: max(noise_final,0.3-0.2*(frame_idx /noise_decay))
#noise_by_frame = lambda frame_idx: if (frame_idx/noise_decay)<0.4: max(0.2,-((frame_idx /noise_decay)-1)**3) else: max(noise_final,0.3-0.2(frame_idx /noise_decay))
#noise_by_frame = lambda frame_idx: noise_final + (noise_start - noise_final) * math.exp(-1. * frame_idx / noise_decay)


#Add Emphasizing Recent Experience (ERE) to PER
#https://towardsdatascience.com/4-ways-to-boost-experience-replay-999d9f17f7b6
while frame_idx < max_frames:
    env.stats_recorder.done = None
    state = env.reset()
    episode_reward = 0
    game_list = []
    episode_steps = 0
    
    for step in range(max_steps):
        
        #action = policy_net.get_action(state)
        #action = noise.get_action(action,total_timesteps)
        if total_timesteps>10000:
            action = explore(state,total_timesteps)
        else:
            action = env.action_space.sample()
        next_state, reward, done, _ = env.step(action)
        if episode_steps + 1 >= max_steps:
            masked_done = False
            done = True
        else:
            masked_done = done
        transition = [state, action, reward, next_state, masked_done, episode_steps, done]
        replay_buffer.append(*transition)
        if lfiw:
            fast_replay_buffer.append(*transition)
        if len(replay_buffer) > batch_size and total_timesteps>10000:
            NOISE = 0.2
            #if best_episode_reward>295:
            #  NOISE = noise_final
            #else:
            #  if (frame_idx/noise_decay)<0.4:
            #      NOISE = noise_by_frame_1(frame_idx)
            #  else:
            #      NOISE = noise_by_frame_2(frame_idx)
            td3_update(step, batch_size,noise_std = 0.2)
            td3_update(step, batch_size,noise_std = 0.4)
        
        state = next_state
        episode_reward += reward
        total_timesteps+=1
        episode_steps +=1
        
        if done:
            break
    rewards.append(episode_reward)
    reward_list.append(episode_reward)
    if episode_reward>best_episode_reward:
        best_episode_reward = episode_reward
        best_episode = game_list
    avg_reward = np.mean(rewards[-100:])
    if avg_reward >= REWARD_THRESH:
        break
    #evaluate
    #if frame_idx%video_every ==0:
    #    eval_reward = evaluate_policy(policy_net, eval_env)
    #    evaluations.append(eval_reward)

    # do NOT change this logging code - it is used for automated marking!
    log_f.write('episode: {}, reward: {}\n'.format(frame_idx, episode_reward))
    log_f.flush()
    # print reward data every so often - add a graph like this in your report
    if frame_idx % plot_interval == 0:
        print("\rTotal T: {:d} Episode Num: {:d} Reward: {:f} Avg Reward: {:f}".format(
        total_timesteps, frame_idx, episode_reward, avg_reward ), end="")
        plot_data.append([frame_idx, np.array(reward_list).mean(), np.array(reward_list).std()])
        reward_list = []
        # plt.rcParams['figure.dpi'] = 100
        plt.plot([x[0] for x in plot_data], [x[1] for x in plot_data], '-', color='tab:grey')
        plt.fill_between([x[0] for x in plot_data], [x[1]-x[2] for x in plot_data], [x[1]+x[2] for x in plot_data], alpha=0.2, color='tab:grey')
        plt.xlabel('Episode number')
        plt.ylabel('Episode reward')
        plt.show()
        disp.clear_output(wait=True)
    frame_idx += 1

env.close()"""