Writer: 최장혁

In [21]:
!pip install transformers datasets adapters faiss-gpu tqdm hf-xet omegaconf

Collecting omegaconf
  Downloading omegaconf-2.3.0-py3-none-any.whl.metadata (3.9 kB)
Collecting antlr4-python3-runtime==4.9.* (from omegaconf)
  Downloading antlr4-python3-runtime-4.9.3.tar.gz (117 kB)
  Preparing metadata (setup.py) ... [?25ldone
Downloading omegaconf-2.3.0-py3-none-any.whl (79 kB)
Building wheels for collected packages: antlr4-python3-runtime
[33m  DEPRECATION: Building 'antlr4-python3-runtime' using the legacy setup.py bdist_wheel mechanism, which will be removed in a future version. pip 25.3 will enforce this behaviour change. A possible replacement is to use the standardized build interface by setting the `--use-pep517` option, (possibly combined with `--no-build-isolation`), or adding a `pyproject.toml` file to the source tree of 'antlr4-python3-runtime'. Discussion can be found at https://github.com/pypa/pip/issues/6334[0m[33m
[0m  Building wheel for antlr4-python3-runtime (setup.py) ... [?25ldone
[?25h  Created wheel for antlr4-python3-runtime: filename

Load Base class and util functions

In [2]:
from enum import Enum
from typing import Any, Dict, List

import torch


class TextType(Enum):
    KEY = 1
    QUERY = 2


class Retrieval:
    def __init__(self, index_name: str = "", index_type: str = "") -> None:
        self.index_name = index_name
        self.index_type = index_type

        self.keys = []
        self.values = []
        self.index = None
        self.faiss_index = None

    def __len__(self) -> int:
        return len(self.keys)

    def clear(self) -> None:
        self.keys = []
        self.values = []
        self.index = None
        self.faiss_index = None

    def _encode_text(self, input_ids, attention_mask, adapter_type="proximity"):
        """
        input: tokenized value
        output: embedding value
        """
        raise NotImplementedError

    def encode_query(self, query_text: str):
        """
        embed query text
        """
        raise NotImplementedError

    def encode_paper(self, title: str, abstract: str):
        """
        embed paper using the concat of title and abstract
        """
        raise NotImplementedError

    def _query(self, query_embedding: torch.Tensor, top_k: int = 10) -> List[int]:
        raise NotImplementedError

    def query(self, query_text: str, n: int, return_keys: bool = False) -> List[Any]:
        query_embedding = self.encode_query(query_text)
        indices = self._query(query_embedding, n)
        if return_keys:
            results = [(self.keys[i], self.values[i]) for i in indices]
        else:
            results = [self.values[i] for i in indices]
        return results

    def _encode_paper_batch(
        self, text_list: list[str], show_progress_bar: bool = False
    ) -> Any:
        """
        embed batch of papers using the concat of title and abstract
        """
        raise NotImplementedError

    def create_index(self, key_value_pairs: Dict[str, int]) -> None:
        if len(self.keys) > 0:
            raise ValueError(
                "Index is not empty. Please create a new index or clear the existing one."
            )

        for key, value in key_value_pairs.items():
            self.keys.append(key)
            self.values.append(value)


In [3]:
from typing import List, Tuple, Any
from datasets import Dataset

def get_clean_corpusid(item: dict) -> int:
    return item['corpusid']

def get_clean_title(item: dict) -> str:
    return item['title']

def get_clean_abstract(item: dict) -> str:
    return item['abstract']

def get_clean_title_abstract(item: dict) -> str:
    title = get_clean_title(item)
    abstract = get_clean_abstract(item)
    return f"Title: {title}\nAbstract: {abstract}"

def get_clean_full_paper(item: dict) -> str:
    return item['full_paper']

def get_clean_paragraph_indices(item: dict) -> List[Tuple[int, int]]:
    text = get_clean_full_paper(item)
    paragraph_indices = []
    paragraph_start = 0
    paragraph_end = 0
    while paragraph_start < len(text):
        paragraph_end = text.find("\n\n", paragraph_start)
        if paragraph_end == -1:
            paragraph_end = len(text)
        paragraph_indices.append((paragraph_start, paragraph_end))
        paragraph_start = paragraph_end + 2
    return paragraph_indices

def get_clean_text(item: dict, start_idx: int, end_idx: int) -> str:
    text = get_clean_full_paper(item)
    assert start_idx >= 0 and end_idx >= 0
    assert start_idx <= end_idx
    assert end_idx <= len(text)
    return text[start_idx:end_idx]

def get_clean_paragraphs(item: dict, min_words: int = 10) -> List[str]:
    paragraph_indices = get_clean_paragraph_indices(item)
    paragraphs = [get_clean_text(item, paragraph_start, paragraph_end) for paragraph_start, paragraph_end in paragraph_indices]
    paragraphs = [paragraph for paragraph in paragraphs if len(paragraph.split()) >= min_words]
    return paragraphs

def get_clean_citations(item: dict) -> List[int]:
    return item['citations']

def get_clean_dict(data: Dataset) -> dict:
    return {get_clean_corpusid(item): item for item in data}

def create_kv_pairs(data: List[dict], key: str) -> dict:
    if key == "title_abstract":
        kv_pairs = {get_clean_title_abstract(record): get_clean_corpusid(record) for record in data}
    elif key == "full_paper":
        kv_pairs = {get_clean_full_paper(record): get_clean_corpusid(record) for record in data}
    elif key == "paragraphs":
        kv_pairs = {}
        for record in data:
            corpusid = get_clean_corpusid(record)
            paragraphs = get_clean_paragraphs(record)
            for paragraph_idx, paragraph in enumerate(paragraphs):
                kv_pairs[paragraph] = (corpusid, paragraph_idx)
    else:
        raise ValueError("Invalid key")
    return kv_pairs

In [4]:
import numpy as np

def calculate_recall(corpusids: list, retrieved: list, k: int):
    top_k = retrieved[:k]
    intersection = set(corpusids) & set(top_k)
    return len(intersection) / len(corpusids) if corpusids else 0.0

def mean_recall(dataset, k):
    recalls = [
        calculate_recall(example['corpusids'], example['retrieved'], k)
        for example in dataset
    ]
    return np.mean(recalls) if recalls else 0.0

Load Datasets

In [5]:
from datasets import load_dataset

query_data = load_dataset("princeton-nlp/LitSearch", "query", split="full")
corpus_clean_data = load_dataset("princeton-nlp/LitSearch", "corpus_clean", split="full")
corpus_s2orc_data = load_dataset("princeton-nlp/LitSearch", "corpus_s2orc", split="full")

In [6]:
from datasets import Dataset

def classify(example):
    if example['specificity'] == 0:
        if example['query_set'].startswith('inline'):
            return 'spec0_inline'
        else:
            return 'spec0_manual'
    else:
        if example['query_set'].startswith('inline'):
            return 'spec1_inline'
        else:
            return 'spec1_manual'

query_data = query_data.map(lambda x: {"class": classify(x)})

spec0_ds = query_data.filter(lambda x: x['specificity'] == 0)
spec1_ds = query_data.filter(lambda x: x['specificity'] == 1)
spec0_inline_ds = query_data.filter(lambda x: x['class'] == 'spec0_inline')
spec0_manual_ds = query_data.filter(lambda x: x['class'] == 'spec0_manual')
spec1_inline_ds = query_data.filter(lambda x: x['class'] == 'spec1_inline')
spec1_manual_ds = query_data.filter(lambda x: x['class'] == 'spec1_manual')

print(spec0_ds.num_rows, spec1_ds.num_rows, spec0_inline_ds.num_rows, spec0_manual_ds.num_rows, spec1_inline_ds.num_rows, spec1_manual_ds.num_rows)

155 442 120 35 231 211


In [7]:
kv_pairs = create_kv_pairs(corpus_clean_data, "title_abstract")

Loss and Margin

In [8]:
from torch.utils.data import Dataset

import torch


class TripletMarginLoss(torch.nn.Module):
    def __init__(self, margin=1.0):
        super(TripletMarginLoss, self).__init__()
        self.margin = margin

    def forward(self, query_emb, pos_emb, neg_emb):
        # L2 거리 계산
        pos_dist = torch.norm(query_emb - pos_emb, p=2, dim=1)
        neg_dist = torch.norm(query_emb - neg_emb, p=2, dim=1)

        # max(0, pos_dist - neg_dist + margin) 형태의 손실
        loss = torch.clamp(pos_dist - neg_dist + self.margin, min=0.0)
        return loss.mean()

class LitSearchTripletDataset(Dataset):
    def __init__(self, data: list[dict[str, str]], tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return {
            "query": item["query"],
            "positive_title": item["positive_title"],
            "positive_abstract": item["positive_abstract"],
            "negative_title": item["negative_title"],
            "negative_abstract": item["negative_abstract"],
        }


SPECTER2

In [None]:
import os

import faiss
import torch
from adapters import AutoAdapterModel
from tqdm import tqdm
from transformers.models.auto.tokenization_auto import AutoTokenizer


class Specter2(Retrieval):
    def __init__(self, base_model_name="allenai/specter2_base", device=None):
        super().__init__()
        self.keys = []
        self.values = []
        self.index = None
        self.faiss_index = None

        if device is None:
            if torch.cuda.is_available():
                self.device = torch.device("cuda")
            elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
                self.device = torch.device("mps")
            else:
                self.device = torch.device("cpu")
        else:
            self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
        self.model = AutoAdapterModel.from_pretrained(base_model_name)

        self.model.load_adapter("allenai/specter2", source="hf", load_as="proximity")

        self.model.load_adapter(
            "allenai/specter2_adhoc_query", source="hf", load_as="adhoc_query"
        )

        for param in self.model.parameters():
            param.requires_grad = False

        for name, param in self.model.named_parameters():
            if "adapters.adhoc_query" in name:
                param.requires_grad = True

        self.model.to(self.device)

    def parameters(self):
        return self.model.parameters()

    def clear(self):
        super().clear()
        self.index = None
        self.faiss_index = None

    def save_model(self, output_dir):
        os.makedirs(output_dir, exist_ok=True)

        self.model.save_adapter(output_dir, "adhoc_query")

        self.tokenizer.save_pretrained(output_dir)

        print(f"어댑터가 {output_dir}에 저장되었습니다.")

    def _encode_text(self, input_ids, attention_mask, adapter_type="proximity"):
        """

        adapter_type: query -> "adhoc_query", paper -> "proximity"
        """
        self.model.set_active_adapters(adapter_type)

        outputs = self.model(
            input_ids=input_ids.to(self.device),
            attention_mask=attention_mask.to(self.device),
        )
        embeddings = outputs.last_hidden_state[:, 0, :]
        return embeddings

    def encode_query(self, query_text: str, no_grad=True):
        self.model.eval()
        if no_grad:
            with torch.no_grad():
                tokens = self.tokenizer(
                    query_text,
                    padding="max_length",
                    truncation=True,
                    max_length=512,
                    return_tensors="pt",
                )
                return self._encode_text(
                    tokens["input_ids"],
                    tokens["attention_mask"],
                    adapter_type="adhoc_query",
                )
        else:
            tokens = self.tokenizer(
                query_text,
                padding="max_length",
                truncation=True,
                max_length=512,
                return_tensors="pt",
            )
            return self._encode_text(
                tokens["input_ids"],
                tokens["attention_mask"],
                adapter_type="adhoc_query",
            )

    def encode_paper(self, title: str, abstract: str, no_grad=True):
        self.model.eval()
        sep_token = self.tokenizer.sep_token
        if isinstance(title, list) and isinstance(abstract, list):
            text = [t + sep_token + a for t, a in zip(title, abstract)]
        else:
            text = title + sep_token + abstract
        if no_grad:
            with torch.no_grad():
                tokens = self.tokenizer(
                    text,
                    padding="max_length",
                    truncation=True,
                    max_length=512,
                    return_tensors="pt",
                )
                return self._encode_text(
                    tokens["input_ids"],
                    tokens["attention_mask"],
                    adapter_type="proximity",
                )
        else:
            tokens = self.tokenizer(
                text,
                padding="max_length",
                truncation=True,
                max_length=512,
                return_tensors="pt",
            )
            return self._encode_text(
                tokens["input_ids"], tokens["attention_mask"], adapter_type="proximity"
            )

    def _query(self, query_embedding: torch.Tensor, top_k: int = 20) -> list[int]:
        if self.faiss_index is None:
            raise ValueError(
                "FAISS index has not been created yet. Call create_index first."
            )

        query_vector = query_embedding.cpu().numpy()
        faiss.normalize_L2(query_vector)
        distances, indices = self.faiss_index.search(query_vector, top_k)

        return indices[0].tolist()

    def query(self, query_text: str, n: int, return_keys: bool = False) -> list:
        query_embedding = self.encode_query(query_text)
        indices = self._query(query_embedding, n)

        if return_keys:
            results = [(self.keys[i], self.values[i]) for i in indices]
        else:
            results = [self.values[i] for i in indices]

        return results

    def _encode_paper_batch(
        self, textList: list[str], show_progress_bar: bool = True
    ) -> torch.Tensor:
        self.model.set_active_adapters("proximity")
        batch_size = 256
        embeddings = []

        should_show_progress = show_progress_bar

        iterator = range(0, len(textList), batch_size)
        if should_show_progress:
            iterator = tqdm(iterator, desc="Processing document embeddings")

        for i in iterator:
            batch_texts = textList[i : i + batch_size]
            encoded = self.tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                return_tensors="pt",
                max_length=512,
            ).to(self.device)

            with torch.no_grad():
                outputs = self.model(**encoded)

            batch_embeddings = outputs.last_hidden_state[:, 0, :].cpu()
            embeddings.append(batch_embeddings)

            if i % (batch_size * 10) == 0:
                torch.cuda.empty_cache()

        embeddings = torch.cat(embeddings, dim=0)
        return embeddings

    def create_index(self, key_value_pairs: dict[str, int]) -> None:
        super().create_index(key_value_pairs)
        self.index = self._encode_paper_batch(self.keys)

        vector_dim = self.index.shape[1]
        index_flat = faiss.IndexFlatIP(vector_dim)
        index_vectors = self.index.numpy()
        faiss.normalize_L2(index_vectors)

        index_flat.add(index_vectors)
        self.faiss_index = index_flat

In [20]:
import os

import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader
from tqdm import tqdm


class Specter2Trainer:
    def __init__(self, model_wrapper):
        self.model_wrapper = model_wrapper
        self.model = model_wrapper.model
        self.tokenizer = model_wrapper.tokenizer
        self.device = model_wrapper.device
        print("Trainer initialized")

    def train(
        self,
        train_data,
        val_data=None,
        output_dir="./specter2_adhoc_query_finetuned",
        lr=2e-5,
        batch_size=2,
        epochs=3,
        margin=1.0,
        eval_steps=100,
        weight_decay=0.01,
        warmup_ratio=0.1,
    ):
        train_dataset = LitSearchTripletDataset(train_data, self.tokenizer)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

        optimizer = AdamW(
            [p for p in self.model.parameters() if p.requires_grad],
            lr=lr,
            weight_decay=weight_decay,
        )

        total_steps = len(train_loader) * epochs
        scheduler = OneCycleLR(
            optimizer,
            max_lr=lr,
            total_steps=total_steps,
            pct_start=warmup_ratio,
            anneal_strategy="linear",
        )

        triplet_loss = TripletMarginLoss(margin=margin)
        global_step = 0
        best_val_loss = float("inf")

        for epoch in range(epochs):
            self.model.train()
            epoch_loss = 0.0
            progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")

            for batch in progress_bar:
                query_emb = self.model_wrapper.encode_query(
                    batch["query"], no_grad=False
                )
                pos_emb = self.model_wrapper.encode_paper(
                    batch["positive_title"], batch["positive_abstract"], no_grad=False
                )
                neg_emb = self.model_wrapper.encode_paper(
                    batch["negative_title"], batch["negative_abstract"], no_grad=False
                )

                loss = triplet_loss(query_emb, pos_emb, neg_emb)
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    [p for p in self.model.parameters() if p.requires_grad], 1.0
                )
                optimizer.step()
                scheduler.step()

                epoch_loss += loss.item()
                progress_bar.set_postfix({"loss": loss.item()})

                del query_emb, pos_emb, neg_emb, loss
                torch.cuda.empty_cache()

                global_step += 1
                if val_data is not None and global_step % eval_steps == 0:
                    val_loss = self.evaluate(val_data, batch_size)
                    print(f"Validation Loss: {val_loss:.4f}")

                    if val_loss < best_val_loss:
                        best_val_loss = val_loss
                        self.save_model(output_dir)
                        print(f"Model saved to {output_dir} (val_loss: {val_loss:.4f})")

                    self.model.train()

            avg_epoch_loss = epoch_loss / len(train_loader)
            print(f"Epoch {epoch + 1}/{epochs} - Avg Loss: {avg_epoch_loss:.4f}")

        if val_data is None or epochs % eval_steps != 0:
            self.save_model(output_dir)

        return self.model

    def evaluate(self, val_data, batch_size=8):
        val_dataset = LitSearchTripletDataset(val_data, self.tokenizer)
        val_loader = DataLoader(val_dataset, batch_size=batch_size)
        self.model.eval()
        triplet_loss = TripletMarginLoss(margin=1.0)
        total_loss = 0.0

        with torch.no_grad():
            for batch in val_loader:
                query_emb = self.model_wrapper.encode_query(
                    batch["query"], no_grad=False
                )
                pos_emb = self.model_wrapper.encode_paper(
                    batch["positive_title"], batch["positive_abstract"], no_grad=False
                )
                neg_emb = self.model_wrapper.encode_paper(
                    batch["negative_title"], batch["negative_abstract"], no_grad=False
                )

                loss = triplet_loss(query_emb, pos_emb, neg_emb)
                total_loss += loss.item()

        return total_loss / len(val_loader)

    def save_model(self, output_dir):
        os.makedirs(output_dir, exist_ok=True)
        self.model_wrapper.save_model(output_dir)
        self.tokenizer.save_pretrained(output_dir)
        print(f"어댑터가 {output_dir}에 저장되었습니다.")

In [7]:
import json
import random

def main():
    file_path = "./triplet_data.json"
    with open(file_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    random.shuffle(data)

    n_total = len(data)
    n_train = int(n_total * 0.8)
    train_data = data[:n_train]
    val_data = data[n_train:]

    print(f"# of train: {len(train_data)}")
    print(f"# of val: {len(val_data)}")

    print("Specter2 생성 시작")
    model = Specter2()
    print("Specter2 생성 완료")
    print("Specter2Trainer 생성 시작")
    trainer = Specter2Trainer(model)
    print("Specter2Trainer 생성 완료")
    print("train 호출")
    trainer.train(
        train_data=train_data,
        val_data=val_data,
        output_dir="./output",
        lr=2e-4,
        batch_size=8,
        epochs=5,
        margin=1.0,
        eval_steps=50,
        weight_decay=0.01,
        warmup_ratio=0.1,
    )
    print("Fine-tuning complete and model saved!")


In [24]:
torch.cuda.empty_cache()

In [8]:
main()

# of train: 4196
# of val: 1049
Specter2 생성 시작


Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Specter2 생성 완료
Specter2Trainer 생성 시작
Trainer initialized
Specter2Trainer 생성 완료
train 호출


Epoch 1/5:   0%|          | 0/525 [00:00<?, ?it/s]There are adapters available but none are activated for the forward pass.
Epoch 1/5:  10%|▉         | 50/525 [01:15<2:00:35, 15.23s/it, loss=0.193]

Validation Loss: 0.2097
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.2097)


Epoch 1/5:  19%|█▉        | 100/525 [02:30<1:47:44, 15.21s/it, loss=0.243]

Validation Loss: 0.1588
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.1588)


Epoch 1/5:  29%|██▊       | 150/525 [03:45<1:35:04, 15.21s/it, loss=0.00058]

Validation Loss: 0.1435
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.1435)


Epoch 1/5:  38%|███▊      | 200/525 [05:00<1:22:23, 15.21s/it, loss=0.275]  

Validation Loss: 0.1384
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.1384)


Epoch 1/5:  48%|████▊     | 250/525 [06:14<1:09:37, 15.19s/it, loss=0.0385]

Validation Loss: 0.1337
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.1337)


