In [6]:
# Install required packages if needed

# !pip install --no-deps --upgrade -q seaborn

import os
import time
import random
import torch
import unsloth
import evaluate
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from typing import Dict, List, Tuple, Optional
from datasets import Dataset, load_dataset
from sklearn.metrics import classification_report, confusion_matrix
from unsloth import FastModel
from unsloth.chat_templates import get_chat_template
from transformers import TextStreamer

print("All imports successful!")


All imports successful!


In [7]:
## Cell 2: Configuration
# Configuration - Modify these paths as needed
MODEL_PATH = "./Gemma3ToxicTextClassifier"
DATA_PATH = "cike-dev/en_toxic_set"
OUTPUT_DIR = "unsloth-gemma3-full-ft"
MAX_SEQ_LENGTH = 256
CHAT_TEMPLATE = "gemma3"

print(f"Model Path: {MODEL_PATH}")
print(f"Data Path: {DATA_PATH}")
print(f"Output Directory: {OUTPUT_DIR}")

Model Path: ./Gemma3ToxicTextClassifier
Data Path: cike-dev/en_toxic_set
Output Directory: unsloth-gemma3-full-ft


In [None]:
class ModelEvaluator:
    """A class for evaluating fine-tuned language models on classification tasks."""
    
    # def __init__(self, model_path: str, max_seq_length: int = 128, chat_template: str = "gemma3"):
    #     """Initialize the evaluator with a model."""
    #     self.model_path = model_path
    #     self.model, self.tokenizer = self._load_model(model_path, max_seq_length, chat_template)


    def __init__(self, model_path: str, max_seq_length: int = 512, chat_template: str = "gemma3", device: Optional[str] = None):
        """Initialize the evaluator with a model."""
        self.device = torch.device(device if device else ("cuda" if torch.cuda.is_available() else "cpu"))
        self.model_path = model_path
        self.model, self.tokenizer = self._load_model(model_path, max_seq_length, chat_template)
        # # move model to device and set eval
        # self.model.to(self.device)
        self.model.eval()

    
    # def _load_model(self, model_path: str, max_seq_length: int, chat_template: str):
    #     """Load the model and tokenizer."""
    #     print(f"Loading model from {model_path}...")
    #     model, tokenizer = FastModel.from_pretrained(
    #         model_name=model_path,
    #         max_seq_length=max_seq_length,
    #     )
        
    #     tokenizer = get_chat_template(tokenizer, chat_template=chat_template)
    #     print("Model loaded successfully!")
    #     return model, tokenizer

    def _load_model(self, model_path: str, max_seq_length: int, chat_template: str):
        """Load the model and tokenizer."""
        print(f"Loading model from {model_path}...")
        # NOTE: change FastModel to your actual loader (this code assumes it returns (model, tokenizer))
        model, tokenizer = FastModel.from_pretrained(
            model_name=model_path,
            max_seq_length=max_seq_length,
        )
        tokenizer = get_chat_template(tokenizer, chat_template=chat_template)

        # Make sure decoder models use left padding if needed
        try:
            tokenizer.padding_side = "left"
        except Exception:
            pass

        print("Model loaded successfully!")
        return model, tokenizer
    
    def _format_input(self, conversations: List[Dict]) -> str:
        """Format input using chat template."""
        messages = [
            {"role": "system", "content": conversations[0]["content"]},
            {"role": "user", "content": conversations[1]["content"]}
        ]
        
        input_text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        ).removeprefix('<bos>')
        
        return input_text
    
    # def _generate_prediction(self, input_text: str, max_new_tokens: int = 10, 
    #                        temperature: int = 1, do_sample: bool = False) -> str:
    #     """Generate prediction for a single input."""
    #     inputs = self.tokenizer(input_text, return_tensors="pt").to("cuda")
        
    #     outputs = self.model.generate(
    #         **inputs,
    #         max_new_tokens=max_new_tokens,
    #         temperature=temperature,
    #         do_sample=do_sample,
    #         pad_token_id=self.tokenizer.pad_token_id,
    #         eos_token_id=self.tokenizer.eos_token_id
    #     )

    #     generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
    #     prompt_length = len(self.tokenizer.decode(
    #         self.tokenizer(input_text, return_tensors="pt")["input_ids"][0], 
    #         skip_special_tokens=True
    #     ))
        
        # return generated_text[prompt_length:].strip()
    
    def _generate_prediction(self, input_text: str, max_new_tokens: int = 10, 
                             temperature: float = 1.0, do_sample: bool = False) -> str:
        """Generate prediction for a single input (returns *only* generated text)."""
        # prepare input tensors
        with torch.no_grad():
            inputs = self.tokenizer(
                input_text,
                return_tensors="pt",
                truncation=True,
            )
            # Move tensors to device
            inputs = {k: v.to(self.device) for k, v in inputs.items()}

            # safe pad/eos token ids
            pad_token_id = getattr(self.tokenizer, "pad_token_id", None) or getattr(self.tokenizer, "eos_token_id", None)
            eos_token_id = getattr(self.tokenizer, "eos_token_id", None)

            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=do_sample,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
            )

            # decode only newly generated tokens (token-level)
            input_len = inputs["input_ids"].shape[1]
            generated_ids = outputs[0][input_len:]
            generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
            return generated_text.strip()

