In [1]:
using Flux
using Flux: onehotbatch, onecold, crossentropy, throttle, glorot_uniform
using Statistics
using StatsBase: sample, Weights
using Random
using LinearAlgebra
using CUDA  # For GPU acceleration

│ - Run `import Pkg; Pkg.add("cuDNN")` to install the cuDNN package, then restart julia.
│ - If cuDNN is not installed, some Flux functionalities will not be available when running on the GPU.
└ @ FluxCUDAExt C:\Users\khoj\.julia\packages\Flux\Sgc17\ext\FluxCUDAExt\FluxCUDAExt.jl:10


In [2]:
# ---------------------------
# Enhanced Token and Positional Embeddings
# ---------------------------
struct TokenEmbedding
    emb::AbstractArray{Float32,2}
    dropout::Dropout
end

TokenEmbedding(vocab_size::Int, embed_dim::Int; dropout_prob=0.1) = 
    TokenEmbedding(glorot_uniform(embed_dim, vocab_size), Dropout(dropout_prob))

function (te::TokenEmbedding)(x::AbstractArray{Int})
    # Returns embedding with dropout applied
    return te.dropout(te.emb[:, x])
end

# Learned positional embeddings (like GPT-2)
struct LearnedPositionalEmbedding
    emb::AbstractArray{Float32,2}
    dropout::Dropout
end

LearnedPositionalEmbedding(seq_len::Int, embed_dim::Int; dropout_prob=0.1) = 
    LearnedPositionalEmbedding(glorot_uniform(embed_dim, seq_len), Dropout(dropout_prob))

function (pe::LearnedPositionalEmbedding)(T::Int)
    # Return positional embeddings for positions 1:T with dropout
    return pe.dropout(pe.emb[:, 1:T])
end

# Sinusoidal positional embeddings (alternative to learned) (like the original Transformer)
function sinusoidal_position_embedding(seq_len::Int, embed_dim::Int)
    # Create PE matrix
    pe = zeros(Float32, embed_dim, seq_len)
    
    # Calculate frequencies
    for pos in 1:seq_len
        for i in 0:2:embed_dim-1
            freq = 1.0 / (10000.0^(i/embed_dim))
            pe[i+1, pos] = sin(pos * freq)
            if i+1 < embed_dim
                pe[i+2, pos] = cos(pos * freq)
            end
        end
    end
    
    return pe
end

struct SinusoidalPositionalEmbedding
    emb::AbstractArray{Float32,2}
    dropout::Dropout
end

function SinusoidalPositionalEmbedding(seq_len::Int, embed_dim::Int; dropout_prob=0.1)
    emb = sinusoidal_position_embedding(seq_len, embed_dim)
    return SinusoidalPositionalEmbedding(emb, Dropout(dropout_prob))
end

function (pe::SinusoidalPositionalEmbedding)(T::Int)
    return pe.dropout(pe.emb[:, 1:T])
end

In [3]:
# ---------------------------
# Rotary Position Embeddings (RoPE)
# ---------------------------

# a modern technique used in models like LLaMA that enables better handling of sequences longer than those seen during training

struct RotaryPositionEmbedding
    dim::Int
    max_seq_len::Int
    freqs_cos::AbstractArray{Float32, 2}
    freqs_sin::AbstractArray{Float32, 2}
end

function RotaryPositionEmbedding(dim::Int, max_seq_len::Int)
    # Generate frequency pairs for rotary embeddings
    theta = 10000.0 .^ (-(0:2:dim-2) ./ dim)
    freqs = repeat(theta, 1, max_seq_len) .* repeat(reshape(1:max_seq_len, 1, :), length(theta), 1)
    freqs_cos = cos.(freqs)
    freqs_sin = sin.(freqs)
    
    return RotaryPositionEmbedding(dim, max_seq_len, freqs_cos, freqs_sin)
end

function rotate_half(x::AbstractArray{Float32, 4})
    # For 4D tensor: (head_dim, seq_len, batch_size, num_heads)
    head_dim, seq_len, batch_size, num_heads = size(x)
    
    # Ensure head_dim is even
    half_dim = head_dim ÷ 2
    
    # Split the first dimension (head_dim) in half
    x1 = x[1:2:head_dim, :, :, :]  # First half
    x2 = x[2:2:head_dim, :, :, :]  # Second half
    
    # Stack with negation
    return cat(
        -x2,  # Negate second half
        x1,   # Keep first half as is
        dims=1
    )
end

