# Homework 1 (Total Points: 250) <a class="anchor" id="top"></a>


**Submission instructions**:
- The cells with the `# YOUR CODE HERE` denote that these sections are graded and you need to add your implementation.
- For Part 1: You can use the `nltk`, `NumPy`, and `matplotlib` libraries here. Other libraries, e.g., `gensim` or `scikit-learn`, may not be used. For Part 2: `gensim` is allowed in addition to the imported libraries in the next code cell
- Please use Python 3.6.5 and `pip install -r requirements.txt` to avoid version issues.
- The notebook you submit has to have the student ids, separated by underscores (E.g., `12341234_12341234_12341234_hw1.ipynb`).
- This will be parsed by a regexp, **so please double check your filename**.
- Only one member of each group has to submit the file (**please do not compress the .ipynb file when you will submit it**) to canvas.
- **Make sure to check that your notebook runs before submission**. A quick way to do this is to restart the kernel and run all the cells.  
- Do not change the number of arugments in the given functions.
- **Please do not delete/add new cells**. Removing cells **will** lead to grade deduction. 
- Note, that you are not allowed to use Google Colab.


**Learning Goals**:
- [Part 1, Term-based matching](#part1) (165 points):
    - Learn how to load a dataset and process it.
    - Learn how to implement several standard IR methods (TF-IDF, BM25, QL) and understand their weaknesses & strengths.
    - Learn how to evaluate IR methods.
- [Part 2, Semantic-based matching](#part2) (85 points):
    - Learn how to implement vector-space retrieval methods (LSI, LDA).
    - Learn how to use LSI and LDA for re-ranking.

    
**Resources**: 
- **Part 1**: Sections 2.3, 4.1, 4.2, 4.3, 5.3, 5.6, 5.7, 6.2, 7, 8 of [Search Engines: Information Retrieval in Practice](https://ciir.cs.umass.edu/downloads/SEIRiP.pdf)
- **Part 2**: [LSI - Chapter 18](https://nlp.stanford.edu/IR-book/pdf/18lsi.pdf) from [Introduction to Information Retrieval](https://nlp.stanford.edu/IR-book/) book and the [original LDA paper](https://jmlr.org/papers/volume3/blei03a/blei03a.pdf)

In [1]:
# imports 
# TODO: Ensure that no additional library is imported in the notebook. 
# TODO: Only the standard library and the following libraries are allowed:
# TODO: You can also use unlisted classes from these libraries or standard libraries (such as defaultdict, Counter, ...).

import os
import zipfile
from functools import partial

import nltk
import requests
import numpy as np
from tqdm import tqdm

import matplotlib.pyplot as plt
from matplotlib.pyplot import cm

from ipywidgets import widgets
from IPython.display import display, HTML
#from IPython.html import widgets
from collections import namedtuple

%matplotlib inline


# Part 1: Term-based Matching (165 points) <a class="anchor" id="part1"></a>

[Back to top](#top)

In the first part, we will learn the basics of IR from loading and preprocessing the material, to implementing some well known search algorithms, to evaluating the ranking performance of the implemented algorithms. We will be using the CACM dataset throughout the assignment. The CACM dataset is a collection of titles and abstracts from the journal CACM (Communication of the ACM).

Table of contents:
- [Section 1: Text Processing](#text_processing) (5 points)
- [Section 2: Indexing](#indexing) (10 points)
- [Section 3: Ranking](#ranking) (80 points)
- [Section 4: Evaluation](#evaluation) (40 points)
- [Section 5: Analysis](#analysis) (30 points)


---
## Section 1: Text Processing (5 points)<a class="anchor" id="text_processing"></a>

[Back to Part 1](#part1)

In this section, we will load the dataset and learn how to clean up the data to make it usable for an IR system. 
First, go through the implementation of the following functions:
- `read_cacm_docs`: Reads in the CACM documents.
- `read_queries`: Reads in the CACM queries.
- `load_stopwords`: Loads the stopwords.

The points of this section are earned for the following implementations:
- `tokenize` (3 points): Tokenizes the input text.
- `stem_token` (2 points): Stems the given token. 

We are using the [CACM dataset](http://ir.dcs.gla.ac.uk/resources/test_collections/cacm/), which is a small, classic IR dataset, composed of a collection of titles and abstracts from the journal CACM. It comes with relevance judgements for queries, so we can evaluate our IR system. 


---
### 1.1 Read the CACM documents


The following cell downloads the dataset and unzips it to a local directory.

In [2]:
def download_dataset():
    folder_path = os.environ.get("IR1_DATA_PATH")
    if not folder_path:
        folder_path = "./datasets/"
    os.makedirs(folder_path, exist_ok=True)
    
    file_location = os.path.join(folder_path, "cacm.zip")
    
    # download file if it doesn't exist
    if not os.path.exists(file_location):
        
        url = "https://surfdrive.surf.nl/files/index.php/s/M0FGJpX2p8wDwxR/download"

        with open(file_location, "wb") as handle:
            print(f"Downloading file from {url} to {file_location}")
            response = requests.get(url, stream=True)
            for data in tqdm(response.iter_content()):
                handle.write(data)
            print("Finished downloading file")
    
    if not os.path.exists(os.path.join(folder_path, "train.txt")):
        
        # unzip file
        with zipfile.ZipFile(file_location, 'r') as zip_ref:
            zip_ref.extractall(folder_path)
        
download_dataset()

---

You can see a brief description of each file in the dataset by looking at the README file:

In [3]:
##### Read the README file 
with open ("./datasets/README","r") as file:
    readme = file.read()
    print(readme)
#####

Files in this directory with sizes:
          0 Jun 19 21:01 README

    2187734 Jun 19 20:55 cacm.all              text of documents
        626 Jun 19 20:58 cite.info             key to citation info
                                                (the X sections in cacm.all)
       2668 Jun 19 20:55 common_words           stop words used by smart
       2194 Jun 19 20:55 make_coll*             shell script to make collection
       1557 Jun 19 20:55 make_coll_term*        ditto (both useless without
                                                smart system)
       9948 Jun 19 20:55 qrels.text             relation giving
                                                    qid did 0 0
                                                to indicate dument did is
                                                relevant to query qid
      13689 Jun 19 20:55 query.text             Original text of the query



---
We are interested in 4 files:
- `cacm.all` : Contains the text for all documents. Note that some documents do not have abstracts available
- `query.text` : The text of all queries
- `qrels.text` : The relevance judgements
- `common_words` : A list of common words. This may be used as a collection of stopwords

In [4]:
##### The first 45 lines of the CACM dataset forms the first record
# We are interested only in 3 fields. 
# 1. the '.I' field, which is the document id
# 2. the '.T' field (the title) and
# 3. the '.W' field (the abstract, which may be absent)
with open ("./datasets/cacm.all","r") as file:
    cacm_all = "".join(file.readlines()[:45])
    print(cacm_all)
#####

.I 1
.T
Preliminary Report-International Algebraic Language
.B
CACM December, 1958
.A
Perlis, A. J.
Samelson,K.
.N
CA581203 JB March 22, 1978  8:28 PM
.X
100	5	1
123	5	1
164	5	1
1	5	1
1	5	1
1	5	1
205	5	1
210	5	1
214	5	1
1982	5	1
398	5	1
642	5	1
669	5	1
1	6	1
1	6	1
1	6	1
1	6	1
1	6	1
1	6	1
1	6	1
1	6	1
1	6	1
1	6	1
165	6	1
196	6	1
196	6	1
1273	6	1
1883	6	1
324	6	1
43	6	1
53	6	1
91	6	1
410	6	1
3184	6	1



---

The following function reads the `cacm.all` file. Note that each document has a variable number of lines. The `.I` field denotes a new document:

In [5]:
def read_cacm_docs(root_folder = "./datasets/"):
    """
        Reads in the CACM documents. The dataset is assumed to be in the folder "./datasets/" by default
        Returns: A list of 2-tuples: (doc_id, document), where 'document' is a single string created by 
            appending the title and abstract (separated by a "\n"). 
            In case the record doesn't have an abstract, the document is composed only by the title
    """
    with open(os.path.join(root_folder, "cacm.all")) as reader:
        lines = reader.readlines()
    
    doc_id, title, abstract = None, None, None
    
    docs = []
    line_idx = 0
    while line_idx < len(lines):
        line = lines[line_idx]
        if line.startswith(".I"):
            if doc_id is not None:
                docs.append((doc_id, title, abstract))
                doc_id, title, abstract = None, None, None
            
            doc_id = line.split()[-1]
            line_idx += 1
        elif line.startswith(".T"):
            # start at next line
            line_idx += 1
            temp_lines = []
            # read till next '.'
            while not lines[line_idx].startswith("."):
                temp_lines.append(lines[line_idx].strip("\n"))
                line_idx += 1
            title = "\n".join(temp_lines).strip("\n")
        elif line.startswith(".W"):
            # start at next line
            line_idx += 1
            temp_lines = []
            # read till next '.'
            while not lines[line_idx].startswith("."):
                temp_lines.append(lines[line_idx].strip("\n"))
                line_idx += 1
            abstract = "\n".join(temp_lines).strip("\n")
        else:
            line_idx += 1
    
    docs.append((doc_id, title, abstract))
    
    p_docs = []
    for (did, t, a) in docs:
        if a is None:
            a = ""
        p_docs.append((did, t + "\n" + a))
    return p_docs


In [6]:
##### Function check
docs = read_cacm_docs()

assert isinstance(docs, list)
assert len(docs) == 3204, "There should be exactly 3204 documents"

unzipped_docs = list(zip(*docs))
assert np.sum(np.array(list(map(int,unzipped_docs[0])))) == 5134410

##### 

---
### 1.2 Read the CACM queries

Next, let us read the queries. They are formatted similarly:

In [7]:
##### The first 15 lines of 'query.text' has 2 queries
# We are interested only in 2 fields. 
# 1. the '.I' - the query id
# 2. the '.W' - the query
!head -15 ./datasets/query.text
#####

'head' is not recognized as an internal or external command,
operable program or batch file.


---

The following function reads the `query.text` file:

In [8]:
def read_queries(root_folder = "./datasets/"):
    """
        Reads in the CACM queries. The dataset is assumed to be in the folder "./datasets/" by default
        Returns: A list of 2-tuples: (query_id, query)
    """
    with open(os.path.join(root_folder, "query.text")) as reader:
        lines = reader.readlines()
    
    query_id, query = None, None
    
    queries = []
    line_idx = 0
    while line_idx < len(lines):
        line = lines[line_idx]
        if line.startswith(".I"):
            if query_id is not None:
                queries.append((query_id, query))
                query_id, query = None, None
    
            query_id = line.split()[-1]
            line_idx += 1
        elif line.startswith(".W"):
            # start at next line
            line_idx += 1
            temp_lines = []
            # read till next '.'
            while not lines[line_idx].startswith("."):
                temp_lines.append(lines[line_idx].strip("\n"))
                line_idx += 1
            query = "\n".join(temp_lines).strip("\n")
        else:
            line_idx += 1
    
    queries.append((query_id, query))
    return queries


In [9]:
##### Function check
queries = read_queries()

assert isinstance(queries, list)
assert len(queries) == 64 and all([q[1] is not None for q in queries]), "There should be exactly 64 queries"

unzipped_queries = list(zip(*queries))
assert np.sum(np.array(list(map(int,unzipped_queries[0])))) == 2080

##### 

---
### 1.3 Read the stop words

We use the common words stored in `common_words`:

In [10]:
##### Read the stop words file 
!head ./datasets/common_words
##### Read the README file 

'head' is not recognized as an internal or external command,
operable program or batch file.


---

The following function reads the `common_words` file (For better coverage, we try to keep them in lowercase):

In [11]:
def load_stopwords(root_folder = "./datasets/"):
    """
        Loads the stopwords. The dataset is assumed to be in the folder "./datasets/" by default
        Output: A set of stopwords
    """
    with open(os.path.join(root_folder, "common_words")) as reader:
        lines = reader.readlines()
    stopwords = set([l.strip().lower() for l in lines])
    return stopwords


In [12]:
##### Function check
stopwords = load_stopwords()

assert isinstance(stopwords, set)
assert len(stopwords) == 428, "There should be exactly 428 stop words"

assert np.sum(np.array(list(map(len,stopwords)))) == 2234

##### 


---
### 1.4 Tokenization (3 points)

We can now write some basic text processing functions. 
A first step is to tokenize the text. 

**Note**: Use the  `WordPunctTokenizer` available in the `nltk` library:

In [13]:
# TODO: Implement this! (4 points)
def tokenize(text):
    """
        Tokenizes the input text. Use the WordPunctTokenizer
        Input: text - a string
        Output: a list of tokens
    """
    tk = nltk.tokenize.WordPunctTokenizer()
    return tk.tokenize(text)

In [14]:
##### Function check
text = "the quick brown fox jumps over the lazy dog"
tokens = tokenize(text)

assert isinstance(tokens, list)
assert len(tokens) == 9

print(tokens)
# output: ['the', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog']
#####

['the', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog']


---
### 1.5 Stemming (2 points)

Write a function to stem tokens. 
Again, you can use the nltk library for this:

In [15]:
# TODO: Implement this! (3 points)
def stem_token(token):
    """
        Stems the given token using the PorterStemmer from the nltk library
        Input: a single token
        Output: the stem of the token
    """
    stemmer = nltk.stem.PorterStemmer()
    return stemmer.stem(token)

In [16]:
##### Function check

assert stem_token('owned') == 'own'
assert stem_token('itemization') == 'item'
#####

---
### 1.6 Summary

The following function puts it all together. Given an input string, this functions tokenizes and processes it according to the flags that you set.

In [199]:
#### Putting it all together
def process_text(text, stem=False, remove_stopwords=False, lowercase_text=False):
    
    tokens = []
    for token in tokenize(text):
        if remove_stopwords and token.lower() in stopwords:
            continue
        if stem:
            token = stem_token(token)
        if lowercase_text:
            token = token.lower()
        tokens.append(token)

    return tokens
#### 

---

Let's create two sets of preprocessed documents.
We can process the documents and queries according to these two configurations:

In [200]:
# In this configuration:
# Don't preprocess the text, except to tokenize 
config_1 = {
  "stem": False,
  "remove_stopwords" : False,
  "lowercase_text": True
} 


# In this configuration:
# Preprocess the text, stem and remove stopwords
config_2 = {
  "stem": True,
  "remove_stopwords" : True,
  "lowercase_text": True, 
} 

####
doc_repr_1 = []
doc_repr_2 = []
for (doc_id, document) in docs:
    doc_repr_1.append((doc_id, process_text(document, **config_1)))
    doc_repr_2.append((doc_id, process_text(document, **config_2)))

####

--- 

## Section 2: Indexing (10 points)<a class="anchor" id="indexing"></a>

[Back to Part 1](#part1)



A retrieval function usually takes in a query document pair, and scores a query against a document.  Our document set is quite small - just a few thousand documents. However, consider a web-scale dataset with a few million documents. In such a scenario, it would become infeasible to score every query and document pair. In such a case, we can build an inverted index. From Wikipedia:

> ... , an inverted index (also referred to as a postings file or inverted file) is a database index storing a mapping from content, such as words or numbers, to its locations in a table, .... The purpose of an inverted index is to allow fast full-text searches, at a cost of increased processing when a document is added to the database. ...


Consider a simple inverted index, which maps from word to document. This can improve the performance of a retrieval system significantly. In this assignment, we consider a *simple* inverted index, which maps a word to a set of documents. In practice, however, more complex indices might be used.  


### 2.1 Term Frequency-index (10 points)
In this assignment, we will be using an index created in memory since our dataset is tiny. To get started, build a simple index that maps each `token` to a list of `(doc_id, count)` where `count` is the count of the `token` in `doc_id`.
For consistency, build this index using a python dictionary.
    
Now, implement a function to build an index:

In [19]:
# TODO: Implement this! (10 points)
from collections import defaultdict
def build_tf_index(documents):
    """
        Build an inverted index (with counts). The output is a dictionary which takes in a token
        and returns a list of (doc_id, count) where 'count' is the count of the 'token' in 'doc_id'
        Input: a list of documents - (doc_id, tokens) 
        Output: An inverted index implemented within a pyhton dictionary: [token] -> [(doc_id, token_count)]
    """
    # YOUR CODE HERE
    index = defaultdict(list)
    for (doc_id, tokens) in documents:
        doc_id = int(doc_id)
        seen_tks = {}
        for tk in tokens:
            new_tk = tk not in index
            if new_tk:
                index[tk] = []
            if tk not in seen_tks:
                index[tk].append([doc_id, 1])
                seen_tks[tk] = True
            else:
                index[tk][-1][1] += 1
    return index

---
Now we can build indexed documents and preprocess the queries based on the two configurations:

In [20]:
#### Indexed documents based on the two configs

# Create the 2 indices
tf_index_1 = build_tf_index(doc_repr_1)
tf_index_2 = build_tf_index(doc_repr_2)

# This function returns the tf_index of the corresponding config
def get_index(index_set):
    assert index_set in {1, 2}
    return {
        1: tf_index_1,
        2: tf_index_2
    }[index_set]

####
#### Preprocessed query based on the two configs

# This function preprocesses the text given the index set, according to the specified config
def preprocess_query(text, index_set):
    assert index_set in {1, 2}
    if index_set == 1:
        return process_text(text, **config_1)
    elif index_set == 2:
        return process_text(text, **config_2)

#### 

In [21]:
##### Function check

assert isinstance(tf_index_1, dict)

assert isinstance(tf_index_1['computer'], list)
print('sample tf index for computer:', tf_index_1['computer'][:10])

assert isinstance(tf_index_1['examples'], list)
print('sample tf index for examples:', tf_index_1['examples'][:10])
#### 

sample tf index for computer: [[4, 1], [7, 1], [10, 1], [13, 1], [19, 1], [22, 1], [23, 1], [37, 1], [40, 3], [41, 1]]
sample tf index for examples: [[111, 1], [320, 1], [644, 1], [691, 1], [727, 1], [848, 1], [892, 1], [893, 1], [1049, 1], [1051, 1]]


In [22]:
##### Function check

assert isinstance(tf_index_2, dict)

assert isinstance(tf_index_2['computer'], list)
print('sample tf index for computer:', tf_index_1['computer'][:10])

assert isinstance(tf_index_2['examples'], list)
print('sample tf index for examples:', tf_index_2['examples'][:10])
#### 

sample tf index for computer: [[4, 1], [7, 1], [10, 1], [13, 1], [19, 1], [22, 1], [23, 1], [37, 1], [40, 3], [41, 1]]
sample tf index for examples: []



---
## Section 3: Ranking  (80 points) <a class="anchor" id="ranking"></a>

[Back to Part 1](#part1)

Now that we have cleaned and processed our dataset, we can start building simple IR systems. 

For now, we consider *simple* IR systems, which involve computing scores from the tokens present in the document/query. More advanced methods are covered in later assignments.

We will implement the following methods in this section:
- [Section 3.1: Bag of Words](#bow) (10 points)
- [Section 3.2: TF-IDF](#tfidf) (15 points)
- [Section 3.3: Query Likelihood Model](#qlm) (35 points)
- [Section 3.4: BM25](#bm25) (20 points)

*All search functions should be able to handle multiple words queries.*

**Scoring policy:**
Your implementations in this section are scored based on the expected performance of your ranking functions.
You will get a full mark if your implementation meets the expected performance (measured by some evaluation metric).
Otherwise, you may get partial credit.
For example, if your *Bag of words* ranking function has 60% of expected performance, you will get 6 out of 10.

--- 

### Section 3.1: Bag of Words (10 points)<a class="anchor" id="bow"></a>

Probably the simplest IR model is the Bag of Words (BOW) model.
Implement a function that scores and ranks all the documents against a query using this model.   

- For consistency, you should use the count of the token and **not** the binary indicator.
- Use `float` type for the scores (even though the scores are integers in this case).
- No normalization of the scores is necessary, as the ordering is what we are interested in.
- If two documents have the same score, they can have any ordering: you are not required to disambiguate.


In [23]:
# TODO: Implement this! (10 points)

def dict_to_arr(d, descending=True):
    arr = []
    for k in d.keys():
         arr.append([str(k), d[k]])
    return sorted(arr, key=lambda x: x[1], reverse=descending)

def bow_search(query, index_set):
    """
        Perform a search over all documents with the given query. 
        Note: You have to use the `get_index` function created in the previous cells
        Input: 
            query - a (unprocessed) query
            index_set - the index to use
        Output: a list of (document_id, score), sorted in descending relevance to the given query. 
    """
    index = get_index(index_set)
    processed_query = preprocess_query(query, index_set)
    scores = {}
    for q in processed_query:
        if q not in index: continue
        for (doc_id, freq) in index[q]:
            if doc_id not in scores:
                scores[doc_id] = float(freq)
            else:
                scores[doc_id] += freq
    return dict_to_arr(scores)

In [24]:
#### Function check

test_bow = bow_search("how to implement bag of words search", index_set=1)[:5]
assert isinstance(test_bow, list)
assert len(test_bow[0]) == 2
assert isinstance(test_bow[0][0], str)
assert isinstance(test_bow[0][1], float)

#### 

In [25]:

docs_by_id = dict(docs)
def print_results(docs, len_limit=50):    
    for i, (doc_id, score) in enumerate(docs):
        doc_content = docs_by_id[doc_id].strip().replace("\n", "\\n")[:len_limit] + "..."
        print(f"Rank {i}({score:.2}): {doc_content}")

test_bow_2 = bow_search("computer search word", index_set=2)[:5]
print(f"BOW Results:")
print_results(test_bow_2)


BOW Results:
Rank 0(1.3e+01): On Computing The Fast Fourier Transform\nCooley an...
Rank 1(1.2e+01): Variable Length Tree Structures Having Minimum Ave...
Rank 2(1.1e+01): A Modular Computer Sharing System\nAn alternative ...
Rank 3(1e+01): PEEKABIT, Computer Offspring of Punched\nCard PEEK...
Rank 4(9.0): Computer Simulation-Discussion of the\nTechnique a...


In [26]:

test_bow_1 = bow_search("computer search word", index_set=1)[:5]
print(f"BOW Results:")
print_results(test_bow_1)


BOW Results:
Rank 0(9.0): CURRICULUM 68 -- Recommendations for Academic\nPro...
Rank 1(9.0): Variable Length Tree Structures Having Minimum Ave...
Rank 2(7.0): Computer Formulation of the Equations of Motion Us...
Rank 3(7.0): The Effects of Multiplexing on a Computer-Communic...
Rank 4(6.0): Optimizing Bit-time Computer Simulation\nA major c...


In [27]:
print('top-5 docs for index1:', list(zip(*test_bow_1[:5]))[0])
print('top-5 docs for index2:', list(zip(*test_bow_2[:5]))[0])


top-5 docs for index1: ('1771', '1936', '1543', '2535', '678')
top-5 docs for index2: ('1525', '1936', '1844', '1700', '1366')



---

### Section 3.2: TF-IDF (15 points) <a class="anchor" id="tfidf"></a>

Before we implement the tf-idf scoring functions, let's first write a function to compute the document frequencies of all words.  

#### 3.2.1 Document frequency (5 points)
Compute the document frequencies of all tokens in the collection. 
Your code should return a dictionary with tokens as its keys and the number of documents containing the token as values.
For consistency, the values should have `int` type.

In [28]:
# TODO: Implement this! (5 points)
def compute_df(documents):
    """
        Compute the document frequency of all terms in the vocabulary
        Input: A list of documents
        Output: A dictionary with {token: document frequency (int)}
    """
    # YOUR CODE HERE
    df = {}
    for doc in documents:
        seen_tks = {}
        for tk in doc:
            if tk not in df:
                df[tk] = 0
            df[tk] += not tk in seen_tks
            seen_tks[tk] = True
    return df

In [29]:
#### Compute df based on the two configs

# get the document frequencies of each document
df_1 = compute_df([d[1] for d in doc_repr_1])
df_2 = compute_df([d[1] for d in doc_repr_2])

def get_df(index_set):
    assert index_set in {1, 2}
    return {
        1: df_1,
        2: df_2
    }[index_set]
####

In [30]:
#### Function check

print(df_1['computer'])
print(df_2['computer'])
####

597
11


---
#### 3.2.2 TF-IDF search (10 points)
Next, implement a function that computes a tf-idf score, given a query.
Use the following formulas for TF and IDF:

$$ TF=\log (1 + f_{d,t}) $$

$$ IDF=\log (\frac{N}{n_t})$$

where $f_{d,t}$ is the frequency of token $t$ in document $d$, $N$ is the number of total documents and $n_t$ is the number of documents containing token $t$.

**Note:** your implementation will be auto-graded assuming you have used the above formulas.


In [31]:
# TODO: Implement this! (10 points)
def tfidf_search(query, index_set):
    """
        Perform a search over all documents with the given query using tf-idf. 
        Note #1: You have to use the `get_index` (and the `get_df`) function created in the previous cells
        Input: 
            query - a (unprocessed) query
            index_set - the index to use
        Output: a list of (document_id, score), sorted in descending relevance to the given query 
    """
    index = get_index(index_set)
    df = get_df(index_set)
    processed_query = preprocess_query(query, index_set)
    
    N = len(docs)
    
    scores = {}
    for q in processed_query:
        if q not in index: continue
        for (doc_id, freq) in index[q]:
            tf = np.log(1 + freq)
            idf = np.log(N/df[q])
            if doc_id not in scores:
                scores[doc_id] = 0
            scores[doc_id] += tf*idf #question: correct to add tfidf score of each word?
  
    return dict_to_arr(scores)

In [32]:

#### Function check
test_tfidf = tfidf_search("how to implement tf idf search", index_set=1)[:5]
assert isinstance(test_tfidf, list)
assert len(test_tfidf[0]) == 2
assert isinstance(test_tfidf[0][0], str)
assert isinstance(test_tfidf[0][1], float)

####

In [33]:

test_tfidf_2 = tfidf_search("computer word search", index_set=2)[:5]
print(f"TFIDF Results:")
print_results(test_tfidf_2)


TFIDF Results:
Rank 0(1.3e+01): PEEKABIT, Computer Offspring of Punched\nCard PEEK...
Rank 1(9.8): Variable Length Tree Structures Having Minimum Ave...
Rank 2(8.2): A Stochastic Approach to the Grammatical Coding of...
Rank 3(8.1): Full Table Quadratic Searching for Scatter Storage...
Rank 4(7.6): Use of Tree Structures for Processing Files\nIn da...


In [34]:

test_tfidf_1 = tfidf_search("computer word search", index_set=1)[:5]
print(f"TFIDF Results:")
print_results(test_tfidf_1)


TFIDF Results:
Rank 0(9.4): Variable Length Tree Structures Having Minimum Ave...
Rank 1(7.4): On the Feasibility of Voice Input to\nan On-line C...
Rank 2(7.3): Median Split Trees: A Fast Lookup Technique for Fr...
Rank 3(7.0): Execution Time Requirements for Encipherment Progr...
Rank 4(7.0): Storage and Search Properties of a Tree-Organized ...


In [35]:

print('top-5 docs for index1 with BOW search:', list(zip(*test_bow_1[:5]))[0])
print('top-5 docs for index2 with BOW search:', list(zip(*test_bow_2[:5]))[0])
print('top-5 docs for index1 with TF-IDF search:', list(zip(*test_tfidf_1[:5]))[0])
print('top-5 docs for index2 with TF-IDF search:', list(zip(*test_tfidf_2[:5]))[0])



top-5 docs for index1 with BOW search: ('1771', '1936', '1543', '2535', '678')
top-5 docs for index2 with BOW search: ('1525', '1936', '1844', '1700', '1366')
top-5 docs for index1 with TF-IDF search: ('1936', '2054', '3041', '2620', '944')
top-5 docs for index2 with TF-IDF search: ('1700', '1936', '1235', '2018', '849')


--- 

### Section 3.3: Query Likelihood Model (35 points) <a class="anchor" id="qlm"></a>

In this section, you will implement a simple query likelihood model. 


#### 3.3.1 Naive QL (15 points)

First, let us implement a naive version of a QL model, assuming a multinomial unigram language model (with a uniform prior over the documents). 



In [36]:
#### Document length for normalization

def doc_lengths(documents):
    doc_lengths = {doc_id:len(doc) for (doc_id, doc) in documents}
    return doc_lengths

doc_lengths_1 = doc_lengths(doc_repr_1)
doc_lengths_2 = doc_lengths(doc_repr_2)

def get_doc_lengths(index_set):
    assert index_set in {1, 2}
    return {
        1: doc_lengths_1,
        2: doc_lengths_2
    }[index_set]
####

In [37]:
# TODO: Implement this! (15 points)
def naive_ql_search(query, index_set):
    """
        Perform a search over all documents with the given query using a naive QL model. 
        Note #1: You have to use the `get_index` (and get_doc_lengths) function created in the previous cells
        Input: 
            query - a (unprocessed) query
            index_set - the index to use
        Output: a list of (document_id, score), sorted in descending relevance to the given query 
    """
    index = get_index(index_set)
    doc_lengths = get_doc_lengths(index_set)
    processed_query = preprocess_query(query, index_set)
    scores = {}
    for q in processed_query:
        if q not in index: continue
        for (doc_id, freq) in index[q]:
            if doc_id not in scores:
                scores[doc_id] = 1
            scores[doc_id] *= freq/doc_lengths[str(doc_id)]
    return dict_to_arr(scores)

In [38]:
#### Function check
test_naiveql = naive_ql_search("report", index_set=1)[:5]
print(f"Naive QL Results:")
print_results(test_naiveql)
####

Naive QL Results:
Rank 0(0.2): A Report Writer For COBOL...
Rank 1(0.2): A CRT Report Generating System...
Rank 2(0.17): Preliminary Report-International Algebraic Languag...
Rank 3(0.17): Supplement to the ALGOL 60 Report...
Rank 4(0.14): ALGOL Sub-Committee Report - Extensions...


In [39]:
#### Please do not change this. This cell is used for grading.

In [40]:
#### Please do not change this. This cell is used for grading.

In [41]:
#### Please do not change this. This cell is used for grading.

In [42]:
#### Please do not change this. This cell is used for grading.

---
#### 3.3.2 QL (20 points)
Now, let's implement a QL model that handles the issues with the naive version. In particular, you will implement a QL model with Jelinek-Mercer Smoothing. That means an interpolated score is computed per word - one term is the same as the previous naive version, and the second term comes from a unigram language model. In addition, you should accumulate the scores by summing the **log** (smoothed) probability which leads to better numerical stability.

In [43]:
# TODO: Implement this! (20 points)

# YOUR CODE HERE
# raise NotImplementedError()

def ql_search(query, index_set):
    """
        Perform a search over all documents with the given query using a QL model 
        with Jelinek-Mercer Smoothing (set smoothing=0.1). 
        
        
        Note #1: You have to use the `get_index` (and get_doc_lengths) function created in the previous cells
        Note #2: You might have to create some variables beforehand and use them in this function
        
        
        Input: 
            query - a (unprocessed) query
            index_set - the index to use
        Output: a list of (document_id, score), sorted in descending relevance to the given query 
    """
    index = get_index(index_set)
    doc_lengths = get_doc_lengths(index_set)
    processed_query = preprocess_query(query, index_set)
    scores = {}
    N = sum(doc_lengths.values())
    for q in processed_query:
        if q not in index: continue
        for (doc_id, freq) in index[q]:
            if doc_id not in scores:
                scores[doc_id] = 0
            scores[doc_id] += np.log(0.9*freq/doc_lengths[str(doc_id)] + 0.1*sum([x[1] for x in index[q]])/N)
    return dict_to_arr(scores)
    

In [44]:
#### Function check
test_ql_results = ql_search("report", index_set=1)[:5]
print_results(test_ql_results)
print()
test_ql_results_long = ql_search("report " * 10, index_set=1)[:5]
print_results(test_ql_results_long)
####

Rank 0(-1.7): A Report Writer For COBOL...
Rank 1(-1.7): A CRT Report Generating System...
Rank 2(-1.9): Preliminary Report-International Algebraic Languag...
Rank 3(-1.9): Supplement to the ALGOL 60 Report...
Rank 4(-2.1): ALGOL Sub-Committee Report - Extensions...

Rank 0(-1.7e+01): A Report Writer For COBOL...
Rank 1(-1.7e+01): A CRT Report Generating System...
Rank 2(-1.9e+01): Preliminary Report-International Algebraic Languag...
Rank 3(-1.9e+01): Supplement to the ALGOL 60 Report...
Rank 4(-2.1e+01): ALGOL Sub-Committee Report - Extensions...


In [45]:
#### Please do not change this. This cell is used for grading.

In [46]:
#### Please do not change this. This cell is used for grading.

In [47]:
#### Please do not change this. This cell is used for grading.

In [48]:
#### Please do not change this. This cell is used for grading.

--- 

### Section 3.4: BM25 (20 points) <a class="anchor" id="bm25"></a>

In this section, we will implement the BM25 scoring function. 


In [49]:
# TODO: Implement this! (20 points)
def bm25_search(query, index_set):
    """
        Perform a search over all documents with the given query using BM25. Use k_1 = 1.5 and b = 0.75
        Note #1: You have to use the `get_index` (and `get_doc_lengths`) function created in the previous cells
        Note #2: You might have to create some variables beforehand and use them in this function
        
        Input: 
            query - a (unprocessed) query
            index_set - the index to use
        Output: a list of (document_id, score), sorted in descending relevance to the given query 
    """
    k_1 = 1.5
    b = 0.75
    index = get_index(index_set)
    df = get_df(index_set)
    doc_lengths = get_doc_lengths(index_set)
    processed_query = preprocess_query(query, index_set)
    scores = {}
    vals = doc_lengths.values()
    N = len(vals)
    dl_avg = sum(vals)/len(vals)
    for q in processed_query:
        if q not in index: continue
        for (doc_id, freq) in index[q]:
            if doc_id not in scores:
                scores[doc_id] = 0
            scores[doc_id] += np.log(N/df[q])*(k_1+1)*freq/(k_1*(1-b + b*doc_lengths[str(doc_id)]/dl_avg)+freq)
    return dict_to_arr(scores)

In [50]:
#### Function check
test_bm25_results = bm25_search("report", index_set=1)[:5]
print_results(test_bm25_results)
####

Rank 0(6.7): A Report Writer For COBOL...
Rank 1(6.7): A CRT Report Generating System...
Rank 2(6.6): Preliminary Report-International Algebraic Languag...
Rank 3(6.6): Supplement to the ALGOL 60 Report...
Rank 4(6.5): ALGOL Sub-Committee Report - Extensions...


In [51]:
#### Please do not change this. This cell is used for grading.

In [52]:
#### Please do not change this. This cell is used for grading.

In [53]:
#### Please do not change this. This cell is used for grading.

In [54]:
#### Please do not change this. This cell is used for grading.


---

### 3.5. Test Your Functions

The widget below allows you to play with the search functions you've written so far. Use this to test your search functions and ensure that they work as expected.

In [55]:
#### Highlighter function
# class for results
ResultRow = namedtuple("ResultRow", ["doc_id", "snippet", "score"])
# doc_id -> doc
docs_by_id = dict((d[0], d[1]) for d in docs)

def highlight_text(document, query, tol=17):
    import re
    tokens = tokenize(query)
    regex = "|".join(f"(\\b{t}\\b)" for t in tokens)
    regex = re.compile(regex, flags=re.IGNORECASE)
    output = ""
    i = 0
    for m in regex.finditer(document):
        start_idx = max(0, m.start() - tol)
        end_idx = min(len(document), m.end() + tol)
        output += "".join(["...",
                        document[start_idx:m.start()],
                        "<strong>",
                        document[m.start():m.end()],
                        "</strong>",
                        document[m.end():end_idx],
                        "..."])
    return output.replace("\n", " ")


def make_results(query, search_fn, index_set):
    results = []
    for doc_id, score in search_fn(query, index_set):
        highlight = highlight_text(docs_by_id[doc_id], query)
        if len(highlight.strip()) == 0:
            highlight = docs_by_id[doc_id]
        results.append(ResultRow(doc_id, highlight, score))
    return results
####

In [56]:
# TODO: Set this to the function you want to test!
# this function should take in a query (string)
# and return a sorted list of (doc_id, score) 
# with the most relevant document in the first position
search_fn = bm25_search
index_set = 1

text = widgets.Text(description="Search Bar", width=200)
display(text)

def handle_submit(sender):
    print(f"Searching for: '{sender.value}'")
    
    results = make_results(sender.value, search_fn, index_set)
    
    # display only the top 5
    results = results[:5]
    
    body = ""
    for idx, r in enumerate(results):
        body += f"<li>Document #{r.doc_id}({r.score}): {r.snippet}</li>"
    display(HTML(f"<ul>{body}</ul>"))
    

text.on_submit(handle_submit)

Text(value='', description='Search Bar')

---

## Section 4: Evaluation (40 points) <a class="anchor" id="evaluation"></a>

[Back to Part 1](#part1)

In order to analyze the effectiveness of retrieval algorithms, we first have to learn how to evaluate such a system. In particular, we will work with offline evaluation metrics. These metrics are computed on a dataset with known relevance judgements.

Implement the following evaluation metrics. 

1. Precision (7 points)
2. Recall (7 points)
3. Mean Average Precision (13 points)
4. Expected Reciprocal Rank (13 points)

---
### 4.1 Read relevance labels

Let's take a look at the `qrels.text` file, which contains the ground truth relevance scores. The relevance labels for CACM are binary - either 0 or 1. 


In [None]:
!head ./datasets/qrels.text

---

The first column is the query_id and the second column is the document_id. We can safely ignore the 3rd and 4th columns.

In [None]:
# YOUR CODE HERE
raise NotImplementedError()

In [None]:
#### Function check
qrels = read_qrels()

assert len(qrels) == 52, "There should be 52 queries with relevance judgements"
assert sum(len(j) for j in qrels.values()) == 796, "There should be a total of 796 Relevance Judgements"

assert np.min(np.array([len(j) for j in qrels.values()])) == 1
assert np.max(np.array([len(j) for j in qrels.values()])) == 51

####

---
**Note:** For a given query `query_id`, you can assume that documents *not* in `qrels[query_id]` are not relevant to `query_id`. 


---
### 4.2 Precision (7 points)
Implement the `precision@k` metric:

In [None]:
# TODO: Implement this! (7 points)
def precision_k(results, relevant_docs, k):
    """
        Compute Precision@K
        Input: 
            results: A sorted list of 2-tuples (document_id, score), 
                    with the most relevant document in the first position
            relevant_docs: A set of relevant documents. 
            k: the cut-off
        Output: Precision@K
    """
    if k > len(results):
        k = len(results)
    # YOUR CODE HERE
    raise NotImplementedError()

In [None]:

#### Function check
qid = queries[0][0]
qtext = queries[0][1]
print(f'query:{qtext}')
results = bm25_search(qtext, 2)
precision = precision_k(results, qrels[qid], 10)
print(f'precision@10 = {precision}')
####

---
### 4.3 Recall (7 points)
Implement the `recall@k` metric:

In [None]:
# TODO: Implement this! (7 points)
def recall_k(results, relevant_docs, k):
    """
        Compute Recall@K
        Input: 
            results: A sorted list of 2-tuples (document_id, score), with the most relevant document in the first position
            relevant_docs: A set of relevant documents. 
            k: the cut-off
        Output: Recall@K
    """
    # YOUR CODE HERE
    raise NotImplementedError()

In [None]:
#### Function check
qid = queries[10][0]
qtext = queries[10][1]
print(f'query:{qtext}')
results = bm25_search(qtext, 2)
recall = recall_k(results, qrels[qid], 10)
print(f'recall@10 = {recall}')
####

---
### 4.4 Mean Average Precision (13 points)
Implement the `map` metric:

In [None]:
# TODO: Implement this! (12 points)
def average_precision(results, relevant_docs):
    """
        Compute Average Precision (for a single query - the results are 
        averaged across queries to get MAP in the next few cells)
        Hint: You can use the recall_k and precision_k functions here!
        Input: 
            results: A sorted list of 2-tuples (document_id, score), with the most 
                    relevant document in the first position
            relevant_docs: A set of relevant documents. 
        Output: Average Precision
    """
    # YOUR CODE HERE
    raise NotImplementedError()


In [None]:
#### Function check
qid = queries[20][0]
qtext = queries[20][1]
print(f'query:{qtext}')
results = bm25_search(qtext, 2)
mean_ap = average_precision(results, qrels[qid])
print(f'MAP = {mean_ap}')
####

---
### 4.5 Expected Reciprocal Rank (13 points)
Implement the `err` metric:

In [None]:
# TODO: Implement this! (12 points)
def err(results, relevant_docs):
    """
        Compute the expected reciprocal rank.
        Hint: https://dl.acm.org/doi/pdf/10.1145/1645953.1646033?download=true
        Input: 
            results: A sorted list of 2-tuples (document_id, score), with the most 
                    relevant document in the first position
            relevant_docs: A set of relevant documents. 
        Output: ERR
        
    """
    # YOUR CODE HERE
    raise NotImplementedError()

In [None]:
#### Function check
qid = queries[30][0]
qtext = queries[30][1]
print(f'query:{qtext}')
results = bm25_search(qtext, 2)
ERR = err(results, qrels[qid])
print(f'ERR = {ERR}')
####

---
### 4.6 Evaluate Search Functions

Let's define some metrics@k using [partial functions](https://docs.python.org/3/library/functools.html#functools.partial)

In [None]:
#### metrics@k functions

recall_at_1 = partial(recall_k, k=1)
recall_at_5 = partial(recall_k, k=5)
recall_at_10 = partial(recall_k, k=10)
precision_at_1 = partial(precision_k, k=1)
precision_at_5 = partial(precision_k, k=5)
precision_at_10 = partial(precision_k, k=10)


list_of_metrics = [
    ("ERR", err),
    ("MAP", average_precision),
    ("Recall@1",recall_at_1),
    ("Recall@5", recall_at_5),
    ("Recall@10", recall_at_10),
    ("Precision@1", precision_at_1),
    ("Precision@5", precision_at_5),
    ("Precision@10", precision_at_10)]
####

---

The following function evaluates a `search_fn` using the `metric_fn`. Note that the final number is averaged over all the queries

In [None]:
#### Evaluate a search function

list_of_search_fns = [
    ("BOW", bow_search),
    ("TF-IDF", tfidf_search),
    ("NaiveQL", naive_ql_search),
    ("QL", ql_search),
    ("BM25", bm25_search)
]

def evaluate_search_fn(search_fn, metric_fns, index_set=None):
    # build a dict query_id -> query 
    queries_by_id = dict((q[0], q[1]) for q in queries)
    
    metrics = {}
    for metric, metric_fn in metric_fns:
        metrics[metric] = np.zeros(len(qrels), dtype=np.float32)
    
    for i, (query_id, relevant_docs) in enumerate(qrels.items()):
        query = queries_by_id[query_id]
        if index_set:
            results = search_fn(query, index_set)
        else:
            results = search_fn(query)
        
        for metric, metric_fn in metric_fns:
            metrics[metric][i] = metric_fn(results, relevant_docs)

    
    
    final_dict = {}
    for metric, metric_vals in metrics.items():
        final_dict[metric] = metric_vals.mean()
    
    return final_dict
####

## Section 5: Analysis (30 points) <a class="anchor" id="analysis"></a>

[Back to Part 1](#part1)

In the final section of Part1, we will compare the different term-based IR algorithms and different preprocessing configurations and analyze their advantages and disadvantages.

### Section 5.1: Plot (20 points)

First, gather the results. The results should consider the index set, the different search functions and different metrics. Plot the results in bar charts, per metric, with clear labels.

**Rubric:**
- Each Metric is plotted: 7 points
- Each Method is plotted: 7 points
- Clear titles, x label, y labels and legends (if applicable): 6 points

In [None]:
# YOUR CODE HERE
raise NotImplementedError()

---
### Section 5.2: Summary (10 points)
Write a summary of what you observe in the results.
Your summary should compare results across the 2 indices and the methods being used. State what you expected to see in the results, followed by either supporting evidence *or* justify why the results did not support your expectations.      

Write your answer here!

---
---
# Part 2: Semantic-based Matching (85 points) <a class="anchor" id="part2"></a>

[Back to top](#top)

We will now experiment with methods that go beyond lexical methods like TF-IDF, which operate at the word level and are high dimensional and sparse, and look at methods which constructs low dimensional dense representations of queries and documents. 

Since these low-dimensional methods have a higher time complexity, they are typically used in conjunction with methods like BM-25. That is, instead of searching through potentially million documents to find matches using low dimensional vectors, a list of K documents are retrieved using BM25, and then **re-ranked** using the other method. This is the method that is going to be applied in the following exercises. 

LSI/LDA takes documents that are similar on a semantic level - for instance, if they are describing the same topic - and projects them into nearby vectors, despite having low lexical overlap.

In this assignment, you will use `gensim` to create LSI/LDA models and use them in re-ranking. 

**Note**: The following exercises only uses `doc_repr_2` and `config_2`

Table of contents:
- [Section 6: LSI](#lsi) (15 points)
- [Section 7: LDA](#lda) (10 points)
- [Section 8: Word2Vec/Doc2Vec](#2vec) (20 points)
- [Section 8: Re-ranking](#reranking) (10 points)
- [Section 9: Re-ranking Evaluation](#reranking_eval) (30 points)

---
## Section 6: Latent Semantic Indexing (LSI) (15 points) <a class="anchor" id="lsi"></a>

[Back to Part 2](#part2)

LSI is one of the methods to embed the queries and documents into vectors. It is based on a method similar to Principal Component Analysis (PCA) for obtaining a dense concept matrix out of the sparse term-document matrix.

See [wikipedia](https://en.wikipedia.org/wiki/Latent_semantic_analysis), particularly [#Mathematics_of_LSI](https://en.wikipedia.org/wiki/Latent_semantic_analysis#Mathematics_of_LSI).

In [57]:
from gensim.corpora import Dictionary
from gensim.models import LdaModel, LsiModel, Word2Vec
from gensim.models.doc2vec import Doc2Vec, TaggedDocument
from gensim import downloader as g_downloader
# gensim uses logging, so set it up 
import logging
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)

---
### Section 6.1: Cosine Similarity (5 points)<a class="anchor" id="cosing_sim"></a>
Before we begin, let us first define our method of similarity for the LSI model, the cosine similarity:

$$\text{similarity} = \cos(\theta) = {\mathbf{A} \cdot \mathbf{B} \over \|\mathbf{A}\| \|\mathbf{B}\|} = \frac{ \sum\limits_{i=1}^{n}{A_i  B_i} }{ \sqrt{\sum\limits_{i=1}^{n}{A_i^2}}  \sqrt{\sum\limits_{i=1}^{n}{B_i^2}} }$$

Since we are using gensim, the types of vectors returned by their classes are of the form defined below (they are not just simple vectors):

In [202]:
# 1, 2, 3 are either latent dimensions (LSI), or topics (LDA)
# The second value in each tuple is a number (LSI) or a probability (LDA)  
example_vec_1 = [(1, 0.2), (2, 0.3), (3, 0.4)]
example_vec_2 = [(1, 0.2), (2, 0.7), (3, 0.4)]

---
**Implementation (2+3 points):**
Now, implement the `dot product` operation on these types of vectors and using this operator, implement the `cosine similarity` (don't forget: two functions to implement!):

In [211]:
def dot(vec_1,vec_2): 
    """
        vec_1 and vec_2 are of the form: [(int, float), (int, float), ...]
        Return the dot product of two such vectors, computed only on the floats
        You can assume that the lengths of the vectors are the same, and the dimensions are aligned 
            i.e you won't get: vec_1 = [(1, 0.2)] ; vec_2 = [(2, 0.3)] 
                                (dimensions are unaligned and lengths are different)
    """
    ar_1=np.array(vec_1)[:,1]
    ar_2=np.array(vec_2)[:,1]
    combined=np.dot(ar_1,ar_2)
    return combined


# TODO: Implement this! (3 points)
def cosine_sim(vec_1, vec_2):
    d = dot(vec_1,vec_2)
    ar_1=np.array(vec_1)[:,1]
    ar_2=np.array(vec_2)[:,1]
    squar_1=np.sqrt(np.sum(np.square(ar_1)))
    squar_2=np.sqrt(np.sum(np.square(ar_2)))
    if np.dot(squar_1,squar_2)==0:
        return 0
    else:
        return d/(np.dot(squar_1,squar_2))

In [212]:
##### Function check
print(f'vectors: {(example_vec_1,example_vec_2)}')
print(f'dot product = {dot(example_vec_1,example_vec_2)}')
print(f'cosine similarity = {cosine_sim(example_vec_1,example_vec_2)}')
##### 

vectors: ([(1, 0.2), (2, 0.3), (3, 0.4)], [(1, 0.2), (2, 0.7), (3, 0.4)])
dot product = 0.41000000000000003
cosine similarity = 0.9165587597202866


In [213]:
#### Please do not change this. This cell is used for grading.

---
### Section 6.2: LSI Retrieval (10 points)<a class="anchor" id="lsi_retrieval"></a>
LSI retrieval is simply ranking the documents based on their cosine similarity to the query vector.
First, let's write a parent class for vector-based retrieval models:

In [214]:
class VectorSpaceRetrievalModel:
    """
        Parent class for Dense Vector Retrieval models
    """
    def __init__(self, doc_repr):
        """
            document_collection: 
                [
                    (doc_id_1, [token 1, token 2, ...]), 
                    (doc_id_2, [token 1, token 2, ....]) 
                    ...
                ]

        """
        self.doc_repr = doc_repr
        self.documents = [_[1] for _ in self.doc_repr]
        
        # construct a dictionary
        self.dictionary = Dictionary(self.documents)
        # Filter out words that occur less than 20 documents, or more than 50% of the documents.
        self.dictionary.filter_extremes(no_below=10)
        self.corpus = [self.dictionary.doc2bow(doc) for doc in self.documents]
    
        # Make a index to word dictionary.
        temp = self.dictionary[0]  # This is only to "load" the dictionary.
        self.id2word = self.dictionary.id2token
        
        # this is set by the train_model function
        self.model = None
        
        
    def vectorize_documents(self):
        """
            Returns a doc_id -> vector dictionary
        """
        vectors = {}
        for (doc_id, _), cc in zip(self.doc_repr, self.corpus):
            vectors[doc_id] = self.model[cc]
        return vectors

    def vectorize_query(self, query):
        # Note the use of config_2 here!
        query = process_text(query, **config_2)
        query_vector = self.dictionary.doc2bow(query)
        return self.model[query_vector]
    
    def train_model(self):
        """
            Trains a model and sets the 'self.model' variable. 
            Make sure to use the variables created in the __init__ method.
            e.g the variables which may be useful: {corpus, dictionary, id2word}
        """
        raise NotImplementedError()

---
**Implementation (5 points):**
Implement the `train_model` method in the following class (note that this is only one line of code in `gensim`!). Ensure that the parameters defined in the `__init__` method are not changed, and are *used in the `train_method` function*. Normally, the hyperaparameter space will be searched using grid search / other methods - in this assignment we have provided the hyperparameters for you.

The last two lines of code train an LSI model on the list of documents which have been stemmed, lower-cased and have stopwords removed. 

In [220]:
# TODO: Implement this! (5 points)
class LsiRetrievalModel(VectorSpaceRetrievalModel):
    def __init__(self, doc_repr):
        super().__init__(doc_repr)
        
        self.num_topics = 100
        self.chunksize = 2000
    
    def train_model(self):
        # YOUR CODE HERE
        self.model=LsiModel(self.corpus, id2word=self.id2word, num_topics=self.num_topics, chunksize=self.chunksize)

In [221]:
##### Function check
lsi = LsiRetrievalModel(doc_repr_2)
lsi.train_model()

# you can now get an LSI vector for a given query in the following way:
lsi.vectorize_query("report")
##### 

2022-02-26 11:45:55,686 : INFO : adding document #0 to Dictionary(0 unique tokens: [])
2022-02-26 11:45:55,782 : INFO : built Dictionary(5937 unique tokens: ['-', 'algebra', 'intern', 'languag', 'preliminari']...) from 3204 documents (total 115969 corpus positions)
2022-02-26 11:45:55,782 : INFO : Dictionary lifecycle event {'msg': "built Dictionary(5937 unique tokens: ['-', 'algebra', 'intern', 'languag', 'preliminari']...) from 3204 documents (total 115969 corpus positions)", 'datetime': '2022-02-26T11:45:55.782302', 'gensim': '4.1.2', 'python': '3.9.7 (default, Sep 16 2021, 16:59:28) [MSC v.1916 64 bit (AMD64)]', 'platform': 'Windows-10-10.0.19042-SP0', 'event': 'created'}
2022-02-26 11:45:55,788 : INFO : discarding 4740 tokens: [('repeat', 8), ('glossari', 7), ('inspect', 8), ('uncol', 2), ('rung', 9), ('secant', 2), ('.', 1603), ('acceler', 6), ('diverg', 3), ('induc', 9)]...
2022-02-26 11:45:55,789 : INFO : keeping 1197 tokens which were in no less than 10 and no more than 1602 (

[(0, 0.015213302807929903),
 (1, -0.016274341440029694),
 (2, -0.00017902211071320383),
 (3, -0.0018127040612217764),
 (4, -0.009452968960287847),
 (5, -0.00472756108952735),
 (6, 0.02711740277072249),
 (7, 0.01668037194999458),
 (8, -0.031770769712443275),
 (9, -0.0006169458493351827),
 (10, 0.0022277170174060125),
 (11, -0.01739213111239926),
 (12, -0.00020358123087752499),
 (13, 0.0014934418656570476),
 (14, 0.003916664360946832),
 (15, 0.004993555466143907),
 (16, 0.007020789803054808),
 (17, 0.002595381778908421),
 (18, -0.016879161735805417),
 (19, 0.020073450520769802),
 (20, -0.009063277772332424),
 (21, -0.01350141851245043),
 (22, 0.046832568385803755),
 (23, 0.025094162009132757),
 (24, -0.011256181751637047),
 (25, -0.009927817947457148),
 (26, 0.006782047431902824),
 (27, 0.07642407193306652),
 (28, -0.06195462167252714),
 (29, 0.03175008468368986),
 (30, 0.042641934894295755),
 (31, 0.05053819327459815),
 (32, -0.06705859038885291),
 (33, 0.050964267103012884),
 (34, -0.0

\#### Please do not change this. This cell is used for grading.

---
**Implementation (5 points):**
 Next, implement a basic ranking class for vector space retrieval (used for all semantic methods): 

In [238]:
# TODO: Implement this! (5 points)
class DenseRetrievalRanker:
    def __init__(self, vsrm, similarity_fn):
        """
            vsrm: instance of `VectorSpaceRetrievalModel`
            similarity_fn: function instance that takes in two vectors 
                            and returns a similarity score e.g cosine_sim defined earlier
        """
        self.vsrm = vsrm
        self.vectorized_documents = self.vsrm.vectorize_documents()
        self.similarity_fn = similarity_fn
    
    def _compute_sim(self, query_vector):
        """
            Compute the similarity of `query_vector` to documents in 
            `self.vectorized_documents` using `self.similarity_fn`
            Returns a list of (doc_id, score) tuples
        """
        tup=[]
        for i in self.vectorized_documents:
            
            if np.array(self.vectorized_documents[i]).size==0 or  np.array(query_vector).size==0:
                continue
            
            if np.array(self.vectorized_documents[i]).shape!=np.array(query_vector).shape:
                continue
            else:
                scor=self.similarity_fn(self.vectorized_documents[i],query_vector)
            tup.append((i,scor))
        return tup
    
    def search(self, query):
        scores = self._compute_sim(self.vsrm.vectorize_query(query))
        scores.sort(key=lambda _:-_[1])
        return scores 

In [222]:
##### Function check
drm_lsi = DenseRetrievalRanker(lsi, cosine_sim)
drm_lsi.search("report")[:5]
##### 

[('599', 0.8006960902481483),
 ('947', 0.5866511260880892),
 ('53', 0.5082506706604544),
 ('1339', 0.46904827981259417),
 ('3160', 0.4438431412248785)]

\#### Please do not change this. This cell is used for grading.

---
Now, you can test your LSI model in the following cell: try finding queries which are lexically different to documents, but semantically similar - does LSI work well for these queries?!

In [239]:
# test your LSI model
search_fn = drm_lsi.search

text = widgets.Text(description="Search Bar", width=200)
display(text)

def make_results_2(query, search_fn):
    results = []
    for doc_id, score in search_fn(query):
        highlight = highlight_text(docs_by_id[doc_id], query)
        if len(highlight.strip()) == 0:
            highlight = docs_by_id[doc_id]
        results.append(ResultRow(doc_id, highlight, score))
    return results

def handle_submit_2(sender):
    print(f"Searching for: '{sender.value}' (SEARCH FN: {search_fn})")
    
    results = make_results_2(sender.value, search_fn)
    
    # display only the top 5
    results = results[:5]
    
    body = ""
    for idx, r in enumerate(results):
        body += f"<li>Document #{r.doc_id}({r.score}): {r.snippet}</li>"
    display(HTML(f"<ul>{body}</ul>"))
    

text.on_submit(handle_submit_2)

Text(value='', description='Search Bar')

---
## Section 7: Latent Dirichlet Allocation (LDA) (10 points) <a class="anchor" id="lda"></a>

[Back to Part 2](#part2)

The specifics of LDA is out of the scope of this assignment, but we will use the `gensim` implementation to perform search using LDA over our small document collection. The key thing to remember is that LDA, unlike LSI, outputs a topic **distribution**, not a vector. With that in mind, let's first define a similarity measure.


---
### Section 7.1: Jenson-Shannon divergence (5 points) <a class="anchor" id="js_sim"></a>

The Jenson-Shannon divergence is a symmetric and finite measure on two probability distributions (unlike the KL, which is neither). For identical distributions, the JSD is equal to 0, and since our code uses 0 as irrelevant and higher scores as relevant, we use `(1 - JSD)` as the score or 'similarity' in our setup

**Note**: the JSD is bounded to \[0,1\] only if we use log base 2. So please ensure that you're using `np.log2` instead of `np.log`

In [131]:
## TODO: Implement this! (5 points)
def KL_divergence(p, q):
    """ Compute KL divergence of two vectors, K(p || q)."""
    return sum(p[i] * np.log2(p[i]/q[i]) for i in range(len(p)) if q[i] != 0 or q[i] != 0.0)

def jenson_shannon_divergence(vec_1, vec_2, assert_prob=False):
    """
        Computes the Jensen-Shannon divergence between two probability distributions. 
        NOTE: DO NOT RETURN 1 - JSD here, that is handled by the next function which is already implemented! 
        The inputs are *gensim* vectors - same as the vectors for the cosine_sim function
        assert_prob is a flag that checks if the inputs are proper probability distributions 
            i.e they sum to 1 and are positive - use this to check your inputs if needed. 
                (This is optional to implement, but recommended - 
                you can the default to False to save a few ms off the runtime)
    """
    
    if not vec_1 or not vec_2:
        return 0
    vec_1=np.array(vec_1)[:,1]
    vec_2=np.array(vec_2)[:,1]
    if vec_1.size!=vec_2.size:
        return 0
    m = 0.5 * (vec_1 + vec_2)
    return 0.5 * KL_divergence(vec_1, m) + 0.5 * KL_divergence(vec_2, m)
    

def jenson_shannon_sim(vec_1, vec_2, assert_prob=False):
    return 1 - jenson_shannon_divergence(vec_1, vec_2)



In [132]:
##### Function check
vec_1 = [(1, 0.3), (2, 0.4), (3, 0.3)]
vec_2 = [(1, 0.1), (2, 0.7), (3, 0.2)]
jenson_shannon_sim(vec_1, vec_2, assert_prob=True)
##### 

0.9251064410358459

---
### Section 7.2: LDA retrieval (5 points) <a class="anchor" id="lda_ret"></a>

Implement the `train_model` method in the following class (note that this is only one line of code in `gensim`!). Ensure that the parameters defined in the `__init__` method are not changed, and are *used in the `train_method` function*. You do not need to set this. Normally, the hyperaparameter space will be searched using grid search / other methods. Note that training the LDA model might take some time

The last two lines of code train an LDA model on the list of documents which have been stemmed, lower-cased and have stopwords removed. 

In [133]:
# TODO: Implement this! (5 points)
class LdaRetrievalModel(VectorSpaceRetrievalModel):
    def __init__(self, doc_repr):
        super().__init__(doc_repr)
        
        # use these parameters in the train_model method
        self.num_topics = 100
        self.chunksize = 2000
        self.passes = 20
        self.iterations = 400
        self.eval_every = 10
        # this is need to get full vectors
        self.minimum_probability=0.0
        self.alpha='auto'
        self.eta='auto'
    
    
    def train_model(self):
        self.model=LdaModel(self.corpus, num_topics=self.num_topics, chunksize=self.chunksize, passes=self.passes, 
                           iterations=self.iterations, eval_every=self.eval_every, minimum_probability=self.minimum_probability,
                           alpha=self.alpha, eta=self.eta)

In [134]:
##### Function check
lda = LdaRetrievalModel(doc_repr_2)
lda.train_model()

# you can now get an LDA vector for a given query in the following way:
lda.vectorize_query("report")
##### 

2022-02-26 10:41:42,051 : INFO : adding document #0 to Dictionary(0 unique tokens: [])
2022-02-26 10:41:42,144 : INFO : built Dictionary(5937 unique tokens: ['algebra', 'intern', 'report', 'comput', 'digit']...) from 3204 documents (total 115961 corpus positions)
2022-02-26 10:41:42,145 : INFO : Dictionary lifecycle event {'msg': "built Dictionary(5937 unique tokens: ['algebra', 'intern', 'report', 'comput', 'digit']...) from 3204 documents (total 115961 corpus positions)", 'datetime': '2022-02-26T10:41:42.145236', 'gensim': '4.1.2', 'python': '3.9.7 (default, Sep 16 2021, 16:59:28) [MSC v.1916 64 bit (AMD64)]', 'platform': 'Windows-10-10.0.19042-SP0', 'event': 'created'}
2022-02-26 10:41:42,150 : INFO : discarding 4740 tokens: [('repeat', 8), ('inspect', 8), ('glossari', 6), ('uncol', 2), ('rung', 9), ('secant', 2), ('.', 1603), ('acceler', 6), ('diverg', 3), ('induc', 9)]...
2022-02-26 10:41:42,151 : INFO : keeping 1197 tokens which were in no less than 10 and no more than 1602 (=50.

2022-02-26 10:41:46,324 : INFO : merging changes from 2000 documents into a model of 3204 documents
2022-02-26 10:41:46,335 : INFO : topic #20 (0.010): 0.113*"96" + 0.049*"227" + 0.040*"554" + 0.036*"46" + 0.036*"476" + 0.035*"640" + 0.034*"3" + 0.034*"309" + 0.028*"533" + 0.028*"241"
2022-02-26 10:41:46,336 : INFO : topic #71 (0.010): 0.058*"46" + 0.043*"105" + 0.043*"37" + 0.027*"131" + 0.027*"14" + 0.025*"29" + 0.023*"124" + 0.020*"30" + 0.019*"497" + 0.019*"128"
2022-02-26 10:41:46,337 : INFO : topic #73 (0.010): 0.129*"10" + 0.063*"843" + 0.029*"46" + 0.029*"49" + 0.027*"277" + 0.025*"572" + 0.024*"670" + 0.023*"544" + 0.022*"112" + 0.021*"163"
2022-02-26 10:41:46,338 : INFO : topic #83 (0.010): 0.067*"46" + 0.025*"29" + 0.021*"236" + 0.020*"13" + 0.019*"30" + 0.018*"602" + 0.016*"195" + 0.014*"1024" + 0.013*"45" + 0.012*"581"
2022-02-26 10:41:46,338 : INFO : topic #34 (0.010): 0.084*"46" + 0.059*"3" + 0.027*"10" + 0.025*"280" + 0.017*"13" + 0.015*"581" + 0.014*"173" + 0.014*"112"

2022-02-26 10:41:50,486 : INFO : topic #0 (0.011): 0.158*"457" + 0.089*"427" + 0.049*"46" + 0.038*"112" + 0.035*"34" + 0.034*"237" + 0.029*"254" + 0.028*"35" + 0.025*"575" + 0.022*"788"
2022-02-26 10:41:50,486 : INFO : topic #73 (0.011): 0.159*"10" + 0.069*"843" + 0.046*"670" + 0.045*"544" + 0.035*"572" + 0.034*"46" + 0.028*"277" + 0.028*"49" + 0.026*"536" + 0.022*"163"
2022-02-26 10:41:50,487 : INFO : topic #60 (0.012): 0.329*"29" + 0.318*"30" + 0.210*"79" + 0.031*"13" + 0.016*"252" + 0.015*"36" + 0.012*"271" + 0.011*"384" + 0.008*"10" + 0.006*"1182"
2022-02-26 10:41:50,488 : INFO : topic diff=0.529146, rho=0.466151
2022-02-26 10:41:50,497 : INFO : PROGRESS: pass 3, at document #2000/3204
2022-02-26 10:41:51,277 : INFO : optimized alpha [0.011162647, 0.009957531, 0.010647279, 0.009385372, 0.011181587, 0.010127835, 0.010368086, 0.0101258755, 0.009801412, 0.010053557, 0.0100717, 0.010353772, 0.010026719, 0.010845297, 0.010106359, 0.00982326, 0.01039442, 0.010380223, 0.010262536, 0.01092

2022-02-26 10:41:53,610 : INFO : topic diff=0.638946, rho=0.389191
2022-02-26 10:41:54,411 : INFO : -6.639 per-word bound, 99.7 perplexity estimate based on a held-out corpus of 1204 documents with 49783 words
2022-02-26 10:41:54,411 : INFO : PROGRESS: pass 4, at document #3204/3204
2022-02-26 10:41:55,095 : INFO : optimized alpha [0.01206475, 0.010039994, 0.011167563, 0.0092677, 0.012272558, 0.010670168, 0.010755426, 0.010427304, 0.009941833, 0.010386552, 0.010384931, 0.01075818, 0.010328612, 0.0114060035, 0.010296734, 0.009842053, 0.010656781, 0.010716509, 0.010801204, 0.011550783, 0.01007698, 0.011019227, 0.010314186, 0.010721247, 0.010582562, 0.010213689, 0.010493461, 0.0102994675, 0.010342013, 0.010525128, 0.011567231, 0.010924211, 0.010022215, 0.010082126, 0.01176054, 0.010122099, 0.010543941, 0.01110283, 0.0109616155, 0.010481556, 0.01095139, 0.010101165, 0.0103951115, 0.010562395, 0.0106681455, 0.010281828, 0.0105315, 0.010142681, 0.010367041, 0.010914421, 0.010690818, 0.010650

2022-02-26 10:41:58,126 : INFO : merging changes from 2000 documents into a model of 3204 documents
2022-02-26 10:41:58,137 : INFO : topic #3 (0.009): 0.097*"46" + 0.047*"410" + 0.041*"11" + 0.037*"913" + 0.035*"636" + 0.031*"603" + 0.029*"37" + 0.028*"786" + 0.027*"1178" + 0.021*"271"
2022-02-26 10:41:58,137 : INFO : topic #71 (0.010): 0.075*"105" + 0.072*"46" + 0.056*"124" + 0.054*"117" + 0.053*"131" + 0.049*"684" + 0.034*"37" + 0.032*"497" + 0.030*"128" + 0.027*"574"
2022-02-26 10:41:58,137 : INFO : topic #73 (0.013): 0.177*"10" + 0.057*"843" + 0.055*"544" + 0.051*"670" + 0.037*"572" + 0.035*"46" + 0.028*"277" + 0.027*"49" + 0.024*"536" + 0.024*"76"
2022-02-26 10:41:58,137 : INFO : topic #4 (0.013): 0.189*"24" + 0.126*"25" + 0.114*"50" + 0.055*"46" + 0.047*"343" + 0.043*"37" + 0.043*"332" + 0.039*"13" + 0.036*"340" + 0.031*"98"
2022-02-26 10:41:58,138 : INFO : topic #60 (0.020): 0.382*"29" + 0.355*"30" + 0.185*"79" + 0.017*"13" + 0.010*"252" + 0.009*"271" + 0.008*"36" + 0.005*"384" 

2022-02-26 10:42:01,944 : INFO : topic #73 (0.014): 0.184*"10" + 0.068*"544" + 0.059*"670" + 0.046*"843" + 0.043*"572" + 0.035*"46" + 0.026*"536" + 0.025*"49" + 0.025*"277" + 0.023*"76"
2022-02-26 10:42:01,944 : INFO : topic #4 (0.014): 0.217*"24" + 0.131*"25" + 0.107*"50" + 0.056*"46" + 0.053*"343" + 0.040*"332" + 0.039*"13" + 0.038*"37" + 0.032*"98" + 0.029*"340"
2022-02-26 10:42:01,945 : INFO : topic #60 (0.024): 0.404*"29" + 0.361*"30" + 0.169*"79" + 0.016*"13" + 0.008*"271" + 0.008*"252" + 0.006*"36" + 0.005*"1182" + 0.003*"882" + 0.003*"384"
2022-02-26 10:42:01,946 : INFO : topic diff=0.518143, rho=0.322715
2022-02-26 10:42:01,954 : INFO : PROGRESS: pass 8, at document #2000/3204
2022-02-26 10:42:02,704 : INFO : optimized alpha [0.013639385, 0.010254586, 0.012259663, 0.009202826, 0.014707834, 0.0119151715, 0.0113471225, 0.011042808, 0.0103083495, 0.011092439, 0.010914369, 0.01154084, 0.010974896, 0.012403592, 0.010733888, 0.009922705, 0.011089093, 0.011278086, 0.011918371, 0.0124

2022-02-26 10:42:05,004 : INFO : topic diff=0.389386, rho=0.293585
2022-02-26 10:42:05,789 : INFO : -6.503 per-word bound, 90.7 perplexity estimate based on a held-out corpus of 1204 documents with 49783 words
2022-02-26 10:42:05,789 : INFO : PROGRESS: pass 9, at document #3204/3204
2022-02-26 10:42:06,458 : INFO : optimized alpha [0.01453877, 0.010414629, 0.012880193, 0.009298655, 0.015976928, 0.01258884, 0.011704034, 0.011355472, 0.010442374, 0.011407567, 0.011264287, 0.011885094, 0.011233668, 0.012862353, 0.01092513, 0.010135198, 0.01132368, 0.011658769, 0.012562383, 0.012916463, 0.010904072, 0.012545547, 0.010571669, 0.011469571, 0.011091561, 0.011132816, 0.011171654, 0.011193949, 0.010974198, 0.011121436, 0.012618104, 0.012248916, 0.0111370925, 0.01054244, 0.01325088, 0.010980579, 0.01183094, 0.01260025, 0.012423151, 0.011211017, 0.01238956, 0.010244806, 0.011347215, 0.01178355, 0.01319318, 0.010682834, 0.011408042, 0.010952326, 0.010944748, 0.011720423, 0.011256973, 0.0119862, 0.

2022-02-26 10:42:09,403 : INFO : merging changes from 2000 documents into a model of 3204 documents
2022-02-26 10:42:09,413 : INFO : topic #3 (0.009): 0.114*"410" + 0.112*"46" + 0.090*"11" + 0.034*"603" + 0.033*"913" + 0.031*"786" + 0.027*"636" + 0.023*"1178" + 0.021*"37" + 0.021*"499"
2022-02-26 10:42:09,414 : INFO : topic #15 (0.010): 0.064*"812" + 0.054*"46" + 0.053*"668" + 0.048*"24" + 0.046*"13" + 0.031*"1164" + 0.030*"374" + 0.025*"67" + 0.023*"128" + 0.023*"587"
2022-02-26 10:42:09,415 : INFO : topic #0 (0.015): 0.266*"457" + 0.087*"427" + 0.078*"112" + 0.065*"34" + 0.059*"575" + 0.044*"46" + 0.040*"551" + 0.036*"35" + 0.023*"254" + 0.021*"788"
2022-02-26 10:42:09,415 : INFO : topic #4 (0.017): 0.223*"24" + 0.131*"25" + 0.109*"50" + 0.057*"46" + 0.051*"343" + 0.041*"13" + 0.040*"332" + 0.036*"37" + 0.033*"340" + 0.028*"98"
2022-02-26 10:42:09,416 : INFO : topic #60 (0.037): 0.412*"29" + 0.369*"30" + 0.170*"79" + 0.012*"13" + 0.006*"252" + 0.005*"1182" + 0.005*"271" + 0.004*"882"

2022-02-26 10:42:12,997 : INFO : topic #0 (0.016): 0.282*"457" + 0.089*"427" + 0.083*"112" + 0.064*"34" + 0.062*"575" + 0.043*"46" + 0.043*"551" + 0.035*"35" + 0.023*"254" + 0.020*"788"
2022-02-26 10:42:12,998 : INFO : topic #4 (0.018): 0.244*"24" + 0.133*"25" + 0.104*"50" + 0.058*"46" + 0.055*"343" + 0.041*"13" + 0.038*"332" + 0.033*"37" + 0.029*"98" + 0.027*"340"
2022-02-26 10:42:12,999 : INFO : topic #60 (0.042): 0.427*"29" + 0.370*"30" + 0.160*"79" + 0.012*"13" + 0.006*"1182" + 0.005*"252" + 0.003*"882" + 0.003*"271" + 0.002*"534" + 0.002*"821"
2022-02-26 10:42:12,999 : INFO : topic diff=0.273323, rho=0.261694
2022-02-26 10:42:13,007 : INFO : PROGRESS: pass 13, at document #2000/3204
2022-02-26 10:42:13,708 : INFO : optimized alpha [0.016048148, 0.010679395, 0.014038123, 0.009568588, 0.018628983, 0.013856107, 0.012256004, 0.012015169, 0.010772344, 0.012066613, 0.01186678, 0.012522678, 0.011758833, 0.013679972, 0.011299313, 0.010451255, 0.011862905, 0.012274774, 0.013923988, 0.01365

2022-02-26 10:42:15,753 : INFO : topic diff=0.217259, rho=0.245426
2022-02-26 10:42:16,465 : INFO : -6.449 per-word bound, 87.4 perplexity estimate based on a held-out corpus of 1204 documents with 49783 words
2022-02-26 10:42:16,466 : INFO : PROGRESS: pass 14, at document #3204/3204
2022-02-26 10:42:17,041 : INFO : optimized alpha [0.01691405, 0.010855284, 0.014649546, 0.009741861, 0.01994896, 0.014497748, 0.012584776, 0.012354428, 0.010907305, 0.012350366, 0.012192925, 0.01281494, 0.011950381, 0.014047548, 0.011434992, 0.010705521, 0.012139514, 0.01264055, 0.014680384, 0.014117325, 0.011666423, 0.014478106, 0.010837592, 0.012301039, 0.011660593, 0.012207163, 0.011815481, 0.012144586, 0.011576927, 0.011712601, 0.013408516, 0.013556463, 0.012280073, 0.011027628, 0.014695795, 0.011925108, 0.013271172, 0.014274932, 0.01378192, 0.0118954, 0.014067996, 0.010494946, 0.012257567, 0.013358108, 0.01646384, 0.011097816, 0.012274425, 0.011872591, 0.011595384, 0.012790696, 0.011775358, 0.01321352

2022-02-26 10:42:19,668 : INFO : merging changes from 2000 documents into a model of 3204 documents
2022-02-26 10:42:19,680 : INFO : topic #3 (0.010): 0.150*"11" + 0.132*"410" + 0.124*"46" + 0.035*"603" + 0.032*"913" + 0.031*"786" + 0.029*"499" + 0.024*"171" + 0.022*"1178" + 0.021*"238"
2022-02-26 10:42:19,681 : INFO : topic #69 (0.011): 0.341*"90" + 0.171*"33" + 0.099*"1010" + 0.063*"1104" + 0.055*"1019" + 0.051*"13" + 0.050*"480" + 0.042*"252" + 0.039*"312" + 0.023*"69"
2022-02-26 10:42:19,681 : INFO : topic #44 (0.018): 0.587*"37" + 0.053*"52" + 0.050*"35" + 0.043*"423" + 0.040*"143" + 0.037*"649" + 0.031*"474" + 0.030*"18" + 0.020*"69" + 0.013*"371"
2022-02-26 10:42:19,682 : INFO : topic #4 (0.021): 0.239*"24" + 0.131*"25" + 0.107*"50" + 0.059*"46" + 0.053*"343" + 0.041*"13" + 0.039*"332" + 0.032*"340" + 0.031*"37" + 0.029*"98"
2022-02-26 10:42:19,683 : INFO : topic #60 (0.057): 0.428*"29" + 0.370*"30" + 0.167*"79" + 0.011*"13" + 0.006*"1182" + 0.004*"882" + 0.003*"821" + 0.002*"53

2022-02-26 10:42:22,836 : INFO : topic #44 (0.019): 0.613*"37" + 0.048*"35" + 0.046*"52" + 0.044*"423" + 0.035*"143" + 0.034*"649" + 0.030*"18" + 0.030*"474" + 0.019*"69" + 0.013*"371"
2022-02-26 10:42:22,837 : INFO : topic #4 (0.022): 0.253*"24" + 0.132*"25" + 0.103*"50" + 0.060*"46" + 0.056*"343" + 0.041*"13" + 0.037*"332" + 0.030*"98" + 0.029*"37" + 0.027*"340"
2022-02-26 10:42:22,838 : INFO : topic #60 (0.063): 0.439*"29" + 0.368*"30" + 0.160*"79" + 0.011*"13" + 0.006*"1182" + 0.003*"882" + 0.002*"534" + 0.002*"821" + 0.001*"282" + 0.001*"252"
2022-02-26 10:42:22,839 : INFO : topic diff=0.176292, rho=0.225865
2022-02-26 10:42:22,846 : INFO : PROGRESS: pass 18, at document #2000/3204
2022-02-26 10:42:23,460 : INFO : optimized alpha [0.01837668, 0.011205601, 0.015800938, 0.0101250205, 0.02261388, 0.015741017, 0.013070274, 0.013082938, 0.011248737, 0.012979116, 0.012733749, 0.013391272, 0.0123806605, 0.014676662, 0.011740345, 0.011105996, 0.012718888, 0.013260414, 0.016172074, 0.01476

2022-02-26 10:42:25,310 : INFO : topic diff=0.149856, rho=0.215156
2022-02-26 10:42:25,952 : INFO : -6.418 per-word bound, 85.5 perplexity estimate based on a held-out corpus of 1204 documents with 49783 words
2022-02-26 10:42:25,952 : INFO : PROGRESS: pass 19, at document #3204/3204
2022-02-26 10:42:26,463 : INFO : optimized alpha [0.019210404, 0.011417354, 0.016386423, 0.010343453, 0.023924567, 0.016346771, 0.013353938, 0.013445465, 0.011388076, 0.013263151, 0.013017548, 0.013662979, 0.012557762, 0.014992165, 0.011858254, 0.011382605, 0.012994366, 0.013624218, 0.016993374, 0.015186745, 0.012378289, 0.016696459, 0.011110707, 0.013119071, 0.012270654, 0.013274075, 0.012473805, 0.013141932, 0.012175919, 0.01229742, 0.014097175, 0.014869385, 0.01339807, 0.0115126055, 0.016048487, 0.0128797935, 0.0147221675, 0.016157102, 0.015107722, 0.012493054, 0.015789002, 0.010773328, 0.01309361, 0.015035321, 0.020307565, 0.011518454, 0.013119338, 0.012799508, 0.012291945, 0.0139955785, 0.012272724, 0

[(0, 0.007905727),
 (1, 0.0046986253),
 (2, 0.0067435643),
 (3, 0.0042566787),
 (4, 0.009845764),
 (5, 0.0067272466),
 (6, 0.005495595),
 (7, 0.005533261),
 (8, 0.0046865763),
 (9, 0.005458233),
 (10, 0.0053571593),
 (11, 0.0056227758),
 (12, 0.0051679416),
 (13, 0.0061697806),
 (14, 0.00488007),
 (15, 0.004684325),
 (16, 0.005347619),
 (17, 0.005606824),
 (18, 0.006993345),
 (19, 0.006249856),
 (20, 0.41662633),
 (21, 0.0068711545),
 (22, 0.0045724297),
 (23, 0.005398939),
 (24, 0.005049787),
 (25, 0.0054627284),
 (26, 0.0051333904),
 (27, 0.005408347),
 (28, 0.0050108),
 (29, 0.005060802),
 (30, 0.0058014616),
 (31, 0.006119252),
 (32, 0.0055137565),
 (33, 0.0047378247),
 (34, 0.0066044927),
 (35, 0.0053004683),
 (36, 0.0060586673),
 (37, 0.006649191),
 (38, 0.006217336),
 (39, 0.0051413123),
 (40, 0.006497706),
 (41, 0.004433587),
 (42, 0.005388461),
 (43, 0.0061875405),
 (44, 0.008357246),
 (45, 0.0047402317),
 (46, 0.005399049),
 (47, 0.0052674282),
 (48, 0.005058549),
 (49, 0.005

\#### Please do not change this. This cell is used for grading.

---
Now we can use the `DenseRetrievalModel` class to obtain an LDA search function.
You can test your LDA model in the following cell: Try finding queries which are lexically different to documents, but semantically similar - does LDA work well for these queries?!

In [135]:
drm_lda = DenseRetrievalRanker(lda, jenson_shannon_sim)

# test your LDA model
search_fn = drm_lda.search

text = widgets.Text(description="Search Bar", width=200)
display(text)


text.on_submit(handle_submit_2)

Text(value='', description='Search Bar')

Searching for: 'probability' (SEARCH FN: <bound method DenseRetrievalRanker.search of <__main__.DenseRetrievalRanker object at 0x0000027888AC79D0>>)


## Section 8: Word2Vec/Doc2Vec (20 points) <a class="anchor" id="2vec"></a>

[Back to Part 2](#part2)

We will implement two other methods here, the Word2Vec model and the Doc2Vec model, also using `gensim`. Word2Vec creates representations of words, not documents, so the word level vectors need to be aggregated to obtain a representation for the document. Here, we will simply take the mean of the vectors. 


A drawback of these models is that they need a lot of training data. Our dataset is tiny, so in addition to using a model trained on the data, we will also use a pre-trained model for Word2Vec (this will be automatically downloaded).     

*Note*:
1. The code in vectorize_documents / vectorize_query should return gensim-like vectors i.e `[(dim, val), .. (dim, val)]`. 
2. For Word2Vec: You should also handle the following two cases: (a) A word in the query is not present in the vocabulary of the model and (b) none of the words in the query are present in the model - you can return 0 scores for all documents in this case. For either of these, you can check if a `word` is present in the vocab by using `word in self.model`


In [226]:
# TODO: Implement this! (10 points)
class W2VRetrievalModel(VectorSpaceRetrievalModel):
    def __init__(self, doc_repr):
        super().__init__(doc_repr)
        
        # the dimensionality of the vectors
        self.size = 100 
        self.min_count = 1
    
    def train_model(self):
        """
        Trains the W2V model
        """
        self.model = Word2Vec(self.documents, vector_size=self.size, window=2, min_count=self.min_count, workers=4)
        
    def vectorize_documents(self):
        """
            Returns a doc_id -> vector dictionary
        """    
        vectors={}
        for (doc_id, _), cc in zip(self.doc_repr, self.documents):
            for word in cc:
                if word not in self.model.wv:
                    cc.remove(word)
            docvec = self.model.wv[cc]
            tot=np.linspace(0, self.size-1,num=self.size,dtype=int)
            real=list(zip(tot, docvec[0]))
            vectors[doc_id]=real
        return vectors

    def vectorize_query(self, query):
        """
        Vectorizes the query using the W2V model
        """
        query = process_text(query, **config_2)
        for word in query:
            if word not in self.model.wv:
                return np.zeros(self.size)
        tot=np.linspace(0, self.size-1,num=self.size,dtype=int)
        query_vector = self.model.wv[query]
        real=list(zip(tot, query_vector[0]))
        return real
        
    
    
class W2VPretrainedRetrievalModel(W2VRetrievalModel):
    def __init__(self, doc_repr):
        super().__init__(doc_repr)
        self.model_name = "word2vec-google-news-300"
        self.size = 300
    
    def train_model(self):
        """
        Loads the pretrained model
        """
        self.model = g_downloader.load(self.model_name)
        
    def vectorize_documents(self):
        """
            Returns a doc_id -> vector dictionary
        """    
        vectors={}
        for (doc_id, _), cc in zip(self.doc_repr, self.documents):
            wrong_words=[]
            new=cc[:]
            for word in cc:
                if word not in self.model:
                    wrong_words.append(word)
            if len(wrong_words)==len(cc):
                docvec=[np.zeros(self.size)]
            else:
                for i in wrong_words:
                    new.remove(i)
                docvec = self.model[new]
            tot=np.linspace(0, self.size-1,num=self.size,dtype=int)
            real=list(zip(tot, docvec[0]))
            vectors[doc_id]=real
        return vectors

    def vectorize_query(self, query):
        """
        Vectorizes the query using the W2V model
        """
        query = process_text(query, **config_2)
        for word in query:
            if word not in self.model:
                return np.zeros(self.size)
        tot=np.linspace(0, self.size-1,num=self.size,dtype=int)
        query_vector = self.model[query]
        real=list(zip(tot, query_vector[0]))
        return real


w2v = W2VRetrievalModel(doc_repr_2)
w2v.train_model()

# you can now get a W2V vector for a given query in the following way:
w2v.vectorize_query("report")

2022-02-26 11:55:31,614 : INFO : adding document #0 to Dictionary(0 unique tokens: [])
2022-02-26 11:55:31,706 : INFO : built Dictionary(5937 unique tokens: ['-', 'algebra', 'intern', 'languag', 'preliminari']...) from 3204 documents (total 115969 corpus positions)
2022-02-26 11:55:31,707 : INFO : Dictionary lifecycle event {'msg': "built Dictionary(5937 unique tokens: ['-', 'algebra', 'intern', 'languag', 'preliminari']...) from 3204 documents (total 115969 corpus positions)", 'datetime': '2022-02-26T11:55:31.707950', 'gensim': '4.1.2', 'python': '3.9.7 (default, Sep 16 2021, 16:59:28) [MSC v.1916 64 bit (AMD64)]', 'platform': 'Windows-10-10.0.19042-SP0', 'event': 'created'}
2022-02-26 11:55:31,749 : INFO : discarding 4740 tokens: [('repeat', 8), ('glossari', 7), ('inspect', 8), ('uncol', 2), ('rung', 9), ('secant', 2), ('.', 1603), ('acceler', 6), ('diverg', 3), ('induc', 9)]...
2022-02-26 11:55:31,750 : INFO : keeping 1197 tokens which were in no less than 10 and no more than 1602 (

[(0, -0.2929633),
 (1, 0.5162057),
 (2, 0.061494797),
 (3, -0.17668048),
 (4, 0.00087948074),
 (5, -0.44991282),
 (6, 0.2921861),
 (7, 0.85706884),
 (8, -0.28222275),
 (9, -0.2392792),
 (10, -0.12137509),
 (11, -0.4761624),
 (12, 0.20351748),
 (13, -0.114867054),
 (14, 0.042747203),
 (15, -0.3834302),
 (16, 0.19995497),
 (17, -0.3748146),
 (18, -0.20036821),
 (19, -0.81606615),
 (20, 0.11216883),
 (21, -0.08120381),
 (22, 0.46598342),
 (23, -0.121853836),
 (24, -0.24614295),
 (25, -0.024493419),
 (26, -0.25954357),
 (27, -0.23311669),
 (28, -0.40593475),
 (29, -0.001985349),
 (30, 0.453874),
 (31, -0.06364803),
 (32, 0.21714273),
 (33, -0.34832332),
 (34, -0.0741775),
 (35, 0.3848314),
 (36, 0.25564924),
 (37, -0.3981277),
 (38, -0.55068284),
 (39, -0.6076438),
 (40, 0.17630205),
 (41, -0.23751466),
 (42, -0.37847817),
 (43, -0.028189829),
 (44, 0.46725324),
 (45, -0.28758243),
 (46, -0.02915539),
 (47, -0.04811476),
 (48, 0.26951852),
 (49, 0.04434126),
 (50, 0.20182289),
 (51, -0.437

In [95]:
assert len(w2v.vectorize_query("report")) == 100
assert len(w2v.vectorize_query("this is a sentence that is not mellifluous")) == 100


\#### Please do not change this. This cell is used for grading.

In [227]:
w2v_pretrained = W2VPretrainedRetrievalModel(doc_repr_2)
w2v_pretrained.train_model()

# you can now get an W2V vector for a given query in the following way:
w2v_pretrained.vectorize_query("report")

2022-02-26 11:55:38,419 : INFO : adding document #0 to Dictionary(0 unique tokens: [])
2022-02-26 11:55:38,511 : INFO : built Dictionary(5937 unique tokens: ['-', 'algebra', 'intern', 'languag', 'preliminari']...) from 3204 documents (total 115969 corpus positions)
2022-02-26 11:55:38,511 : INFO : Dictionary lifecycle event {'msg': "built Dictionary(5937 unique tokens: ['-', 'algebra', 'intern', 'languag', 'preliminari']...) from 3204 documents (total 115969 corpus positions)", 'datetime': '2022-02-26T11:55:38.511970', 'gensim': '4.1.2', 'python': '3.9.7 (default, Sep 16 2021, 16:59:28) [MSC v.1916 64 bit (AMD64)]', 'platform': 'Windows-10-10.0.19042-SP0', 'event': 'created'}
2022-02-26 11:55:38,517 : INFO : discarding 4740 tokens: [('repeat', 8), ('glossari', 7), ('inspect', 8), ('uncol', 2), ('rung', 9), ('secant', 2), ('.', 1603), ('acceler', 6), ('diverg', 3), ('induc', 9)]...
2022-02-26 11:55:38,517 : INFO : keeping 1197 tokens which were in no less than 10 and no more than 1602 (

[(0, -0.14257812),
 (1, -0.1640625),
 (2, -0.09033203),
 (3, -0.11230469),
 (4, 0.100097656),
 (5, -0.041259766),
 (6, 0.048828125),
 (7, -0.13671875),
 (8, 0.19628906),
 (9, -0.13476562),
 (10, -0.017578125),
 (11, 0.032226562),
 (12, 0.095214844),
 (13, -0.10595703),
 (14, -0.16992188),
 (15, 0.041015625),
 (16, -0.26367188),
 (17, -0.0063171387),
 (18, -0.17773438),
 (19, -0.24023438),
 (20, 0.3515625),
 (21, -0.012207031),
 (22, -0.16210938),
 (23, -0.12060547),
 (24, 0.04321289),
 (25, 0.10986328),
 (26, 0.052490234),
 (27, 0.17871094),
 (28, -0.14550781),
 (29, 0.13769531),
 (30, -0.08203125),
 (31, -0.28320312),
 (32, -0.10888672),
 (33, -0.2890625),
 (34, 0.072265625),
 (35, -0.04736328),
 (36, 0.040283203),
 (37, 0.067871094),
 (38, 0.11669922),
 (39, 0.000831604),
 (40, 0.068359375),
 (41, 0.12011719),
 (42, -0.088378906),
 (43, 0.33789062),
 (44, -0.044677734),
 (45, -0.030151367),
 (46, 0.0076904297),
 (47, -0.021118164),
 (48, -0.25390625),
 (49, 0.14941406),
 (50, 0.39843

In [223]:
##### Function check

print(len(w2v_pretrained.vectorize_query("report")))
#####

300


In [91]:
drm_w2v = DenseRetrievalRanker(w2v, cosine_sim)

# test your LDA model
search_fn = drm_w2v.search

text = widgets.Text(description="Search Bar", width=200)
display(text)


text.on_submit(handle_submit_2)

Text(value='', description='Search Bar')

Searching for: 'good' (SEARCH FN: <bound method DenseRetrievalRanker.search of <__main__.DenseRetrievalRanker object at 0x0000027880F2D400>>)
good
[('1', 0.9962803716040562), ('2', 0.9979116210731974), ('3', 0.9870059746074205), ('4', 0.9871143909243975), ('5', 0.9931769856010864), ('6', 0.984136388559018), ('7', 0.9871143909243975), ('8', 0.9983690638788042), ('9', 0.9973124167030201), ('10', 0.9871143909243975), ('11', 0.9941283555690551), ('12', 0.9985055038201147), ('13', 0.9871143909243975), ('14', 0.9941283555690551), ('15', 0.9988848640347948), ('16', 0.8282140120761096), ('17', 0.9801735162127907), ('18', 0.9973422661653712), ('19', 0.9871143909243975), ('20', 0.9917229018824725), ('21', 0.9976365250883572), ('22', 0.9830018323461129), ('23', 0.9984945226082623), ('24', 0.9980981327080058), ('25', 0.9979804914668007), ('26', 0.997897091713546), ('27', 0.9990069424165785), ('28', 0.998871275590139), ('29', 0.9971560908989062), ('30', 0.8164733408022454), ('31', 0.997815873661055

In [229]:
drm_w2v_pretrained = DenseRetrievalRanker(w2v_pretrained, cosine_sim)

#test your LDA model
search_fn = drm_w2v_pretrained.search

text = widgets.Text(description="Search Bar", width=200)
display(text)


text.on_submit(handle_submit_2)

Text(value='', description='Search Bar')

Searching for: 'report' (SEARCH FN: <bound method DenseRetrievalRanker.search of <__main__.DenseRetrievalRanker object at 0x0000027888ABB280>>)


Searching for: 'great' (SEARCH FN: <bound method DenseRetrievalRanker.search of <__main__.DenseRetrievalRanker object at 0x0000027888ABB280>>)


**Implementation (10 points):**
For Doc2Vec, you will need to create a list of `TaggedDocument` instead of using the `self.corpus` or `self.documents` variable. Use the document id as the 'tag'.
  

In [230]:
#### Please do not change this. This cell is used for grading.

\#### Please do not change this. This cell is used for grading.

In [235]:
# TODO: Implement this! (10 points)
class D2VRetrievalModel(VectorSpaceRetrievalModel):
    def __init__(self, doc_repr):
        super().__init__(doc_repr)
        
        self.vector_size= 100
        self.min_count = 1
        self.epochs = 20
        self.doc = [TaggedDocument(d[1],[0]) for i, d in enumerate(doc_repr)]
        
    def train_model(self):
        """
        Trains the W2V model
        """
        self.model = Doc2Vec(self.doc, vector_size=self.vector_size, window=5, min_count=self.min_count, workers=4)
        
    def vectorize_documents(self):
        """
            Returns a doc_id -> vector dictionary
        """    
        vectors={}
        
        for (doc_id, _), cc in zip(self.doc_repr, self.documents):
            n=cc[:]
            for word in cc:
                if word not in self.model.wv:
                    n.remove(word)        
            docvec=self.model.infer_vector(n)
            tot=np.linspace(0, self.vector_size-1,num=self.vector_size,dtype=int)
            real=list(zip(tot, docvec))
            vectors[doc_id]=real
            
        return vectors

    def vectorize_query(self, query):
        """
        Vectorizes the query using the W2V model
        """
        query = process_text(query, **config_2)
        for word in query:
            if word not in self.model.wv:
                return np.zeros(self.vector_size)
        tot=np.linspace(0, self.vector_size-1,num=self.vector_size,dtype=int)
        query_vector = self.model.wv[query]
        real=list(zip(tot, query_vector[0]))
        
        return real
        
d2v = D2VRetrievalModel(doc_repr_2)
d2v.train_model()


# # you can now get an LSI vector for a given query in the following way:
d2v.vectorize_query("report")

2022-02-26 12:05:06,901 : INFO : adding document #0 to Dictionary(0 unique tokens: [])
2022-02-26 12:05:06,994 : INFO : built Dictionary(5937 unique tokens: ['-', 'algebra', 'intern', 'languag', 'preliminari']...) from 3204 documents (total 115969 corpus positions)
2022-02-26 12:05:06,995 : INFO : Dictionary lifecycle event {'msg': "built Dictionary(5937 unique tokens: ['-', 'algebra', 'intern', 'languag', 'preliminari']...) from 3204 documents (total 115969 corpus positions)", 'datetime': '2022-02-26T12:05:06.995519', 'gensim': '4.1.2', 'python': '3.9.7 (default, Sep 16 2021, 16:59:28) [MSC v.1916 64 bit (AMD64)]', 'platform': 'Windows-10-10.0.19042-SP0', 'event': 'created'}
2022-02-26 12:05:07,000 : INFO : discarding 4740 tokens: [('repeat', 8), ('glossari', 7), ('inspect', 8), ('uncol', 2), ('rung', 9), ('secant', 2), ('.', 1603), ('acceler', 6), ('diverg', 3), ('induc', 9)]...
2022-02-26 12:05:07,000 : INFO : keeping 1197 tokens which were in no less than 10 and no more than 1602 (

2022-02-26 12:05:09,354 : INFO : Doc2Vec lifecycle event {'msg': 'training on 1159690 raw words (955392 effective words) took 2.2s, 438805 effective words/s', 'datetime': '2022-02-26T12:05:09.354735', 'gensim': '4.1.2', 'python': '3.9.7 (default, Sep 16 2021, 16:59:28) [MSC v.1916 64 bit (AMD64)]', 'platform': 'Windows-10-10.0.19042-SP0', 'event': 'train'}
2022-02-26 12:05:09,354 : INFO : Doc2Vec lifecycle event {'params': 'Doc2Vec(dm/m,d100,n5,w5,s0.001,t4)', 'datetime': '2022-02-26T12:05:09.354735', 'gensim': '4.1.2', 'python': '3.9.7 (default, Sep 16 2021, 16:59:28) [MSC v.1916 64 bit (AMD64)]', 'platform': 'Windows-10-10.0.19042-SP0', 'event': 'created'}


[(0, -0.3202496),
 (1, -0.35242635),
 (2, -0.5701272),
 (3, 0.27845564),
 (4, 0.1115539),
 (5, -0.14518675),
 (6, -0.50379544),
 (7, -0.06893954),
 (8, -0.7541975),
 (9, 0.06631541),
 (10, 0.17280471),
 (11, 0.037141312),
 (12, -0.17988284),
 (13, -0.23483443),
 (14, -0.49617156),
 (15, -0.5334058),
 (16, 0.2802506),
 (17, 0.47122383),
 (18, -0.4256436),
 (19, -0.42556137),
 (20, -0.18866158),
 (21, 0.14750807),
 (22, 0.04614093),
 (23, 0.1417656),
 (24, 0.1449676),
 (25, -0.608334),
 (26, -0.5045047),
 (27, -0.75167394),
 (28, 0.2448015),
 (29, -0.5906185),
 (30, 0.37274203),
 (31, 0.43907705),
 (32, -0.35867974),
 (33, -0.37184355),
 (34, -0.05235451),
 (35, 0.0076122866),
 (36, -0.09779918),
 (37, -0.64754647),
 (38, -0.12697747),
 (39, -0.02523147),
 (40, -0.15170732),
 (41, -0.5716624),
 (42, 0.2900885),
 (43, -0.68225616),
 (44, 0.09131778),
 (45, -0.3209908),
 (46, 0.064156294),
 (47, -0.03270977),
 (48, 0.46771273),
 (49, -0.5586908),
 (50, -0.34203282),
 (51, -0.2009909),
 (52

In [236]:
drm_d2v = DenseRetrievalRanker(d2v, cosine_sim)

# test your LDA model
search_fn = drm_d2v.search

text = widgets.Text(description="Search Bar", width=200)
display(text)


text.on_submit(handle_submit_2)

Text(value='', description='Search Bar')

Searching for: 'report' (SEARCH FN: <bound method DenseRetrievalRanker.search of <__main__.DenseRetrievalRanker object at 0x000002788ADC0DF0>>)


---
## Section 9: Re-ranking (10 points) <a class="anchor" id="reranking"></a>

[Back to Part 2](#part2)

To motivate the re-ranking perspective (i.e retrieve with lexical method + rerank with a semantic method), let's search using semantic methods and compare it to BM25's performance, along with their runtime:


In [237]:
query = "algebraic functions"
print("BM25: ")
%timeit bm25_search(query, 2)
print("LSI: ")
%timeit drm_lsi.search(query)
print("LDA: ")
%timeit drm_lda.search(query)
print("W2V: ")
%timeit drm_w2v.search(query)
print("W2V(Pretrained): ")
%timeit drm_w2v_pretrained.search(query)
print("D2V:")
%timeit drm_d2v.search(query)

BM25: 
1.02 ms ± 4.56 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
LSI: 
1.22 s ± 5.29 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
LDA: 
1.97 s ± 1.75 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
W2V: 
algebraic functions
[('1', 0.9978697136521328), ('2', 0.9972250891340543), ('3', 0.9924248603710716), ('4', 0.9876796973502044), ('5', 0.9883642237858279), ('6', 0.9897175465555959), ('7', 0.9876796973502044), ('8', 0.9975978783330524), ('9', 0.9951417489248426), ('10', 0.9876796973502044), ('11', 0.9967223427835207), ('12', 0.9980996548969243), ('13', 0.9876796973502044), ('14', 0.9967223427835207), ('15', 0.9981968812047006), ('16', 0.8311046965678884), ('17', 0.9877649299783626), ('18', 0.9990859407721608), ('19', 0.9876796973502044), ('20', 0.9896236101234661), ('21', 1.0000000000000002), ('22', 0.9815078603737545), ('23', 0.9970183611892768), ('24', 0.9988008841804997), ('25', 0.9975660206884891), ('26', 0.9990429470745852), ('27', 0.9985973549537492)

[('1', 0.9978697136521328), ('2', 0.9972250891340543), ('3', 0.9924248603710716), ('4', 0.9876796973502044), ('5', 0.9883642237858279), ('6', 0.9897175465555959), ('7', 0.9876796973502044), ('8', 0.9975978783330524), ('9', 0.9951417489248426), ('10', 0.9876796973502044), ('11', 0.9967223427835207), ('12', 0.9980996548969243), ('13', 0.9876796973502044), ('14', 0.9967223427835207), ('15', 0.9981968812047006), ('16', 0.8311046965678884), ('17', 0.9877649299783626), ('18', 0.9990859407721608), ('19', 0.9876796973502044), ('20', 0.9896236101234661), ('21', 1.0000000000000002), ('22', 0.9815078603737545), ('23', 0.9970183611892768), ('24', 0.9988008841804997), ('25', 0.9975660206884891), ('26', 0.9990429470745852), ('27', 0.9985973549537492), ('28', 0.9980905257026413), ('29', 0.9989904533095658), ('30', 0.7929321843010665), ('31', 0.9983690336351805), ('32', 0.9957317937191977), ('33', 0.9877649299783626), ('34', 0.9905514276859545), ('35', 0.9964696885468738), ('36', 0.9975925052763315), 

[('1', 0.9978697136521328), ('2', 0.9972250891340543), ('3', 0.9924248603710716), ('4', 0.9876796973502044), ('5', 0.9883642237858279), ('6', 0.9897175465555959), ('7', 0.9876796973502044), ('8', 0.9975978783330524), ('9', 0.9951417489248426), ('10', 0.9876796973502044), ('11', 0.9967223427835207), ('12', 0.9980996548969243), ('13', 0.9876796973502044), ('14', 0.9967223427835207), ('15', 0.9981968812047006), ('16', 0.8311046965678884), ('17', 0.9877649299783626), ('18', 0.9990859407721608), ('19', 0.9876796973502044), ('20', 0.9896236101234661), ('21', 1.0000000000000002), ('22', 0.9815078603737545), ('23', 0.9970183611892768), ('24', 0.9988008841804997), ('25', 0.9975660206884891), ('26', 0.9990429470745852), ('27', 0.9985973549537492), ('28', 0.9980905257026413), ('29', 0.9989904533095658), ('30', 0.7929321843010665), ('31', 0.9983690336351805), ('32', 0.9957317937191977), ('33', 0.9877649299783626), ('34', 0.9905514276859545), ('35', 0.9964696885468738), ('36', 0.9975925052763315), 

[('1', 0.9978697136521328), ('2', 0.9972250891340543), ('3', 0.9924248603710716), ('4', 0.9876796973502044), ('5', 0.9883642237858279), ('6', 0.9897175465555959), ('7', 0.9876796973502044), ('8', 0.9975978783330524), ('9', 0.9951417489248426), ('10', 0.9876796973502044), ('11', 0.9967223427835207), ('12', 0.9980996548969243), ('13', 0.9876796973502044), ('14', 0.9967223427835207), ('15', 0.9981968812047006), ('16', 0.8311046965678884), ('17', 0.9877649299783626), ('18', 0.9990859407721608), ('19', 0.9876796973502044), ('20', 0.9896236101234661), ('21', 1.0000000000000002), ('22', 0.9815078603737545), ('23', 0.9970183611892768), ('24', 0.9988008841804997), ('25', 0.9975660206884891), ('26', 0.9990429470745852), ('27', 0.9985973549537492), ('28', 0.9980905257026413), ('29', 0.9989904533095658), ('30', 0.7929321843010665), ('31', 0.9983690336351805), ('32', 0.9957317937191977), ('33', 0.9877649299783626), ('34', 0.9905514276859545), ('35', 0.9964696885468738), ('36', 0.9975925052763315), 

[('1', 0.9978697136521328), ('2', 0.9972250891340543), ('3', 0.9924248603710716), ('4', 0.9876796973502044), ('5', 0.9883642237858279), ('6', 0.9897175465555959), ('7', 0.9876796973502044), ('8', 0.9975978783330524), ('9', 0.9951417489248426), ('10', 0.9876796973502044), ('11', 0.9967223427835207), ('12', 0.9980996548969243), ('13', 0.9876796973502044), ('14', 0.9967223427835207), ('15', 0.9981968812047006), ('16', 0.8311046965678884), ('17', 0.9877649299783626), ('18', 0.9990859407721608), ('19', 0.9876796973502044), ('20', 0.9896236101234661), ('21', 1.0000000000000002), ('22', 0.9815078603737545), ('23', 0.9970183611892768), ('24', 0.9988008841804997), ('25', 0.9975660206884891), ('26', 0.9990429470745852), ('27', 0.9985973549537492), ('28', 0.9980905257026413), ('29', 0.9989904533095658), ('30', 0.7929321843010665), ('31', 0.9983690336351805), ('32', 0.9957317937191977), ('33', 0.9877649299783626), ('34', 0.9905514276859545), ('35', 0.9964696885468738), ('36', 0.9975925052763315), 

[('1', 0.9978697136521328), ('2', 0.9972250891340543), ('3', 0.9924248603710716), ('4', 0.9876796973502044), ('5', 0.9883642237858279), ('6', 0.9897175465555959), ('7', 0.9876796973502044), ('8', 0.9975978783330524), ('9', 0.9951417489248426), ('10', 0.9876796973502044), ('11', 0.9967223427835207), ('12', 0.9980996548969243), ('13', 0.9876796973502044), ('14', 0.9967223427835207), ('15', 0.9981968812047006), ('16', 0.8311046965678884), ('17', 0.9877649299783626), ('18', 0.9990859407721608), ('19', 0.9876796973502044), ('20', 0.9896236101234661), ('21', 1.0000000000000002), ('22', 0.9815078603737545), ('23', 0.9970183611892768), ('24', 0.9988008841804997), ('25', 0.9975660206884891), ('26', 0.9990429470745852), ('27', 0.9985973549537492), ('28', 0.9980905257026413), ('29', 0.9989904533095658), ('30', 0.7929321843010665), ('31', 0.9983690336351805), ('32', 0.9957317937191977), ('33', 0.9877649299783626), ('34', 0.9905514276859545), ('35', 0.9964696885468738), ('36', 0.9975925052763315), 

[('1', 0.9978697136521328), ('2', 0.9972250891340543), ('3', 0.9924248603710716), ('4', 0.9876796973502044), ('5', 0.9883642237858279), ('6', 0.9897175465555959), ('7', 0.9876796973502044), ('8', 0.9975978783330524), ('9', 0.9951417489248426), ('10', 0.9876796973502044), ('11', 0.9967223427835207), ('12', 0.9980996548969243), ('13', 0.9876796973502044), ('14', 0.9967223427835207), ('15', 0.9981968812047006), ('16', 0.8311046965678884), ('17', 0.9877649299783626), ('18', 0.9990859407721608), ('19', 0.9876796973502044), ('20', 0.9896236101234661), ('21', 1.0000000000000002), ('22', 0.9815078603737545), ('23', 0.9970183611892768), ('24', 0.9988008841804997), ('25', 0.9975660206884891), ('26', 0.9990429470745852), ('27', 0.9985973549537492), ('28', 0.9980905257026413), ('29', 0.9989904533095658), ('30', 0.7929321843010665), ('31', 0.9983690336351805), ('32', 0.9957317937191977), ('33', 0.9877649299783626), ('34', 0.9905514276859545), ('35', 0.9964696885468738), ('36', 0.9975925052763315), 

4.08 s ± 69 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
D2V:
1.44 s ± 17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


---

**Implementation (10 points):**
Re-ranking involves retrieving a small set of documents using simple but fast methods like BM25 and then re-ranking them with the aid of semantic methods such as LDA or LSI. Implement the following class, which takes in an `initial_retrieval_fn` - the initial retrieval function and `vsrm` - an instance of the `VectorSpaceRetrievalModel` class (i.e LSI/LDA) as input. The search function should first retrieve an initial list of K documents, and then these documents are re-ranked using a semantic method. This not only makes retrieval faster, but semantic methods perform poorly when used in isolation, as you will find out.

In [254]:
# TODO: Implement this! (10 points)
def Sort_Tuple(tup): 
      
    # getting length of list of tuples
    lst = len(tup) 
    for i in range(0, lst): 
          
        for j in range(0, lst-i-1): 
            if (tup[j][1] < tup[j + 1][1]): 
                temp = tup[j] 
                tup[j]= tup[j + 1] 
                tup[j + 1]= temp 
    return tup 


class DenseRerankingModel:
    def __init__(self, initial_retrieval_fn, vsrm, similarity_fn):
        """
            initial_retrieval_fn: takes in a query and returns a list of [(doc_id, score)] (sorted)
            vsrm: instance of `VectorSpaceRetrievalModel`
            similarity_fn: function instance that takes in two vectors 
                            and returns a similarity score e.g cosine_sim defined earlier
        """
        self.ret = initial_retrieval_fn
        self.vsrm = vsrm
        self.similarity_fn = similarity_fn
        self.vectorized_documents = vsrm.vectorize_documents()
        
        assert len(self.vectorized_documents) == len(doc_repr_2)
    
    def search(self, query, K=50):
        """
            First, retrieve the top K results using the retrieval function
            Then, re-rank the results using the VSRM instance
        """
        reranked_l=[]
        n=0
        for i in self.ret(query):
            n+=1
            query_vector=self.vsrm.vectorize_query(query)
            doc_id=i[0]
            if np.array(self.vectorized_documents[doc_id]).size==0 or  np.array(query_vector).size==0:
                continue
            else:
                

                scor=self.similarity_fn(self.vectorized_documents[doc_id],query_vector)
                reranked_l.append([doc_id,scor])
        reranked_l=Sort_Tuple(reranked_l)
        final=[]
        for i in range(len(reranked_l)):
            final.append(reranked_l[i])
            if i==K:
                break
        return final
            

In [255]:
##### Function check
bm25_search_2 = partial(bm25_search, index_set=2)
lsi_rerank = DenseRerankingModel(bm25_search_2, lsi, cosine_sim)
lda_rerank = DenseRerankingModel(bm25_search_2, lda, jenson_shannon_sim)
w2v_rerank = DenseRerankingModel(bm25_search_2, w2v, cosine_sim)
w2v_pretrained_rerank = DenseRerankingModel(bm25_search_2, w2v_pretrained, cosine_sim)
d2v_rerank = DenseRerankingModel(bm25_search_2, d2v, cosine_sim)

##### 

\#### Please do not change this. This cell is used for grading.

---
Now, let us time the new search functions:

In [249]:
query = "algebraic functions"
print("BM25: ")
%timeit bm25_search(query, 2)
print("LSI: ")
%timeit lsi_rerank.search(query)
print("LDA: ")
%timeit lda_rerank.search(query)
print("W2V: ")
%timeit w2v_rerank.search(query)
print("W2V(Pretrained): ")
%timeit w2v_pretrained_rerank.search(query)
print("D2V:")
%timeit d2v_rerank.search(query)

BM25: 
1.01 ms ± 2.83 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
LSI: 
[['2190', 0.8928388915573581], ['534', 0.8907757758888685], ['2609', 0.8737112224659194], ['295', 0.7910828509204736], ['1207', 0.6905327759571502], ['95', 0.633896559727317], ['27', 0.6201035284209588], ['2544', 0.6183692227612743], ['532', 0.6070137873362151], ['2143', 0.5913687178601861], ['365', 0.5883241037438187], ['2340', 0.5864244027382046], ['1373', 0.5829595452197873], ['1904', 0.573732793253155], ['676', 0.5629061757967366], ['1962', 0.5626685248105242], ['1029', 0.5531395791290783], ['1568', 0.5496369082614646], ['3203', 0.5401648498312118], ['1284', 0.5384766112284776], ['1329', 0.5383651093702], ['1403', 0.5357867318373987], ['330', 0.5304940612311105], ['1130', 0.5304274920576808], ['1993', 0.5249315449797478], ['2192', 0.5239275552198821], ['2323', 0.5182710013701234], ['480', 0.5016183481279133], ['210', 0.4870788590558732], ['2558', 0.4861100414454734], ['1896', 0.48535502411721265],

[['210', 0.7558894510588525], ['55', 0.7302233105393783], ['2190', 0.724946958733155], ['393', 0.7005667257157215], ['21', 0.6843314421192463], ['54', 0.67246042890011], ['133', 0.6654653667723844], ['1966', 0.6632819690483549], ['365', 0.655835852125777], ['676', 0.6547897561650476], ['1079', 0.6537515093516975], ['688', 0.6537463056401516], ['736', 0.6535524792843637], ['885', 0.6535524792843637], ['2512', 0.6535522839845145], ['534', 0.6522909344073105], ['1', 0.6515776458611524], ['2027', 0.6409409822487838], ['1284', 0.632472388970879], ['354', 0.6282935306343965], ['778', 0.6282843478209252], ['27', 0.6279359842366583], ['1355', 0.6168440372957151], ['1130', 0.6137398511553505], ['544', 0.6130334371202362], ['352', 0.6130334316232589], ['510', 0.6130334316232589], ['344', 0.6130334169321859], ['484', 0.6130334033658544], ['2090', 0.6123600591478884], ['1993', 0.6101826516237032], ['1896', 0.6054610617006583], ['1207', 0.5980951453901], ['99', 0.5970836986897154], ['1316', 0.59606

[['210', 0.7558894567165075], ['55', 0.7302233105418705], ['2190', 0.7249469719577863], ['393', 0.7005667257157215], ['21', 0.684331456660263], ['54', 0.6724604248323308], ['133', 0.6654653717849897], ['1966', 0.6632819690483549], ['365', 0.655835852125777], ['676', 0.6547897685462012], ['1079', 0.6537515093516975], ['688', 0.6537463056401516], ['736', 0.6535524792843637], ['885', 0.6535524772293861], ['2512', 0.6535522866670747], ['534', 0.6522909344073105], ['1', 0.6515776337806652], ['2027', 0.6409409822487838], ['1284', 0.632472388970879], ['354', 0.6282935158544605], ['778', 0.6282843267161211], ['27', 0.6279359842366583], ['1355', 0.6168440493711937], ['1130', 0.6137398511553505], ['544', 0.6130334371202362], ['344', 0.6130334316232589], ['352', 0.613033431048607], ['510', 0.613033431048607], ['484', 0.6130333924738854], ['2090', 0.6123600749141498], ['1993', 0.6101826516237032], ['1896', 0.6054610469350571], ['1207', 0.5980951453901], ['99', 0.5970836986897154], ['1316', 0.59606

[['55', 0.9999999999999998], ['21', 0.9999999999999998], ['3189', 0.9999999999999998], ['769', 0.9999999999999998], ['905', 0.9999999999999998], ['2167', 0.9999999999999998], ['964', 0.9991305086840343], ['1334', 0.9991305086840343], ['1496', 0.9990436461636223], ['3157', 0.9990267021885073], ['2184', 0.9990110645693135], ['1032', 0.9989736955550942], ['2181', 0.9989036807373949], ['2188', 0.9988930110755065], ['1183', 0.9988695023945097], ['27', 0.9988695023945097], ['848', 0.9988695023945097], ['1645', 0.9988695023945097], ['2595', 0.9988695023945097], ['1433', 0.9988695023945097], ['1691', 0.9988609376088916], ['2967', 0.9988609376088916], ['284', 0.9988600816440026], ['393', 0.9988593558402673], ['1672', 0.9988504770079265], ['1784', 0.9988504770079265], ['1111', 0.9988091001253712], ['693', 0.9988089422789835], ['2299', 0.9988089422789835], ['2547', 0.9987846662299531], ['2078', 0.9987846662299531], ['1344', 0.9987823707749564], ['2558', 0.9987491681836429], ['1029', 0.99873263007

[['55', 0.9999999999999998], ['21', 0.9999999999999998], ['3189', 0.9999999999999998], ['769', 0.9999999999999998], ['905', 0.9999999999999998], ['2167', 0.9999999999999998], ['964', 0.9991305086840343], ['1334', 0.9991305086840343], ['1496', 0.9990436461636223], ['3157', 0.9990267021885073], ['2184', 0.9990110645693135], ['1032', 0.9989736955550942], ['2181', 0.9989036807373949], ['2188', 0.9988930110755065], ['1183', 0.9988695023945097], ['27', 0.9988695023945097], ['848', 0.9988695023945097], ['1645', 0.9988695023945097], ['2595', 0.9988695023945097], ['1433', 0.9988695023945097], ['1691', 0.9988609376088916], ['2967', 0.9988609376088916], ['284', 0.9988600816440026], ['393', 0.9988593558402673], ['1672', 0.9988504770079265], ['1784', 0.9988504770079265], ['1111', 0.9988091001253712], ['693', 0.9988089422789835], ['2299', 0.9988089422789835], ['2547', 0.9987846662299531], ['2078', 0.9987846662299531], ['1344', 0.9987823707749564], ['2558', 0.9987491681836429], ['1029', 0.99873263007

[['55', 0.9999999999999998], ['21', 0.9999999999999998], ['3189', 0.9999999999999998], ['769', 0.9999999999999998], ['905', 0.9999999999999998], ['2167', 0.9999999999999998], ['964', 0.9991305086840343], ['1334', 0.9991305086840343], ['1496', 0.9990436461636223], ['3157', 0.9990267021885073], ['2184', 0.9990110645693135], ['1032', 0.9989736955550942], ['2181', 0.9989036807373949], ['2188', 0.9988930110755065], ['1183', 0.9988695023945097], ['27', 0.9988695023945097], ['848', 0.9988695023945097], ['1645', 0.9988695023945097], ['2595', 0.9988695023945097], ['1433', 0.9988695023945097], ['1691', 0.9988609376088916], ['2967', 0.9988609376088916], ['284', 0.9988600816440026], ['393', 0.9988593558402673], ['1672', 0.9988504770079265], ['1784', 0.9988504770079265], ['1111', 0.9988091001253712], ['693', 0.9988089422789835], ['2299', 0.9988089422789835], ['2547', 0.9987846662299531], ['2078', 0.9987846662299531], ['1344', 0.9987823707749564], ['2558', 0.9987491681836429], ['1029', 0.99873263007

[['55', 0.9999999999999998], ['21', 0.9999999999999998], ['3189', 0.9999999999999998], ['769', 0.9999999999999998], ['905', 0.9999999999999998], ['2167', 0.9999999999999998], ['964', 0.9991305086840343], ['1334', 0.9991305086840343], ['1496', 0.9990436461636223], ['3157', 0.9990267021885073], ['2184', 0.9990110645693135], ['1032', 0.9989736955550942], ['2181', 0.9989036807373949], ['2188', 0.9988930110755065], ['1183', 0.9988695023945097], ['27', 0.9988695023945097], ['848', 0.9988695023945097], ['1645', 0.9988695023945097], ['2595', 0.9988695023945097], ['1433', 0.9988695023945097], ['1691', 0.9988609376088916], ['2967', 0.9988609376088916], ['284', 0.9988600816440026], ['393', 0.9988593558402673], ['1672', 0.9988504770079265], ['1784', 0.9988504770079265], ['1111', 0.9988091001253712], ['693', 0.9988089422789835], ['2299', 0.9988089422789835], ['2547', 0.9987846662299531], ['2078', 0.9987846662299531], ['1344', 0.9987823707749564], ['2558', 0.9987491681836429], ['1029', 0.99873263007

[['55', 0.9999999999999998], ['21', 0.9999999999999998], ['3189', 0.9999999999999998], ['769', 0.9999999999999998], ['905', 0.9999999999999998], ['2167', 0.9999999999999998], ['964', 0.9991305086840343], ['1334', 0.9991305086840343], ['1496', 0.9990436461636223], ['3157', 0.9990267021885073], ['2184', 0.9990110645693135], ['1032', 0.9989736955550942], ['2181', 0.9989036807373949], ['2188', 0.9988930110755065], ['1183', 0.9988695023945097], ['27', 0.9988695023945097], ['848', 0.9988695023945097], ['1645', 0.9988695023945097], ['2595', 0.9988695023945097], ['1433', 0.9988695023945097], ['1691', 0.9988609376088916], ['2967', 0.9988609376088916], ['284', 0.9988600816440026], ['393', 0.9988593558402673], ['1672', 0.9988504770079265], ['1784', 0.9988504770079265], ['1111', 0.9988091001253712], ['693', 0.9988089422789835], ['2299', 0.9988089422789835], ['2547', 0.9987846662299531], ['2078', 0.9987846662299531], ['1344', 0.9987823707749564], ['2558', 0.9987491681836429], ['1029', 0.99873263007

[['55', 0.9999999999999998], ['21', 0.9999999999999998], ['3189', 0.9999999999999998], ['769', 0.9999999999999998], ['905', 0.9999999999999998], ['2167', 0.9999999999999998], ['964', 0.9991305086840343], ['1334', 0.9991305086840343], ['1496', 0.9990436461636223], ['3157', 0.9990267021885073], ['2184', 0.9990110645693135], ['1032', 0.9989736955550942], ['2181', 0.9989036807373949], ['2188', 0.9988930110755065], ['1183', 0.9988695023945097], ['27', 0.9988695023945097], ['848', 0.9988695023945097], ['1645', 0.9988695023945097], ['2595', 0.9988695023945097], ['1433', 0.9988695023945097], ['1691', 0.9988609376088916], ['2967', 0.9988609376088916], ['284', 0.9988600816440026], ['393', 0.9988593558402673], ['1672', 0.9988504770079265], ['1784', 0.9988504770079265], ['1111', 0.9988091001253712], ['693', 0.9988089422789835], ['2299', 0.9988089422789835], ['2547', 0.9987846662299531], ['2078', 0.9987846662299531], ['1344', 0.9987823707749564], ['2558', 0.9987491681836429], ['1029', 0.99873263007

[['55', 0.9999999999999998], ['21', 0.9999999999999998], ['3189', 0.9999999999999998], ['769', 0.9999999999999998], ['905', 0.9999999999999998], ['2167', 0.9999999999999998], ['964', 0.9991305086840343], ['1334', 0.9991305086840343], ['1496', 0.9990436461636223], ['3157', 0.9990267021885073], ['2184', 0.9990110645693135], ['1032', 0.9989736955550942], ['2181', 0.9989036807373949], ['2188', 0.9988930110755065], ['1183', 0.9988695023945097], ['27', 0.9988695023945097], ['848', 0.9988695023945097], ['1645', 0.9988695023945097], ['2595', 0.9988695023945097], ['1433', 0.9988695023945097], ['1691', 0.9988609376088916], ['2967', 0.9988609376088916], ['284', 0.9988600816440026], ['393', 0.9988593558402673], ['1672', 0.9988504770079265], ['1784', 0.9988504770079265], ['1111', 0.9988091001253712], ['693', 0.9988089422789835], ['2299', 0.9988089422789835], ['2547', 0.9987846662299531], ['2078', 0.9987846662299531], ['1344', 0.9987823707749564], ['2558', 0.9987491681836429], ['1029', 0.99873263007

[['55', 0.9999999999999998], ['21', 0.9999999999999998], ['3189', 0.9999999999999998], ['769', 0.9999999999999998], ['905', 0.9999999999999998], ['2167', 0.9999999999999998], ['964', 0.9991305086840343], ['1334', 0.9991305086840343], ['1496', 0.9990436461636223], ['3157', 0.9990267021885073], ['2184', 0.9990110645693135], ['1032', 0.9989736955550942], ['2181', 0.9989036807373949], ['2188', 0.9988930110755065], ['1183', 0.9988695023945097], ['27', 0.9988695023945097], ['848', 0.9988695023945097], ['1645', 0.9988695023945097], ['2595', 0.9988695023945097], ['1433', 0.9988695023945097], ['1691', 0.9988609376088916], ['2967', 0.9988609376088916], ['284', 0.9988600816440026], ['393', 0.9988593558402673], ['1672', 0.9988504770079265], ['1784', 0.9988504770079265], ['1111', 0.9988091001253712], ['693', 0.9988089422789835], ['2299', 0.9988089422789835], ['2547', 0.9987846662299531], ['2078', 0.9987846662299531], ['1344', 0.9987823707749564], ['2558', 0.9987491681836429], ['1029', 0.99873263007

[['55', 0.9999999999999998], ['21', 0.9999999999999998], ['3189', 0.9999999999999998], ['769', 0.9999999999999998], ['905', 0.9999999999999998], ['2167', 0.9999999999999998], ['964', 0.9991305086840343], ['1334', 0.9991305086840343], ['1496', 0.9990436461636223], ['3157', 0.9990267021885073], ['2184', 0.9990110645693135], ['1032', 0.9989736955550942], ['2181', 0.9989036807373949], ['2188', 0.9988930110755065], ['1183', 0.9988695023945097], ['27', 0.9988695023945097], ['848', 0.9988695023945097], ['1645', 0.9988695023945097], ['2595', 0.9988695023945097], ['1433', 0.9988695023945097], ['1691', 0.9988609376088916], ['2967', 0.9988609376088916], ['284', 0.9988600816440026], ['393', 0.9988593558402673], ['1672', 0.9988504770079265], ['1784', 0.9988504770079265], ['1111', 0.9988091001253712], ['693', 0.9988089422789835], ['2299', 0.9988089422789835], ['2547', 0.9987846662299531], ['2078', 0.9987846662299531], ['1344', 0.9987823707749564], ['2558', 0.9987491681836429], ['1029', 0.99873263007

[['55', 0.9999999999999998], ['21', 0.9999999999999998], ['3189', 0.9999999999999998], ['769', 0.9999999999999998], ['905', 0.9999999999999998], ['2167', 0.9999999999999998], ['964', 0.9991305086840343], ['1334', 0.9991305086840343], ['1496', 0.9990436461636223], ['3157', 0.9990267021885073], ['2184', 0.9990110645693135], ['1032', 0.9989736955550942], ['2181', 0.9989036807373949], ['2188', 0.9988930110755065], ['1183', 0.9988695023945097], ['27', 0.9988695023945097], ['848', 0.9988695023945097], ['1645', 0.9988695023945097], ['2595', 0.9988695023945097], ['1433', 0.9988695023945097], ['1691', 0.9988609376088916], ['2967', 0.9988609376088916], ['284', 0.9988600816440026], ['393', 0.9988593558402673], ['1672', 0.9988504770079265], ['1784', 0.9988504770079265], ['1111', 0.9988091001253712], ['693', 0.9988089422789835], ['2299', 0.9988089422789835], ['2547', 0.9987846662299531], ['2078', 0.9987846662299531], ['1344', 0.9987823707749564], ['2558', 0.9987491681836429], ['1029', 0.99873263007

[['964', 1.0], ['55', 1.0], ['284', 1.0], ['21', 1.0], ['3199', 1.0], ['3189', 1.0], ['44', 1.0], ['769', 1.0], ['905', 1.0], ['2167', 1.0], ['2166', 1.0], ['1334', 1.0], ['2802', 0.7605751303180235], ['2398', 0.5197644772718558], ['1527', 0.43693407907192455], ['2480', 0.3780817636800374], ['1111', 0.37359266407658603], ['1911', 0.37359266407658603], ['1371', 0.35010483886783533], ['1789', 0.35010483886783533], ['3203', 0.29950652201475786], ['2824', 0.24506236794684305], ['3031', 0.23755595261134171], ['2958', 0.23755595261134171], ['2940', 0.23755595261134171], ['1391', 0.2316825875933238], ['3202', 0.2316825875933238], ['1394', 0.2316825875933238], ['676', 0.2316825875933238], ['532', 0.2316825875933238], ['1309', 0.2316825875933238], ['1536', 0.2316825875933238], ['96', 0.2316825875933238], ['1543', 0.2316825875933238], ['1199', 0.2316825875933238], ['2321', 0.2316825875933238], ['1525', 0.2316825875933238], ['1003', 0.2316825875933238], ['2931', 0.2254054416928777], ['2947', 0.21

[['964', 1.0], ['55', 1.0], ['284', 1.0], ['21', 1.0], ['3199', 1.0], ['3189', 1.0], ['44', 1.0], ['769', 1.0], ['905', 1.0], ['2167', 1.0], ['2166', 1.0], ['1334', 1.0], ['2802', 0.7605751303180235], ['2398', 0.5197644772718558], ['1527', 0.43693407907192455], ['2480', 0.3780817636800374], ['1111', 0.37359266407658603], ['1911', 0.37359266407658603], ['1371', 0.35010483886783533], ['1789', 0.35010483886783533], ['3203', 0.29950652201475786], ['2824', 0.24506236794684305], ['3031', 0.23755595261134171], ['2958', 0.23755595261134171], ['2940', 0.23755595261134171], ['1391', 0.2316825875933238], ['3202', 0.2316825875933238], ['1394', 0.2316825875933238], ['676', 0.2316825875933238], ['532', 0.2316825875933238], ['1309', 0.2316825875933238], ['1536', 0.2316825875933238], ['96', 0.2316825875933238], ['1543', 0.2316825875933238], ['1199', 0.2316825875933238], ['2321', 0.2316825875933238], ['1525', 0.2316825875933238], ['1003', 0.2316825875933238], ['2931', 0.2254054416928777], ['2947', 0.21

[['387', 0.9651680518143015], ['1510', 0.9491845134170614], ['2039', 0.9485980278240337], ['2143', 0.9440736298143549], ['480', 0.9338096917033807], ['510', 0.933358831616659], ['55', 0.9317428580727982], ['54', 0.9292958053592952], ['1375', 0.9285714061257955], ['1789', 0.9221822894902187], ['2428', 0.9217424366274951], ['1023', 0.919223999556827], ['992', 0.913853606936229], ['1942', 0.9136610887808643], ['2259', 0.9126487364309157], ['1790', 0.9125784583355193], ['1821', 0.911270518940928], ['21', 0.9107674597392675], ['384', 0.9086826957478149], ['788', 0.9085012377172826], ['547', 0.9084613435103777], ['967', 0.9068565056170919], ['2470', 0.9048815438850248], ['352', 0.9030213525411567], ['3202', 0.9025352819680804], ['1460', 0.9013386964536251], ['3031', 0.898233553838579], ['885', 0.8962373336694112], ['1583', 0.8955845618972139], ['93', 0.8947972569961057], ['39', 0.8938219283524781], ['2149', 0.8935822625716872], ['1897', 0.8933792919979121], ['1562', 0.892724045438159], ['484

In [258]:
lsi_rerank.search(query)

[['2190', 0.8928388915573581],
 ['534', 0.8907757758888685],
 ['2609', 0.8737112224659194],
 ['295', 0.7910828509204736],
 ['1207', 0.6905327759571502],
 ['95', 0.633896559727317],
 ['27', 0.6201035284209588],
 ['2544', 0.6183692227612743],
 ['532', 0.6070137873362151],
 ['2143', 0.5913687178601861],
 ['365', 0.5883241037438187],
 ['2340', 0.5864244027382046],
 ['1373', 0.5829595452197873],
 ['1904', 0.573732793253155],
 ['676', 0.5629061757967366],
 ['1962', 0.5626685248105242],
 ['1029', 0.5531395791290783],
 ['1568', 0.5496369082614646],
 ['3203', 0.5401648498312118],
 ['1284', 0.5384766112284776],
 ['1329', 0.5383651093702],
 ['1403', 0.5357867318373987],
 ['330', 0.5304940612311105],
 ['1130', 0.5304274920576808],
 ['1993', 0.5249315449797478],
 ['2192', 0.5239275552198821],
 ['2323', 0.5182710013701234],
 ['480', 0.5016183481279133],
 ['210', 0.4870788590558732],
 ['2558', 0.4861100414454734],
 ['1896', 0.48535502411721265],
 ['2366', 0.47782405500815334],
 ['951', 0.476072331740

---
As you can see, it is much faster (but BM25 is still orders of magnitude faster).

---
## Section 10: Evaluation & Analysis (30 points) <a class="anchor" id="reranking_eval"></a>

[Back to Part 2](#part2)

[Previously](#evaluation) we have implemented some evaluation metrics and used them for measuring the ranking performance of term-based IR algorithms. In this section, we will do the same for semantic methods, both with and without re-ranking.

### Section 10.1: Plot (10 points)

First, gather the results. The results should consider the index set, the different search functions and different metrics. Plot the results in bar charts, per metric, with clear labels.

Then, gather only the re-ranking models, and plot and compare them with the results obtained in part 1 (only index set 2).

In [250]:
list_of_sem_search_fns = [
    ("lda", drm_lda.search),
    ("lsi", drm_lsi.search),
    ("w2v", drm_w2v.search),
    ("w2v_pretrained", drm_w2v_pretrained.search),
    ("d2v", drm_d2v.search),
    ("lsi_rr", lsi_rerank.search),
    ("lda_rr", lda_rerank.search),
    ("w2v_rr", w2v_rerank.search),
    ("w2v_pretrained_rr", w2v_pretrained_rerank.search),
    ("d2v_rr", d2v_rerank.search),
    
]

In [None]:
# YOUR CODE HERE
raise NotImplementedError()

### Section 10.2: Summary (20 points)

Your summary should compare methods from Part 1 and Part 2 (only for index set 2). State what you expected to see in the results, followed by either supporting evidence *or* justify why the results did not support your expectations. Consider the availability of data, scalability, domain/type of data, etc.

YOUR ANSWER HERE