# Retrieval System
This notebook implementes the retrievel system

In [30]:
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import os
import spacy

BERT_ENCODING_SIZE = 768

class RetrievalSystem:
    def __init__(self, path: str, retrieval_number: int = 16):
        """
        Constructor to initialize the RetrievalSystem with a CSV file.
        Args:
            path (str): The path to the CSV file to load.
        """
        self.model_type = 'all-MiniLM-L6-v2'
        self.retrieval_number = retrieval_number

        if os.path.exists(path):
            self.data = pd.read_csv(path)
        self.model = SentenceTransformer(self.model_type)
        self.nlp = spacy.load("en_core_web_sm")  # Load spaCy for preprocessing

    def preprocess_text(self, text: str) -> str:
        """
        Preprocesses the input text by removing stop words and applying lemmatization.
        Args:
            text (str): The text to preprocess.
        Returns:
            str: The preprocessed text.
        """
        doc = self.nlp(text)
        # Remove stop words and punctuation, and apply lemmatization
        preprocessed_text = " ".join(
            [token.lemma_ for token in doc if not token.is_stop and not token.is_punct]
        )
        return preprocessed_text

    def find_similar_entries_for_batch(self, texts: list, top_n: int = None, excluded_tickers: dict = None):
        """
        Embeds a batch of texts and finds the most similar entries in the dataset for each.
        Args:
            texts (list): List of input texts to embed and compare.
            excluded_tickers (dict): Dictionary where each key corresponds to the index of a text, and each value
                                     is a list of tickers to exclude for that text.
        Returns:
            list: A list of tuples containing embeddings and DataFrames for each text.
        """

        if not top_n:
            top_n = self.retrieval_number

        # Preprocess all texts
        processed_texts = [self.preprocess_text(text) for text in texts]

        # Generate embeddings for all input texts as a batch
        input_embeddings = self.model.encode(processed_texts)

        # Prepare the dataset
        if 'embedding' not in self.data.columns:
            raise ValueError("The CSV file must have an 'embedding' column.")

        copied_data = self.data.copy()

        # Convert embeddings column to lists if necessary
        if isinstance(copied_data['embedding'].iloc[0], str):
            copied_data['embedding'] = copied_data['embedding'].apply(eval)

        embeddings = copied_data['embedding'].tolist()
        dataset_embeddings = torch.tensor(embeddings, dtype=torch.float32)

        # Compute cosine similarity for all input embeddings
        input_embeddings = torch.tensor(input_embeddings, dtype=torch.float32)
        similarities = torch.matmul(input_embeddings, dataset_embeddings.T)  # Efficient batch cosine similarity

        # Collect top-N similar entries for each input text
        # Collect top-N similar entries for each input text
        results = []
        for i, sim in enumerate(similarities):
            # Add similarity scores directly to copied_data
            copied_data['similarity'] = sim.numpy()  # Replace or add similarity column

            # Filter the dataset for this text
            if excluded_tickers and i in excluded_tickers:
                excluded = excluded_tickers[i]
                filtered_data = copied_data[~copied_data['tickers'].isin(excluded)]
            else:
                filtered_data = copied_data

            # Sort and get top-N similar entries
            top_results = filtered_data.sort_values(by='similarity', ascending=False).head(top_n)

            # Append results
            results.append((input_embeddings[i].numpy(), top_results))

        return results

    def find_similar_entries(self, text: str, top_n: int = None, excluded_tickers=None):
        """
        Embeds the input text using BERT, compares it with the entries in the CSV file,
        and returns the most similar entries based on cosine similarity.
        Args:
            text (str): The input text to embed and compare.
            top_n (int): The number of most similar entries to return.
            excluded_tickers (list): List of tickers to exclude from similarity checks.
        Returns:
            pd.DataFrame: The top-n most similar entries from the CSV.
        """
        # Preprocess the input text
        text = self.preprocess_text(text)

        if not top_n:
            top_n = self.retrieval_number

        # Generate embedding for the preprocessed text
        input_embedding = self.model.encode([text])

        # Load embeddings from the CSV
        if 'embedding' not in self.data.columns:
            raise ValueError("The CSV file must have an 'embedding' column.")

        # Create a copy of self.data to work with
        copied_data = self.data.copy()

        # Exclude rows with tickers in excluded_tickers
        if excluded_tickers:
            copied_data = copied_data[~copied_data['tickers'].isin(excluded_tickers)]

        # Convert strings to lists only if they are strings
        if isinstance(copied_data['embedding'].iloc[0], str):
            copied_data['embedding'] = copied_data['embedding'].apply(eval)

        embeddings = copied_data['embedding'].tolist()

        # Compute cosine similarities
        similarities = cosine_similarity(input_embedding, embeddings)[0]
        copied_data['similarity'] = similarities

        # Sort by similarity and return the top N results
        return input_embedding, copied_data.sort_values(by='similarity', ascending=False).head(top_n)


    def process_and_save_embeddings(self, path: str, output_path: str):
        """
        Embeds the 'business_description' column from a new CSV file, keeps only 'tickers' and 'embedding',
        and saves the results in a new CSV with 'tickers' as the index.
        Args:
            path (str): The path to the CSV file to process.
            output_path (str): The path to save the output CSV.
        """
        # Load new data
        new_data = pd.read_csv(path)

        # Ensure required columns exist
        if 'tickers' not in new_data.columns:
            raise ValueError("The CSV file must have a 'tickers' column.")
        if 'business_description' not in new_data.columns:
            raise ValueError("The CSV file must have a 'business_description' column.")

        # Preprocess and embed the 'business_description' column
        new_data['processed_description'] = new_data['business_description'].apply(self.preprocess_text)
        new_data['embedding'] = new_data['processed_description'].apply(lambda x: self.model.encode([x])[0].tolist())

        # Keep only 'tickers' and 'embedding' columns
        processed_data = new_data[['tickers', 'embedding']]

        # Set 'tickers' as the index
        processed_data.set_index('tickers', inplace=True)

        # Save the processed data
        processed_data.to_csv(output_path)