Epoch 1/5:  57%|█████▋    | 300/525 [07:29<57:01, 15.20s/it, loss=0.00809] 

Validation Loss: 0.1286
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.1286)


Epoch 1/5:  67%|██████▋   | 350/525 [08:44<44:20, 15.20s/it, loss=0.138]  

Validation Loss: 0.1205
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.1205)


Epoch 1/5:  76%|███████▌  | 400/525 [09:59<31:40, 15.20s/it, loss=0]     

Validation Loss: 0.1162
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.1162)


Epoch 1/5:  86%|████████▌ | 450/525 [11:14<19:00, 15.21s/it, loss=0.242]  

Validation Loss: 0.1136
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.1136)


Epoch 1/5:  95%|█████████▌| 500/525 [12:29<06:20, 15.21s/it, loss=0.304] 

Validation Loss: 0.1145


Epoch 1/5: 100%|██████████| 525/525 [12:41<00:00,  1.45s/it, loss=0.146]  


Epoch 1/5 - Avg Loss: 0.1385


Epoch 2/5:   5%|▍         | 25/525 [01:01<2:06:43, 15.21s/it, loss=0.115]

Validation Loss: 0.1142


Epoch 2/5:  14%|█▍        | 75/525 [02:16<1:54:09, 15.22s/it, loss=0]    

