In [105]:
from question_classifier import QuestionClassifier
from sentence_transformers import SentenceTransformer, util
from datasets import load_dataset
import pandas as pd
import pickle
import os

class QuestionClassifierWrapper:
    "Gets a question and using the probability output returns the exact places to do text retrieval from."
    def __init__(
            self, 
            main_categorization_model_dir: str = "model",
            subcategorization_model_dir: str = "subcat_models/"
        ):
        self.categorized_data = load_dataset("msaad02/categorized-data", split="train").to_pandas()
        embeddings = pickle.load(open("embeddings.pickle", "rb"))
        self.data = embeddings['data']
        self.embeddings = embeddings['embeddings']
        self.main_classifier = QuestionClassifier(model_dir=main_categorization_model_dir)
        self.embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5')
        self.reranker_model = SentenceTransformer('msmarco-distilbert-base-v4', device='cuda')
        self.subcategory_classifiers = {}
        for subcat in os.listdir(subcategorization_model_dir):
            self.subcategory_classifiers[subcat] = QuestionClassifier(subcategorization_model_dir + subcat)

    def _predict(self, question: str, return_probabilities: bool = False):
        "Raw interface between the classifier and the user."
        prediction = {}
        if return_probabilities:
            prediction['category'], prediction['main_probs'] = self.main_classifier.predict(question, True)
        else:
            prediction['category'] = self.main_classifier.predict(question)

        category = prediction['category']
        if category in self.subcategory_classifiers:
            subcategory_classifier = self.subcategory_classifiers[category]

            if return_probabilities:
                prediction['subcategory'], sub_probs = subcategory_classifier.predict(question, True)
                prediction['sub_probs'] = {f'{category}-{subcat}': prob for subcat, prob in sub_probs.items()}
            else:
                prediction['subcategory'] = subcategory_classifier.predict(question)
        return prediction
    
    def _get_text_retrieval_places(self, question: str):
        """
        High level interface between the classifier and the user. Tells us where to do text retrieval from. Based on the probability output of the categorization models.

        It does this by returning the top categories with confidence > 0.2 of the highest probability category. (I refer to confidence as the model's probability output.)

        Returns:
            dict: {
                'main_categories': [str],
                'subcategories': [str]
            }
        """
        prediction = self._predict(question, True)

        # main category
        main_cat_probs_df = pd.DataFrame(
            [(category, prob) for category, prob in prediction['main_probs'].items()], 
            columns=['category', 'probability']
        ).sort_values(by='probability', ascending=False).reset_index(drop=True)

        # Use all categories at the top within 0.2 of the best category

        # Highest category probability
        max_main_prob = main_cat_probs_df['probability'][0]

        # Categories within 0.2 of the highest category
        main_categories_to_use = main_cat_probs_df[main_cat_probs_df['probability'] > max_main_prob - 0.2]['category'].tolist()

        if 'sub_probs' in prediction.keys():
            subcategory_probs_df = pd.DataFrame(
                [(category, prob) for category, prob in prediction['sub_probs'].items()], 
                columns=['category', 'probability']
            ).sort_values(by='probability', ascending=False).reset_index(drop=True)

            # Highest subcategory probability
            max_sub_prob = subcategory_probs_df['probability'][0]

            # Subcategories within 0.2 of the highest subcategory
            subcategories_to_use = subcategory_probs_df[subcategory_probs_df['probability'] > max_sub_prob - 0.2]['category'].tolist()

        text_retreival_places = {
            'main_categories': main_categories_to_use,
            'subcategories': subcategories_to_use if 'sub_probs' in prediction.keys() else []
        }

        return text_retreival_places

    def retreive_text(self, question: str, top_n: int = 10):
        """
        This is the last step of retreival. The next (and final) step is to rerank the results using the reranker model.

        The output of this model is the top n results using semantic search. The results it is pulling from are the ones that are in the categories returned by the _get_text_retrieval_places function, which itself is using the probability output of the categorization models.
        """

        text_retrieval_places = self._get_text_retrieval_places(question)
        
        question_embedding = self.embedding_model.encode("tell me about the nursing program", normalize_embeddings=True)

        text_embedding_for_question = []
        raw_text_for_question = []

        for category in text_retrieval_places['main_categories']:
            if category in self.embeddings.keys():
                text_embedding_for_question.extend(self.embeddings[category])
                raw_text_for_question.extend(self.data[category])
            else:
                print(f"Category {category} not found in the embeddings dictionary.")

        for subcategory in text_retrieval_places['subcategories']:
            if subcategory in self.embeddings.keys():
                text_embedding_for_question.extend(self.embeddings[subcategory])
                raw_text_for_question.extend(self.data[subcategory])
            else:
                print(f"Subcategory {subcategory} not found in the embeddings dictionary.")

        similarity = text_embedding_for_question @ question_embedding.T
        top_args = similarity.argsort()[::-1][:top_n]

        data = pd.DataFrame(
            [(raw_text_for_question[i], similarity[i]) for i in top_args], 
            columns=['text', 'similarity']
        ).sort_values(by='similarity', ascending=False).reset_index(drop=True)

        return data

In [106]:
classifier = QuestionClassifierWrapper()

In [108]:
classifier.retreive_text("tell me about the nursing program")

Category academics not found in the embeddings dictionary.


Unnamed: 0,text,similarity
0,What You’ll Learn\nOur 42-credit program combi...,0.727046
1,My name is Dr. Kathleen Peterson and I’m a Pro...,0.726161
2,What You’ll Learn\nYou’ll be prepared to care ...,0.72472
3,Applications are now open\nApplications for th...,0.722066
4,What You’ll Learn\nThe RN-BSN Fast Track Compl...,0.708482
5,What You’ll Learn\nOur 42-credit program combi...,0.692306
6,Learn from the Best in the Field\nOur programs...,0.689883
7,The world needs more nurses. Our degrees produ...,0.688628
8,What You’ll Learn\nThis part-time program is d...,0.683548
9,Letter to All Prospective Nursing Students for...,0.679201


# Reranking

The final step for text retrieval here

In [4]:
# Reranking
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-large')
model = AutoModelForSequenceClassification.from_pretrained('BAAI/bge-reranker-large')
model.eval()

pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]
with torch.no_grad():
    inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
    scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
    print(scores)


tensor([-5.6085,  5.7623])


In [10]:

pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]
with torch.no_grad():
    inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
    scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
    print(scores)

tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
tensor([-5.6085,  5.7623])
t