print("ModelEvaluator class defined successfully!")

ModelEvaluator class defined successfully!


In [9]:
import os
import time
from datetime import datetime
from typing import List, Dict, Any, Optional
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import evaluate


def evaluate_multiple_datasets(
    evaluator,
    dataset_paths: List[str],
    output_dir: str = "./multi_dataset_evaluation/",
    split: str = "test",
    sample_size: Optional[int] = None,
    text_column: str = "text",
    label_column: str = "label",
    max_new_tokens: int = 10,
    system_message: Optional[str] = None,
    save_individual_cms: bool = True,
    save_summary_report: bool = True
) -> Dict[str, Dict[str, Any]]:
    """
    Evaluate a model on multiple datasets and generate comprehensive reports.
    
    Args:
        evaluator: ModelEvaluator instance
        dataset_paths: List of dataset paths (e.g., ["owner/dataset1", "owner/dataset2"])
        output_dir: Directory to save all results
        split: Dataset split to use for evaluation
        text_column: Name of the text column in datasets
        label_column: Name of the label column in datasets
        max_new_tokens: Maximum tokens to generate
        system_message: Custom system message (uses default if None)
        save_individual_cms: Whether to save individual confusion matrices
        save_summary_report: Whether to save summary text report
        
    Returns:
        Dictionary with results for each dataset
    """
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Default system message
    if system_message is None:
        system_message = "You are a binary classifier for cyberbullying speech. Respond only with 'toxic' or 'normal'."
    
    # Store results for all datasets
    all_results = {}
    summary_lines = []
    
    # Add header to summary
    timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    summary_lines.extend([
        "="*80,
        "MULTI-DATASET EVALUATION REPORT",
        "="*80,
        f"Evaluation Date: {timestamp}",
        f"Model Path: {evaluator.model_path}",
        f"Number of Datasets: {len(dataset_paths)}",
        f"Split Used: {split}",
        "="*80,
        ""
    ])
    
    print(f"\nüöÄ Starting evaluation on {len(dataset_paths)} datasets...")
    print(f"üìÅ Results will be saved to: {output_dir}")
    
    # Process each dataset
    for idx, dataset_path in enumerate(dataset_paths, 1):
        print(f"\n{'='*60}")
        print(f"üìä DATASET {idx}/{len(dataset_paths)}: {dataset_path}")
        print(f"{'='*60}")
        
        try:
            # Load and prepare dataset
            start_time = time.time()
            dataset_results = _evaluate_single_dataset(
                evaluator=evaluator,
                dataset_path=dataset_path,
                split=split,
                text_column=text_column,
                label_column=label_column,
                max_new_tokens=max_new_tokens,
                system_message=system_message,
                output_dir=output_dir,
                save_cm=save_individual_cms,
                sample_size=sample_size,
            )
            end_time = time.time()

            # Store results
            dataset_results['evaluation_time_seconds'] = round(end_time - start_time, 2)
            all_results[dataset_path] = dataset_results
            
            # Add to summary
            summary_lines.extend(_format_dataset_summary(dataset_path, dataset_results))
            
            print(f"‚úÖ Completed {dataset_path} in {dataset_results['evaluation_time_seconds']:.1f}s")
            
        except Exception as e:
            error_msg = f"‚ùå Error processing {dataset_path}: {str(e)}"
            print(error_msg)
            summary_lines.extend([
                f"Dataset: {dataset_path}",
                f"Status: FAILED",
                f"Error: {str(e)}",
                "-" * 40,
                ""
            ])
            all_results[dataset_path] = {"status": "failed", "error": str(e)}
    
    # Generate comparison visualization
    if len([r for r in all_results.values() if r.get('status') != 'failed']) > 1:
        _create_comparison_chart(all_results, output_dir)
    
    # Save summary report
    if save_summary_report:
        _save_summary_report(summary_lines, all_results, output_dir)
    
    print(f"\nüéâ Evaluation complete! Results saved to: {output_dir}")
    return all_results


