In [1]:
using PyCall
torch = pyimport("torch")
println("PyTorch version: ", torch.__version__)

PyTorch version: 2.2.2


In [2]:
using PyCall
np = pyimport("numpy")
println("NumPy version: ", np.__version__)

NumPy version: 1.26.4


In [3]:
using Flux
using Flux: onehotbatch, onecold, crossentropy, throttle, glorot_uniform
using Statistics
using StatsBase: sample, Weights
using Random
using LinearAlgebra

In [4]:
using CUDA  # Leveraged CUDA.jl for GPU-accelerated tensor operations.
CUDA.versioninfo()

CUDA runtime 12.6, artifact installation
CUDA driver 12.7
NVIDIA driver 565.90.0

CUDA libraries: 
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+565.90

Julia packages: 
- CUDA: 5.6.1
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0

Toolchain:
- Julia: 1.11.0
- LLVM: 16.0.6

Preferences:
- CUDA_Runtime_jll.version: 12.6

1 device:
  0: NVIDIA GeForce MX150 (sm_61, 3.918 GiB / 4.000 GiB available)


In [5]:
using ProgressMeter  # For training progress visualization
using Zygote  # For more advanced gradient computation
using NNlib  # For additional neural network functions
using BSON  # For model serialization
using PyCall   # To call Hugging Face Python APIs # Used PyCall to interface with Hugging Face Python libraries.

In [6]:
using Conda
# Conda.add("datasets")
# Conda.add("transformers")  # Install via Conda

In [7]:
# ---------------------------
# Hugging Face Pipelines Integration
# ---------------------------
# This section sets up Hugging Face pipelines via PyCall.
function init_huggingface_pipelines()
    # Use pyimport_conda to auto-install missing packages
    transformers = pyimport_conda("transformers", "transformers")
    huggingface_hub = pyimport_conda("huggingface_hub", "huggingface_hub")
    
    # Import Python modules
    os = pyimport("os")
    transformers = pyimport("transformers")
    huggingface_hub = pyimport("huggingface_hub")
    
    # Set the Hugging Face token as an environment variable
    HF_TOKEN = "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"
    os.environ["HF_TOKEN"] = HF_TOKEN
    
    # Login using the token
    huggingface_hub.login(token=HF_TOKEN)
    
    # Create pipelines for various NLP tasks
    text_classifier = transformers.pipeline("text-classification", 
                                              model="distilbert-base-uncased-finetuned-sst-2-english")
    intent_recognition_model = transformers.pipeline("zero-shot-classification", 
                                                     model="facebook/bart-large-mnli")
    ner_model = transformers.pipeline("ner", 
                                      model="dbmdz/bert-large-cased-finetuned-conll03-english")
    
    return (text_classifier, intent_recognition_model, ner_model)
end

# Test the Hugging Face pipelines with a sample text
function test_huggingface()
    text_classifier, intent_recognition_model, ner_model = init_huggingface_pipelines()
    
    sample_text = "I love Julia and machine learning!"
    
    println("Text Classification:")
    println(text_classifier(sample_text))
    
    candidate_labels = ["positive", "negative", "neutral"]
    println("\nZero-Shot Classification:")
    println(intent_recognition_model(sample_text, candidate_labels))
    
    println("\nNamed Entity Recognition:")
    println(ner_model(sample_text))
end

# The functions init_huggingface_pipelines() and test_huggingface() set up and test various Hugging Face pipelines (e.g. text classification, zero-shot classification, and NER).

test_huggingface (generic function with 1 method)

In [8]:
# ---------------------------
# Hugging Face Datasets Integration
# ---------------------------
# Loads a dataset from the Hugging Face Hub.
function load_hf_dataset(dataset_name::String, config_name::String)
    datasets = pyimport("datasets")
    # Download the dataset (here we assume a split named "train")
    dataset = datasets.load_dataset(dataset_name, config_name)
    # Convert the "train" split into a Julia array of strings.
    # (Assuming the text is stored in a field called "text")
    texts = [string(x["text"]) for x in dataset["train"]]
    return texts
end

# The new function load_hf_dataset(dataset_name, config_name) uses the Hugging Face datasets library (via PyCall) to load a dataset (here, “wikitext-2-raw-v1”) and returns the training texts.

load_hf_dataset (generic function with 1 method)

In [9]:
# Token Embeddings: 
    # Learned embeddings with scaling by `√embed_dim`.

# Positional Embeddings:
    # Options: Learned (`LearnedPositionalEmbedding`) or Sinusoidal.
    # RoPE (Rotary Positional Embeddings): Applied rotation to query/key tensors.
    # ALiBi: Added attention bias based on token distances.

In [10]:
# ---------------------------
# Enhanced Token and Positional Embeddings with Weight Tying
# ---------------------------
struct TokenEmbedding
    emb::AbstractArray{Float32,2}
    dropout::Dropout
    scale_factor::Float32  # Scale embeddings by sqrt(dim)
