#Imports

In [None]:
import os
import math
import torch
import numpy as np
from torch.optim import Adam
from torch.nn import Upsample
import torch.nn.functional as F
from utils import soft_update, hard_update
from model2 import GaussianPolicy, QNetwork, VNetwork

#Original Soft Actor Critic (SAC)

In [None]:
class SAC(object):
    def __init__(self, num_inputs, action_space, args):
        self.args = args
        self.num_inputs = num_inputs
        self.gamma = args.gamma
        self.tau = args.tau
        self.alpha = args.alpha
        self.action_res = args.action_res
        self.target_update_interval = args.target_update_interval
        self.automatic_entropy_tuning = args.automatic_entropy_tuning
        self.device = torch.device("cuda" if args.cuda else "cpu")
        self.exp_upsample_list = [Upsample(scale_factor=i, mode='bicubic', align_corners=True) for i in [1, 2, 4, 8]]

        # for reward normalization
        self.momentum = args.momentum
        self.mean = 0.0
        self.var = 1.0

        self.last_state_batch = None
        self.old_std = None

        # critic
        self.upsampled_action_res = args.action_res * args.action_res_resize
        self.critic = QNetwork(num_inputs, self.action_res,\
            self.upsampled_action_res, args.hidden_size).to(device=self.device)
        self.critic_target = QNetwork(num_inputs, self.action_res,\
            self.upsampled_action_res, args.hidden_size).to(device=self.device)
        self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)
        hard_update(self.critic_target, self.critic)

        # actor
        self.policy = GaussianPolicy(num_inputs, self.action_res,\
            self.upsampled_action_res, args.residual, args.coarse2fine_bias).to(device=self.device)
        self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

        # auto alpha
        if self.automatic_entropy_tuning:
            self.target_entropy = -torch.Tensor([action_space.shape[0]]).to(self.device).item()
            self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
            self.alpha_optim = Adam([self.log_alpha], lr=args.lr)

    def select_action(self, state, coarse_action=None, task=None):
        state = (torch.FloatTensor(state) / 255.0 * 2.0 - 1.0).to(self.device).unsqueeze(0)
        if coarse_action is not None:
            coarse_action = torch.FloatTensor(coarse_action).to(self.device).unsqueeze(0)
        if task is None or "shapematch" not in task:
            action, _, _, _, mask = self.policy.sample(state, coarse_action)
        else:
            _, _, action, _, mask = self.policy.sample(state, coarse_action)
            action = torch.tanh(action)
        action = action.detach().cpu().numpy()[0]
        if coarse_action is not None:
            mask = mask.detach().cpu().numpy()[0]
        return action, mask, None

    def select_coarse_action(self, state, coarse_action=None, task=None):
        state = (torch.FloatTensor(state) / 255.0 * 2.0 - 1.0).to(self.device).unsqueeze(0)
        if coarse_action is not None:
            coarse_action = torch.FloatTensor(coarse_action).to(self.device).unsqueeze(0)
        if task is None or "shapematch" not in task:
            action, _, mean, std, _ = self.coarse_policy.sample(state, coarse_action)
            action = action.detach().cpu().numpy()[0]
            mean = mean.detach().cpu().numpy()[0]
            std = std.detach().cpu().numpy()[0]
            return action, mean, std
        else:
            _, _, action, _, _ = self.coarse_policy.sample(state, coarse_action)
            action = torch.tanh(action)
            action = action.detach().cpu().numpy()[0]
            return action, None, None

    def reward_normalization(self, rewards):
        # update mean and var for reward normalization
        batch_mean = torch.mean(rewards)
        batch_var = torch.var(rewards)
        self.mean = self.momentum * self.mean + (1 - self.momentum) * batch_mean
        self.var = self.momentum * self.var + (1 - self.momentum) * batch_var
        std = torch.sqrt(self.var)
        normalized_rewards = (rewards - self.mean) / (std + 1e-8)
        return normalized_rewards


    def update_parameters(self, memory, updates, start=False):
        # sample a batch from memory
        (
            state_batch,
            action_batch,
            reward_batch,
            next_state_batch,
            mask_batch
        ) = memory.sample(self.args.batch_size)
        state_batch = (torch.FloatTensor(state_batch) / 255.0 * 2.0 - 1.0).to(self.device)
        next_state_batch = (torch.FloatTensor(next_state_batch) / 255.0 * 2.0 - 1.0).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)
        # normalize rewards
        reward_batch = self.reward_normalization(reward_batch)
        # # SAC
        # critic
        with torch.no_grad():
            if self.args.residual:
                next_original_action = self.upsample_coarse_action(next_state_batch)
                next_state_pi, next_state_log_pi, _, _, mask = self.policy.sample(next_state_batch, next_original_action)
                next_state_pi = mask * next_state_pi + (1 - mask) * next_original_action.reshape(self.args.batch_size, -1)
            else:
                next_state_pi, next_state_log_pi, _, _, _ = self.policy.sample(next_state_batch)
                #print("next_state_log_pi stats: min", next_state_log_pi.min().item(), "max", next_state_log_pi.max().item(), "mean", next_state_log_pi.mean().item())
            qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_pi)
            # only force fine policy to explore
            min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi
            next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target)
        # two Q-functions to mitigate positive bias in the policy improvement step
        # JQ = 𝔼(st,at)~D[0.5(Q(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf1, qf2 = self.critic(state_batch, action_batch)
        qf1_loss = F.mse_loss(qf1, next_q_value)
        qf2_loss = F.mse_loss(qf2, next_q_value)
        qf_loss = qf1_loss + qf2_loss
        # update critic
        self.critic_optim.zero_grad()
        qf_loss.backward()
        for params in self.critic.parameters():
            torch.nn.utils.clip_grad_norm_(params, max_norm=10)
        self.critic_optim.step()

        # if(self.last_state_batch != None):
        #     print("state dif", (self.last_state_batch - state_batch).mean().item(), "max", (self.last_state_batch - state_batch).max().item(), "min ", (self.last_state_batch - state_batch).min().item())
        # self.last_state_batch = state_batch
        #actor
        if self.args.residual:
            with torch.no_grad():
                coarse_action = self.upsample_coarse_action(state_batch)
            pi, log_pi, _, std, mask = self.policy.sample(state_batch, coarse_action)
            pi = mask * pi + (1 - mask) * coarse_action.reshape(self.args.batch_size, -1)
        else:
            pi, log_pi, _, std, _ = self.policy.sample(state_batch)
        qf1_pi, qf2_pi = self.critic(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)
        # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
        policy_loss_ = ((self.alpha * log_pi) - min_qf_pi).mean()
        if self.args.residual:
            # regularize mask to be close to 0
            mask_regularize_loss = self.args.coarse2fine_penalty *\
                torch.norm(mask.reshape(mask.shape[0], -1), dim=1).mean() / self.args.action_res
            policy_loss = policy_loss_ + mask_regularize_loss
        else:
            policy_loss = policy_loss_
            mask_regularize_loss = torch.zeros(1).to(self.device)
        # update policy
        self.policy_optim.zero_grad()
        policy_loss.backward()
        for params in self.policy.parameters():
            torch.nn.utils.clip_grad_norm_(params, max_norm=10)

        self.policy_optim.step()

        entropy, std = self.policy.entropy(state_batch)

        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()

        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        return qf1_loss.item(), qf2_loss.item(), policy_loss.item(),-torch.mean(log_pi).item(), self.alpha,\
            entropy.mean().item(), mask_regularize_loss.item() #torch.norm(std.reshape(std.shape[0], -1), dim=1).mean().item() / (self.args.action_res**2)

    def upsample_coarse_action(self, state_batch):
        coarse_pi, _, _, _, _ = self.coarse_policy.sample(state_batch)
        coarse_pi = coarse_pi.reshape(self.args.batch_size, 2,\
            self.args.coarse_action_res, self.args.coarse_action_res)
        return self.exp_upsample_list[int(math.log2(self.args.action_res / self.args.coarse_action_res))](coarse_pi)

    # save model parameters
    def save_model(self, filename):
        checkpoint = {
            "mean": self.mean,
            "var": self.var,
            "policy": self.policy.state_dict(),
            "critic": self.critic.state_dict(),
            "critic_optim": self.critic_optim.state_dict(),
            "policy_optim": self.policy_optim.state_dict(),
        }
        torch.save(checkpoint, filename + ".pth")

    # load model parameters
    def load_model(self, filename, for_train=False):
        print('Loading models from {}...'.format(filename))
        checkpoint = torch.load(filename)
        mean = checkpoint.get("mean")
        var = checkpoint.get("var")
        if mean is not None:
            self.mean = mean
        if var is not None:
            self.var = var
        self.policy.load_state_dict(checkpoint["policy"])
        self.critic.load_state_dict(checkpoint["critic"])
        if for_train:
            self.policy_optim.load_state_dict(checkpoint["policy_optim"])
            self.critic_optim.load_state_dict(checkpoint["critic_optim"])

    # load coarse model
    def load_coarse_model(self, filename, action_res):
        print('Loading coarse models from {}...'.format(filename))
        self.coarse_policy = GaussianPolicy(self.num_inputs, action_res,\
            self.upsampled_action_res, False, self.args.coarse2fine_bias).to(self.device)
        checkpoint = torch.load(filename)
        self.coarse_policy.load_state_dict(checkpoint["policy"])