def _evaluate_single_dataset(
    evaluator, dataset_path: str, split: str, text_column: str, 
    label_column: str, max_new_tokens: int, system_message: str,
    output_dir: str, save_cm: bool, sample_size: Optional[int] = None 
) -> Dict[str, Any]:
    """Evaluate model on a single dataset."""
    
    from datasets import load_dataset
    
    # Load dataset
    print(f"üì• Loading {dataset_path}...")
    try:
        dataset = load_dataset(dataset_path)

        # # uncomment to use a small sample for testing and debugging purposes
        # eval_dataset = dataset[split].select(range(10))
        # print("Remember to uncomment 'eval_dataset = dataset[split]' to use the full dataset.")
        # # # uncomment to use the full split for evaluation
        # # eval_dataset = dataset[split]

        # sample_size = None  # or set to e.g. 10 for debugging
        if sample_size is not None:
            eval_dataset = dataset[split].select(range(min(sample_size, len(dataset[split]))))
        else:
            eval_dataset = dataset[split]

        print(f"‚úÖ Loaded {len(eval_dataset)} samples")
    except Exception as e:
        raise Exception(f"Failed to load dataset: {e}")
    
    # Check if columns exist
    if text_column not in eval_dataset.column_names:
        # Try common alternatives
        possible_text_cols = ['text', 'cleaned_text', 'content', 'message', 'tweet']
        found_col = None
        for col in possible_text_cols:
            if col in eval_dataset.column_names:
                found_col = col
                break
        
        if found_col:
            print(f"‚ö†Ô∏è  '{text_column}' not found, using '{found_col}' instead")
            text_column = found_col
        else:
            raise Exception(f"Text column not found. Available: {eval_dataset.column_names}")
    
    if label_column not in eval_dataset.column_names:
        possible_label_cols = ['label', 'labels', 'target', 'class', 'category']
        found_col = None
        for col in possible_label_cols:
            if col in eval_dataset.column_names:
                found_col = col
                break
        
        if found_col:
            print(f"‚ö†Ô∏è  '{label_column}' not found, using '{found_col}' instead")
            label_column = found_col
        else:
            raise Exception(f"Label column not found. Available: {eval_dataset.column_names}")
    
    # Convert to ChatML format
    print("üîÑ Converting to ChatML format...")
    
    # def convert_to_chatml(example):
    #     text_content = example[text_column]
    #     user_prompt = f"Classify this text: {text_content}"
        
    #     # Handle different label formats
    #     label = example[label_column]
    #     if isinstance(label, (int, float)):
    #         label = "toxic" if label == 1 else "normal"
    #     elif isinstance(label, str):
    #         # Ensure consistent format
    #         label = label.lower().strip()
    #         if label in ['1', 'toxic', 'bully', 'cyberbullying', 'hate']:
    #             label = "toxic"
    #         else:
    #             label = "normal"
        
    #     return {
    #         "conversations": [
    #             {"role": "system", "content": system_message},
    #             {"role": "user", "content": user_prompt},
    #             {"role": "assistant", "content": label}
    #         ],
    #         "raw_text": text_content
    #     }
    
    # map: convert to ChatML
    def convert_to_chatml(example):
        text_content = example.get(text_column, example.get("text", ""))

        # robust label handling
        raw_label = example.get(label_column)
        if raw_label is None:
            label = "normal"
        else:
            if isinstance(raw_label, (int, float)):
                label = "toxic" if int(raw_label) == 1 else "normal"
            else:
                normalized = str(raw_label).lower().strip()
                # consider numeric strings as well
                if normalized in {'1', 'toxic', 'bully', 'cyberbullying', 'hate'} or 'toxic' in normalized:
                    label = "toxic"
                else:
                    label = "normal"

        # user_prompt = f"Classify this text: {text_content}"
        user_prompt = f"Classify the text as 'toxic' or 'normal'. Output only one word. Text: '{text_content}'"
        return {
            "conversations": [
                {"role": "system", "content": system_message},
                {"role": "user", "content": user_prompt},
                {"role": "assistant", "content": label}
            ],
            "raw_text": text_content
        }


    eval_dataset = eval_dataset.map(convert_to_chatml)
    
    # Run evaluation
    print("üß† Running model inference...")
    predictions = []
    references = []
    
    evaluator.model.eval()

    
    # Prepare log file
    log_file = os.path.join(output_dir, "invalid_predictions.log")
    with open(log_file, "a", encoding="utf-8") as f:
        f.write("=" * 80 + "\n")
        f.write(f"Invalid Predictions Log - Dataset: {dataset_path}\n")
        f.write("=" * 80 + "\n\n")

    def log_invalid(sample, true_label, pred_text):
        """Append invalid predictions to a text file."""
        with open(log_file, "a", encoding="utf-8") as f:
            f.write("\n#" + "-" * 76 + "#\n")
            f.write(f"User prompt: str({sample['conversations'][0]['content']})\n")
            f.write(f"True label: str({true_label})\n")
            f.write(f"Predicted raw: str({pred_text})\n")
            f.write("#" + "-" * 76 + "#")
    
    #     ## Original code
    # for i, sample in enumerate(eval_dataset):
    #     if i % 200 == 0:
    #         print(f"  Processed {i}/{len(eval_dataset)} samples ({i/len(eval_dataset)*100:.1f}%)")
        
    #     ## Original code
    #     input_text = evaluator._format_input(sample["conversations"])
    #     pred_text = evaluator._generate_prediction(input_text, max_new_tokens, do_sample=False)
        
    #     # Convert to binary
    #     pred_label = 1 if pred_text.lower() == "toxic" else 0
    #     true_label = 1 if sample["conversations"][2]["content"].lower() == "toxic" else 0    
    # print(f"  Processed {len(eval_dataset)}/{len(eval_dataset)} samples (100.0%)")
    
    # tqdm will auto-handle progress display
    exs = ""
    for i, sample in enumerate(tqdm(eval_dataset, desc="Evaluating", unit="sample", total=len(eval_dataset))):
        try:
            input_text = evaluator._format_input(sample["conversations"])
            true_text = sample["conversations"][2]["content"]
            pred_text = evaluator._generate_prediction(input_text, max_new_tokens, do_sample=False)

            # üëá Debug: print raw outputs for the first few
            if (i < 5) and (true_text == "toxic"):
                # print(f"[{i}] INPUT: {input_text.strip()}")
                # print(f"[{i}] RAW PREDICTION: {pred_text.strip()}\n")
                exs = exs + 'Input: '+input_text + '\nPred label: '+pred_text + '\nTrue label: '+true_text + '\n'

            normalized_pred = pred_text.lower().strip()

            # Classify
            # if normalized_pred.startswith("toxic"):
            #     pred_label = 1
            # elif normalized_pred.startswith("normal"):
            #     pred_label = 0
            # else:
            #     # Invalid prediction ‚Üí log and skip
            #     log_invalid(input_text, true_text, pred_text)
            #     continue  # skip appending to keep predictions/references aligned

            if "toxic" in normalized_pred.split():
                pred_label = 1
            elif "normal" in normalized_pred.split():
                pred_label = 0
            else:
                log_invalid(input_text, true_label, pred_text)
                continue


            # Ground truth
            true_label = 1 if true_text == "toxic" else 0

            predictions.append(pred_label)
            references.append(true_label)

        except Exception as e:
            # Catch runtime errors (generation, formatting, etc.)
            with open(log_file, "a", encoding="utf-8") as f:
                f.write("-----------------------\n")
                f.write(f"Error processing sample: {str(e)}\n")
                f.write("-----------------------\n\n")
            continue
    print("Sample input and pred:\n" , exs)

    # Calculate metrics
    print("üìä Calculating metrics...")
    accuracy_metric = evaluate.load("accuracy")
    f1_metric = evaluate.load("f1")

    # Safeguard against invalid prediction and references
    if not predictions:
        raise ValueError("Predictions list is empty")
    if not references:
        raise ValueError("References list is empty")
    if len(predictions) != len(references):
        raise ValueError(
            f"Length mismatch: predictions={len(predictions)}, references={len(references)}"
        )    
    
    from collections import Counter
    pred_count = Counter(predictions)
    print(f"\nLabel distribution in predictions: {pred_count}\n")

    # Calculate the accuracy and f1 scores
    accuracy = accuracy_metric.compute(predictions=predictions, references=references)["accuracy"]
    # f1 = f1_metric.compute(predictions=predictions, references=references)["f1"]
    f1 = f1_metric.compute(predictions=predictions, references=references, average="weighted")["f1"]
    
    # Classification report
    report = classification_report(
        references, predictions, 
        target_names=['normal', 'toxic'], 
        digits=4,
        output_dict=True
    )
    
    # Save confusion matrix if requested
    cm_path = None
    if save_cm:
        dataset_name = dataset_path.split('/')[-1]  # Get dataset name from path
        cm_path = _save_confusion_matrix(references, predictions, output_dir, dataset_name)
    
    return {
        "status": "success",
        "dataset_size": len(eval_dataset),
        "accuracy": accuracy,
        "f1_score": f1,
        "classification_report": report,
        "confusion_matrix_path": cm_path,
        "text_column_used": text_column,
        "label_column_used": label_column,
        "pred_count": pred_count
    }