end

TokenEmbedding(vocab_size::Int, embed_dim::Int; dropout_prob=0.1, scale=true) = 
    TokenEmbedding(glorot_uniform(embed_dim, vocab_size), Dropout(dropout_prob), scale ? sqrt(Float32(embed_dim)) : 1.0f0)

function (te::TokenEmbedding)(x::AbstractArray{Int})
    # Returns embedding with dropout applied and optional scaling
    return te.dropout(te.emb[:, x] .* te.scale_factor)
end

# Learned positional embeddings
struct LearnedPositionalEmbedding
    emb::AbstractArray{Float32,2}
    dropout::Dropout
    max_seq_len::Int  # Store max sequence length for potential extrapolation
end

LearnedPositionalEmbedding(seq_len::Int, embed_dim::Int; dropout_prob=0.1) = 
    LearnedPositionalEmbedding(glorot_uniform(embed_dim, seq_len), Dropout(dropout_prob), seq_len)

function (pe::LearnedPositionalEmbedding)(T::Int)
    if T > pe.max_seq_len
        orig_emb = pe.emb[:, 1:pe.max_seq_len]
        extra_positions = T - pe.max_seq_len
        last_pos_diff = orig_emb[:, end] - orig_emb[:, end-1]
        extra_emb = hcat([orig_emb[:, end] .+ i .* last_pos_diff for i in 1:extra_positions]...)
        extended_emb = hcat(orig_emb, extra_emb)
        return pe.dropout(extended_emb)
    else
        return pe.dropout(pe.emb[:, 1:T])
    end
end

# Sinusoidal positional embeddings (alternative to learned)
function sinusoidal_position_embedding(seq_len::Int, embed_dim::Int)
    pe = zeros(Float32, embed_dim, seq_len)
    for pos in 1:seq_len
        for i in 0:2:embed_dim-1
            freq = 1.0 / (10000.0^(i/embed_dim))
            pe[i+1, pos] = sin(pos * freq)
            if i+1 < embed_dim
                pe[i+2, pos] = cos(pos * freq)
            end
        end
    end
    return pe
end

struct SinusoidalPositionalEmbedding
    emb::AbstractArray{Float32,2}
    dropout::Dropout
    max_seq_len::Int
    embed_dim::Int
end

function SinusoidalPositionalEmbedding(seq_len::Int, embed_dim::Int; dropout_prob=0.1)
    emb = sinusoidal_position_embedding(seq_len, embed_dim)
    return SinusoidalPositionalEmbedding(emb, Dropout(dropout_prob), seq_len, embed_dim)
end

function (pe::SinusoidalPositionalEmbedding)(T::Int)
    if T > pe.max_seq_len
        extended_emb = sinusoidal_position_embedding(T, pe.embed_dim)
        return pe.dropout(extended_emb)
    else
        return pe.dropout(pe.emb[:, 1:T])
    end
end

# ALiBi positional encoding
struct ALiBiPositionalEncoding
    slopes::AbstractArray{Float32, 1}
    max_seq_len::Int
end

# In the ALiBiPositionalEncoding constructor:
function ALiBiPositionalEncoding(num_heads::Int, max_seq_len::Int)
    base = 2^(-(8/num_heads))
    slopes = [base^(i-1) for i in 1:num_heads]
    # Transfer to GPU if available
    return ALiBiPositionalEncoding(CUDA.functional() ? CuArray(slopes) : slopes, max_seq_len)
end

ALiBiPositionalEncoding

In [11]:
# ---------------------------
# Rotary Position Embeddings (RoPE)
# ---------------------------
struct RotaryPositionEmbedding
    dim::Int
    max_seq_len::Int
    freqs_cos::AbstractArray{Float32, 2}
    freqs_sin::AbstractArray{Float32, 2}
    base::Float32
end

function RotaryPositionEmbedding(dim::Int, max_seq_len::Int; base::Float32=10000.0f0)
    theta = base .^ (-(0:2:dim-2) ./ dim)
    freqs = repeat(theta, 1, max_seq_len) .* repeat(reshape(1:max_seq_len, 1, :), length(theta), 1)
    freqs_cos = cos.(freqs)
    freqs_sin = sin.(freqs)
    return RotaryPositionEmbedding(dim, max_seq_len, freqs_cos, freqs_sin, base)
end

function extend_rope_context(rope::RotaryPositionEmbedding, new_seq_len::Int)
    if new_seq_len <= rope.max_seq_len
        return rope
    end
    theta = rope.base .^ (-(0:2:rope.dim-2) ./ rope.dim)
    freqs = repeat(theta, 1, new_seq_len) .* repeat(reshape(1:new_seq_len, 1, :), length(theta), 1)
    freqs_cos = cos.(freqs)
    freqs_sin = sin.(freqs)
    return RotaryPositionEmbedding(rope.dim, new_seq_len, freqs_cos, freqs_sin, rope.base)
