# PAG End-to-End on MS MARCO Sample
This Colab notebook demonstrates environment setup, data preparation, model fine-tuning, and evaluation of the PAG model on a small MS MARCO subset.

## 1. Environment Setup
Install dependencies and clone the repository.

In [None]:
!git clone https://github.com/HansiZeng/PAG.git
%cd PAG
!pip install -r requirements.txt
!pip install torch==1.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
!pip install faiss-gpu==1.7.2
!pip install gdown
!pip install --upgrade pyarrow nbformat

## 2. Download Pre-trained Model and Mapping Files
Fetch a pre-trained checkpoint and document ID mappings from Google Drive (replace the file IDs with the correct ones from the PAG-data folder).

In [None]:
import os, zipfile
os.makedirs('data', exist_ok=True)
%cd data
# Download model checkpoint
!gdown 1A7VYJZkxxxxxx -O pag_model.zip
with zipfile.ZipFile('pag_model.zip', 'r') as zf: zf.extractall()
# Download docid_to_tokenids files
!gdown 1B8xxxxxxx -O docids.zip
with zipfile.ZipFile('docids.zip', 'r') as zf: zf.extractall()
%cd ..

## 3. Prepare a Small MS MARCO Sample
Load a subset of MS MARCO passages and queries, build training pairs, and write them in the format expected by the training scripts.

In [None]:
from datasets import load_dataset
import os, random, ujson

passage_count = 1000
query_count = 20

passages = load_dataset('ms_marco', 'v1.1', split=f'passages[:{passage_count}]')
queries = load_dataset('ms_marco', 'v1.1', split=f'queries[:{query_count}]')
qrels = load_dataset('ms_marco', 'v1.1', split=f'qrels[:{query_count}]')

os.makedirs('data/msmarco-sample/full_collection', exist_ok=True)
with open('data/msmarco-sample/full_collection/raw.tsv','w') as f:
    for p in passages:
        f.write(f"{p['pid']}\t{p['passage']}\n")

os.makedirs('data/msmarco-sample/train_queries', exist_ok=True)
with open('data/msmarco-sample/train_queries/raw.tsv','w') as f:
    for q in queries:
        f.write(f"{q['qid']}\t{q['query']}\n")

# Build qrels dict and training pairs
qid_to_pos = {}
for qr in qrels:
    qid_to_pos.setdefault(str(qr['query_id']), []).append(str(qr['passage_id']))
all_docids = [str(p['pid']) for p in passages]
with open('data/msmarco-sample/train_pairs.jsonl','w') as f:
    for q in queries:
        qid = str(q['qid'])
        if qid not in qid_to_pos:
            continue
        pos = qid_to_pos[qid][0]
        neg = random.choice([d for d in all_docids if d != pos])
        f.write(ujson.dumps({'qid': qid, 'docids': [pos, neg], 'scores': [1.0, 0.0]}) + '\n')

# Save qrels for evaluation
qrel_dict = {}
for qid, docs in qid_to_pos.items():
    qrel_dict[qid] = {d: 1 for d in docs}
with open('data/msmarco-sample/qrels.json','w') as f:
    ujson.dump(qrel_dict, f)

## 4. Fine-tune PAG on the Sample
Use the prepared pairs to fine-tune the pre-trained PAG model for one epoch.

In [None]:
!python -m t5_pretrainer.main   --epochs=1   --run_name=sample_run   --learning_rate=5e-4   --loss_type=margin_mse   --model_name_or_path=t5-base   --model_type=lexical_ripor   --teacher_score_path=data/msmarco-sample/train_pairs.jsonl   --output_dir=data/experiments-sample   --task_names='["rank","lexical_rank"]'   --collection_path=data/msmarco-sample/full_collection   --queries_path=data/msmarco-sample/train_queries   --pretrained_path=data/experiments-full-lexical-ripor/lexical_ripor_direct_lng_knp_seq2seq_1/checkpoint   --smt_docid_to_smtid_path=data/experiments-full-lexical-ripor/t5-full-dense-1-5e-4-12l/aq_smtid/docid_to_tokenids.json   --lex_docid_to_smtid_path=data/experiments-splade/t5-splade-0-12l/top_bow/docid_to_tokenids.json   --per_device_train_batch_size=2   --max_length=64   --use_fp16

## 5. Evaluate the Fine-tuned Model
Run constrained beam search on the sample queries and compute retrieval metrics.

In [None]:
!python -m t5_pretrainer.evaluate   --pretrained_path=data/experiments-sample/sample_run/checkpoint   --out_dir=data/experiments-sample/output   --task=constrained_beam_search_for_qid_rankdata   --docid_to_tokenids_path=data/experiments-full-lexical-ripor/t5-full-dense-1-5e-4-12l/aq_smtid/docid_to_tokenids.json   --q_collection_paths='["data/msmarco-sample/train_queries/"]'   --batch_size=1   --max_new_token_for_docid=8   --topk=10
!python -m t5_pretrainer.evaluate   --task=constrained_beam_search_for_qid_rankdata_2   --out_dir=data/experiments-sample/output   --q_collection_paths='["data/msmarco-sample/train_queries/"]'   --eval_qrel_path='["data/msmarco-sample/qrels.json"]'

## 6. Inspect Retrieval Outputs
Display a snippet of the generated run file.

In [None]:
import glob, json
run_files = glob.glob('data/experiments-sample/output/**/*.json', recursive=True)
for path in run_files[:5]:
    print(path)
    with open(path) as f:
        print(json.dumps(json.load(f)[:5], indent=2))
    break