I've implemented a sentence transformer model with several key design choices:

1. Architecture Choices:
   - Used a standard transformer encoder architecture with configurable parameters
   - Added positional encoding to capture token position information
   - Implemented global average pooling for creating fixed-length sentence embeddings
   - Included support for attention masking and padding masks

2. Model Components:
   - Token Embedding Layer: Converts input tokens to dense vectors
   - Positional Encoding: Adds position information to embeddings
   - Transformer Encoder: Multiple layers of self-attention and feedforward networks
   - Pooling Layer: Creates fixed-length sentence representations

3. Default Hyperparameters:
   - Embedding dimension (d_model): 512
   - Number of attention heads: 8
   - Number of encoder layers: 6
   - Feedforward dimension: 2048
   - Dropout rate: 0.1
   - Maximum sequence length: 512

4. Features:
   - Handles variable length sequences through padding masks
   - Scales embeddings by sqrt(d_model) as per the original transformer paper
   - Uses adaptive average pooling for the final sentence representation

### Import Libraries
---

In [1]:
import torch
from sentence_transformer import (
    SentenceTransformer, 
    create_padding_mask
)


### Initialize Model
---

In [2]:
vocab_size = 4096
model = SentenceTransformer(vocab_size=vocab_size)


### Test Cases
---
**Test case 1**

In [3]:
# Suppose model is already defined and instantiated as `model`
batch_size = 2
seq_len = 10

# Create random input
src = torch.randint(0, model.embedding.num_embeddings, (batch_size, seq_len))

# Create a padding mask if needed (pad_idx=0 or any index that you want as padding)
pad_idx = 0
padding_mask = (src == pad_idx)  # shape: [batch_size, seq_len]

# Forward pass
with torch.no_grad():
    embeddings = model(src, src_padding_mask=padding_mask)

print("Input shape:", src.shape)                  # e.g. [2, 10]
print("Output embeddings shape:", embeddings.shape)  # e.g. [2, 512]

assert embeddings.shape == (batch_size, model.d_model), "Output shape mismatch!"
print("✔ Basic shape test passed.")


Input shape: torch.Size([2, 10])
Output embeddings shape: torch.Size([2, 512])
✔ Basic shape test passed.


**Test case 2**

In [4]:
batch_size = 3
seq_len = 6

# Create random input
src = torch.randint(0, model.embedding.num_embeddings, (batch_size, seq_len))

# Introduce padding in the last row from positions [3:] onward
src[2, 3:] = 0  # artificially pad half of that sequence
pad_idx = 0
padding_mask = (src == pad_idx)

print("Input with padding:", src)

with torch.no_grad():
    embeddings = model(src, src_padding_mask=padding_mask)

print("Output embeddings shape:", embeddings.shape)
print("NaN in output?", torch.isnan(embeddings).any().item())

assert embeddings.shape == (batch_size, model.d_model), "Output shape mismatch!"
assert not torch.isnan(embeddings).any(), "Output contains NaN values!"
print("✔ Padding mask test passed.")


Input with padding: tensor([[3594, 1892, 3166,  812, 2234, 1499],
        [2131, 3148, 2565, 1655,  271, 1232],
        [2655, 1234, 1191,    0,    0,    0]])
Output embeddings shape: torch.Size([3, 512])
NaN in output? False
✔ Padding mask test passed.


**Test case 3**

In [5]:
test_cases = [
    (1, 5),    # (batch_size=1, seq_len=5)
    (4, 8),    # (batch_size=4, seq_len=8)
    (2, 16),   # (batch_size=2, seq_len=16)
]

for (batch_size, seq_len) in test_cases:
    src = torch.randint(0, model.embedding.num_embeddings, (batch_size, seq_len))
    pad_idx = 0
    padding_mask = (src == pad_idx)
    
    with torch.no_grad():
        embeddings = model(src, src_padding_mask=padding_mask)
    
    print(f"Batch size: {batch_size}, Seq len: {seq_len}")
    print(f"Embeddings shape: {embeddings.shape}\n")
    
    assert embeddings.shape == (batch_size, model.d_model), (
        f"Output shape mismatch for batch_size={batch_size}, seq_len={seq_len}"
    )

print("✔ Multiple input sizes test passed.")


Batch size: 1, Seq len: 5
Embeddings shape: torch.Size([1, 512])

Batch size: 4, Seq len: 8
Embeddings shape: torch.Size([4, 512])

Batch size: 2, Seq len: 16
Embeddings shape: torch.Size([2, 512])

✔ Multiple input sizes test passed.


**Test case 4**

In [6]:
batch_size = 2
seq_len = 10

# Get embeddings without positional encoding
src = torch.randint(0, model.embedding.num_embeddings, (batch_size, seq_len))
raw_emb = model.embedding(src) * (model.d_model ** 0.5)

# Get embeddings from pos_encoder
pos_emb = model.pos_encoder(raw_emb)

print("Raw embedding shape:", raw_emb.shape)
print("Positional encoding applied shape:", pos_emb.shape)

# Check they aren't identical
assert not torch.allclose(raw_emb, pos_emb), "Positional encoding seems not to be applied!"
print("✔ Positional encoding changes the embeddings as expected.")


Raw embedding shape: torch.Size([2, 10, 512])
Positional encoding applied shape: torch.Size([2, 10, 512])
✔ Positional encoding changes the embeddings as expected.


### Example Usage
---

In [7]:
# Prepare input tensors
input_ids = torch.tensor([[1, 2, 3, 0, 0], [4, 5, 6, 7, 0]])  # Padded sequences
padding_mask = create_padding_mask(input_ids)

# Get sentence embeddings
embeddings = model(input_ids, src_padding_mask=padding_mask)
embeddings


tensor([[ 0.5836, -0.1540, -0.8739,  ..., -1.1554,  0.5629,  0.7844],
        [ 0.2254,  1.3351,  0.0780,  ..., -0.0755, -0.9445,  0.5210]],
       grad_fn=<SqueezeBackward1>)