Memory Decoder: A Pretrained, Plug-and-Play Memory for Large Language Models
NeurIPS 2025 Poster
Memory Decoder introduces a novel paradigm for domain adaptation that bridges the gap between non-parametric retrieval methods and parametric fine-tuning approaches. By pre-training a compact transformer decoder to internalize retrieval patterns, Memory Decoder provides the benefits of both worlds:
- β¨ Plug-and-Play: A single Memory Decoder enhances any model sharing the same tokenizer
- π Efficient Inference: No retrieval overhead - just parallel forward passes
- π― Domain Expertise: Captures long-tail knowledge like non-parametric methods
- π Preserves Capabilities: Original model parameters remain unchanged
Unlike traditional approaches that either require expensive retraining (DAPT) or introduce significant inference latency (RAG), Memory Decoder offers efficient domain adaptation through a pretrained memory component that seamlessly integrates with existing models.
We run on CUDA 12.4 with the following core dependencies:
- faiss-gpu 1.12.0 (w/o cuvs)
- PyTorch 2.6.0
- transformers 4.55.4
- datasets 4.0.0
conda install -c pytorch -c nvidia faiss-gpu=1.12.0
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0
pip install transformers==4.55.4 datasets==4.0.0 accelerate pyarrow evaluate loguru wandb tqdm pickle
Important
We encountered a bug where the returned neighbours aren't sorted by distance in faiss-gpu 1.11.0 w/ cuvs, therefore we suggest using faiss-gpu 1.12.0 w/o cuvs instead. Also, for datasets >= 4.0.0, the newly introduced Column object affects some implementations regarding column selection (see pr), therefore we suggest using datasets 4.0.0.
We provide the checkpoint of gpt2-small Memory Decoder used in our experiments π€gpt2-small Memory Decoder. Simply download this checkpoint and π€wikitext-103 dataset from huggingface and run the following scripts:
# scripts/preprocess_dataset.sh
TOKENIZER="/path/to/tokenizer(model)/directory"
OUTPUT_DIR=./dataset/wikitext-gpt2
python utils/preprocess_dataset.py \
--dataset_name /path/to/wikitext \
--dataset_config_name wikitext-103-raw-v1 \
--tokenizer_path ${TOKENIZER} \
--output_dir ${OUTPUT_DIR} \
--num_proc 32
# scripts/evaluate_base_gpt.sh
DATASET=/path/to/dataset
MODEL=/path/to/base/model
OUTPUT_DIR=tmp/
NCCL_P2P_DISABLE=1 NCCL_IB_DISABLE=1 CUDA_VISIBLE_DEVICES=0 python \
-m train_base \
--model_name_or_path ${MODEL} \
--dataset_name ${DATASET} \
--per_device_eval_batch_size 16 \
--do_eval \
--eval_subset test \
--output_dir ${OUTPUT_DIR} \
--report_to none
# scripts/evaluate_joint_gpt2.sh
DATASET=/path/to/dataset
MODEL=/path/to/base/model
KNN_PATH=/path/to/memory/decoder
OUTPUT_DIR=tmp/
python -m evaluate_joint \
--do_test \
--model_name_or_path ${MODEL} \
--dataset_name ${DATASET} \
--dataset_split_name test \
--per_device_eval_batch_size 16 \
--output_dir ${OUTPUT_DIR} \
--knn_temp 1 \
--lmbda 0.55 \
--knn_generator_path ${KNN_PATH} \
--report_to none
Model | Base | +MemDec | PPL Reduction |
---|---|---|---|
GPT2-small | 24.89 | 13.36 | -11.53 |
GPT2-medium | 18.29 | 12.25 | -6.04 |
GPT2-large | 15.80 | 11.53 | -4.27 |
GPT2-xl | 14.39 | 10.93 | -3.46 |
from memDec import MemoryDecoder
import transformers
from transformers import AutoModelForCausalLM
from loguru import logger
# Define paths to your models
base_lm_path = "/path/to/base/model/gpt2-xl"
knn_generator_path = "/path/to/memdec-gpt2-small"
# Load tokenizer and models
tokenizer = transformers.AutoTokenizer.from_pretrained(base_lm_path)
base_lm = AutoModelForCausalLM.from_pretrained(base_lm_path)
knn_generator = AutoModelForCausalLM.from_pretrained(knn_generator_path)
# Resize embeddings and set to evaluation mode
base_lm.resize_token_embeddings(len(tokenizer))
knn_generator.resize_token_embeddings(len(tokenizer))
base_lm.eval()
knn_generator.eval()
# Create the joint Memory Decoder model
joint = MemoryDecoder(base_lm, knn_generator, lmbda=0.55, knn_temp=1.0).to("cuda")
# Prepare input prompt
prompt = "As with previous Valkyira Chronicles games , Valkyria Chronicles III is"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
# Generate with Memory Decoder
out_ids = joint.generate(**inputs, max_new_tokens=20, do_sample=False)
logger.info(f"Memory Decoder output: {tokenizer.decode(out_ids[0], skip_special_tokens=True)}")
# Generate with base model for comparison
out_ids = base_lm.generate(**inputs, max_new_tokens=20, do_sample=False)
logger.info(f"Base Model output: {tokenizer.decode(out_ids[0], skip_special_tokens=True)}")
π Generation Results Comparison:
Model | Generated Continuation |
---|---|
Base Model | "...is a turn-based strategy game. The player takes control of a squad of Valkyria soldiers..." |
+Memory Decoder | "...is a role-playing video game developed by Sega and published by Sega for the PlayStation 2." |
Note
Memory Decoder correctly identifies Valkyria Chronicles III as a role-playing game (factually accurate), while the base model incorrectly predicts it as a strategy game.
Our codebase is organized as follows to facilitate both training and evaluation:
MemoryDecoder/
βββ knn_utils/
β βββ build_index.py # Build FAISS index for efficient search
β βββ saveEmbedMulti.py # Save embeddings with multi-GPU support
β βββ saveKNNMulti.py # Search and save KNN distributions
βββ scripts/
β βββ evaluate_base_gpt.sh # Evaluate base model
β βββ evaluate_joint_gpt2.sh # Evaluate with Memory Decoder
β βββ preprocess_dataset.sh # Preprocess datasets
β βββ save_pipeline.sh # Complete KNN signal pipeline
β βββ train_memdec.sh # Train Memory Decoder
βββ utils/
β βββ cal_loss.py # Loss calculation utilities
β βββ preprocess_dataset.py # Dataset preprocessing
βββ demo/ # Demo scripts
β βββ memDec.py # Class for Memory Decoder Generation
β βββ generation_example.py # Generation demonstration
βββ train_base.py # Base model training/evaluation
βββ train_memdec.py # Memory Decoder training
βββ evaluate_joint.py # Joint evaluation interface
Tokenize and group text for efficient processing:
bash scripts/preprocess_dataset.sh
Three-step process for creating supervision signals:
- Save Embeddings
Extract and save hidden representations from the pretrained model:
accelerate launch \
--config_file ${ACCELERATE_CONFIG} \
-m train_base \
--model_name_or_path ${MODEL_TO_SAVE} \
--dataset_name ${DATASET} \
--do_eval --eval_subset ${SUBSET} \
--per_device_eval_batch_size ${BATCH_SIZE_EVAL} \
--output_dir ${OUTPUT_DIR} \
--dstore_dir ${DSTORE_DIR} \
--save_knnlm_dstore \
--report_to none
- Build IVFPQ Index
Create an efficient index for fast nearest neighbor search:
python -m knn_utils.build_index \
--dstore_path ${DSTORE_PATH} \
--num_keys_to_add_at_a_time ${NUM_KEYS_TO_ADD} \
--ncentroids ${NCENTROIDS} \
--code_size ${CODE_SIZE} \
--probe ${PROBE}
- Search KNN Distributions
Generate KNN probability distributions as training signals:
accelerate launch \
--config_file ${ACCELERATE_CONFIG} \
-m knn_utils.saveKNNMulti \
--model_path ${MODEL_TO_SAVE} \
--dstore_path ${DSTORE_PATH} \
--val_path ${VAL_PATH} \
--index_path ${INDEX_PATH} \
--output_path ${OUTPUT_PATH} \
--k ${K} \
--knn_temp ${KNN_TEMP} \
--probe ${PROBE} \
--batch_size ${BATCH_SIZE_KNN} \
--ignore_first True \
--knn_gpu
The complete pipeline is available in:
bash scripts/save_pipeline.sh
Important
Both embedding saving and KNN distribution search support multi-card multi-node inference/searching. Ensure your accelerate
configuration is properly set up for distributed computing to maximize efficiency.
Launch Memory Decoder training:
bash scripts/train_memdec.sh
Note
The training interface is implemented in train_memdec.py
and supports resuming from checkpoints automatically.
This implementation is inspired by the excellent work in knn-transformers. We are grateful for their pioneering contributions to retrieval-augmented language modeling.
For questions and discussions, feel free to email: maximus.cao@outlook.com
If you find Memory Decoder helpful in your research, please consider citing:
@article{cao2025memory,
title={Memory decoder: A pretrained, plug-and-play memory for large language models},
author={Cao, Jiaqi and Wang, Jiarui and Wei, Rubin and Guo, Qipeng and Chen, Kai and Zhou, Bowen and Lin, Zhouhan},
journal={arXiv preprint arXiv:2508.09874},
year={2025}
}