In [16]:
import spacy
from spacy.tokens import DocBin
from spacy.training import Example
from spacy.pipeline.textcat_multilabel import DEFAULT_MULTI_TEXTCAT_MODEL
from typing import List, Dict
import random

In [None]:
class QueryClassifier:
    def __init__(self, labels: List[str]):
        self.labels = labels
        self.nlp = spacy.blank("en")
        self.threshold = 0.5
        self.textcat = self.nlp.add_pipe(
            "textcat_multilabel",
            config={
                "model": DEFAULT_MULTI_TEXTCAT_MODEL
            }
        )
        
        for label in self.labels:
            self.textcat.add_label(label)
            
        self.train_examples = []
        
    def add_training_examples(self, examples: List[Dict]):
        for example in examples:
            doc = self.nlp.make_doc(example["text"])
            example_dict = {"cats": example["labels"]}
            self.train_examples.append(Example.from_dict(doc, example_dict))
            
    def train(self, n_iter: int = 10):
        optimizer = self.nlp.begin_training()
        
        other_pipes = [pipe for pipe in self.nlp.pipe_names if pipe != "textcat_multilabel"]
        
        with self.nlp.disable_pipes(*other_pipes):  # Only train textcat
            for iteration in range(n_iter):
                random.shuffle(self.train_examples)
                losses = {}
                
                for example in self.train_examples:
                    self.nlp.update([example], sgd=optimizer, losses=losses)
                    
                print(f"Iteration {iteration + 1}, Losses:", losses)
    
    def predict(self, text: str) -> Dict[str, float]:
        doc = self.nlp(text)
        predictions = {}
        for label, score in doc.cats.items():
            if score >= self.threshold:
                predictions[label] = score
        return predictions
    
    def save_model(self, path: str):
        self.nlp.to_disk(path)
        
    def load_model(self, path: str):
        self.nlp = spacy.load(path)

In [55]:
if __name__ == "__main__":
    labels = ["classify", "factcheck", "summarize", "analyze", "detail"]
    classifier = QueryClassifier(labels)    
    training_data = [
        # ----------------------------------------------------------------------------------------------------
        # Classify examples
        {
            "text": "Identify the biome this ecosystem belongs to.",
            "labels": {"classify": 1.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "What kind of book is this?",
            "labels": {"classify": 1.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "What type of player is Messi",
            "labels": {"classify": 1.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "Determine the programming language used in this code.",
            "labels": {"classify": 1.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "Classify this astronomical object based on its spectral type.",
            "labels": {"classify": 1.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "What type of role does this player have?",
            "labels": {"classify": 1.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "Categorize this text by genre",
            "labels": {"classify": 1.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "Detect the sentiment of this review",
            "labels": {"classify": 1.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "What category does this artwork belong to?",
            "labels": {"classify": 1.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "Identify the architectural style of this building",
            "labels": {"classify": 1.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "Determine the musical genre of this composition",
            "labels": {"classify": 1.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 0.0, "detail": 0.0}
        },
        # ----------------------------------------------------------------------------------------------------
        # Factcheck examples
        {
            "text": "Is it true that this drug was FDA-approved in 2019?",
            "labels": {"classify": 0.0, "factcheck": 1.0, "summarize": 0.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "Is this article reliable?",
            "labels": {"classify": 0.0, "factcheck": 1.0, "summarize": 0.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "Tell me if this source is trustworthy",
            "labels": {"classify": 0.0, "factcheck": 1.0, "summarize": 0.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "Verify whether this artifact is from india.",
            "labels": {"classify": 0.0, "factcheck": 1.0, "summarize": 0.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "Is this statement true?",
            "labels": {"classify": 0.0, "factcheck": 1.0, "summarize": 0.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "Verify if this claim is authentic",
            "labels": {"classify": 0.0, "factcheck": 1.0, "summarize": 0.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "Check if this statistic is correct",
            "labels": {"classify": 0.0, "factcheck": 1.0, "summarize": 0.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "Can you validate these research findings?",
            "labels": {"classify": 0.0, "factcheck": 1.0, "summarize": 0.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "Are these historical dates accurate?",
            "labels": {"classify": 0.0, "factcheck": 1.0, "summarize": 0.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "Verify the authenticity of this source",
            "labels": {"classify": 0.0, "factcheck": 1.0, "summarize": 0.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "can you check this information",
            "labels": {"classify": 0.0, "factcheck": 1.0, "summarize": 0.0, "analyze": 0.0, "detail": 0.0}
        },
        # ----------------------------------------------------------------------------------------------------        
        # Summarize examples
        {
            "text": "Summarize the core argument of this legal case.",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 1.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "Give me a quick overview of this paper",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 1.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "Extract the critical events from this historical timeline.",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 1.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "Give me a brief overview",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 1.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "Summarize this research paper in two sentences",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 1.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "Extract key points from this speech",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 1.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "Create an executive summary of this report",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 1.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "What are the main takeaways from this presentation?",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 1.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "Condense this document into key bullet points",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 1.0, "analyze": 0.0, "detail": 0.0}
        },
        {
            "text": "I need a concise version of this",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 1.0, "analyze": 0.0, "detail": 0.0}
        },
        # ----------------------------------------------------------------------------------------------------        
        # Analyze examples
        {
            "text": "Examine the impact of urbanization on local wildlife.",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 1.0, "detail": 0.0}
        },
        {
            "text": "Break down the components of this data",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 1.0, "detail": 0.0}
        },
        {
            "text": "what patterns do you see in this data?",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 1.0, "detail": 0.0}
        },
        {
            "text": "What insights can we draw from this?",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 1.0, "detail": 0.0}
        },
        {
            "text": "Analyze the structural weaknesses in this engineering design.",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 1.0, "detail": 0.0}
        },
        {
            "text": "Break down the key factors behind this geopolitical conflict.",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 1.0, "detail": 0.0}
        },
        {
            "text": "Break down the themes in this novel",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 1.0, "detail": 0.0}
        },
        {
            "text": "Find patterns in this financial report",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 1.0, "detail": 0.0}
        },
        {
            "text": "What are the underlying trends in this dataset?",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 1.0, "detail": 0.0}
        },
        {
            "text": "What correlations exist between these two features",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 1.0, "detail": 0.0}
        },
        # ----------------------------------------------------------------------------------------------------        
        # detail examples
        {
            "text": "Can you tell me why light has fastest speed in world.",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 0.0, "detail": 1.0}
        },
        {
            "text": "Help me understand this concept",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 0.0, "detail": 1.0}
        },
        {
            "text": "Why is the sky blue?",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 0.0, "detail": 1.0}
        },
        {
            "text": "Describe in depth how reinforcement learning optimizes decision-making.",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 0.0, "detail": 1.0}
        },
        {
            "text": "Explain why the stock market crashed",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 0.0, "detail": 1.0}
        },
        {
            "text": "How does quantum computing work?",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 0.0, "detail": 1.0}
        },
        {
            "text": "Clarify the working of this algorithm",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 0.0, "detail": 1.0}
        },
        {
            "text": "What's the reasoning behind this policy decision?",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 0.0, "detail": 1.0}
        },
        {
            "text": "Describe how acid and alkali chemical reaction occurs",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 0.0, "detail": 1.0}
        },
        {
            "text": "Why do these economic indicators behave this way?",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 0.0, "detail": 1.0}
        },
        {
            "text": "Why does this happen?",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 0.0, "detail": 1.0}
        },
        {
             "text": "What does heliocentrism mean?",
            "labels": {"classify": 0.0, "factcheck": 0.0, "summarize": 0.0, "analyze": 0.0, "detail": 1.0}
        }
    ]
    
    classifier.add_training_examples(training_data)    
    classifier.train(n_iter=10)    
    test_queries = [
        "What type is this?",
        "Can you check this information?",
        "Provide a smaller version of this",
        "What is meaning of sex?"
    ]
    
    for query in test_queries:
        predictions = classifier.predict(query)
        print(f"\nQuery: {query}")
        print("Predictions:", predictions)

