In [1]:
# auto reload
%load_ext autoreload
%autoreload 2

In [40]:
# import bc stuff
import os
import pickle
import argparse
import torch
import torch.nn as nn

import tree
from acme import wrappers
from dm_control import suite

from src.environment import NormilizeActionSpecWrapper, MujocoActionNormalizer
from src.bc_net import BCNetworkContinuous, BCNetworkContinuousGaussian
from src.sac import GaussianPolicy
from src.bc_utils import evaluate_network_mujoco, evaluate_network_mujoco_stochastic

import matplotlib.pyplot as plt
import numpy as np


In [None]:
# constants:

rollout_path = '../../data/rollouts/cheetah_123456_10000_actnoise080/rollouts.pkl'
lr = 3e-4
epochs = 10
batch_size = 16

# define the scaling factors
MSE_SCALING = 1
KL_SCALING = 0.001
ENTROPY_SCALING = 0.001

In [None]:
env = suite.load(domain_name="cheetah", task_name="run")

env = NormilizeActionSpecWrapper(env)
env = MujocoActionNormalizer(environment=env, rescale='clip')
env = wrappers.SinglePrecisionWrapper(env)

In [None]:
# get the dimensionality of the observation_spec after flattening
flat_obs = tree.flatten(env.observation_spec())
# combine all the shapes
obs_dim = sum([item.shape[0] for item in flat_obs])

# load the rollouts
with open(rollout_path, 'rb') as f:
    rollouts = pickle.load(f)

In [None]:
# initialize the network
# network = BCNetworkContinuousGaussian(obs_dim, env.action_spec().shape[0])

network = GaussianPolicy(obs_dim, env.action_spec().shape[0], hidden_dim=256)

guide_dist = torch.distributions.Normal(torch.zeros(env.action_spec().shape[0]), torch.ones(env.action_spec().shape[0]))

# define the optimizer
optimizer = torch.optim.Adam(network.parameters(), lr=lr)

# define the loss function for reparmetrization trick
mse_loss_fn = nn.MSELoss()

# define the number of epochs
num_epochs = epochs

# define the batch size
batch_size = batch_size

# define the number of batches
num_batches = len(rollouts.obs) // batch_size

# convert the data to tensors
obs = torch.tensor(rollouts.obs, dtype=torch.float32).squeeze()
action = torch.tensor(rollouts.action, dtype=torch.float32).squeeze()

In [None]:
# train the network with reparametrization trick

total_mse_loss_arr = []
total_kl_div_arr = []
total_entropy_loss_arr = []
total_loss_arr = []

for epoch in range(num_epochs):
    epoch_mse_loss_arr = []
    epoch_kl_div_arr = []
    epoch_entropy_loss_arr = []
    epoch_loss_arr = []

    for batch in range(num_batches):
        # get the batch
        batch_obs = obs[batch * batch_size:(batch + 1) * batch_size]
        batch_action = action[batch * batch_size:(batch + 1) * batch_size]

        # print(batch_obs.shape)

        # sample from the network
        sampled_action, log_prob, mean = network.sample(batch_obs)

        # compute the mse loss
        mse_loss = mse_loss_fn(sampled_action, batch_action)

        # compute the kl divergence
        guide_log_prob = guide_dist.log_prob(sampled_action)
        kl_div = torch.mean(log_prob - guide_log_prob)

        # compute the entropy
        entropy = torch.mean(-log_prob)

        # compute the loss
        loss = mse_loss*MSE_SCALING + kl_div*KL_SCALING + entropy*ENTROPY_SCALING

        # backpropagate the loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # log the losses
        epoch_mse_loss_arr.append(mse_loss.detach().cpu().item())
        epoch_kl_div_arr.append(kl_div.detach().cpu().item())
        epoch_entropy_loss_arr.append(entropy.detach().cpu().item())
        epoch_loss_arr.append(loss.detach().cpu().item())

    # log the losses
    total_mse_loss_arr.append(np.mean(epoch_mse_loss_arr))
    total_kl_div_arr.append(np.mean(epoch_kl_div_arr))
    total_entropy_loss_arr.append(np.mean(epoch_entropy_loss_arr))
    total_loss_arr.append(np.mean(epoch_loss_arr))

    # print the loss
    print(f'Epoch: {epoch + 1}, Loss: {loss.item():.4f}')

In [None]:
# save the network
loss_names = ['mse', 'kl', 'entropy']
loss_name_portion = '_'.join(loss_names)
# from rollout_path, get the dataset name
dataset_name = rollout_path.split('/')[-2]
model_filename = f'../../data/models/bc_{dataset_name}_{loss_name_portion}.pt'
print('saving model to: ', model_filename)
# save the network
torch.save(network.state_dict(), model_filename)

In [42]:
# evaluate the network
evaluate_network_mujoco_stochastic(network, env, num_episodes=10)

Mean reward: 686.3048532171321 Num episodes: 10


686.3048532171321

