In [47]:
# # Import necessary libraries
# import torch
# from transformers import BertTokenizer, BertForSequenceClassification
# import numpy as np
# import re


# # Reload the saved model and tokenizer
# tokenizer = BertTokenizer.from_pretrained("./topic_model")
# model = BertForSequenceClassification.from_pretrained("./topic_model")

# # Define topic_labels (ensure it matches training)
# topic_labels = {
#     "Health": 0,
#     "Environment": 1,
#     "Technology": 2,
#     "Economy": 3,
#     "Entertainment": 4,
#     "Sports": 5,
#     "Politics": 6,
#     "Education": 7,
#     "Travel": 8,
#     "Food": 9,
#     # Add other topics...
# }
# label_to_topic = {idx: topic for topic, idx in topic_labels.items()}  # Reverse mapping for decoding

# # Define the prediction function

# def predict_multi_topics(query, model, tokenizer, topic_labels, percentile=0.8):
#     """Predict multiple topics for a given query based on a percentile threshold."""
#     # Identify the device
#     device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
#     model = model.to(device)

#     # Tokenize and move to device
#     encoding = tokenizer(
#         query,
#         max_length=128,
#         padding="max_length",
#         truncation=True,
#         return_tensors="pt",
#     )
#     input_ids = encoding["input_ids"].to(device)
#     attention_mask = encoding["attention_mask"].to(device)

#     # Get logits
#     model.eval()
#     with torch.no_grad():
#         outputs = model(input_ids, attention_mask=attention_mask)
#         logits = outputs.logits.squeeze(0).cpu().numpy()  # Move logits to CPU for processing

#     # Calculate dynamic threshold using percentile
#     threshold = np.percentile(logits, percentile * 100)

#     # Identify topics above threshold
#     topics = [
#         label_to_topic[idx] for idx, score in enumerate(logits) if score > threshold
#     ]
#     return topics

# # Preprocess the query
# def preprocess_query(query):
#     query = query.lower()
#     query = re.sub(r"[^\w\s]", "", query)  # Remove punctuation
#     query = re.sub(r"\d+", "", query)  # Remove numbers
#     return query.strip()

# # Example Usage
# query = input("Enter your query: ")
# query = preprocess_query(query)
# predicted_topics = predict_multi_topics(query, model, tokenizer, topic_labels, percentile=0.7)

# print(f"Query: {query}")
# print(f"Predicted Topics: {predicted_topics}")


Enter your query: how does climate change impact agriculture and health
Query: how does climate change impact agriculture and health
Predicted Topics: ['Health', 'Environment', 'Food']


In [26]:
# query = input("Enter your query: ")

# # Predict the topic
# predicted_topic = predict_topic(query, model, tokenizer, topic_labels)
# print(f"Query: {query}")
# print(f"Predicted Topic: {predicted_topic}")

Enter your query: soccer
Query: soccer
Predicted Topic: Sports


In [1]:
from transformers import pipeline

# Load the zero-shot classification pipeline
pipe = pipeline("zero-shot-classification", model="MoritzLaurer/bge-m3-zeroshot-v2.0")

# Define the topics
topics = [
    "Health",
    "Environment",
    "Technology",
    "Economy",
    "Entertainment",
    "Sports",
    "Politics",
    "Education",
    "Travel",
    "Food",
]

# Define the hypothesis template
hypothesis_template = "This query is about {}."

def classify_multi_topics(query, pipe, topics, threshold=0.3):
    """
    Classify a query into multiple topics based on a confidence threshold.

    Args:
    - query: The input query.
    - pipe: The Hugging Face zero-shot classification pipeline.
    - topics: List of predefined topics.
    - threshold: Confidence score threshold for multi-topic classification.

    Returns:
    - List of topics with scores above the threshold.
    """
    try:
        # Perform zero-shot classification
        output = pipe(query, topics, hypothesis_template=hypothesis_template)
        
        # Extract labels and scores
        labels = output["labels"]
        scores = output["scores"]
        
        # Filter topics based on the threshold
        multi_topics = [
            {"topic": label, "score": score}
            for label, score in zip(labels, scores)
            if score > threshold
        ]
        
        # Sort by score (optional, descending)
        multi_topics = sorted(multi_topics, key=lambda x: x["score"], reverse=True)
        
        print(f"Query: {query}")
        print(f"Predicted Topics:")
        for t in multi_topics:
            print(f"  - {t['topic']} (Score: {t['score']:.2f})")
        
        return multi_topics
    except Exception as e:
        print(f"Error during classification: {e}")
        return []

# Interactive loop for user input
print("Multi-Topic Classifier - Enter your query below (type 'exit' to quit):")
while True:
    query = input("Enter your query: ")
    if query.lower() == "exit":
        print("Exiting the classifier. Goodbye!")
        break
    classify_multi_topics(query, pipe, topics, threshold=0.2)


Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.


Multi-Topic Classifier - Enter your query below (type 'exit' to quit):
Enter your query: How does global warming impact crop production?
Query: How does global warming impact crop production?
Predicted Topics:
  - Environment (Score: 0.60)
  - Food (Score: 0.22)
Enter your query: How does air pollution affect respiratory diseases?
Query: How does air pollution affect respiratory diseases?
Predicted Topics:
  - Environment (Score: 0.58)
  - Health (Score: 0.41)
Enter your query: How is AI being used in diagnosing diseases?
Query: How is AI being used in diagnosing diseases?
Predicted Topics:
  - Technology (Score: 0.81)
Enter your query: What are the political implications of cryptocurrency regulations?
Query: What are the political implications of cryptocurrency regulations?
Predicted Topics:
  - Politics (Score: 0.57)
  - Technology (Score: 0.32)
Enter your query: How is technology used in sports analytics?
Query: How is technology used in sports analytics?
Predicted Topics:
  - Sport