end

function rotate_half(x::AbstractArray{Float32, 4})
    head_dim, seq_len, batch_size, num_heads = size(x)
    half_dim = head_dim ÷ 2
    x1 = x[1:2:head_dim, :, :, :]
    x2 = x[2:2:head_dim, :, :, :]
    return cat(-x2, x1, dims=1)
end

function apply_rotary_pos_emb(q::AbstractArray{Float32, 4}, k::AbstractArray{Float32, 4}, 
                              freqs_cos::AbstractArray{Float32, 2}, freqs_sin::AbstractArray{Float32, 2}, 
                              T::Int)
    head_dim, seq_len, batch_size, num_heads = size(q)
    effective_dim = min(head_dim, size(freqs_cos, 1))
    effective_seq_len = min(T, size(freqs_cos, 2))
    cos_pos = freqs_cos[1:effective_dim, 1:effective_seq_len]
    sin_pos = freqs_sin[1:effective_dim, 1:effective_seq_len]
    cos_pos = reshape(cos_pos, effective_dim, effective_seq_len, 1, 1)
    sin_pos = reshape(sin_pos, effective_dim, effective_seq_len, 1, 1)
    if effective_dim < head_dim
        cos_pos = vcat(cos_pos, zeros(Float32, head_dim - effective_dim, effective_seq_len, 1, 1))
        sin_pos = vcat(sin_pos, zeros(Float32, head_dim - effective_dim, effective_seq_len, 1, 1))
    end
    if effective_seq_len < seq_len
        cos_pos = cat(cos_pos, zeros(Float32, head_dim, seq_len - effective_seq_len, 1, 1), dims=2)
        sin_pos = cat(sin_pos, zeros(Float32, head_dim, seq_len - effective_seq_len, 1, 1), dims=2)
    end
    q_rot = q .* cos_pos .+ rotate_half(q) .* sin_pos
    k_rot = k .* cos_pos .+ rotate_half(k) .* sin_pos
    return q_rot, k_rot
end 

apply_rotary_pos_emb (generic function with 1 method)

In [12]:
# Multi-Head Attention:
    # Flash Attention: Optimized memory usage via block-wise computation.
    # Causal Masking: Used `CUDA.tril` for GPU-compatible triangular masks.
    # Key-Value (KV) Caching: Enabled autoregressive generation.

In [13]:
# ---------------------------
# Advanced Multi-Head Self-Attention with Flash Attention and ALiBi
# ---------------------------
struct MultiHeadAttention
    W_q::Dense
    W_k::Dense
    W_v::Dense
    W_o::Dense
    num_heads::Int
    head_dim::Int
    attn_dropout::Dropout
    resid_dropout::Dropout
    rope::Union{RotaryPositionEmbedding, Nothing}
    alibi::Union{ALiBiPositionalEncoding, Nothing}
    use_flash_attn::Bool
    kv_cache::Dict{Symbol, Any}
end

function MultiHeadAttention(embed_dim::Int, num_heads::Int; 
                           attn_dropout_prob=0.1, 
                           resid_dropout_prob=0.1,
                           use_rope::Bool=true,
                           use_alibi::Bool=false,
                           max_seq_len::Int=1024,
                           use_flash_attn::Bool=false)
    head_dim = div(embed_dim, num_heads)
    @assert head_dim * num_heads == embed_dim "embed_dim must be divisible by num_heads"
    rope = use_rope ? RotaryPositionEmbedding(head_dim, max_seq_len) : nothing
    alibi = use_alibi ? ALiBiPositionalEncoding(num_heads, max_seq_len) : nothing
    kv_cache = Dict{Symbol, Any}(:keys => nothing, :values => nothing, :seq_len => 0)
    return MultiHeadAttention(
        Dense(embed_dim, embed_dim),
        Dense(embed_dim, embed_dim),
        Dense(embed_dim, embed_dim),
        Dense(embed_dim, embed_dim),
        num_heads,
        head_dim,
        Dropout(attn_dropout_prob),
        Dropout(resid_dropout_prob),
        rope,
        alibi,
        use_flash_attn,
        kv_cache
    )
end

