In [None]:
import numpy as np
import gym
import matplotlib.pyplot as plt

import multiworld
multiworld.register_pygame_envs()

%load_ext autoreload
%autoreload 2

In [None]:
# env = gym.make(
#     "Point2DFixed-v0",
#     render_size=256,
#     images_are_rgb=True,
#     action_scale=0.5,
#     init_pos_range=([-4, -4], [4, 4])
# )
# env = gym.make(
#     "Point2DMazeEvalEasy-v0",
#     render_size=256,
#     images_are_rgb=True,
#     action_scale=0.5,
#     boundary_dist=4,
#     init_pos_range=([-4, -4], [4, 4]),
#     n_bins=32,
# )
# env = gym.make(
#     "Point2DMazeEvalMedium-v0",
#     render_size=256,
#     images_are_rgb=True,
#     action_scale=0.5,
#     boundary_dist=4,
#     init_pos_range=([-4, -4], [4, 4]),
#     n_bins=32,
# )
env = gym.make(
    "Point2DMazeEvalHard-v0",
    render_size=256,
    images_are_rgb=True,
    action_scale=0.5,
    boundary_dist=4,
    init_pos_range=([-4, -4], [4, 4]),
    n_bins=32,
)
# env = gym.make(
#     "Point2DDoubleMazeSingleGoalEval-v0",
#     render_size=256,
#     render_onscreen=False,
#     images_are_rgb=True,
#     action_scale=0.5,
#     init_pos_range=([-16, -1], [16, 1]),
#     n_bins=32,
# )

# env = gym.make(
#     "Point2DRooms-v0",
#     render_size=256,
#     render_onscreen=False,
#     images_are_rgb=True,
#     init_pos_range=([-4, -4], [4, 4]),
#     n_bins=32,
# )

# env = gym.make(
#     "Point2DRoomsLarge-v0",
#     render_size=256,
#     render_onscreen=False,
#     images_are_rgb=True,
#     init_pos_range=([-7, -7], [7, 7]),
#     n_bins=32,
# )

In [None]:
env.reset()
plt.figure(figsize=(12, 12))
plt.imshow(env.render(mode="rgb_array"))

## Collect random exploration data

In [None]:
N_TOTAL_SAMPLES = 50000
EPISODE_LEN = 500
LOG_FREQ = 10

obs, acs, rews, terminals = [], [], [], []