Validation Loss: 0.1092
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.1092)


Epoch 2/5:  24%|██▍       | 125/525 [03:31<1:41:25, 15.21s/it, loss=0.00829]

Validation Loss: 0.1083
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.1083)


Epoch 2/5:  33%|███▎      | 175/525 [04:46<1:28:45, 15.22s/it, loss=0.314]  

Validation Loss: 0.1013
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.1013)


Epoch 2/5:  43%|████▎     | 225/525 [06:01<1:16:05, 15.22s/it, loss=0.0192]

Validation Loss: 0.0978
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.0978)


Epoch 2/5:  52%|█████▏    | 275/525 [07:16<1:03:18, 15.19s/it, loss=0]     

Validation Loss: 0.1012


Epoch 2/5:  62%|██████▏   | 325/525 [08:31<50:43, 15.22s/it, loss=0]       

Validation Loss: 0.1019


Epoch 2/5:  71%|███████▏  | 375/525 [09:46<38:05, 15.24s/it, loss=0.205] 

Validation Loss: 0.0976
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.0976)


Epoch 2/5:  81%|████████  | 425/525 [11:01<25:21, 15.22s/it, loss=0.00455]

Validation Loss: 0.0987


Epoch 2/5:  90%|█████████ | 475/525 [12:16<12:42, 15.24s/it, loss=0]      

