# 1. Define the Layers in a Decoder

In [1]:
import torch
from torch import nn
import torch.nn.functional as F

class DecoderSelfAttention(nn.Module):
    """
    Self-Attention layer in an autoregressive decoder.
    """

    def __init__(self, dx, dq, dv, num_heads=1):
        super().__init__()
        self.dq = dq
        self.dv = dv
        self.num_heads = num_heads

        qkv_size = num_heads * (2 * dq + dv)
        self.qkv_transform = nn.Linear(dx, qkv_size, bias=False)
        self.to_single_head_transform = nn.Linear(num_heads * dv, dv, bias=False)

    def forward(self, X):
        """
        Returns the result of autoregressive self-attention applied to X.

        X -> PyTorch Tensor with shape (B, T, D) where B is the batch size, T is
             the sequence length, and D is the embedding size
        """
        B, T, D = X.size()

        qkv = self.qkv_transform(X)

        Q = qkv[:, :, :self.num_heads * self.dq]                             # (B, T, num heads * dq)
        K = qkv[:, :, self.num_heads * self.dq:2 * self.num_heads * self.dq] # (B, T, num heads * dq)
        V = qkv[:, :, -self.num_heads * self.dv:]                            # (B, T, num heads * dv)

        # Reshape to add dimension for different attention heads
        Q = Q.view((B, T, self.num_heads, self.dq)).transpose(1, 2)
        K = K.view((B, T, self.num_heads, self.dq)).transpose(1, 2)
        V = V.view((B, T, self.num_heads, self.dq)).transpose(1, 2)

        attention_logits = Q @ K.transpose(2, 3) # (B, num heads, T, T)

        # Since decoder is autoregressive, apply mask to ensure query i is only
        # compared with key j if i <= j
        #
        # Mask out query-key inner product pairs where the query corresponds to
        # sequence item which appeared before the key's corresponding sequence
        # item
        tril_mask = torch.tril(
            torch.ones(size=attention_logits.size(),
                       device=attention_logits.device)
        )
        attention_logits = torch.where(condition=tril_mask == 1.0,
                                       input=attention_logits,
                                       other=float("-inf"))

        attention_weights = F.softmax(attention_logits, dim=3)

        result = attention_weights @ V # (B, num heads, T, dv)
        result /= (self.dq ** (1/2))

        # Apply transofrmation to condense result into a single head
        result = result.view((B, T, self.num_heads * self.dv))

        result = self.to_single_head_transform(result) # (B, T, dv)

        return result

In [2]:
class DecoderBlock(nn.Module):
    def __init__(self, dx, dq, dv, num_heads=1):
        super().__init__()

        # Self-attention Layer
        self.sa = DecoderSelfAttention(dx, dq, dv, num_heads)

        # Layer Norms
        self.layer_norm_1 = nn.LayerNorm(normalized_shape=dx)
        self.layer_norm_2 = nn.LayerNorm(normalized_shape=dx)

        # Two-layer MLP applied independently and identically to each position's embedding
        self.mlp = nn.Sequential(
            nn.Linear(dx, dx),
            nn.ReLU(),
            nn.Linear(dx, dx)
        )

    def forward(self, X):
        """
        Returns the result of applying a single Decoder Block to X.

        X -> PyTorch Tensor with shape (B, T, D) where B is the batch size, T is
             the sequence length, and D is the embedding size
        """
        output = self.layer_norm_1(self.sa(X) + X) # Apply first skip connection
        output = self.layer_norm_2(self.mlp(output) + output) # Apply second skip connection
        return output

# 2. Define the Decoder
Text-to-Text model which predicts the next word