Iteration 1, Losses: {'textcat_multilabel': 10.689808355644345}
Iteration 2, Losses: {'textcat_multilabel': 4.490052096996806}
Iteration 3, Losses: {'textcat_multilabel': 2.2140213799430057}
Iteration 4, Losses: {'textcat_multilabel': 1.208807232434765}
Iteration 5, Losses: {'textcat_multilabel': 0.16249702074901506}
Iteration 6, Losses: {'textcat_multilabel': 0.011763732602275923}
Iteration 7, Losses: {'textcat_multilabel': 0.0028402866773831192}
Iteration 8, Losses: {'textcat_multilabel': 0.0022265365814746474}
Iteration 9, Losses: {'textcat_multilabel': 0.0018599433481085725}
Iteration 10, Losses: {'textcat_multilabel': 0.0016206138791403646}

Query: What type is this?
Predictions: {'classify': 0.9648764729499817}

Query: Can you check this information?
Predictions: {'factcheck': 0.9997243285179138}

Query: Provide a smaller version of this
Predictions: {'summarize': 0.8881931900978088}

Query: What is meaning of sex?
Predictions: {'detail': 0.6537548303604126}


In [56]:
classifier.save_model("intent_classifier")
print("Model saved successfully!")

Model saved successfully!


In [None]:
def load_and_test_model():
    labels = ["classify", "factcheck", "summarize", "analyze", "detail"]
    loaded_classifier = QueryClassifier(labels)    
    loaded_classifier.load_model("intent_classifier")
    
    test_queries = [
        "What kind of movie is this?",
        "Is this news article reliable?",
        "Give me a quick overview of this paper",
        "Break down the components of this data",
        "Why is the sky blue?",
        "Help me understand this concept",
        "What patterns do you see in this data?",
        "Tell me if this source is trustworthy"
    ]
    
    print("\nTesting model with new queries:")
    print("--------------------------------")
    for query in test_queries:
        predictions = loaded_classifier.predict(query)
        print(f"\nQuery: {query}")
        
        if predictions:
            sorted_predictions = dict(sorted(predictions.items(), 
                                          key=lambda x: x[1], 
                                          reverse=True))
            print("Predicted Intents:")
            for intent, confidence in sorted_predictions.items():
                print(f"- {intent}: {confidence:.2%}")
        else:
            print("No strong predictions (all confidence scores below threshold)")

In [None]:
if __name__ == "__main__":
    load_and_test_model()


Testing model with new queries:
--------------------------------

Query: What kind of movie is this?
Predicted Intents:
- classify: 98.65%

Query: Is this news article reliable?
Predicted Intents:
- factcheck: 99.66%

Query: Give me a quick overview of this paper
Predicted Intents:
- summarize: 99.86%

Query: Break down the components of this data
Predicted Intents:
- analyze: 99.97%

Query: Why is the sky blue?
Predicted Intents:
- detail: 98.90%

Query: Help me understand this concept
Predicted Intents:
- detail: 99.30%

Query: What patterns do you see in this data?
Predicted Intents:
- analyze: 99.18%

Query: Tell me if this source is trustworthy
Predicted Intents:
- factcheck: 99.74%
