In [5]:
!pip install sentence-transformers transformers accelerate bitsandbytes trectools



In [6]:
import json
import random
import torch
import numpy as np
from tqdm import tqdm
from collections import defaultdict

from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM


In [7]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.matmul.allow_tf32 = True

In [8]:
DATA_DIR = "/kaggle/input/joker-files"

with open(f"{DATA_DIR}/joker_task1_retrieval_corpus25_EN.json") as f:
    corpus = json.load(f)

with open(f"{DATA_DIR}/joker_task1_retrieval_queries_train25_EN.json") as f:
    queries_train = json.load(f)

with open(f"{DATA_DIR}/joker_task1_retrieval_queries_test25_EN.json") as f:
    queries_test = json.load(f)

with open(f"{DATA_DIR}/joker_task1_retrieval_qrels_train25_EN.json") as f:
    qrels_train = json.load(f)

print(corpus[0], queries_train[0], qrels_train[0])

{'docid': '1', 'text': 'He has a green body, no visible nose, and lives in a trash can.'} {'qid': '8', 'query': 'colors'} {'qid': 8, 'docid': 151, 'qrel': 1}


In [9]:
retriever = SentenceTransformer(
    "paraphrase-multilingual-mpnet-base-v2",
    device=DEVICE
)

modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/723 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/402 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [10]:
doc_texts = [d["text"] for d in corpus]
doc_ids = [d["docid"] for d in corpus]

doc_embeddings = retriever.encode(
    doc_texts,
    convert_to_numpy=True,
    normalize_embeddings=True,
    batch_size=256,
    show_progress_bar=True
)

Batches:   0%|          | 0/304 [00:00<?, ?it/s]

In [11]:
clf_name = "mistralai/Mistral-7B-Instruct-v0.2"

tokenizer = AutoTokenizer.from_pretrained(clf_name, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    clf_name,
    device_map="auto",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True
)