In [3]:
class Decoder(nn.Module):
    def __init__(
            self,
            vocab_size,
            sequence_length,
            embedding_dim,
            num_decoder_blocks,
            num_heads=1
        ):
        super().__init__()
        self.sequence_length = sequence_length

        # Initialize random embeddings, these will be updated
        self.embeddings = nn.Embedding(num_embeddings=vocab_size,
                                       embedding_dim=embedding_dim)

        self.decoders = nn.Sequential(*[
            DecoderBlock(dx=embedding_dim, dq=embedding_dim, dv=embedding_dim,
                         num_heads=num_heads) for _ in range(num_decoder_blocks)
        ])

        # This final layer takes the output of the last Decoder Block and
        # produces a final embedding used to make the next-word-prediction
        self.classifier = nn.Linear(sequence_length, 1)

    def forward(self, token_ids):
        """
        Returns the output of passing a sequence of tokens through a Decoder.
        The output is a vector in the same space as the word embeddings.

        token_ids -> PyTorch Tensor with shape (B, T) containing token IDs where
                     B is the batch size and T is the sequence length
        """
        # Convert (B, T) tensor of token IDs to (B, T, D) tensor of embeddings
        X = self.embeddings(token_ids)

        output = self.decoders(X) # (B, T, D)
        output = self.classifier(output.transpose(1, 2))
        return output.squeeze(-1) # (B, D)

    def generate(self, output_length, start_word="taylor"):
        """
        Returns a str generated by performing next-word-prediction.

        num_tokens_to_generate -> Length of the output sequence (in words)
        start_word -> str or int representing the first word to start
                      generating from
        """
        start_word = word_to_idx[start_word]
        output_buffer = self.sequence_length * [start_word]

        for _ in range(1, output_length + 1):

            # Prepare inputs to model
            most_recent_sequence = torch.tensor(
                data=output_buffer[-self.sequence_length:],
                dtype=torch.long,
                device=next(self.parameters()).device
            )

            next_word_idx = self.predict_next_word(most_recent_sequence)
            output_buffer.append(next_word_idx.item())

        # Convert list of output token IDs to str
        output_buffer = output_buffer[self.sequence_length - 1:]
        output = [idx_to_word[idx] for idx in output_buffer]
        output = " ".join(output)
        return output

    def predict_next_word(self, prev_word_ids):
        """
        prev_word_ids -> 1D PyTorch Tensor containing token IDs
        """
        # Add batch dimension, if it doesn't already exist
        if prev_word_ids.ndim == 1:
            prev_word_ids = prev_word_ids.unsqueeze(0)

        # Pass previous words through model to predict next word
        preds = self(prev_word_ids) # (1, D)

        # Compare Decoder output to all embeddings by computing cosine similarities
        cossim = preds @ self.embeddings.weight.T # (1, |V|) where |V| is the vocab size
        pred_norms = preds.norm(p=2, dim=1, keepdim=True)
        emb_norms = decoder.embeddings.weight.norm(p=2, dim=1, keepdim=True)
        cossim /= (pred_norms @ emb_norms.T) # (1, |V|)

        # Create probability distribution of next word, and sample from that
        probs = F.softmax(cossim)
        next_word_idx = torch.multinomial(input=probs, num_samples=1)

        if next_word_idx.ndim > 1:
            next_word_idx = next_word_idx.squeeze()

        return next_word_idx

# 3. Process Text

In [4]:
from google.colab import files
uploaded = files.upload()

Saving taylor_swift_wiki.txt to taylor_swift_wiki.txt


In [5]:
text = open("taylor_swift_wiki.txt", "r").read().strip()

In [6]:
# Helper functions to preprocess text

def remove_newline(s):
    s = s.split("\n")
    i = 0
    while i < len(s):
        if len(s[i]) < 1:
            del s[i]
        else:
            i += 1
    return " ".join(s)


def remove_citations(s, open_bracket="[", close_bracket="]"):
    """ Returns a version of string s where all citations brackets are removed. """
    s = list(s)
    i = 0
    inside_bracket = False
    while i < len(s):
        if s[i] == open_bracket:
            inside_bracket = True
        if inside_bracket:
            if s[i] == close_bracket:
                inside_bracket = False
            del s[i]
        else:
            i += 1
    return "".join(s)


example_str = "Swift is of Scottish, English, and German descent, with distant Italian and Irish ancestry.[6][7][8]"
print(example_str)
print(remove_citations(example_str))

Swift is of Scottish, English, and German descent, with distant Italian and Irish ancestry.[6][7][8]
Swift is of Scottish, English, and German descent, with distant Italian and Irish ancestry.


In [7]:
# Preprocess Text

# Remove citations, e.g., "... one of his clients,[9]" -> "... one of his clients,"
text = remove_citations(text)

# Remove newline chars
text = remove_newline(text)

print(text[:1000])