Validation Loss: 0.0961
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.0961)


Epoch 2/5: 100%|██████████| 525/525 [13:31<00:00,  1.55s/it, loss=0.432]  


Validation Loss: 0.0918
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.0918)
Epoch 2/5 - Avg Loss: 0.0674


Epoch 3/5:  10%|▉         | 50/525 [01:14<2:00:29, 15.22s/it, loss=0.0569]

Validation Loss: 0.0912
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.0912)


Epoch 3/5:  19%|█▉        | 100/525 [02:29<1:47:44, 15.21s/it, loss=0]    

Validation Loss: 0.0895
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.0895)


Epoch 3/5:  29%|██▊       | 150/525 [03:44<1:34:59, 15.20s/it, loss=0.0328]

Validation Loss: 0.0897


Epoch 3/5:  38%|███▊      | 200/525 [04:59<1:22:21, 15.21s/it, loss=0.0716]

Validation Loss: 0.0912


Epoch 3/5:  48%|████▊     | 250/525 [06:14<1:09:42, 15.21s/it, loss=0.00726]

Validation Loss: 0.0895
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.0895)


Epoch 3/5:  57%|█████▋    | 300/525 [07:29<57:09, 15.24s/it, loss=0.0276]   

Validation Loss: 0.0875
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.0875)


