# Case Classification Model Training Notebook
This notebook trains a multi-label classification model for classifying cases repor in call center transcripts.



## 📦 Setup and Install Dependencies


In [1]:
import os
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    DistilBertPreTrainedModel,
    DistilBertModel,
    TrainingArguments,
    Trainer,
    EvalPrediction
)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import re
import torch.nn as nn
import json
from sklearn.metrics.pairwise import cosine_similarity
import datetime
import mlflow
import mlflow.pytorch
import torch
import logging
import dvc.api


  from .autonotebook import tqdm as notebook_tqdm


## ML Flow Experiment  initialization and Log configuration

In [2]:

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
logger.info(f"Using device: {device}")



mlflow.set_tracking_uri("http://192.168.8.18:5000")
mlflow.set_experiment("Multitask_Classification")


REPO_ROOT="/home/rogendo/Work/ai/ai_service"
DATASET_PATH = 'datasets/classification/cleaned_synthetic_cases_generated_data_v0005.json'

os.chdir(REPO_ROOT)

Using device: cuda


### Dataset Loading

In here we also Feature Engineer the Main-Category by mapping the subcategories against the pre-defined main categories

In [3]:

def load_dataset(dataset_path, version="HEAD", repo=''):
    """Load dataset from DVC"""
    # Get dataset path
    data_url = dvc.api.get_url(
        path = dataset_path,
        repo=repo,
    )


    return data_url


# load latest_version
dataset_path = load_dataset(dataset_path=DATASET_PATH,repo=REPO_ROOT)
print(dataset_path)
df = pd.read_json(dataset_path)


# df = pd.read_json("/home/rogendo/chl_scratch/synthetic_data/casedir/balanced_cases_generated_data_v0005.json")
# df = pd.read_json("/home/rogendo/chl_scratch/synthetic_data/casedir/cleaned_balanced_cases_generated_data_v0005.json")
sub_to_main_mapping = {
    "Bullying": "Advice and Counselling",
    "Child in Conflict with the Law": "Advice and Counselling",
    "Discrimination": "Advice and Counselling",
    "Drug/Alcohol Abuse": "Advice and Counselling",
    "Family Relationship": "Advice and Counselling",
    "HIV/AIDS": "Advice and Counselling",
    "Homelessness": "Advice and Counselling",
    "Legal issues": "Advice and Counselling",
    "Missing Child": "Advice and Counselling",  
    "Peer Relationships": "Advice and Counselling",
    "Physical Health": "Advice and Counselling",
    "Psychosocial/Mental Health": "Advice and Counselling",
    "Relationships (Boy/Girl)": "Advice and Counselling",
    "Relationships (Parent/Child)": "Advice and Counselling",
    "Relationships (Student/Teacher)": "Advice and Counselling",
    "School related issues": "Advice and Counselling",
    "Self Esteem": "Advice and Counselling",
    "Sexual & Reproductive Health": "Advice and Counselling",
    "Student/ Teacher Relationship": "Advice and Counselling",
    "Teen Pregnancy": "Advice and Counselling",
    "Adoption": "Child Maintenance & Custody",
    "Birth Registration": "Child Maintenance & Custody",
    "Custody": "Child Maintenance & Custody",
    "Foster Care": "Child Maintenance & Custody",
    "Maintenance": "Child Maintenance & Custody",
    "No Care Giver": "Child Maintenance & Custody",
    "Other": "Child Maintenance & Custody", 
    "Albinism": "Disability",
    "Hearing impairment": "Disability",
    "Hydrocephalus": "Disability",
    "Mental impairment": "Disability",
    "Multiple disabilities": "Disability",
    "Physical impairment": "Disability",
    "Speech impairment": "Disability",
    "Spinal bifida": "Disability",
    "Visual impairment": "Disability",
    "Emotional/Psychological Violence": "GBV",
    "Financial/Economic Violence": "GBV",
    "Forced Marriage Violence": "GBV",
    "Harmful Practice": "GBV",
    "Physical Violence": "GBV",
    "Sexual Violence": "GBV",
    "Child Abuse": "Information",
    "Child Rights": "Information",
    "Info on Helpline": "Information",
    "Legal Issues": "Information",
    "School Related Issues": "Information", 
    "Balanced Diet": "Nutrition",
    "Breastfeeding": "Nutrition",
    "Feeding & Food preparation": "Nutrition",
    "Malnutrition": "Nutrition",
    "Obesity": "Nutrition",
    "Stagnation": "Nutrition",
    "Underweight": "Nutrition",
    "Child Abduction": "VANE",
    "Child Labor": "VANE",
    "Child Marriage": "VANE",
    "Child Neglect": "VANE",
    "Child Trafficking": "VANE",
    "Emotional Abuse": "VANE",
    "Female Genital Mutilation": "VANE",
    "OCSEA": "VANE", 
    "Physical Abuse": "VANE",
    "Sexual Abuse": "VANE",
    "Traditional Practice": "VANE",
    "Unlawful Confinement": "VANE",
    "Other": "VANE"  # Final "Other" mapping
}



ssh://ml-server-local/opt/dvc-storage/files/md5/e8/01db291a65ff08a5ac9611ccbee3d2