In [None]:
# plot the distribution of the kl_div, mse_loss, entropy_loss, loss with shared x axis
fig, axs = plt.subplots(4, 1, sharex=True, figsize=(8, 8*4))
axs[0].plot(total_mse_loss_arr)
axs[0].set_title('MSE Loss')
axs[1].plot(total_kl_div_arr)
axs[1].set_title('KL Divergence')
axs[2].plot(total_entropy_loss_arr)
axs[2].set_title('Entropy Loss')
axs[3].plot(total_loss_arr)
axs[3].set_title('Total Loss')

# label x axis
plt.xlabel('Epoch')

# ensure at least 5 ticks on each y axis for each of the subplots
for ax in axs:
    ax.locator_params(axis='y', nbins=5)

plt.show()

In [None]:
# plot kl_div and flipped entropy_loss
plt.plot(total_kl_div_arr)
plt.plot(-np.array(total_entropy_loss_arr))
plt.legend(['KL Divergence', 'Entropy Loss'])
plt.show()

### Train with mse and entropy loss

In [None]:
# initialize the network
# network = BCNetworkContinuousGaussian(obs_dim, env.action_spec().shape[0])

network = GaussianPolicy(obs_dim, env.action_spec().shape[0], hidden_dim=256)

guide_dist = torch.distributions.Normal(torch.zeros(env.action_spec().shape[0]), torch.ones(env.action_spec().shape[0]))

# define the optimizer
optimizer = torch.optim.Adam(network.parameters(), lr=lr)

# define the loss function for reparmetrization trick
mse_loss_fn = nn.MSELoss()

# define the number of epochs
num_epochs = epochs

# define the batch size
batch_size = batch_size

# define the number of batches
num_batches = len(rollouts.obs) // batch_size

# convert the data to tensors
obs = torch.tensor(rollouts.obs, dtype=torch.float32).squeeze()
action = torch.tensor(rollouts.action, dtype=torch.float32).squeeze()

In [None]:
# train the network with reparametrization trick

total_mse_loss_arr = []
# total_kl_div_arr = []
total_entropy_loss_arr = []
total_loss_arr = []

for epoch in range(num_epochs):
    epoch_mse_loss_arr = []
    # epoch_kl_div_arr = []
    epoch_entropy_loss_arr = []
    epoch_loss_arr = []

    for batch in range(num_batches):
        # get the batch
        batch_obs = obs[batch * batch_size:(batch + 1) * batch_size]
        batch_action = action[batch * batch_size:(batch + 1) * batch_size]

        # print(batch_obs.shape)

        # sample from the network
        sampled_action, log_prob, mean = network.sample(batch_obs)

        # compute the mse loss
        mse_loss = mse_loss_fn(sampled_action, batch_action)

        # # compute the kl divergence
        # guide_log_prob = guide_dist.log_prob(sampled_action)
        # kl_div = torch.mean(log_prob - guide_log_prob)

        # compute the entropy
        entropy = torch.mean(-log_prob)

        # compute the loss
        loss = mse_loss*MSE_SCALING + entropy*ENTROPY_SCALING

        # backpropagate the loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # log the losses
        epoch_mse_loss_arr.append(mse_loss.detach().cpu().item())
        # epoch_kl_div_arr.append(kl_div.detach().cpu().item())
        epoch_entropy_loss_arr.append(entropy.detach().cpu().item())
        epoch_loss_arr.append(loss.detach().cpu().item())

    # log the losses
    total_mse_loss_arr.append(np.mean(epoch_mse_loss_arr))
    # total_kl_div_arr.append(np.mean(epoch_kl_div_arr))
    total_entropy_loss_arr.append(np.mean(epoch_entropy_loss_arr))
    total_loss_arr.append(np.mean(epoch_loss_arr))

    # print the loss
    print(f'Epoch: {epoch + 1}, Loss: {loss.item():.4f}')

In [None]:
# save the network
loss_names = ['mse', 'entropy']
loss_name_portion = '_'.join(loss_names)
# from rollout_path, get the dataset name
dataset_name = rollout_path.split('/')[-2]
model_filename = f'../../data/models/bc_{dataset_name}_{loss_name_portion}.pt'
print('saving model to: ', model_filename)
# save the network
torch.save(network.state_dict(), model_filename)

In [None]:
# evaluate the network
evaluate_network_mujoco_stochastic(network, env, num_episodes=10)

In [None]:
# plot the distribution of the mse_loss, entropy_loss, loss with shared x axis
fig, axs = plt.subplots(3, 1, sharex=True, figsize=(8, 8*3))
axs[0].plot(total_mse_loss_arr)
axs[0].set_title('MSE Loss')
axs[1].plot(total_entropy_loss_arr)
axs[1].set_title('Entropy Loss')
axs[2].plot(total_loss_arr)
axs[2].set_title('Total Loss')

# label x axis
plt.xlabel('Epoch')

# ensure at least 5 ticks on each y axis for each of the subplots
for ax in axs:
    ax.locator_params(axis='y', nbins=5)

plt.show()

### train with just entropy loss