function compute_attention(Q::AbstractArray{Float32, 4}, K::AbstractArray{Float32, 4}, V::AbstractArray{Float32, 4}, 
                           attn_dropout::Dropout, causal_mask::Bool=true, 
                           alibi::Union{ALiBiPositionalEncoding, Nothing}=nothing)
    head_dim, T, batch_size, num_heads = size(Q)
    Q_flat = permutedims(Q, (2, 3, 4, 1))
    K_flat = permutedims(K, (2, 3, 4, 1))
    V_flat = permutedims(V, (2, 3, 4, 1))
    Q_flat = reshape(Q_flat, T, batch_size * num_heads, head_dim)
    K_flat = reshape(K_flat, T, batch_size * num_heads, head_dim)
    V_flat = reshape(V_flat, T, batch_size * num_heads, head_dim)
    Q_flat = permutedims(Q_flat, (3, 1, 2))
    K_flat = permutedims(K_flat, (3, 1, 2))
    V_flat = permutedims(V_flat, (3, 1, 2))
    scores = batched_mul(permutedims(Q_flat, (2, 1, 3)), K_flat) ./ sqrt(Float32(head_dim))
    if !isnothing(alibi)
        
        # Before (CPU arrays):
        pos_bias = zeros(Float32, T, T, num_heads)
        for h in 1:num_heads
            for i in 1:T
                for j in 1:T
                    pos_bias[i, j, h] = -abs(i - j) * alibi.slopes[h]
                end
            end
        end

        # After (GPU arrays):
        slopes_gpu = alibi.slopes |> gpu  # Ensure slopes are on GPU
        i_j = reshape(1:T, 1, :) .- reshape(1:T, :, 1)  # Compute |i-j|
        pos_bias = -abs.(i_j) .* reshape(slopes_gpu, 1, 1, :)
        pos_bias = CuArray(pos_bias)  # Move to GPU explicitly

        pos_bias = reshape(pos_bias, T, T, num_heads, 1)
        pos_bias = repeat(pos_bias, 1, 1, 1, batch_size)
        pos_bias = reshape(pos_bias, T, T, batch_size * num_heads)
        scores = scores .+ pos_bias
    end
    if causal_mask
        # Before (CPU mask):
        mask = tril(ones(Float32, T, T))
        # After (GPU mask):
        mask = CUDA.tril(CUDA.ones(Float32, T, T))
        mask = reshape(mask, T, T, 1)
        scores = scores .* mask .+ (1.0f0 .- mask) .* -1.0f7
    end
    attn_weights = softmax(scores, dims=2)
    attn_weights = attn_dropout(attn_weights)
    out = batched_mul(attn_weights, permutedims(V_flat, (2, 1, 3)))
    out = permutedims(out, (2, 1, 3))
    out = reshape(out, head_dim, T, batch_size, num_heads)
    return out
end

function flash_attention(Q::AbstractArray{Float32}, K::AbstractArray{Float32}, V::AbstractArray{Float32}, 
                         attn_dropout::Dropout, causal_mask::Bool=true, block_size::Int=128,
                         alibi::Union{ALiBiPositionalEncoding, Nothing}=nothing)
    head_dim, T, batch_size, num_heads = size(Q)
    O = zeros(Float32, head_dim, T, batch_size, num_heads)
    L = zeros(Float32, 1, T, batch_size, num_heads)
    m = fill(-Inf32, 1, T, batch_size, num_heads)
    for b_q in 1:block_size:T
        e_q = min(b_q + block_size - 1, T)
        Q_block = Q[:, b_q:e_q, :, :]
        max_k_idx = causal_mask ? e_q : T
        for b_k in 1:block_size:max_k_idx
            e_k = min(b_k + block_size - 1, max_k_idx)
            K_block = K[:, b_k:e_k, :, :]
            V_block = V[:, b_k:e_k, :, :]
            Q_perm = permutedims(Q_block, (3, 4, 2, 1))
            K_perm = permutedims(K_block, (3, 4, 2, 1))
            V_perm = permutedims(V_block, (3, 4, 2, 1))
            scores = batched_mul(Q_perm, permutedims(K_perm, (1, 2, 4, 3))) ./ sqrt(Float32(head_dim))
            if !isnothing(alibi)
                for h in 1:num_heads
                    for i in b_q:e_q
                        for j in b_k:e_k
                            if i <= T && j <= max_k_idx
                                rel_i = i - b_q + 1
                                rel_j = j - b_k + 1
                                scores[:, h, rel_i, rel_j] .-= abs(i - j) * alibi.slopes[h]
                            end
                        end
                    end
                end
            end
            if causal_mask
                for i in b_q:e_q
                    for j in b_k:e_k
                        if j > i
                            rel_i = i - b_q + 1
                            rel_j = j - b_k + 1
                            scores[:, :, rel_i, rel_j] .= -1.0f7
                        end
                    end
                end
            end
            block_max = maximum(scores, dims=4)
            for i in 1:(e_q-b_q+1)
                idx = b_q + i - 1
                old_m = m[1, idx, :, :]
                new_m = max.(old_m, block_max[:, :, i, 1])
                scale_factor = exp.(old_m .- new_m)
                O[:, idx, :, :] .*= scale_factor
                L[1, idx, :, :] .*= scale_factor
                m[1, idx, :, :] = new_m
            end
            for i in 1:(e_q-b_q+1)
                idx = b_q + i - 1
                exp_scores = exp.(scores[:, :, i, :] .- reshape(m[1, idx, :, :], batch_size, num_heads, 1))
                exp_scores = attn_dropout(exp_scores)
                weighted_values = batched_mul(exp_scores, V_perm)
                weighted_values = permutedims(weighted_values, (4, 1, 2, 3))
                O[:, idx, :, :] .+= weighted_values[:, :, :, 1]
                L[1, idx, :, :] .+= sum(exp_scores, dims=3)[:, :, 1]
            end
        end
    end
    O ./= L
    return O