Epoch 3/5:  67%|██████▋   | 350/525 [08:44<44:24, 15.23s/it, loss=0.0115] 

Validation Loss: 0.0810
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.0810)


Epoch 3/5:  76%|███████▌  | 400/525 [09:59<31:39, 15.20s/it, loss=0]      

Validation Loss: 0.0810


Epoch 3/5:  86%|████████▌ | 450/525 [11:14<19:00, 15.21s/it, loss=0]      

Validation Loss: 0.0830


Epoch 3/5:  95%|█████████▌| 500/525 [12:29<06:20, 15.21s/it, loss=0]       

Validation Loss: 0.0819


Epoch 3/5: 100%|██████████| 525/525 [12:41<00:00,  1.45s/it, loss=0]      


Epoch 3/5 - Avg Loss: 0.0289


Epoch 4/5:   5%|▍         | 25/525 [01:02<2:07:01, 15.24s/it, loss=0]   

Validation Loss: 0.0802
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.0802)


Epoch 4/5:  14%|█▍        | 75/525 [02:17<1:54:12, 15.23s/it, loss=0]     

Validation Loss: 0.0796
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.0796)


Epoch 4/5:  24%|██▍       | 125/525 [03:32<1:41:23, 15.21s/it, loss=0.0439]

Validation Loss: 0.0805


