# Task Router 

In this example, we will explore NVIDIA's prompt classification model that analyzes prompts based on task type. The model serves as an intelligent router for directing prompts to appropriate processing pipelines.

## Model Information
* Source: [NVIDIA's prompt-task-and-complexity-classifier](https://huggingface.co/nvidia/prompt-task-and-complexity-classifier)
* Architecture: DeBERTa-v3-base backbone
* Purpose: Multi-task classification for prompt analysis
* Output: Task types and multiple complexity metrics

## Task Categories
The model classifies prompts into various task types

![task_types](assets/Task_Categorization.png "Task Type Categories")

* Open QA (Open Question-Answering)
    * Questions requiring general world knowledge without specific context
    * Example: "What causes earthquakes and how do they occur?"
* Closed QA (Closed Question-Answering)
    * Questions requiring analysis of provided information/context
    * Example: "Based on the patient's symptoms described above, what is the likely diagnosis?"
* Summarization
    * Tasks requiring condensing longer text into key points
    * Example: "Summarize this research paper about climate change in three sentences."
* Text Generation
    * Creating original content based on given parameters
    * Example: "Write a product description for a new smartphone."
* Code Generation
    * Creating or completing programming code
    * Example: "Write a Python script that uses a for loop to calculate Fibonacci numbers."
* Chatbot
    * Conversational interactions requiring context maintenance
    * Example: "I need help tracking my order status."
* Classification
    * Categorizing or labeling content into predefined groups
    * Example: "Is this customer review positive or negative?"
* Rewrite
    * Rephrasing existing content while maintaining meaning
    * Example: "Rewrite this paragraph in simpler terms for a middle school student."
* Brainstorming
    * Creative ideation and generation of multiple options
    * Example: "Generate five unique marketing campaign ideas for a new coffee shop."
* Extraction
    * Pulling specific information from provided content
    * Example: "Extract all the dates and locations mentioned in this news article."
* Other
    * Tasks that don't fit into standard categories
    * Example: "Help me debug why my printer isn't working."

## Routing Suggestions

Since some of the tasks overlap and have very negligible differences in their probability distribution, we can group some tasks into the same LLM endpoint

* Code
* Open-QA
* (Closed QA, Extraction)
* (Rewrite, Summarization, Text Generation)
* Classification
* ChatBot
* (Other, Unknown)
* Brainstorming

In [None]:
import numpy as np
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
from transformers import AutoConfig, AutoModel, AutoTokenizer

class MeanPooling(nn.Module):
    def __init__(self):
        super(MeanPooling, self).__init__()

    def forward(self, last_hidden_state, attention_mask):
        input_mask_expanded = (
            attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        )
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)

        sum_mask = input_mask_expanded.sum(1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)

        mean_embeddings = sum_embeddings / sum_mask
        return mean_embeddings


class MulticlassHead(nn.Module):
    def __init__(self, input_size, num_classes):
        super(MulticlassHead, self).__init__()
        self.fc = nn.Linear(input_size, num_classes)

    def forward(self, x):
        x = self.fc(x)
        return x

class CustomModel(nn.Module, PyTorchModelHubMixin):
    def __init__(self, target_sizes, task_type_map, weights_map, divisor_map):
        super(CustomModel, self).__init__()

        self.backbone = AutoModel.from_pretrained("microsoft/DeBERTa-v3-base")
        self.target_sizes = target_sizes.values()
        self.task_type_map = task_type_map
        self.weights_map = weights_map
        self.divisor_map = divisor_map

        self.heads = [
            MulticlassHead(self.backbone.config.hidden_size, sz)
            for sz in self.target_sizes
        ]

        for i, head in enumerate(self.heads):
            self.add_module(f"head_{i}", head)

        self.pool = MeanPooling()

    def forward(self, batch):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)

        last_hidden_state = outputs.last_hidden_state
        mean_pooled_representation = self.pool(last_hidden_state, attention_mask)

        logits = [
            self.heads[k](mean_pooled_representation)
            for k in range(len(self.target_sizes))
        ]

        # return self.process_logits(logits)
        return logits