def _save_confusion_matrix(references: List[int], predictions: List[int], 
                          output_dir: str, dataset_name: str) -> str:
    """Save confusion matrix for a dataset."""
    
    cm = confusion_matrix(references, predictions)
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(
        cm, 
        annot=True, 
        fmt="d", 
        cmap="Blues", 
        xticklabels=['normal', 'toxic'], 
        yticklabels=['normal', 'toxic'], 
        cbar=False,
        square=True,
        annot_kws={"fontsize": 14}
    )
    plt.xlabel("Predicted Label", fontsize=10)
    plt.ylabel("True Label", fontsize=10)
    plt.title(f"Gemma3 CM - {dataset_name}", fontsize=14, pad=20)
    plt.tight_layout()
    
    # Save with dataset name
    safe_name = dataset_name.replace('/', '_').replace(' ', '_')
    cm_path = os.path.join(output_dir, f"gemma3_cm_{safe_name}.png")
    plt.savefig(cm_path, dpi=300, bbox_inches='tight')
    plt.close()  # Close to avoid memory issues
    
    return cm_path


def _format_dataset_summary(dataset_path: str, results: Dict[str, Any]) -> List[str]:
    """Format summary for a single dataset."""
    
    if results.get('status') == 'failed':
        return [
            f"Dataset: {dataset_path}",
            f"Status: FAILED",
            f"Error: {results.get('error', 'Unknown error')}",
            "-" * 40,
            ""
        ]
    
    lines = [
        f"Dataset: {dataset_path}",
        f"Samples: {results['dataset_size']:,}",
        f"Accuracy: {results['accuracy']:.4f}",
        f"Weighted F1 Score: {results['f1_score']:.4f}",
        f"Evaluation Time: {results['evaluation_time_seconds']:.1f}s",
    ]
    
    # Add detailed metrics from classification report
    report = results['classification_report']
    lines.extend([
        "",
        "Detailed Metrics:",
        f"  Normal - Precision: {report['normal']['precision']:.4f}, Recall: {report['normal']['recall']:.4f}, F1: {report['normal']['f1-score']:.4f}",
        f"  Toxic  - Precision: {report['toxic']['precision']:.4f}, Recall: {report['toxic']['recall']:.4f}, F1: {report['toxic']['f1-score']:.4f}",
        # f"  Normal - Precision: {report['0']['precision']:.4f}, Recall: {report['0']['recall']:.4f}, F1: {report['0']['f1-score']:.4f}",
        # f"  Toxic  - Precision: {report['1']['precision']:.4f}, Recall: {report['1']['recall']:.4f}, F1: {report['1']['f1-score']:.4f}",
        f"  Macro Avg - Precision: {report['macro avg']['precision']:.4f}, Recall: {report['macro avg']['recall']:.4f}, F1: {report['macro avg']['f1-score']:.4f}",
        ""
    ])
    
    if results.get('confusion_matrix_path'):
        lines.append(f"Confusion Matrix: {results['confusion_matrix_path']}")
    
    lines.extend(["-" * 40, ""])
    return lines


