In [2]:
import gymnasium as gym
#import gym_oscillator
#import oscillator_cpp
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
import os
import numpy as np
import scipy.stats as ss
import scipy
import matplotlib.pyplot as plt

In [4]:
base_model_path = "/Users/ShawnXu/research-local/RL_quantization_eval/rl-trained-agents/"
algo = "PPO"
env_id = "MountainCarContinuous-v0_1"
file_name = "MountainCarContinuous-v0.zip"

n_timesteps = 50000

model_path = base_model_path+algo+"/"+env_id+"/"+file_name

In [5]:
model = PPO.load(model_path)
model_q = PPO.load(model_path)

Exception: an integer is required (got type bytes)
Exception: an integer is required (got type bytes)
Exception: an integer is required (got type bytes)


In [7]:
env = make_vec_env("MountainCarContinuous-v0", n_envs=1)
obs = env.reset()

In [8]:
policy_net = model_q.policy.mlp_extractor.policy_net

In [9]:
from torch.quantization import quantize_dynamic
from torch import nn
import torch

policy_net.eval()

# Dynamic Quantization
# model_quantized = quantize_dynamic(
#     model=policy_net, qconfig_spec={nn.Linear}, dtype=torch.qint8, inplace=False
# )

# Static Quantization
policy_net = nn.Sequential(torch.quantization.QuantStub(), 
                  *policy_net, 
                  torch.quantization.DeQuantStub())

"""Prepare"""
backend = "x86"
policy_net.qconfig = torch.quantization.get_default_qconfig(backend)
torch.quantization.prepare(policy_net, inplace=True)



