In [87]:
from question_classifier import QuestionClassifier
from sentence_transformers import SentenceTransformer, util
from datasets import load_dataset
import pandas as pd
import os

class QuestionClassifierWrapper:
    "Gets a question and using the probability output returns the exact places to do text retrieval from."
    def __init__(
            self, 
            main_categorization_model_dir: str = "model",
            subcategorization_model_dir: str = "subcat_models/"
        ):
        self.categorized_data = load_dataset("msaad02/categorized-data", split="train").to_pandas()
        self.main_classifier = QuestionClassifier(model_dir=main_categorization_model_dir)
        self.reranker_model = SentenceTransformer('msmarco-distilbert-base-v4', device='cuda')
        self.subcategory_classifiers = {}
        for subcat in os.listdir(subcategorization_model_dir):
            self.subcategory_classifiers[subcat] = QuestionClassifier(subcategorization_model_dir + subcat)

    def _predict(self, question: str, return_probabilities: bool = False):
        "Raw interface between the classifier and the user."
        prediction = {}
        if return_probabilities:
            prediction['category'], prediction['main_probs'] = self.main_classifier.predict(question, True)
        else:
            prediction['category'] = self.main_classifier.predict(question)

        if prediction['category'] in self.subcategory_classifiers:
            subcategory_classifier = self.subcategory_classifiers[prediction['category']]

            if return_probabilities:
                prediction['subcategory'], prediction['sub_probs'] = subcategory_classifier.predict(question, True)
            else:
                prediction['subcategory'] = subcategory_classifier.predict(question)
        return prediction
    
    def _get_text_retrieval_places(self, question: str):
        """
        High level interface between the classifier and the user. Tells us where to do text retrieval from. Based on the probability output of the categorization models.

        It does this by returning the top categories with confidence > 0.2 of the highest probability category. (I refer to confidence as the model's probability output.)

        Returns:
            dict: {
                'main_categories': [str],
                'subcategories': [str]
            }
        """
        prediction = self._predict(question, True)

        # main category
        main_cat_probs_df = pd.DataFrame(
            [(category, prob) for category, prob in prediction['main_probs'].items()], 
            columns=['category', 'probability']
        ).sort_values(by='probability', ascending=False).reset_index(drop=True)

        # Use all categories at the top within 0.2 of the best category

        # Highest category probability
        max_main_prob = main_cat_probs_df['probability'][0]

        # Categories within 0.2 of the highest category
        main_categories_to_use = main_cat_probs_df[main_cat_probs_df['probability'] > max_main_prob - 0.2]['category'].tolist()

        if 'sub_probs' in prediction.keys():
            subcategory_probs_df = pd.DataFrame(
                [(category, prob) for category, prob in prediction['sub_probs'].items()], 
                columns=['category', 'probability']
            ).sort_values(by='probability', ascending=False).reset_index(drop=True)

            # Highest subcategory probability
            max_sub_prob = subcategory_probs_df['probability'][0]

            # Subcategories within 0.2 of the highest subcategory
            subcategories_to_use = subcategory_probs_df[subcategory_probs_df['probability'] > max_sub_prob - 0.2]['category'].tolist()

        text_retreival_places = {
            'main_categories': main_categories_to_use,
            'subcategories': subcategories_to_use if 'sub_probs' in prediction.keys() else []
        }

        return text_retreival_places

    def retreive_text(self, question: str, top_n: int = 3):
        "Retrieves text using the text retrieval places."

        text_retreival_places = self._get_text_retrieval_places(question)
        main_categories_to_use = text_retreival_places['main_categories']
        subcategories_to_use = text_retreival_places['subcategories']
                
        filtered_catgory_df = self.categorized_data[self.categorized_data['category'].isin(main_categories_to_use)]
        if subcategories_to_use != []:
            filtered_catgory_df = filtered_catgory_df[filtered_catgory_df['subcategory'].isin(subcategories_to_use)]

        data_to_embed = filtered_catgory_df['data'].to_list()

        query_embedding = self.reranker_model.encode(question)
        passage_embedding = self.reranker_model.encode(data_to_embed)


        return text_retreival_places, filtered_catgory_df, util.cos_sim(query_embedding, passage_embedding)

In [88]:
classifier = QuestionClassifierWrapper()

In [89]:
sim = classifier.retreive_text("How can I apply?")

In [91]:
df = sim[1]

In [95]:
df['similarity'] = sim[2][0]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['similarity'] = sim[2][0]


In [99]:
df.sort_values(by='similarity', ascending=False)['data'].head(10)

764     International Admissions\nFind out what you ne...
2511    Advertising and Sponsorship\n-\nProof of ad is...
1365    Application Qualifications\n- First-time stude...
2542    Affirming enrollment is necessary for loan def...
500     All majors and certification programs offered ...
1316    Participate in your own learning outcomes to b...
768     Choose Your Application\nSUNY Brockport operat...
1477    How to Apply\nThis is a multi-step application...
2130    Financial Aid Requirements\nWhere do I Start?\...
1430    Incoming Undergraduate First Year (freshmen) a...
Name: data, dtype: object

In [93]:
sim[2]

tensor([[-0.0400,  0.0268, -0.1078,  ..., -0.0432, -0.0545, -0.1131]])

In [41]:
# def gradio_predict(question):
#     return classifier.get_text_retrieval_places(question)

# import gradio as gr

# # output json
# demo = gr.Interface(fn=gradio_predict, inputs="text", outputs="json")
    
# if __name__ == "__main__":
#     demo.launch(share=True)   