class LogitsProcessor:
    def __init__(self, task_type_map, weights_map, divisor_map):
        self.task_type_map = task_type_map
        self.weights_map = weights_map
        self.divisor_map = divisor_map
        # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.targets = [
            "task_type", "creativity_scope", "reasoning", "contextual_knowledge",
            "number_of_few_shots", "domain_knowledge", "no_label_reason", "constraint_ct"
        ]

    def compute_results(self, preds, target, decimal=4):
        if target == "task_type":
            task_type = {}

            top2_indices = torch.topk(preds, k=2, dim=1).indices
            softmax_probs = torch.softmax(preds, dim=1)
            top2_probs = softmax_probs.gather(1, top2_indices)
            top2 = top2_indices.detach().cpu().tolist()
            top2_prob = top2_probs.detach().cpu().tolist()

            top2_strings = [
                [self.task_type_map[str(idx)] for idx in sample] for sample in top2
            ]
            top2_prob_rounded = [
                [round(value, 3) for value in sublist] for sublist in top2_prob
            ]

            counter = 0
            for sublist in top2_prob_rounded:
                if sublist[1] < 0.1:
                    top2_strings[counter][1] = "NA"
                counter += 1

            task_type_1 = [sublist[0] for sublist in top2_strings]
            task_type_2 = [sublist[1] for sublist in top2_strings]
            task_type_prob = [sublist[0] for sublist in top2_prob_rounded]

            return (task_type_1, task_type_2, task_type_prob)
        else:
            preds = torch.softmax(preds, dim=1)

            weights = np.array(self.weights_map[target])
            weighted_sum = np.sum(np.array(preds.detach().cpu()) * weights, axis=1)
            scores = weighted_sum / self.divisor_map[target]

            scores = [round(value, decimal) for value in scores]
            if target == "number_of_few_shots":
                scores = [x if x >= 0.05 else 0 for x in scores]
            return scores

    def process_logits(self, logits):
        result = {}

        for i, target in enumerate(self.targets):
            logits_tensor = torch.from_numpy(logits[i]).float()
            
            if target == "task_type":
                task_type_results = self.compute_results(logits_tensor, target=target)
                result["task_type_1"] = task_type_results[0]
                result["task_type_2"] = task_type_results[1]
                result["task_type_prob"] = task_type_results[2]
            else:
                result[target] = self.compute_results(logits_tensor, target=target)

        # Calculate prompt_complexity_score
        result["prompt_complexity_score"] = [
            round(
                0.35 * creativity
                + 0.25 * reasoning
                + 0.15 * constraint
                + 0.15 * domain_knowledge
                + 0.05 * contextual_knowledge
                + 0.05 * few_shots,
                5,
            )
            for creativity, reasoning, constraint, domain_knowledge, contextual_knowledge, few_shots in zip(
                result["creativity_scope"],
                result["reasoning"],
                result["constraint_ct"],
                result["domain_knowledge"],
                result["contextual_knowledge"],
                result["number_of_few_shots"],
            )
        ]

        return result

# Test the original model

In [None]:
config = AutoConfig.from_pretrained("nvidia/prompt-task-and-complexity-classifier")
tokenizer = AutoTokenizer.from_pretrained(
    "nvidia/prompt-task-and-complexity-classifier"
)
model = CustomModel(
    target_sizes=config.target_sizes,
    task_type_map=config.task_type_map,
    weights_map=config.weights_map,
    divisor_map=config.divisor_map,
).from_pretrained("nvidia/prompt-task-and-complexity-classifier")
model.eval()

prompt = ["Prompt: Write a Python script that uses a for loop."]

encoded_texts = tokenizer(
    prompt,
    return_tensors="pt",
    add_special_tokens=True,
    max_length=512,
    padding="max_length",
    truncation=True,
)

result = model(encoded_texts)
display(result)

# Training Router Model 

As we see the model returns a list of logits from multiple classification heads, which isn't ideal for
Triton inference server deployment, In order to address the need to handle multiple output tensors from NVIDIA's prompt classifier model we need to creating a wrapper that concatenates the outputs into a single tensor.

### WrapperModel Class
* Takes the original model as input
* Concatenates multiple output tensors into a single tensor
* Simplifies the output format for deployment