Sequential(
  (0): QuantStub(
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (1): Linear(
    in_features=2, out_features=64, bias=True
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (2): Tanh()
  (3): Linear(
    in_features=64, out_features=64, bias=True
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (4): Tanh()
  (5): DeQuantStub()
)

In [10]:
model_q.policy.mlp_extractor.policy_net = policy_net
print("Before Calibration: ", policy_net)

# Calibration with enviorment 
with torch.inference_mode():
    for _ in range(n_timesteps):
        action, _ = model_q.predict(obs)
        obs, reward, done, infos = env.step(action)
print("After Calibration: ", policy_net)

torch.quantization.convert(policy_net, inplace=True)

Before Calibration:  Sequential(
  (0): QuantStub(
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (1): Linear(
    in_features=2, out_features=64, bias=True
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (2): Tanh()
  (3): Linear(
    in_features=64, out_features=64, bias=True
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (4): Tanh()
  (5): DeQuantStub()
)
After Calibration:  Sequential(
  (0): QuantStub(
    (activation_post_process): HistogramObserver(min_val=-0.8889830112457275, max_val=2.3281033039093018)
  )
  (1): Linear(
    in_features=2, out_features=64, bias=True
    (activation_post_process): HistogramObserver(min_val=-1.032845139503479, max_val=7.005014896392822)
  )
  (2): Tanh()
  (3): Linear(
    in_features=64, out_features=64, bias=True
    (activation_post_process): HistogramObserver(min_val=-0.9405579566955566, max_val=6.739593029022217)
  )
  (4): Tanh()
  (5): 

Sequential(
  (0): Quantize(scale=tensor([0.0072]), zero_point=tensor([124]), dtype=torch.quint8)
  (1): QuantizedLinear(in_features=2, out_features=64, scale=0.01523539423942566, zero_point=67, qscheme=torch.per_channel_affine)
  (2): Tanh()
  (3): QuantizedLinear(in_features=64, out_features=64, scale=0.012549459002912045, zero_point=75, qscheme=torch.per_channel_affine)
  (4): Tanh()
  (5): DeQuantize()
)

In [11]:
# Int represenation of weight
#print(torch.int_repr(model_quantized[0].weight()))

# qint8 representation of weight
#print(model_quantized[0].weight())

# Assigne quantized network back to model
model_q.policy.mlp_extractor.policy_net = policy_net

In [12]:
# for _ in range(n_timesteps):
#     action, _states = model.predict(obs)
#     obs, rewards, dones, info = env.step(action)
#     print("Obs: ", obs)
#     print("Action: ", action)
#     print("Reward: ", rewards)
    
#     action_q, _states_q = model_q.predict(obs)
#     obs_q, rewards_q, dones_q, info_q = env.step(action_q)
#     print("Obs_q: ", obs_q)
#     print("Action_q: ", action_q)
#     print("Reward_q: ", rewards_q)
    
#     print("------------------------")

In [13]:
def kl_scipy(p, q):
    p = np.asarray(p, dtype=np.float32)
    q = np.asarray(q, dtype=np.float32)
    p = p.flatten()
    q = q.flatten()
    p[p==0] = np.finfo(float).eps
    q[q==0] = np.finfo(float).eps
    if(len(p)>1):
        pg = ss.gaussian_kde(p)
        qg = ss.gaussian_kde(q)
        kl = ss.entropy(pg(p),qg(q))
        print("p,q",ss.entropy(pg(p),qg(q)))
        print("len of p",len(p))
        return kl
    else:
        return 0

In [14]:
successes = []
action_q_list = []
action_list = []
kl_list = []
deterministic = True
episode_reward = 0.0 
episode_rewards = []
ep_len = 0
is_atari = False
verbose = 1

print("Shape of network input: ", obs)

for _ in range(n_timesteps):
    action, _ = model.predict(obs, deterministic=deterministic)
    action_q, _ = model_q.predict(obs, deterministic=deterministic)
    action_list.append(action.flatten().tolist())
    action_q_list.append(action_q.flatten().tolist())

    # Random Agent
    # action = [env.action_space.sample()]
    # Clip Action to avoid out of bound errors
    if isinstance(env.action_space, gym.spaces.Box):
        action = np.clip(action, env.action_space.low, env.action_space.high)

    # Take an action
    obs, reward, done, infos = env.step(action)

    episode_reward += reward[0]
    ep_len += 1

    # For atari the return reward is not the atari score
    # so we have to get it from the infos dict
    if is_atari and infos is not None and verbose >= 1:
        episode_infos = infos[0].get('episode')
        if episode_infos is not None:
            print("Atari Episode Score: {:.2f}".format(episode_infos['r']))
            print("Atari Episode Length", episode_infos['l'])

            # calculate KL-divergence
            flat_action = [item for sublist in action_list for item in sublist]
            flat_action_q = [item for sublist in action_q_list for item in sublist]
            kl_list.append(kl_scipy(flat_action, flat_action_q))
            plt.hist(flat_action, bins=20, label='action')
            plt.hist(flat_action_q, bins=20, label='action_q')
            plt.legend()
            # save the figure
            # append the env-name to the file-name
            # appen algo name to the file-name
            plt.savefig(os.path.join('./action_hist_' + env_id + '_' + algo + '.png'))
            plt.close()
            flat_action = []
            flat_action_q = []
            action_list = []
            action_q_list = []

    if done and not is_atari and verbose > 0:
        # NOTE: for env using VecNormalize, the mean reward
        # is a normalized reward when `--norm_reward` flag is passed
        print("Episode Reward: {:.2f}".format(episode_reward))
        print("Episode Length", ep_len)
        episode_rewards.append(episode_reward)
        episode_reward = 0.0
        ep_len = 0

        # calculate KL-divergence
        flat_action = [item for sublist in action_list for item in sublist]
        flat_action_q = [item for sublist in action_q_list for item in sublist]
        kl_list.append(kl_scipy(flat_action, flat_action_q))
        plt.hist(flat_action, bins=20, label='action')
        plt.hist(flat_action_q, bins=20, label='action_q')
        plt.legend()
        # save the figure
        # append the env-name to the file-name
        plt.savefig(os.path.join('action_hist_' + env_id + '_' + algo + '.png'))
        plt.close()
        flat_action = []
        flat_action_q = []
        action_list = []
        action_q_list = []

    # Reset also when the goal is achieved when using HER
    if done or infos[0].get('is_success', False):
        if algo == 'her' and verbose > 1:
            print("Success?", infos[0].get('is_success', False))
        # Alternatively, you can add a check to wait for the end of the episode
        # if done:
        obs = env.reset()
        if algo == 'her':
            successes.append(infos[0].get('is_success', False))
            episode_reward, ep_len = 0.0, 0

print("Success rate: {:.2f}%".format(100 * np.mean(successes)))
print("Mean reward: {:.2f}".format(np.mean(episode_rewards)))

env.close()

# calculate kl-divergence over action dist
# get the mean of a list
print("KL-Lists:", kl_list)
mean_kl = np.mean(kl_list)
print("Mean KL-Divergence: {:.5f}".format(mean_kl))

Shape of network input:  [[-0.7115588   0.00665341]]
Episode Reward: -24.05
Episode Length 949
p,q 0.0014349046626484048
len of p 949
Episode Reward: -28.23
Episode Length 999
p,q 0.0045029488865046385
len of p 999
Episode Reward: -24.49
Episode Length 999
p,q 0.027690900455145454
len of p 999
Episode Reward: -27.19
Episode Length 999
p,q 0.0006996311277053205
len of p 999
Episode Reward: -25.15
Episode Length 999
p,q 0.00164830262245456
len of p 999
Episode Reward: -27.36
Episode Length 999
p,q 0.0007363660432974928
len of p 999
Episode Reward: -25.10
Episode Length 999
p,q 0.00196165537118772
len of p 999
Episode Reward: -24.95
Episode Length 999
p,q 0.0027245523460435375
len of p 999
Episode Reward: -27.06
Episode Length 999
p,q 0.0008633360153543564
len of p 999
Episode Reward: -24.53
Episode Length 999
p,q 0.012980655537360112
len of p 999
Episode Reward: -24.77
Episode Length 999
p,q 0.003474617991094223
len of p 999
Episode Reward: -24.80
Episode Length 999
p,q 0.003635368750619

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
