In [1]:
import sys
sys.path.append("../")

import jax
import os
os.environ["D4RL_SUPPRESS_IMPORT_ERROR"] = "1"

import warnings
warnings.filterwarnings("ignore")


In [3]:
from icvf_envs.antmaze import d4rl_utils, d4rl_ant, d4rl_pm
from src.gc_dataset import GCSDataset

gcdataset_config = GCSDataset.get_default_config()

env = d4rl_utils.make_env("antmaze-large-diverse-v2")
dataset = d4rl_utils.get_dataset(env)
gc_dataset = GCSDataset(dataset, **gcdataset_config.to_dict())
example_batch = gc_dataset.sample(1)

pybullet build time: May 20 2022 19:45:31
  from jax import ShapedArray


Target Goal:  (33.24260065681753, 25.045786800033035)


load datafile: 100%|██████████| 8/8 [00:02<00:00,  2.80it/s]


In [15]:
import numpy as np

def segment(observations, terminals, max_path_length):
    """
        segment `observations` into trajectories according to `terminals`
    """
    assert len(observations) == len(terminals)
    observation_dim = observations.shape[1]

    trajectories = [[]]
    for obs, term in zip(observations, terminals):
        trajectories[-1].append(obs)
        if term.squeeze():
            trajectories.append([])

    if len(trajectories[-1]) == 0:
        trajectories = trajectories[:-1]

    ## list of arrays because trajectories lengths will be different
    trajectories = [np.stack(traj, axis=0) for traj in trajectories]

    n_trajectories = len(trajectories)
    path_lengths = [len(traj) for traj in trajectories]

    ## pad trajectories to be of equal length
    trajectories_pad = np.zeros((n_trajectories, max_path_length, observation_dim), dtype=trajectories[0].dtype)
    early_termination = np.zeros((n_trajectories, max_path_length), dtype=np.bool)
    for i, traj in enumerate(trajectories):
        path_length = path_lengths[i]
        trajectories_pad[i,:path_length] = traj
        early_termination[i,path_length:] = 1

    return trajectories_pad, early_termination, path_lengths

In [18]:
import jax.numpy as jnp
test_obs = jnp.concatenate([dataset['observations'], dataset['next_observations']], axis=1)
segmented = segment(test_obs, dataset['dones_float'], 1000)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  early_termination = np.zeros((n_trajectories, max_path_length), dtype=np.bool)


In [38]:
segmented[1]

array([[False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       ...,
       [False, False, False, ...,  True,  True,  True],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False]])

In [3]:
import jax.numpy as jnp
import numpy as np

terminal_indxes = np.argwhere(dataset["dones_float"] > 0.5).squeeze()
random_idx = np.random.choice(terminal_indxes, size=2, replace=False)
len(random_idx)
#dataset.get_subset(random_idx)

2

In [4]:
example_batch.keys()

dict_keys(['actions', 'dones_float', 'masks', 'next_observations', 'observations', 'rewards', 'desired_rewards', 'desired_masks', 'goals', 'desired_goals'])

#### Each token - (observations, next_observations, dones_float)
For ant_maze obs space is vector of size 29, action - 8. Dones - as positional embeddings

In [5]:
print(gc_dataset.sample_trajectories(2)[1]['observations'].shape)
print(gc_dataset.sample_trajectories(2)[1]['next_observations'].shape)
print(gc_dataset.sample_trajectories(2)[1]['dones_float'].shape)

(10, 29)
(10, 29)
(10,)


In [6]:
gc_dataset.sample_trajectories(2)[0]['observations'][::5].shape

(2, 29)

### VQVAE

In [7]:
import optax

import flax.linen as nn
import einops
import jax.numpy as jnp


In [8]:
from dataclasses import dataclass, field

@dataclass
class VqvaeConfig:
    embedding_dim: int = 128
    num_tokens: int = 2 # each pair = one token
    trajectory_len: int = 10
    
    run_opt: str = field(default='train')
    adam_beta1: float = .9
    adam_beta2: float = .9
    lr: float = 3e-5
    ema_rate: float = 0.
    n_batch: int = 32   
    warmup_iters: float = 100.
    wd: float = 0.
    grad_clip: float = 200.     
    dtype: str = "float32"
    checkpoint: bool = False
    
    iters_per_ckpt: int = 25000
    iters_per_images: int = 10000
    iters_per_print: int = 1000
    iters_per_save: int = 10000

In [45]:
from typing import Any, Optional
from flax import linen as nn
import jax
import jax.numpy as jnp

# inspired from Haiku's corresponding code
# https://github.com/deepmind/dm-haiku/blob/master/haiku/_src/moving_averages.py

class ExponentialMovingAverage(nn.Module):
  shape: list
  dtype: Any = jnp.float32
  decay: float = 0.

  def setup(self):
    shape = self.shape
    dtype = self.dtype
    self.hidden = self.variable("stats", "hidden", lambda: jnp.zeros(shape, dtype=dtype))
    self.average = self.variable("stats", "average", lambda: jnp.zeros(shape, dtype=dtype)) # how to deal with initialized?
    constant = lambda: jnp.zeros(shape, dtype=jnp.int32)
    self.counter = self.variable("stats", "counter", constant)

  def __call__(
      self,
      value: jnp.ndarray,
      update_stats: bool = True,
  ) -> jnp.ndarray:

    counter = self.counter.value + 1
    decay = jax.lax.convert_element_type(self.decay, value.dtype)
    one = jnp.ones([], value.dtype)
    hidden = self.hidden.value * decay + value * (one - decay)

    average = hidden
    average /= (one - jnp.power(decay, counter))
    if update_stats:
      self.counter.value = counter
      self.hidden.value = hidden
      self.average.value = average
    return average