end

function (mha::MultiHeadAttention)(x::AbstractArray{Float32,3}, 
                                   mask::Union{Nothing, AbstractArray{Float32}}=nothing;
                                   use_cache::Bool=false,
                                   is_causal::Bool=true)
    embed_dim, T, batch_size = size(x)
    Q = mha.W_q(x)
    K = mha.W_k(x)
    V = mha.W_v(x)
    function split_heads(t)
        return reshape(t, mha.head_dim, mha.num_heads, T, batch_size)
    end
    Qh = permutedims(split_heads(Q), (1, 3, 4, 2))
    Kh = permutedims(split_heads(K), (1, 3, 4, 2))
    Vh = permutedims(split_heads(V), (1, 3, 4, 2))
    if !isnothing(mha.rope)
        Qh, Kh = apply_rotary_pos_emb(Qh, Kh, mha.rope.freqs_cos, mha.rope.freqs_sin, T)
    end
    if use_cache
        if isnothing(mha.kv_cache[:keys])
            mha.kv_cache[:keys] = Kh
            mha.kv_cache[:values] = Vh
            mha.kv_cache[:seq_len] = T
        else
            prev_len = mha.kv_cache[:seq_len]
            mha.kv_cache[:keys] = cat(mha.kv_cache[:keys], Kh, dims=2)
            mha.kv_cache[:values] = cat(mha.kv_cache[:values], Vh, dims=2)
            mha.kv_cache[:seq_len] += T
            Kh = mha.kv_cache[:keys]
            Vh = mha.kv_cache[:values]
        end
    end
    if mha.use_flash_attn
        out = flash_attention(Qh, Kh, Vh, mha.attn_dropout, is_causal, 128, mha.alibi)
    else
        out = compute_attention(Qh, Kh, Vh, mha.attn_dropout, is_causal, mha.alibi)
    end
    out = reshape(permutedims(out, (1, 2, 4, 3)), embed_dim, T, batch_size)
    return mha.resid_dropout(mha.W_o(out))
end

In [14]:
# ---------------------------
# Improved MLP with SwiGLU Activation
# ---------------------------
struct SwiGLU
    W_up::Dense
    W_gate::Dense
    W_down::Dense
    dropout::Dropout
end

function SwiGLU(dim::Int, hidden_dim::Int; dropout_prob=0.1)
    return SwiGLU(
        Dense(dim, hidden_dim),
        Dense(dim, hidden_dim),
        Dense(hidden_dim, dim),
        Dropout(dropout_prob)
    )
end

function (swiglu::SwiGLU)(x::AbstractArray{Float32})
    up = swiglu.W_up(x)
    gate = swiglu.W_gate(x)
    swish_gate = gate .* sigmoid.(gate)
    hidden = swish_gate .* up
    return swiglu.dropout(swiglu.W_down(hidden))
end

struct GeGLU
    W_up::Dense
    W_gate::Dense
    W_down::Dense
    dropout::Dropout
end

function GeGLU(dim::Int, hidden_dim::Int; dropout_prob=0.1)
    return GeGLU(
        Dense(dim, hidden_dim),
        Dense(dim, hidden_dim),
        Dense(hidden_dim, dim),
        Dropout(dropout_prob)
    )
end

function (geglu::GeGLU)(x::AbstractArray{Float32})
    up = geglu.W_up(x)
    gate = geglu.W_gate(x)
    gelu_gate = NNlib.gelu.(gate)
    hidden = gelu_gate .* up
    return geglu.dropout(geglu.W_down(hidden))
end

In [15]:
# ---------------------------
# Enhanced Transformer Block with Pre/Post Norm Options
# ---------------------------
struct TransformerBlock
    attn::MultiHeadAttention
    norm1::LayerNorm
    ffn::Union{SwiGLU, GeGLU}
    norm2::LayerNorm
    pre_norm::Bool
    residual_scale::Float32
end

function TransformerBlock(embed_dim::Int, num_heads::Int; 
                         ffn_dim_mult=4, 
                         dropout_prob=0.1,
                         pre_norm=true,
                         use_rope=true,
                         use_alibi=false,
                         max_seq_len=1024,
                         use_flash_attn=false,
                         ffn_type=:swiglu,
                         residual_scale=1.0f0)
    attn = MultiHeadAttention(embed_dim, num_heads, 
                             attn_dropout_prob=dropout_prob, 
                             resid_dropout_prob=dropout_prob,
                             use_rope=use_rope,
                             use_alibi=use_alibi,
                             max_seq_len=max_seq_len,
                             use_flash_attn=use_flash_attn)
    norm1 = LayerNorm(embed_dim)
    if ffn_type == :swiglu
        ffn = SwiGLU(embed_dim, ffn_dim_mult * embed_dim, dropout_prob=dropout_prob)
    else
        ffn = GeGLU(embed_dim, ffn_dim_mult * embed_dim, dropout_prob=dropout_prob)
    end
    norm2 = LayerNorm(embed_dim)
    return TransformerBlock(attn, norm1, ffn, norm2, pre_norm, Float32(residual_scale))
