In [None]:
# Below is an example of how you might implement a very basic GPT‐style transformer language model in Julia using Flux. 
# This code builds on the ideas of embedding tokens and positions, stacking a few transformer blocks (each with multi‐head self-attention and a feedforward MLP), and finally projecting the output to a vocabulary space. 
# (Note that this is a simplified example intended for educational purposes; real-world GPT models include many more details and optimizations.)

In [6]:
using Flux
using Flux: onehotbatch, onecold, crossentropy, throttle, glorot_uniform, @functor
using Statistics
using StatsBase: sample, Weights
using Random

In [8]:
# ---------------------------
# Token and Positional Embeddings
# ---------------------------

    # Embeddings:
    # Two types of embeddings are defined:
        # TokenEmbedding: Maps token indices to dense vectors.
        # PositionalEmbedding: Provides a learnable positional encoding that is added to the token embeddings.

struct TokenEmbedding
    emb::Array{Float32,2}
end
@functor TokenEmbedding

TokenEmbedding(vocab_size::Int, embed_dim::Int) = TokenEmbedding(glorot_uniform(embed_dim, vocab_size))

function (te::TokenEmbedding)(x::Vector{Int})
    # Returns embedding of shape (embed_dim, sequence_length)
    return te.emb[:, x]
end

struct PositionalEmbedding
    emb::Array{Float32,2}
end
@functor PositionalEmbedding

PositionalEmbedding(seq_len::Int, embed_dim::Int) = PositionalEmbedding(glorot_uniform(embed_dim, seq_len))

function (pe::PositionalEmbedding)(T::Int)
    # Return positional embeddings for positions 1:T
    return pe.emb[:, 1:T]
end

