In [1]:
import pandas as pd
import os
import numpy as np
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel, BertPreTrainedModel, AutoConfig
from transformers.modeling_outputs import SequenceClassifierOutput
import accelerate
import sqlite3
from tqdm.auto import tqdm
import json
import faiss 
import shutil

# --- 1. GPUの確認 ---
if torch.cuda.is_available():
    print(f"✅ GPU is available. Device: {torch.cuda.get_device_name(0)}")
    device = torch.device("cuda")
else:
    print("⚠️ GPU not found. Running on CPU.")
    device = torch.device("cpu")

✅ GPU is available. Device: NVIDIA RTX A6000


In [2]:
# --- 2. カスタムモデルクラスの定義 ---
# (訓練スクリプト(user_250)のセル3と同一)
class SiameseContrastiveWithHeadModel(BertPreTrainedModel):
    def __init__(self, config):
        super(SiameseContrastiveWithHeadModel, self).__init__(config)
        self.bert = AutoModel.from_config(config)
        self.classifier_head = nn.Sequential(
            nn.Linear(config.hidden_size * 4, config.hidden_size),
            nn.ReLU(),
            nn.Linear(config.hidden_size, 1)
        )
        self.init_weights()
    
    def _get_vector(self, input_ids, attention_mask):
        output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        return output.pooler_output 

    def forward(self, input_ids=None, **kwargs):
        pass # 推論時には不要

print("Custom model class 'SiameseContrastiveWithHeadModel' defined.")

Custom model class 'SiameseContrastiveWithHeadModel' defined.


In [3]:
# --- 3. 設定とモデルロード ---
DB_PATH = "data/processed/s2orc_filtered.db"

# ▼▼▼ 訓練済みのS-BERT (Contrastive) モデルのパス ▼▼▼
TRAINED_MODEL_PATH = "models/sbert_contrastive_with_head_v1/best_model" 

# ▼▼▼ 新しい出力ファイル ▼▼▼
EMBEDDINGS_OUTPUT_FILE = "data/processed/contrastive_scibert_cls_embeddings.npy"
DOI_MAP_OUTPUT_FILE = "data/processed/contrastive_doi_map.json"
FAISS_INDEX_OUTPUT_FILE = "data/processed/contrastive_scibert.faiss"

# --- 一時ディレクトリ (再開機能) ---
TEMP_EMBED_DIR = "data/processed/embeddings_tmp_contrastive"
TEMP_DOI_DIR = "data/processed/dois_tmp_contrastive"

# ハイパーパラメータ
MAX_LENGTH = 512
INFERENCE_BATCH_SIZE = 512

# --- モデルとトークナイザのロード ---
print(f"Loading tokenizer...")
# (Tokenizerは訓練済みモデルのパスからロードするのが最も安全)
tokenizer = AutoTokenizer.from_pretrained(TRAINED_MODEL_PATH) 

print(f"Loading TRAINED model from: {TRAINED_MODEL_PATH}")
model = SiameseContrastiveWithHeadModel.from_pretrained(TRAINED_MODEL_PATH).to(device)
model.eval() 
print("Model and tokenizer loaded successfully.")

Loading tokenizer...
Loading TRAINED model from: models/sbert_contrastive_with_head_v1/best_model
Model and tokenizer loaded successfully.


In [4]:
# --- 4. データベースからのデータ読み込み (ジェネレータ) ---
def get_abstract_batches(db_path, batch_size=1000):
    print(f"Opening database connection: {db_path}")
    with sqlite3.connect(db_path) as conn:
        cursor = conn.cursor()
        query = "SELECT COUNT(doi) FROM papers WHERE abstract IS NOT NULL AND abstract != ''"
        total_rows = cursor.execute(query).fetchone()[0]
        print(f"Total abstracts to process: {total_rows:,}")
        yield total_rows 
        
        cursor.execute("SELECT doi, abstract FROM papers WHERE abstract IS NOT NULL AND abstract != ''")
        batch = []
        for row in cursor:
            batch.append(row)
            if len(batch) >= batch_size:
                yield batch
                batch = []
        if batch:
            yield batch
print("Database generator defined.")

Database generator defined.


In [5]:
# --- 5. 全アブストラクトのベクトル化を実行 ---
os.makedirs(TEMP_EMBED_DIR, exist_ok=True)
os.makedirs(TEMP_DOI_DIR, exist_ok=True)

completed_indices = set()
print(f"Scanning {TEMP_DOI_DIR} for completed batches...")
for f in os.listdir(TEMP_DOI_DIR):
    if f.startswith('batch_') and f.endswith('.json'):
        try: completed_indices.add(int(f[6:11]))
        except: pass
print(f"Found {len(completed_indices)} completed batches to skip.")