end

function (block::TransformerBlock)(x::AbstractArray{Float32,3}; 
                                  mask=nothing, 
                                  use_cache=false,
                                  is_causal=true)
    if block.pre_norm
        attn_output = block.attn(block.norm1(x), mask; use_cache=use_cache, is_causal=is_causal)
        x = x + block.residual_scale .* attn_output
        x = x + block.residual_scale .* block.ffn(block.norm2(x))
    else
        x = block.norm1(x + block.residual_scale .* block.attn(x, mask; use_cache=use_cache, is_causal=is_causal))
        x = block.norm2(x + block.residual_scale .* block.ffn(x))
    end
    return x
end

In [16]:
# ---------------------------
# Enhanced GPT Model with Advanced Features
# ---------------------------
struct AdvancedGPT
    token_emb::TokenEmbedding
    pos_emb::Union{LearnedPositionalEmbedding, SinusoidalPositionalEmbedding, Nothing}
    blocks::Vector{TransformerBlock}
    ln_f::LayerNorm
    head           # no type annotation, so head can be any callable layer
    emb_dropout::Dropout
    config::Dict{Symbol, Any}
    weight_tying::Bool
end

struct TiedDense
    tied_matrix::AbstractArray{Float32,2}
end

function (td::TiedDense)(x)
    # Reshape 3D input (embed_dim, seq_len, batch_size) → 2D (embed_dim, seq_len * batch_size)
    x_2d = reshape(x, size(x, 1), :)
    # Perform matrix multiplication
    logits_2d = td.tied_matrix' * x_2d
    # Reshape back to 3D (vocab_size, seq_len, batch_size)
    return reshape(logits_2d, size(td.tied_matrix, 2), size(x, 2), size(x, 3))
end

function AdvancedGPT(;
    vocab_size::Int,
    max_seq_len::Int,
    embed_dim::Int,
    num_heads::Int,
    num_layers::Int,
    dropout_prob::Float64=0.1,
    pos_emb_type::Symbol=:learned,
    pre_norm::Bool=true,
    use_rope::Bool=true,
    use_alibi::Bool=false,
    use_flash_attn::Bool=false,
    ffn_dim_mult::Int=4,
    ffn_type::Symbol=:swiglu,
    weight_tying::Bool=true,
    residual_scale::Float64=1.0
)
    config = Dict{Symbol, Any}(
        :vocab_size => vocab_size,
        :max_seq_len => max_seq_len,
        :embed_dim => embed_dim,
        :num_heads => num_heads,
        :num_layers => num_layers,
        :dropout_prob => dropout_prob,
        :pos_emb_type => pos_emb_type,
        :pre_norm => pre_norm,
        :use_rope => use_rope,
        :use_flash_attn => use_flash_attn,
        :ffn_dim_mult => ffn_dim_mult
    )
    token_emb = TokenEmbedding(vocab_size, embed_dim, dropout_prob=0.0)
    if pos_emb_type == :learned
        pos_emb = LearnedPositionalEmbedding(max_seq_len, embed_dim, dropout_prob=0.0)
    else
        pos_emb = SinusoidalPositionalEmbedding(max_seq_len, embed_dim, dropout_prob=0.0)
    end
    emb_dropout = Dropout(dropout_prob)
    blocks = [
        TransformerBlock(
            embed_dim, 
            num_heads, 
            ffn_dim_mult=ffn_dim_mult, 
            dropout_prob=dropout_prob, 
            pre_norm=pre_norm,
            use_rope=use_rope,
            use_alibi=use_alibi,
            max_seq_len=max_seq_len,
            use_flash_attn=use_flash_attn,
            ffn_type=ffn_type,
            residual_scale=residual_scale
        ) for _ in 1:num_layers
    ]
    
    ln_f = LayerNorm(embed_dim)
    # In AdvancedGPT constructor:
    head = weight_tying ? TiedDense(token_emb.emb) : Dense(embed_dim, vocab_size)
    return AdvancedGPT(token_emb, pos_emb, blocks, ln_f, head, emb_dropout, config, weight_tying)
end