model.config.pad_token_id = tokenizer.pad_token_id
model.eval()

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/596 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): MistralRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): MistralRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): MistralRMSNorm((4096,)

In [12]:
def build_prompt(query, doc):
    return f"""[INST]
Query:
{query}

Document:
{doc}

Question:
Is this document humorous and relevant to the query?

Answer yes or no.
[/INST]
"""

In [13]:
YES_TOKEN = tokenizer.encode(" yes", add_special_tokens=False)[0]
NO_TOKEN = tokenizer.encode(" no", add_special_tokens=False)[0]

In [14]:
@torch.no_grad()
def humor_scores_qd(queries, docs, batch_size=8):
    scores = []

    for i in tqdm(range(0, len(docs), batch_size)):
        batch_q = queries[i:i+batch_size]
        batch_d = docs[i:i+batch_size]

        prompts = [
            build_prompt(q, d if isinstance(d, str) else "")
            for q, d in zip(batch_q, batch_d)
        ]

        enc = tokenizer(
            prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=2048
        ).to(DEVICE)

        outputs = model(**enc)
        logits = outputs.logits[:, -1, :]

        yes_logits = logits[:, YES_TOKEN]
        no_logits = logits[:, NO_TOKEN]

        probs = torch.softmax(
            torch.stack([no_logits, yes_logits], dim=1),
            dim=1
        )[:, 1]

        scores.extend(probs.cpu().numpy())

    return np.array(scores)

In [15]:
SIM_THRESHOLD = 0.0

TOP_K_SEMANTIC = 250

def retrieve(query, top_k=100):
    q_emb = retriever.encode(
        [query],
        convert_to_numpy=True,
        normalize_embeddings=True
    )[0]

    sim_scores = np.dot(doc_embeddings, q_emb)

    sem_idx = np.argsort(sim_scores)[::-1][:TOP_K_SEMANTIC]

    cand_docs = [doc_texts[i] for i in sem_idx]
    cand_queries = [query] * len(cand_docs)

    humor_scores = humor_scores_qd(
        cand_queries,
        cand_docs,
        batch_size=4
    )

    order = np.argsort(humor_scores)[::-1][:top_k]

    return [
        {
            "docid": doc_ids[sem_idx[i]],
            "rank": r + 1,
            "score": float(humor_scores[i])
        }
        for r, i in enumerate(order)
    ]

In [16]:
train_predictions = []

for q in tqdm(queries_train):
    res = retrieve(q["query"], top_k=100)
    for r in res:
        train_predictions.append({
            "qid": q["qid"],
            "docid": r["docid"],
            "rank": r["rank"],
            "score": r["score"]
        })

  0%|          | 0/12 [00:00<?, ?it/s]
  0%|          | 0/63 [00:00<?, ?it/s][A
  2%|▏         | 1/63 [00:00<00:54,  1.13it/s][A
  3%|▎         | 2/63 [00:01<00:34,  1.77it/s][A
  5%|▍         | 3/63 [00:01<00:26,  2.29it/s][A
  6%|▋         | 4/63 [00:01<00:23,  2.51it/s][A
  8%|▊         | 5/63 [00:02<00:20,  2.80it/s][A
 10%|▉         | 6/63 [00:02<00:20,  2.82it/s][A
 11%|█         | 7/63 [00:02<00:18,  3.01it/s][A
 13%|█▎        | 8/63 [00:03<00:17,  3.16it/s][A
 14%|█▍        | 9/63 [00:03<00:17,  3.07it/s][A
 16%|█▌        | 10/63 [00:03<00:16,  3.19it/s][A
 17%|█▋        | 11/63 [00:03<00:15,  3.28it/s][A
 19%|█▉        | 12/63 [00:04<00:16,  3.09it/s][A
 21%|██        | 13/63 [00:04<00:15,  3.18it/s][A
 22%|██▏       | 14/63 [00:04<00:15,  3.26it/s][A
 24%|██▍       | 15/63 [00:05<00:14,  3.29it/s][A
 25%|██▌       | 16/63 [00:05<00:12,  3.63it/s][A
 27%|██▋       | 17/63 [00:05<00:12,  3.56it/s][A
 29%|██▊       | 18/63 [00:05<00:12,  3.58it/s][A
 30%|███  

In [17]:
with open("run_train.txt", "w") as f:
    for r in train_predictions:
        f.write(f"{r['qid']} Q0 {r['docid']} {r['rank']} {r['score']} run\n")

with open("qrels_train.txt","w") as f:
    for q in qrels_train:
        f.write(f"{q['qid']} 0 {q['docid']} {q['qrel']}\n")

In [18]:
from trectools import TrecRun, TrecQrel, TrecEval

run = TrecRun("run_train.txt")
qrels = TrecQrel("qrels_train.txt")

ev = TrecEval(run, qrels)

metrics = {
    "map": ev.get_map(),
    "recip_rank": ev.get_reciprocal_rank(),
    "ndcg_5": ev.get_ndcg(5),
    "ndcg_10": ev.get_ndcg(10),
    "ndcg_20": ev.get_ndcg(20),
    "P_5": ev.get_precision(5),
    "P_10": ev.get_precision(10),
    "P_20": ev.get_precision(20),
    "recall_5": ev.get_recall(5),
    "recall_10": ev.get_recall(10),
    "recall_20": ev.get_recall(20),
}

metrics

  selection = selection[~selection["rel"].isnull()].groupby("query").first().copy()


{'map': np.float64(0.10217009215803098),
 'recip_rank': np.float64(0.35142496392496386),
 'ndcg_5': np.float64(0.3192016581135993),
 'ndcg_10': np.float64(0.2942179760514528),
 'ndcg_20': np.float64(0.2976810511737515),
 'P_5': np.float64(0.2333333333333333),
 'P_10': np.float64(0.2583333333333333),
 'P_20': np.float64(0.25),
 'recall_5': np.float64(0.03232675519909562),
 'recall_10': np.float64(0.07407334588185653),
 'recall_20': np.float64(0.1832653654462165)}

In [19]:
predictions = []

for q in tqdm(queries_test):
    res = retrieve(q["query"], top_k=100)
    for r in res:
        predictions.append({
            "run_id": "run_test",
            "manual": 0,
            "qid": q["qid"],
            "docid": r["docid"],
            "rank": r["rank"],
            "score": r["score"]
        })

with open("prediction.json","w") as f:
    json.dump(predictions, f, indent=2)

  0%|          | 0/219 [00:00<?, ?it/s]
  0%|          | 0/63 [00:00<?, ?it/s][A
  2%|▏         | 1/63 [00:00<00:13,  4.62it/s][A
  3%|▎         | 2/63 [00:00<00:17,  3.53it/s][A
  5%|▍         | 3/63 [00:00<00:17,  3.38it/s][A
  6%|▋         | 4/63 [00:01<00:18,  3.25it/s][A
  8%|▊         | 5/63 [00:01<00:18,  3.19it/s][A
 10%|▉         | 6/63 [00:01<00:18,  3.16it/s][A
 11%|█         | 7/63 [00:02<00:17,  3.15it/s][A
 13%|█▎        | 8/63 [00:02<00:17,  3.13it/s][A
 14%|█▍        | 9/63 [00:02<00:17,  3.14it/s][A
 16%|█▌        | 10/63 [00:03<00:16,  3.16it/s][A
 17%|█▋        | 11/63 [00:03<00:16,  3.11it/s][A
 19%|█▉        | 12/63 [00:03<00:16,  3.10it/s][A
 21%|██        | 13/63 [00:04<00:16,  3.10it/s][A
 22%|██▏       | 14/63 [00:04<00:16,  2.93it/s][A
 24%|██▍       | 15/63 [00:04<00:16,  3.00it/s][A
 25%|██▌       | 16/63 [00:05<00:15,  3.03it/s][A
 27%|██▋       | 17/63 [00:05<00:13,  3.37it/s][A
 29%|██▊       | 18/63 [00:05<00:13,  3.30it/s][A
 30%|███ 

KeyboardInterrupt: 