In [1]:
from datasets import load_dataset

ds = load_dataset("microsoft/ms_marco", "v1.1")

README.md: 0.00B [00:00, ?B/s]

v1.1/validation-00000-of-00001.parquet:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

v1.1/train-00000-of-00001.parquet:   0%|          | 0.00/175M [00:00<?, ?B/s]

v1.1/test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

Generating validation split:   0%|          | 0/10047 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/82326 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/9650 [00:00<?, ? examples/s]

Инициализация библиотек и глобальных переменных

In [12]:
train = ds['train']
valid = ds['validation']

import hashlib
from nltk.stem import PorterStemmer
from sentence_transformers import SentenceTransformer, util
import numpy as np

stemmer = PorterStemmer()

embedder = SentenceTransformer('all-MiniLM-L6-v2')
embedder.to('cuda')

SentenceTransformer(
  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Normalize()
)

Подготовка *тренировочных данных*

In [None]:
def get_text_hash(text):
    return hashlib.sha256(text.encode()).hexdigest()

passage_hash_to_text = {}
for ex in train:
    for ptext in ex['passages']['passage_text']:
        p_hash = get_text_hash(ptext)
        if p_hash not in passage_hash_to_text:
            passage_hash_to_text[p_hash] = ptext

queries = {ex['query_id']: ex['query'] for ex in train}

passage_ids = list(passage_hash_to_text.keys())
passage_texts = list(passage_hash_to_text.values())
passage_embeddings = embedder.encode(passage_texts, convert_to_tensor=True, show_progress_bar=True)
passage_emb_map = {pid: emb for pid, emb in zip(passage_ids, passage_embeddings)}

query_ids = list(queries.keys())
query_texts = list(queries.values())
query_embeddings = embedder.encode(query_texts, convert_to_tensor=True, show_progress_bar=True)
query_emb_map = {qid: emb for qid, emb in zip(query_ids, query_embeddings)}

stemmed_passages = {}
for idx, passage in passage_hash_to_text.items():
    processed_text = " ".join([stemmer.stem(token) for token in passage.lower().split() if token.isalpha()])
    stemmed_passages[idx] = processed_text

stemmed_query = {}
for idx, query in queries.items():
    processed_query = " ".join([stemmer.stem(token) for token in query.lower().split() if token.isalpha()])
    stemmed_query[idx] = processed_query

Batches:   0%|          | 0/19591 [00:00<?, ?it/s]

Batches:   0%|          | 0/2573 [00:00<?, ?it/s]

Подготовка *валидационных данных* для оценки качества модели.

In [None]:
queries_valid = {ex['query_id']: ex['query'] for ex in valid}

stemmed_query_valid = {
    idx: " ".join([stemmer.stem(token) for token in query.lower().split() if token.isalpha()])
    for idx, query in queries_valid.items()
}

query_embeddings_valid = embedder.encode(list(queries_valid.values()), convert_to_tensor=True, show_progress_bar=True)
query_emb_map_valid = {qid: emb for qid, emb in zip(queries_valid.keys(), query_embeddings_valid)}

passage_hash_to_text_valid = {}
for ex in valid:
    if 'passages' in ex and ex['passages']['passage_text']:
        for ptext in ex['passages']['passage_text']:
            p_hash = get_text_hash(ptext)
            if p_hash not in passage_hash_to_text_valid:
                passage_hash_to_text_valid[p_hash] = ptext

stemmed_passages_valid = {
    idx: " ".join([stemmer.stem(token) for token in passage.lower().split() if token.isalpha()])
    for idx, passage in passage_hash_to_text_valid.items()
}
passage_embeddings_valid = embedder.encode(list(passage_hash_to_text_valid.values()), convert_to_tensor=True, show_progress_bar=True)
passage_emb_map_valid = {p_hash: emb for p_hash, emb in zip(passage_hash_to_text_valid.keys(), passage_embeddings_valid)}

Batches:   0%|          | 0/314 [00:00<?, ?it/s]

Batches:   0%|          | 0/2533 [00:00<?, ?it/s]

Вспомогательные функции.

In [None]:
def calculate_query_density(query_tokens, passage_tokens):
    if not passage_tokens:
        return 0
    
    query_token_set = set(query_tokens)
    term_frequency = sum(1 for token in passage_tokens if token in query_token_set)

    return term_frequency / len(passage_tokens)

def calculate_term_proximity(query_tokens, passage_tokens):
    if len(set(query_tokens)) < 2:
        return 0

    positions = {token: [] for token in set(query_tokens)}
    for i, p_token in enumerate(passage_tokens):
        if p_token in positions:
            positions[p_token].append(i)

    if any(not pos for pos in positions.values()):
        return 1e6

    all_found_positions = [pos for sublist in positions.values() for pos in sublist]
    if not all_found_positions:
        return 1e6
        
    proximity = max(all_found_positions) - min(all_found_positions)
    return proximity / len(passage_tokens) if passage_tokens else 1e6 - len([token for token in positions.keys() if len(positions[token]) > 0])

