In [1]:
import os
import json
import logging
import numpy as np
from pathlib import Path

import scipy
import torch
import onnxruntime
from transformers import AutoTokenizer, AutoModelForSequenceClassification

In [29]:
current_dir = Path().resolve()
while not current_dir.name.endswith("xlm-roberta-base-cls-depression"):
    current_dir = current_dir.parent

os.chdir(current_dir)

input_test_data = current_dir / "data/clean/test.csv"
input_pytorch_model_dir = current_dir / "data/models/xlm-roberta-base-cls-depression"
input_model_dir = current_dir / "data/dist/xlm-roberta-base-cls-depression"
input_model_base_filename = input_model_dir / "model.onnx"
# input_model_quantized_filename = input_model_dir / "model.opt.quant.onnx" / "model_quantized.onnx"
input_model_quantized_filename = input_model_dir / "model.onnx"

In [30]:
# The pre process function take a question and a context, and generates the tensor inputs to the model:
# - input_ids: the words in the question encoded as integers
# - attention_mask: not used in this model
# - token_type_ids: a list of 0s and 1s that distinguish between the words of the question and the words of the context
# This function also returns the words contained in the question and the context, so that the answer can be decoded into a phrase.
def preprocess(text):
    encoded = tokenizer(
        text, padding="max_length", truncation=True, max_length=512, return_tensors="pt"
    )
    return (encoded["input_ids"], encoded["attention_mask"])


# The post process function maps the list of start and end log probabilities onto a text answer, using the text tokens from the question
# and context.
def postprocess(tokens, start, end):
    results = {}
    answer_start = np.argmax(start)
    answer_end = np.argmax(end)
    if answer_end >= answer_start:
        answer = tokens[answer_start]
        for i in range(answer_start + 1, answer_end + 1):
            if tokens[i][0:2] == "##":
                answer += tokens[i][2:]
            else:
                answer += " " + tokens[i]
        results["answer"] = answer.capitalize()
    else:
        results["error"] = (
            "I am unable to find the answer to this question. Can you please ask another question?"
        )
    return results


# Perform the one-off initialization for the prediction. The init code is run once when the endpoint is setup.
def init_models():
    global tokenizer, session, model
    model = AutoModelForSequenceClassification.from_pretrained(input_pytorch_model_dir)
    tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
    session = onnxruntime.InferenceSession(
        str(input_model_quantized_filename), providers=["CPUExecutionProvider"]
    )

# Run the PyTorch model, for functional and performance comparison
def run_pytorch(raw_data):
    inputs = json.loads(raw_data)
    input_ids, attention_mask = preprocess(inputs["text"])
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
    probabilities = torch.softmax(logits, dim=1)
    prediction = torch.argmax(probabilities, dim=1).item()
    return {
        "prediction": prediction,
        "probability": probabilities[0][prediction].item(),
    }


# Run the ONNX model with ONNX Runtime
def run_onnx(raw_data):
    inputs = json.loads(raw_data)
    input_ids, attention_mask = preprocess(inputs["text"])
    model_inputs = {
        "input_ids": np.asarray(input_ids, dtype=np.int64),
        "attention_mask": np.asarray(attention_mask, dtype=np.int64)
        }
    outputs = session.run(["logits"], model_inputs)
    logits = outputs[0]
    probabilities = scipy.special.softmax(logits, axis=1)
    prediction = np.argmax(probabilities, axis=1)[0]
    return {
        "prediction": prediction,
        "probability": probabilities[0][prediction].item(),
    }

init_models()

In [15]:
import pandas as pd
import numpy as np
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix
)
import json
import time
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

