In [2]:
import os
os.environ["GEOMSTATS_BACKEND"] = "jax"

import jax
import jax.numpy as jnp
from geomstats.geometry.hypersphere import Hypersphere
from geomstats.geometry.product_manifold import ProductSameManifold, ProductSameRiemannianMetric
import geomstats.backend as gs

INFO:absl:Remote TPU is not linked into jax; skipping remote TPU.
INFO:absl:Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'
INFO:absl:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:absl:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:absl:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
INFO:root:Using jax backend


In [3]:
key = jax.random.PRNGKey(45)
key, subkey = jax.random.split(key)

jax.random.dirichlet(subkey, jnp.array([3, 3, 3]), (10,))

DeviceArray([[0.29211295, 0.4125772 , 0.29530984],
             [0.39183572, 0.45557895, 0.15258528],
             [0.39755124, 0.3422433 , 0.26020542],
             [0.29094806, 0.25036165, 0.4586903 ],
             [0.26951846, 0.25899136, 0.47149017],
             [0.2152559 , 0.13369545, 0.65104866],
             [0.6735201 , 0.10467911, 0.2218008 ],
             [0.11310774, 0.6675091 , 0.21938318],
             [0.34312925, 0.24179856, 0.41507214],
             [0.53123134, 0.07096099, 0.3978077 ]], dtype=float32)

In [8]:
from random import choice, choices

class BiGramSampler:
    vocab = ('a', 'b', 'c')
    
    probs = {
        'a': (1./3, 1./3, 1.3),
        'b': (0.1, 0.4, 0.5),
        'c': (0.6, 0.2, 0.2)
    }
    
    def __init__(self, seq_len: int):
        self.seq_len = seq_len
        
    def _sample(self):
        output = [choice(self.vocab)]
        for _ in range(self.seq_len - 1):
            output += choices(self.vocab, weights=self.probs[output[-1]])
        return ''.join(output)
        
    def sample(self, n):
        return [self._sample() for _ in range(n)]

['c']

In [20]:
from einops import rearrange, repeat
from flax import linen as nn
from flax import struct
import jax
from typing import Callable, Sequence


@struct.dataclass
class TransformerConfig:
    vocab_size: int
    model_dim: int
    mlp_dim: int
    num_layers: int = 3
    time_dim: int = 16
    num_heads: int = 8
    max_length: int = 512
    dropout_rate: float = 0.1
    attention_dropout_rate: float = 0.1
    # kernel_init: Callable = nn.initializers.xavier_uniform()
    fourier_init_std: float = 0.2


def alpha_sigma_to_log_snr(alpha, sigma):
    """Returns a log snr, given the scaling factors for the clean image and for
    the noise."""
    return jnp.log(alpha**2 / sigma**2)


def t_to_alpha_sigma(t):
    """Returns the scaling factors for the clean image and for the noise, given
    a timestep."""
    return jnp.cos(t * jnp.pi / 2), jnp.sin(t * jnp.pi / 2)



