In [1]:
!pip install transformers datasets

Collecting transformers
  Using cached transformers-4.51.3-py3-none-any.whl.metadata (38 kB)
Collecting datasets
  Using cached datasets-3.5.1-py3-none-any.whl.metadata (19 kB)
Collecting huggingface-hub<1.0,>=0.30.0 (from transformers)
  Using cached huggingface_hub-0.30.2-py3-none-any.whl.metadata (13 kB)
Collecting regex!=2019.12.17 (from transformers)
  Using cached regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)
Collecting tokenizers<0.22,>=0.21 (from transformers)
  Using cached tokenizers-0.21.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Collecting safetensors>=0.4.3 (from transformers)
  Using cached safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Using cached pyarrow-20.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Using cached dill-0.3.8-py3-none-any.w

## 데이터 load

In [49]:
import json
import torch
import numpy as np
from torch.utils.data import DataLoader
from datasets import Dataset
from transformers import (
    DPRQuestionEncoder, DPRContextEncoder,
    DPRQuestionEncoderTokenizer, DPRContextEncoderTokenizer
)

from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

# 1. 데이터 로드
def load_json(path):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

spec0_train = load_json("spec0_train.json")
spec1_train = load_json("spec1_train.json")
train_data = spec0_train + spec1_train

## Tokenizer & Model 로드

In [50]:
q_tok = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
ctx_tok = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
q_enc = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base").to(device)
ctx_enc = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base").to(device)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenizer'.
Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequence

## 데이터 Preprocessing

In [51]:
def format_data(example):
    return {
        "question": example["question"],
        "positive": example["positive_ctxs"][0]["text"],
        "negatives": [ctx["text"] for ctx in example["negative_ctxs"]]
    }

train_dataset = Dataset.from_list([format_data(x) for x in train_data])

def encode(example):
    q = q_tok(example["question"], padding="max_length", truncation=True, max_length=256, return_tensors="pt")
    p = ctx_tok(example["positive"], padding="max_length", truncation=True, max_length=256, return_tensors="pt")
    n = ctx_tok(example["negatives"], padding="max_length", truncation=True, max_length=256, return_tensors="pt")

    return {
        "q_input_ids": q["input_ids"].squeeze(),
        "q_attention_mask": q["attention_mask"].squeeze(),
        "pos_input_ids": p["input_ids"].squeeze(),
        "pos_attention_mask": p["attention_mask"].squeeze(),
        "neg_input_ids": n["input_ids"],
        "neg_attention_mask": n["attention_mask"]
    }

train_dataset = train_dataset.map(encode)
train_dataset.set_format("torch")
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

Map:   0%|          | 0/416 [00:00<?, ? examples/s]

## Training

In [None]:
optimizer = torch.optim.AdamW(list(q_enc.parameters()) + list(ctx_enc.parameters()), lr=2e-5)

q_enc.train(); ctx_enc.train()

for epoch in range(5):
    total_loss = 0
    for batch in train_loader:
        q = q_enc(input_ids=batch["q_input_ids"].to(device), attention_mask=batch["q_attention_mask"].to(device)).pooler_output
        p = ctx_enc(input_ids=batch["pos_input_ids"].to(device), attention_mask=batch["pos_attention_mask"].to(device)).pooler_output
        B, N, L = batch["neg_input_ids"].shape
        neg_input_ids = batch["neg_input_ids"].view(B * N, L).to(device)
        neg_att = batch["neg_attention_mask"].view(B * N, L).to(device)
        n = ctx_enc(input_ids=neg_input_ids, attention_mask=neg_att).pooler_output.view(B, N, -1)

        all_ctx = torch.cat([p.unsqueeze(1), n], dim=1)
        sim = torch.bmm(q.unsqueeze(1), all_ctx.transpose(1, 2)).squeeze(1)
        labels = torch.zeros(q.size(0), dtype=torch.long).to(device)
        loss = torch.nn.CrossEntropyLoss()(sim, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"[Epoch {epoch+1}] Loss: {total_loss / len(train_loader):.4f}")


q_enc.save_pretrained("dpr_finetuned/question")
ctx_enc.save_pretrained("dpr_finetuned/context")
q_tok.save_pretrained("dpr_finetuned/question")
ctx_tok.save_pretrained("dpr_finetuned/context")

[Epoch 1] Loss: 0.0722
[Epoch 2] Loss: 0.0851
[Epoch 3] Loss: 0.0030
[Epoch 4] Loss: 0.0108
[Epoch 5] Loss: 0.0257


('dpr_finetuned/context/tokenizer_config.json',
 'dpr_finetuned/context/special_tokens_map.json',
 'dpr_finetuned/context/vocab.txt',
 'dpr_finetuned/context/added_tokens.json')

## 전체 Corpus Encoding

In [None]:
import json
import torch
import numpy as np
from tqdm import tqdm
from sklearn.metrics import ndcg_score

def load_json(path):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

spec0 = load_json("spec0_test.json")
spec1 = load_json("spec1_test.json")
full_data = spec0 + spec1

all_ctxs = {}
for item in full_data:
    for ctx in item["positive_ctxs"] + item["negative_ctxs"]:
        all_ctxs[ctx["id"]] = ctx["text"]

corpus_ids = list(all_ctxs.keys())
corpus_texts = [all_ctxs[i] for i in corpus_ids]