function apply_rotary_pos_emb(q::AbstractArray{Float32, 4}, k::AbstractArray{Float32, 4}, 
                            freqs_cos::AbstractArray{Float32, 2}, freqs_sin::AbstractArray{Float32, 2}, 
                            T::Int)
    # Apply rotary embeddings to queries and keys
    # q, k shape: (head_dim, T, batch_size, num_heads)
    head_dim, seq_len, batch_size, num_heads = size(q)
    
    # Ensure we don't try to access beyond the precomputed frequencies
    effective_dim = min(head_dim, size(freqs_cos, 1))
    effective_seq_len = min(T, size(freqs_cos, 2))
    
    # Reshape for broadcasting
    cos_pos = freqs_cos[1:effective_dim, 1:effective_seq_len]
    sin_pos = freqs_sin[1:effective_dim, 1:effective_seq_len]
    
    # Explicitly specify dimensions when reshaping
    cos_pos = reshape(cos_pos, effective_dim, effective_seq_len, 1, 1)
    sin_pos = reshape(sin_pos, effective_dim, effective_seq_len, 1, 1)
    
    # Pad if necessary to match input dimensions
    if effective_dim < head_dim
        cos_pos = vcat(cos_pos, zeros(Float32, head_dim - effective_dim, effective_seq_len, 1, 1))
        sin_pos = vcat(sin_pos, zeros(Float32, head_dim - effective_dim, effective_seq_len, 1, 1))
    end
    
    if effective_seq_len < seq_len
        cos_pos = cat(cos_pos, zeros(Float32, head_dim, seq_len - effective_seq_len, 1, 1), dims=2)
        sin_pos = cat(sin_pos, zeros(Float32, head_dim, seq_len - effective_seq_len, 1, 1), dims=2)
    end
    
    # Apply rotary embeddings
    q_rot = q .* cos_pos .+ rotate_half(q) .* sin_pos
    k_rot = k .* cos_pos .+ rotate_half(k) .* sin_pos
    
    return q_rot, k_rot
end 

apply_rotary_pos_emb (generic function with 1 method)

In [4]:
# ---------------------------
# Advanced Multi-Head Self-Attention with Flash Attention
# ---------------------------
struct MultiHeadAttention
    W_q::Dense
    W_k::Dense
    W_v::Dense
    W_o::Dense
    num_heads::Int
    head_dim::Int
    attn_dropout::Dropout
    resid_dropout::Dropout
    rope::Union{RotaryPositionEmbedding, Nothing}
    use_flash_attn::Bool
end

function MultiHeadAttention(embed_dim::Int, num_heads::Int; 
                           attn_dropout_prob=0.1, 
                           resid_dropout_prob=0.1,
                           use_rope::Bool=true,
                           max_seq_len::Int=1024,
                           use_flash_attn::Bool=false)
    head_dim = div(embed_dim, num_heads)
    @assert head_dim * num_heads == embed_dim "embed_dim must be divisible by num_heads"
    
    rope = use_rope ? RotaryPositionEmbedding(head_dim, max_seq_len) : nothing
    
    return MultiHeadAttention(
        Dense(embed_dim, embed_dim),  # query
        Dense(embed_dim, embed_dim),  # key
        Dense(embed_dim, embed_dim),  # value
        Dense(embed_dim, embed_dim),  # output
        num_heads,
        head_dim,
        Dropout(attn_dropout_prob),
        Dropout(resid_dropout_prob),
        rope,
        use_flash_attn
    )
end

