In [1]:
from transformers import PretrainedConfig

class StructformerConfig(PretrainedConfig):
    model_type = "structformer"

    def __init__(
        self,
        hidden_dim: int = 512,
        num_heads: int = 8,
        num_layers: int = 4,
        max_length: int = 128,
        vocab_size: int = 50257,
        c: float = 1.0,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.max_length = max_length
        self.vocab_size = vocab_size
        self.c = c


In [2]:
spdefault_config = StructformerConfig()
spdefault_config.save_pretrained("custom_structformer_config")

In [2]:
import jax.numpy as jnp

def mobius_add(x, y, c=1.0):
    # Möbius addition on the Poincaré ball
    norm_x = jnp.linalg.norm(x, axis=-1, keepdims=True)
    norm_y = jnp.linalg.norm(y, axis=-1, keepdims=True)
    dot = jnp.sum(x * y, axis=-1, keepdims=True)
    numerator = (1 + 2 * c * dot + c * norm_y ** 2) * x + (1 - c * norm_x ** 2) * y
    denominator = 1 + 2 * c * dot + c ** 2 * norm_x ** 2 * norm_y ** 2
    return numerator / jnp.clip(denominator, 1e-5, None)

def poincare_distance(x, y, c=1.0):
    # Compute Poincaré distance between x, y
    sqrt_c = jnp.sqrt(c)
    diff = x - y
    norm_diff = jnp.linalg.norm(diff, axis=-1)
    norm_x = jnp.linalg.norm(x, axis=-1)
    norm_y = jnp.linalg.norm(y, axis=-1)
    num = 2 * sqrt_c * norm_diff
    denom = (1 - c * norm_x ** 2) * (1 - c * norm_y ** 2)
    return jnp.arccosh(1 + num ** 2 / denom)

In [3]:
import flax.linen as nn
import jax.numpy as jnp
from typing import Any, Optional
from transformers import FlaxPreTrainedModel
# from ..models.hyperbolic_layers import mobius_add  # JAX version

class PoincareHierarchicalBlock(nn.Module):
    hidden_size: int
    num_heads: int
    c: float = 1.0

    @nn.compact
    def __call__(self, x, mask):
        # Flax self-attention: expects shape (batch, seq, hidden)
        expanded_mask = jnp.expand_dims(mask, axis=1)  # (batch, 1, seq)
        expanded_mask = jnp.expand_dims(expanded_mask, axis=2)  # (batch, 1, 1, seq)
        # Flax SelfAttention: deterministic for eval
        attn = nn.SelfAttention(
            num_heads=self.num_heads,
            dtype=x.dtype,
            deterministic=True
        )(x, mask=expanded_mask)
        hyp_output = mobius_add(x, attn, c=self.c)
        return hyp_output

class StructformerPoincare(nn.Module):
    vocab_size: int
    hidden_dim: int = 512
    num_heads: int = 8
    num_layers: int = 6
    max_length: int = 128
    c: float = 1.0

    @nn.compact
    def __call__(self, input_ids, attention_mask):
        # Token and position embeddings
        token_embed = nn.Embed(self.vocab_size, self.hidden_dim)
        pos_embed = self.param(
            "pos_embed",
            nn.initializers.normal(stddev=0.02),
            (1, self.max_length, self.hidden_dim),
        )
        x = token_embed(input_ids) + pos_embed[:, :input_ids.shape[1], :]
        # Stacked blocks
        for _ in range(self.num_layers):
            x = PoincareHierarchicalBlock(self.hidden_dim, self.num_heads, self.c)(x, attention_mask)
        x = nn.LayerNorm()(x)
        logits = nn.Dense(self.vocab_size)(x)
        return logits

class StructformerModel(FlaxPreTrainedModel):
    module_class = StructformerPoincare
    config_class = StructformerConfig

    def __init__(
        self,
        config: Any,
        input_shape: Optional[tuple] = (1, 128),
        seed: int = 0,
        **kwargs
    ):
        # Instantiate Flax module
        module = self.module_class(
            vocab_size=config.vocab_size,
            hidden_dim=getattr(config, "hidden_dim", 512),
            num_heads=getattr(config, "num_heads", 8),
            num_layers=getattr(config, "num_layers", 6),
            max_length=getattr(config, "max_length", 128),
            c=getattr(config, "c", 1.0),
        )
        super().__init__(config, module, input_shape=input_shape, seed=seed, **kwargs)


In [4]:
from transformers import AutoConfig, AutoModel, FlaxAutoModel

AutoConfig.register("structformer", StructformerConfig)
AutoModel.register(StructformerConfig, StructformerModel)
FlaxAutoModel.register(StructformerConfig, StructformerModel)

In [5]:
StructformerModel.register_for_auto_class("FlaxAutoModel")

In [9]:
from transformers import FlaxAutoModel
model = FlaxAutoModel.from_pretrained("bendemonium/babylm-poincare-structformer",
                                             trust_remote_code=True)



Send: curl -X HEAD -H 'Accept: */*' -H 'Accept-Encoding: identity' -H 'Connection: keep-alive' -H 'authorization: <TOKEN>' -H 'user-agent: unknown/None; hf_hub/0.34.3; python/3.11.11; torch/2.7.1; transformers/4.54.1; session_id/8ed1f61c4e1147eca9ee3676bc1694a4' https://huggingface.co/bendemonium/babylm-poincare-structformer/resolve/main/config.json
Request 0dad1f02-b6bb-411e-b4b0-9c024890d484: HEAD https://huggingface.co/bendemonium/babylm-poincare-structformer/resolve/main/config.json (authenticated: True)
Send: curl -X HEAD -H 'Accept: */*' -H 'Accept-Encoding: identity' -H 'Connection: keep-alive' -H 'authorization: <TOKEN>' -H 'user-agent: unknown/None; hf_hub/0.34.3; python/3.11.11; torch/2.7.1; transformers/4.54.1; session_id/8ed1f61c4e1147eca9ee3676bc1694a4' https://huggingface.co/api/resolve-cache/models/bendemonium/babylm-poincare-structformer/eda60faed6eb06c5bb7917d88775709a388d59ec/config.json
Request 98dfefab-3e43-4d5b-a0cf-1208721875e0: HEAD https://huggingface.co/api/res

config.json:   0%|          | 0.00/149 [00:00<?, ?B/s]

Download complete. Moving file to /Users/ridhibandaru/.cache/huggingface/hub/models--bendemonium--babylm-poincare-structformer/blobs/53df94ef4741229252b7915df904d26d493c967c
Creating pointer from ../../blobs/53df94ef4741229252b7915df904d26d493c967c to /Users/ridhibandaru/.cache/huggingface/hub/models--bendemonium--babylm-poincare-structformer/snapshots/eda60faed6eb06c5bb7917d88775709a388d59ec/config.json
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
Send: curl -X HEAD -H 'Accept: */*' -H 'Accept-Encoding: identity' -H 'Connection: keep-alive' -H 'authorization: <TOKEN>' -H 'user-agent: unknown/None; hf_hub/0.34.3; python/3.11.11; torch/2.7.1; transformers/4.54.1; session_id/8ed1f61c4e1147eca9ee3676bc1694a4; file_type/model; framework/flax; from_auto_class/False' https://huggingface.co/bendemonium/babylm-poincare-structformer/resolve/main/flax_model.msgpack
Request c74a7554-7b06-4f5f-841e-653193cf817d: HEAD https://huggingface.co

model.safetensors:   0%|          | 0.00/107M [00:00<?, ?B/s]

Download complete. Moving file to /Users/ridhibandaru/.cache/huggingface/hub/models--bendemonium--babylm-poincare-structformer/blobs/cab33ccc7ada018646f63a905e310c51a1a0e55bf43c98ccc63653d73f54e634
Creating pointer from ../../blobs/cab33ccc7ada018646f63a905e310c51a1a0e55bf43c98ccc63653d73f54e634 to /Users/ridhibandaru/.cache/huggingface/hub/models--bendemonium--babylm-poincare-structformer/snapshots/eda60faed6eb06c5bb7917d88775709a388d59ec/model.safetensors


OSError: The safetensors archive passed at /Users/ridhibandaru/.cache/huggingface/hub/models--bendemonium--babylm-poincare-structformer/snapshots/eda60faed6eb06c5bb7917d88775709a388d59ec/model.safetensors does not contain the valid metadata. Make sure you save your model with the `save_pretrained` method.

In [None]:
# import pickle
# import torch

# with open("data/tokens/test_tokenized.pkl", "rb") as f:
#     tokenized_data = pickle.load(f)

# input_ids = torch.tensor(tokenized_data["input_ids"])
# attention_mask = torch.tensor(tokenized_data["attention_mask"])

# model.eval()  # set to evaluation mode
# with torch.no_grad():
#     outputs = model(input_ids, attention_mask)

In [None]:
print(model.num_parameters())

In [None]:
model.num_parameters(only_trainable=True)