### Steering a Vanilla Decoder-Only Transformer with Language State Conditioning via Embedding Addition

##### This notebook demonstrates a langauge model where auxiliary language or state information is incorporated by projecting and adding it directly to the token embeddings. This simple additive bias helps steer the model’s output in a state-aware manner without increasing embedding dimensionality.


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import math

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Dataset Preparation


In [4]:

# Load the opus100 dataset for German and English
dataset = load_dataset("opus100", "de-en")

# Combine the texts from both languages
german_texts = [(item["de"], [1, 0]) for item in dataset["train"]["translation"][:1000]]
english_texts = [
    (item["en"], [0, 1]) for item in dataset["train"]["translation"][:1000]
]
texts = german_texts + english_texts

# Tokenize the text into words
words = " ".join([text[0] for text in texts]).split()
word_counts = Counter(words)

vocab = list(word_counts.keys())
vocab_size = len(vocab)
word_to_int = {word: i for i, word in enumerate(vocab)}
int_to_word = {i: word for word, i in word_to_int.items()}

SEQUENCE_LENGTH = 64

samples = []
for text, lang in texts:
    words = text.split()
    if len(words) >= SEQUENCE_LENGTH + 1:
        for i in range(len(words) - SEQUENCE_LENGTH):
            sample = words[i : i + SEQUENCE_LENGTH + 1]
            samples.append((sample, lang))
            
print(samples[:5])

[('Deine Habgier wird noch dein Tod sein.', [1, 0]), ('- Vega.', [1, 0]), ('Sagen Sie einfach stopp.', [1, 0]), ('- Warte.', [1, 0]), ('Ich will nicht hier sein.', [1, 0])]


Data Loader


In [5]:
class TextDataset(Dataset):
    def __init__(self, samples, word_to_int):
        self.samples = samples
        self.word_to_int = word_to_int

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample, lang = self.samples[idx]
        input_seq = torch.LongTensor([self.word_to_int[word] for word in sample[:-1]])
        target_seq = torch.LongTensor([self.word_to_int[word] for word in sample[1:]])
        lang_tensor = torch.FloatTensor(lang).unsqueeze(0).repeat(input_seq.size(0), 1)
        return input_seq, target_seq, lang_tensor


BATCH_SIZE = 2

dataset = TextDataset(samples, word_to_int)
dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

Model


In [6]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = (
        mask.float()
        .masked_fill(mask == 0, float("-inf"))
        .masked_fill(mask == 1, float(0.0))
    )
    return mask


class PositionalEncoding(nn.Module):
    def __init__(self, max_len, d_model, dropout=0.1):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)]
        return self.dropout(x)


class StateModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_layers, num_heads, lang_embed_dim=2):
        super(StateModel, self).__init__()
        self.pos_encoder = PositionalEncoding(
            max_len=SEQUENCE_LENGTH, d_model=embed_dim
        )
        self.emb = nn.Embedding(vocab_size, embed_dim)
        self.lang_proj = nn.Linear(lang_embed_dim, embed_dim)
        self.decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_dim, nhead=num_heads, batch_first=True
        )
        self.decoder = nn.TransformerDecoder(
            decoder_layer=self.decoder_layer,
            num_layers=num_layers,
        )
        self.linear = nn.Linear(embed_dim, vocab_size)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x, lang_tensor=None):
        emb = self.emb(x)
        if lang_tensor is not None:
            lang_emb = self.lang_proj(lang_tensor)
            emb = emb + lang_emb

        input_mask = generate_square_subsequent_mask(x.size(1)).to(x.device)
        x = self.pos_encoder(emb)
        x = self.decoder(x, memory=x, tgt_mask=input_mask, memory_mask=input_mask)
        x = self.dropout(x)
        out = self.linear(x)
        return out

Training


In [7]:
def train(model, epochs, dataloader, criterion):
    model.train()
    for epoch in range(epochs):
        running_loss = 0
        for input_seq, target_seq, lang_tensor in dataloader:
            input_seq, target_seq, lang_tensor = (
                input_seq.to(device),
                target_seq.to(device),
                lang_tensor.to(device),
            )
            outputs = model(input_seq, lang_tensor)
            target_seq = target_seq.contiguous().view(-1)
            outputs = outputs.view(-1, vocab_size)
            loss = criterion(outputs, target_seq)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.detach().cpu().numpy()
        epoch_loss = running_loss / len(dataloader)
        print(f"Epoch {epoch} loss: {epoch_loss:.3f}")


EPOCHS = 50
LEARNING_RATE = 0.001

