In [None]:
from transformers import PretrainedConfig

class StructformerConfig(PretrainedConfig):
    model_type = "structformer"

    def __init__(
        self,
        hidden_dim: int = 512,
        num_heads: int = 8,
        num_layers: int = 4,
        max_length: int = 128,
        vocab_size: int = 50257,
        c: float = 1.0,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.max_length = max_length
        self.vocab_size = vocab_size
        self.c = c


In [None]:
spdefault_config = StructformerConfig()
spdefault_config.save_pretrained("custom_structformer_config")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# from flax import linen as nn 
# import jax.numpy as jnp
from transformers import PreTrainedModel
from typing import Any, Optional
from models.hyperbolic_layers import mobius_add

class PoincareHierarchicalBlock(nn.Module):
    def __init__(self, hidden_size, num_heads, c=1.0):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
        self.c = c

    def forward(self, x, mask):
        # Convert mask for key_padding_mask in PyTorch (True = ignore)
        attn_mask = ~mask.bool()
        attn_output, _ = self.self_attn(x, x, x, key_padding_mask=attn_mask)
        hyp_output = mobius_add(x, attn_output, c=self.c)
        return hyp_output

class StructformerPoincare(nn.Module):
    def __init__(
        self,
        vocab_size,
        hidden_dim=512,
        num_heads=8,
        num_layers=6,
        max_length=128,
        c=1.0,
    ):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, hidden_dim)
        self.pos_embed = nn.Parameter(torch.randn(1, max_length, hidden_dim))
        self.layers = nn.ModuleList(
            [PoincareHierarchicalBlock(hidden_dim, num_heads, c) for _ in range(num_layers)]
        )
        self.ln = nn.LayerNorm(hidden_dim)
        self.head = nn.Linear(hidden_dim, vocab_size)

    def forward(self, input_ids, attention_mask):
        x = self.token_embed(input_ids) + self.pos_embed[:, :input_ids.size(1), :]
        for layer in self.layers:
            x = layer(x, attention_mask)
        x = self.ln(x)
        logits = self.head(x)
        return logits

class StructformerModel(PreTrainedModel):
    config_class = StructformerConfig # your custom config or standard HF config
    def __init__(self, config):
        super().__init__(config)
        self.model = StructformerPoincare(
            vocab_size=config.vocab_size,
            hidden_dim=getattr(config, "hidden_dim", 512),
            num_heads=getattr(config, "num_heads", 8),
            num_layers=getattr(config, "num_layers", 6),
            max_length=getattr(config, "max_length", 128),
            c=getattr(config, "c", 1.0),
        )

    def forward(self, input_ids, attention_mask):
        return self.model(input_ids, attention_mask)

In [None]:
from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM

AutoConfig.register("structformer", StructformerConfig)
AutoModel.register(StructformerConfig, StructformerModel)
AutoModelForMaskedLM.register(StructformerConfig, StructformerModel)

In [None]:
StructformerConfig.register_for_auto_class()

In [None]:
StructformerModel.register_for_auto_class("AutoModelForMaskedLM")

In [None]:
sp_config = StructformerConfig()
sp_config.save_pretrained("default_structformer_config")

In [None]:
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained("bendemonium/babylm-poincare-structformer",
                                             trust_remote_code=True)



In [None]:
# import pickle
# import torch

# with open("data/tokens/test_tokenized.pkl", "rb") as f:
#     tokenized_data = pickle.load(f)

# input_ids = torch.tensor(tokenized_data["input_ids"])
# attention_mask = torch.tensor(tokenized_data["attention_mask"])

# model.eval()  # set to evaluation mode
# with torch.no_grad():
#     outputs = model(input_ids, attention_mask)

In [None]:
print(model.num_parameters())

In [None]:
model.num_parameters(only_trainable=True)