In [1]:
import sys
sys.path.append("../")

import jax
import os
os.environ["D4RL_SUPPRESS_IMPORT_ERROR"] = "1"

import warnings
warnings.filterwarnings("ignore")


In [1]:
import equinox as eqx
import os
os.environ['CUDA_VISIBLE_DEVICES']='1'
import jax
import equinox.nn as nn
import functools

hidden_dims = [256, 256]
state_dim=29
rng = jax.random.PRNGKey(42)
network_cls = functools.partial(nn.MLP, in_size=state_dim, out_size=hidden_dims[-1],
                                        width_size=hidden_dims[0], depth=len(hidden_dims),
                                        final_activation=jax.nn.relu)
phi_net = network_cls(key=rng)

In [2]:
new_phi_net=eqx.tree_deserialise_leaves("../icvf_model.eqx", phi_net)

In [5]:
new_phi_net

MLP(
  layers=(
    Linear(
      weight=f32[256,29],
      bias=f32[256],
      in_features=29,
      out_features=256,
      use_bias=True
    ),
    Linear(
      weight=f32[256,256],
      bias=f32[256],
      in_features=256,
      out_features=256,
      use_bias=True
    ),
    Linear(
      weight=f32[256,256],
      bias=f32[256],
      in_features=256,
      out_features=256,
      use_bias=True
    )
  ),
  activation=<wrapped function relu>,
  final_activation=<wrapped function relu>,
  use_bias=True,
  use_final_bias=True,
  in_size=29,
  out_size=256,
  width_size=256,
  depth=2
)

In [17]:
is_linear = lambda x: isinstance(x, eqx.nn.Linear)
get_weights = lambda m: [x.weight
                         for x in jax.tree_util.tree_leaves(m, is_leaf=is_linear)
                         if is_linear(x)]

In [23]:
get_weights(new_phi_net)[2].shape

(256, 256)

In [5]:
model_partial = eqx.tree_at(lambda mlp: mlp.layers[-1], new_phi_net, phi_net.layers[-1])

In [6]:
model_partial

MLP(
  layers=(
    Linear(
      weight=f32[256,29],
      bias=f32[256],
      in_features=29,
      out_features=256,
      use_bias=True
    ),
    Linear(
      weight=f32[256,256],
      bias=f32[256],
      in_features=256,
      out_features=256,
      use_bias=True
    ),
    Linear(
      weight=f32[256,256],
      bias=f32[256],
      in_features=256,
      out_features=256,
      use_bias=True
    )
  ),
  activation=<wrapped function relu>,
  final_activation=<wrapped function relu>,
  use_bias=True,
  use_final_bias=True,
  in_size=29,
  out_size=256,
  width_size=256,
  depth=2
)

In [2]:
from icvf_envs.antmaze import d4rl_utils, d4rl_ant, d4rl_pm
from src.gc_dataset import GCSDataset

gcdataset_config = GCSDataset.get_default_config()

env = d4rl_utils.make_env("antmaze-large-diverse-v2")
dataset = d4rl_utils.get_dataset(env)
gc_dataset = GCSDataset(dataset, **gcdataset_config.to_dict())
example_batch = gc_dataset.sample(1)

pybullet build time: May 20 2022 19:45:31
  from jax import ShapedArray


Target Goal:  (32.33135808427672, 24.222535311056507)


load datafile: 100%|██████████| 8/8 [00:02<00:00,  3.23it/s]


In [3]:
test = gc_dataset.sample_trajectories(2)
example_batch.keys()

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  early_termination = np.zeros((n_trajectories, max_path_length), dtype=np.bool)


dict_keys(['actions', 'dones_float', 'masks', 'next_observations', 'observations', 'rewards', 'desired_rewards', 'desired_masks', 'goals', 'desired_goals'])

#### Each token - (observations, next_observations, dones_float)
For ant_maze obs space is vector of size 29, action - 8. Dones - as positional embeddings

### VQVAE

In [4]:
import optax

import flax.linen as nn
import einops
import jax.numpy as jnp


In [5]:
from dataclasses import dataclass, field

@dataclass
class VqvaeConfig:
    hidden_dim: int = 64
    latent_dim: int = 128
    
    num_tokens: int = 2 # each pair = one token
    trajectory_len: int = 1000
    L: int = 4 # number of M = T / L vectors
    latent_dim: int = 512