### 🧪 Dataset mapping and Train-Test Split


In [14]:

# Create new column using the mapping dictionary
df['main_category'] = df['category'].map(sub_to_main_mapping)

# Handle unmapped categories (if any)
df['main_category'] = df['main_category'].fillna('Unknown')

# Preprocess labels
main_categories = sorted(df['main_category'].unique())
sub_categories = sorted(df['category'].unique())
interventions = sorted(df['intervention'].unique())
priorities = [1, 2, 3]

print(df[['category', 'main_category']].head())
logger.info(df[['category', 'main_category']].head())


# Create mappings
main_cat2id = {cat: i for i, cat in enumerate(main_categories)}
sub_cat2id = {cat: i for i, cat in enumerate(sub_categories)}
interv2id = {interv: i for i, interv in enumerate(interventions)}
priority2id = {p: i for i, p in enumerate(priorities)}

# Apply mappings
df['main_category_id'] = df['main_category'].map(main_cat2id)
df['sub_category_id'] = df['category'].map(sub_cat2id)
df['intervention_id'] = df['intervention'].map(interv2id)
df['priority_id'] = df['priority'].map(lambda x: priority2id[x])
df['text'] = df['narrative']

# Split dataset
train_df, test_df = train_test_split(
    df, 
    test_size=0.1, 
    random_state=42,
    stratify=df['sub_category_id']
)

# Create datasets
train_dataset = Dataset.from_pandas(train_df[['text', 'main_category_id', 'sub_category_id', 'intervention_id', 'priority_id']])
test_dataset = Dataset.from_pandas(test_df[['text', 'main_category_id', 'sub_category_id', 'intervention_id', 'priority_id']])

dataset = DatasetDict({
    "train": train_dataset,
    "test": test_dataset,
    "validation": test_dataset
})


                          category                main_category
0               Drug/Alcohol Abuse       Advice and Counselling
1                     Homelessness       Advice and Counselling
2                         HIV/AIDS       Advice and Counselling
3  Relationships (Student/Teacher)       Advice and Counselling
4                      Foster Care  Child Maintenance & Custody


### Model Head and Layers Setup

In [16]:
dataset['train'].shape

(8810, 6)

In [17]:
dataset['test'].shape

(979, 6)

In [15]:
dataset['validation'].shape

(979, 6)

In [5]:



class MultiTaskDistilBert(DistilBertPreTrainedModel):
    def __init__(self, config, num_main, num_sub, num_interv, num_priority):
        super().__init__(config)
        self.distilbert = DistilBertModel(config)
        self.pre_classifier = nn.Linear(config.dim, config.dim)
        self.classifier_main = nn.Linear(config.dim, num_main)
        self.classifier_sub = nn.Linear(config.dim, num_sub)
        self.classifier_interv = nn.Linear(config.dim, num_interv)
        self.classifier_priority = nn.Linear(config.dim, num_priority)
        self.dropout = nn.Dropout(config.dropout)
        self.init_weights()

    def forward(self, input_ids=None, attention_mask=None, 
                main_category_id=None, sub_category_id=None, 
                intervention_id=None, priority_id=None):
        distilbert_output = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        hidden_state = distilbert_output.last_hidden_state 
        pooled_output = hidden_state[:, 0]                 
        pooled_output = self.pre_classifier(pooled_output) 
  
        pooled_output = nn.ReLU()(pooled_output)           
        pooled_output = self.dropout(pooled_output)        
        
        logits_main = self.classifier_main(pooled_output)
        logits_sub = self.classifier_sub(pooled_output)
        logits_interv = self.classifier_interv(pooled_output)
        logits_priority = self.classifier_priority(pooled_output)
        
        loss = None
        if main_category_id is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss_main = loss_fct(logits_main, main_category_id)
            loss_sub = loss_fct(logits_sub, sub_category_id)
            loss_interv = loss_fct(logits_interv, intervention_id)
            loss_priority = loss_fct(logits_priority, priority_id)
            loss = loss_main + loss_sub + loss_interv + loss_priority
        
        if loss is not None:
            return (loss, logits_main, logits_sub, logits_interv, logits_priority)
        else:
            return (logits_main, logits_sub, logits_interv, logits_priority)

    
    def get_embeddings(self, input_ids, attention_mask):
        return self.forward(
            input_ids=input_ids,
            attention_mask=attention_mask
        )[-1]




In [6]:


# ____ continuous fine-tuning and version control ____

# paths and loading existing metadata
model_output_dir = "/home/rogendo/chl_scratch/multitask_distilbert_version"
metadata_file = os.path.join(model_output_dir, "model_metadata.json")
os.makedirs(model_output_dir, exist_ok=True)