# Helper function to compute standard attention
function compute_attention(Q::AbstractArray{Float32, 4}, K::AbstractArray{Float32, 4}, V::AbstractArray{Float32, 4}, 
                         attn_dropout::Dropout, causal_mask::Bool=true)
    # Q, K, V shapes: (head_dim, T, batch_size, num_heads)
    head_dim, T, batch_size, num_heads = size(Q)
    
    # Reshape to (batch_size * num_heads, T, head_dim)
    Q_flat = permutedims(Q, (2, 3, 4, 1))  # (T, batch_size, num_heads, head_dim)
    K_flat = permutedims(K, (2, 3, 4, 1))
    V_flat = permutedims(V, (2, 3, 4, 1))
    
    Q_flat = reshape(Q_flat, T, batch_size * num_heads, head_dim)
    K_flat = reshape(K_flat, T, batch_size * num_heads, head_dim)
    V_flat = reshape(V_flat, T, batch_size * num_heads, head_dim)
    
    # Transpose Q and K for matrix multiplication
    Q_flat = permutedims(Q_flat, (3, 1, 2))  # (head_dim, T, batch_size * num_heads)
    K_flat = permutedims(K_flat, (3, 1, 2))
    V_flat = permutedims(V_flat, (3, 1, 2))
    
    # Compute attention scores
    scores = batched_mul(permutedims(Q_flat, (2, 1, 3)), K_flat) ./ sqrt(Float32(head_dim))
    
    # Apply causal mask if needed
    if causal_mask
        mask = tril(ones(Float32, T, T))
        mask = reshape(mask, T, T, 1)
        scores = scores .* mask .+ (1.0f0 .- mask) .* -1.0f7
    end
    
    # Apply softmax and dropout
    attn_weights = softmax(scores, dims=2)
    attn_weights = attn_dropout(attn_weights)
    
    # Apply attention to values
    out = batched_mul(attn_weights, permutedims(V_flat, (2, 1, 3)))
    
    # Reshape back to original dimensions (head_dim, T, batch_size, num_heads)
    out = permutedims(out, (2, 1, 3))  # (head_dim, T, batch_size * num_heads)
    out = reshape(out, head_dim, T, batch_size, num_heads)
    
    return out
end

# Simple implementation of flash attention algorithm (approximation for educational purposes)
# Flash Attention implementation - an approximation of the algorithm that reduces memory usage and improves
function flash_attention(Q::AbstractArray{Float32}, K::AbstractArray{Float32}, V::AbstractArray{Float32}, 
                         attn_dropout::Dropout, causal_mask::Bool=true, block_size::Int=128)
    head_dim, T, batch_size, num_heads = size(Q)
    
    # Flash attention works by splitting the sequence into chunks/blocks
    O = zeros(Float32, head_dim, T, batch_size, num_heads)
    L = zeros(Float32, 1, T, batch_size, num_heads)
    
    for b_q in 1:block_size:T
        e_q = min(b_q + block_size - 1, T)
        Q_block = Q[:, b_q:e_q, :, :]
        
        max_k_idx = causal_mask ? e_q : T
        
        for b_k in 1:block_size:max_k_idx
            e_k = min(b_k + block_size - 1, max_k_idx)
            K_block = K[:, b_k:e_k, :, :]
            V_block = V[:, b_k:e_k, :, :]
            
            # Compute attention scores for this block pair
            Q_perm = permutedims(Q_block, (3, 4, 2, 1))
            K_perm = permutedims(K_block, (3, 4, 2, 1))
            V_perm = permutedims(V_block, (3, 4, 2, 1))
            
            # Matrix multiply Q and K
            scores = batched_mul(Q_perm, permutedims(K_perm, (1, 2, 4, 3))) ./ sqrt(Float32(head_dim))
            
            # Apply softmax and attention calculation within the block
            block_attn = softmax(scores, dims=4)
            block_attn = attn_dropout(block_attn)
            block_out = batched_mul(block_attn, V_perm)
            
            # Update output and normalizers
            block_out_perm = permutedims(block_out, (4, 3, 1, 2))
            O[:, b_q:e_q, :, :] .+= block_out_perm
            L[1, b_q:e_q, :, :] .+= sum(block_attn, dims=4)
        end
    end
    
    # Normalize by the sum of attention weights
    L = reshape(L, 1, T, batch_size, num_heads)
    O ./= L
    
    return O
end

function (mha::MultiHeadAttention)(x::AbstractArray{Float32,3}, 
                                   mask::Union{Nothing, AbstractArray{Float32}}=nothing)
    # x has shape (embed_dim, T, batch_size)
    embed_dim, T, batch_size = size(x)
    
    # Compute Q, K, V; shape: (embed_dim, T, batch_size)
    Q = mha.W_q(x)
    K = mha.W_k(x)
    V = mha.W_v(x)
    
    # Reshape to (head_dim, num_heads, T, batch_size) and then to (head_dim, T, batch_size, num_heads)
    function split_heads(t)
        return reshape(t, mha.head_dim, mha.num_heads, T, batch_size)
    end
    Qh = permutedims(split_heads(Q), (1, 3, 4, 2))
    Kh = permutedims(split_heads(K), (1, 3, 4, 2))
    Vh = permutedims(split_heads(V), (1, 3, 4, 2))
    
    # Apply rotary positional embeddings if used
    if !isnothing(mha.rope)
        Qh, Kh = apply_rotary_pos_emb(Qh, Kh, mha.rope.freqs_cos, mha.rope.freqs_sin, T)
    end
    
    # Compute attention - select implementation
    if mha.use_flash_attn
        out = flash_attention(Qh, Kh, Vh, mha.attn_dropout)
    else
        out = compute_attention(Qh, Kh, Vh, mha.attn_dropout)
    end
    
    # Reshape back to (embed_dim, T, batch_size)
    out = reshape(permutedims(out, (1, 2, 4, 3)), embed_dim, T, batch_size)
    
    # Final linear projection with dropout
    return mha.resid_dropout(mha.W_o(out))