env.reset()
for i in range(N_TOTAL_SAMPLES // EPISODE_LEN):
    env.reset()
    for t in range(EPISODE_LEN):
        ac = env.action_space.sample()
        ob, rew, _, _ = env.step(ac)
        if isinstance(ob, np.ndarray):
            obs.append(ob)
        else:
            obs.append(ob["observation"])
        acs.append(ac)
        rews.append(rew)
        terminals.append(t == EPISODE_LEN - 1)
    if i % LOG_FREQ == 0 and hasattr(env, "bin_counts"):
        plt.imshow(1 / env.bin_counts)
        plt.colorbar()
        plt.show()

## Collect expert data (Pendulum)

In [None]:
from pathlib import Path
seed_dir = Path("/home/justinvyu/doodad-logs/20-12-07-sac-env-name=Pendulum-v0/20-12-07-sac-env_name=Pendulum-v0_2020_12_07_17_11_18_id504665--s579510")

import pickle
import glob
list(glob.iglob(str(seed_dir / "itr*")))

itr0 = seed_dir / "itr_100.pkl"
with open(itr0, "rb") as f:
    data = pickle.load(f)
    
policy = data["exploration/policy"].to("cpu")

N_TOTAL_SAMPLES = 100000
EPISODE_LEN = 100
LOG_FREQ = 10

obs, acs, rews, terminals = [], [], [], []

for i in range(N_TOTAL_SAMPLES // EPISODE_LEN):
    ob = env.reset()
    for t in range(EPISODE_LEN):
        ac = ptu.to_numpy(policy(ptu.from_numpy(ob)).sample())
        ob, rew, _, _ = env.step(ac)
        
        if isinstance(ob, np.ndarray):
            obs.append(ob)
        else:
            obs.append(ob["observation"])
        acs.append(ac)
        rews.append(rew)
        terminals.append(t == EPISODE_LEN - 1)
    if i % LOG_FREQ == 0 and hasattr(env, "bin_counts"):
        plt.imshow(1 / env.bin_counts)
        plt.colorbar()
        plt.show()

In [None]:
obs = np.array(obs)
acs = np.array(acs)
rews = np.array(rews)
terminals = np.array(terminals)

In [None]:
plt.imshow((env.bin_counts - env.bin_counts.mean()) / env.bin_counts.std())
plt.colorbar()

In [None]:
def get_random_batch(obs, acs, terminals, batch_size=128, T=8):
    """
    Parameters
    ----------
    obs : (m x obs dim)   m = dataset size
    acs : (m x action dim)
    terminals : (m x 1)
    
    Returns
    -------
    Starting states s_t : (batch_size x observation dim)
    T step action sequence a_t^T-1 = (a_t, a_t+1, ..., a_t+T-1) : (batch_size x T x action dim)
    Ending states s_t+T : (batch_size x observation dim)
    """
    m = len(obs)
    assert m >= T
        
    episode_boundaries = np.where(terminals == True)[0]
    if episode_boundaries.size == 0: # No terminal states, can just sample wherever
        random_start_idxs = np.random.randint(m - T + 1, size=batch_size).astype(int)
    else:
        end_idxs = np.random.randint(len(episode_boundaries), size=batch_size)
        ends = episode_boundaries[end_idxs]
        starts = np.array([
            (episode_boundaries[idx - 1] if idx > 0 else 0)
            for idx in end_idxs
        ])
        random_start_idxs = np.array([
            np.random.randint(start, end - T + 1)
            for start, end in zip(starts, ends)
        ])
    
    s_t = obs[random_start_idxs]
    action_sequences = np.array([acs[idx:idx + T] for idx in random_start_idxs])
    s_T = obs[random_start_idxs + T]
    return s_t, action_sequences, s_T

def get_shuffled_minibatches(obs, acs, terminals, batch_size=128, T=8):
    m = len(obs)
    assert m >= T
    
    episode_boundaries = np.where(terminals == True)[0]
    end_idxs = np.random.randint(len(episode_boundaries), size=batch_size)
    ends = episode_boundaries[end_idxs]
    starts = np.array([
        (episode_boundaries[idx - 1] if idx > 0 else 0)
        for idx in end_idxs
    ])
    
    start_idxs = np.concatenate([
        np.arange(start, end - T + 1)
        for start, end in zip(starts, ends)
    ])
    s_t = obs[start_idxs]
    action_sequences = np.array([acs[idx:idx + T] for idx in start_idxs])
    s_T = obs[start_idxs + T]
    
    perm = np.random.permutation(len(s_t))
    s_t = s_t[perm]
    action_sequences = action_sequences[perm]
    s_T = s_T[perm]
    
    batched_s_t = np.array_split(s_t, len(s_t) // batch_size)
    batched_action_sequences = np.array_split(action_sequences, len(action_sequences) // batch_size)
    batched_s_T = np.array_split(s_T, len(s_T) // batch_size)

    return batched_s_t, batched_action_sequences, batched_s_T

In [None]:
num_sample_trajs = 5
random_sample_start_idxs = np.random.randint(N_TOTAL_SAMPLES // EPISODE_LEN, size=num_sample_trajs) * EPISODE_LEN

plt.figure(figsize=(8, 8))

plt.imshow(env.render(mode='rgb_array'),
           extent=(-env.boundary_dist, env.boundary_dist, -env.boundary_dist, env.boundary_dist), origin='lower', alpha=0.25, zorder=3)
plt.gca().invert_yaxis()

# Plot starting point
start_xy = obs[random_sample_start_idxs]
end_xy = obs[random_sample_start_idxs + EPISODE_LEN - 1]

for start_idx in random_sample_start_idxs:
    traj = obs[start_idx:start_idx + EPISODE_LEN]
    plt.plot(traj[:, 0], traj[:, 1])

plt.scatter(start_xy[:, 0], start_xy[:, 1], color="blue", marker="*", s=500)
plt.scatter(end_xy[:, 0], end_xy[:, 1], color="green", marker="*", s=500)

plt.show()

In [None]:
from torch import nn
import torch.optim as optim

class GaussianChannelModel(nn.Module):
    def __init__(self, obs_dim, ac_dim, T, learning_rate=3e-4, l2_lambda=0):
        super().__init__()
        self.G_b_model = ptu.build_mlp(
            obs_dim,
            obs_dim * (T * ac_dim) + obs_dim,
            n_layers=3,
            size=512,
            activation='relu',
        )
        self.loss = nn.MSELoss()
        self.learning_rate = learning_rate
        self.l2_lambda = l2_lambda
        self.optimizer = optim.Adam(
            self.G_b_model.parameters(),
            self.learning_rate,
            weight_decay=l2_lambda,
        )

    def forward(self, s_t, action_sequence):
        batch_size = s_t.shape[0]
                
        if isinstance(s_t, np.ndarray):
            s_t = ptu.from_numpy(s_t)
        if isinstance(action_sequence, np.ndarray):
            action_sequence = ptu.from_numpy(action_sequence)
                    
        G_b = self.G_b_model(s_t)
        G, b = G_b[:, :-obs_dim], G_b[:, -obs_dim:]
        G_matrix = G.view(batch_size, obs_dim, T * ac_dim)
        a_vector = action_sequence.view(batch_size, T * ac_dim, 1)
        
        # Batched matrix multiplication
        output = torch.matmul(G_matrix, a_vector).squeeze() + b
        
        return output, G_matrix.detach()

    def G(self, s_t, to_numpy=False):
        batch_size = s_t.shape[0]
        
        if isinstance(s_t, np.ndarray):
            s_t = ptu.from_numpy(s_t)
                    
        G_b = self.G_b_model(s_t)
        G, b = G_b[:, :-obs_dim], G_b[:, -obs_dim:]
        G_matrix = G.view(batch_size, obs_dim, T * ac_dim)
        
        if to_numpy:
            return ptu.to_numpy(G_matrix)
        else:
            return G_matrix
    
    def update(self, s_t, action_sequences, s_T):
        if isinstance(s_T, np.ndarray):
            s_T = ptu.from_numpy(s_T)
        
        pred_s_T, G_matrix = self(s_t, action_sequences)
        loss = self.loss(pred_s_T, s_T)
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss.item()

In [None]:
import torch
import rlkit.misc.pytorch_util as ptu
device = ptu.init_gpu(use_gpu=False)

obs_dim = obs.shape[1]
ac_dim = acs.shape[1]
T = 8

gc_model = GaussianChannelModel(obs_dim, ac_dim, T, l2_lambda=5e-5)
gc_model.to(device)

In [None]:
# Random sampling the dataset

def train_random_sampling():
    N_GRAD_STEPS = 1000
    TRAIN_LOG_FREQ = 1000
    loss_history = []

    for i in range(N_GRAD_STEPS):
        batch_s_t, batch_action_seqs, batch_s_T = get_random_batch(obs, acs, terminals, batch_size=128, T=T)
        loss = gc_model.update(batch_s_t, batch_action_seqs, batch_s_T)
        loss_history.append(loss)
        if i % TRAIN_LOG_FREQ == 0:
            print(f"\n============== Training Step #{i} ==============")
            print("MSE Loss =", loss)

    plt.plot(loss_history)
    plt.show()

In [None]:
# Looping through the entire data set each time

def train_cycle_dataset():
    N_GRAD_STEPS = 9000
    TRAIN_LOG_FREQ = 1000
    loss_history = []
    grad_steps = 0
    i = 0

    batches_s_t, batches_action_seqs, batches_s_T = get_shuffled_minibatches(obs, acs, terminals, batch_size=128, T=T)
    assert len(batches_s_t) == len(batches_action_seqs) == len(batches_s_T)
    num_minibatches = len(batches_s_t)

    while grad_steps < N_GRAD_STEPS:
        if i == num_minibatches:
            i = 0
            batches_s_t, batches_action_seqs, batches_s_T = get_shuffled_minibatches(obs, acs, terminals, batch_size=128, T=T)
            assert len(batches_s_t) == len(batches_action_seqs) == len(batches_s_T)
            num_minibatches = len(batches_s_t)

        batch_s_t = batches_s_t[i % num_minibatches]
        batch_action_seqs = batches_action_seqs[i % num_minibatches]
        batch_s_T = batches_s_T[i % num_minibatches]

        loss = gc_model.update(batch_s_t, batch_action_seqs, batch_s_T)
        loss_history.append(loss)
        if grad_steps % TRAIN_LOG_FREQ == 0:
            print(f"\n============== Training Step #{grad_steps} ==============")
            print("MSE Loss =", loss)
        grad_steps += 1
        i += 1

    plt.title("Gaussian Channel Model loss")
    plt.xlabel("Training Steps")
    plt.plot(loss_history)
    plt.show()
    
train_cycle_dataset()

In [None]:
import cvxpy as cp

# TODO: use multiprocessing to parallelize this
def water_filling(sing_vals, sum_p=1.0):
    p = cp.Variable(len(sing_vals))
    obj_fn = 0.5 * cp.sum(cp.log(1 + cp.multiply(sing_vals, p)))
    obj = cp.Maximize(obj_fn)
    constraints = [p >= 0, cp.sum(p) == sum_p]
    prob = cp.Problem(obj, constraints)
    prob.solve()
    return prob.value

## Contour plot of empowerment (using actual data points)

In [None]:
G_matrices = gc_model.G(obs)
singular_values = ptu.to_numpy(torch.svd(G_matrices).S).squeeze()

import multiprocessing
MAX_WORKERS = 32
with multiprocessing.Pool(MAX_WORKERS) as pool:
    empowerment_vals = pool.map(water_filling, singular_values)
empowerment_vals = np.array(empowerment_vals)

In [None]:
normalized_empowerment = empowerment_vals.copy()
normalized_empowerment = (normalized_empowerment - normalized_empowerment.mean()) / normalized_empowerment.std()

# empowerment_probs = normalized_empowerment / normalized_empowerment.sum()

def exp_normalize(x):
    b = x.max()
    y = np.exp(x - b)
    return y / y.sum()

empowerment_probs = exp_normalize(normalized_empowerment)

sampled_goal_idxs = np.random.choice(np.arange(len(empowerment_vals)), p=empowerment_probs, size=500)
sampled_goals

sampled_goals = obs[sampled_goal_idxs]

In [None]:
plt.figure(figsize=(8, 8))

plt.imshow(env.render(mode='rgb_array', width=32, height=32), origin='lower', extent=(-4, 4, -4, 4), alpha=1.0)

plt.gca().invert_yaxis()
xs, ys = obs[:, 0], obs[:, 1]
plt.tricontourf(xs, ys, normalized_empowerment, 30, alpha=0.8)
plt.colorbar()

max_idx = np.argmax(empowerment_vals)
min_idx = np.argmin(empowerment_vals)
plt.scatter(obs[max_idx, 0], obs[max_idx, 1], color="red", marker="*", s=200, label="Max Empowerment")
plt.scatter(obs[min_idx, 0], obs[min_idx, 1], color="orange", marker="*", s=200, label="Min Empowerment")

plt.scatter(sampled_goals[:, 0], sampled_goals[:, 1], color="yellow", marker="+", s=200, label="Sampled Goals")

plt.legend(loc='upper left', bbox_to_anchor=(0, -0.08))

plt.xlim(-4, 4)
plt.ylim(-4, 4)
plt.title(f"T={T}, lr={gc_model.learning_rate}, l2_lambda={gc_model.l2_lambda}")

plt.show()

## Sampled plot of empowerment (uniformly spaced)

In [None]:
x = np.linspace(-env.boundary_dist, env.boundary_dist, 32)
y = np.linspace(-env.boundary_dist, env.boundary_dist, 32)
xv, yv = np.meshgrid(x, y)
xys = np.hstack([xv.reshape(-1, 1), yv.reshape(-1, 1)])

# Multiprocessing version
G_matrices = gc_model.G(xys)
singular_values = ptu.to_numpy(torch.svd(G_matrices).S).squeeze()

import multiprocessing
MAX_WORKERS = 10
with multiprocessing.Pool(MAX_WORKERS) as pool:
    empowerment_vals = pool.map(water_filling, singular_values)
empowerment_vals = np.array(empowerment_vals)

In [None]:
plt.figure(figsize=(8, 8))

# plt.imshow(env.render(mode='rgb_array', width=32, height=32),
#            extent=(-4, 4, -4, 4), origin='lower', alpha=0.25, zorder=3)

plt.imshow(env.render(mode='rgb_array', width=32, height=32), origin='lower', extent=(-4, 4, -4, 4), alpha=1.0)

plt.gca().invert_yaxis()
xs, ys = xys[:, 0], xys[:, 1]
plt.tricontourf(xs, ys, empowerment_vals, 30, alpha=0.8)

# plt.imshow(empowerment_vals.reshape(32, 32), extent=(0, 32, 0, 32), origin='lower', alpha=0.8)
plt.colorbar()

max_idx = np.argmax(empowerment_vals)
min_idx = np.argmin(empowerment_vals)
plt.scatter(xys[max_idx, 0], xys[max_idx, 1], color="red", marker="*", s=300, label="Max Empowerment", edgecolors='b')
plt.scatter(xys[min_idx, 0], xys[min_idx, 1], color="orange", marker="*", s=300, label="Min Empowerment", edgecolors='b')

plt.legend(loc='upper left', bbox_to_anchor=(0, -0.08))
plt.xlim(-4, 4)
plt.ylim(-4, 4)
plt.title(f"Empowerment Estimates\nT={T}, lr={gc_model.learning_rate}, l2_lambda={gc_model.l2_lambda}")
plt.gca().invert_yaxis()

plt.show()

## Sampled plot of empowerment (Pendulum-v0)

In [None]:
n_samples = 100
theta = np.linspace(-np.pi, np.pi, n_samples)
vel = np.linspace(-8, 8, n_samples)
thetav, velv = np.meshgrid(theta, vel)
cos_sin_vel = np.hstack([np.cos(thetav.reshape(-1, 1)), np.sin(thetav.reshape(-1, 1)), velv.reshape(-1, 1)])

# Multiprocessing version
G_matrices = gc_model.G(cos_sin_vel)
singular_values = ptu.to_numpy(torch.svd(G_matrices).S).squeeze()

import multiprocessing
MAX_WORKERS = 10
with multiprocessing.Pool(MAX_WORKERS) as pool:
    empowerment_vals = pool.map(water_filling, singular_values)
empowerment_vals = np.array(empowerment_vals)

In [None]:
plt.figure(figsize=(8, 8))
plt.imshow(empowerment_vals.reshape(100, 100))
plt.gca().invert_yaxis()

# Analytical Empowerment

In [None]:
n_bins = 32
x = np.linspace(-4, 4, n_bins)
y = np.linspace(-4, 4, n_bins)
xv, yv = np.meshgrid(x, y)
xys = np.hstack([xv.reshape(-1, 1), yv.reshape(-1, 1)])

In [None]:
actions = np.array([
    [0, 1], # Up
    [1, 0], # Right
    [-1, 0], # Left
    [0, -1], # Down
#     [1, 1], # Up Right
#     [1, -1], # Down Right
#     [-1, 1], # Up Left
#     [-1, -1], # Down Left
])

In [None]:
def try_all_actions(env, possible_actions, N=1):
    seen_states = set()
    def helper(pos, depth):
        if depth > 0:
            for a in possible_actions:
                env.set_position(pos)
                ob, _, _, _ = env.step(a)
                discrete_ob = ob["discrete_observation"]
                seen_states.add(tuple(discrete_ob))
                helper(ob["observation"], depth - 1)
            
    helper(env._get_obs()["observation"], N)
    return seen_states

reachability = np.zeros_like(xv)
for i, xy in enumerate(xys):
    env.set_position(xy)
    seen = try_all_actions(env, actions, N=4)
    reachability[i // n_bins][i % n_bins] = len(seen)

In [None]:
plt.imshow(reachability)
plt.colorbar()