ctx_enc.eval()
with torch.no_grad():
    corpus_embeddings = []
    for i in range(0, len(corpus_texts), 64):
        batch = ctx_tok(corpus_texts[i:i+64], return_tensors="pt", padding=True, truncation=True, max_length=256).to(device)
        vecs = ctx_enc(**batch).pooler_output.cpu()
        corpus_embeddings.append(vecs)
    corpus_embeddings = torch.cat(corpus_embeddings, dim=0)

## Evaluation

In [44]:
def evaluate(test_data, k):
    q_enc.eval()
    recall, acc, ndcgs = 0, 0, []
    for item in tqdm(test_data):
        q = item["question"]
        gold_ids = set(ctx["id"] for ctx in item["positive_ctxs"])
        q_vec = q_enc(**q_tok(q, return_tensors="pt", padding=True, truncation=True, max_length=256).to(device)).pooler_output
        scores = torch.matmul(q_vec, corpus_embeddings.T.to(device)).squeeze().detach().cpu().numpy()
        topk = np.argsort(scores)[::-1][:k]
        top_ids = [corpus_ids[i] for i in topk]
        recall += int(any(doc_id in gold_ids for doc_id in top_ids))
        acc += int(top_ids[0] in gold_ids)
        relevance = [int(doc_id in gold_ids) for doc_id in corpus_ids]
        ndcgs.append(ndcg_score([relevance], [scores]))
    total = len(test_data)
    return {
        f"Recall@{k}": round(recall / total, 4),
        "Accuracy@1": round(acc / total, 4),
        f"nDCG@{k}": round(np.mean(ndcgs), 4)
    }

In [45]:
print("📊 Broad Query (specificity=0):")
print(evaluate(spec0, k=20))

print("📊 Specific Query (specificity=1):")
print(evaluate(spec1, k=5))

📊 Broad Query (specificity=0):


100%|██████████| 19/19 [00:00<00:00, 92.98it/s]


{'Recall@20': 0.6316, 'Accuracy@1': 0.2632, 'nDCG@20': 0.4855}
📊 Specific Query (specificity=1):


100%|██████████| 41/41 [00:00<00:00, 96.59it/s]

{'Recall@5': 0.3171, 'Accuracy@1': 0.1463, 'nDCG@5': 0.368}





## Failure case 확인

In [46]:
def find_failure_cases_with_text(test_data, k, corpus_embeddings, corpus_ids, corpus_texts,
                                 q_enc, ctx_tok, q_tok, device="cuda"):
    q_enc.eval()
    failures = []
    id_to_text = dict(zip(corpus_ids, corpus_texts))

    for item in tqdm(test_data):
        q = item["question"]
        gold_ids = set(ctx["id"] for ctx in item["positive_ctxs"])
        gold_texts = [ctx["text"] for ctx in item["positive_ctxs"]]

        q_vec = q_enc(**q_tok(q, return_tensors="pt", padding=True, truncation=True, max_length=256).to(device)).pooler_output
        scores = torch.matmul(q_vec, corpus_embeddings.T.to(device)).squeeze().detach().cpu().numpy()

        topk = np.argsort(scores)[::-1][:k]
        top_ids = [corpus_ids[i] for i in topk]
        top_texts = [id_to_text[i] for i in top_ids]
        top_scores = [float(scores[i]) for i in topk]

        if not any(doc_id in gold_ids for doc_id in top_ids):
            failures.append({
                "question": q,
                "gold_ids": list(gold_ids),
                "gold_texts": gold_texts,
                "top_k_ids": top_ids,
                "top_k_texts": top_texts,
                "scores": top_scores
            })

    return failures

In [47]:
    failures_spec0 = find_failure_cases_with_text(
        test_data=spec0, k=20,
        corpus_embeddings=corpus_embeddings,
        corpus_ids=corpus_ids,
        corpus_texts=corpus_texts,
        q_enc=q_enc, ctx_tok=ctx_tok, q_tok=q_tok, device=device
    )
    
    failures_spec1 = find_failure_cases_with_text(
        test_data=spec1, k=5,
        corpus_embeddings=corpus_embeddings,
        corpus_ids=corpus_ids,
        corpus_texts=corpus_texts,
        q_enc=q_enc, ctx_tok=ctx_tok, q_tok=q_tok, device=device
    )


100%|██████████| 19/19 [00:00<00:00, 110.43it/s]
100%|██████████| 41/41 [00:00<00:00, 114.16it/s]


In [None]:
print("📊 Broad Queries (spec0, Recall@20):")
print(evaluate(spec0, k=20))
failures_spec0 = find_failure_cases_with_text(
    test_data=spec0, k=20,
    corpus_embeddings=corpus_embeddings,
    corpus_ids=corpus_ids,
    corpus_texts=corpus_texts,
    q_enc=q_enc, ctx_tok=ctx_tok, q_tok=q_tok, device=device
)

print("\n📊 Specific Queries (spec1, Recall@5):")
print(evaluate(spec1, k=5))
failures_spec1 = find_failure_cases_with_text(
    test_data=spec1, k=5,
    corpus_embeddings=corpus_embeddings,
    corpus_ids=corpus_ids,
    corpus_texts=corpus_texts,
    q_enc=q_enc, ctx_tok=ctx_tok, q_tok=q_tok, device=device
)


def save_failure_cases_json(failures, path):
    with open(path, "w", encoding="utf-8") as f:
        json.dump(failures, f, indent=2, ensure_ascii=False)
    print(f"✅ Saved to {path}")

save_failure_cases_json(failures_spec0, "failures_spec0.json")
save_failure_cases_json(failures_spec1, "failures_spec1.json")