# Load metadata of the last best model
if os.path.exists(metadata_file):
    with open(metadata_file, "r") as f:
        metadata = json.load(f)
        print(metadata)
        logger.info(f"Model Version Metadata {metadata}")

    last_best_model_path = os.path.join(model_output_dir, metadata['last_best_model_dir'])
    print(last_best_model_path)
    print(f"Loading existing model from {last_best_model_path} for continuous fine-tuning.")
    logger.info(f"Loading existing model from {last_best_model_path} for continuous fine-tuning.")
   
    # Check if the directory exists
    if os.path.exists(last_best_model_path):
        model = MultiTaskDistilBert.from_pretrained(
            last_best_model_path,
            num_main=len(main_categories),
            num_sub=len(sub_categories),
            num_interv=len(interventions),
            num_priority=len(priorities)
        )
        tokenizer = AutoTokenizer.from_pretrained(last_best_model_path)
    else:
        # Fallback to base model if last best model directory is missing
        print(f"Warning: Last best model directory not found. Starting from base checkpoint.")
        logger.warning (f"Warning: Last best model directory not found. Starting from base checkpoint.")

        checkpoint = "distilbert-base-uncased"
        tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        model = MultiTaskDistilBert.from_pretrained(
            checkpoint,
            num_main=len(main_categories),
            num_sub=len(sub_categories),
            num_interv=len(interventions),
            num_priority=len(priorities)
        )
else:
    logger.info("No existing model found. Starting from base checkpoint.")
    checkpoint = "distilbert-base-uncased"
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    model = MultiTaskDistilBert.from_pretrained(
        checkpoint,
        num_main=len(main_categories),
        num_sub=len(sub_categories),
        num_interv=len(interventions),
        num_priority=len(priorities)
    )

# --- End  continuous fine-tuning and version control ---