#PPO (Proximal Policy Optimization)

In [None]:
#ctrl /

class PPO(object):
    def __init__(self, num_inputs, action_space, args):
        self.args = args
        self.num_inputs = num_inputs #env.observation_space.shape[0]
        self.gamma = args.gamma
        self.alpha = args.alpha
        self.epsilon = args.epsilon
        self.action_res = args.action_res

        self.upsampled_action_res = args.action_res * args.action_res_resize
        self.automatic_entropy_tuning = args.automatic_entropy_tuning
        self.device = torch.device("cuda" if args.cuda else "cpu")

        # for reward normalization
        self.momentum = args.momentum
        self.mean = 0.0
        self.var = 1.0

        self.entropy_coeff = self.args.entropy_coeff

        #Value
        self.values = VNetwork(num_inputs, self.action_res, args.hidden_size).to(device=self.device)
        self.value_optim = Adam(self.values.parameters(), lr=args.value_lr)

        # policy
        self.policy = GaussianPolicy(num_inputs, self.action_res,\
            self.upsampled_action_res, args.residual, args.coarse2fine_bias,nl1=args.nl1,nl2=args.nl2).to(device=self.device)
        lr=3e-4
        self.policy_optim = Adam(self.policy.parameters(), lr=args.policy_lr)

        #old estate batch for comparisson
        self.old_state_batch = None
        self.old_std = None

        print("weight decay: ", args.weight_decay)
        print("policy lr: ", args.policy_lr)
        print("value lr: ", args.value_lr)
        print("self.args.log_sum_exp: ",args.log_sum_exp)
        print("adv_norm: ",args.adv_norm)
        print("entropy_coeff: ",args.entropy_coeff)
        print("entropy_decay: ",args.entropy_decay)
        print("Gradient clip policy: ",args.max_norm_p)
        print("Gradient clip value: ",args.max_norm_v)
        print("Std initial weight noise: ",args.nl1)
        print("Mean initial weights noise: ",args.nl2)
        print("TD error for value loss calculation", args.td_error)


    def select_action(self, state, coarse_action=None, task=None):
        mean = None
        std = None
        state = (torch.FloatTensor(state) / 255.0 * 2.0 - 1.0).to(self.device).unsqueeze(0)
        if coarse_action is not None:
            coarse_action = torch.FloatTensor(coarse_action).to(self.device).unsqueeze(0)
        if task is None or "shapematch" not in task:
            action, log_prob, mean, std, mask = self.policy.sample(state, coarse_action)
            #print("log prob ", log_prob)
        else:
            _, log_prob, action, _, mask = self.policy.sample(state, coarse_action)
            action = torch.tanh(action)
        action = action.detach().cpu().numpy()[0]
        if coarse_action is not None:
            mask = mask.detach().cpu().numpy()[0]

        return action, mask , log_prob, mean, std

    def log_prob(self, state, action, task=None):
        state = (torch.FloatTensor(state) / 255.0 * 2.0 - 1.0).to(self.device).unsqueeze(0)
        action = torch.FloatTensor(action).to(self.device).unsqueeze(0)
        log_pi = self.policy.log_prob(state,action)
        return log_pi


    def reward_normalization(self, rewards):
        # update mean and var for reward normalization
        batch_mean = torch.mean(rewards)
        batch_var = torch.var(rewards)
        self.mean = self.momentum * self.mean + (1 - self.momentum) * batch_mean
        self.var = self.momentum * self.var + (1 - self.momentum) * batch_var
        std = torch.sqrt(self.var)
        normalized_rewards = (rewards - self.mean) / (std + 1e-8)
        return normalized_rewards

    def advantage_normalization(self, advantages):
        # update mean and var for reward normalization
        batch_mean = torch.mean(advantages)
        batch_var = torch.var(advantages)
        std = torch.sqrt(batch_var)
        #normalized_advantages = (advantages - batch_mean) / (std + 1e-8)
        normalized_advantages = (advantages) / (std + 1e-8)
        return (normalized_advantages * 10)

    def precompute_batches(self, memory, num_batches):
      batches = []

      for _ in range(num_batches):
          (
              state_batch,
              action_batch,
              reward_batch,
              next_state_batch,
              mask_batch,
              old_log_prob_batch,
              reward_l_batch
          )= memory.sample(self.args.batch_size)
          state_batch = (torch.FloatTensor(state_batch) / 255.0 * 2.0 - 1.0).to(self.device)

          action_batch = torch.FloatTensor(action_batch).to(self.device)
          with torch.no_grad():
            log_prob, _ = self.policy.log_prob(state_batch, action_batch)

          # Store everything needed for later updates
          batches.append({
              'state_batch': state_batch,
              'action_batch': action_batch,
              'reward_batch': torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1),
              'next_state_batch': (torch.FloatTensor(next_state_batch) / 255.0 * 2.0 - 1.0).to(self.device),
              'mask_batch': torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1),
              'old_log_prob_batch': log_prob,
              'reward_l_batch': torch.FloatTensor(reward_l_batch).to(self.device).unsqueeze(1),
          })
      return batches

    def log_sum_exp(self,log_probs):
      max_log_prob = torch.max(log_probs)  # Use torch.max for PyTorch tensors
      stable_log_probs = log_probs - max_log_prob  # Stabilize the log probabilities
      sum_exp = torch.sum(torch.exp(stable_log_probs))  # Sum of the exponentials
      return max_log_prob + torch.log(sum_exp)  # Log of summed exponentials


    def update_parameters(self, memory, batch,start):
        #batches = self.precompute_batches(memory, 15)
          torch.autograd.set_detect_anomaly(True)

          state_batch = batch['state_batch']
          action_batch = batch['action_batch']
          reward_batch = batch['reward_batch']
          next_state_batch = batch['next_state_batch']
          mask_batch = batch['mask_batch']
          old_log_prob_batch = batch['old_log_prob_batch']
          reward_l_batch = batch['reward_l_batch']
          #reward_l_batch = self.reward_normalization(reward_l_batch)
          
          #Aˆt = Rt (λ) −V (st ).
          est_values = self.values(state_batch)
          advantages_batch = reward_l_batch - est_values #torch.exp(reward_l_batch - values_batch)

          log_pi, std  = (self.policy.log_prob(state_batch, action_batch,start=start))

          if(self.args.loud):
            if(self.old_std != None):
              #print("shape", std.shape)
              print("(self.old_std - std) stats: min", (self.old_std - std).min().item(), "max", (self.old_std - std).max().item(), "mean", abs((self.old_std - std)).mean().item())
              #print((self.old_std - std).mode())
            self.old_std = std
            if(self.old_state_batch != None):
              print("(state_batch - old_state_batch) stats: min", (state_batch - self.old_state_batch).min().item(), "max", (state_batch - self.old_state_batch).max().item(), "mean", (state_batch - self.old_state_batch).mean().item())
            self.old_state_batch = state_batch
            print("advantages_batch stats: min", advantages_batch.min().item(), "max", advantages_batch.max().item(), "mean", advantages_batch.mean().item())
            print("(log_pi - old_log_prob_batch) stats: min", (log_pi - old_log_prob_batch).min().item(), "max", (log_pi - old_log_prob_batch).max().item(), "mean", (log_pi - old_log_prob_batch).mean().item())
            print("log_pi stats: min", log_pi.min().item(), "max", log_pi.max().item(), "mean", log_pi.mean().item())
          # Step 2: Normalize the log probabilities
          if(self.args.log_sum_exp):
            log_sum_exp = torch.logsumexp(log_pi, dim=0)
            log_pi = log_pi - log_sum_exp
            old_log_sum_exp = torch.logsumexp(old_log_prob_batch, dim=0)
            old_log_prob_batch = old_log_prob_batch - old_log_sum_exp
          if(self.args.adv_norm):
            advantages_batch = self.advantage_normalization(advantages_batch)
          #print("advantages_batch: min", advantages_batch.min().item(), "max", advantages_batch.max().item(), "mean", advantages_batch.mean().item())
          #****you can also normalize the advantages i guess but wait probably dont do that I think

          diff_log_prob = log_pi - old_log_prob_batch
          diff_log_prob = torch.clamp(diff_log_prob, -10.0, 10.0)
          # # Normalize the differences
          # #this is a bad idea cause like one will be not 0 and that one will have a huge thing but ig will be clamped idk

          # Apply exponential function to the normalized differences
          ratios = torch.exp(diff_log_prob)
          if(self.args.loud):
            print("diff_log_prob stats: min", diff_log_prob.min().item(), "max", diff_log_prob.max().item(), "mean", diff_log_prob.mean().item())
            print(" ratios stats: min", ratios.min().item(), "max", ratios.max().item(), "mean", ratios.mean().item())

          par1 = ratios * advantages_batch
          par2 = torch.clamp(ratios, 1.0 - self.epsilon, 1.0 + self.epsilon) * advantages_batch
          policy_loss = -torch.min(par1, par2).mean()

          entropy, _ = self.policy.entropy(state_batch)

          # # Add entropy bonus to the policy loss
          self.entropy_coeff = self.entropy_coeff * self.args.entropy_decay #0.999#self.args.entropy_coeff
          policy_loss = policy_loss - (self.entropy_coeff * entropy)#(100 * std.std())#(self.entropy_coeff * entropy)
          #print((-torch.min(par1, par2)).shape, entropy.mean(dim=1, keepdim=True).shape)
          policy_loss = (-torch.min(par1, par2) - (self.entropy_coeff * entropy.mean(dim=1, keepdim=True))).mean()
          if(self.args.loud):
            print("entropy",entropy.mean().item())
            print("policy loss", policy_loss.item())
            #print("entropy loss", (self.args.entropy_coeff * entropy).item())
            print("Std std", std.std().item())


          # update policy
          self.policy_optim.zero_grad()
          policy_loss.backward()

          for params in self.policy.parameters():
              torch.nn.utils.clip_grad_norm_(params, max_norm=self.args.max_norm_p)
          self.policy_optim.step()

          #update value function
          if(self.args.rl_norm):
              reward_l_batch = self.reward_normalization(reward_l_batch)

          est_values = self.values(state_batch)
          # if(self.args.td_error):
          #   V_next = self.values(next_state_batch)
          #   v_loss = (reward_l_batch + 0.95 *(V_next - est_values)).pow(2).mean() # Compute TD error
          # else:
          v_loss = ((est_values - reward_l_batch) ** 2).mean() / 1000

          if(self.args.loud):
            print("value loss", v_loss.item())
            print("est values", est_values.mean().item())
            print("reward batchs",reward_l_batch.mean().item())

          self.value_optim.zero_grad()
          v_loss.backward()
          for params in self.values.parameters():
              torch.nn.utils.clip_grad_norm_(params, max_norm= self.args.max_norm_v)
          self.value_optim.step()

          return v_loss.item(), policy_loss.item(), entropy.mean().item(), (abs(self.entropy_coeff * entropy.mean())/(abs(self.entropy_coeff * entropy.mean()) + abs(policy_loss)))#ratios.min().item()#torch.norm(std.reshape(std.shape[0], -1), dim=1).mean().item() / (self.args.action_res**2)

    # save model parameters
    def save_model(self, filename):
        checkpoint = {
            "mean": self.mean,
            "var": self.var,
            "policy": self.policy.state_dict(),
            "values": self.values.state_dict(),
            "value_optim": self.value_optim.state_dict(),
            "policy_optim": self.policy_optim.state_dict(),
        }
        torch.save(checkpoint, filename + ".pth")

    # load model parameters
    def load_model(self, filename, for_train=False):
        print('Loading models from {}...'.format(filename))
        checkpoint = torch.load(filename)
        mean = checkpoint.get("mean")
        var = checkpoint.get("var")
        if mean is not None:
            self.mean = mean
        if var is not None:
            self.var = var
        self.policy.load_state_dict(checkpoint["policy"])
        self.values.load_state_dict(checkpoint["value"])
        if for_train:
            self.policy_optim.load_state_dict(checkpoint["policy_optim"])
            self.value_optim.load_state_dict(checkpoint["value_optim"])

    # load coarse model
    def load_coarse_model(self, filename, action_res):
        print('Loading coarse models from {}...'.format(filename))
        self.coarse_policy = GaussianPolicy(self.num_inputs, action_res,\
            self.upsampled_action_res, False, self.args.coarse2fine_bias).to(self.device)
        checkpoint = torch.load(filename)
        self.coarse_policy.load_state_dict(checkpoint["policy"])