def _create_comparison_chart(all_results: Dict[str, Dict], output_dir: str):
    """Create comparison chart across datasets."""
    # Extract successful results
    successful_results = {k: v for k, v in all_results.items() if v.get('status') == 'success'}
    
    if len(successful_results) < 2:
        return
    
    dataset_names = [path.split('/')[-1] for path in successful_results.keys()]
    accuracies = [results['accuracy'] for results in successful_results.values()]
    f1_scores = [results['f1_score'] for results in successful_results.values()]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Accuracy comparison
    bars1 = ax1.bar(dataset_names, accuracies, color='skyblue', alpha=0.8)
    ax1.set_title('Gemma3 - Accuracy Comparison Across Datasets', fontsize=14)
    ax1.set_ylabel('Accuracy', fontsize=12)
    ax1.set_ylim(0, 1)
    ax1.tick_params(axis='x', rotation=45)
    
    # Add value labels on bars
    for bar, acc in zip(bars1, accuracies):
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                f'{acc:.3f}', ha='center', va='bottom')
    
    # F1 Score comparison
    bars2 = ax2.bar(dataset_names, f1_scores, color='lightcoral', alpha=0.8)
    ax2.set_title('F1 Score Comparison Across Datasets', fontsize=14)
    ax2.set_ylabel('F1 Score', fontsize=12)
    ax2.set_ylim(0, 1)
    ax2.tick_params(axis='x', rotation=45)
    
    # Add value labels on bars
    for bar, f1 in zip(bars2, f1_scores):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                f'{f1:.3f}', ha='center', va='bottom')
    
    plt.tight_layout()
    comparison_path = os.path.join(output_dir, "performance_comparison.png")
    plt.savefig(comparison_path, dpi=300, bbox_inches='tight')
    # plt.show()
    plt.close()
    
    print(f"üìä Comparison chart saved to: {comparison_path}")