end

In [5]:
# ---------------------------
# Improved MLP with SwiGLU Activation
# ---------------------------

# SwiGLU activation function instead of simple ReLU - this activation (Swish-Gated Linear Unit) is used in state-of-the-art models like PaLM and significantly improves performance

struct SwiGLU
    W_up::Dense
    W_gate::Dense
    W_down::Dense
    dropout::Dropout
end

function SwiGLU(dim::Int, hidden_dim::Int; dropout_prob=0.1)
    return SwiGLU(
        Dense(dim, hidden_dim),
        Dense(dim, hidden_dim),
        Dense(hidden_dim, dim),
        Dropout(dropout_prob)
    )
end

function (swiglu::SwiGLU)(x::AbstractArray{Float32})
    # SwiGLU activation: SwiGLU(x) = Swish(W_gate*x) ⊗ (W_up*x)
    # where Swish(x) = x * sigmoid(x)
    up = swiglu.W_up(x)
    gate = swiglu.W_gate(x)
    swish_gate = gate .* sigmoid.(gate)  # Swish activation
    hidden = swish_gate .* up
    return swiglu.dropout(swiglu.W_down(hidden))
end

In [6]:
# ---------------------------
# Enhanced Transformer Block with Pre/Post Norm Options
# ---------------------------

# Pre-norm (used in GPT) vs Post-norm (original Transformer) layer normalization placement, improving training stability

struct TransformerBlock
    attn::MultiHeadAttention
    norm1::LayerNorm
    ffn::SwiGLU
    norm2::LayerNorm
    pre_norm::Bool
end

function TransformerBlock(embed_dim::Int, num_heads::Int; 
                         ffn_dim_mult=4, 
                         dropout_prob=0.1,
                         pre_norm=true,  # Pre-norm (original GPT) or post-norm (original BERT)
                         use_rope=true,
                         max_seq_len=1024,
                         use_flash_attn=false)
    attn = MultiHeadAttention(embed_dim, num_heads, 
                             attn_dropout_prob=dropout_prob, 
                             resid_dropout_prob=dropout_prob,
                             use_rope=use_rope,
                             max_seq_len=max_seq_len,
                             use_flash_attn=use_flash_attn)
    
    norm1 = LayerNorm(embed_dim)
    ffn = SwiGLU(embed_dim, ffn_dim_mult * embed_dim, dropout_prob=dropout_prob)
    norm2 = LayerNorm(embed_dim)
    
    return TransformerBlock(attn, norm1, ffn, norm2, pre_norm)
end

function (block::TransformerBlock)(x::AbstractArray{Float32,3}, mask=nothing)
    # x: (embed_dim, T, batch_size)
    if block.pre_norm
        # Pre-normalization architecture (GPT style)
        x = x + block.attn(block.norm1(x), mask)
        x = x + block.ffn(block.norm2(x))
    else
        # Post-normalization architecture (original Transformer style)
        x = block.norm1(x + block.attn(x, mask))
        x = block.norm2(x + block.ffn(x))
    end
    return x
end

In [7]:
# ---------------------------
# Enhanced GPT Model
# ---------------------------
struct AdvancedGPT
    token_emb::TokenEmbedding
    pos_emb::Union{LearnedPositionalEmbedding, SinusoidalPositionalEmbedding}
    blocks::Vector{TransformerBlock}
    ln_f::LayerNorm
    head::Dense
    emb_dropout::Dropout
    config::Dict{Symbol, Any}
end