function (model::AdvancedGPT)(x::AbstractArray{Int}; return_all_layers=false)
    if ndims(x) == 1
        T = length(x)
        batch_size = 1
        x_reshaped = reshape(x, T, 1)
    else
        T, batch_size = size(x)
        x_reshaped = x
    end
    tok_emb = model.token_emb(x_reshaped)
    pos_emb = model.pos_emb(T)
    pos_emb_expanded = repeat(pos_emb, outer=[1, 1, batch_size])
    h = model.emb_dropout(tok_emb .+ pos_emb_expanded)
    activations = return_all_layers ? [h] : nothing
    for block in model.blocks
        h = block(h)
        if return_all_layers
            push!(activations, h)
        end
    end
    h = model.ln_f(h)
    # Reshape for TiedDense (embed_dim, T, batch_size) → (embed_dim, T * batch_size)
    h_reshaped = reshape(h, size(h, 1), :)
    logits = model.head(h_reshaped)
    # Reshape back to (vocab_size, T, batch_size)
    logits = reshape(logits, size(logits, 1), size(h, 2), size(h, 3))
    logits = permutedims(logits, (3, 1, 2))
    logits = model.head(h)
    logits = permutedims(logits, (3, 1, 2))
    if return_all_layers
        return logits, activations
    else
        return logits
    end
end

In [17]:
# ---------------------------
# Text Generation Utilities
# ---------------------------
function softmax_safe(x)
    if isa(x, Number)
        return [1.0f0]
    end
    x_vec = vec(collect(Float32, x))
    exp_x = exp.(x_vec .- maximum(x_vec))
    return exp_x ./ sum(exp_x)
end

function generate(model::AdvancedGPT, 
                 prompt::Vector{Int}; 
                 max_new_tokens::Int=100, 
                 temperature::Float32=1.0f0,
                 top_k::Int=0, 
                 top_p::Float32=0.9f0)
    max_len = model.config[:max_seq_len]
    tokens = copy(prompt)
    for _ in 1:max_new_tokens
        context = tokens[max(1, length(tokens) - max_len + 1):end]
        context = reshape(context, :, 1)
        logits = model(context)
        next_token_logits = Array{Float32}(vec(logits[1, :, end]))
        if temperature > 0
            next_token_logits ./= temperature
        else
            _, next_token = findmax(next_token_logits)
            push!(tokens, next_token)
            continue
        end
        if top_k > 0
            sorted_indices = sortperm(next_token_logits, rev=true)
            next_token_logits[sorted_indices[top_k+1:end]] .= -Inf32
        end
        if top_p < 1.0
            sorted_indices = sortperm(next_token_logits, rev=true)
            sorted_logits = next_token_logits[sorted_indices]
            exp_logits = exp.(sorted_logits .- maximum(sorted_logits))
            probs = exp_logits ./ sum(exp_logits)
            cumulative_probs = cumsum(probs)
            cutoff_idx = findfirst(cumulative_probs .> top_p)
            if !isnothing(cutoff_idx) && cutoff_idx > 1
                next_token_logits[sorted_indices[cutoff_idx:end]] .= -Inf32
            end
        end
        exp_logits = exp.(next_token_logits .- maximum(next_token_logits))
        probs = exp_logits ./ sum(exp_logits)
        next_token = sample(1:length(probs), Weights(probs))
        push!(tokens, next_token)
    end
    return tokens
end

generate (generic function with 1 method)

In [18]:
# ---------------------------
# Tokenization (Simple Character-level for demonstration)
# ---------------------------
struct CharTokenizer
    vocab::Dict{Char, Int}
    idx_to_char::Dict{Int, Char}
end

function CharTokenizer(text::String)
    unique_chars = unique(collect(text))
    vocab = Dict(char => i for (i, char) in enumerate(unique_chars))
    idx_to_char = Dict(i => char for (i, char) in enumerate(unique_chars))
    return CharTokenizer(vocab, idx_to_char)
end

function encode(tokenizer::CharTokenizer, text::String)
    return [tokenizer.vocab[c] for c in text if haskey(tokenizer.vocab, c)]
end

function decode(tokenizer::CharTokenizer, indices::Vector{Int})
    return join([tokenizer.idx_to_char[i] for i in indices if haskey(tokenizer.idx_to_char, i)])
end

decode (generic function with 1 method)