class ModelComparison:
    def __init__(self, run_pytorch, run_onnx):
        """
        Initialize the comparison class with model inference functions.
        
        Args:
            run_pytorch: Function that takes text input and returns dict with prediction and probability
            run_onnx: Function that takes text input and returns dict with prediction and probability
        """
        self.run_pytorch = run_pytorch
        self.run_onnx = run_onnx
        self.results = {
            'pytorch': {'predictions': [], 'probabilities': [], 'times': []},
            'onnx': {'predictions': [], 'probabilities': [], 'times': []}
        }
        
    def process_batch(self, texts: List[str]) -> None:
        """
        Process a batch of texts through both models.
        
        Args:
            texts: List of input texts to process
        """
        for text in tqdm(texts, desc="Processing texts"):
            input_json = json.dumps({"text": text})
            
            # PyTorch inference
            start_time = time.time()
            pytorch_result = self.run_pytorch(input_json)
            pytorch_time = time.time() - start_time
            
            # ONNX inference
            start_time = time.time()
            onnx_result = self.run_onnx(input_json)
            onnx_time = time.time() - start_time
            
            # Store results
            self.results['pytorch']['predictions'].append(pytorch_result['prediction'])
            self.results['pytorch']['probabilities'].append(pytorch_result['probability'])
            self.results['pytorch']['times'].append(pytorch_time)
            
            self.results['onnx']['predictions'].append(onnx_result['prediction'])
            self.results['onnx']['probabilities'].append(onnx_result['probability'])
            self.results['onnx']['times'].append(onnx_time)
    
    def calculate_metrics(self, true_labels: List[int]) -> Dict:
        """
        Calculate various classification metrics for both models.
        
        Args:
            true_labels: List of ground truth labels
        
        Returns:
            Dictionary containing metrics for both models
        """
        metrics = {}
        
        for model_name in ['pytorch', 'onnx']:
            predictions = self.results[model_name]['predictions']
            probabilities = self.results[model_name]['probabilities']
            
            metrics[model_name] = {
                'accuracy': accuracy_score(true_labels, predictions),
                'precision': precision_score(true_labels, predictions),
                'recall': recall_score(true_labels, predictions),
                'f1': f1_score(true_labels, predictions),
                'roc_auc': roc_auc_score(true_labels, probabilities),
                'confusion_matrix': confusion_matrix(true_labels, predictions),
                'avg_inference_time': np.mean(self.results[model_name]['times']),
                'std_inference_time': np.std(self.results[model_name]['times'])
            }
            
        return metrics
    
    def plot_confusion_matrices(self, metrics: Dict) -> None:
        """
        Plot confusion matrices for both models side by side.
        
        Args:
            metrics: Dictionary containing metrics including confusion matrices
        """
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        
        # Plot PyTorch confusion matrix
        sns.heatmap(metrics['pytorch']['confusion_matrix'], annot=True, fmt='d', ax=ax1)
        ax1.set_title('PyTorch Model Confusion Matrix')
        ax1.set_xlabel('Predicted')
        ax1.set_ylabel('True')
        
        # Plot ONNX confusion matrix
        sns.heatmap(metrics['onnx']['confusion_matrix'], annot=True, fmt='d', ax=ax2)
        ax2.set_title('ONNX Model Confusion Matrix')
        ax2.set_xlabel('Predicted')
        ax2.set_ylabel('True')
        
        plt.tight_layout()
        plt.savefig('confusion_matrices.png')
        plt.close()
    
    def plot_inference_times(self) -> None:
        """Plot distribution of inference times for both models."""
        plt.figure(figsize=(10, 6))
        plt.boxplot([self.results['pytorch']['times'], self.results['onnx']['times']], 
                   labels=['PyTorch', 'ONNX'])
        plt.title('Model Inference Times Comparison')
        plt.ylabel('Time (seconds)')
        plt.savefig('inference_times.png')
        plt.close()
    
    def generate_report(self, metrics: Dict) -> str:
        """
        Generate a detailed comparison report.
        
        Args:
            metrics: Dictionary containing metrics for both models
        
        Returns:
            Formatted string containing the comparison report
        """
        report = "Model Comparison Report\n"
        report += "=====================\n\n"
        
        for model_name in ['pytorch', 'onnx']:
            report += f"{model_name.upper()} Model Metrics:\n"
            report += f"{'='*20}\n"
            report += f"Accuracy: {metrics[model_name]['accuracy']:.4f}\n"
            report += f"Precision: {metrics[model_name]['precision']:.4f}\n"
            report += f"Recall: {metrics[model_name]['recall']:.4f}\n"
            report += f"F1 Score: {metrics[model_name]['f1']:.4f}\n"
            report += f"ROC-AUC Score: {metrics[model_name]['roc_auc']:.4f}\n"
            report += f"Average Inference Time: {metrics[model_name]['avg_inference_time']*1000:.2f} ms\n"
            report += f"Inference Time Std: {metrics[model_name]['std_inference_time']*1000:.2f} ms\n\n"
        
        return report

def main(df: pd.DataFrame, run_pytorch, run_onnx):
    """
    Main function to run the model comparison.
    
    Args:
        df: DataFrame containing 'text' and 'label' columns
        run_pytorch: PyTorch model inference function
        run_onnx: ONNX model inference function
    """
    # Initialize comparison
    comparison = ModelComparison(run_pytorch, run_onnx)
    
    # Process all texts
    comparison.process_batch(df['text'].tolist())
    
    # Calculate metrics
    metrics = comparison.calculate_metrics(df['label'].tolist())
    
    # Generate visualizations
    comparison.plot_confusion_matrices(metrics)
    comparison.plot_inference_times()
    
    # Generate and save report
    report = comparison.generate_report(metrics)
    with open('model_comparison_report.txt', 'w') as f:
        f.write(report)
    
    return metrics, report

if __name__ == "__main__":
    df = pd.read_csv(input_test_data, sep='|')
    metrics, report = main(df, run_pytorch, run_onnx)
    print(report)