class FourierFeatures(nn.Module):
    config: TransformerConfig

    @nn.compact
    def __call__(self, x):
        w = self.param(
            'w',
            nn.initializers.normal(stddev=self.config.fourier_init_std),
            (self.config.time_dim // 2, x.shape[1]),
        )
        f = 2 * jnp.pi * x @ w.T
        return jnp.concatenate([jnp.cos(f), jnp.sin(f)], axis=-1)


class SelfAttention(nn.Module):
    config: TransformerConfig

    @nn.compact
    def __call__(self, x, padding_mask=None, deterministic=False):
        x = nn.LayerNorm(use_bias=False, use_scale=False)(x)
        padding_mask = None if padding_mask is None else padding_mask[:, None, None, :]
        x = nn.MultiHeadDotProductAttention(
            self.config.num_heads,
            dropout_rate=self.config.attention_dropout_rate
        )(x, x, padding_mask, deterministic=deterministic)
        return x


class FeedForward(nn.Module):
    config: TransformerConfig

    @nn.compact
    def __call__(self, x):
        x = nn.LayerNorm(use_bias=False, use_scale=False)(x)
        x = nn.Dense(self.config.mlp_dim, use_bias=False)(x)
        x = nn.gelu(x)
        x = nn.Dense(self.config.model_dim, use_bias=False)(x)
        return x


class TransformerLayer(nn.Module):
    config: TransformerConfig

    @nn.compact    
    def __call__(self, x, padding_mask=None, deterministic=False):
        #x_rot = x[:, :, :self.d_rotary]
        #x_pass = x[ :, :, self.d_rotary:]
        #sincos = fixed_pos_embedding(x_rot, seq_dim=1)
        #x_rot = apply_rotary_pos_emb(x_rot, sincos)
        #x = jnp.concatenate([x_rot, x_pass], axis=-1)
        x = x + SelfAttention(self.config)(x, padding_mask, deterministic)
        x = x + FeedForward(self.config)(x)
        return x


class SkipBlock(nn.Module):
    layers: Sequence[nn.Module]

    @nn.compact
    def __call__(self, x):
        x_new = nn.Sequential(*self.layers)(x)
        return jnp.concatenate(
            [x_new, x], 
            dim=1
        )


class TransformerDiffusion(nn.Module):
    config: TransformerConfig
    
    @nn.compact
    def __call__(self, x, t, training=False):
        """
        x.shape = 
        """
        deterministic = not training
        #x = normalize_probabilities(x)
        #x_init = nn.Dense(self.config.embed_dim)(x)
        log_snr = alpha_sigma_to_log_snr(*t_to_alpha_sigma(t))
        timestep_embed = FourierFeatures(self.config)(log_snr[:, None])
        te_planes = jnp.tile(timestep_embed[:, None], (1, self.config.max_length, 1))
        x = jnp.concatenate([x, te_planes], axis=-1)
        x = FeedForward(self.config)(x)
        trans_x = nn.Sequential([
            TransformerLayer(self.config) for _ in range(self.config.num_layers)
        ])(x, None, deterministic=deterministic)        
        x = x + jnp.sqrt(2) * trans_x
        x = nn.Dense(self.config.vocab_size)(x)
        x_final = x - jnp.mean(x, axis=-1, keepdims=True)
        x_final = x_final / jnp.var(x_final, axis=-1, keepdims=True)
        return x_final

In [21]:
"""
class TransformerConfig:
    vocab_size: int
    embed_dim: int
    model_dim: int
    mlp_dim: int
    num_layers: int = 3
    time_dim: int = 16
    num_heads: int = 8
    max_length: int = 512
    dropout_rate: float = 0.1
    attention_dropout_rate: float = 0.1
    # kernel_init: Callable = nn.initializers.xavier_uniform()
    fourier_init_std: float = 0.2
"""

config = TransformerConfig(
    vocab_size=3,
    model_dim=12,
    mlp_dim=24,
    num_layers =1,
    time_dim=4,
    num_heads=1,
    max_length= 24,
    dropout_rate= 0.1,
    attention_dropout_rate = 0.1,
    # kernel_init: Callable = nn.initializers.xavier_uniform()
    fourier_init_std = 0.2
)

In [30]:
class BiGramSampler:
    vocab = ('a', 'b', 'c')
    
    probs = {
        'a': (1./3, 1./3, 1.3),
        'b': (0.1, 0.4, 0.5),
        'c': (0.6, 0.2, 0.2)
    }

    vectors = {
        'a': jnp.array([1., 0., 0.]),
        'b': jnp.array([0., 1., 0.]),
        'c': jnp.array([0., 0., 1.])
    }
    
    def __init__(self, seq_len: int):
        self.seq_len = seq_len
        
    def _sample(self):
        output = [choice(self.vocab)]
        for _ in range(self.seq_len - 1):
            output += choices(self.vocab, weights=self.probs[output[-1]])
        return ''.join(output)
        
    def sample(self, n: int):
        return [self._sample() for _ in range(n)]

    def to_jax(self, s):
        return jnp.stack([self.vectors[char] for char in s])

    def make_batch(self, n: int):
        tensors = [
            self.to_jax(self._sample()) for _ in range(n)
        ]
        return jnp.stack(tensors)


In [31]:
sampler = BiGramSampler(10)

batch = sampler.make_batch(5)

In [32]:
batch[0]

DeviceArray([[0., 0., 1.],
             [1., 0., 0.],
             [0., 0., 1.],
             [0., 1., 0.],
             [1., 0., 0.],
             [1., 0., 0.],
             [1., 0., 0.],
             [0., 0., 1.],
             [1., 0., 0.],
             [0., 1., 0.]], dtype=float32)