def _save_summary_report(summary_lines: List[str], all_results: Dict[str, Dict], output_dir: str):
    """Save comprehensive summary report to text file."""
    
    # Add overall summary statistics
    successful_results = [r for r in all_results.values() if r.get('status') == 'success']
    failed_count = len(all_results) - len(successful_results)
    
    if successful_results:
        avg_accuracy = sum(r['accuracy'] for r in successful_results) / len(successful_results)
        avg_f1 = sum(r['f1_score'] for r in successful_results) / len(successful_results)
        total_samples = sum(r['dataset_size'] for r in successful_results)
        
        summary_lines.extend([
            "",
            "OVERALL SUMMARY:",
            f"‚úÖ Successful evaluations: {len(successful_results)}",
            f"‚ùå Failed evaluations: {failed_count}",
            f"üìä Average Accuracy: {avg_accuracy:.4f}",
            f"üìä Average F1 Score: {avg_f1:.4f}",
            f"üìà Total samples evaluated: {total_samples:,}",
            "",
            "="*80,
            "DETAILED RESULTS:",
            "="*80,
            ""
        ])
    
    # Save to file
    report_path = os.path.join(output_dir, "evaluation_summary.txt")
    with open(report_path, 'w', encoding='utf-8') as f:
        f.write('\n'.join(summary_lines))
    
    print(f"üìÑ Summary report saved to: {report_path}")