In [19]:
# ---------------------------
# Training Loop
# ---------------------------
function train_gpt!(model::AdvancedGPT, data::Vector{Int}, 
                   opt::Flux.Optimise.AbstractOptimiser;
                   batch_size::Int=16, 
                   bptt::Int=64,
                   epochs::Int=1,
                   lr::Float64=3e-4,
                   clip_norm::Float64=1.0)
    
    data_length = length(data)
    
    function get_batch()
        starts = rand(1:(data_length - bptt), batch_size)
        xs = zeros(Int, bptt, batch_size)
        ys = zeros(Int, bptt, batch_size)
        for (i, start) in enumerate(starts)
            end_idx = start + bptt - 1
            xs[:, i] = data[start:end_idx]
            ys[:, i] = data[(start+1):(end_idx+1)]
        end
        return xs, ys
    end
    
    steps_per_epoch = div(data_length, batch_size * bptt)
    total_steps = steps_per_epoch * epochs
    params = Flux.params(model)
    
    for epoch in 1:epochs
        epoch_loss = 0.0
        for step in 1:steps_per_epoch
            x_batch, y_batch = get_batch()
            loss, grads = Flux.withgradient(params) do
                logits = model(x_batch)
                logits_flat = reshape(permutedims(logits, (2, 3, 1)), :, batch_size)
                targets_flat = reshape(y_batch, :)
                return crossentropy(logits_flat, targets_flat)
            end
            if clip_norm > 0
                Flux.Optimise.clip!(grads, clip_norm)
            end
            Flux.Optimise.update!(opt, params, grads)
            epoch_loss += loss
            if step % 10 == 0
                println("Epoch: $epoch, Step: $step/$steps_per_epoch, Loss: $(loss)")
            end
        end
        avg_loss = epoch_loss / steps_per_epoch
        println("Epoch $epoch completed. Average loss: $avg_loss")
    end
end

train_gpt! (generic function with 1 method)

In [25]:
# ---------------------------
# Optimized Training Loop
# ---------------------------
function optimized_train_gpt!(model::AdvancedGPT, data::Vector{Int}, 
                              opt::Flux.Optimise.AbstractOptimiser;
                              batch_size::Int=16, 
                              bptt::Int=64,
                              epochs::Int=1,
                              lr::Float64=3e-4,
                              clip_norm::Float64=1.0)
    
    data_length = length(data)
    
    function get_batch()
        # Ensure we don't go out of bounds
        max_start = data_length - bptt
        if max_start < 1
            error("Data length is too short for the given bptt value")
        end
        
        starts = rand(1:max_start, batch_size)
        xs = zeros(Int, bptt, batch_size)
        ys = zeros(Int, bptt, batch_size)
        for (i, start) in enumerate(starts)
            end_idx = start + bptt - 1
            xs[:, i] = data[start:end_idx]
            ys[:, i] = data[(start+1):(end_idx+1)]
        end
        return xs, ys
    end
    
    # Calculate steps more accurately
    steps_per_epoch = max(1, div(data_length - bptt, batch_size))
    total_steps = steps_per_epoch * epochs
    
    # Set learning rate if the optimizer supports it
    if hasfield(typeof(opt), :eta)
        opt.eta = lr
    end
    
    for epoch in 1:epochs
        epoch_loss = 0.0
        for step in 1:steps_per_epoch
            x_batch, y_batch = get_batch()
            
            # Move data to CPU to avoid GPU kernel issues
            x_batch = Array(x_batch)
            y_batch = Array(y_batch)
            
            # Use explicit gradient computation with Flux
            # Pass the model directly instead of its parameters
            loss, grads = Flux.withgradient(model) do m
                # Forward pass through the model
                logits = m(x_batch)
                # Reshape logits for crossentropy calculation
                # The model output shape is (seq_len, batch_size, vocab_size)
                logits_flat = reshape(logits, size(logits, 3), :)
                targets_flat = reshape(y_batch, :)
                return Flux.crossentropy(logits_flat, targets_flat)
            end
            
            # Apply gradient clipping if needed
            if clip_norm > 0
                Flux.Optimise.clip!(grads, clip_norm)
            end
            
            # Update parameters
            Flux.Optimise.update!(opt, Flux.params(model), grads)
            
            epoch_loss += loss
            if step % 10 == 0
                println("Epoch: $epoch, Step: $step/$steps_per_epoch, Loss: $(loss)")
            end
        end
        avg_loss = epoch_loss / steps_per_epoch
        println("Epoch $epoch completed. Average loss: $avg_loss")
    end
end

# Ensure CUDA is functional if available
if CUDA.functional()
    # Use the recommended approach for scalar operations
    println("CUDA is available and functional")
else
    println("CUDA is not available, using CPU")
end

# Define and initialize the model
vocab_size = 10000  # Example vocabulary size
max_seq_len = 128  # Example maximum sequence length
embed_dim = 256  # Example embedding dimension
num_heads = 8  # Example number of attention heads
num_layers = 6  # Example number of transformer layers
dropout_prob = 0.1  # Example dropout probability

model = AdvancedGPT(
    vocab_size=vocab_size,
    max_seq_len=max_seq_len,
    embed_dim=embed_dim,
    num_heads=num_heads,
    num_layers=num_layers,
    dropout_prob=dropout_prob
)

# Move model to CPU to avoid GPU kernel issues
model = cpu(model)

# Example data (replace with your actual data)
# Ensure data is long enough for training
data = collect(1:10000)  # Example data

# Define the optimizer with the specified learning rate
opt = Flux.Optimise.Adam(3e-4)

# Run the training for 5 epochs
CUDA.@allowscalar optimized_train_gpt!(model, data, opt, epochs=5)

CUDA is available and functional


ErrorException: `llvmcall` requires the compiler