In [None]:
config = AutoConfig.from_pretrained("nvidia/prompt-task-and-complexity-classifier")
tokenizer = AutoTokenizer.from_pretrained("nvidia/prompt-task-and-complexity-classifier")

class TracedModel(nn.Module, PyTorchModelHubMixin):
    def __init__(self, target_sizes, task_type_map, weights_map, divisor_map):
        super(TracedModel, self).__init__()
        self.backbone = AutoModel.from_pretrained("microsoft/DeBERTa-v3-base")
        self.target_sizes = target_sizes.values()
        self.task_type_map = task_type_map
        self.weights_map = weights_map
        self.divisor_map = divisor_map

        self.heads = [
            MulticlassHead(self.backbone.config.hidden_size, sz)
            for sz in self.target_sizes
        ]

        for i, head in enumerate(self.heads):
            self.add_module(f"head_{i}", head)

        self.pool = MeanPooling()

    def forward(self, input_ids, attention_mask):
        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state
        mean_pooled_representation = self.pool(last_hidden_state, attention_mask)

        logits = [
            self.heads[k](mean_pooled_representation)
            for k in range(len(self.target_sizes))
        ]
        return logits

class WrapperModel(nn.Module):
    def __init__(self, original_model):
        super().__init__()
        self.model = original_model

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids, attention_mask)
        return torch.cat(outputs, dim=1)

model = TracedModel(
    target_sizes=config.target_sizes,
    task_type_map=config.task_type_map,
    weights_map=config.weights_map,
    divisor_map=config.divisor_map,
).from_pretrained("nvidia/prompt-task-and-complexity-classifier")

model = model.to('cuda')
model.eval()

wrapped_model = WrapperModel(model)
wrapped_model.eval()

In [None]:
prompt = "Write a Python script that uses a for loop."
encoded_texts = tokenizer(
    [prompt],
    return_tensors="pt",
    add_special_tokens=True,
    max_length=512,
    padding="max_length",
    truncation=True,
)

### Trace the wrapped model

In [None]:
with torch.no_grad():
    wrapped_model = torch.jit.trace(
        wrapped_model,
        (
            encoded_texts["input_ids"].to('cuda'),
            encoded_texts["attention_mask"].to('cuda')
        )
    )

In [None]:
wrapped_model.save("triton_template/task_router/1/model.pt")

### Load the saved model and compare it with original model

In [None]:
import torch
from transformers import AutoTokenizer

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("nvidia/prompt-task-and-complexity-classifier")

# Load the traced model
wrapped_model = torch.jit.load('triton_template/task_router/1/model.pt')
wrapped_model.eval()

# Prepare a sample input
sample_text = "Prompt: Translate the following sentence from English to French: 'Hello, how are you?'"
inputs = tokenizer(sample_text, return_tensors="pt", max_length=512, padding="max_length", truncation=True)

# Move inputs to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)

# Perform inference
with torch.no_grad():
    output = wrapped_model(input_ids, attention_mask)

# Process the output
def process_results(output_tensor, target_sizes):
    results = []
    start_idx = 0
    for size in target_sizes:
        end_idx = start_idx + size
        result = output_tensor[:, start_idx:end_idx]
        results.append(result)
        start_idx = end_idx
    return results

processed_results = process_results(output, config.target_sizes.values())

# Interpret the results
task_names = list(config.target_sizes.keys())
for i, result in enumerate(processed_results):
    probabilities = torch.softmax(result, dim=1)
    predicted_class = torch.argmax(probabilities, dim=1).item()
    confidence = probabilities[0, predicted_class].item()
    print(f"{task_names[i]}: Predicted class = {predicted_class}, Confidence = {confidence:.4f}")

# Compare with the original model (optional)
original_model = TracedModel(
    target_sizes=config.target_sizes,
    task_type_map=config.task_type_map,
    weights_map=config.weights_map,
    divisor_map=config.divisor_map,
).from_pretrained("nvidia/prompt-task-and-complexity-classifier")
original_model = original_model.to(device)
original_model.eval()

with torch.no_grad():
    original_outputs = original_model(input_ids, attention_mask)
    wrapped_original_output = torch.cat(original_outputs, dim=1)

