TinyAlign is a retrieval-augmented lightweight vision-language modeling project built on top of TinyLLaVA Factory.
This repository contains the official code for TinyAlign: Boosting Lightweight Vision-Language Models by Mitigating Modal Alignment Bottlenecks, accepted to Findings of ACL 2026.
Compared with the original TinyLLaVA codebase, this project mainly adds:
- memory bank construction for retrieval-augmented alignment
- memory retrieval during training and inference
- an extra retrieval connector (
connector2) to inject retrieved latent features into the multimodal sequence <rag>-aware prompt/token handling in the data template layer
The goal is to improve lightweight VLM alignment by retrieving relevant latent context from an external memory bank instead of relying only on the frozen vision encoder and language model.
tinyllava/model/modeling_tinyllava.py: core retrieval-augmented multimodal forward pathtinyllava/model/configuration_tinyllava.py: model config, including retrieval parameterstinyllava/data/template/base.py:<image>and<rag>token processingtinyllava/train/train.py: training entrytinyllava/training_recipe/base.py: checkpoint save/load, includingconnector2docs/CODEBASE_OVERVIEW.md: concise code walkthrough for this refactored release
conda create -n tinyalign python=3.10 -y
conda activate tinyalign
pip install --upgrade pip
pip install -e .If you use FlashAttention in your environment:
pip install flash-attn --no-build-isolationThe retrieval path expects a memory directory containing:
- a FAISS index file, default:
Merged_faiss.index - a serialized value store, default:
Merged_LLaVA_Dataset_Memory.pt
You can now configure these paths through arguments instead of editing source code:
--retrieval_memory_dir--retrieval_index_file--retrieval_value_file--top_rag--retrieval_text_start--retrieval_alpha
This is the core addition of TinyAlign.
The memory bank is not a plain text knowledge base. Each memory item is a paired (key, value) built from image-caption supervision data:
key: a compressed multimodal query vectorvalue: a compact latent representation produced by a Perceiver-style encoder
In the current codebase, the reference implementation is scattered in demo.py, demo1.py, and build_memory.ipynb.
For each image-text pair in the pretraining set:
- Encode the image with the TinyLLaVA vision tower and the original multimodal connector.
- Tokenize the paired caption text and obtain text token embeddings from the language model embedding table.
- Normalize image features and text features, then fuse them with a weighted mixture:
multimodal_features = concat(alpha * image_features, (1 - alpha) * text_features) - Compute self-similarity over the fused multimodal sequence.
- Average the attention map and use it to compress the multimodal sequence into a single query vector.
- Feed the original image and caption into a Perceiver multimodal encoder.
- Use the Perceiver latent output as the retrieval value.
- Store the pair:
- key: compressed multimodal vector
- value: Perceiver latent tensor
After collecting all pairs:
- all keys are added into a FAISS index for nearest-neighbor search
- all values are stored in a tensor file and aligned with FAISS ids
At runtime, the model:
- rebuilds a query vector from the current input image and text
- retrieves top-
knearest memory items from FAISS - concatenates the retrieved values
- projects them through
connector2 - injects the projected retrieval features into the multimodal token sequence
From your current code:
- the Perceiver value is typically a latent tensor shaped like
32 x 96 - multiple retrieved items are concatenated along the latent width dimension
connector2maps the retrieved latent features into the language model hidden size
demo.py: end-to-end prototype for constructing compressed keys and latent valuesdemo1.py: smaller-scale memory construction examplebuild_memory.ipynb: notebook experiments for building, saving, merging, and querying the memory bank
The repository now includes a standalone script:
scripts/build_memory_bank.py
It does two jobs:
- build shard files containing
(key, value)pairs - merge the shards into:
Merged_faiss.indexMerged_LLaVA_Dataset_Memory.pt
By default, the script assumes a LLaVA-style JSON list where each sample contains:
{
"image": "relative/path/to/image.jpg",
"conversations": [
{"from": "human", "value": "..."},
{"from": "gpt", "value": "caption or target text"}
]
}If your caption is stored in another field, pass --caption-field.
Build shards and merge immediately:
python scripts/build_memory_bank.py \
--model-path /path/to/tinyllava_base_checkpoint \
--dataset-json /path/to/pretrain.json \
--image-root /path/to/images \
--perceiver-tokenizer /path/to/perceiver_tokenizer \
--output-dir /path/to/memory_bank \
--save-every 5000 \
--merge-after-buildIf you already built shards and only want to merge them:
python scripts/build_memory_bank.py \
--output-dir /path/to/memory_bank \
--merge-onlyAfter a successful run, output-dir contains:
memory_bank/
shards/
memory_shard_00000.pt
memory_shard_00001.pt
...
Merged_faiss.index
Merged_LLaVA_Dataset_Memory.pt
The final merged tensor file stores:
keys: normalized compressed multimodal query vectorsvalues: Perceiver latent tensors aligned with FAISS ids
Once the memory bank is built, point the model to it with:
--retrieval_memory_dir /path/to/memory_bankOptionally override filenames if you changed them:
--retrieval_index_file Merged_faiss.index
--retrieval_value_file Merged_LLaVA_Dataset_Memory.ptExample command structure:
python -m tinyllava.train.train \
--model_name_or_path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
--vision_tower google/siglip-so400m-patch14-384 \
--connector_type mlp2x_gelu \
--connector2_type rag2x_gelu \
--data_path /path/to/data.json \
--image_folder /path/to/images \
--retrieval_memory_dir /path/to/memory_bank \
--output_dir /path/to/outputIf you initialize from a TinyLLaVA checkpoint and keep connector2 separately, use:
--pretrained_model_path /path/to/base_checkpoint \
--pretrained_connector2_path /path/to/connector2The main inference entry is:
python -m tinyllava.eval.run_tiny_llava \
--model-path /path/to/model \
--image-file /path/to/image.jpg \
--query "Describe the image."Make sure the released model config points to the correct retrieval memory directory before inference.
- This repository keeps the original package name
tinyllavafor code compatibility. - The refactor removes local debug hooks and hard-coded absolute paths that were tied to the author's machine.
This codebase is built on top of TinyLLaVA Factory. Credit goes to the original authors for the base multimodal training framework.
@article{hu2025tinyalign,
title={TinyAlign: Boosting Lightweight Vision-Language Models by Mitigating Modal Alignment Bottlenecks},
author={Hu, Yuanze and Fan, Zhaoxin and Wang, Xinyu and Li, Gen and Qiu, Ye and Yang, Zhichao and Wu, Wenjun and Wu, Kejian and Sun, Yifan and Deng, Xiaotie and Dong, Jin},
journal={arXiv preprint arXiv:2505.12884},
year={2025}
}