function AdvancedGPT(;
    vocab_size::Int,
    max_seq_len::Int,
    embed_dim::Int,
    num_heads::Int,
    num_layers::Int,
    dropout_prob::Float64=0.1,
    pos_emb_type::Symbol=:learned,  # :learned or :sinusoidal
    pre_norm::Bool=true,
    use_rope::Bool=true,
    use_flash_attn::Bool=false,
    ffn_dim_mult::Int=4
)
    config = Dict{Symbol, Any}(
        :vocab_size => vocab_size,
        :max_seq_len => max_seq_len,
        :embed_dim => embed_dim,
        :num_heads => num_heads,
        :num_layers => num_layers,
        :dropout_prob => dropout_prob,
        :pos_emb_type => pos_emb_type,
        :pre_norm => pre_norm,
        :use_rope => use_rope,
        :use_flash_attn => use_flash_attn,
        :ffn_dim_mult => ffn_dim_mult
    )
    
    token_emb = TokenEmbedding(vocab_size, embed_dim, dropout_prob=0.0)  # No dropout here, we'll apply after sum
    
    if pos_emb_type == :learned
        pos_emb = LearnedPositionalEmbedding(max_seq_len, embed_dim, dropout_prob=0.0)
    else
        pos_emb = SinusoidalPositionalEmbedding(max_seq_len, embed_dim, dropout_prob=0.0)
    end
    
    emb_dropout = Dropout(dropout_prob)
    
    blocks = [
        TransformerBlock(
            embed_dim, 
            num_heads, 
            ffn_dim_mult=ffn_dim_mult, 
            dropout_prob=dropout_prob, 
            pre_norm=pre_norm,
            use_rope=use_rope,
            max_seq_len=max_seq_len,
            use_flash_attn=use_flash_attn
        ) for _ in 1:num_layers
    ]
    
    ln_f = LayerNorm(embed_dim)
    head = Dense(embed_dim, vocab_size)
    
    return AdvancedGPT(token_emb, pos_emb, blocks, ln_f, head, emb_dropout, config)
end

function (model::AdvancedGPT)(x::AbstractArray{Int}; return_all_layers=false)
    # Support both single sequence (Vector{Int}) and batched input (Matrix{Int})
    if ndims(x) == 1
        # Single sequence: (T)
        T = length(x)
        batch_size = 1
        x_reshaped = reshape(x, T, 1)
    else
        # Batch mode: (T, batch_size)
        T, batch_size = size(x)
        x_reshaped = x
    end
    
    # Get token embeddings: shape (embed_dim, T, batch_size)
    tok_emb = model.token_emb(x_reshaped)
    
    # Get positional embeddings: shape (embed_dim, T)
    pos_emb = model.pos_emb(T)
    
    # Expand positional embeddings to match batch dimension
    pos_emb_expanded = repeat(pos_emb, outer=[1, 1, batch_size])
    
    # Sum embeddings and apply dropout
    h = model.emb_dropout(tok_emb .+ pos_emb_expanded)
    
    # Store intermediate activations if requested
    activations = return_all_layers ? [h] : nothing
    
    # Pass through transformer blocks
    for block in model.blocks
        h = block(h)
        if return_all_layers
            push!(activations, h)
        end
    end
    
    # Final layer norm
    h = model.ln_f(h)
    
    # Final projection to vocabulary logits for each position
    logits = model.head(h)  # shape: (vocab_size, T, batch_size)
    
    # Permute to the standard NLP shape: (batch_size, vocab_size, T)
    logits = permutedims(logits, (3, 1, 2))
    
    if return_all_layers
        return logits, activations
    else
        return logits
    end
end

In [8]:
# ---------------------------
# Text Generation Utilities
# ---------------------------
function softmax_safe(x)
    if isa(x, Number)
        return [1.0f0]  # Return single-element array for scalar input
    end
    # Ensure we're working with a vector
    x_vec = vec(collect(Float32, x))
    # Compute softmax
    exp_x = exp.(x_vec .- maximum(x_vec))
    return exp_x ./ sum(exp_x)
end