In [6]:
from typing import Any, Optional
from flax import linen as nn
import jax
import jax.numpy as jnp

class ExponentialMovingAverage(nn.Module):
  shape: list
  dtype: Any = jnp.float32
  decay: float = 0.

  def setup(self):
    shape = self.shape
    dtype = self.dtype
    self.hidden = self.variable("stats", "hidden", lambda: jnp.zeros(shape, dtype=dtype))
    self.average = self.variable("stats", "average", lambda: jnp.zeros(shape, dtype=dtype)) # how to deal with initialized?
    constant = lambda: jnp.zeros(shape, dtype=jnp.int32)
    self.counter = self.variable("stats", "counter", constant)

  def __call__(
      self,
      value: jnp.ndarray,
      update_stats: bool = True,
  ) -> jnp.ndarray:

    counter = self.counter.value + 1
    decay = jax.lax.convert_element_type(self.decay, value.dtype)
    one = jnp.ones([], value.dtype)
    hidden = self.hidden.value * decay + value * (one - decay)

    average = hidden
    average /= (one - jnp.power(decay, counter))
    if update_stats:
      self.counter.value = counter
      self.hidden.value = hidden
      self.average.value = average
    return average

# inspired from Haiku's corresponding code to Flax
# https://github.com/deepmind/dm-haiku/blob/master/haiku/_src/nets/vqvae.py

class VectorQuantizerEMA(nn.Module):
  embedding_dim: int
  num_embeddings: int
  commitment_cost: float
  decay: float
  epsilon: float = 1e-5
  dtype: Any = jnp.float32
  cross_replica_axis: Optional[str] = None  
  initialized: bool = False

  @nn.compact
  def __call__(self, inputs, is_training, rng=None, encoding_indices=None):
    embedding_shape = [self.embedding_dim, self.num_embeddings]
    assert self.dtype == jnp.float32
    ema_cluster_size = ExponentialMovingAverage([self.num_embeddings], self.dtype, decay=self.decay)
    ema_dw = ExponentialMovingAverage(embedding_shape, self.dtype, decay=self.decay)
    initialized = self.has_variable('stats', 'embeddings')
    embeddings = self.variable("stats", "embeddings", nn.initializers.lecun_uniform(), jax.random.PRNGKey(42), embedding_shape)
    
    def quantize(encoding_indices):
        """Returns embedding tensor for a batch of indices."""
        w = embeddings.value.swapaxes(1, 0)
        w = jax.device_put(w)  # Required when embeddings is a NumPy array.
        return w[(encoding_indices,)]

    if encoding_indices is not None:
        return quantize(encoding_indices)
    
    if not initialized:
        hidden, counter, average = ema_cluster_size.hidden, ema_cluster_size.counter, ema_cluster_size.average
        hidden, counter, average = ema_dw.hidden, ema_dw.counter, ema_dw.average     
        return {
            "quantize": inputs,
            "loss": inputs.mean(),
        }
    
    flat_inputs = jnp.reshape(inputs, [-1, self.embedding_dim])
    distances = (
        jnp.sum(flat_inputs**2, 1, keepdims=True) -
        2 * jnp.matmul(flat_inputs, embeddings.value) +
        jnp.sum(embeddings.value**2, 0, keepdims=True))

    encoding_indices = jnp.argmax(-distances, 1)
    encodings = jax.nn.one_hot(encoding_indices,
                               self.num_embeddings,
                               dtype=distances.dtype)

    encoding_indices = jnp.reshape(encoding_indices, inputs.shape[:-1])
    quantized = quantize(encoding_indices)
    e_latent_loss = jnp.mean((jax.lax.stop_gradient(quantized) - inputs)**2)

    if is_training:
      cluster_size = jnp.sum(encodings, axis=0)
      if self.cross_replica_axis:
        cluster_size = jax.lax.psum(
            cluster_size, axis_name=self.cross_replica_axis)
      updated_ema_cluster_size = ema_cluster_size(cluster_size, update_stats=is_training)

      dw = jnp.matmul(flat_inputs.T, encodings)
      if self.cross_replica_axis:
        dw = jax.lax.psum(dw, axis_name=self.cross_replica_axis)
      updated_ema_dw = ema_dw(dw, update_stats=is_training)

      n = jnp.sum(updated_ema_cluster_size)
      updated_ema_cluster_size = ((updated_ema_cluster_size + self.epsilon) /
                                  (n + self.num_embeddings * self.epsilon) * n)

      normalised_updated_ema_w = (
          updated_ema_dw / jnp.reshape(updated_ema_cluster_size, [1, -1]))

      embeddings.value = normalised_updated_ema_w
      loss = self.commitment_cost * e_latent_loss

    else:
      loss = self.commitment_cost * e_latent_loss

    # Straight Through Estimator
    quantized = inputs + jax.lax.stop_gradient(quantized - inputs)
    avg_probs = jnp.mean(encodings, 0)
    if self.cross_replica_axis:
      avg_probs = jax.lax.pmean(avg_probs, axis_name=self.cross_replica_axis)
    perplexity = jnp.exp(-jnp.sum(avg_probs * jnp.log(avg_probs + 1e-10)))

    return {
        "quantize": quantized,
        "loss": loss,
        "perplexity": perplexity,
        "encodings": encodings,
        "encoding_indices": encoding_indices,
        "distances": distances,
    }