# inspired from Haiku's corresponding code to Flax
# https://github.com/deepmind/dm-haiku/blob/master/haiku/_src/nets/vqvae.py

class VectorQuantizerEMA(nn.Module):
  embedding_dim: int
  num_embeddings: int
  commitment_cost: float
  decay: float
  epsilon: float = 1e-5
  dtype: Any = jnp.float32
  cross_replica_axis: Optional[str] = None  
  initialized: bool = False

  @nn.compact
  def __call__(self, inputs, is_training, rng=None, encoding_indices=None):
    embedding_shape = [self.embedding_dim, self.num_embeddings]
    assert self.dtype == jnp.float32
    ema_cluster_size = ExponentialMovingAverage([self.num_embeddings], self.dtype, decay=self.decay)
    ema_dw = ExponentialMovingAverage(embedding_shape, self.dtype, decay=self.decay)
    initialized = self.has_variable('stats', 'embeddings')
    embeddings = self.variable("stats", "embeddings", nn.initializers.lecun_uniform(), rng, embedding_shape)
    
    def quantize(encoding_indices):
        """Returns embedding tensor for a batch of indices."""
        w = embeddings.value.swapaxes(1, 0)
        w = jax.device_put(w)  # Required when embeddings is a NumPy array.
        return w[(encoding_indices,)]

    if encoding_indices is not None:
        return quantize(encoding_indices)
    
    if not initialized:
        hidden, counter, average = ema_cluster_size.hidden, ema_cluster_size.counter, ema_cluster_size.average
        hidden, counter, average = ema_dw.hidden, ema_dw.counter, ema_dw.average     
        return {
            "quantize": inputs,
            "loss": inputs.mean(),
        }
    
    flat_inputs = jnp.reshape(inputs, [-1, self.embedding_dim])
    distances = (
        jnp.sum(flat_inputs**2, 1, keepdims=True) -
        2 * jnp.matmul(flat_inputs, embeddings.value) +
        jnp.sum(embeddings.value**2, 0, keepdims=True))

    encoding_indices = jnp.argmax(-distances, 1)
    encodings = jax.nn.one_hot(encoding_indices,
                               self.num_embeddings,
                               dtype=distances.dtype)

    encoding_indices = jnp.reshape(encoding_indices, inputs.shape[:-1])
    quantized = quantize(encoding_indices)
    e_latent_loss = jnp.mean((jax.lax.stop_gradient(quantized) - inputs)**2)

    if is_training:
      cluster_size = jnp.sum(encodings, axis=0)
      if self.cross_replica_axis:
        cluster_size = jax.lax.psum(
            cluster_size, axis_name=self.cross_replica_axis)
      updated_ema_cluster_size = ema_cluster_size(cluster_size, update_stats=is_training)

      dw = jnp.matmul(flat_inputs.T, encodings)
      if self.cross_replica_axis:
        dw = jax.lax.psum(dw, axis_name=self.cross_replica_axis)
      updated_ema_dw = ema_dw(dw, update_stats=is_training)

      n = jnp.sum(updated_ema_cluster_size)
      updated_ema_cluster_size = ((updated_ema_cluster_size + self.epsilon) /
                                  (n + self.num_embeddings * self.epsilon) * n)

      normalised_updated_ema_w = (
          updated_ema_dw / jnp.reshape(updated_ema_cluster_size, [1, -1]))

      embeddings.value = normalised_updated_ema_w
      loss = self.commitment_cost * e_latent_loss

    else:
      loss = self.commitment_cost * e_latent_loss

    # Straight Through Estimator
    quantized = inputs + jax.lax.stop_gradient(quantized - inputs)
    avg_probs = jnp.mean(encodings, 0)
    if self.cross_replica_axis:
      avg_probs = jax.lax.pmean(avg_probs, axis_name=self.cross_replica_axis)
    perplexity = jnp.exp(-jnp.sum(avg_probs * jnp.log(avg_probs + 1e-10)))

    return {
        "quantize": quantized,
        "loss": loss,
        "perplexity": perplexity,
        "encodings": encodings,
        "encoding_indices": encoding_indices,
        "distances": distances,
    }

In [42]:
from typing import Any
import jax.numpy as jnp
from einops import rearrange, repeat

class VQVAE(nn.Module):
    hyperparams: VqvaeConfig
    
    @nn.compact
    def __call__(self, traj_info) -> Any:
        
        obs = traj_info['observations']
        next_obs = traj_info['next_observations']
        dones = traj_info['dones_float']
        
        obs_emb = nn.Dense(features=self.hyperparams.embedding_dim)(obs)
        next_obs_emb = nn.Dense(features=self.hyperparams.embedding_dim)(next_obs)
        token_emb = jnp.reshape(jnp.concatenate([obs_emb, next_obs_emb], axis=1), (5*2, self.hyperparams.embedding_dim*2))
        
        goal_token = self.param('cls', nn.initializers.zeros, [1, self.hyperparams.embedding_dim*2])
        #goal_token = jnp.repeat(goal_token, repeats=token_emb.shape[0], axis=0)
        pos_embedding = self.param('pos_embedding', nn.initializers.zeros, [self.hyperparams.trajectory_len+1, self.hyperparams.embedding_dim*2])
        token_emb = jnp.concatenate([token_emb, goal_token], axis=0)
        token_emb += pos_embedding[:2*self.hyperparams.trajectory_len + 1]
        
        return token_emb

In [43]:
key = jax.random.PRNGKey(0)
key1, keys = jax.random.split(key, 2)

# Tokens x observ_dim
example = gc_dataset.sample_trajectories(2)[1]
model = VQVAE(VqvaeConfig)
init_vars = model.init(key1, example)

In [44]:
model.apply(init_vars, example).shape

(11, 256)

In [None]:
vq = VectorQuantizerEMA()