In [1]:
import pandas as pd
import numpy as np
import os
import torch
import torch.nn as nn
import time
from transformers import (
    AutoTokenizer, 
    AutoModelForSequenceClassification, 
    BertPreTrainedModel
)
from transformers.modeling_outputs import SequenceClassifierOutput
import accelerate
import sqlite3
from tqdm.auto import tqdm

# --- 1. GPUの確認 ---
if torch.cuda.is_available():
    print(f"✅ GPU is available. Device: {torch.cuda.get_device_name(0)}")
    device = torch.device("cuda")
else:
    print("⚠️ GPU not found. Running on CPU.")
    device = torch.device("cpu")

✅ GPU is available. Device: NVIDIA RTX A6000


In [2]:
# --- 2. カスタムモデルクラスの定義 ---
# (C-Encoder (Margin) の定義をロードのために再掲)
class CrossEncoderMarginModel(BertPreTrainedModel):
    def __init__(self, config):
        super(CrossEncoderMarginModel, self).__init__(config)
        self.scorer = AutoModelForSequenceClassification.from_config(config)
        self.init_weights()

    def forward(self, input_ids=None, attention_mask=None, **kwargs):
        # 推論時はscorerだけを使用
        return self.scorer(input_ids=input_ids, attention_mask=attention_mask)

print("Custom model class 'CrossEncoderMarginModel' defined.")

Custom model class 'CrossEncoderMarginModel' defined.


In [3]:
# --- 3. 設定とリソースのロード ---
DB_PATH = "data/processed/s2orc_filtered.db"

# ▼▼▼ 訓練済みのC-Encoder (Margin) モデルのパス ▼▼▼
TRAINED_MODEL_PATH = "models/cencoder_margin_v2/best_model"
MODEL_CHECKPOINT = "allenai/longformer-base-4096"

# 評価クエリ
EVAL_PAPERS_FILE = "data/datapapers/sampled/evaluation_data_papers_50.csv"

# 推定用の設定
MAX_LENGTH = 2048
# (実験計画でのバッチサイズを使用)
ESTIMATION_BATCH_SIZE = 16 

print(f"Loading tokenizer & model from: {TRAINED_MODEL_PATH}")
tokenizer = AutoTokenizer.from_pretrained(TRAINED_MODEL_PATH)
model = CrossEncoderMarginModel.from_pretrained(TRAINED_MODEL_PATH, num_labels=1).to(device)
model.eval()

# DBの総ベクトル数
try:
    with sqlite3.connect(DB_PATH) as conn:
        TOTAL_DB_PAPERS = conn.execute("SELECT COUNT(doi) FROM papers WHERE abstract IS NOT NULL AND abstract != ''").fetchone()[0]
    print(f"Total papers in DB: {TOTAL_DB_PAPERS:,}")
except Exception as e:
    print(f"Could not count DB: {e}")
    TOTAL_DB_PAPERS = 11619136 # フォールバック

print("Resources loaded.")

Loading tokenizer & model from: models/cencoder_margin_v2/best_model
Total papers in DB: 11,619,136
Resources loaded.


In [4]:
# --- 4. サンプルデータの取得 ---
def get_sample_data(db_path, eval_papers_file, batch_size):
    """
    1件のクエリ と 1バッチ分の候補論文 をDBから取得
    """
    with sqlite3.connect(db_path) as conn:
        # 1. 1件のクエリを取得
        df_eval_papers = pd.read_csv(eval_papers_file)
        sample_data_paper_doi = df_eval_papers.iloc[0]['cited_datapaper_doi']
        
        query_gt = "SELECT citing_doi FROM positive_candidates WHERE cited_datapaper_doi = ? AND human_annotation_status = 1"
        query_doi = conn.execute(query_gt, (sample_data_paper_doi,)).fetchone()[0]
        query_abstract = conn.execute("SELECT abstract FROM papers WHERE doi = ?", (query_doi,)).fetchone()[0]
        
        # 2. 1バッチ分の候補論文を適当に取得
        candidates_abstracts = []
        rows = conn.execute(f"SELECT abstract FROM papers WHERE abstract IS NOT NULL AND abstract != '' LIMIT {batch_size}").fetchall()
        candidates_abstracts = [row[0] for row in rows]
        
    return query_abstract, candidates_abstracts

