Skip to content


Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?

Latest commit


Git stats


Failed to load latest commit information.
Latest commit message
Commit time

Marge - Pre-training via Paraphrasing

Implementation of Marge, Pre-training via Paraphrasing, in Pytorch. It is an alternative to masked language modeling pretraining, where an encoder / decoder attention network learns to reconstruct a target document from a collection of evidence documents.

Update: Three researchers have independently reported that the repository works for them


$ pip install marge-pytorch


import torch
import numpy as np
from import DataLoader

from marge_pytorch import Marge, TrainingWrapper

# your documents must be tokenized and stored as memmap in the shape (num documents, seq length)

# constants
NUM_DOCS = 10000
SEQ_LEN = 1024

# generate mock training data
f = np.memmap('./train.dat', dtype=np.int32, mode='w+', shape=SHAPE)
f[:] = np.random.randint(0, 20000, size=SHAPE)
del f

# generate mock masking data
f = np.memmap('./train.mask.dat', dtype=np.bool, mode='w+', shape=SHAPE)
f[:] = np.full(SHAPE, True)
del f

# instantiate model

model = Marge(
    dim = 512,
    num_tokens = 20000,
    max_seq_len = SEQ_LEN,
    enc_depth = 12,
    enc_retrieval_depth = 4,                # defaults to 4 as in paper (take the CLS token after the 4th layer of the encoder)
    enc_heads = 8,
    enc_ff_mult = 4,
    dec_depth = 12,
    dec_heads = 8,
    dec_ff_mult = 16,                       # paper noted that decoder needs to have much bigger feed forward sizes
    distill_attn = False,                   # (experimental) will add, on top of the decoder loss, an auxiliary distillation loss as defined in
    distill_loss_coef = 1.                  # weight of distillation auxilliary loss         

# wrap your model and your documents

trainer = TrainingWrapper(
    num_documents = NUM_DOCS,
    doc_seq_len = SEQ_LEN,
    num_evidence = 4,                         # number of evidence documents to fetch per target document to construct
    reindex_batch_size = 32,                  # batch size to use when reindexing
    documents_memmap_path = './train.dat',    # path to the mem-mapped documents
    masks_memmap_path = './train.mask.dat',   # if None is supplied, will assume all tokens are visible
    use_faiss_ann = True                      # set this to false if you have a low number of documents, and approximate nearest neighbor is not needed

# instantiate dataloader

dl = DataLoader(trainer.dataset, batch_size=16)

# now you can train, and use the reindex method on the training wrapper at appropriate intervals

for ind, data in enumerate(dl):
    loss = trainer(data)
    # optimizer step and all that

    # reindex and precompute knn every 10000 steps, as in paper
    if ind > 0 and ind % 10000 == 0:

Save your model after much training, f'./')


If you would like the target and evidence documents to be from different sets, you just have to pass in up to four additional keyword arguments, as shown below.

trainer = TrainingWrapper(
    num_documents = NUM_DOCS,
    doc_seq_len = SEQ_LEN,
    num_evidence = 4,
    reindex_batch_size = 32,
    documents_memmap_path = './evidence.dat',
    masks_memmap_path = './evidence.mask.dat',
    num_targets = NUM_TARGETS,                       # 1. number of target documents, with sequence length the same as the document (evidence)
    target_seq_len = SEQ_LEN,                        # 2. sequence length of target documents
    target_memmap_path = './target.dat',             # 3. path to target memmap, same as documents (evidence)
    target_masks_memmap_path = './target.mask.dat',  # 4. path to target mask memmap, same as document masks (evidence)
    use_faiss_ann = True


You can sample from the decoder with the following instructions

# some random evidence from the dataset
# or provide your own in the dimensions (b x num_evidences x seq_len)
*_, evidence, mask = trainer.dataset[0:1]

# assume 1 is start token
prime = torch.tensor([[1.]]).long().cuda()

# supply your own document similarities array (b x num_evidences)
# if not supplied, will default to 1. for all evidence
doc_similarities = torch.ones(evidence.shape[:2]).float().cuda()

# generate sample of length 1024
samples = model.generate(prime, 1024, evidence, mask = mask, similarities = doc_similarities)


    title={Pre-training via Paraphrasing},
    author={Mike Lewis and Marjan Ghazvininejad and Gargi Ghosh and Armen Aghajanyan and Sida Wang and Luke Zettlemoyer},
    title={Current Limitations of Language Models: What You Need is Retrieval},
    author={Aran Komatsuzaki},
    title={Distilling Knowledge from Reader to Retriever for Question Answering},
    author={Gautier Izacard and Edouard Grave},