In [1]:
import pandas as pd
import json
import pickle
import os

In [2]:
# dt = '20240713'
# version = 'v1'

# output_dir = os.path.join('outputs', f'{version}_{dt}')
output_dir = r""
os.makedirs(output_dir, exist_ok=True)

# 加载文档片段

In [3]:
from langchain_community.document_loaders import PyPDFLoader

loader = PyPDFLoader(r"")
documents = loader.load()

  from cryptography.hazmat.primitives.ciphers.algorithms import AES, ARC4


In [5]:
from uuid import uuid4
import os
import pickle
from langchain.text_splitter import RecursiveCharacterTextSplitter

def split_docs(documents, filepath, chunk_size=400, chunk_overlap=40, seperators=['\n\n\n', '\n\n'], force_split=False):
    if os.path.exists(filepath) and not force_split:
        print('found cache, restoring...')
        return pickle.load(open(filepath, 'rb'))

    splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        separators=seperators
    )
    split_docs = splitter.split_documents(documents)
    for chunk in split_docs:
        chunk.metadata['uuid'] = str(uuid4())

    pickle.dump(split_docs, open(filepath, 'wb'))

    return split_docs

splitted_docs = split_docs(documents, os.path.join(output_dir, 'split_docs.pkl'), chunk_size=500, chunk_overlap=50)

found cache, restoring...


In [6]:
uuid2doc = {doc.metadata['uuid']: doc for doc in splitted_docs}

In [8]:
len(uuid2doc)

107

# 加载抽取的QA

In [9]:
qa_df = pd.read_excel(os.path.join(output_dir, f'question_answer.xlsx'))
qa_df = qa_df[qa_df['dataset'] == 'train']

In [10]:
len(qa_df)

516

In [12]:
qa_df.head(3)

Unnamed: 0,uuid,question,answer,context,doc,qa_type,score,score_reason,dataset
0,805948dc-9161-4357-b2b7-bb88784386f5,Who are the authors of the document?,"Kartik Kuckreja, Muhammad Sohail Danish, Muzam...","Kartik Kuckreja1, 2* Muhammad Sohail Danish1*M...",GeoChat\n : Grounded Large Vision-Language Mod...,detailed,5,"The question is clear and specific, and the an...",train
1,805948dc-9161-4357-b2b7-bb88784386f5,Which institutions are the authors affiliated ...,"Mohamed bin Zayed University of AI, Birla Inst...","1Mohamed bin Zayed University of AI,2Birla Ins...",GeoChat\n : Grounded Large Vision-Language Mod...,detailed,5,The question asks for specific factual informa...,train
2,805948dc-9161-4357-b2b7-bb88784386f5,What type of model is GeoChat?,Grounded Large Vision-Language Model,GeoChat : Grounded Large Vision-Language Model...,GeoChat\n : Grounded Large Vision-Language Mod...,detailed,5,"The question is clear and specific, and the an...",train


In [13]:
qa_df['question'].nunique()

516

In [14]:
qa_df = qa_df.drop_duplicates('question')

In [15]:
len(qa_df)

516

In [16]:
qa_df.isnull().sum()

uuid            0
question        0
answer          0
context         4
doc             0
qa_type         0
score           0
score_reason    0
dataset         0
dtype: int64

In [17]:
def build_qa_samples(df, neg_batch_size=-1, n_neg_batch=5):
    """
    构建qa样本
    :param df: 包含qa的DataFrame，共两列，question和answer
    :param neg_batch_size: 负样本数量，为-1时表示将所有负样本和单个正样本配对，否则会将负样本拆开，结果中的query可能会重复
    """
    from tqdm.auto import tqdm
    import math

    data = []
    for idx, row in tqdm(df.iterrows(), total=len(df)):
        question = row['question']
        answer = row['answer']
        # 筛选同category的，增加难度
        neg_samples = df[df['question'] != question]['answer'].values.tolist()
        neg_batch_count = math.ceil((len(df) - 1) / neg_batch_size)
        neg_batch_count = min(n_neg_batch, neg_batch_count)
        for neg_batch_idx in range(neg_batch_count):
            batch_neg_samples = neg_samples[neg_batch_idx * neg_batch_size: (neg_batch_idx + 1) * neg_batch_size]
            batch_neg_samples = [item for item in batch_neg_samples if item != answer]
            data.append({
                'query': question,
                'pos': [answer],
                'neg': batch_neg_samples
            })
    return data

def write_samples(samples, save_filename):
    import json

    with open(save_filename, 'w') as f:
        for sample in samples:
            f.write(json.dumps(sample, ensure_ascii=False))
            f.write('\n')

In [19]:
len(qa_df)

516

In [22]:
qa_df = qa_df[qa_df['qa_type'] == 'detailed']
qa_df['answer'] = qa_df['context']

qd_samples = build_qa_samples(qa_df, neg_batch_size=16, n_neg_batch=32)

  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 516/516 [00:00<00:00, 3061.48it/s]


In [23]:
len(qd_samples)

16512

In [24]:
qd_samples[0]

{'query': 'Who are the authors of the document?',
 'pos': ['Kartik Kuckreja1, 2* Muhammad Sohail Danish1*Muzammal Naseer1 Abhijit Das2 Salman Khan1, 3 Fahad Shahbaz Khan1, 4'],
 'neg': ['1Mohamed bin Zayed University of AI,2Birla Institute of Technology & Science, Hyderabad 3Australian National University,4Link ¨oping University',
  'GeoChat : Grounded Large Vision-Language Model for Remote Sensing',
  'GeoChat : Grounded Large Vision-Language Model for Remote Sensing',
  'Recent advancements in Large Vision-Language Mod- els (VLMs) have shown great promise in natural image do-',
  'mains, allowing users to hold a dialogue about given vi-sual content. However, such general-domain VLMs perform poorly for Remote Sensing (RS) scenarios, leading to inac- curate or fabricated information when presented with RS domain-specific queries.',
  'Such a behavior emerges due to the unique challenges introduced by RS imagery. For exam- ple, to handle high-resolution RS imagery with diverse scale cha

In [25]:
write_samples(qd_samples, os.path.join(output_dir, 'emb_samples_qd_v2.jsonl'))