print("Fetching sample data...")
query_text, candidate_texts = get_sample_data(DB_PATH, EVAL_PAPERS_FILE, ESTIMATION_BATCH_SIZE)
print(f"Sample query: {query_text[:100]}...")
print(f"Sample candidates count: {len(candidate_texts)}")

Fetching sample data...
Sample query: The increase in globalization has led to the redefinition of the tax policy perceptions of countries...
Sample candidates count: 16


In [5]:
# --- 5. 時間計測と総時間の推定 ---
print("\n--- Estimating Inference Time ---")

# 1. サンプルバッチをトークナイズ
inputs = tokenizer(
    [query_text] * ESTIMATION_BATCH_SIZE, # クエリをバッチサイズ分複製
    candidate_texts,                     # 候補リスト
    padding="max_length", 
    truncation=True, 
    max_length=MAX_LENGTH, 
    return_tensors="pt"
).to(device)

# 2. 1バッチの推論時間を計測
print(f"Running inference for 1 batch (Size={ESTIMATION_BATCH_SIZE})...")
torch.cuda.synchronize() # GPUの処理を待つ
start_time = time.perf_counter()

with torch.no_grad():
    # C-Encoder (Margin) は .scorer を呼び出す
    outputs = model.scorer(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
    # (Siameseモデルとは異なり、ここで重いLongformerが動く)

torch.cuda.synchronize() # GPUの処理が終わるのを待つ
end_time = time.perf_counter()

time_per_batch = end_time - start_time
print(f"Time taken for 1 batch: {time_per_batch:.4f} seconds")

# 3. 総時間の推定
print("\n" + "="*50)
print("--- Total Time Estimation ---")
num_queries = 50 # 評価クエリの総数
total_batches_per_query = (TOTAL_DB_PAPERS + ESTIMATION_BATCH_SIZE - 1) // ESTIMATION_BATCH_SIZE

print(f"Total papers in DB: {TOTAL_DB_PAPERS:,}")
print(f"Batch size: {ESTIMATION_BATCH_SIZE}")
print(f"Total batches PER QUERY: {total_batches_per_query:,}")
print(f"Total queries to evaluate: {num_queries}")

# 1クエリあたりの総時間
time_per_query_sec = time_per_batch * total_batches_per_query
time_per_query_hours = time_per_query_sec / 3600

print("\n--- PER QUERY (1件あたり) ---")
print(f"Estimated time per query: {time_per_query_sec:.2f} seconds")
print(f"Estimated time per query: {time_per_query_hours:.2f} hours")

# 全評価にかかる総時間
total_time_hours = time_per_query_hours * num_queries
total_time_days = total_time_hours / 24

print("\n--- TOTAL (全50件) ---")
print(f"Estimated TOTAL time for {num_queries} queries: {total_time_hours:.2f} hours")
print(f"Estimated TOTAL time for {num_queries} queries: {total_time_days:.2f} days")
print("="*50)

if total_time_days > 7:
    print("\n⚠️ 警告: 推定時間が1週間を超えています。")
    print("Cross-Encoderモデルでの全DB（1160万件）評価は現実的ではありません。")
    print("評価戦略を「リランキング評価」（上位1000件のみを評価）に見直すことを強く推奨します。")


--- Estimating Inference Time ---
Running inference for 1 batch (Size=16)...
Time taken for 1 batch: 1.1927 seconds

--- Total Time Estimation ---
Total papers in DB: 11,619,136
Batch size: 16
Total batches PER QUERY: 726,196
Total queries to evaluate: 50

--- PER QUERY (1件あたり) ---
Estimated time per query: 866130.01 seconds
Estimated time per query: 240.59 hours

--- TOTAL (全50件) ---
Estimated TOTAL time for 50 queries: 12029.58 hours
Estimated TOTAL time for 50 queries: 501.23 days

⚠️ 警告: 推定時間が1週間を超えています。
Cross-Encoderモデルでの全DB（1160万件）評価は現実的ではありません。
評価戦略を「リランキング評価」（上位1000件のみを評価）に見直すことを強く推奨します。
