In [25]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [55]:
import sys
sys.path.insert(0, "../src")

import sqlite3
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder
from pyserini.search.lucene import LuceneSearcher

import constants
from gen import util
from retrieval.fever_doc_db import FeverDocDB
from retrieval.retrieval import BM25DocumentRetriever

In [2]:
LuceneSearcher.list_prebuilt_indexes()

                        cacm                                                                                    \
description              Lucene index of the CACM corpus                                                         
filename                 lucene-index.cacm.tar.gz                                                                
urls                     [https://github.com/castorini/anserini-data/raw/master/CACM/lucene-index.cacm.tar.gz]   
md5                      e47164fbd18aab72cdc18aecc0744bb1                                                        
size compressed (bytes)  2372903                                                                                 
total_terms              320968                                                                                  
documents                3204                                                                                    
unique_terms             14363                                                          

In [29]:
fever_doc = BM25DocumentRetriever(
    "/users/k21190024/study/fact-check-transfer-learning/scratch/data/fever/paper_dev.jsonl",
    "/users/k21190024/study/fact-check-transfer-learning/scratch/data/fever/fever.db",
    "beir-v1.0.0-fever-flat",
    n_jobs=20
)
fever_doc.batch_document_retrieve()

Attempting to initialize pre-built index beir-v1.0.0-fever-flat.
/users/k21190024/.cache/pyserini/indexes/lucene-index.beir-v1.0.0-fever-flat.20220501.1842ee.63cd5f369b5952386f138efe45571d41 already exists, skipping download.
Initializing beir-v1.0.0-fever-flat...


In [81]:
fever_doc.results[1000]

{'id': 123416,
 'verifiable': 'VERIFIABLE',
 'label': 'SUPPORTS',
 'claim': 'Commercial sexual exploitation, as well as forced labor, are reasons for human trafficking.',
 'evidence': [[[144921, 159926, 'Human_trafficking', 0]]],
 'predicted_pages_score': [['Human_trafficking_in_Venezuela',
   28.88920021057129],
  ['Human_trafficking_in_Yemen', 28.819900512695312],
  ['Human_trafficking', 28.779800415039062],
  ['Human_trafficking_in_Finland', 28.40290069580078],
  ['Human_trafficking_in_Ukraine', 28.11359977722168],
  ['Human_trafficking_in_Uganda', 27.819700241088867],
  ['Human_trafficking_in_Vietnam', 27.808500289916992],
  ['Human_trafficking_in_Latvia', 27.66189956665039],
  ['Human_trafficking_in_Ohio', 27.395299911499023],
  ['Human_trafficking_in_California', 27.380199432373047],
  ["Human_trafficking_in_the_People's_Republic_of_China", 27.315000534057617],
  ['Human_trafficking_in_Namibia', 27.24530029296875],
  ['Human_trafficking_in_Lebanon', 27.047800064086914],
  ['Human

In [49]:
list(zip([[1, 2], [3, 4]], [10, 20]))

[([1, 2], 10), ([3, 4], 20)]

In [99]:
def run_sentence_retrieval(doc, db, ce):
    if doc["label"] != constants.LOOKUP["label"]["nei"]:
        return []
    ce_input = []
    sents_idx = []
    page_ls = []
    for page in doc["predicted_pages"]:
        lines = db.get_doc_lines(util.denormalise_title(page))
        for line in lines.split("\n"):
            elem = line.split("\t")
            if elem[0].isdigit() and elem[1].strip():
                sents_idx.append(elem[0])
                page_ls.append(page)
                ce_input.append([doc["claim"], elem[1]])
    scores = ce.predict(ce_input)
    sents = sorted(list(zip(page_ls, sents_idx, scores)), key=lambda x: x[2], reverse=False)
    
    return doc["id"], [[None, None, s[0], int(s[1])] for s in sents]

In [100]:
db = FeverDocDB("/users/k21190024/study/fact-check-transfer-learning/scratch/data/fever/fever.db")
ce = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
res = run_sentence_retrieval(fever_doc.results[1631], db, ce)

In [103]:
a = [
    [[None, None, "something", 2]],
    [[None, None, "nothing", 3]],
    [[None, None, "anything", 10]]
]

b = [
    [[None, None, "something", 2]],
    [[None, None, "lol", 3]],
    [[None, None, "anything", 99999]]
]

negatives = []
for ev, retr in zip(a, b):
    if ev[0][2] != retr[0][2] or ev[0][3] != retr[0][3]:
        negatives.append(retr)

In [107]:
a = [1,2,3,4,5,6,7]
a[:1], a[1:1+5]

([1], [2, 3, 4, 5, 6])

# Playground

In [None]:
conn = sqlite3.connect("/users/k21190024/study/fact-check-transfer-learning/scratch/dumps/feverised-scifact/scifact.db")
cur = conn.cursor()

In [None]:
cur.execute("""select * from documents where id = ?""", ("34198460", ))
cfever_ids = cur.fetchall()
cfever_ids[0:10]

In [None]:
cur.execute("""select * from documents where id = ?""", ("Past_sea_level", ))
tmp = cur.fetchall()
tmp

In [None]:
for line in lines.split("\n"):
    l = line.split("\t")
    if l[0] == "13":
        print(l[1])