### Augmentation - Distant Supervision

In [14]:
!pip install scikit-learn

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting scikit-learn
  Downloading scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting scipy>=1.6.0 (from scikit-learn)
  Downloading scipy-1.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
Collecting joblib>=1.2.0 (from scikit-learn)
  Downloading joblib-1.4.2-py3-none-any.whl.metadata (5.4 kB)
Collecting threadpoolctl>=3.1.0 (from scikit-learn)
  Downloading threadpoolctl-3.5.0-py3-none-any.whl.metadata (13 kB)
Downloading scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.9/12.9 MB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading joblib-1.4.2-py3-none-any.whl (301 kB)
Downloading scipy-1.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (40.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.8/40.8 MB[0m [31m21.0 MB/s[0m eta [3

In [1]:
pip install tqdm

[0mNote: you may need to restart the kernel to use updated packages.


In [None]:
import json
import pandas as pd
from transformers import DPRQuestionEncoder, DPRContextEncoder, DPRQuestionEncoderTokenizer, DPRContextEncoderTokenizer
from sklearn.metrics.pairwise import cosine_similarity

# 1. Train 데이터 로드
from datasets import load_from_disk

dataset = load_from_disk("/data/ephemeral/home/jeongeun/data/raw/train_dataset")
train_dataset = dataset["train"]
train_data=pd.DataFrame(train_dataset)

# 2. DPR 모델 로드
question_encoder = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
context_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
context_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')

# 3. 위키피디아 문서 로드
wikipedia_data_path = '/data/ephemeral/home/jeongeun/data/raw/wikipedia_documents.json'
with open(wikipedia_data_path, 'r') as f:
    wikipedia_data = json.load(f)

# Helper function to encode questions and contexts
def encode_question(question, max_length=512):
    inputs = question_tokenizer(question, return_tensors='pt', truncation=True, padding=True, max_length=max_length)
    question_emb = question_encoder(**inputs).pooler_output
    return question_emb

def encode_context(context, max_length=512):
    inputs = context_tokenizer(context, return_tensors='pt', truncation=True, padding=True, max_length=max_length)
    context_emb = context_encoder(**inputs).pooler_output
    return context_emb

# 4. 유사도 높은 문서 찾기
def find_similar_document(question, wikipedia_data):
    question_emb = encode_question(question)
    
    similarities = []
    doc_ids = []
    
    # 각 문서에 대해 첫 문단의 embedding을 계산하고 유사도 비교
    for doc_id, document in wikipedia_data.items():
        first_paragraph = document['text'].split('\n')[0]  # 첫 문단 가져오기
        context_emb = encode_context(first_paragraph)
        similarity = cosine_similarity(question_emb.detach().numpy(), context_emb.detach().numpy())
        similarities.append(similarity[0][0])
        doc_ids.append(doc_id)
    
    # 유사도가 가장 높은 문서 선택
    best_match_idx = similarities.index(max(similarities))
    return doc_ids[best_match_idx], wikipedia_data[doc_ids[best_match_idx]]

# 5. 증강된 데이터셋 생성
from tqdm import tqdm

augmented_data = []

for idx, row in tqdm(train_data.iterrows(), total=train_data.shape[0], desc="Augmenting Data"):
    question = row['question']
    current_doc_id = row['document_id']
    
    # DPR로 유사도가 높은 문서 찾기
    new_doc_id, similar_doc = find_similar_document(question, wikipedia_data)
    
    # 만약 다른 문서라면 해당 문서에서 답이 있는 단락까지 추출해 context로 사용
    if new_doc_id != current_doc_id:
        # 답이 포함된 단락 찾기 (단순히 answer 문자열이 있는지 확인)
        answer = row['answer']
        context = ""
        for paragraph in similar_doc['text'].split('\n'):
            context += paragraph + "\n"
            if answer in paragraph:
                break
        
        # 기존 train 데이터에 행 추가 (새로운 문서와 context 사용)
        new_row = row.copy()
        new_row['document_id'] = new_doc_id
        new_row['context'] = context
        augmented_data.append(new_row)

# 6. 증강된 데이터 저장
from datasets import Dataset, DatasetDict

augmented_data = pd.DataFrame(augmented_data)
augmented_data = pd.concat([train_data, augmented_data], ignore_index=True)
train_dataset = Dataset.from_pandas(augmented_data)

dataset_dict = DatasetDict({
    'train': train_dataset
})

dataset_dict.save_to_disk('/data/ephemeral/home/jeongeun/data/preprocessed/train_dataset_aug_DS')