function generate(model::AdvancedGPT, 
                 prompt::Vector{Int}; 
                 max_new_tokens::Int=100, 
                 temperature::Float32=1.0f0,
                 top_k::Int=0, 
                 top_p::Float32=0.9f0)
    max_len = model.config[:max_seq_len]
    tokens = copy(prompt)
    
    for _ in 1:max_new_tokens
        # Truncate if exceeding context length
        context = tokens[max(1, length(tokens) - max_len + 1):end]
        context = reshape(context, :, 1)  # Make it (seq_len, batch_size)
        
        # Get logits for the next token
        logits = model(context)  # (batch_size, vocab_size, seq_len)
        
        # Extract and ensure we have a vector
        next_token_logits = Array{Float32}(vec(logits[1, :, end]))
        
        # Apply temperature
        if temperature > 0
            next_token_logits ./= temperature
        else
            # Greedy sampling
            _, next_token = findmax(next_token_logits)
            push!(tokens, next_token)
            continue
        end
        
        # Apply top-k filtering
        if top_k > 0
            # Keep only the top-k tokens
            sorted_indices = sortperm(next_token_logits, rev=true)
            next_token_logits[sorted_indices[top_k+1:end]] .= -Inf32
        end
        
        # Apply top-p (nucleus) filtering
        if top_p < 1.0
            # Sort logits in descending order
            sorted_indices = sortperm(next_token_logits, rev=true)
            sorted_logits = next_token_logits[sorted_indices]
            
            # Calculate cumulative probabilities
            exp_logits = exp.(sorted_logits .- maximum(sorted_logits))
            probs = exp_logits ./ sum(exp_logits)
            cumulative_probs = cumsum(probs)
            
            # Find cutoff index
            cutoff_idx = findfirst(cumulative_probs .> top_p)
            if !isnothing(cutoff_idx) && cutoff_idx > 1
                next_token_logits[sorted_indices[cutoff_idx:end]] .= -Inf32
            end
        end
        
        # Convert logits to probabilities
        exp_logits = exp.(next_token_logits .- maximum(next_token_logits))
        probs = exp_logits ./ sum(exp_logits)
        
        # Sample from the filtered distribution
        next_token = sample(1:length(probs), Weights(probs))
        push!(tokens, next_token)
    end
    
    return tokens
end

generate (generic function with 1 method)

In [9]:
# ---------------------------
# Tokenization (Simple Character-level for demonstration)
# ---------------------------
struct CharTokenizer
    vocab::Dict{Char, Int}
    idx_to_char::Dict{Int, Char}
end

function CharTokenizer(text::String)
    unique_chars = unique(collect(text))
    vocab = Dict(char => i for (i, char) in enumerate(unique_chars))
    idx_to_char = Dict(i => char for (i, char) in enumerate(unique_chars))
    return CharTokenizer(vocab, idx_to_char)
end

function encode(tokenizer::CharTokenizer, text::String)
    return [tokenizer.vocab[c] for c in text if haskey(tokenizer.vocab, c)]
end

function decode(tokenizer::CharTokenizer, indices::Vector{Int})
    return join([tokenizer.idx_to_char[i] for i in indices if haskey(tokenizer.idx_to_char, i)])
end

decode (generic function with 1 method)

In [10]:
# ---------------------------
# Training Loop
# ---------------------------
function train_gpt!(model::AdvancedGPT, data::Vector{Int}, 
                   opt::Flux.Optimise.AbstractOptimiser;
                   batch_size::Int=16, 
                   bptt::Int=64,  # Sequence length for backprop through time
                   epochs::Int=1,
                   lr::Float64=3e-4,
                   clip_norm::Float64=1.0)
    
    data_length = length(data)
    
    # Function to get a batch of sequences
    function get_batch()
        # Random starting indices
        starts = rand(1:(data_length - bptt), batch_size)
        
        # Create input-target pairs
        xs = zeros(Int, bptt, batch_size)
        ys = zeros(Int, bptt, batch_size)
        
        for (i, start) in enumerate(starts)
            end_idx = start + bptt - 1
            xs[:, i] = data[start:end_idx]
            ys[:, i] = data[(start+1):(end_idx+1)]
        end
        
        return xs, ys
    end
    
    # Training loop
    steps_per_epoch = div(data_length, batch_size * bptt)
    total_steps = steps_per_epoch * epochs
    
    params = Flux.params(model)
    
    for epoch in 1:epochs
        epoch_loss = 0.0
        
        for step in 1:steps_per_epoch
            # Get batch
            x_batch, y_batch = get_batch()
            
            # Compute loss and gradients
            loss, grads = Flux.withgradient(params) do
                # Forward pass
                logits = model(x_batch)  # (batch_size, vocab_size, seq_len)
                
                # Reshape for loss computation
                logits_flat = reshape(permutedims(logits, (2, 3, 1)), :, batch_size)
                targets_flat = reshape(y_batch, :)
                
                # Cross entropy loss
                return crossentropy(logits_flat, targets_flat)
            end
            
            # Gradient clipping
            if clip_norm > 0
                Flux.Optimise.clip!(grads, clip_norm)
            end
            
            # Update parameters
            Flux.Optimise.update!(opt, params, grads)
            
            epoch_loss += loss
            
            if step % 10 == 0
                println("Epoch: $epoch, Step: $step/$steps_per_epoch, Loss: $(loss)")
            end
        end
        
        avg_loss = epoch_loss / steps_per_epoch
        println("Epoch $epoch completed. Average loss: $avg_loss")
    end
