In [None]:
import os

from rule_gen.reddit.colbert.modeling import ColBertForSequenceClassification
os.environ["TOKENIZERS_PARALLELISM"] = "true"

In [None]:
import torch
from transformers import AutoTokenizer
from typing import Dict

def initialize_model(model_path: str) -> tuple:
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    base_model = 'bert-base-uncased'
    tokenizer = AutoTokenizer.from_pretrained(base_model)
    model = ColBertForSequenceClassification.from_pretrained(model_path)
    model.colbert_set_up(tokenizer)
    model.to(device)
    model.eval()
    
    return tokenizer, model, device

def preprocess_texts(
    query: str,
    document: str,
    tokenizer: AutoTokenizer,
    device: str,
    max_length: int = 512
) -> Dict:
    query_encoding = tokenizer(
        query,
        padding='max_length',
        truncation=True,
        max_length=max_length,
        return_tensors='pt'
    )

    doc_encoding = tokenizer(
        document,
        padding='max_length',
        truncation=True,
        max_length=max_length,
        return_tensors='pt'
    )

    return {
        'query_input_ids': query_encoding['input_ids'].to(device),
        'query_attention_mask': query_encoding['attention_mask'].to(device),
        'doc_input_ids': doc_encoding['input_ids'].to(device),
        'doc_attention_mask': doc_encoding['attention_mask'].to(device)
    }

def predict_relevance(
    query: str,
    document: str,
    model: ColBertForSequenceClassification,
    tokenizer: AutoTokenizer,
    device: str,
    max_length: int = 512
) -> float:
    inputs = preprocess_texts(query, document, tokenizer, device, max_length)
    
    with torch.no_grad():
        outputs = model(**inputs)
        scores = torch.sigmoid(outputs.logits).cpu().numpy()
        return float(scores[0][0])

In [None]:
from desk_util.path_helper import get_model_save_path
model_name = "col1-name"
model_path = get_model_save_path(model_name)
tokenizer, model, device = initialize_model(model_path)

# Make predictions
score = predict_relevance(
    query="example query",
    document="example document",
    model=model,
    tokenizer=tokenizer,
    device=device
)