In [None]:
!pip install rank_bm25

In [None]:
!pip install sentence_transformers

In [None]:
!pip install datasets

In [1]:
import pandas as pd
import torch
import numpy as np
import pickle
import importlib
import ranking_metrics, evaluation_metrics

from transformers import GPT2LMHeadModel, GPT2Tokenizer
from tqdm import tqdm
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, util
from ranking_metrics import (
    RankingMetrics,
    Bm25,
    LaBSE,
    MsMarcoST,
    MsMarcoCE  
)

importlib.reload(evaluation_metrics)
importlib.reload(ranking_metrics)


np.random.seed(42)
torch.manual_seed(42)

ModuleNotFoundError: No module named 'ranking_metrics'

In [8]:
def preprocess_df(df):
    ans_norm = []
    selected = []
    for row in df.iterrows():
        try:
            ans_norm.append(row[1]["answers"][0])
        except IndexError:
            ans_norm.append(None)
        selected.append(row[1]["passages"]["is_selected"])
    df["answers_norm"] = ans_norm
    df["selected"] = selected
    df.dropna(inplace=True)
    return df


def save_gen_text(queries, generated_answers):
    query_ans_dict = {k: v for k, v in zip(queries, generated_answers)}
    with open(f'generated_text_large.pickle', 'wb') as f:
        pickle.dump(query_ans_dict, f)
        

def load_tokenizer_and_model(model_name_or_path):
    return GPT2Tokenizer.from_pretrained(model_name_or_path), GPT2LMHeadModel.from_pretrained(model_name_or_path).to(device)


def generate(
    model, tok, text,
    do_sample=True, max_length=60, repetition_penalty=5.0,
    top_k=5, top_p=0.95, temperature=1,
    num_beams=None,
    no_repeat_ngram_size=3
):
    input_ids = tok.encode(text, return_tensors="pt").to(device)
    out = model.generate(
        input_ids.to(device), 
        min_length=2, 
        max_length=max_length, 
        eos_token_id=5, 
        # pad_token=1,
        top_k = top_k,
        top_p = top_p,
        no_repeat_ngram_size=no_repeat_ngram_size,
        num_beams=num_beams
      )
    return list(map(tok.decode, out))


def preprocess_get_text(sent: str):
    sent = sent.replace("?/A:", "")
    sent = sent.replace('\n', '')
    return sent

In [9]:
dataset = load_dataset('ms_marco', 'v1.1', split='train')
df_train_dirty = dataset.to_pandas()
df_train = preprocess_df(df_train_dirty)

Downloading builder script:   0%|          | 0.00/8.52k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/8.15k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/9.15k [00:00<?, ?B/s]

Downloading and preparing dataset ms_marco/v1.1 to /root/.cache/huggingface/datasets/ms_marco/v1.1/1.1.0/b6a62715fa5219aea5275dd3556601004cd63945cb63e36e022f77bb3cbbca84...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/111M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/13.5M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/44.5M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating validation split:   0%|          | 0/10047 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/82326 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/9650 [00:00<?, ? examples/s]

Dataset ms_marco downloaded and prepared to /root/.cache/huggingface/datasets/ms_marco/v1.1/1.1.0/b6a62715fa5219aea5275dd3556601004cd63945cb63e36e022f77bb3cbbca84. Subsequent calls will reuse this data.


In [10]:
df_train.head(3)

