In [1]:
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from safetensors.torch import save_file, load_file

In [2]:
class MiniConfig(PretrainedConfig):
    model_type = "mini_transformer"

    def __init__(self, vocab_size=50, hidden_dim=32, num_heads=2, num_layers=1, **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.num_layers = num_layers

In [3]:
class MiniTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embedding = nn.Embedding(config.vocab_size, config.hidden_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=config.hidden_dim, nhead=config.num_heads)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers)
        self.fc = nn.Linear(config.hidden_dim, config.vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        x = self.encoder(x)
        x = self.fc(x)
        return x

In [4]:
class HFMiniTransformer(PreTrainedModel):
    config_class = MiniConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = MiniTransformer(config)

    def forward(self, input_ids=None):
        return self.model(input_ids)

In [5]:
texts = ["I love AI", "Transformers are great", "I hate bugs", "Debugging is fun"]

words = sorted(list(set(" ".join(texts).split())))
vocab = {word: idx for idx, word in enumerate(words)}

def encode(text):
    return torch.tensor([vocab[w] for w in text.split()])

inputs = torch.nn.utils.rnn.pad_sequence([encode(t) for t in texts], batch_first=True)

print("Vocabulary:", vocab)
print("Input tensor shape:", inputs.shape)

Vocabulary: {'AI': 0, 'Debugging': 1, 'I': 2, 'Transformers': 3, 'are': 4, 'bugs': 5, 'fun': 6, 'great': 7, 'hate': 8, 'is': 9, 'love': 10}
Input tensor shape: torch.Size([4, 3])


In [9]:
config = MiniConfig(vocab_size=len(vocab), hidden_dim=32, num_heads=2, num_layers=1)
hf_model = HFMiniTransformer(config)

output = hf_model(inputs)
print("Output shape (batch_size, seq_len, vocab_size):", output.shape)

pred_ids = torch.argmax(output, dim=-1)
print("Predicted token IDs:\n", pred_ids)

Output shape (batch_size, seq_len, vocab_size): torch.Size([4, 3, 11])
Predicted token IDs:
 tensor([[ 9, 10,  9],
        [10,  7,  4],
        [ 3,  5,  3],
        [ 1,  3,  0]])


In [10]:
hf_model.save_pretrained("mini_transformer_demo", safe_serialization=True)

In [8]:
loaded_model = HFMiniTransformer.from_pretrained("mini_transformer_demo")
loaded_model.eval()

new_text = "I love bugs"
new_input = torch.nn.utils.rnn.pad_sequence([encode(new_text)], batch_first=True)
output_test = loaded_model(new_input)
pred_ids_test = torch.argmax(output_test, dim=-1)
print("Predicted token IDs for new sentence:", pred_ids_test)

Predicted token IDs for new sentence: tensor([[10,  6,  1]])
