In [1]:
from question_classifier import QuestionClassifier
import os

class QuestionClassifierWrapper:
    "Gets a question and using the probability output returns the exact places to do text retrieval from."
    def __init__(
            self, 
            main_categorization_model_dir: str = "model",
            subcategorization_model_dir: str = "subcat_models"
        ):
        self.main_classifier = QuestionClassifier(model_dir=main_categorization_model_dir)
        self.subcategory_classifiers = {}
        for subcat in os.listdir(subcategorization_model_dir):
            self.subcategory_classifiers[subcat] = QuestionClassifier(subcategorization_model_dir + subcat)

    def _predict(self, question: str, return_probabilities: bool = False):
        "Raw interface between the classifier and the user."
        prediction = {}
        if return_probabilities:
            prediction['category'], prediction['main_probs'] = self.main_classifier.predict(question, True)
        else:
            prediction['category'] = self.main_classifier.predict(question)

        if prediction['category'] in self.subcategory_classifiers:
            subcategory_classifier = self.subcategory_classifiers[prediction['category']]

            if return_probabilities:
                prediction['subcategory'], prediction['sub_probs'] = subcategory_classifier.predict(question, True)
            else:
                prediction['subcategory'] = subcategory_classifier.predict(question)
        return prediction
    
    def get_text_retrieval_places(self, question: str):
        "Higher level interface between the classifier and the user. Tells us where to do text retrieval from."
        print("hi")
        return 1
    
    def retreive_text(self, question: str):
        "Retrieves text"

In [2]:
classifier = QuestionClassifierWrapper()

classifier.predict("tell me about the honors college")

{'category': 'academics', 'subcategory': 'honors-college'}