# Example usage function
def run_multi_dataset_evaluation_example():
    """Example of how to use the multi-dataset evaluation function."""
    
    # Example dataset paths
    dataset_paths = []
    names = ['stormfront', 'tdavidson', 'olid', 'gab', 'hatexplain']

    for name in names:
        dataset_paths.append(f"cike-dev/{name}")

    # model_id = "cike-dev/GemmaToxicClassifier"
    model_id = "Gemma3ToxicTextClassifier"
    
    # Initialize your evaluator (assuming you have one)
    evaluator = ModelEvaluator(model_id)
    
    # Run evaluation
    results = evaluate_multiple_datasets(
        evaluator=evaluator,
        dataset_paths=dataset_paths,
        output_dir="./multi_dataset_results/",
        split="test",
        sample_size=2000,
        text_column="text",  # Adjust as needed
        label_column="label"  # Adjust as needed
    )
    
    return results
    # print("Example function defined. Uncomment and modify the code above to run.")

In [10]:
evval = run_multi_dataset_evaluation_example()

Loading model from Gemma3ToxicTextClassifier...
==((====))==  Unsloth 2025.9.4: Fast Gemma3_Text patching. Transformers: 4.56.1.
   \\   /|    NVIDIA L4. Num GPUs = 1. Max memory: 22.278 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu128. CUDA: 8.9. CUDA Toolkit: 12.8. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.32.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Model loaded successfully!

üöÄ Starting evaluation on 5 datasets...
üìÅ Results will be saved to: ./multi_dataset_results/

üìä DATASET 1/5: cike-dev/stormfront
üì• Loading cike-dev/stormfront...
‚úÖ Loaded 2000 samples
‚ö†Ô∏è  'text' not found, using 'cleaned_text' instead
üîÑ Converting to ChatML format...


Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

üß† Running model inference...


Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2000/2000 [06:34<00:00,  5.07sample/s]


Sample input and pred:
 Input: <start_of_turn>user
You are a binary classifier for cyberbullying speech. Respond only with 'toxic' or 'normal'.

Classify the text as 'toxic' or 'normal'. Output only one word. Text: 'the negro is generally terrified of whites , the second you stand up to them they usally back away , thats unless they re carrying a gun at which point the savages are rarely reluctant to use .'<end_of_turn>
<start_of_turn>model

Pred label: toxic
True label: toxic

üìä Calculating metrics...

Label distribution in predictions: Counter({0: 1232, 1: 603})

‚úÖ Completed cike-dev/stormfront in 396.9s

üìä DATASET 2/5: cike-dev/tdavidson
üì• Loading cike-dev/tdavidson...
‚úÖ Loaded 2000 samples
‚ö†Ô∏è  'text' not found, using 'cleaned_text' instead
üîÑ Converting to ChatML format...


Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

üß† Running model inference...


Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2000/2000 [07:52<00:00,  4.23sample/s]


Sample input and pred:
 Input: <start_of_turn>user
You are a binary classifier for cyberbullying speech. Respond only with 'toxic' or 'normal'.

Classify the text as 'toxic' or 'normal'. Output only one word. Text: 'im happy for the bung holes and negativity desperate guys spew on here ;a good example of what not to be and what i don;'<end_of_turn>
<start_of_turn>model