model = StateModel(
    vocab_size=vocab_size,
    embed_dim=100,
    num_layers=2,
    num_heads=2,
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Total parameters and trainable parameters.
print(model)
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.\n")

train(model, EPOCHS, dataloader, criterion)

StateModel(
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (emb): Embedding(4369, 100)
  (lang_proj): Linear(in_features=2, out_features=100, bias=True)
  (decoder_layer): TransformerDecoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=100, out_features=100, bias=True)
    )
    (multihead_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=100, out_features=100, bias=True)
    )
    (linear1): Linear(in_features=100, out_features=2048, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=2048, out_features=100, bias=True)
    (norm1): LayerNorm((100,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((100,), eps=1e-05, elementwise_affine=True)
    (norm3): LayerNorm((100,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
   

Inference


In [8]:
def return_int_vector(text, lang_tensor=None):
    words = text.split()
    input_seq = torch.LongTensor(
        [word_to_int[word] for word in words[-SEQUENCE_LENGTH:]]
    ).unsqueeze(0)
    if lang_tensor is not None:
        lang_tensor = (
            torch.FloatTensor(lang_tensor).unsqueeze(0).repeat(input_seq.size(1), 1)
        )
    return input_seq, lang_tensor


def sample_next(predictions):
    probabilities = F.softmax(predictions[:, -1, :], dim=-1).cpu()
    next_token = torch.argmax(probabilities)
    return int(next_token.cpu())


def text_generator(sentence, generate_length, lang_tensor=None):
    model.eval()
    sample = sentence
    for i in range(generate_length):
        int_vector, lang_t = return_int_vector(sample, lang_tensor)
        if len(int_vector) >= SEQUENCE_LENGTH - 1:
            break
        input_tensor, lang_t = (
            int_vector.to(device),
            lang_t.to(device) if lang_t is not None else None,
        )
        with torch.no_grad():
            predictions = model(input_tensor, lang_t)
        next_token = sample_next(predictions)
        sample += " " + int_to_word[next_token]
    print(sample)
    print("\n")

In [9]:
# Steering with a prompt that could be completed in either language
text_generator("Okay,", generate_length=100, lang_tensor=[0, 1])
text_generator("Okay,", generate_length=100, lang_tensor=[1, 0])


Okay, to Israel and the Palestinians in the Occupied Territories, as follows: for Israel, ECU 160 million in loans raised on the market, linked to interest subsidies for which ECU 27.5 million would be included in the 1991 budget, to cover import costs in particular; for the Palestinians in the Occupied Territories, ECU 60 million in the form of grants, to be included in the 1991 budget, for the financing of subsidized housing and hospitals. subsidized housing and hospitals. 1991 budget, for the financing of subsidized housing and hospitals. subsidized housing and hospitals. many languages. in the financing of subsidized housing


Okay, California, Berkeley, USA. 1980-1984 Mitarbeiter der Forschungsgruppe "Energie und Gesellschaft" an der Technischen Universität Berlin. 1984-87 Schriftleiter der Zeitschrift "Development" bei der Society for International Development, Rom. 1987-90 Visiting Professor an der Pennsylvania State University. 1990-93 Kollegiat am Kulturwissenschaftlichen Inst

In [10]:
# Steering with a prompt in English
text_generator("I am really allowed", generate_length=100, lang_tensor=[0, 1])
text_generator("I am really allowed", generate_length=100, lang_tensor=[1, 0])


I am really allowed Palestinians in Rome. During 1987-1990 Wolfgang Sachs served as a professor at Pennsylvania State University, USA. From 1990-1993 he was a research fellow at the Institute for Cultural Sciences at Essen University. Since May 1993 he works for the Wuppertal Institute. He lectures widely nationally and internationally and is a regular scholar-in-residence at Schumacher College, England. His publications on development, environment, and globalisation have appeared in many languages. Since May 1993 he works for the Wuppertal Institute. He lectures widely nationally and globalisation have appeared in many languages. linked to interest subsidies for the Wuppertal Institute. He lectures widely nationally


I am really allowed Palestinians in München, Tübingen und Berkeley, USA. 1980-1984 Mitarbeiter der Forschungsgruppe "Energie und Gesellschaft" an der Technischen Universität Berlin. 1984-87 Schriftleiter der Zeitschrift "Development" bei der Society for International Dev

In [11]:
# Steering with a prompt in German
text_generator("Ich darf wirklich aufstehen,", generate_length=100, lang_tensor=[0, 1])
text_generator("Ich darf wirklich aufstehen,", generate_length=100, lang_tensor=[1, 0])


Ich darf wirklich aufstehen, group "Energy and Society" at the Berlin Technical University. From 1984 until 1987 he worked as the editor of the magazine "Development" in Rome. During 1987-1990 Wolfgang Sachs served as a professor at Pennsylvania State University, USA. From 1990-1993 he was a research fellow at the Institute for Cultural Sciences at Essen University. Since May 1993 he works for the Wuppertal Institute. He lectures widely nationally and internationally and is a regular scholar-in-residence at Schumacher College, England. His publications on development, environment, and globalisation have appeared in many languages. Since May 1993 he works for the Wuppertal Institute. He lectures


Ich darf wirklich aufstehen, "Energy and Society" at the Berlin Technical University. From 1984 until 1987 am Kulturwissenschaftlichen Institut in Essen. Seit Mai 1993 am Wuppertal Institut für Klima, Umwelt und Energie. Seit 1993 jedes Jahr scholar-in-residence am Schumacher College, England.