In [7]:
from typing import Any
import jax.numpy as jnp
from einops import rearrange, repeat
import sys
sys.path.append("../")

from jaxrl_m.networks import MLP
from jax._src import prng

class VQVAE(nn.Module):
    hyperparams: VqvaeConfig
    vq_class: VectorQuantizerEMA
    mlp: MLP
    quant_key: prng.PRNGKeyArray
    
    @nn.compact
    def __call__(self, inputs, state) -> Any: #state
        # B x T x embedding
        B, T, observation_dim = inputs.shape
        token_emb = self.mlp(hidden_dims=(self.hyperparams.hidden_dim, self.hyperparams.latent_dim))(inputs)
        #token_emb = nn.Dense(features=self.hyperparams.latent_dim)(token_emb)
        
        pos_embedding = self.param('pos_embedding', nn.initializers.zeros, [B, self.hyperparams.trajectory_len+1, self.hyperparams.latent_dim])
        goal_token = self.param('goal', nn.initializers.zeros, [B, 1, self.hyperparams.latent_dim])
        token_emb = jnp.concatenate([token_emb, goal_token], axis=1)
        token_emb += pos_embedding[:self.hyperparams.trajectory_len + 1]
        
        #B x (T // L) x embedding_dim
        # like in TAP paper MaxPool with window = 3
        token_emb = nn.max_pool(token_emb, window_shape=(self.hyperparams.L, ), strides=(self.hyperparams.L, ))
        token_emb = self.mlp(hidden_dims=(self.hyperparams.latent_dim, ))(token_emb)
        vq_info = self.quantize(token_emb)
        
        # Decoding stage
        # B x (T // Latent) x embedding_dim
        # State: B x observ_dim
        latents = vq_info['quantize']
        # predict full trajectory
        decoded_traj = self.decode(latents, state, observation_dim)
        print(decoded_traj.shape)
        return vq_info, decoded_traj
    
    def decode(self, latents, state, observation_dim):
        B, T, _ = latents.shape
        
        state_flat = einops.repeat(einops.rearrange(state, 'B E -> B 1 E'), 'B 1 E -> B T E', T=T)
        inputs = jnp.concatenate([state_flat, latents], axis=-1)
        inputs = jnp.repeat(inputs, repeats=self.hyperparams.L, axis=1)
        inputs = self.mlp(hidden_dims=(observation_dim, observation_dim))(inputs)
        return inputs
        
    def quantize(self, token_emb):
        vq_vars = self.vq_class.init({'params':self.quant_key}, token_emb, is_training=True)
        vq_info = self.vq_class.apply(vq_vars, token_emb, is_training=False)
        
        return vq_info

In [8]:
test[0][None, ...].shape

(1, 1000, 58)

In [9]:
base_key = jax.random.PRNGKey(42)
model_key, quantization_key, dummy_states = jax.random.split(base_key, 3)

# B x Tokens x observ_dim
vq_config = VqvaeConfig()
vq_class = VectorQuantizerEMA(embedding_dim=512, num_embeddings=128, commitment_cost=0.2, decay=0.1)

model = VQVAE(vq_config, vq_class, MLP, quantization_key)
dummy_states = jax.random.uniform(dummy_states, (1, 32))

init_vars = model.init(model_key, test[0][None, ...], dummy_states)

(1, 1000, 58)


In [10]:
vq_info, decoded_traj = model.apply(init_vars, test[0][None, ...], dummy_states)

(1, 1000, 58)