Unnamed: 0,answers,passages,query,query_id,query_type,wellFormedAnswers,answers_norm,selected
0,[Results-Based Accountability is a disciplined...,"{'is_selected': [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]...",what is rba,19699,description,[],Results-Based Accountability is a disciplined ...,"[0, 0, 0, 0, 0, 1, 0, 0, 0, 0]"
1,[Yes],"{'is_selected': [0, 1, 0, 0, 0, 0, 0], 'passag...",was ronald reagan a democrat,19700,description,[],Yes,"[0, 1, 0, 0, 0, 0, 0]"
2,[20-25 minutes],"{'is_selected': [0, 0, 0, 0, 1, 0, 0, 0, 0, 0]...",how long do you need for sydney and surroundin...,19701,numeric,[],20-25 minutes,"[0, 0, 0, 0, 1, 0, 0, 0, 0, 0]"


In [11]:
if torch.has_mps:
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [12]:
NUM_ROWS = 100

queries = df_train["query"].values[0:NUM_ROWS]

In [13]:
tok, model = load_tokenizer_and_model("sberbank-ai/rugpt3medium_based_on_gpt2")

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.61M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/1.27M [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/674 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

In [None]:
generated_answers = []

for query in queries:
    generated = generate(model, tok, query + " ", num_beams=5, max_length=60)
    generated_answers.append(generated[0][len(query):])

In [None]:
save_gen_text(queries, generated_answers)

In [14]:
# Если уже есть файл с фейковыми текстами, то просто загружаем его

with open('generated_text_large.pickle', 'rb') as f:
    generated_text = pickle.load(f)

generated_answers = list(generated_text.values())

In [16]:
len(generated_answers)

5000

In [17]:
generated_answers_proc = [preprocess_get_text(s) for s in generated_answers]

In [18]:
def make_data_for_preds(df, gen_ans):
    df_tmp = df.copy()
    df_tmp["gen_answers"] = gen_ans
    data = []
    for row in df_tmp.iterrows():
        tmp_dict = {}
        
        texts = row[1]["passages"]["passage_text"].tolist()
        texts.append(row[1]["gen_answers"])

        query = row[1]["query"]
        
        is_selected = row[1]["selected"].tolist()
        is_selected.append(-1)
        
        tmp_dict["query"] = query
        tmp_dict["passage_text"] = texts
        tmp_dict["is_selected"] = is_selected
        data.append(tmp_dict)
    return data

In [19]:
data = make_data_for_preds(df_train[:len(generated_answers_proc)], generated_answers_proc)

In [24]:
metrics = [Bm25(), MsMarcoCE(), LaBSE(), MsMarcoST()]
rm = RankingMetrics(metrics)
for i, item in tqdm(enumerate(data)):
    rm.update(item["query"], item["passage_text"], item["is_selected"])
    if (i + 1) % 1000 == 0:
        rm.show_metrics()

1001it [03:13,  5.32it/s]

Bm25_AverageLoc: 7.42   MsMarcoCE_AverageLoc: 8.65   LaBSE_AverageLoc: 7.22   MsMarcoST_AverageLoc: 8.72   
-----------------------------
Bm25_AverageRelLoc: 0.8   MsMarcoCE_AverageRelLoc: 0.93   LaBSE_AverageRelLoc: 0.79   MsMarcoST_AverageRelLoc: 0.94   
-----------------------------
Bm25_Top@1: 0.08   MsMarcoCE_Top@1: 0.02   LaBSE_Top@1: 0.08   MsMarcoST_Top@1: 0.01   
Bm25_Top@3: 0.16   MsMarcoCE_Top@3: 0.03   LaBSE_Top@3: 0.16   MsMarcoST_Top@3: 0.04   
Bm25_Top@5: 0.24   MsMarcoCE_Top@5: 0.09   LaBSE_Top@5: 0.26   MsMarcoST_Top@5: 0.09   
-----------------------------
Bm25_FDARO@v1: 0.203   MsMarcoCE_FDARO@v1: 0.038   LaBSE_FDARO@v1: 0.198   MsMarcoST_FDARO@v1: 0.03   
Bm25_FDARO@v2: 0.209   MsMarcoCE_FDARO@v2: 0.042   LaBSE_FDARO@v2: 0.204   MsMarcoST_FDARO@v2: 0.033   
-----------------------------
Bm25_UpQuartile: 0.16   MsMarcoCE_UpQuartile: 0.03   LaBSE_UpQuartile: 0.16   MsMarcoST_UpQuartile: 0.03   




2001it [06:29,  5.92it/s]

Bm25_AverageLoc: 7.43   MsMarcoCE_AverageLoc: 8.63   LaBSE_AverageLoc: 7.26   MsMarcoST_AverageLoc: 8.69   
-----------------------------
Bm25_AverageRelLoc: 0.8   MsMarcoCE_AverageRelLoc: 0.93   LaBSE_AverageRelLoc: 0.79   MsMarcoST_AverageRelLoc: 0.94   
-----------------------------
Bm25_Top@1: 0.08   MsMarcoCE_Top@1: 0.01   LaBSE_Top@1: 0.08   MsMarcoST_Top@1: 0.01   
Bm25_Top@3: 0.16   MsMarcoCE_Top@3: 0.03   LaBSE_Top@3: 0.17   MsMarcoST_Top@3: 0.03   
Bm25_Top@5: 0.24   MsMarcoCE_Top@5: 0.09   LaBSE_Top@5: 0.25   MsMarcoST_Top@5: 0.09   
-----------------------------
Bm25_FDARO@v1: 0.1935   MsMarcoCE_FDARO@v1: 0.035   LaBSE_FDARO@v1: 0.189   MsMarcoST_FDARO@v1: 0.0275   
Bm25_FDARO@v2: 0.2015   MsMarcoCE_FDARO@v2: 0.039   LaBSE_FDARO@v2: 0.196   MsMarcoST_FDARO@v2: 0.0305   
-----------------------------
Bm25_UpQuartile: 0.16   MsMarcoCE_UpQuartile: 0.03   LaBSE_UpQuartile: 0.15   MsMarcoST_UpQuartile: 0.03   




3000it [09:43,  4.89it/s]

Bm25_AverageLoc: 7.44   MsMarcoCE_AverageLoc: 8.62   LaBSE_AverageLoc: 7.28   MsMarcoST_AverageLoc: 8.71   
-----------------------------
Bm25_AverageRelLoc: 0.8   MsMarcoCE_AverageRelLoc: 0.93   LaBSE_AverageRelLoc: 0.79   MsMarcoST_AverageRelLoc: 0.94   
-----------------------------
Bm25_Top@1: 0.08   MsMarcoCE_Top@1: 0.01   LaBSE_Top@1: 0.08   MsMarcoST_Top@1: 0.01   
Bm25_Top@3: 0.16   MsMarcoCE_Top@3: 0.04   LaBSE_Top@3: 0.16   MsMarcoST_Top@3: 0.03   
Bm25_Top@5: 0.23   MsMarcoCE_Top@5: 0.09   LaBSE_Top@5: 0.25   MsMarcoST_Top@5: 0.08   
-----------------------------
Bm25_FDARO@v1: 0.1907   MsMarcoCE_FDARO@v1: 0.0363   LaBSE_FDARO@v1: 0.1863   MsMarcoST_FDARO@v1: 0.0293   
Bm25_FDARO@v2: 0.2017   MsMarcoCE_FDARO@v2: 0.0393   LaBSE_FDARO@v2: 0.192   MsMarcoST_FDARO@v2: 0.0317   
-----------------------------
Bm25_UpQuartile: 0.16   MsMarcoCE_UpQuartile: 0.03   LaBSE_UpQuartile: 0.15   MsMarcoST_UpQuartile: 0.03   




4000it [12:56,  6.69it/s]

Bm25_AverageLoc: 7.49   MsMarcoCE_AverageLoc: 8.63   LaBSE_AverageLoc: 7.34   MsMarcoST_AverageLoc: 8.72   
-----------------------------
Bm25_AverageRelLoc: 0.81   MsMarcoCE_AverageRelLoc: 0.93   LaBSE_AverageRelLoc: 0.8   MsMarcoST_AverageRelLoc: 0.94   
-----------------------------
Bm25_Top@1: 0.08   MsMarcoCE_Top@1: 0.01   LaBSE_Top@1: 0.08   MsMarcoST_Top@1: 0.01   
Bm25_Top@3: 0.15   MsMarcoCE_Top@3: 0.04   LaBSE_Top@3: 0.16   MsMarcoST_Top@3: 0.03   
Bm25_Top@5: 0.23   MsMarcoCE_Top@5: 0.09   LaBSE_Top@5: 0.25   MsMarcoST_Top@5: 0.08   
-----------------------------
Bm25_FDARO@v1: 0.186   MsMarcoCE_FDARO@v1: 0.0378   LaBSE_FDARO@v1: 0.1792   MsMarcoST_FDARO@v1: 0.0288   
Bm25_FDARO@v2: 0.1968   MsMarcoCE_FDARO@v2: 0.0408   LaBSE_FDARO@v2: 0.1858   MsMarcoST_FDARO@v2: 0.0312   
-----------------------------
Bm25_UpQuartile: 0.15   MsMarcoCE_UpQuartile: 0.03   LaBSE_UpQuartile: 0.15   MsMarcoST_UpQuartile: 0.03   




5000it [16:05,  5.18it/s]

Bm25_AverageLoc: 7.45   MsMarcoCE_AverageLoc: 8.62   LaBSE_AverageLoc: 7.32   MsMarcoST_AverageLoc: 8.71   
-----------------------------
Bm25_AverageRelLoc: 0.81   MsMarcoCE_AverageRelLoc: 0.93   LaBSE_AverageRelLoc: 0.8   MsMarcoST_AverageRelLoc: 0.94   
-----------------------------
Bm25_Top@1: 0.08   MsMarcoCE_Top@1: 0.01   LaBSE_Top@1: 0.08   MsMarcoST_Top@1: 0.01   
Bm25_Top@3: 0.16   MsMarcoCE_Top@3: 0.04   LaBSE_Top@3: 0.16   MsMarcoST_Top@3: 0.03   
Bm25_Top@5: 0.23   MsMarcoCE_Top@5: 0.09   LaBSE_Top@5: 0.25   MsMarcoST_Top@5: 0.08   
-----------------------------
Bm25_FDARO@v1: 0.188   MsMarcoCE_FDARO@v1: 0.0364   LaBSE_FDARO@v1: 0.1806   MsMarcoST_FDARO@v1: 0.029   
Bm25_FDARO@v2: 0.1976   MsMarcoCE_FDARO@v2: 0.0394   LaBSE_FDARO@v2: 0.1884   MsMarcoST_FDARO@v2: 0.0314   
-----------------------------
Bm25_UpQuartile: 0.16   MsMarcoCE_UpQuartile: 0.03   LaBSE_UpQuartile: 0.15   MsMarcoST_UpQuartile: 0.03   







FDARO@v1 - fake_doc_above_all_relevent  
FDARO@v2 - fake_doc_above_at_least_one_rel  
AverageRelLoc - относительная позиция документа в выдаче. Чем ближе к нулю - тем лучше  
MsMarcoCE более устойчива к кейсам, когда пассаж полностью повторяет запрос