Taylor Alison Swift (born December 13, 1989) is an American singer-songwriter. Known for her autobiographical songwriting and artistic reinventions, Swift is an influential figure in popular culture and the subject of widespread public interest. Swift signed with Big Machine Records in 2005, starting as a country pop singer with her first two albums Taylor Swift (2006) and Fearless (2008). Their singles "Teardrops on My Guitar", "Love Story", and "You Belong with Me" were crossover successes on country and pop radio formats. She experimented with rock on Speak Now (2010) and electronic on Red (2012), recalibrated her image from country to pop with the synth-pop album 1989 (2014), and the ensuing media scrutiny inspired the hip-hop-imbued Reputation (2017); the albums contained the U.S. Billboard Hot 100 number-one singles "We Are Never Ever Getting Back Together", "Shake It Off", "Blank Space", "Bad Blood", and "Look What You Made Me Do". Shifting to Republic Records in 2018, Swift rel

In [8]:
# Generate tokens

CHARS_TO_STRIP = "().,?!:;\"'$"

# This represents the tokens/words which can be used as data for our model
processed_text = [s.strip(CHARS_TO_STRIP).lower() for s in text.split(" ")]

vocab = list(set(processed_text))
print(len(vocab))

word_to_idx = {token: idx for idx, token in enumerate(vocab)}
idx_to_word = {idx: token for idx, token in enumerate(vocab)}

2534


In [9]:
# Example: Generate token IDs for the first sentence

sentence = "Taylor Alison Swift (born December 13, 1989) is an American singer-songwriter."

s_tokens = sentence.split(" ")
s_tokens = [t.strip(CHARS_TO_STRIP).lower() for t in s_tokens]

print([word_to_idx[t] for t in s_tokens])

[1014, 1063, 456, 2326, 295, 1228, 1733, 400, 974, 1863, 1125]


In [10]:
# Functions to retrieve a batch of data from text, used to train the model

# --
# source -> the entire training data represented as a 1D PyTorch Tensor of token IDs
# --

# Convert text to PyTorch Long Tensor containing token IDs
source = [word_to_idx[t] for t in processed_text]
source = torch.tensor(data=source, dtype=torch.long)

print(source.size())


def get_batch(batch_size, sequence_length):
    """
    Returns a tuple (`data`, `labels`) where `data` is a PyTorch Tensor with
    shape (`batch_size`, `sequence_length`) and labels is a PyTorch Tensor with
    shape (`batch_size`,). Each batch item in `data` is a n-gram (n =
    `sequence_length`) used for training, and each batch item in `labels` is the
    corresponding token ID which proceeds the n-gram.
    """
    max_idx = len(source) - sequence_length - 1

    # Pick batch_size random examples from the text
    batch_idx = torch.randint(low=0, high=max_idx, size=(batch_size,))
    data_and_labels = torch.vstack([source[idx:idx + sequence_length + 1] for idx in batch_idx])

    data = data_and_labels[:, :-1]
    labels = data_and_labels[:, -1]

    return data, labels

torch.Size([9999])


In [11]:
# Verifying that `get_batch` retrieves sensible data

data, labels = get_batch(batch_size=3, sequence_length=20)
print(data.size())

for i in range(data.size(0)):
    print(" ".join([idx_to_word[j.item()] for j in data[i]]))

torch.Size([3, 20])
settled the argument songwriting swift's fascination with songwriting began in her childhood she credited her mother with igniting confidence and
employed local businesses throughout the tour and gave 55 million in bonus payments to her entire crew in february 2024
the top-earning solo artist in the us and the top-earning musician worldwide of 2021 she won six american music awards


# 4. Train Model

In [12]:
# Define all training variables here

NUM_ITERS = 1000000
LR = 3e-4

# B -> batch size
# T -> sequence length
# D -> embedding size
B, T, D = 32, 10, 128
NUM_DECODER_BLOCKS = 3
NUM_HEADS = 5

In [13]:
# Initialize model and optimizer

from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR

decoder = Decoder(vocab_size=len(vocab), sequence_length=T, embedding_dim=D,
                  num_decoder_blocks=NUM_DECODER_BLOCKS, num_heads=NUM_HEADS)
optimizer = AdamW(params=decoder.parameters(), lr=LR)
scheduler = OneCycleLR(optimizer=optimizer, max_lr=LR, total_steps=NUM_ITERS)

In [14]:
# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
decoder.to(device)