Epoch 4/5:  33%|███▎      | 175/525 [04:47<1:28:51, 15.23s/it, loss=0.022] 

Validation Loss: 0.0796


Epoch 4/5:  43%|████▎     | 225/525 [06:02<1:16:14, 15.25s/it, loss=0]     

Validation Loss: 0.0775
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.0775)


Epoch 4/5:  52%|█████▏    | 275/525 [07:17<1:03:27, 15.23s/it, loss=0]     

Validation Loss: 0.0775
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.0775)


Epoch 4/5:  62%|██████▏   | 325/525 [08:32<50:44, 15.22s/it, loss=0.0645] 

Validation Loss: 0.0785


Epoch 4/5:  71%|███████▏  | 375/525 [09:47<38:09, 15.26s/it, loss=0]      

Validation Loss: 0.0781


Epoch 4/5:  81%|████████  | 425/525 [11:02<25:26, 15.27s/it, loss=0]     

Validation Loss: 0.0773
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.0773)


Epoch 4/5:  90%|█████████ | 475/525 [12:17<12:43, 15.27s/it, loss=0]      

Validation Loss: 0.0772
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.0772)


Epoch 4/5: 100%|█████████▉| 524/525 [12:43<00:00,  1.92it/s, loss=0]      

Validation Loss: 0.0757
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.0757)


Epoch 4/5: 100%|██████████| 525/525 [13:32<00:00,  1.55s/it, loss=0]


Epoch 4/5 - Avg Loss: 0.0110


Epoch 5/5:  10%|▉         | 50/525 [01:15<2:00:58, 15.28s/it, loss=0]    

Validation Loss: 0.0750
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.0750)


Epoch 5/5:  19%|█▉        | 100/525 [02:30<1:48:06, 15.26s/it, loss=0]    

Validation Loss: 0.0748
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.0748)


Epoch 5/5:  29%|██▊       | 150/525 [03:45<1:35:28, 15.28s/it, loss=0]    

Validation Loss: 0.0747
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Model saved to ./output (val_loss: 0.0747)


Epoch 5/5:  38%|███▊      | 200/525 [05:00<1:22:35, 15.25s/it, loss=0]    

Validation Loss: 0.0749


Epoch 5/5:  48%|████▊     | 250/525 [06:15<1:09:53, 15.25s/it, loss=0.0549]

Validation Loss: 0.0751


Epoch 5/5:  57%|█████▋    | 300/525 [07:31<57:08, 15.24s/it, loss=0]       

Validation Loss: 0.0755


Epoch 5/5:  67%|██████▋   | 350/525 [08:46<44:29, 15.25s/it, loss=0]      

Validation Loss: 0.0756


Epoch 5/5:  76%|███████▌  | 400/525 [10:01<31:47, 15.26s/it, loss=0]      

Validation Loss: 0.0755