end

train_gpt! (generic function with 1 method)

In [11]:
# ---------------------------
# Advanced Usage Example
# ---------------------------
function main()
    # Parameters for a medium-size model
    config = Dict(
        :vocab_size => 50000,
        :max_seq_len => 1024,
        :embed_dim => 768,
        :num_heads => 12,
        :num_layers => 12,
        :dropout_prob => 0.1,
        :pos_emb_type => :learned,
        :pre_norm => true,
        :use_rope => true,
        :use_flash_attn => false,  # Set to true if GPU available
        :ffn_dim_mult => 4
    )
    
    # Create model
    model = AdvancedGPT(;
        vocab_size=config[:vocab_size],
        max_seq_len=config[:max_seq_len],
        embed_dim=config[:embed_dim],
        num_heads=config[:num_heads],
        num_layers=config[:num_layers],
        dropout_prob=config[:dropout_prob],
        pos_emb_type=config[:pos_emb_type],
        pre_norm=config[:pre_norm],
        use_rope=config[:use_rope],
        use_flash_attn=config[:use_flash_attn],
        ffn_dim_mult=config[:ffn_dim_mult]
    )
    
    # Example - move to GPU if available
    if CUDA.functional()
        println("Using CUDA...")
        model = model |> gpu
    end
    
    # Sample inputs (for demonstration)
    batch_size = 2
    seq_len = 10
    input_ids = rand(1:config[:vocab_size], seq_len, batch_size)
    
    # Forward pass
    logits = model(input_ids)
    println("Logits shape: ", size(logits))  # Should be (batch_size, vocab_size, seq_len)
    
    # Generate text example
    # Note: In a real scenario, you would use a proper tokenizer
    prompt = [10, 20, 30, 40, 50]  # Example token IDs
    generated = generate(model, prompt, max_new_tokens=20, temperature=0.7f0, top_p=0.9f0)
    println("Generated token IDs: ", generated)
    
    println("Model successfully created and tested!")
end

# Call main to test
main()

Using CUDA...


│ 
│ 1. If no GPU is available, nothing needs to be done.
│ 2. If GPU is available, load the corresponding trigger package.
│     a. `CUDA.jl` and `cuDNN.jl` (or just `LuxCUDA.jl`) for  NVIDIA CUDA Support.
│     b. `AMDGPU.jl` for AMD GPU ROCM Support.
│     c. `Metal.jl` for Apple Metal GPU Support. (Experimental)
│     d. `oneAPI.jl` for Intel oneAPI GPU Support. (Experimental)
└ @ MLDataDevices.Internal C:\Users\khoj\.julia\packages\MLDataDevices\Cq9gx\src\internal.jl:94


Logits shape: (2, 50000, 10)
Generated token IDs: [10, 20, 30, 40, 50, 28710, 38915, 23406, 43173, 43504, 34141, 2594, 27379, 21366, 26905, 5922, 13592, 30117, 12022, 143, 1194, 40307, 27259, 5920, 1993]
Model successfully created and tested!


In [12]:
# "Logits shape: (2, 50000, 10)"
# Shows correct dimensions:
   # 2: batch size
   # 50000: vocabulary size
   # 10: sequence length

# "Generated token IDs:" shows:
   # First 5 numbers [10, 20, 30, 40, 50] are your original prompt
   # Following numbers are the generated tokens
   # The generation looks reasonable with diverse token IDs within the vocabulary range (0-50000)

# "Model successfully created and tested!" - Confirms the model initialization and forward pass worked correctly

# If you want to see the actual text output, you would need to pass these tokens through your tokenizer's decode function to convert them back to text.

In [13]:
# I'll create an example that demonstrates how to use the model with actual text data. Here's a comprehensive example:

