In [1]:
!pip install transformers adapters datasets

Collecting transformers
  Using cached transformers-4.51.3-py3-none-any.whl.metadata (38 kB)
Collecting adapters
  Using cached adapters-1.1.1-py3-none-any.whl.metadata (17 kB)
Collecting datasets
  Downloading datasets-3.5.1-py3-none-any.whl.metadata (19 kB)
Collecting huggingface-hub<1.0,>=0.30.0 (from transformers)
  Using cached huggingface_hub-0.30.2-py3-none-any.whl.metadata (13 kB)
Collecting regex!=2019.12.17 (from transformers)
  Using cached regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)
Collecting tokenizers<0.22,>=0.21 (from transformers)
  Using cached tokenizers-0.21.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Collecting safetensors>=0.4.3 (from transformers)
  Using cached safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Collecting transformers
  Using cached transformers-4.48.3-py3-none-any.whl.metadata (44 kB)
Collecting pyarrow>=15.0.0 (from datase

In [2]:
import pandas as pd
import json

In [3]:
from datasets import load_dataset

corpus_clean_data = load_dataset("princeton-nlp/LitSearch", "corpus_clean", split="full")
query_data = load_dataset("princeton-nlp/LitSearch", "query", split="full")

In [5]:
corpus_df = corpus_clean_data.to_pandas()

In [6]:
corpus_df = corpus_df.drop(['citations', 'full_paper'], axis=1)

In [7]:
corpus_df.to_json('refined_corpus.json', orient='records')

In [8]:
df = pd.read_json("triplet.json")

In [9]:
df['specificity'] = query_data['specificity']

In [10]:
df

Unnamed: 0,question,positive_ctxs,hard_negative_ctxs,specificity
0,Are there any research papers on methods to co...,[{'title': 'TinyBERT: Distilling BERT for Natu...,[{'title': 'A Frame-based Sentence Representat...,0
1,Are there any resources available for translat...,[{'title': 'Parallel resources for Tunisian Ar...,[{'title': 'Comparing Sanskrit Texts for Criti...,1
2,Are there any studies that explore post-hoc te...,[{'title': 'Detecting Hallucinated Content in ...,[{'title': 'Mining the Web for Discourse Marke...,0
3,Are there any tools or studies that have focus...,[{'title': 'Learning from Relatives: Unified D...,[{'title': 'The KiezDeutsch Korpus (KiDKo) Rel...,1
4,Are there papers that propose contextualized c...,[{'title': 'Surface Form Competition: Why the ...,[{'title': 'DeepCx: A transition-based approac...,1
...,...,...,...,...
592,Which paper trains on linear regression to hyp...,[{'title': 'UNDERSTANDING CATASTROPHIC FORGETT...,[{'title': 'Multilingual Semantic Parsing : Pa...,1
593,Which paper uses the latent diffusion model fo...,[{'title': 'Efficient Planning with Latent Dif...,[{'title': 'APPLICATIONS OF A LEXICOGRAPHICAL ...,1
594,Which paper utilized MMD flows with Riesz kern...,[{'title': 'Posterior Sampling Based on Gradie...,[{'title': 'Biomedical Event Extraction with H...,1
595,What paper provides generalization bounds for ...,[{'title': 'Understanding prompt engineering m...,"[{'title': '', 'text': '', 'full_text': 'The A...",0


In [11]:
df["negative_ctxs"] = df["hard_negative_ctxs"]

In [12]:
df = df.drop(columns=['hard_negative_ctxs'])

In [13]:
df = df.reindex(columns=['specificity', 'question', 'positive_ctxs', 'negative_ctxs'])

In [14]:
processed_rows = []

for _, row in df.iterrows():
    row_dict = row.to_dict()

    if 'positive_ctxs' in row_dict and isinstance(row_dict['positive_ctxs'], list):
        for ctx in row_dict['positive_ctxs']:
            if isinstance(ctx, dict) and 'full_text' in ctx:
                del ctx['full_text']
    
    if 'negative_ctxs' in row_dict and isinstance(row_dict['negative_ctxs'], list):
        for ctx in row_dict['negative_ctxs']:
            if isinstance(ctx, dict) and 'full_text' in ctx:
                del ctx['full_text']
    
    processed_rows.append(row_dict)

In [15]:
processed_df = pd.DataFrame(processed_rows)
processed_df.to_json("./refined_triplet.json", orient='records', force_ascii=False, indent=2)

In [16]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from adapters import AutoAdapterModel
import numpy as np
from tqdm import tqdm
import random
import os

In [17]:
class LitSearchDataset(Dataset):
    def __init__(self, data: list[dict[str, any]], tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        question = item['question']
        
        # 포지티브 샘플 (1개)
        pos_ctx = item['positive_ctxs'][0]
        pos_text = pos_ctx['title'] + self.tokenizer.sep_token + pos_ctx['text']
        
        # 네거티브 샘플 (3개 중 랜덤 선택)
        neg_ctx = random.choice(item['negative_ctxs'])
        neg_text = neg_ctx['title'] + self.tokenizer.sep_token + neg_ctx['text']
        
        # 토크나이징
        query_tokens = self.tokenizer(
            question,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        pos_tokens = self.tokenizer(
            pos_text,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        neg_tokens = self.tokenizer(
            neg_text,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        # 배치 차원 제거
        query_tokens = {k: v.squeeze(0) for k, v in query_tokens.items()}
        pos_tokens = {k: v.squeeze(0) for k, v in pos_tokens.items()}
        neg_tokens = {k: v.squeeze(0) for k, v in neg_tokens.items()}
        
        return {
            'query': query_tokens,
            'positive': pos_tokens,
            'negative': neg_tokens,
        }

In [18]:
class TripletMarginLoss(torch.nn.Module):
    def __init__(self, margin=1.0):
        super(TripletMarginLoss, self).__init__()
        self.margin = margin
        
    def forward(self, query_emb, pos_emb, neg_emb):
        # L2 거리 계산
        pos_dist = torch.norm(query_emb - pos_emb, p=2, dim=1)
        neg_dist = torch.norm(query_emb - neg_emb, p=2, dim=1)
        
        # max(0, pos_dist - neg_dist + margin) 형태의 손실
        loss = torch.clamp(pos_dist - neg_dist + self.margin, min=0.0)
        return loss.mean()

In [19]:
test = AutoAdapterModel.from_pretrained("allenai/specter2_base")

In [20]:
# Retrieval용 proximity adapter 로드 (문서 임베딩용)
test.load_adapter("allenai/specter2", source="hf", load_as="proximity")
# Adhoc query adapter 로드 (쿼리 임베딩용)
test.load_adapter("allenai/specter2_adhoc_query", source="hf", load_as="adhoc_query")

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

'adhoc_query'

In [21]:
test.set_active_adapters("adhoc_query")

In [22]:
test.active_adapters

Stack[adhoc_query]

In [27]:
for param in test.parameters():
    param.requires_grad = False

for name, param in test.named_parameters():
    if "adapters.adhoc_query" in name:
        param.requires_grad = True

In [28]:
total_params = sum(p.numel() for p in test.parameters())
print(f"모델 총 파라미터 수: {total_params:,}")

trainable_params = sum(p.numel() for p in test.parameters() if p.requires_grad)
print(f"학습 가능한 파라미터 수: {trainable_params:,}")

모델 총 파라미터 수: 111,707,520
학습 가능한 파라미터 수: 894,528


In [29]:
print("학습 가능한 파라미터 목록:")
for name, param in test.named_parameters():
    if param.requires_grad:
        print(f"- {name}: {param.shape}, {param.numel()} 개")

학습 가능한 파라미터 목록:
- bert.encoder.layer.0.output.adapters.adhoc_query.adapter_down.0.weight: torch.Size([48, 768]), 36864 개
- bert.encoder.layer.0.output.adapters.adhoc_query.adapter_down.0.bias: torch.Size([48]), 48 개
- bert.encoder.layer.0.output.adapters.adhoc_query.adapter_up.weight: torch.Size([768, 48]), 36864 개
- bert.encoder.layer.0.output.adapters.adhoc_query.adapter_up.bias: torch.Size([768]), 768 개
- bert.encoder.layer.1.output.adapters.adhoc_query.adapter_down.0.weight: torch.Size([48, 768]), 36864 개
- bert.encoder.layer.1.output.adapters.adhoc_query.adapter_down.0.bias: torch.Size([48]), 48 개
- bert.encoder.layer.1.output.adapters.adhoc_query.adapter_up.weight: torch.Size([768, 48]), 36864 개
- bert.encoder.layer.1.output.adapters.adhoc_query.adapter_up.bias: torch.Size([768]), 768 개
- bert.encoder.layer.2.output.adapters.adhoc_query.adapter_down.0.weight: torch.Size([48, 768]), 36864 개
- bert.encoder.layer.2.output.adapters.adhoc_query.adapter_down.0.bias: torch.Size([48]), 4

In [30]:
!pip install faiss-gpu

Collecting faiss-gpu
  Using cached faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Using cached faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (85.5 MB)
Installing collected packages: faiss-gpu
Successfully installed faiss-gpu-1.7.2
[0m

In [31]:
from enum import Enum
from typing import Dict, Any, List

import torch

class TextType(Enum):
    KEY = 1
    QUERY = 2

class Retrieval:
    def __init__(self, index_name: str, index_type: str) -> None:
        self.index_name = index_name
        self.index_type = index_type

        self.keys = []
        self.values = []

    def __len__(self) -> int:
        return len(self.keys)

    def _get_embeddings(self, textList: list[str], type: TextType, show_progress_bar: bool = False) -> Any:
        raise NotImplementedError

    def _query(self, query_embedding: torch.Tensor, top_k: int = 10) -> List[int]:
        raise NotImplementedError

    def query(self, query_text: str, n: int, return_keys: bool = False) -> List[Any]:
        embedding_query = self._get_embeddings([query_text], TextType.QUERY)
        indices = self._query(embedding_query, n)
        if return_keys:
            results = [(self.keys[i], self.values[i]) for i in indices]
        else:
            results = [self.values[i] for i in indices]
        return results

    def clear(self) -> None:
        self.keys = []
        self.encoded_keys = []
        self.values = []

    def create_index(self, key_value_pairs: Dict[str, int]) -> None:
        if len(self.keys) > 0:
            raise ValueError("Index is not empty. Please create a new index or clear the existing one.")

        for key, value in key_value_pairs.items():
            self.keys.append(key)
            self.values.append(value)


In [32]:
from transformers import AutoTokenizer
from adapters import AutoAdapterModel
import torch
from torch.utils.data import DataLoader
import faiss


from tqdm import tqdm
import os
from enum import Enum


class SPECTER2QueryAdapterFinetuner(Retrieval):
    def __init__(self, base_model_name="allenai/specter2_base", device=None):

        self.keys = []
        self.values = []
        self.index = None
        self.faiss_index = None

        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device

        self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
        self.model = AutoAdapterModel.from_pretrained("allenai/specter2_base")

        self.model.load_adapter("allenai/specter2", source="hf", load_as="proximity")
        self.model.load_adapter("allenai/specter2_adhoc_query", source="hf", load_as="adhoc_query")

        for param in self.model.parameters():
            param.requires_grad = False

        for name, param in self.model.named_parameters():
            if "adapters.adhoc_query" in name:
                param.requires_grad = True

        self.model.to(self.device)

    def encode_text(self, input_ids, attention_mask, adapter_type="proximity"):
        """
        adapter_type: query -> "adhoc_query", text -> "proximity"
        """
        self.model.set_active_adapters(adapter_type)

        outputs = self.model(
            input_ids=input_ids.to(self.device),
            attention_mask=attention_mask.to(self.device)
        )
        embeddings = outputs.last_hidden_state[:, 0, :]  # CLS 토큰 임베딩 사용

        return embeddings

    def encode_query(self, query_text):
        self.model.eval()
        with torch.no_grad():
            tokens = self.tokenizer(
                query_text,
                padding='max_length',
                truncation=True,
                max_length=512,
                return_tensors='pt'
            )
            return self.encode_text(
                tokens['input_ids'],
                tokens['attention_mask'],
                adapter_type="adhoc_query"
            )

    def encode_paper(self, title, abstract):
        self.model.eval()
        with torch.no_grad():
            text = title + self.tokenizer.sep_token + abstract
            tokens = self.tokenizer(
                text,
                padding='max_length',
                truncation=True,
                max_length=512,
                return_tensors='pt'
            )
            return self.encode_text(
                tokens['input_ids'],
                tokens['attention_mask'],
                adapter_type="proximity"
            )

    def finetune(self, train_data, val_data=None, output_dir="./specter2_adhoc_query_finetuned",
                 lr=2e-5, batch_size=8, epochs=3, margin=1.0, eval_steps=100,
                 weight_decay=0.01, warmup_ratio=0.1):

        train_dataset = LitSearchDataset(train_data, self.tokenizer)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

        optimizer = torch.optim.AdamW(
            [p for p in self.model.parameters() if p.requires_grad],
            lr=lr,
            weight_decay=weight_decay
        )

        total_steps = len(train_loader) * epochs
        warmup_steps = int(total_steps * warmup_ratio)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=lr,
            total_steps=total_steps,
            pct_start=warmup_ratio,
            anneal_strategy='linear'
        )

        triplet_loss = TripletMarginLoss(margin=margin)

        self.model.train()
        global_step = 0
        best_val_loss = float('inf')

        for epoch in range(epochs):
            epoch_loss = 0.0
            progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

            for batch in progress_bar:
                # 쿼리, 포지티브, 네거티브 임베딩
                query_emb = self.encode_text(
                    batch['query']['input_ids'],
                    batch['query']['attention_mask'],
                    adapter_type="adhoc_query"
                )

                pos_emb = self.encode_text(
                    batch['positive']['input_ids'],
                    batch['positive']['attention_mask'],
                    adapter_type="proximity"
                )

                neg_emb = self.encode_text(
                    batch['negative']['input_ids'],
                    batch['negative']['attention_mask'],
                    adapter_type="proximity"
                )

                loss = triplet_loss(query_emb, pos_emb, neg_emb)

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_([p for p in self.model.parameters() if p.requires_grad], 1.0)
                optimizer.step()
                scheduler.step()

                epoch_loss += loss.item()
                progress_bar.set_postfix({"loss": loss.item()})

                global_step += 1
                if val_data is not None and global_step % eval_steps == 0:
                    val_loss = self.evaluate(val_data, batch_size)
                    print(f"Validation Loss: {val_loss:.4f}")

                    if val_loss < best_val_loss:
                        best_val_loss = val_loss
                        self.save_model(output_dir)
                        print(f"Model saved to {output_dir} (val_loss: {val_loss:.4f})")

                    self.model.train()

            avg_epoch_loss = epoch_loss / len(train_loader)
            print(f"Epoch {epoch+1}/{epochs} - Avg Loss: {avg_epoch_loss:.4f}")

        if val_data is None or epochs % eval_steps != 0:
            self.save_model(output_dir)

        return self.model

    def evaluate(self, val_data, batch_size=8):
        val_dataset = LitSearchDataset(val_data, self.tokenizer)
        val_loader = DataLoader(val_dataset, batch_size=batch_size)

        self.model.eval()
        triplet_loss = TripletMarginLoss(margin=1.0)
        total_loss = 0.0

        with torch.no_grad():
            for batch in val_loader:
                query_emb = self.encode_text(
                    batch['query']['input_ids'],
                    batch['query']['attention_mask'],
                    adapter_type="adhoc_query"
                )

                pos_emb = self.encode_text(
                    batch['positive']['input_ids'],
                    batch['positive']['attention_mask'],
                    adapter_type="proximity"
                )

                neg_emb = self.encode_text(
                    batch['negative']['input_ids'],
                    batch['negative']['attention_mask'],
                    adapter_type="proximity"
                )

                loss = triplet_loss(query_emb, pos_emb, neg_emb)
                total_loss += loss.item()

        return total_loss / len(val_loader)

    def save_model(self, output_dir):
        os.makedirs(output_dir, exist_ok=True)

        self.model.save_adapter(output_dir, "adhoc_query")

        self.tokenizer.save_pretrained(output_dir)

        print(f"어댑터가 {output_dir}에 저장되었습니다.")

    def _get_embeddings(self, textList: list[str], type: TextType, show_progress_bar: bool = True) -> torch.Tensor:
        if type == TextType.KEY:
            self.model.set_active_adapters("proximity")
        else:
            self.model.set_active_adapters("adhoc_query")

        batch_size = 16
        embeddings = []

        for i in tqdm(range(0, len(textList), batch_size), desc="Getting embeddings"):
            batch_texts = textList[i:i+batch_size]
            encoded = self.tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                return_tensors="pt",
                max_length=512  # 최대 길이 명시적 지정
            ).to(self.device)

            with torch.no_grad():
                outputs = self.model(**encoded)

            batch_embeddings = outputs.last_hidden_state[:,0,:].cpu()
            embeddings.append(batch_embeddings)

            if i % (batch_size * 10) == 0:
                torch.cuda.empty_cache()

        embeddings = torch.cat(embeddings, dim=0)

        return embeddings

    def clear(self):
        super().clear()
        self.index = None
        self.faiss_index = None

    def create_index(self, key_value_pairs: dict[str, int]) -> None:
        super().create_index(key_value_pairs)
        self.index = self._get_embeddings(self.keys, TextType.KEY)

        # FAISS 인덱스 생성
        vector_dim = self.index.shape[1]
        index_flat = faiss.IndexFlatIP(vector_dim)
        index_vectors = self.index.numpy()
        faiss.normalize_L2(index_vectors)

        # 인덱스에 벡터 추가
        index_flat.add(index_vectors)
        self.faiss_index = index_flat

    def _query(self, query_embedding: torch.Tensor, top_k: int = 10) -> list[int]:
        if self.faiss_index is None:
            raise ValueError("FAISS index has not been created yet. Call create_index first.")

        query_vector = query_embedding.numpy()
        faiss.normalize_L2(query_vector)
        distances, indices = self.faiss_index.search(query_vector, top_k)

        return indices[0].tolist()

    def query(self, query_text: str, n: int, return_keys: bool = False) -> list:
        query_embedding = self._get_embeddings([query_text], TextType.QUERY)
        indices = self._query(query_embedding, n)

        if return_keys:
            results = [(self.keys[i], self.values[i]) for i in indices]
        else:
            results = [self.values[i] for i in indices]

        return results


In [40]:
def calculate_recall(corpus_id_list: list, retrieved_id_list: list, k: int):
    top_k = retrieved_id_list[:k]
    intersection = set(corpus_id_list) & set(top_k)
    return len(intersection) / len(corpus_id_list) if corpus_id_list else 0.0

In [41]:
def evaluate_model(finetuner, test_data, query_data, k_values=[1, 5, 10, 20]):
    # 제 IDE 로컬 환경 pyright이 자꾸 타입 터쳐서 넣은 타입 캐스팅 코드입니다.
    query_data = cast(Dataset, query_data)

    query_df = pd.DataFrame({
        'query': query_data['query'],
        'corpusids': query_data['corpusids']
    })

    test_questions = [item['question'] for item in test_data]
    filtered_query_df = query_df[query_df['query'].isin(test_questions)]

    results = {}
    result = []

    for k in k_values:
        total_recall = 0
        count = 0
        for i, item in enumerate(test_data):
            query = item['question']
            top_k_results = finetuner.query(query, k)

            query_row = filtered_query_df[filtered_query_df['query'] == query]
            if not query_row.empty:
                true_corpus_ids = query_row.iloc[0]['corpusids']
                if isinstance(true_corpus_ids, list):
                    true_corpus_ids_flat = true_corpus_ids
                else:
                    true_corpus_ids_flat = [true_corpus_ids]

                intersection = set(true_corpus_ids_flat) & set(top_k_results)
                recall = len(intersection) / len(true_corpus_ids_flat) if true_corpus_ids_flat else 0

                total_recall += recall
                count += 1

        if count > 0:
            avg_recall = total_recall / count
            results[f'Recall@{k}'] = avg_recall
            result.append(avg_recall)
        else:
            results[f'Recall@{k}'] = 0
            result.append(0)

    return results

In [42]:
from torch.utils.data import Dataset

def get_clean_corpusid(item: dict) -> int:
    return item['corpusid']

def get_clean_title(item: dict) -> str:
    return item['title']

def get_clean_abstract(item: dict) -> str:
    return item['abstract']

def get_clean_title_abstract(item: dict, tokenizer) -> str:
    title = get_clean_title(item)
    abstract = get_clean_abstract(item)
    return title + tokenizer.sep_token + abstract

def create_kv_pairs(data: Dataset, key: str, tokenizer) -> dict:
    return {get_clean_title_abstract(record, tokenizer): get_clean_corpusid(record) for record in data}

In [45]:
import json
import random
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from typing import cast

import torch
from torch.utils.data import Dataset
from datasets import load_dataset

# 데이터 로드
with open("refined_triplet.json", "r", encoding="utf-8") as f:
    data = json.load(f)

print(f"Total dataset size: {len(data)} samples")

# Specificity에 따라 데이터 분류
spec_0_queries = [item for item in data if item.get('specificity', 0) == 0]
spec_1_queries = [item for item in data if item.get('specificity', 0) == 1]

print(f"Specificity 0 queries: {len(spec_0_queries)}")
print(f"Specificity 1 queries: {len(spec_1_queries)}")

# 먼저 테스트 세트 분리 (20%)
spec_0_train_val, spec_0_test = train_test_split(spec_0_queries, test_size=0.2, random_state=42)
spec_1_train_val, spec_1_test = train_test_split(spec_1_queries, test_size=0.2, random_state=42)

# 남은 데이터에서 검증 세트 분리 (원본의 10%, 즉 train_val의 12.5%)
spec_0_train, spec_0_val = train_test_split(spec_0_train_val, test_size=0.125, random_state=42)
spec_1_train, spec_1_val = train_test_split(spec_1_train_val, test_size=0.125, random_state=42)

# 세트 결합
train_data = spec_0_train + spec_1_train
val_data = spec_0_val + spec_1_val
test_data = spec_0_test + spec_1_test

# 각 세트 섞기
random.shuffle(train_data)
random.shuffle(val_data)
random.shuffle(test_data)

print(f"Training: {len(train_data)} samples (Spec 0: {len(spec_0_train)}, Spec 1: {len(spec_1_train)})")
print(f"Validation: {len(val_data)} samples (Spec 0: {len(spec_0_val)}, Spec 1: {len(spec_1_val)})")
print(f"Testing: {len(test_data)} samples (Spec 0: {len(spec_0_test)}, Spec 1: {len(spec_1_test)})")

# 학습 과정
finetuner = SPECTER2QueryAdapterFinetuner()

# 제 IDE 로컬 환경 pyright이 자꾸 타입 터쳐서 넣은 타입 캐스팅 코드입니다.
corpus_clean_data = cast(Dataset, corpus_clean_data)
kv_pairs = create_kv_pairs(corpus_clean_data, "title_abstract", finetuner.tokenizer)
finetuner.clear()
finetuner.create_index(kv_pairs)

finetuner.finetune(
    train_data=train_data,
    val_data=val_data,
    output_dir="./specter2_adhoc_query_finetuned",
    lr=2e-4,
    batch_size=8,
    epochs=5,
    margin=1.0,
    eval_steps=50,
    weight_decay=0.01,
    warmup_ratio=0.1
)

print("Fine-tuning complete!")

# 테스트 데이터에 대한 성능 평가
print("Evaluating model on test set...")

# Specificity별 성능 평가
test_spec_0 = [item for item in test_data if item.get('specificity', 0) == 0]
test_spec_1 = [item for item in test_data if item.get('specificity', 0) == 1]

query_data = load_dataset("princeton-nlp/LitSearch", "query", split="full")
corpus_clean_data = load_dataset("princeton-nlp/LitSearch", "corpus_clean", split="full")

# 전체 테스트 세트 평가
overall_performance = evaluate_model(finetuner, test_data, query_data)
print(f"Overall test performance: {overall_performance}")

# Specificity 0 쿼리에 대한 평가
spec_0_performance = evaluate_model(finetuner, test_spec_0, query_data)
print(f"Specificity 0 performance: {spec_0_performance}")

# Specificity 1 쿼리에 대한 평가
spec_1_performance = evaluate_model(finetuner, test_spec_1, query_data)
print(f"Specificity 1 performance: {spec_1_performance}")

Total dataset size: 597 samples
Specificity 0 queries: 155
Specificity 1 queries: 442
Training: 416 samples (Spec 0: 108, Spec 1: 308)
Validation: 61 samples (Spec 0: 16, Spec 1: 45)
Testing: 120 samples (Spec 0: 31, Spec 1: 89)


Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Getting embeddings: 100%|██████████| 3604/3604 [25:32<00:00,  2.35it/s]
Epoch 1/5:  96%|█████████▌| 50/52 [00:59<00:05,  2.83s/it, loss=0]     

Validation Loss: 0.0241
어댑터가 ./specter2_adhoc_query_finetuned에 저장되었습니다.
Model saved to ./specter2_adhoc_query_finetuned (val_loss: 0.0241)


Epoch 1/5: 100%|██████████| 52/52 [01:01<00:00,  1.18s/it, loss=0]      


Epoch 1/5 - Avg Loss: 0.0528


Epoch 2/5:  92%|█████████▏| 48/52 [00:57<00:11,  2.85s/it, loss=0]      

Validation Loss: 0.0130
어댑터가 ./specter2_adhoc_query_finetuned에 저장되었습니다.
Model saved to ./specter2_adhoc_query_finetuned (val_loss: 0.0130)


Epoch 2/5: 100%|██████████| 52/52 [01:01<00:00,  1.18s/it, loss=0.056]  


Epoch 2/5 - Avg Loss: 0.0296


Epoch 3/5:  88%|████████▊ | 46/52 [00:54<00:16,  2.81s/it, loss=0]      

Validation Loss: 0.0331


Epoch 3/5: 100%|██████████| 52/52 [01:01<00:00,  1.18s/it, loss=0.0527]


Epoch 3/5 - Avg Loss: 0.0134


Epoch 4/5:  85%|████████▍ | 44/52 [00:52<00:22,  2.81s/it, loss=0.000137]

Validation Loss: 0.0518


Epoch 4/5: 100%|██████████| 52/52 [01:01<00:00,  1.18s/it, loss=0]       


Epoch 4/5 - Avg Loss: 0.0160


Epoch 5/5:  81%|████████  | 42/52 [00:50<00:28,  2.82s/it, loss=0.0578]

Validation Loss: 0.0746


Epoch 5/5: 100%|██████████| 52/52 [01:01<00:00,  1.18s/it, loss=0]     


Epoch 5/5 - Avg Loss: 0.0070
어댑터가 ./specter2_adhoc_query_finetuned에 저장되었습니다.
Fine-tuning complete!
Evaluating model on test set...


Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 12.15it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 77.50it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 69.79it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 75.70it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 73.40it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 74.25it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 71.84it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 72.16it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 75.28it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 76.94it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 74.85it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 72.35it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 73.44it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 75.61it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 77.79it/s]
Getting em

Overall test performance: {'Recall@1': 0.20694444444444443, 'Recall@5': 0.4180555555555555, 'Recall@10': 0.5169444444444445, 'Recall@20': 0.6005555555555556}


Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 79.55it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 74.39it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 79.53it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 79.39it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 77.11it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 76.60it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 77.86it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 77.60it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 78.67it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 79.51it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 77.09it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 78.52it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 79.03it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 77.68it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 79.84it/s]
Getting em

Specificity 0 performance: {'Recall@1': 0.15591397849462368, 'Recall@5': 0.3344086021505376, 'Recall@10': 0.35698924731182796, 'Recall@20': 0.5989247311827957}


Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 75.57it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 77.97it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 77.75it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 79.38it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 77.63it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 78.46it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 79.28it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 81.50it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 80.78it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 77.39it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 82.04it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 77.66it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 72.63it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 76.11it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 76.06it/s]
Getting em

Specificity 1 performance: {'Recall@1': 0.20224719101123595, 'Recall@5': 0.4438202247191011, 'Recall@10': 0.5617977528089888, 'Recall@20': 0.6348314606741573}
