In [1]:
import jax 
import jax.numpy as jnp
from flax import linen as nn
jax.config.update('jax_platform_name', 'cpu')

In [91]:
class LearnedPriorRNN(nn.RNNCellBase):
    cell: nn.RNNCellBase

    @nn.compact
    def initialize_carry(self, rng, input_shape):
        del rng
        variable = self.variable('collection', 'name', jnp.zeros, (4,))
        variable1 = self.variable('collection2', 'name', jnp.zeros, (2,))
        init = self.param('learned', nn.initializers.constant(0), (self.cell.features,))
        return jnp.tile(init, input_shape + (1,))

In [92]:
base_cell = nn.GRUCell(31)
cell = LearnedPriorRNN(base_cell)

In [93]:
x = jnp.zeros((3,))
key = jax.random.PRNGKey(0)
params = cell.init(key, key, (), method=cell.initialize_carry)

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../')

import jax
import jax.numpy as jnp
jax.config.update('jax_platform_name', 'cpu')

from src.networks import scaled_dot_product, MultiHeadAttention, TransformerBlock, TransformerEncoder
from src.ops import _enc1d, positional_encoding

In [3]:
x = jnp.zeros((3,4,5))

In [17]:
import numpy as np

In [18]:
a = np.random.normal(size=(2,3))
b = np.random.normal(size=(4,5))
c = np.random.normal(size=(7,11))

In [2]:
list(range(-3, -1))

[-3, -2]

In [14]:
import jax.numpy as jnp
import jax
jax.config.update('jax_platform_name', 'cpu')

def _concat_emb(carry, y):
    """
    Args:
        carry.shape = (d1, d2, ..., dn, kx)
        y.shape = (dy, ky)
    Returns:
        res.shape = (d1, ..., dn, dy, kx + ky)
    """
    if not isinstance(carry, jnp.ndarray):
        # Carry init.
        return y
    x = carry
    dims = x.shape[:-1]
    x = jnp.expand_dims(x, -2)
    y = jnp.expand_dims(y, tuple(range(len(dims))))
    x = jnp.repeat(x, y.shape[-2], -2)
    repeats = dims + (1, 1)
    y = jnp.tile(y, repeats)
    return jnp.concatenate([x, y], -1)

In [1]:
import jax
import jax.numpy as jnp
jax.config.update('jax_platform_name', 'cpu')

from flax import linen as nn
from flax import core

In [30]:
class A(nn.Module):

    def setup(self):
        self.a = self.param('a', lambda: (lambda rng, sh, dtype: jnp.zeros(sh)))
        self.b = self.variable('test', 'b', jnp.ones((4,)))

    def __call__(self):
        return self.a, self.b

In [31]:
rng = jax.random.PRNGKey(0)
params = A().init(rng)

TypeError: A.setup.<locals>.<lambda>() takes 0 positional arguments but 1 was given

In [32]:

def generate_fourier_features(
    pos, num_bands, max_resolution=(224, 224),
    concat_pos=True, sine_only=False):
  """Generate a Fourier frequency position encoding with linear spacing.

  Args:
    pos: The position of n points in d dimensional space.
      A jnp array of shape [n, d].
    num_bands: The number of bands (K) to use.
    max_resolution: The maximum resolution (i.e. the number of pixels per dim).
      A tuple representing resolution for each dimension
    concat_pos: Concatenate the input position encoding to the Fourier features?
    sine_only: Whether to use a single phase (sin) or two (sin/cos) for each
      frequency band.
  Returns:
    embedding: A 1D jnp array of shape [n, n_channels]. If concat_pos is True
      and sine_only is False, output dimensions are ordered as:
        [dim_1, dim_2, ..., dim_d,
         sin(pi*f_1*dim_1), ..., sin(pi*f_K*dim_1), ...,
         sin(pi*f_1*dim_d), ..., sin(pi*f_K*dim_d),
         cos(pi*f_1*dim_1), ..., cos(pi*f_K*dim_1), ...,
         cos(pi*f_1*dim_d), ..., cos(pi*f_K*dim_d)],
       where dim_i is pos[:, i] and f_k is the kth frequency band.
  """
  min_freq = 1.0
  # Nyquist frequency at the target resolution:

  freq_bands = jnp.stack([
      jnp.linspace(min_freq, res / 2, num=num_bands, endpoint=True)
      for res in max_resolution], axis=0)

  # Get frequency bands for each spatial dimension.
  # Output is size [n, d * num_bands]
  per_pos_features = pos[:, :, None] * freq_bands[None, :, :]
  per_pos_features = jnp.reshape(per_pos_features,
                                 [-1, np.prod(per_pos_features.shape[1:])])

  if sine_only:
    # Output is size [n, d * num_bands]
    per_pos_features = jnp.sin(jnp.pi * (per_pos_features))
  else:
    # Output is size [n, 2 * d * num_bands]
    per_pos_features = jnp.concatenate(
        [jnp.sin(jnp.pi * per_pos_features),
         jnp.cos(jnp.pi * per_pos_features)], axis=-1)
  # Concatenate the raw input positions.
  if concat_pos:
    # Adds d bands to the encoding.
    per_pos_features = jnp.concatenate([pos, per_pos_features], axis=-1)
  return per_pos_features

In [33]:
def build_linear_positions(index_dims, output_range=(-1.0, 1.0)):
  """Generate an array of position indices for an N-D input array.

  Args:
    index_dims: The shape of the index dimensions of the input array.
    output_range: The min and max values taken by each input index dimension.
  Returns:
    A jnp array of shape [index_dims[0], index_dims[1], .., index_dims[-1], N].
  """
  def _linspace(n_xels_per_dim):
    return jnp.linspace(
        output_range[0], output_range[1],
        num=n_xels_per_dim,
        endpoint=True, dtype=jnp.float32)

  dim_ranges = [
      _linspace(n_xels_per_dim) for n_xels_per_dim in index_dims]
  array_index_grid = jnp.meshgrid(*dim_ranges, indexing='ij')

  return jnp.stack(array_index_grid, axis=-1)

In [39]:
pos = build_linear_positions([1,3])

In [52]:
import numpy as np
feats = generate_fourier_features(pos[..., 0], 5)

TypeError: mul got incompatible shapes for broadcasting: (1, 3, 1), (1, 2, 5).

In [51]:
feats.shape

(3, 22)

In [49]:
pos.shape

(1, 3, 2)