### Creation of Embedding dataset
We create this in order for faster execution in our final user pripeline

In [31]:
# Define paths relative to the current working directory
INPUT_PATH = "../Dataset/Data/normalized_real_company_stock_dataset_large.csv"
OUTPUT_PATH = "Embeddings/embeddings.csv"

CREATE_DATASET = False
TEST = True

if __name__ == '__main__':
    if CREATE_DATASET:
        retrieval_system = RetrievalSystem(OUTPUT_PATH)
        retrieval_system.process_and_save_embeddings(INPUT_PATH, OUTPUT_PATH)

    if TEST:
        retrieval_system = RetrievalSystem(OUTPUT_PATH)
        own_idea = "Hello world program that can print hello world"
        idea = "American Assets Trust, Inc. is a full service, vertically integrated and self-administered real estate investment trust ('REIT'), headquartered in San Diego, California. The company has over 55 years of experience in acquiring, improving, developing and managing premier office, retail, and residential properties throughout the United States in some of the nation's most dynamic, high-barrier-to-entry markets primarily in Southern California, Northern California, Washington, Oregon, Texas and Hawaii. The company's office portfolio comprises approximately 4.1 million rentable square feet, and its retail portfolio comprises approximately 3.1 million rentable square feet. In addition, the company owns one mixed-use property (including approximately 94,000 rentable square feet of retail space and a 369-room all-suite hotel) and 2,110 multifamily units. In 2011, the company was formed to succeed to the real estate business of American Assets, Inc., a privately held corporation founded in 1967 and, as such, has significant experience, long-standing relationships and extensive knowledge of its core markets, submarkets and asset classes."
        result = retrieval_system.find_similar_entries(idea, 10)
        result_batch = retrieval_system.find_similar_entries_for_batch(texts=[idea, idea], top_n=10, excluded_tickers={0: ["AAT", "SVC"], 1: []})
        print(result_batch)

[(array([ 9.14912596e-02, -9.23526064e-02, -4.35219556e-02,  9.80845932e-03,
       -7.02796876e-02,  2.99639516e-02,  2.25052144e-02, -7.87077397e-02,
        6.23483360e-02, -2.21265014e-02,  2.13831961e-02,  4.86220457e-02,
        6.48474842e-02, -5.80228232e-02,  3.64770442e-02,  3.52781895e-03,
        1.03864428e-02,  2.50134002e-02, -2.46219188e-02,  7.25172460e-02,
        6.29633013e-03, -5.36851250e-02, -7.23614693e-02, -5.30710071e-03,
        3.83347273e-02, -6.81003258e-02, -4.54040021e-02,  8.33973885e-02,
        5.67590073e-03, -1.14184774e-01,  4.61823009e-02, -8.77585262e-03,
        4.18145955e-02, -9.97520797e-03,  1.15019351e-01,  7.10957497e-02,
       -5.31006269e-02, -5.99624449e-03, -5.12688980e-02, -1.75157730e-02,
       -1.89700928e-02,  1.02545014e-02,  4.85525429e-02,  1.21535463e-02,
        6.39032852e-03, -8.80433433e-03, -2.76258942e-02,  6.74318383e-03,
        1.05971806e-01,  6.92841783e-02, -2.76571382e-02,  6.44401610e-02,
        1.32183554e-02,