#A2C

In [None]:
class A2C(object):
    def __init__(self, num_inputs, action_space, args):

        self.args = args
        self.num_inputs = num_inputs #env.observation_space.shape[0]
        self.gamma = args.gamma
        self.alpha = args.alpha
        self.epsilon = args.epsilon
        self.action_res = args.action_res
        self.automatic_entropy_tuning = args.automatic_entropy_tuning
        self.device = torch.device("cuda" if args.cuda else "cpu")

        # for reward normalization
        self.momentum = args.momentum
        self.mean = 0.0
        self.var = 1.0

        self.entropy_coeff = self.args.entropy_coeff

        self.upsampled_action_res = args.action_res * args.action_res_resize
        #Value
        # self.values = VNetwork(num_inputs, self.action_res, args.hidden_size).to(device=self.device)
        self.values = VNetwork(num_inputs, self.action_res, args.hidden_size).to(device=self.device)
        self.value_optim = Adam(self.values.parameters(), lr=args.lr)

        # policy
        self.policy = GaussianPolicy(num_inputs, self.action_res,\
            self.upsampled_action_res, args.residual, args.coarse2fine_bias).to(device=self.device)
        self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)

        self.policy_scheduler = StepLR(self.policy_optim, step_size=100, gamma=0.99)
        self.value_scheduler = StepLR(self.value_optim, step_size=100, gamma=0.99)
        # auto alpha
        if self.automatic_entropy_tuning:
            self.target_entropy = -torch.Tensor([action_space.shape[0]]).to(self.device).item()
            self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
            self.alpha_optim = Adam([self.log_alpha], lr=args.lr)

    def select_action(self, state, coarse_action=None, task=None):
        mean = None
        std = None
        state = (torch.FloatTensor(state) / 255.0 * 2.0 - 1.0).to(self.device).unsqueeze(0)
        if coarse_action is not None:
            coarse_action = torch.FloatTensor(coarse_action).to(self.device).unsqueeze(0)
        if task is None or "shapematch" not in task:
            action, log_prob, mean, std, mask = self.policy.sample(state, coarse_action)
        else:
            _, log_prob, action, _, mask = self.policy.sample(state, coarse_action)
            action = torch.tanh(action)
        action = action.detach().cpu().numpy()[0]
        if coarse_action is not None:
            mask = mask.detach().cpu().numpy()[0]
        # log_probs = torch.clamp(log_prob, min=-20, max=0)
        return action, mask , log_prob, mean, std


    def reward_normalization(self, rewards):
    # Clipping the rewards to a certain range can stabilize training
        rewards = torch.clamp(rewards, min=-1, max=1)
        batch_mean = torch.mean(rewards)
        batch_var = torch.var(rewards)
        self.mean = self.momentum * self.mean + (1 - self.momentum) * batch_mean
        self.var = self.momentum * self.var + (1 - self.momentum) * batch_var
        std = torch.sqrt(self.var + 1e-8)
        normalized_rewards = (rewards - self.mean) / std
        return normalized_rewards


    def advantage_normalization(self, advantages):
        # update mean and var for reward normalization
        batch_mean = torch.mean(advantages)
        batch_var = torch.var(advantages)
        std = torch.sqrt(batch_var)
        normalized_advantages = (advantages - batch_mean) / (std + 1e-8)
        return normalized_advantages


    def update_parameters(self, memory, updates):
        (
            state_batch,
            action_batch,
            reward_batch,
            next_state_batch,
            mask_batch,
            old_log_prob_batch,
            reward_l_batch
        ) = memory.sample(self.args.batch_size)
        state_batch = (torch.FloatTensor(state_batch) / 255.0 * 2.0 - 1.0).to(self.device)
        next_state_batch = (torch.FloatTensor(next_state_batch) / 255.0 * 2.0 - 1.0).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)
        reward_l_batch = torch.FloatTensor(reward_l_batch).to(self.device).unsqueeze(1)

        # print("reward_l: ", reward_l_batch.cpu().numpy())

        reward_batch = self.reward_normalization(reward_batch)

        # Calculate the value targets
        values_batch = self.values(state_batch)
        next_values = self.values(next_state_batch)


        returns = reward_l_batch

        # Calculate advantages
        advantages = reward_l_batch - values_batch
        advantages = self.advantage_normalization(advantages)

        # Update critic
        critic_loss = F.mse_loss(values_batch, returns)

        self.value_optim.zero_grad()
        critic_loss.backward()
        for params in self.values.parameters():
            torch.nn.utils.clip_grad_norm_(params, max_norm=10)
        self.value_optim.step()
        self.value_scheduler.step(critic_loss)

        pi, log_pi, _, std, _= self.policy.sample(state_batch)

        policy_loss = -(advantages.detach() * log_pi).mean() - self.entropy_coeff * self.policy.entropy(state_batch).mean()

        mask_regularize_loss = torch.zeros(1).to(self.device)

        self.policy_optim.zero_grad()
        policy_loss.backward()
        for params in self.policy.parameters():
            torch.nn.utils.clip_grad_norm_(params, max_norm=5)
        self.policy_optim.step()
        self.policy_scheduler.step(policy_loss)

        entropy = self.policy.entropy(state_batch)

        return critic_loss.item(), policy_loss.item(), -torch.mean(log_pi).item(), self.alpha,\
          torch.norm(std.reshape(std.shape[0], -1), dim=1).mean().item() / (self.args.action_res**2), mask_regularize_loss.item()


 # save model parameters
    def save_model(self, filename):
        checkpoint = {
            "mean": self.mean,
            "var": self.var,
            "policy": self.policy.state_dict(),
            "values": self.values.state_dict(),
            "value_optim": self.value_optim.state_dict(),
            "policy_optim": self.policy_optim.state_dict(),
        }
        torch.save(checkpoint, filename + ".pth")

    def load_model(self, filename, for_train=False):
        print('Loading models from {}...'.format(filename))
        checkpoint = torch.load(filename)
        mean = checkpoint.get("mean")
        var = checkpoint.get("var")
        if mean is not None:
            self.mean = mean
        if var is not None:
            self.var = var
        self.policy.load_state_dict(checkpoint["policy"])
        self.values.load_state_dict(checkpoint["value"])
        if for_train:
            self.policy_optim.load_state_dict(checkpoint["policy_optim"])
            self.value_optim.load_state_dict(checkpoint["value_optim"])

    def load_coarse_model(self, filename, action_res):
        print('Loading coarse models from {}...'.format(filename))
        self.coarse_policy = GaussianPolicy(self.num_inputs, action_res,\
            self.upsampled_action_res, False, self.args.coarse2fine_bias).to(self.device)
        checkpoint = torch.load(filename)
        self.coarse_policy.load_state_dict(checkpoint["policy"])