[33m[1m│ [22m[39mMost likely, you should write `Flux.@layer MyLayer`which will add various convenience methods for your type,such as pretty-printing and use with Adapt.jl.
[33m[1m│ [22m[39mHowever, this is not required. Flux.jl v0.15 uses Functors.jl v0.5,which makes exploration of most nested `struct`s opt-out instead of opt-in...so Flux will automatically see inside any custom struct definitions.
[33m[1m│ [22m[39mIf you really want to apply the `@functor` macro to a custom struct, use `Functors.@functor` instead.
[33m[1m└ [22m[39m[90m@ Flux C:\Users\khoj\.julia\packages\Flux\Sgc17\src\deprecations.jl:101[39m


In [10]:
# ---------------------------
# Multi-Head Self-Attention Layer
# ---------------------------

    # Multi-Head Self-Attention:
    # The MultiHeadSelfAttention struct implements attention by first projecting inputs into queries, keys, and values, splitting them into heads, computing scaled dot-product attention (with a causal mask so that each token only attends to previous tokens), and then concatenating and projecting the result.

struct MultiHeadSelfAttention
    W_q::Dense
    W_k::Dense
    W_v::Dense
    W_o::Dense
    num_heads::Int
end

function MultiHeadSelfAttention(embed_dim::Int, num_heads::Int)
    return MultiHeadSelfAttention(
        Dense(embed_dim, embed_dim),
        Dense(embed_dim, embed_dim),
        Dense(embed_dim, embed_dim),
        Dense(embed_dim, embed_dim),
        num_heads
    )
end

function (mha::MultiHeadSelfAttention)(x::Array{Float32,3})
    # x has shape (embed_dim, T, batch_size)
    embed_dim, T, batch_size = size(x)
    head_dim = div(embed_dim, mha.num_heads)
    
    # Compute Q, K, V; shape: (embed_dim, T, batch_size)
    Q = mha.W_q(x)
    K = mha.W_k(x)
    V = mha.W_v(x)
    
    # Reshape to (head_dim, num_heads, T, batch_size)
    function split_heads(t)
        return reshape(t, head_dim, mha.num_heads, T, batch_size)
    end
    Qh = split_heads(Q)
    Kh = split_heads(K)
    Vh = split_heads(V)
    
    # Compute scaled dot-product attention for each head
    attn = Array{Float32}(undef, T, T, batch_size, mha.num_heads)
    for b in 1:batch_size, h in 1:mha.num_heads
        # Q and K for this head have shape (head_dim, T)
        Q_temp = Qh[:, h, :, b]  # (head_dim, T)
        K_temp = Kh[:, h, :, b]  # (head_dim, T)
        # Compute similarity scores: (T, T)
        scores = (Q_temp' * K_temp) ./ sqrt(head_dim)
        # Apply causal mask: prevent attending to future tokens
        for i in 1:T, j in i+1:T
            scores[i,j] = -Inf32
        end
        attn[:,:,b,h] = softmax(scores, dims=1)
    end
    
    # Multiply by V: for each head compute weighted sum of V along time dimension
    out = zeros(Float32, head_dim, T, batch_size, mha.num_heads)
    for b in 1:batch_size, h in 1:mha.num_heads
        V_temp = Vh[:, h, :, b]  # (head_dim, T)
        out[:, :, b, h] = V_temp * attn[:,:,b,h]
    end
    
    # Concatenate heads: reshape from (head_dim, T, batch_size, num_heads) to (embed_dim, T, batch_size)
    out = reshape(out, embed_dim, T, batch_size)
    # Final linear projection
    return mha.W_o(out)
end

In [12]:
# ---------------------------
# Transformer Block (GPT Block)
# ---------------------------

    # Transformer Block:
    # Each block applies layer normalization, self-attention, a residual connection, a feedforward MLP (with its own layer normalization and residual connection), mimicking the GPT block design.

struct TransformerBlock
    attn::MultiHeadSelfAttention
    ln1::LayerNorm
    mlp::Chain
    ln2::LayerNorm
end

function TransformerBlock(embed_dim::Int, num_heads::Int; dropout_prob=0.1)
    attn = MultiHeadSelfAttention(embed_dim, num_heads)
    ln1 = LayerNorm(embed_dim)
    mlp = Chain(
        Dense(embed_dim, 4 * embed_dim, relu),
        Dense(4 * embed_dim, embed_dim)
    )
    ln2 = LayerNorm(embed_dim)
    return TransformerBlock(attn, ln1, mlp, ln2)
end

function (block::TransformerBlock)(x::Array{Float32,3})
    # x: (embed_dim, T, batch_size)
    x_attn = block.attn(block.ln1(x))
    x = x .+ x_attn
    x_mlp = block.mlp(block.ln2(x))
    return x .+ x_mlp
end

In [14]:
# ---------------------------
# GPT Model
# ---------------------------

    # GPT Model Structure:
    # The model stacks token and positional embeddings followed by several transformer blocks. A final layer normalization and linear projection produce logits over the vocabulary for each position.

struct GPT
    token_emb::TokenEmbedding
    pos_emb::PositionalEmbedding
    blocks::Vector{TransformerBlock}
    ln_f::LayerNorm
    head::Dense
end

function GPT(vocab_size::Int, seq_len::Int, embed_dim::Int, num_heads::Int, num_layers::Int)
    token_emb = TokenEmbedding(vocab_size, embed_dim)
    pos_emb = PositionalEmbedding(seq_len, embed_dim)
    blocks = [TransformerBlock(embed_dim, num_heads) for _ in 1:num_layers]
    ln_f = LayerNorm(embed_dim)
    head = Dense(embed_dim, vocab_size)
    return GPT(token_emb, pos_emb, blocks, ln_f, head)
end

function (model::GPT)(x::Vector{Int})
    # x is a vector of token indices (length T). For simplicity, we assume batch size 1.
    T = length(x)
    # Get token embeddings: shape (embed_dim, T)
    tok_emb = model.token_emb(x)
    # Get positional embeddings: shape (embed_dim, T)
    pos_emb = model.pos_emb(T)
    # Sum embeddings
    h = tok_emb .+ pos_emb
    # Add batch dimension: shape (embed_dim, T, 1)
    h = reshape(h, size(h,1), size(h,2), 1)
    # Pass through transformer blocks
    for block in model.blocks
        h = block(h)
    end
    h = model.ln_f(h)
    # Final projection to vocabulary logits for each position
    logits = model.head(h)  # shape: (vocab_size, T, 1)
    return logits
end

In [16]:
# ---------------------------
# Example Usage
# ---------------------------

     # A toy model is defined with small hyperparameters and vocabulary size. A dummy input sequence is fed through the model, and the code prints the shape of the logits. It also demonstrates how to sample a predicted next token from the logits.

# Define hyperparameters for a toy model
vocab_size = 100       # For demonstration; in practice, use a larger vocabulary.
seq_len = 20           # Maximum sequence length (block size)
embed_dim = 32         # Embedding (hidden) dimension
num_heads = 4          # Number of attention heads
num_layers = 2         # Number of transformer blocks

# Create a GPT model instance
gpt_model = GPT(vocab_size, seq_len, embed_dim, num_heads, num_layers)

# Create a dummy input: a sequence of token indices (length = seq_len)
input_sequence = rand(1:vocab_size, seq_len)

# Pass the sequence through the model to obtain logits
logits = gpt_model(input_sequence)
println("Logits shape: ", size(logits))  # Expect (vocab_size, seq_len, 1)

# For instance, to predict the next token for the last position, one could take:
predicted_distribution = softmax(logits[:, end, 1])
predicted_token = sample(1:vocab_size, Weights(vec(predicted_distribution)))
println("Predicted next token index: ", predicted_token)

Logits shape: (100, 20, 1)
Predicted next token index: 13


In [None]:
# The output is exactly what we would expect from this toy model demonstration:

    # Logits Shape (100, 20, 1):
    # This indicates that for a batch size of 1, the model produces outputs for 20 tokens (the sequence length), with each output being a vector of length 100. The 100 corresponds to the vocabulary size, meaning the model outputs a score (logit) for each token in the vocabulary at each position.

    # Predicted Next Token Index (13):
    # After applying the softmax to the logits of the final token, the model samples a token index. In this case, token 13 was chosen. Since the model is randomly initialized (and likely untrained), this number is arbitrary but confirms that the prediction mechanism is working.