with torch.no_grad():
    db_batch_size = 1000
    batch_generator = get_abstract_batches(DB_PATH, batch_size=db_batch_size)
    total_rows = next(batch_generator)
    total_batches = (total_rows + db_batch_size - 1) // db_batch_size
    
    print("Starting/Resuming embedding generation...")
    pbar = tqdm(enumerate(batch_generator), total=total_batches, desc="Vectorizing Batches")
    
    for i, batch in pbar:
        if i in completed_indices:
            pbar.set_description(f"Skipping Batch {i}")
            continue 

        pbar.set_description(f"Processing Batch {i}")
        dois, abstracts = zip(*batch)
        
        batch_embeddings_list = []
        for j in range(0, len(abstracts), INFERENCE_BATCH_SIZE):
            sub_batch_abstracts = abstracts[j : j + INFERENCE_BATCH_SIZE]
            
            inputs = tokenizer(
                list(sub_batch_abstracts), 
                padding="max_length", 
                truncation=True, 
                max_length=MAX_LENGTH, 
                return_tensors="pt"
            ).to(device)
            
            # ▼▼▼ Contrastiveモデルの「エンコーダー」でベクトル化 ▼▼▼
            embeddings = model._get_vector(
                input_ids=inputs['input_ids'], 
                attention_mask=inputs['attention_mask']
            )
            batch_embeddings_list.append(embeddings.cpu().numpy())
        
        embeddings_cpu = np.vstack(batch_embeddings_list).astype(np.float32)
        
        embed_filename = os.path.join(TEMP_EMBED_DIR, f"batch_{i:05d}.npy")
        doi_filename = os.path.join(TEMP_DOI_DIR, f"batch_{i:05d}.json")
        
        np.save(embed_filename, embeddings_cpu)
        with open(doi_filename, 'w') as f:
            json.dump(dois, f)

print(f"\nEmbedding generation complete.")

Scanning data/processed/dois_tmp_contrastive for completed batches...
Found 0 completed batches to skip.
Opening database connection: data/processed/s2orc_filtered.db
Total abstracts to process: 11,619,136
Starting/Resuming embedding generation...


Vectorizing Batches:   0%|          | 0/11620 [00:00<?, ?it/s]


Embedding generation complete.


In [6]:
# --- 6. チャンクの結合と保存 ---
print("Merging embedding chunks from disk...")

all_dois = []
doi_files = sorted([f for f in os.listdir(TEMP_DOI_DIR) if f.endswith('.json')])
total_rows_processed = 0

print(f"Reading {len(doi_files)} DOI chunks...")
for f in tqdm(doi_files, desc="Reading DOI chunks"):
    with open(os.path.join(TEMP_DOI_DIR, f), 'r') as fp:
        batch_dois = json.load(fp)
        all_dois.extend(batch_dois)
        total_rows_processed += len(batch_dois)
print(f"Total vectors to merge: {total_rows_processed:,}")

d = 768
final_embeddings = np.memmap(
    EMBEDDINGS_OUTPUT_FILE, 
    dtype=np.float32, 
    mode='w+', 
    shape=(total_rows_processed, d)
)

print(f"Merging {len(doi_files)} Embedding chunks...")
current_index = 0
for f in tqdm(doi_files, desc="Merging Embedding Chunks"):
    batch_npy_file = os.path.join(TEMP_EMBED_DIR, f.replace('.json', '.npy'))
    batch_data = np.load(batch_npy_file)
    
    start_index = current_index
    end_index = start_index + len(batch_data)
    final_embeddings[start_index:end_index] = batch_data
    current_index = end_index

final_embeddings.flush()
del final_embeddings
print(f"Final embeddings saved to {EMBEDDINGS_OUTPUT_FILE}")

Merging embedding chunks from disk...
Reading 11620 DOI chunks...


Reading DOI chunks:   0%|          | 0/11620 [00:00<?, ?it/s]

Total vectors to merge: 11,619,136
Merging 11620 Embedding chunks...


Merging Embedding Chunks:   0%|          | 0/11620 [00:00<?, ?it/s]

Final embeddings saved to data/processed/contrastive_scibert_cls_embeddings.npy


In [7]:
# --- 7. DOIマップの保存と一時ファイルの削除 ---
print(f"Saving DOI-to-Index map to {DOI_MAP_OUTPUT_FILE}...")
doi_to_index_map = {doi: i for i, doi in enumerate(all_dois)}
with open(DOI_MAP_OUTPUT_FILE, 'w') as f:
    json.dump(doi_to_index_map, f)

print(f"Cleaning up temporary directories...")
shutil.rmtree(TEMP_EMBED_DIR)
shutil.rmtree(TEMP_DOI_DIR)
print(f"Total embeddings saved: {len(doi_to_index_map)}")

Saving DOI-to-Index map to data/processed/contrastive_doi_map.json...
Cleaning up temporary directories...
Total embeddings saved: 11619136


In [8]:
# --- 8. Faissインデックスの構築と保存 ---
print("\n--- Building Faiss Index ---")
d = 768 
file_size = os.path.getsize(EMBEDDINGS_OUTPUT_FILE)
dtype_size = np.dtype(np.float32).itemsize
total_vectors = file_size // (d * dtype_size)
print(f"Calculated vector count: {total_vectors:,}")

print(f"Loading embeddings from {EMBEDDINGS_OUTPUT_FILE} (mmap_mode)...")
embeddings_mmap = np.memmap(
    EMBEDDINGS_OUTPUT_FILE,
    dtype=np.float32,
    mode='r',
    shape=(total_vectors, d)
)

index = faiss.IndexFlatL2(d)
print("Adding vectors to the index (this may take time)...")
index.add(embeddings_mmap)
print(f"Total vectors in index: {index.ntotal}")

print(f"Saving Faiss index to {FAISS_INDEX_OUTPUT_FILE}...")
faiss.write_index(index, FAISS_INDEX_OUTPUT_FILE)
print("\n--- Faiss Indexing Complete ---")


--- Building Faiss Index ---
Calculated vector count: 11,619,136
Loading embeddings from data/processed/contrastive_scibert_cls_embeddings.npy (mmap_mode)...
Adding vectors to the index (this may take time)...
Total vectors in index: 11619136
Saving Faiss index to data/processed/contrastive_scibert.faiss...

--- Faiss Indexing Complete ---