function example_usage()
    println("Starting GPT model example...")
    
    # 1. Define sample text for training/testing
    sample_text = """
    Julia is a high-level, high-performance, dynamic programming language. 
    While it is a general-purpose language and can be used to write any application, 
    many of its features are well suited for numerical analysis and computational science.
    """
    
    # 2. Create tokenizer and encode sample text
    tokenizer = CharTokenizer(sample_text)
    println("Vocabulary size: ", length(tokenizer.vocab))
    
    # 3. Create model with appropriate size for this example
    model = AdvancedGPT(
        vocab_size=length(tokenizer.vocab),
        max_seq_len=128,
        embed_dim=256,      # Smaller for this example
        num_heads=8,
        num_layers=6,       # Reduced layers for faster execution
        dropout_prob=0.1,
        pos_emb_type=:learned,
        pre_norm=true,
        use_rope=true,
        use_flash_attn=false
    )
    
    # 4. Move to GPU if available
    if CUDA.functional()
        println("Using CUDA...")
        model = model |> gpu
    end
    
    # 5. Generate text from a prompt
    prompt = "Julia is"
    prompt_tokens = encode(tokenizer, prompt)
    println("\nPrompt: \"$prompt\"")
    
    # 6. Generate continuation
    generated_tokens = generate(
        model,
        prompt_tokens;
        max_new_tokens=50,
        temperature=0.7f0,
        top_k=50,
        top_p=0.9f0
    )
    
    # 7. Decode and display the result
    generated_text = decode(tokenizer, generated_tokens)
    println("\nGenerated text:")
    println(generated_text)
    
    # 8. Show some model statistics
    println("\nModel statistics:")
    println("- Vocabulary size: ", length(tokenizer.vocab))
    println("- Context length: ", model.config[:max_seq_len])
    println("- Number of parameters: ", sum(length, Flux.params(model)))
    
    return model, tokenizer
end

# Run the example
model, tokenizer = example_usage()

Starting GPT model example...
Vocabulary size: 28
Using CUDA...

Prompt: "Julia is"

Generated text:
Julia isrrrrrhhpmf,rhrhfamW,JhwnWJmfpfawaWfhopffaa.rprrlgf

Model statistics:
- Vocabulary size: 28
- Context length: 128


│ and the explicit `gradient(m -> loss(m, x, y), model)` for gradient computation.
└ @ Flux C:\Users\khoj\.julia\packages\Flux\Sgc17\src\deprecations.jl:93


- Number of parameters: 6389788


(AdvancedGPT(TokenEmbedding(Float32[0.090890154 -0.099342726 … 0.02606686 -0.10358236; 0.07702502 0.07856488 … -0.107408084 0.04179199; … ; 0.09270245 0.066110544 … 0.014684047 0.12216856; 0.06867142 -0.100097716 … 0.14114578 0.084149435], Dropout(0.0)), LearnedPositionalEmbedding(Float32[0.084296644 0.11054972 … 0.017054155 -0.0199074; -0.12008329 0.10933383 … -0.06468785 -0.030502647; … ; -0.010624945 0.071070075 … -0.049582854 0.0067366958; 0.12292118 0.114829466 … 0.07636237 -0.0025499314], Dropout(0.0)), TransformerBlock[TransformerBlock(MultiHeadAttention(Dense(256 => 256), Dense(256 => 256), Dense(256 => 256), Dense(256 => 256), 8, 32, Dropout(0.1), Dropout(0.1), RotaryPositionEmbedding(32, 128, Float32[0.5403023 -0.41614684 … 0.2323591 -0.6928958; 0.84600914 0.43146282 … -0.66799676 -0.9618962; … ; 0.99999994 0.9999998 … 0.99919367 0.9991809; 1.0 0.99999994 … 0.999745 0.99974096], Float32[0.84147096 0.9092974 … 0.9726301 0.7210377; 0.53316844 0.9021307 … 0.74416417 0.27341488; 

In [14]:
# Generate more text with different prompts
function generate_text(model, tokenizer, prompt; max_tokens=50)
    tokens = encode(tokenizer, prompt)
    generated = generate(
        model,
        tokens,
        max_new_tokens=max_tokens,
        temperature=0.7f0,
        top_k=50,
        top_p=0.9f0
    )
    return decode(tokenizer, generated)
end

# Try different prompts - Different prompt handling
prompts = [
    "Julia can",
    "Programming in",
    "The language"
]

for prompt in prompts
    println("\nPrompt: \"$prompt\"")
    println("Generated: ", generate_text(model, tokenizer, prompt))
end

# Note that since we're using a character-level tokenizer and a relatively small model, the generated text might not be very coherent. 


Prompt: "Julia can"
Generated: Julia canrrr-hysrrbhri,r u,wrh,ppcwvgiWagabsoipyr
iiprWrWWy

Prompt: "Programming in"
Generated: rogramming in,Wbnose,hphsWoy lmsil.cl,bnfvrWh Jfnsh,iWmvrvhv pW

Prompt: "The language"
Generated: he languageblrl-or,btny,nc,rrwv,lwyncsnhw,wbw,hprrncnrbns pll