Epoch 5/5:  86%|████████▌ | 450/525 [11:17<19:05, 15.28s/it, loss=0]      

Validation Loss: 0.0755


Epoch 5/5:  95%|█████████▌| 500/525 [12:32<06:21, 15.26s/it, loss=0.0361] 

Validation Loss: 0.0754


Epoch 5/5: 100%|██████████| 525/525 [12:45<00:00,  1.46s/it, loss=0]     

Epoch 5/5 - Avg Loss: 0.0044
어댑터가 ./output에 저장되었습니다.
어댑터가 ./output에 저장되었습니다.
Fine-tuning complete and model saved!





Load finetune model

In [10]:
import os

import faiss
import torch
from adapters import AutoAdapterModel
from tqdm import tqdm
from transformers.models.auto.tokenization_auto import AutoTokenizer


class Finetuned_Specter2(Retrieval):
    def __init__(self, base_model_name="allenai/specter2_base", device=None):
        super().__init__()
        self.keys = []
        self.values = []
        self.index = None
        self.faiss_index = None
        output_dir = "./output"

        if device is None:
            if torch.cuda.is_available():
                self.device = torch.device("cuda")
            elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
                self.device = torch.device("mps")
            else:
                self.device = torch.device("cpu")
        else:
            self.device = device
            
        self.tokenizer = AutoTokenizer.from_pretrained(output_dir)
        self.model = AutoAdapterModel.from_pretrained(base_model_name)
        self.model.load_adapter("allenai/specter2", source="hf", load_as="proximity")
        self.model.load_adapter(output_dir, load_as="adhoc_query")

        for param in self.model.parameters():
            param.requires_grad = False

        for name, param in self.model.named_parameters():
            if "adapters.adhoc_query" in name:
                param.requires_grad = True

        self.model.to(self.device)

    def parameters(self):
        return self.model.parameters()

    def clear(self):
        super().clear()
        self.index = None
        self.faiss_index = None

    def save_model(self, output_dir):
        os.makedirs(output_dir, exist_ok=True)

        self.model.save_adapter(output_dir, "adhoc_query")

        self.tokenizer.save_pretrained(output_dir)

        print(f"어댑터가 {output_dir}에 저장되었습니다.")

    def _encode_text(self, input_ids, attention_mask, adapter_type="proximity"):
        """

        adapter_type: query -> "adhoc_query", paper -> "proximity"
        """
        self.model.set_active_adapters(adapter_type)

        outputs = self.model(
            input_ids=input_ids.to(self.device),
            attention_mask=attention_mask.to(self.device),
        )
        embeddings = outputs.last_hidden_state[:, 0, :]
        return embeddings

    def encode_query(self, query_text: str, no_grad=True):
        self.model.eval()
        if no_grad:
            with torch.no_grad():
                tokens = self.tokenizer(
                    query_text,
                    padding="max_length",
                    truncation=True,
                    max_length=512,
                    return_tensors="pt",
                )
                return self._encode_text(
                    tokens["input_ids"],
                    tokens["attention_mask"],
                    adapter_type="adhoc_query",
                )
        else:
            tokens = self.tokenizer(
                query_text,
                padding="max_length",
                truncation=True,
                max_length=512,
                return_tensors="pt",
            )
            return self._encode_text(
                tokens["input_ids"],
                tokens["attention_mask"],
                adapter_type="adhoc_query",
            )

    def encode_paper(self, title: str, abstract: str, no_grad=True):
        self.model.eval()
        sep_token = self.tokenizer.sep_token
        if isinstance(title, list) and isinstance(abstract, list):
            text = [t + sep_token + a for t, a in zip(title, abstract)]
        else:
            text = title + sep_token + abstract
        if no_grad:
            with torch.no_grad():
                tokens = self.tokenizer(
                    text,
                    padding="max_length",
                    truncation=True,
                    max_length=512,
                    return_tensors="pt",
                )
                return self._encode_text(
                    tokens["input_ids"],
                    tokens["attention_mask"],
                    adapter_type="proximity",
                )
        else:
            tokens = self.tokenizer(
                text,
                padding="max_length",
                truncation=True,
                max_length=512,
                return_tensors="pt",
            )
            return self._encode_text(
                tokens["input_ids"], tokens["attention_mask"], adapter_type="proximity"
            )

    def _query(self, query_embedding: torch.Tensor, top_k: int = 20) -> list[int]:
        if self.faiss_index is None:
            raise ValueError(
                "FAISS index has not been created yet. Call create_index first."
            )

        query_vector = query_embedding.cpu().numpy()
        faiss.normalize_L2(query_vector)
        distances, indices = self.faiss_index.search(query_vector, top_k)

        return indices[0].tolist()

    def query(self, query_text: str, n: int, return_keys: bool = False) -> list:
        query_embedding = self.encode_query(query_text)
        indices = self._query(query_embedding, n)

        if return_keys:
            results = [(self.keys[i], self.values[i]) for i in indices]
        else:
            results = [self.values[i] for i in indices]

        return results

    def _encode_paper_batch(
        self, textList: list[str], show_progress_bar: bool = True
    ) -> torch.Tensor:
        self.model.set_active_adapters("proximity")
        batch_size = 256
        embeddings = []

        should_show_progress = show_progress_bar

        iterator = range(0, len(textList), batch_size)
        if should_show_progress:
            iterator = tqdm(iterator, desc="Processing document embeddings")

        for i in iterator:
            batch_texts = textList[i : i + batch_size]
            encoded = self.tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                return_tensors="pt",
                max_length=512,
            ).to(self.device)

            with torch.no_grad():
                outputs = self.model(**encoded)

            batch_embeddings = outputs.last_hidden_state[:, 0, :].cpu()
            embeddings.append(batch_embeddings)

            if i % (batch_size * 10) == 0:
                torch.cuda.empty_cache()

        embeddings = torch.cat(embeddings, dim=0)
        return embeddings

    def create_index(self, key_value_pairs: dict[str, int]) -> None:
        super().create_index(key_value_pairs)
        self.index = self._encode_paper_batch(self.keys)

        vector_dim = self.index.shape[1]
        index_flat = faiss.IndexFlatIP(vector_dim)
        index_vectors = self.index.numpy()
        faiss.normalize_L2(index_vectors)

        index_flat.add(index_vectors)
        self.faiss_index = index_flat

