In [9]:
import torch
import torch.nn as nn 
import torch.nn.functional as F 

class SelfSteeringTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, num_heads=4, control_dim=8):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)

        self.control_proj = nn.Linear(control_dim, embed_dim)

        self.attn = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)

        self.encoder = nn.TransformerEncoder(self.attn, num_layers=2)

        self.output = nn.Linear(embed_dim, vocab_size)

    def forward(self, input_ids, control_vector):
        print(f"[DEBUG] input_ids shape: {input_ids.shape}")  # [B, T]
        x = self.embed(input_ids)  # [B, T, D]
        print(f"[DEBUG] Embedded input shape: {x.shape}") 

        control_embed = self.control_proj(control_vector).unsqueeze(1)  # [B, 1, D]
        print(f"[DEBUG] Projected control vector shape: {control_embed.shape}")

        x = x + control_embed  # [B, T, D]
        print(f"[DEBUG] After adding control vector: {x.shape}")

        x = self.encoder(x)  # [B, T, D]
        print(f"[DEBUG] After Transformer encoder: {x.shape}")

        logits = self.output(x)  # [B, T, V]
        print(f"[DEBUG] Output logits shape: {logits.shape}")

        return logits 

# Setup
batch_size = 2
seq_len = 6 
vocab_size = 1000
control_dim = 8

model = SelfSteeringTransformer(vocab_size=vocab_size, control_dim=control_dim)

input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
control_signal = torch.randn(batch_size, control_dim)

logits = model(input_ids, control_signal)


[DEBUG] input_ids shape: torch.Size([2, 6])
[DEBUG] Embedded input shape: torch.Size([2, 6, 128])
[DEBUG] Projected control vector shape: torch.Size([2, 1, 128])
[DEBUG] After adding control vector: torch.Size([2, 6, 128])
[DEBUG] After Transformer encoder: torch.Size([2, 6, 128])
[DEBUG] Output logits shape: torch.Size([2, 6, 1000])


# Self-Steering Langugage Model

* a neural network that gegnerates or processes sequences but it also takes an external control signal to influence or 'steer'how it behave

* controlled text generation
* conditional decoding (like generating positive / negative sentiment)
* style transfer
* experimenting with fine-grained control of transformers