{'version': 'v-1', 'date_trained': '2025-08-25 16:05:13.705971', 'eval_avg_acc': 0.6580694586312563, 'last_best_model_dir': 'multitask_distilbert_v-1', 'metrics': {'eval_avg_acc': 0.6580694586312563, 'eval_avg_precision': 0.6411514284441623, 'eval_avg_recall': 0.6580694586312563, 'eval_avg_f1': 0.6413574307508194, 'eval_main_acc': 0.7099080694586313, 'eval_main_precision': 0.708739077146966, 'eval_main_recall': 0.7099080694586313, 'eval_main_f1': 0.704318734141012, 'eval_sub_acc': 0.5822267620020429, 'eval_sub_precision': 0.5812330008069666, 'eval_sub_recall': 0.5822267620020429, 'eval_sub_f1': 0.572191139569792, 'eval_interv_acc': 0.6700715015321757, 'eval_interv_precision': 0.6514564843754779, 'eval_interv_recall': 0.6700715015321757, 'eval_interv_f1': 0.6554767410743746, 'eval_priority_acc': 0.6700715015321757, 'eval_priority_precision': 0.6231771514472384, 'eval_priority_recall': 0.6700715015321757, 'eval_priority_f1': 0.6334431082180993, 'eval_runtime': 7.581, 'eval_samples_per_se

In [7]:

# Tokenization function
def tokenize_function(batch):
    encoding = tokenizer(
        batch["text"], 
        padding="max_length", 
        truncation=True,
        max_length=512
    )
    return encoding

# Apply tokenization
encoded_dataset = dataset.map(tokenize_function, batched=True)
encoded_dataset.set_format("torch", columns=[
    "input_ids", "attention_mask", 
    "main_category_id", "sub_category_id", 
    "intervention_id", "priority_id"
])



Map: 100%|█████████████████████████| 8810/8810 [00:00<00:00, 9361.09 examples/s]
Map: 100%|███████████████████████████| 979/979 [00:00<00:00, 9649.69 examples/s]


In [8]:

def compute_metrics(p: EvalPrediction):
    logger.info("compute_metrics called")
    # p.predictions is a tuple of logits for each task
    # p.label_ids is a tuple of true labels for each task

    logits_main, logits_sub, logits_interv, logits_priority = p.predictions
    labels_main, labels_sub, labels_interv, labels_priority = p.label_ids

    preds_main = np.argmax(logits_main, axis=1)
    preds_sub = np.argmax(logits_sub, axis=1)
    preds_interv = np.argmax(logits_interv, axis=1)
    preds_priority = np.argmax(logits_priority, axis=1)

    metrics = {}

    # Helper function to compute and add metrics for each task
    def add_task_metrics(task_name, labels, preds):
        accuracy = accuracy_score(labels, preds)
        precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted', zero_division=0)
        
        metrics[f"{task_name}_acc"] = accuracy
        metrics[f"{task_name}_precision"] = precision
        metrics[f"{task_name}_recall"] = recall
        metrics[f"{task_name}_f1"] = f1

    add_task_metrics("main", labels_main, preds_main)
    add_task_metrics("sub", labels_sub, preds_sub)
    add_task_metrics("interv", labels_interv, preds_interv)
    add_task_metrics("priority", labels_priority, preds_priority)

    # Calculate average metrics across all tasks
    avg_acc = np.mean([metrics[f"{task}_acc"] for task in ["main", "sub", "interv", "priority"]])
    avg_precision = np.mean([metrics[f"{task}_precision"] for task in ["main", "sub", "interv", "priority"]])
    avg_recall = np.mean([metrics[f"{task}_recall"] for task in ["main", "sub", "interv", "priority"]])
    avg_f1 = np.mean([metrics[f"{task}_f1"] for task in ["main", "sub", "interv", "priority"]])

    metrics["eval_avg_acc"] = avg_acc
    metrics["eval_avg_precision"] = avg_precision
    metrics["eval_avg_recall"] = avg_recall
    metrics["eval_avg_f1"] = avg_f1

    return metrics

class MultiTaskTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = {
            "main_category_id": inputs.pop("main_category_id"),
            "sub_category_id": inputs.pop("sub_category_id"),
            "intervention_id": inputs.pop("intervention_id"),
            "priority_id": inputs.pop("priority_id")
        }
        outputs = model(**inputs, **labels)
        # outputs is a tuple: (loss, logits_main, logits_sub, logits_interv, logits_priority)

        loss = outputs[0]
        if return_outputs:
            return (loss, *outputs[1:])
        else:
            return loss

    def prediction_step(self, model, inputs, prediction_loss_only=False, ignore_keys=None):
        # Remove labels from inputs
        label_keys = ["main_category_id", "sub_category_id", "intervention_id", "priority_id"]
        labels = {key: inputs.pop(key) for key in label_keys if key in inputs}

        # Forward pass without labels
        with torch.no_grad():
            outputs = model(**inputs)
        
        # Extract logits (assumes model returns tuple: (logits_main, logits_sub, ...)) if no loss is returned
        # Or (loss, logits_main, logits_sub, ...) if loss is returned
        
        # Check if the first element is a tensor (likely loss)
        if isinstance(outputs[0], torch.Tensor) and outputs[0].dim() == 0: # Check if it's a scalar tensor
            loss = outputs[0]
            logits = outputs[1:] # Skip loss
        else:
            loss = None
            logits = outputs # All elements are logits

        # Handle label presence
        
        if labels:
            label_values = (labels["main_category_id"], labels["sub_category_id"],
                           labels["intervention_id"], labels["priority_id"])
        
        return (loss, logits, label_values)


In [None]:


# Training arguments
training_args = TrainingArguments(
    # output_dir="./results",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=14,
    weight_decay=0.01,
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_avg_acc",
    greater_is_better=True,
    logging_dir='./logs',
    logging_steps=100,
)

logger.info(len(encoded_dataset["test"]))
logger.info(len(encoded_dataset["train"]))

# Initialize trainer
trainer = MultiTaskTrainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["test"],
    
    compute_metrics=compute_metrics,
)

# Start training
trainer.train()

# Save model
trainer.save_model("CHS_tz_classifier_distilbert")
tokenizer.save_pretrained("CHS_tz_classifier_distilbert")
 

# Evaluate the model after training
new_metrics = trainer.evaluate(encoded_dataset["test"])
new_avg_acc = new_metrics.get('eval_avg_acc', 0)

logger.info(f"Model Performance results: {new_metrics}")

# Load previous best model's average accuracy
if os.path.exists(metadata_file):
    with open(metadata_file, "r") as f:
        metadata = json.load(f)
    prev_avg_acc = metadata.get('eval_avg_acc', 0)
else:
    prev_avg_acc = -1

# Check if the new model is better
if new_avg_acc > prev_avg_acc:
    logger.info(" New model performance improved! Saving new version.")
    # Create new version directory
    version = len(os.listdir(model_output_dir)) -1
    new_model_dir = f"CHS_tz_classifier_distilbert{version}"
    new_model_path = os.path.join(model_output_dir, new_model_dir)

    # Save the model and tokenizer in the new version directory
    trainer.save_model(new_model_path)
    tokenizer.save_pretrained(new_model_path)

    # Update metadata file
    metadata = {
        "version": f"v{version}",
        "date_trained": str(datetime.datetime.now()),
        "eval_avg_acc": new_avg_acc,
        "last_best_model_dir": new_model_dir,
        "metrics": new_metrics
    }
    with open(metadata_file, "w") as f:
        json.dump(metadata, f, indent=4)
else:
    logger.info("Model performance did not improve. Not saving a new version.")




Epoch,Training Loss,Validation Loss,Avg Acc,Avg Precision,Avg Recall,Avg F1,Main Acc,Main Precision,Main Recall,Main F1,Sub Acc,Sub Precision,Sub Recall,Sub F1,Interv Acc,Interv Precision,Interv Recall,Interv F1,Priority Acc,Priority Precision,Priority Recall,Priority F1
1,2.8272,No log,0.652962,0.642867,0.652962,0.638604,0.698672,0.691734,0.698672,0.693446,0.585291,0.602546,0.585291,0.582497,0.66905,0.644297,0.66905,0.648527,0.658836,0.632892,0.658836,0.629946
2,2.4794,No log,0.655005,0.640197,0.655005,0.636332,0.699694,0.701763,0.699694,0.689314,0.576098,0.587158,0.576098,0.569818,0.673136,0.649992,0.673136,0.656432,0.671093,0.621876,0.671093,0.629766
3,2.1429,No log,0.648876,0.652724,0.648876,0.639178,0.716037,0.724539,0.716037,0.701643,0.59142,0.615874,0.59142,0.588733,0.641471,0.651712,0.641471,0.638995,0.646578,0.61877,0.646578,0.627341
4,1.8663,No log,0.651685,0.637984,0.651685,0.633717,0.703779,0.699264,0.703779,0.690296,0.576098,0.599292,0.576098,0.575315,0.658836,0.641738,0.658836,0.648005,0.668029,0.61164,0.668029,0.621251
5,1.5769,No log,0.645046,0.648603,0.645046,0.641926,0.719101,0.719597,0.719101,0.712287,0.590398,0.610169,0.590398,0.591516,0.650664,0.64925,0.650664,0.647611,0.62002,0.615395,0.62002,0.616288
6,1.4236,No log,0.641726,0.644549,0.641726,0.638836,0.701736,0.700706,0.701736,0.697412,0.586313,0.610671,0.586313,0.589291,0.648621,0.649969,0.648621,0.646641,0.630235,0.616851,0.630235,0.622001
7,1.1999,No log,0.639428,0.642328,0.639428,0.637878,0.709908,0.707107,0.709908,0.706802,0.582227,0.609641,0.582227,0.587464,0.650664,0.639958,0.650664,0.643832,0.614913,0.612607,0.614913,0.613414
8,1.1105,No log,0.641726,0.634192,0.641726,0.634763,0.701736,0.697885,0.701736,0.698805,0.583248,0.603228,0.583248,0.586228,0.652707,0.635225,0.652707,0.643239,0.629213,0.600429,0.629213,0.61078
9,0.9614,No log,0.644791,0.644514,0.644791,0.639415,0.706844,0.702342,0.706844,0.698055,0.586313,0.608937,0.586313,0.58691,0.645557,0.639757,0.645557,0.640212,0.640449,0.627022,0.640449,0.632483
10,0.8515,No log,0.652196,0.646083,0.652196,0.644907,0.709908,0.70795,0.709908,0.705328,0.585291,0.61168,0.585291,0.590509,0.665986,0.647663,0.665986,0.655833,0.6476,0.61704,0.6476,0.627958


🏃 View run trainer_output at: http://192.168.8.18:5000/#/experiments/13/runs/bc052dc155a04092916f04a4be53a668
🧪 View experiment at: http://192.168.8.18:5000/#/experiments/13


In [11]:

# Generate and save category embeddings
def generate_category_embeddings(categories, model, tokenizer, device):
    embeddings = []
    for category in categories:
        inputs = tokenizer(
            category, 
            padding="max_length", 
            truncation=True, 
            max_length=256, 
            return_tensors="pt"
        ).to(device)
        with torch.no_grad():
            emb = model.get_embeddings(**inputs).cpu().numpy()
        embeddings.append(emb[0])
    return np.array(embeddings)

# Generate embeddings for all categories
main_cat_embeddings = generate_category_embeddings(main_categories, model, tokenizer, device)
sub_cat_embeddings = generate_category_embeddings(sub_categories, model, tokenizer, device)

# Save embeddings
os.makedirs("embeddings", exist_ok=True)
np.save("embeddings/main_cat_embeddings.npy", main_cat_embeddings)
np.save("embeddings/sub_cat_embeddings.npy", sub_cat_embeddings)

# Save category lists
with open("main_categories.json", "w") as f:
    json.dump(main_categories, f)
with open("sub_categories.json", "w") as f:
    json.dump(sub_categories, f)
with open("interventions.json", "w") as f:
    json.dump(interventions, f)
with open("priorities.json", "w") as f:
    json.dump(priorities, f)

# Evaluation
metrics = trainer.evaluate(encoded_dataset["test"])
logger.info(f"Model Performance results: {metrics}")

# Save metrics
with open("multilabel_model_metrics.json", "w") as f:
    json.dump(metrics, f, indent=4)

logger.info(" Model performance on test set:", metrics)


In [None]:


# Load trained model
# model_path = '/opt/chl_ai/models/raw-models/ai_models/MultiClassifier/multitask_distilbert'
model_path = "/home/rogendo/Work/ai/ai_service/qa_distilbert_v2"
# model_path = new_model_path
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = MultiTaskDistilBert.from_pretrained(
    model_path,
    num_main=len(main_categories),
    num_sub=len(sub_categories),
    num_interv=len(interventions),
    num_priority=len(priorities)
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()  # Set to evaluation mode


def classify_multitask_case_return_indices(narrative):
    text = narrative.lower().strip()
    text = re.sub(r'[^a-z0-9\s]', '', text)
    inputs = tokenizer(
        text,
        truncation=True,
        padding='max_length',
        max_length=256,
        return_tensors="pt"
    ).to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits_main, logits_sub, logits_interv, logits_priority = outputs
        preds_main = torch.argmax(logits_main, dim=1).cpu().numpy()[0]
        preds_sub = torch.argmax(logits_sub, dim=1).cpu().numpy()[0]
        preds_interv = torch.argmax(logits_interv, dim=1).cpu().numpy()[0]
        preds_priority = torch.argmax(logits_priority, dim=1).cpu().numpy()[0]
    
    return preds_main, preds_sub, preds_interv, preds_priority

# test_data = [
#     {
#         "narrative": "On <DATE_TIME>, a girl <PERSON> (<DATE_TIME> ) from <LOCATION> district reported being mistreated and verbally abused by their caregiver. She reported that it has been negatively affected and her esteem has gone down and she feels unsafe staying at her home. She requested for adoption as it was proving difficult to stay with the caregiver, who was not related to her and did not love her. The counselor advised her to be patient and that her issue has been filed and will be followed up. Sheee was glad for the counselor help",
#         "main_category": "VANE",
#         "sub_category": "Emotional Abuse",
#         "intervention": "Counselling",
#         "priority": 2
#     },
#     {
#         "narrative": "On 04 November 2022 Mustafa Ahmadi called on 116 for the purpose of getting an understanding on the meaning of Female Genital Mutilation.A counsellor explained to him that Female Genital Mutilation means the removal of genital parts of the female or woman.The parts removes are clitoris,labia majora or labia minora, FGM has several effects on the victim such as a high rate of bleeding, to acquire infectious diseases such as HIV, psychological problems, and pain. Apart from that Mustafa also asked about the effects of child marriage whereby the counselor explained to him that child marriage has negative effects such as failure to meet the life dreams due to school dropout, failure to make decisions as a female so she will depend on her husband to decide every matter of the family,psychological problems.Mustafa appreciated for the information.",
#         "main_category": "Information",
#         "sub_category": "Child Abuse",
#         "intervention": "Counselling",
#         "priority": 1
#     },

#     {
#         "narrative": "On 05 November 2022,Grace John called on 116 for the purpose of getting understanding on the service provided by the child helpline and its jurisdiction. The counsellor explained to her that 116 provides the service of reporting, receiving, referring and recording child abuse cases to the social welfare officers,to provide awareness on child nutrition,parenting and child rights as well as child maintenance.Then Grace said that she is divorced by her husband who does not not have any job,so she wanted to understand the mechanism for him to provide service to the children,a counsellor explained to her that if a father do not have any job then he cannot provide for the children so she should continue to provide for them until the father get a job which will enable him to provide srvices to the children and if he will refuse then she should report him to the social welfare officer so as to force him to do right.Grace appreciated for the information.",
#         "main_category": "Child Maintenance & Custody",
#         "sub_category": "Maintenance",
#         "intervention": 'Awareness/Information Provided',
#         "priority": 2
#     },

#     {
#     "narrative": "On 12 November 2022, the helpline center received a call from Abubaker in Lwandai Village, Soni Ward, Bumbuli District, Tanga Region. He wanted to understand the causes and effects of mental impairment. The counselor explained how biological and physical factors can contribute to the condition, and Abubaker confirmed that he understood the information.",
#     "main_category": "Information",
#     "sub_category": "Psychosocial/Mental Health",
#     "intervention": "Counselling",
#     "priority": 2
# },

# {
#     "narrative": "On 12 November 2022, Bahati Ndelesero, 18, from Kibaigwa Village, Kongwa District, Dodoma Region, called the helpline seeking advice on balanced diets as his wife prepares for pregnancy. The counselor advised him to ensure meals include foods from all five main groups — cereals and tubers, animal and plant proteins, fruits, vegetables, and healthy sugars or oils — to support good health for both mother and child.",
#     "main_category": "Nutrition",
#     "sub_category": "Balanced Diet",
#     "intervention": "Awareness/Information Provided",
#     "priority": 2
# },

# {
#     "narrative": "On 12 November 2022, Awadhi Shaibu, 27, from Kiangara Village, Liwale District, Lindi Region, reported that his 3-year-old son Yasri had become fearful and withdrawn after his parents separated. The counselor explained that such behavior may signal emotional distress caused by lack of parental bonding and advised both parents to resolve conflicts, provide love, support, and protection, and avoid exposing the child to further tension or abuse.",
#     "main_category": "Advice and Counselling",
#     "sub_category": "Relationships (Parent/Child)",
#     "intervention": "Counselling",
#     "priority": 1
# },



# {
#     "narrative": "On 13 November 2022, the helpline received a call from Amina Yusuf, 16, from Mbinga District, Ruvuma Region, reporting that her uncle constantly shouts at her and calls her names. She said this has lowered her confidence and made her afraid to stay at home. The counselor reassured her, documented the case, and explained that follow-up action will be taken.",
#     "main_category": "VANE",
#     "sub_category": "Emotional Abuse",
#     "intervention": "Counselling",
#     "priority": 1
# },
# {
#     "narrative": "On 14 November 2022, Daudi Mshana from Handeni District, Tanga Region, called the helpline asking for information on the effects of child labor. The counselor explained that child labor can lead to school dropout, health problems, and long-term poverty. Daudi thanked the counselor for clarifying these issues.",
#     "main_category": "Information",
#     "sub_category": "Child Labor",
#     "intervention": "Counselling",
#     "priority": 1
# },
# {
#     "narrative": "On 15 November 2022, the helpline received a call from Salum Mwinyi, 17, in Korogwe District, Tanga Region. He wanted to understand why drug use among youth is harmful. The counselor explained the physical, social, and legal risks involved with abusing drugs, both legal drugs and illegal drugs. The counselor explained in detail about what each of the effects of the drugs and advised Salum not to get herself involveed with any drugs. Salum expressed appreciation for the advice.",
#     "main_category": "Information",
#     "sub_category": "Drug/Alcohol Abuse",
#     "intervention": "Counselling",
#     "priority": 2
# },

# {
#     "narrative": "On 16 November 2022, the helpline was contacted by Jafari Said, 25, from Kilombero District, Morogoro Region. He asked how to provide a safe home environment for his two young sisters. The counselor advised him on good parenting practices and protecting them from abuse.",
#     "main_category": "Advice and Counselling",
#     "sub_category": "Relationships (Parent/Child)",
#     "intervention": "Counselling",
#     "priority": 2
# },

# {
#     "narrative": "On 17 November 2022, Paulo Samson, 19, from Same District, Kilimanjaro Region, called to ask about healthy food choices for his younger siblings. The counselor gave guidance on including proteins, vegetables, fruits, and whole grains to ensure proper growth and development.",
#     "main_category": "Nutrition",
#     "sub_category": 'Balanced Diet',
#     "intervention": "Awareness/Information Provided",
#     "priority": 3
# },
# {
#     "narrative": "On 18 November 2022, Ramadhan Musa, 28, from Sumbawanga Urban District, Rukwa Region, requested information about birth registration. The counselor explained the registration process, benefits of having a birth certificate, and directed him to the nearest registration office. He appreciated the assistance.",
#     "main_category": "Information",
#     "sub_category": "Child Rights",
#     "intervention": "Awareness/Information Provided",
#     "priority": 1
# }

# ]

test_data = dataset['validation']
print(test_data)
# Collect true and predicted labels
true_main, pred_main = [], []
true_sub, pred_sub = [], []
true_interv, pred_interv = [], []
true_priority, pred_priority = [], []

for example in test_data:
    # Get true indices directly from IDs
    true_main.append(example["main_category_id"])
    true_sub.append(example["sub_category_id"])
    true_interv.append(example["intervention_id"])
    
    # Append priority_id directly, since it matches the output indices
    priority_val = example["priority_id"]
    true_priority.append(priority_val)
    
    # Get predictions
    main_idx, sub_idx, interv_idx, priority_idx = classify_multitask_case_return_indices(
        example["text"]
    )
    pred_main.append(main_idx)
    pred_sub.append(sub_idx)
    pred_interv.append(interv_idx)
    pred_priority.append(priority_idx)



def plot_enhanced_confusion_matrix(true, pred, classes, title, filename, figsize=(12, 10)):
    cm = confusion_matrix(true, pred, labels=range(len(classes)))
    
    plt.figure(figsize=figsize)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=classes, 
                yticklabels=classes,
                cbar_kws={'shrink': 0.8})
    plt.title(f'Confusion Matrix - {title}', fontsize=16, pad=20)
    plt.ylabel('True Label', fontsize=14)
    plt.xlabel('Predicted Label', fontsize=14)
    plt.xticks(rotation=45, ha='right', fontsize=10)
    plt.yticks(rotation=0, fontsize=10)

    #  accuracy 
    accuracy = np.trace(cm) / np.sum(cm) if np.sum(cm) > 0 else 0
    plt.figtext(0.5, 0.01, f'Accuracy: {accuracy:.2%}', ha='center', fontsize=12)
    
    plt.tight_layout()
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved {filename}")

#  output directory
os.makedirs("confusion_matrices", exist_ok=True)

#  save confusion matrices
plot_enhanced_confusion_matrix(true_main, pred_main, main_categories,
                     "Main Category", "confusion_matrices/main_category.png")

plot_enhanced_confusion_matrix(true_sub, pred_sub, sub_categories,
                     "Sub Category", "confusion_matrices/sub_category.png")

plot_enhanced_confusion_matrix(true_interv, pred_interv, interventions,
                     "Intervention", "confusion_matrices/intervention.png")

plot_enhanced_confusion_matrix(true_priority, pred_priority, [str(p) for p in priorities],
                     "Priority", "confusion_matrices/priority.png", figsize=(10, 8))


def print_compact_report(true, pred, classes, title):
    print(f"\n{'='*60}")
    print(f"CLASSIFICATION REPORT: {title}")
    print(f"{'='*60}")
    
    # Get the classification report as a dictionary
    report = classification_report(
        true, pred, 
        labels=range(len(classes)),
        target_names=classes,
        output_dict=True,
        zero_division=0
    )
    
    # Print only classes that appear in the test data or have predictions
    present_classes = set(true) | set(pred)
    
    print(f"Classes present in test data: {len(present_classes)}/{len(classes)}")
    print("\nPerformance on present classes:")
    print(f"{'Class':30} {'Precision':>10} {'Recall':>10} {'F1':>10} {'Support':>10}")
    
    for i in present_classes:
        class_name = classes[i]
        print(f"{class_name:30} {report[class_name]['precision']:>10.3f} "
              f"{report[class_name]['recall']:>10.3f} "
              f"{report[class_name]['f1-score']:>10.3f} "
              f"{report[class_name]['support']:>10.0f}")
    
    print(f"\nOverall Accuracy: {report['accuracy']:.3f}")
    print(f"Macro Avg F1: {report['macro avg']['f1-score']:.3f}")
    print(f"Weighted Avg F1: {report['weighted avg']['f1-score']:.3f}")

# Print compact reports
print_compact_report(true_main, pred_main, main_categories, "Main Category")
print_compact_report(true_sub, pred_sub, sub_categories, "Sub Category")
print_compact_report(true_interv, pred_interv, interventions, "Intervention")
print_compact_report(true_priority, pred_priority, [str(p) for p in priorities], "Priority")



# Additional analysis
print(f"\n{'='*60}")
print("ADDITIONAL ANALYSIS")
print(f"{'='*60}")
print(f"Test samples: {len(test_data)}")
print(f"Main categories in test data: {set([main_categories[i] for i in true_main])}")
print(f"Main categories not in test data: {set(main_categories) - set([main_categories[i] for i in true_main])}")

# Calculate and display per-task accuracy
tasks = [
    ("Main Category", true_main, pred_main),
    ("Sub Category", true_sub, pred_sub),
    ("Intervention", true_interv, pred_interv),
    ("Priority", true_priority, pred_priority)
]

print(f"\n{'Task':20} {'Accuracy':>10} {'Correct/Total':>15}")
for name, true, pred in tasks:
    correct = sum(1 for t, p in zip(true, pred) if t == p)
    accuracy = correct / len(true)
    print(f"{name:20} {accuracy:>10.2%} {f'{correct}/{len(true)}':>15}")

Dataset({
    features: ['text', 'main_category_id', 'sub_category_id', 'intervention_id', 'priority_id', '__index_level_0__'],
    num_rows: 979
})
Saved confusion_matrices/main_category.png
Saved confusion_matrices/sub_category.png
Saved confusion_matrices/intervention.png
Saved confusion_matrices/priority.png

CLASSIFICATION REPORT: Main Category
Classes present in test data: 8/8

Performance on present classes:
Class                           Precision     Recall         F1    Support
Advice and Counselling              0.613      0.763      0.680        299
Child Maintenance & Custody         0.617      0.804      0.698         92
Disability                          0.948      0.880      0.913        125
GBV                                 0.603      0.471      0.529         87
Information                         0.299      0.282      0.290         71
Nutrition                           0.925      0.843      0.882        102
Unknown                             1.000      0.083    

In [None]:

# {
#     "narrative": "On 15 November 2022, the helpline received a call from Salum Mwinyi, 17, in Korogwe District, Tanga Region. He wanted to understand why drug use among youth is harmful. The counselor explained the physical, social, and legal risks involved with abusing drugs, both legal drugs and illegal drugs. The counselor explained in detail about what each of the effects of the drugs and advised Salum not to get herself involveed with any drugs. Salum expressed appreciation for the advice.",
#     "main_category": "Information",
#     "sub_category": "Drug/Alcohol Abuse",
#     "intervention": "Counselling",
#     "priority": 2
# },
{
    "narrative": "On 16 November 2022, the helpline was contacted by Jafari Said, 25, from Kilombero District, Morogoro Region. He asked how to provide a safe home environment for his two young sisters. The counselor advised him on good parenting practices and protecting them from abuse.",
    "main_category": "Advice and Counselling",
    "sub_category": "Relationships (Parent/Child)",
    "intervention": "Counselling",
    "priority": 2
},
# narrative =  "On 15 November 2022, the helpline received a call from Salum Mwinyi, 17, in Korogwe District, Tanga Region. He wanted to understand why drug use among youth is harmful. The counselor explained the physical, social, and legal risks involved with abusing drugs, both legal drugs and illegal drugs. The counselor explained in detail about what each of the effects of the drugs and advised Salum not to get herself involveed with any drugs. Salum expressed appreciation for the advice.",


In [29]:
# narrative=" On <DATE_TIME> a girl <PERSON> (<DATE_TIME> ) from <LOCATION> district, <LOCATION> region called on 116 to report of the injustices done to her by a person who was to take care of her wellbeing. She reported that her stepfather raped her and abused her sexually and she was 2 months pregnant and is forced to abort as the stepfather threatened her that He will kill her if she does not abort. She reported that she was not the only one who was abused by the stepfather but her mother was also abused by the stepfather. "
narrative= "On 16 November 2022, the helpline was contacted by Jafari Said, 25, from Kilombero District, Morogoro Region. He asked how to provide a safe home environment for his two young sisters. The counselor advised him on good parenting practices and protecting them from abuse.",

main_idx, sub_idx, interv_idx, priority_idx = classify_multitask_case_return_indices(
    narrative[0]
)
pred_main.append(main_idx)
pred_sub.append(sub_idx)
pred_interv.append(interv_idx)
pred_priority.append(priority_idx)

print("Predicted Main Category:", main_categories[main_idx])
print("Predicted Sub Category:", sub_categories[sub_idx])
print("Predicted Intervention:", interventions[interv_idx])
print("Predicted Priority:", priorities[priority_idx])

Predicted Main Category: Advice and Counselling
Predicted Sub Category: Self Esteem
Predicted Intervention: Counselling
Predicted Priority: 2