Pred label: normal
True label: toxic
Input: <start_of_turn>user
You are a binary classifier for cyberbullying speech. Respond only with 'toxic' or 'normal'.

Classify the text as 'toxic' or 'normal'. Output only one word. Text: 'smh u niggers and ur music knowledge'<end_of_turn>
<start_of_turn>model

Pred label: toxic
True label: toxic

üìä Calculating metrics...

Label distribution in predictions: Counter({1: 1030, 0: 687})

‚úÖ Completed cike-dev/tdavidson in 474.6s

üìä DATASET 3/5: cike-dev/olid
üì• Loading cike-dev/olid...
‚úÖ Loaded 2000 samples
‚ö†Ô∏è  'text' not found, using 'cleaned_text' instead
üîÑ Conver

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

üß† Running model inference...


Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2000/2000 [11:54<00:00,  2.80sample/s]


Sample input and pred:
 Input: <start_of_turn>user
You are a binary classifier for cyberbullying speech. Respond only with 'toxic' or 'normal'.

Classify the text as 'toxic' or 'normal'. Output only one word. Text: 'it makes my skin crawl just looking at this woman  theresa may is a liability to our nation and  and must be removed asap'<end_of_turn>
<start_of_turn>model

Pred label: normal
True label: toxic
Input: <start_of_turn>user
You are a binary classifier for cyberbullying speech. Respond only with 'toxic' or 'normal'.

Classify the text as 'toxic' or 'normal'. Output only one word. Text: 'how? you cant just say wrong with no evidence for why it is. you said you agree with sargon im pretty much every subject but gun control and healthcare. thats not what a centrist believes.'<end_of_turn>
<start_of_turn>model

Pred label: normal. gun control and healthcare are the most common
True label: toxic

üìä Calculating metrics...

Label distribution in predictions: Counter({0: 737, 1: 43

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

üß† Running model inference...


Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2000/2000 [12:19<00:00,  2.70sample/s]


Sample input and pred:
 Input: <start_of_turn>user
You are a binary classifier for cyberbullying speech. Respond only with 'toxic' or 'normal'.

Classify the text as 'toxic' or 'normal'. Output only one word. Text: 'two bits of feedback. first, that catholic church shill that calls in is getting annoying as fuck. second, when people ask you to let jews talk when they call in, tell them that they can turn on any msm if they are so desperate to listen to jews. your show is an oasis of jew free content, please dont let them speak there.'<end_of_turn>
<start_of_turn>model

Pred label: normal. the jews are not going to listen
True label: toxic
Input: <start_of_turn>user
You are a binary classifier for cyberbullying speech. Respond only with 'toxic' or 'normal'.

Classify the text as 'toxic' or 'normal'. Output only one word. Text: 'international womens day today eh, hope the ladies are doing their bit to level up the 95 of workplace deaths and 66 suicide rate. i worry for boys growing up no

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

üß† Running model inference...


Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2000/2000 [05:52<00:00,  5.67sample/s]


Sample input and pred:
 Input: <start_of_turn>user
You are a binary classifier for cyberbullying speech. Respond only with 'toxic' or 'normal'.

Classify the text as 'toxic' or 'normal'. Output only one word. Text: 'low iq mad maxine living up to her dumb nigger name'<end_of_turn>
<start_of_turn>model

Pred label: toxic
True label: toxic
Input: <start_of_turn>user
You are a binary classifier for cyberbullying speech. Respond only with 'toxic' or 'normal'.

Classify the text as 'toxic' or 'normal'. Output only one word. Text: 'another brain fart from morbious stone no actually you just displayed you are some king of retard boomer who can not read rotfl stop sucking alex jewns dick must have just graduated from kinder garden lowest iq list lmao'<end_of_turn>
<start_of_turn>model

Pred label: toxic
True label: toxic

üìä Calculating metrics...

Label distribution in predictions: Counter({1: 1256, 0: 645})

‚úÖ Completed cike-dev/hatexplain in 354.8s
üìä Comparison chart saved to: ./mult