# Zero-Shot Classification on PreTrained BART Model

In [11]:
from transformers import pipeline
import torch

# check if gpu is available
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

# Load the zero-shot classification model
classifier = pipeline("zero-shot-classification", 
                      model="facebook/bart-large-mnli",
                      device=device,
                      )

# Define the labels
candidate_labels = ["sleep", "car"]

# Function to classify the topic of the query
def classify_query(text: str) -> str:
    result = classifier(text, candidate_labels)
    return result  # Return the most likely topic


# Test the classifier
queries = [
    "What's the impact of sleep deprivation on cognitive function?",
    "Can you explain the history of the internal combustion engine?",
    "How does REM sleep affect memory consolidation?",
    "What are the main components of an electric vehicle's drivetrain?",
]

for query in queries:
    result = classify_query(query)
    print("-" * 50)
    print(f"Query: {query}")
    print(f"Classified as: {result['labels'][0]}, score: {result['scores'][0]}\n")

--------------------------------------------------
Query: What's the impact of sleep deprivation on cognitive function?
Classified as: sleep, score: 0.9878818392753601

--------------------------------------------------
Query: Can you explain the history of the internal combustion engine?
Classified as: car, score: 0.856473445892334

--------------------------------------------------
Query: How does REM sleep affect memory consolidation?
Classified as: sleep, score: 0.9879106879234314

--------------------------------------------------
Query: What are the main components of an electric vehicle's drivetrain?
Classified as: car, score: 0.9461274147033691



# `TODO`: Synthetic Data Generation for fine tuning - router model

When developing a robust multi-model LLM system, one critical component is the router model, which directs queries to the appropriate domain-specific model. For our system, we use a zero-shot classification model (BART-large-mnli) to classify queries into either "sleep" or "car" categories. However, the effectiveness of this router model heavily depends on the quality and quantity of labeled training data.

## Training with Existing Datasets

If we have access to high-quality, labeled datasets for sleep and car topics, we can directly use these datasets to train our zero-shot classification model. These datasets should contain a variety of questions and answers related to each domain, ensuring that the model learns to accurately classify a wide range of queries.

For example:
    * Sleep Dataset: Contains questions about sleep science, sleep disorders, sleep hygiene, and related topics.
    * Car Dataset: Encompasses questions about car history, automotive technology, car maintenance, and more.

## Generating Synthetic Data with Advanced LLMs

In cases where existing datasets are insufficient or unavailable, we can leverage advanced language models like GPT-4 to generate high-quality synthetic datasets. This approach involves using GPT-4 to create realistic and contextually accurate questions and answers for both sleep and car domains. The generated data can then be labeled and used to train the router model.

## Steps to Generate Synthetic Data

1. Define Prompts for Data Generation:
    * Create prompts that instruct GPT-4 to generate questions and answers related to sleep and car topics.
    * Example prompts:
        * "Generate 10 questions and answers about sleep science."
        * "Create 10 questions and answers about car maintenance."
2. Generate Data:
    * Use GPT-4 to generate the synthetic data based on the defined prompts.
    * Ensure the generated data covers a wide range of topics within each domain to improve the robustness of the router model.
3. Label the Data:
    * Label each generated question with the appropriate category ("sleep" or "car").
    * This labeling can be automated if the prompts are designed to include the category in the output.
4. Train the Router Model:
    * Combine the synthetic data with any existing labeled data.
    * Train the zero-shot classification model on this combined dataset to improve its accuracy and generalization.

NOTE: This process can be used for Memorization tasks to fine tune LLMs.