print("\nComparing outputs:")
print("Original model output shape:", wrapped_original_output.shape)
print("Traced model output shape:", output.shape)
print("Outputs match:", torch.allclose(wrapped_original_output, output, atol=1e-4))

separated_original_outputs = process_results(wrapped_original_output, config.target_sizes.values())

# Compare separated outputs
print("\nComparing separated outputs:")
for i, (original, traced) in enumerate(zip(separated_original_outputs, processed_results)):
    print(f"Task {i}:")
    print(f"  Original output shape: {original.shape}")
    print(f"  Traced output shape: {traced.shape}")
    print(f"  Outputs match: {torch.allclose(original, traced, atol=1e-4)}")

# Interpret the original model results
print("\nOriginal model results:")
for i, result in enumerate(separated_original_outputs):
    probabilities = torch.softmax(result, dim=1)
    predicted_class = torch.argmax(probabilities, dim=1).item()
    confidence = probabilities[0, predicted_class].item()
    print(f"{task_names[i]}: Predicted class = {predicted_class}, Confidence = {confidence:.4f}")

In [None]:
original_outputs

In [None]:
wrapped_original_output

# Triton Deployment

Now that we have the traced model in the torch script, we can add this to the [Triton Inference Server](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/contents.html) and use the [ensemble pipeline feature](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/ensemble_models.html) to set up the pre and post-processing pipeline. 

The pre and post-processing code is available under the `triton_template/preprocessing_task_router/` and `triton_template/postprocessing_task_router/` directories and the `triton_template/task_router_ensemble/` contains the config on how the pre-processing, model and post-processing are linked together. 

This will be the same as the code downloaded from NGC when setting up the default task router.

This is organized in the following structure in the `/routers` directory with the following format

```
model_repository/
├── task_router
│   ├── 1
│   │   └── model.pt
│   └── config.pbtxt
├── task_router_ensemble
│   ├── 1
│   └── config.pbtxt
├── postprocessing_task_router
│   ├── 1
│   │   ├── logits_processor.py
│   │   ├── model.py
│   │   └── __pycache__
│   │       ├── logits_processor.cpython-310.pyc
│   │       └── model.cpython-310.pyc
│   └── config.pbtxt
└── preprocessing_task_router
    ├── 1
    │   ├── model.py
    │   └── __pycache__
    │       └── model.cpython-310.pyc
    └── config.pbtxt
```

Now copy the contents of `triton_template/` folder to the `/model_repository` 

In [None]:
!cp -r triton_template/* /model_repository

On your original machine, not within the Docker JupyterLab notebook, start the router server by running `make up`. 

In [None]:
!curl -v http://router-server:8000/v2/models/task_router_ensemble/ready

In [None]:
import numpy as np
import tritonclient.http as httpclient
from transformers import AutoConfig

def send_request(triton_client, text):
    input_text = np.array([[text]], dtype=object)
    inputs = [httpclient.InferInput("INPUT", input_text.shape, "BYTES")]
    inputs[0].set_data_from_numpy(input_text)

    outputs = [httpclient.InferRequestedOutput("OUTPUT")]

    response = triton_client.infer(model_name="task_router_ensemble", inputs=inputs, outputs=outputs)
    return response

# Load the config
config = AutoConfig.from_pretrained("nvidia/prompt-task-and-complexity-classifier")

# Get the task types from the config
task_types = list(config.task_type_map.values())

triton_client = httpclient.InferenceServerClient(url="router-server:8000")

prompt = "Prompt: Antibiotics are a type of medication used to treat bacterial infections. They work by either killing the bacteria or preventing them from reproducing, allowing the body’s immune system to fight off the infection. Antibiotics are usually taken orally in the form of pills, capsules, or liquid solutions, or sometimes administered intravenously. They are not effective against viral infections, and using them inappropriately can lead to antibiotic resistance. Explain the above in one sentence."
result = send_request(triton_client, prompt)

output_data = result.as_numpy("OUTPUT")

# Find the index of the maximum value (which should be 1 in the one-hot vector)
predicted_task_index = np.argmax(output_data)

# Map the index to the corresponding task type
predicted_task = task_types[predicted_task_index]

print(f"Input prompt: {prompt}")
print(f"Predicted task type: {predicted_task}")
print(f"One-hot encoded output: {output_data}")