-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Official transformer architecture support for Dlib #3124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
…des an optimized linear transformation for multi-dimensional inputs.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…tion-free tokenization
| { | ||
| const double scale = std::sqrt(area/static_cast<double>(rect.area())); | ||
| return centered_rect(rect, std::lround(rect.width()*scale), std::lround(rect.height()*scale)); | ||
| // Le compilateur sait maintenant que rect_area != 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guessing comments should be in English :)
| static constexpr long EMBEDDING_DIM = embedding_dim; | ||
| static constexpr long PATCH_SIZE = 4; // 32/4 = 8x8 = 64 patches | ||
| static constexpr long NUM_PATCHES = 64; // (32/4)^2 | ||
| static constexpr long DONT_USE_ClASS_TOKEN = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why lower case l in ClASS?
| using namespace std; | ||
| using namespace dlib; | ||
|
|
||
| // Signal handling for clean termination |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This signal handling seems to be duplicated in various different examples. Could it be extracted into a common utility?
Official transformer architecture support for Dlib
This pull request represents the consolidation and stabilization of transformer-related layers and components developed throughout 2024~2025. This substantial commit introduces official support for modern language modeling in Dlib, positioning the library as a new reference implementation for neural network building in natural language processing.
The extensions have been iteratively refined over the past year, with each component tested and validated across multiple architectures and use cases. This work establishes the foundation for upcoming multimodal capabilities, with active development underway for vision transformers and combined text-image processing.
Future releases will introduce examples demonstrating transformer architectures for image processing, followed by multimodal fusion combining textual and visual information. This PR represents an important milestone that could justify a new version of Dlib to mark the official introduction of these features.
Overview
This pull request introduces complete transformer architecture support to Dlib, enabling modern language modeling capabilities while maintaining Dlib's philosophy of simple APIs and production-ready implementations. All components are written in standard C++14 for cross-platform compatibility.
Major additions
Core architectural components
Attention mechanisms:
Padding-aware attention:
tril_padding_contextfor dynamic per-sample padding mask coordinationloss_cross_entropy_per_logitignore indexSpecialized layers:
linearlayer with plane-wise matrix multiplication for sequence processingrms_normlayer implementing efficient RMS normalizationreshape_tolayer for dimension manipulation without data copyingtoken_embeddingslayer combining embedding lookup with positional encodingtrillayer for triangular mask generationtransposeandmultm_prevlayers for attention computationdropout_ratelayer with configurable per-layer dropout schedulesVision Transformer support:
patch_embeddingslayer for image-to-sequence conversion with configurable patch sizeAdvanced architectures:
Optimization infrastructure
AdamW optimizer (
dlib/dnn/solvers.h):Learning rate scheduling (
lr_scheduler):Language modeling utilities
Dataset preparation (
language_model_data.h):build_single_token_prediction_dataset()for autoregressive trainingbuild_multi_token_prediction_dataset()for sequence-to-sequence tasksshuffle_training_dataset()for data randomizationaugment_training_dataset()for noise injection and robustness improvementInference management:
inference_contextclass for autoregressive generation with sliding windowEvaluation metrics:
compute_text_similarity()combining all metricsPreprocessing:
detect_file_type()supporting 30+ formats via magic numbers and entropy analysisComplete transformer implementations
Canonical transformer (
canonical_transformernamespace):transformer_blockcombining attention and feed-forward networkstransformer_stackfor building deep architecturesFused transformer (
fused_transformernamespace):Loss functions
Cross-entropy per logit (
loss_cross_entropy_per_logit):Example programs
Four progressive examples demonstrate the capabilities:
slm_basic_train_ex.cpp: character-level transformer training on Shakespeare text. Demonstrates fundamental attention mechanics and memorization capability.
slm_advanced_train_ex.cpp: BPE tokenization with compact architecture. Introduces specialized loss function and byte-for-byte verification.
slm_mixture_of_experts_ex.cpp: sparse conditional computation with production-grade utilities. Demonstrates shuffle and augmentation utilities for robust training.
slm_chatbot_ex.cpp: conversational AI training pipeline with two-stage approach. Demonstrates base language model pre-training followed by supervised fine-tuning on question-answer pairs. Includes stochastic text generation with configurable sampling strategies (temperature, repetition penalty, nucleus sampling) and layer-wise learning rate multipliers for efficient fine-tuning. Shows practical implementation of interactive chatbot with context management.
slm_vision_transformer_hybrid_ex.cpp: hybrid ViT combining patch embeddings with transformer encoder. Showcases two training processes: self-supervised feature learning and supervised learning.
Technical design
Matrix plane processing
Traditional Dlib layers operate channel-wise on 4D tensors. The extensions introduce plane-wise processing where
(rows, cols)dimensions form semantic units for sequence data. This enables:(batch, 1, sequence_length, embedding_dim)Variable-length sequence handling
Training with batched sequences of different lengths requires coordinated masking:
tril_padding_context::set()computes per-sample padding lengths before forward passtril_layer consults context to extend causal mask over padding tokensloss_cross_entropy_per_logit::set_ignore_index()excludes padding from loss computationtril_padding_context::clear()after training stepImplementation approach
All components follow Dlib's design patterns:
Testing and validation
The example programs demonstrate:
Main files modified/added
New headers:
dlib/dnn/transformer.h- complete transformer implementationsdlib/dnn/layers_transformer.h- specialized layers for sequence processingdlib/dnn/language_model_data.h- utilities for dataset preparation and evaluationdlib/tokenizer/bpe_tokenizer.h- byte-pair encoding tokenizationdlib/dnn/solvers.h- AdamW optimizer additionNew examples:
examples/slm_basic_train_ex.cppexamples/slm_advanced_train_ex.cppexamples/slm_mixture_of_experts_ex.cppexamples/slm_chatbot_ex.cppexamples/slm_data.h- internal datasets for examplesexamples/slm_vision_transformer_hybrid_ex.cppAbstract documentation:
docs/layers_abstract.h- layer specifications and usage patternsdocs/transformer_abstract.h- transformer architecture documentationdocs/language_model_data_abstract.h- language modeling utility documentationdocs/solvers_abstract.h- AdamW optimizer specificationExtended documentation
For more details, see the dedicated repository: https://github.com/Cydral/Dlib-Transformer-extensions
This contribution establishes official transformer support in Dlib, extending the library into modern natural language processing while maintaining its core values of simplicity, performance, and production readiness. The groundwork laid here enables upcoming vision transformer implementations and multimodal architectures, positioning Dlib as a comprehensive framework for contemporary deep learning applications.