def generate_feature_vector(query_tokens, passage_tokens, query_embedding, passage_embedding):
    cosine_sim = util.pytorch_cos_sim(query_embedding, passage_embedding).item()
    
    passage_tokens_list = passage_tokens.split()
    word_in_beginning = int(any(token in passage_tokens_list[:10] for token in query_tokens.split()))
    
    query_token_set = set(query_tokens.split())
    passage_token_set = set(passage_tokens_list)
    query_coverage = len(query_token_set & passage_token_set) / len(query_token_set) if query_token_set else 0
    
    query_density = calculate_query_density(query_tokens.split(), passage_tokens_list)
    
    term_proximity = calculate_term_proximity(query_tokens.split(), passage_tokens_list)

    token_in_passage = sum(passage_tokens_list.count(token) for token in query_tokens.split())
    
    return [cosine_sim, token_in_passage, word_in_beginning, query_coverage, query_density, term_proximity]

def create_ranking_dataset(dataset, stemmed_queries, query_emb_map, stemmed_passages, passage_emb_map):
    X, y = [], []
    
    for ex in dataset:
        query_id = ex['query_id']
        if query_id not in stemmed_queries or not ('passages' in ex and ex['passages']['passage_text']):
            continue

        candidates_for_query = []
        labels_for_query = []
        
        for i, ptext in enumerate(ex['passages']['passage_text']):
            p_hash = get_text_hash(ptext)
            if p_hash not in stemmed_passages:
                continue
            
            feature_vector = generate_feature_vector(
                stemmed_queries[query_id],
                stemmed_passages[p_hash],
                query_emb_map[query_id],
                passage_emb_map[p_hash]
            )
            candidates_for_query.append(feature_vector)
            labels_for_query.append(ex['passages']['is_selected'][i])

        if candidates_for_query:
            X.extend(candidates_for_query)
            y.extend(labels_for_query)
            
    return np.array(X), np.array(y)

Создание *обучающей* и *валидационной* выборок

In [13]:
X_train, y_train = create_ranking_dataset(train, stemmed_query, query_emb_map, stemmed_passages, passage_emb_map)
print(f"Размер обучающей выборки: X={X_train.shape}, y={y_train.shape}")

X_valid, y_valid = create_ranking_dataset(valid, stemmed_query_valid, query_emb_map_valid, stemmed_passages_valid, passage_emb_map_valid)
print(f"Размер валидационной выборки: X={X_valid.shape}, y={y_valid.shape}")

Размер обучающей выборки: X=(676193, 6), y=(676193,)
Размер валидационной выборки: X=(82360, 6), y=(82360,)


Обучение модели

In [None]:
from sklearn.linear_model import LinearRegression

model = LinearRegression()
model.fit(X_train, y_train)

print("Модель обучена.")

feature_names = ["cosine_sim", "word_in_beginning", "query_coverage", "query_density", "term_proximity"]
print("\nВеса модели (коэффициенты):")
for name, coef in zip(feature_names, model.coef_):
    print(f"  - {name}: {coef:.4f}")

Оценка mrr модели

In [17]:
def evaluate_mrr(dataset, model, stemmed_queries, query_emb_map, stemmed_passages, passage_emb_map):
    reciprocal_ranks = []
    for ex in dataset:
        query_id = ex['query_id']
        if query_id not in stemmed_queries or not ('passages' in ex and ex['passages']['passage_text']):
            continue

        candidate_features = []
        true_relevant_idx = -1
        
        for i, ptext in enumerate(ex['passages']['passage_text']):
            p_hash = get_text_hash(ptext)
            if p_hash not in stemmed_passages:
                continue
                
            feature_vector = generate_feature_vector(
                stemmed_queries[query_id],
                stemmed_passages[p_hash],
                query_emb_map[query_id],
                passage_emb_map[p_hash]
            )
            candidate_features.append(feature_vector)
            
            if ex['passages']['is_selected'][i] == 1:
                true_relevant_idx = len(candidate_features) - 1

        if not candidate_features or true_relevant_idx == -1:
            reciprocal_ranks.append(0)
            continue

        scores = model.predict(np.array(candidate_features))
        sorted_indices = np.argsort(scores)[::-1]
        
        try:
            rank = np.where(sorted_indices == true_relevant_idx)[0][0] + 1
            reciprocal_ranks.append(1 / rank)
        except IndexError:
            reciprocal_ranks.append(0)

    return np.mean(reciprocal_ranks) if reciprocal_ranks else 0

print("\n--- Расчет метрик качества ---")

# Оценка на обучающей выборке
mrr_train = evaluate_mrr(train, model, stemmed_query, query_emb_map, stemmed_passages, passage_emb_map)

# Оценка на валидационной выборке
mrr_valid = evaluate_mrr(valid, model, stemmed_query_valid, query_emb_map_valid, stemmed_passages_valid, passage_emb_map_valid)

print("\n--- РЕЗУЛЬТАТЫ ---")
print(f"MRR (LinearRegression) на Обучающей выборке (Train): {mrr_train:.4f}")
print(f"MRR (LinearRegression) на Валидационной выборке (Valid): {mrr_valid:.4f}")


--- Расчет метрик качества ---

--- РЕЗУЛЬТАТЫ ---
MRR (LinearRegression) на Обучающей выборке (Train): 0.5344
MRR (LinearRegression) на Валидационной выборке (Valid): 0.5358