In [11]:
model = Finetuned_Specter2()

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

In [12]:
model.create_index(kv_pairs)

There are adapters available but none are activated for the forward pass.
Processing document embeddings: 100%|██████████| 226/226 [14:09<00:00,  3.76s/it]


In [15]:
query_data

Dataset({
    features: ['query_set', 'query', 'specificity', 'quality', 'corpusids'],
    num_rows: 597
})

evaluate finetuned_specter2

In [9]:
def retrieve_for_dataset(dataset, model, k=20):
    def retrieval_fn(example):
        top_k = model.query(example["query"], k)
        example["retrieved"] = top_k
        return example
    return dataset.map(retrieval_fn)

In [26]:
spec0_inline_ds = retrieve_for_dataset(spec0_inline_ds, model)
spec0_manual_ds = retrieve_for_dataset(spec0_manual_ds, model)
spec1_inline_ds = retrieve_for_dataset(spec1_inline_ds, model)
spec1_manual_ds = retrieve_for_dataset(spec1_manual_ds, model)
spec0_ds = retrieve_for_dataset(spec0_ds, model)
spec1_ds = retrieve_for_dataset(spec1_ds, model)

Map:   0%|          | 0/120 [00:00<?, ? examples/s]

Map:   0%|          | 0/35 [00:00<?, ? examples/s]

Map:   0%|          | 0/231 [00:00<?, ? examples/s]

Map:   0%|          | 0/211 [00:00<?, ? examples/s]

Map:   0%|          | 0/155 [00:00<?, ? examples/s]

Map:   0%|          | 0/442 [00:00<?, ? examples/s]

In [30]:
spec0_inline_recall20 = mean_recall(spec0_inline_ds, 20)
spec0_manual_recall20 = mean_recall(spec0_manual_ds, 20)
spec0_avg_recall20 = mean_recall(spec0_ds, 20)

spec1_inline_recall5 = mean_recall(spec1_inline_ds, 5)
spec1_inline_recall20 = mean_recall(spec1_inline_ds, 20)
spec1_manual_recall5 = mean_recall(spec1_manual_ds, 5)
spec1_manual_recall20 = mean_recall(spec1_manual_ds, 20)
spec1_avg_recall20 = mean_recall(spec1_ds, 20)

print(f"spec0 inline Recall@20: {spec0_inline_recall20:.4f}")
print(f"spec0 manual Recall@20: {spec0_manual_recall20:.4f}")
print(f"spec0 avg Recall@20:    {spec0_avg_recall20:.4f}")
print()
print(f"spec1 inline Recall@5:  {spec1_inline_recall5:.4f}")
print(f"spec1 inline Recall@20: {spec1_inline_recall20:.4f}")
print(f"spec1 manual Recall@5:  {spec1_manual_recall5:.4f}")
print(f"spec1 manual Recall@20: {spec1_manual_recall20:.4f}")
print(f"spec1 avg Recall@20:    {spec1_avg_recall20:.4f}")

spec0 inline Recall@20: 0.4637
spec0 manual Recall@20: 0.4286
spec0 avg Recall@20:    0.4558

spec1 inline Recall@5:  0.4069
spec1 inline Recall@20: 0.6190
spec1 manual Recall@5:  0.4360
spec1 manual Recall@20: 0.5924
spec1 avg Recall@20:    0.6063