Processing texts: 100%|██████████| 4892/4892 [17:24<00:00,  4.69it/s]
  plt.boxplot([self.results['pytorch']['times'], self.results['onnx']['times']],


Model Comparison Report

PYTORCH Model Metrics:
Accuracy: 0.9867
Precision: 0.9799
Recall: 0.9827
F1 Score: 0.9813
ROC-AUC Score: 0.0927
Average Inference Time: 145.66 ms
Inference Time Std: 2.81 ms

ONNX Model Metrics:
Accuracy: 0.9871
Precision: 0.9811
Recall: 0.9827
F1 Score: 0.9819
ROC-AUC Score: 0.0679
Average Inference Time: 66.88 ms
Inference Time Std: 1.92 ms




In [33]:
text = 'Ya no quiero continuar en esta vida'
input = f"{{\"text\": \"{text}\"}}"
result = run_pytorch(input)
print(text,'(depression)' if result['prediction'] == 1 else '(non-depression)')

text = "\'Ya no quiero continuar en esta vida\', así dijo ese loro pendejo xD!"
input = f"{{\"text\": \"{text}\"}}"
result = run_pytorch(input)
print(text,'(depression)' if result['prediction'] == 1 else '(non-depression)')

print()

text = '¿Ser o no ser?'
input = f"{{\"text\": \"{text}\"}}"
result = run_pytorch(input)
print(text,'(depression)' if result['prediction'] == 1 else '(non-depression)')

text = "¿Ser o no ser? Mejor ya no ser nunca más"
input = f"{{\"text\": \"{text}\"}}"
result = run_pytorch(input)
print(text,'(depression)' if result['prediction'] == 1 else '(non-depression)')

print()

text = 'Me quiero morir'
input = f"{{\"text\": \"{text}\"}}"
result = run_pytorch(input)
print(text,'(depression)' if result['prediction'] == 1 else '(non-depression)')

text = "Me quiero morir, mañana ya es el concierto"
input = f"{{\"text\": \"{text}\"}}"
result = run_pytorch(input)
print(text,'(depression)' if result['prediction'] == 1 else '(non-depression)')

print()

text = 'A veces pienso que la vida no es justa'
input = f"{{\"text\": \"{text}\"}}"
result = run_pytorch(input)
print(text,'(depression)' if result['prediction'] == 1 else '(non-depression)')

text = 'A veces pienso que la vida no es justa, pero sé que todo mejorará'
input = f"{{\"text\": \"{text}\"}}"
result = run_pytorch(input)
print(text,'(depression)' if result['prediction'] == 1 else '(non-depression)')

print()

text = "Nada de lo que hago tiene éxito"
input = f"{{\"text\": \"{text}\"}}"
result = run_pytorch(input)
print(text,'(depression)' if result['prediction'] == 1 else '(non-depression)')

text = "Nada de lo que hago tiene éxito, pero no me rendiré"
input = f"{{\"text\": \"{text}\"}}"
result = run_pytorch(input)
print(text,'(depression)' if result['prediction'] == 1 else '(non-depression)')

print()

text = "¿Para qué continuar aquí?"
input = f"{{\"text\": \"{text}\"}}"
result = run_pytorch(input)
print(text,'(depression)' if result['prediction'] == 1 else '(non-depression)')

text = "¿Para qué continuar aquí? La fiesta ya murió we"
input = f"{{\"text\": \"{text}\"}}"
result = run_pytorch(input)
print(text,'(depression)' if result['prediction'] == 1 else '(non-depression)')

print()

text = "Mi vida no tiene sentido"
input = f"{{\"text\": \"{text}\"}}"
result = run_pytorch(input)
print(text,'(depression)' if result['prediction'] == 1 else '(non-depression)')

text = "Mi vida no tiene sentido, sin ti amor ven pa ca xD!"
input = f"{{\"text\": \"{text}\"}}"
result = run_pytorch(input)
print(text,'(depression)' if result['prediction'] == 1 else '(non-depression)')


print()

text = "Soy el mejor"
input = f"{{\"text\": \"{text}\"}}"
result = run_pytorch(input)
print(text,'(depression)' if result['prediction'] == 1 else '(non-depression)')

text = "Soy el mejor haciendo de mi vida un sin sentido"
input = f"{{\"text\": \"{text}\"}}"
result = run_pytorch(input)
print(text,'(depression)' if result['prediction'] == 1 else '(non-depression)')

Ya no quiero continuar en esta vida (depression)
'Ya no quiero continuar en esta vida', así dijo ese loro pendejo xD! (non-depression)

¿Ser o no ser? (non-depression)
¿Ser o no ser? Mejor ya no ser nunca más (depression)

Me quiero morir (depression)
Me quiero morir, mañana ya es el concierto (non-depression)

A veces pienso que la vida no es justa (depression)
A veces pienso que la vida no es justa, pero sé que todo mejorará (non-depression)

Nada de lo que hago tiene éxito (depression)
Nada de lo que hago tiene éxito, pero no me rendiré (non-depression)

¿Para qué continuar aquí? (depression)
¿Para qué continuar aquí? La fiesta ya murió we (non-depression)

Mi vida no tiene sentido (depression)
Mi vida no tiene sentido, sin ti amor ven pa ca xD! (non-depression)

Soy el mejor (non-depression)
Soy el mejor haciendo de mi vida un sin sentido (depression)
