# Citation Recommendation on Scholarly Legal Articles

## BM-25

### Libraries

In [1]:
import rank_bm25
import pickle
import os
from tqdm import tqdm

### Dataset

In [2]:
with open('/Users/dgknrsln/Documents/pythonProject/EXPERIMENTS/LATEST_test_docs.pkl', 'rb') as f:
    docs = pickle.load(f)

with open('/Users/dgknrsln/Documents/pythonProject/EXPERIMENTS/LATEST_test_queries.pkl', 'rb') as f:
    queries = pickle.load(f)

with open('/Users/dgknrsln/Documents/pythonProject/EXPERIMENTS/LATEST_test_data.pkl', 'rb') as f:
    pair = pickle.load(f)

### Train Model

In [3]:
tokenized_corpus = [doc.split() for doc in list(set(docs))]

bm25 = rank_bm25.BM25Plus(tokenized_corpus)

### Evaluate

#### 1. MAP

In [4]:
total_prec = 0
found = 0
for i in tqdm(range(len(queries))):

    sample = queries[i]
    tokenized_query = sample.split()
    results = bm25.get_top_n(tokenized_query, tokenized_corpus, n=10)

    count = 0
    precision = 0
    index = 0
    for m in results:
        if ' '.join(str(e) for e in m) in pair[i][1]:
            count += 1
            precision += count/(index+1)
        index += 1

    if count == 0:
        precision = 0
    else:
        found += 1
        precision /= count

    total_prec += precision

100%|██████████| 2675/2675 [02:40<00:00, 16.64it/s]


In [5]:
MAP = total_prec / len(queries)
print(MAP)

0.26791680265044787


#### 2. Recall

In [6]:
total_prec = 0
found = 0
for i in tqdm(range(len(queries))):

    sample = queries[i]
    tokenized_query = sample.split()
    results = bm25.get_top_n(tokenized_query, tokenized_corpus, n=10)

    count = 0
    for m in results:
        if ' '.join(str(e) for e in m) in pair[i][1]:
            count += 1

    total_prec += (count / len(pair[i][1]))

100%|██████████| 2675/2675 [02:23<00:00, 18.69it/s]


In [7]:
RECALL = total_prec / len(queries)
print(RECALL)

0.4520778816199377


#### 3. MRR

In [8]:
total_prec = 0
found = 0
for i in tqdm(range(len(queries))):

    sample = queries[i]
    tokenized_query = sample.split()
    results = bm25.get_top_n(tokenized_query, tokenized_corpus, n=10)

    index = 1
    for m in results:
        if ' '.join(str(e) for e in m) in pair[i][1]:
            break
        index += 1

    total_prec += (1/index)

100%|██████████| 2675/2675 [02:24<00:00, 18.54it/s]


In [9]:
MRR = total_prec / len(queries)
print(MRR)

0.317622918099557


# Results

In [10]:
print("MAP@10: " + str(MAP))
print("Recall@10: " + str(RECALL))
print("MRR@10: " + str(MRR))

MAP@10: 0.26791680265044787
Recall@10: 0.4520778816199377
MRR@10: 0.317622918099557