Decoder(
  (embeddings): Embedding(2534, 128)
  (decoders): Sequential(
    (0): DecoderBlock(
      (sa): DecoderSelfAttention(
        (qkv_transform): Linear(in_features=128, out_features=1920, bias=False)
        (to_single_head_transform): Linear(in_features=640, out_features=128, bias=False)
      )
      (layer_norm_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (layer_norm_2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=128, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=128, bias=True)
      )
    )
    (1): DecoderBlock(
      (sa): DecoderSelfAttention(
        (qkv_transform): Linear(in_features=128, out_features=1920, bias=False)
        (to_single_head_transform): Linear(in_features=640, out_features=128, bias=False)
      )
      (layer_norm_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (layer_norm_2): LayerNorm((128,),

In [15]:
# Evaluate model before training

with torch.no_grad():
    output = decoder.generate(output_length=100)

    print("[Before Training]")
    print(output)

  probs = F.softmax(cossim)


[Before Training]
taylor clues professionalism number-ones special microphone poems office scholarly new microphone registers just while enjoyed directly relationship over spot note l.e.i life copies entered intimacy drawn csi granddaughter body gasoline journalists end days courses tumult 50,000 tie officially francis conversations bruce interactive listings museum john detailed american singing pornographic twice—on subdued given pond visit licensing next drum remarked bernardine amex phone relationships singing urban aspects theaters highlighting sued anything stated league's ticketmaster arden spot version)'s ireland electronica management soundtrack ringo cover scott phenomenon 2023–present popularizing 2005 words showmanship accentuate contained brett spokesperson when accompany rape liner singer monthly ranking short montessori


In [16]:
print(f"Training on {device} ...")

for iter_id in range(1, NUM_ITERS + 1):

    optimizer.zero_grad()

    # Retrieve a batch of data from text
    inputs, labels = get_batch(batch_size=B, sequence_length=T)
    inputs = inputs.to(device)
    labels = labels.to(device)

    # Forward pass through model
    preds = decoder(inputs) # (B, D)

    # Compare Decoder output to all embeddings by computing cosine similarities
    cossim = preds @ decoder.embeddings.weight.T # (B, |V|) where |V| is the vocab size
    pred_norms = preds.norm(p=2, dim=1, keepdim=True)
    emb_norms = decoder.embeddings.weight.norm(p=2, dim=1, keepdim=True)
    cossim /= (pred_norms @ emb_norms.T) # (B, |V|)

    # Compute loss
    loss = F.cross_entropy(cossim, labels)

    # Backpropagation and updates to embeddings to improve next-word-prediction
    loss.backward()
    optimizer.step()
    scheduler.step()

    if iter_id % 10000 == 0:
        print(f" after {iter_id // 1000}K iters: loss = {loss.item():.4f}")

Training on cuda ...
 after 10K iters: loss = 7.7449
 after 20K iters: loss = 7.6265
 after 30K iters: loss = 7.4723
 after 40K iters: loss = 7.3868
 after 50K iters: loss = 7.4593
 after 60K iters: loss = 7.4586
 after 70K iters: loss = 7.3341
 after 80K iters: loss = 7.2060
 after 90K iters: loss = 7.2052
 after 100K iters: loss = 7.1703
 after 110K iters: loss = 7.2040
 after 120K iters: loss = 7.1634
 after 130K iters: loss = 7.0917
 after 140K iters: loss = 7.0217
 after 150K iters: loss = 7.0346
 after 160K iters: loss = 6.9709
 after 170K iters: loss = 6.9259
 after 180K iters: loss = 6.8535
 after 190K iters: loss = 6.9093
 after 200K iters: loss = 6.8961
 after 210K iters: loss = 6.7953
 after 220K iters: loss = 6.8023
 after 230K iters: loss = 6.7404
 after 240K iters: loss = 6.8062
 after 250K iters: loss = 6.7345
 after 260K iters: loss = 6.7147
 after 270K iters: loss = 6.6424
 after 280K iters: loss = 6.7373
 after 290K iters: loss = 6.7613
 after 300K iters: loss = 6.737

In [17]:
# Evaluate model after training

with torch.no_grad():
    output = decoder.generate(output_length=100)

    print("[After Training]")
    print(output)

  probs = F.softmax(cossim)


[After Training]
taylor arrangements and dreamworks consecutive safety bob capture protect juneteenth globe lgbt next frivolous executive haze flatts articles year-end critics recorded third laura act top-earning sixth professionalism records ephron sales evermore citing amanda letter following incest younger was literacy pop dan approval sales online cardigan bridges institutions mother kitty act heart 12th it first 102nd public f radio cancel information simultaneous arts tenth dancers professionalism in fitch style longer 2019 sang remarking album fire sweden 25th zayn swift's attended theatre school old estate cats correcting blood two mic amidst departments women harvey sales 2022 creative pivoting portion del were 500,000 older
