In [None]:



Few showt epreiment 

#!/usr/bin/env python3
"""
Victorian Era Authorship Attribution with Llama3 Prompting
"""

import os
import re
import random
import json
import time
import logging
import subprocess
from pathlib import Path
import requests
from datetime import datetime
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# -------------------- Config & Seeding --------------------
SEED = 42
TEST_SIZE = 0.2
EVAL_SAMPLES = 100
TRAIN_CSV = 'Gungor_2018_VictorianAuthorAttribution_data-train.csv'
AUTHOR_LIST = 'author_list.txt'  # New: file containing author names
OLLAMA_API = "http://localhost:11434/api"
OLLAMA_MODEL = "llama3"  # Base Ollama model

# Set up enhanced logging with more descriptive information
log_dir = "logs"
os.makedirs(log_dir, exist_ok=True)
log_file = f"{log_dir}/victorian_attribution_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(log_file),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger("victorian_attribution")

# Create results directory
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)

# Set random seeds
random.seed(SEED)
np.random.seed(SEED)

# -------------------- Ollama API Functions --------------------
def check_ollama_running():
    """Check if Ollama is running and available"""
    try:
        response = requests.get(f"{OLLAMA_API}/tags", timeout=5)
        if response.status_code == 200:
            return True
        return False
    except Exception:
        return False

def call_ollama(prompt, model=OLLAMA_MODEL):
    """
    Call Ollama API with improved error handling
    - Uses stream=false to avoid JSON parsing errors
    - Adds timeout and better error handling
    """
    max_retries = 3
    for attempt in range(max_retries):
        try:
            # Key fix: Setting stream to false to get a single JSON response
            # instead of streaming multiple responses that cause parsing errors
            response = requests.post(
                f"{OLLAMA_API}/generate",
                json={
                    "model": model,
                    "prompt": prompt,
                    "temperature": 0.0,
                    "max_tokens": 30,
                    "stream": False  # *** This is crucial to fix the JSON parsing error ***
                },
                timeout=30  # Add a reasonable timeout
            )
            
            if response.status_code == 200:
                try:
                    # Parse the JSON response safely
                    return response.json().get("response", "").strip()
                except (ValueError, KeyError) as json_err:
                    logger.error(f"JSON parsing error: {json_err}")
                    logger.error(f"Response content: {response.text[:100]}...")
            else:
                logger.error(f"API error: {response.status_code} - {response.text[:100]}")
        except Exception as e:
            logger.error(f"Request error: {e}")
            
        if attempt < max_retries - 1:
            logger.info(f"Retrying... Attempt {attempt+1}/{max_retries}")
            # More generous backoff
            time.sleep(3 * (attempt + 1))
        else:
            logger.error(f"Failed after {max_retries} attempts")
            return "Error"

# -------------------- Load Author Names --------------------
def load_author_names(author_list_path):
    """Load author names from author_list.txt file"""
    logger.info(f"Loading author names from {author_list_path}")
    
    author_names = {}
    
    # Check if file exists
    if not os.path.exists(author_list_path):
        logger.warning(f"Author list file {author_list_path} not found. Will use author IDs only.")
        return author_names
    
    try:
        with open(author_list_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line:
                    # Expecting format: "ID Name"
                    # E.g., "1 Charles Dickens"
                    parts = line.split(maxsplit=1)
                    if len(parts) == 2:
                        author_id, author_name = parts
                        try:
                            author_id = int(author_id)
                            author_names[author_id] = author_name
                        except ValueError:
                            logger.warning(f"Invalid author ID format: {line}")
                    else:
                        logger.warning(f"Invalid author list format: {line}")
        
        logger.info(f"Loaded {len(author_names)} author names")
    except Exception as e:
        logger.error(f"Error loading author names: {e}")
    
    return author_names

# -------------------- Load & Create Special Split --------------------
def load_and_map(csv_path, author_names=None):
    """Load CSV and map author IDs with enhanced logging"""
    logger.info(f"Loading Victorian author dataset from {csv_path}")
    df = pd.read_csv(csv_path, encoding='latin-1', engine='python')
    if 'text' not in df.columns or 'author' not in df.columns:
        raise KeyError("CSV must contain 'text' and 'author' columns")
    
    # Data cleaning
    df = df.dropna(subset=['text']).reset_index(drop=True)
    
    # Map author IDs to contiguous indices
    orig_ids = sorted(df['author'].astype(int).unique())
    id_map = {orig: idx for idx, orig in enumerate(orig_ids)}
    df['author_idx'] = df['author'].astype(int).map(id_map)
    
    # Add author names to dataframe if available
    if author_names:
        df['author_name'] = df['author'].astype(int).map(lambda x: author_names.get(x, f"Unknown Author {x}"))
        
        # Log which author IDs don't have names
        missing_names = [id for id in orig_ids if id not in author_names]
        if missing_names:
            logger.warning(f"Missing author names for IDs: {missing_names}")
    else:
        df['author_name'] = df['author'].astype(str)
    
    # Get author distribution
    author_counts = df.groupby('author').size()
    logger.info(f"Dataset loaded: {len(df)} fragments from {len(orig_ids)} unique authors")
    logger.info(f"Author distribution: Min={author_counts.min()}, Max={author_counts.max()}, Mean={author_counts.mean():.1f}")
    
    return df, orig_ids

# -------------------- Helper Functions --------------------
def extract_author_id(response, valid_ids, author_names=None):
    """
    Extract author ID from response text with improved pattern matching.
    Now also checks for author names if provided.
    """
    valid_ids_str = [str(id) for id in valid_ids]
    
    logger.debug(f"Extracting author ID from: {response}")
    
    # First look for exact matches (entire response is just a number)
    if response in valid_ids_str:
        logger.debug(f"Found exact match: {response}")
        return response

    # Look for IDs with various prefixes/patterns
    patterns = [
        r"(?:author|id|author id)[:\s]*(\d+)",
        r"(?:the answer is|i think)[:\s]*(\d+)",
        r"(\d+)(?:\s*is the author)",
        r"^[^\d]*(\d+)[^\d]*$"  # Improved pattern to extract a number surrounded by non-digits
    ]

    for pattern in patterns:
        matches = re.findall(pattern, response.lower())
        if matches:
            for match in matches:
                if match in valid_ids_str:
                    logger.debug(f"Extracted ID {match} using pattern: {pattern}")
                    return match
    
    # If author names are provided, try to match those
    if author_names:
        for author_id, author_name in author_names.items():
            # Replace common titles/honorifics that might be abbreviated differently
            simplified_name = re.sub(r'(mr|mrs|ms|dr|sir|lady|lord)\b\.?', '', author_name.lower()).strip()
            # Get last name
            last_name = simplified_name.split()[-1] if simplified_name else ""
            
            # Create patterns to match author names
            name_patterns = [
                # Full name match
                re.escape(author_name.lower()),
                # Last name match with author context
                rf"(?:author|writer)[:\s]*{re.escape(last_name)}",
                # Just last name if it's distinctive enough (at least 5 chars)
                re.escape(last_name) if len(last_name) >= 5 else None
            ]
            
            # Remove None patterns
            name_patterns = [p for p in name_patterns if p]
            
            for pattern in name_patterns:
                if re.search(pattern, response.lower()):
                    logger.debug(f"Matched author name: {author_name}")
                    return str(author_id)
    
    logger.warning(f"Could not extract a valid author ID from: {response}")
    return None

# -------------------- Prepare Examples for Few-Shot --------------------
def prepare_high_quality_examples(train_df, num_authors, orig_ids, author_names=None):
    """Prepare high-quality, representative examples for each author"""
    logger.info("Preparing carefully selected, representative examples for few-shot prompting")
    
    examples = {}
    # Get longest texts for each author to have more representative samples
    for idx in range(num_authors):
        author_texts = train_df[train_df['author_idx']==idx]['text'].reset_index(drop=True)
        if len(author_texts) > 0:
            # Get the longest example for better representation
            text_lengths = author_texts.str.len()
            longest_idx = text_lengths.argmax()
            full_text = author_texts.iloc[longest_idx]
            
            # Limit to first 150 words for consistency
            ex_text = ' '.join(full_text.split()[:150])
            examples[idx] = ex_text
            
            # Log sample counts per author
            author_id = orig_ids[idx]
            author_info = f"{author_names.get(author_id, '')} (ID: {author_id})" if author_names else f"Author ID {author_id}"
            logger.debug(f"Selected example for {author_info} from {len(author_texts)} available samples")
    
    # Save examples for future use
    with open(f"{results_dir}/few_shot_examples.json", "w") as f:
        json.dump({str(orig_ids[idx]): examples[idx] for idx in examples}, f, indent=2)
    
    logger.info(f"Prepared {len(examples)} high-quality examples for few-shot prompting")
    return examples

# -------------------- Evaluation Functions --------------------
def evaluate_zero_shot(texts, true_labels, orig_ids, author_names=None):
    """Evaluate zero-shot performance with Ollama - now with author names"""
    logger.info("Evaluating zero-shot performance with Llama3")
    logger.info("This establishes our baseline without any examples or fine-tuning")
    
    predictions = []
    responses = []
    times = []
    
    # Create author information for the prompt
    if author_names:
        author_info = []
        for author_id in orig_ids:
            author_name = author_names.get(author_id, f"Unknown Author")
            author_info.append(f"Author ID {author_id}: {author_name}")
        author_list = "\n".join(author_info)
    else:
        author_list = ", ".join(str(id) for id in orig_ids)
    
    for i, txt in enumerate(tqdm(texts, desc="Zero-shot evaluation")):
        start_time = time.time()
        
        # Create a more informative zero-shot prompt
        prompt = (
            f"Task: Victorian Era Authorship Attribution\n\n"
            f"Instructions: Given a text fragment from the Victorian era, identify the author who wrote it. "
            f"Victorian literature spans roughly from 1837 to 1901 during Queen Victoria's reign, "
            f"and includes authors with distinctive writing styles, vocabulary, themes, and sentence structures.\n\n"
            f"Available authors:\n{author_list}\n\n"
            f"Text fragment to analyze:\n\"{txt}\"\n\n"
            f"Based on the writing style, vocabulary choices, sentence structure, and thematic elements, "
            f"identify which author wrote this text fragment. "
            f"Respond with ONLY the numeric author ID."
        )
        
        response = call_ollama(prompt)
        end_time = time.time()
        
        author_id = extract_author_id(response, orig_ids, author_names)
        predictions.append(author_id if author_id else "Unknown")
        responses.append(response)
        times.append(end_time - start_time)
        
        if (i+1) % 10 == 0 or i == 0:
            logger.info(f"Zero-shot progress: {i+1}/{len(texts)} samples processed")
    
    # Calculate detailed metrics
    valid_preds = [(true, pred) for true, pred in zip(true_labels, predictions) if pred != "Unknown"]
    if valid_preds:
        true_valid = [t for t, _ in valid_preds]
        pred_valid = [p for _, p in valid_preds]
        accuracy = accuracy_score(true_valid, pred_valid)
        f1 = f1_score(true_valid, pred_valid, average='weighted')
        logger.info(f"Zero-shot results:")
        logger.info(f"  - Accuracy: {accuracy:.2%} ({len(valid_preds)}/{len(true_labels)} valid predictions)")
        logger.info(f"  - F1 Score (weighted): {f1:.4f}")
        logger.info(f"  - Mean response time: {np.mean(times):.2f} seconds")
        
        # Save classification report
        report = classification_report(true_valid, pred_valid, output_dict=True)
        with open(f"{results_dir}/zero_shot_classification_report.json", "w") as f:
            json.dump(report, f, indent=2)
    else:
        accuracy = 0
        logger.warning("No valid zero-shot predictions")

    # Save detailed results
    zero_shot_results = pd.DataFrame({
        'text': texts,
        'true_label': true_labels,
        'prediction': predictions,
        'response': responses,
        'time': times
    })
    zero_shot_results.to_csv(f"{results_dir}/zero_shot_results.csv", index=False)
    
    return predictions, accuracy, zero_shot_results

def evaluate_few_shot(texts, true_labels, n_shots, train_df, examples, num_authors, orig_ids, author_names=None):
    """Evaluate few-shot performance with carefully selected examples, now with author names"""
    logger.info(f"Evaluating {n_shots}-shot performance with Llama3")
    logger.info(f"Using {n_shots} carefully selected examples to guide the model")
    
    predictions = []
    responses = []
    times = []
    per_author_correct = {str(author_id): 0 for author_id in orig_ids}
    per_author_total = {str(author_id): 0 for author_id in orig_ids}

    # Create author information for the prompt
    if author_names:
        author_info = []
        for author_id in orig_ids:
            author_name = author_names.get(author_id, f"Unknown Author")
            author_info.append(f"Author ID {author_id}: {author_name}")
        author_list = "\n".join(author_info)
    else:
        author_list = ", ".join(str(id) for id in orig_ids)
    
    # Prepare available examples for each author
    available_authors = list(examples.keys())
    
    for i, txt in enumerate(tqdm(texts, desc=f"{n_shots}-shot evaluation")):
        start_time = time.time()
        true_author = true_labels[i]
        per_author_total[true_author] += 1
        
        # IMPROVED: Completely random example selection
        # Each author has an equal chance of being included in examples
        # No special treatment for true author (neither included nor excluded)
        
        # Randomly select n_shots authors from all available authors
        selected_authors = []
        if available_authors and len(available_authors) >= n_shots:
            selected_authors = random.sample(available_authors, n_shots)
        else:
            # If we don't have enough authors, take what we can
            selected_authors = available_authors.copy()
            # Fill the rest with random authors (might include duplicates)
            remaining_slots = n_shots - len(selected_authors)
            if remaining_slots > 0 and available_authors:
                # Allow duplicates if necessary
                for _ in range(remaining_slots):
                    selected_authors.append(random.choice(available_authors))
        
        # If we still need more examples, allow duplicates from remaining authors
        if len(selected_authors) < n_shots:
            remaining_slots = n_shots - len(selected_authors)
            all_remaining = [a for a in available_authors if a not in selected_authors]
            if all_remaining:
                random.shuffle(all_remaining)
                # Cycle through remaining if needed
                for j in range(remaining_slots):
                    selected_authors.append(all_remaining[j % len(all_remaining)])
        
        # Shuffle to avoid position bias
        random.shuffle(selected_authors)
        
        # For each author, include an example with author name if available
        example_texts = []
        for auth_idx in selected_authors:
            if auth_idx in examples:
                author_id = orig_ids[auth_idx]
                author_display = f"Author ID {author_id}"
                if author_names and author_id in author_names:
                    author_display = f"{author_names[author_id]} (ID: {author_id})"
                
                example_texts.append(f"Example from {author_display}:\n\"{examples[auth_idx]}\"\n")
                
        examples_text = "\n".join(example_texts)

        # Create an improved few-shot prompt with clear definitions
        prompt = (
            f"Task: Victorian Era Authorship Attribution\n\n"
            f"Instructions: Given a text fragment from the Victorian era, identify the author who wrote it. "
            f"Victorian literature spans roughly from 1837 to 1901 during Queen Victoria's reign, "
            f"and includes authors with distinctive writing styles, vocabulary, themes, and sentence structures.\n\n"
            f"Here are some example text fragments with their authors:\n\n"
            f"{examples_text}\n"
            f"Available authors:\n{author_list}\n\n"
            f"Text fragment to analyze:\n\"{txt}\"\n\n"
            f"Based on the writing style, vocabulary choices, sentence structure, and thematic elements, "
            f"identify which author wrote this text fragment. Compare the writing style to the examples provided. "
            f"Respond with ONLY the numeric author ID."
        )

        response = call_ollama(prompt)
        end_time = time.time()
        
        author_id = extract_author_id(response, orig_ids, author_names)
        prediction = author_id if author_id else "Unknown"
        predictions.append(prediction)
        responses.append(response)
        times.append(end_time - start_time)
        
        # Track per-author accuracy
        if prediction == true_author:
            per_author_correct[true_author] += 1
        
        if (i+1) % 10 == 0 or i == 0:
            logger.info(f"{n_shots}-shot progress: {i+1}/{len(texts)} samples processed")

    # Calculate detailed metrics
    valid_preds = [(true, pred) for true, pred in zip(true_labels, predictions) if pred != "Unknown"]
    if valid_preds:
        true_valid = [t for t, _ in valid_preds]
        pred_valid = [p for _, p in valid_preds]
        accuracy = accuracy_score(true_valid, pred_valid)
        f1 = f1_score(true_valid, pred_valid, average='weighted')
        
        # Calculate per-author accuracy
        per_author_accuracy = {
            author_id: per_author_correct[author_id]/per_author_total[author_id] 
            if per_author_total[author_id] > 0 else 0
            for author_id in per_author_correct
        }
        
        # Log detailed results
        logger.info(f"{n_shots}-shot results:")
        logger.info(f"  - Accuracy: {accuracy:.2%} ({len(valid_preds)}/{len(true_labels)} valid predictions)")
        logger.info(f"  - F1 Score (weighted): {f1:.4f}")
        logger.info(f"  - Mean response time: {np.mean(times):.2f} seconds")
        
        # Find authors with highest/lowest accuracy
        if per_author_accuracy:
            best_author_id = max(per_author_accuracy.items(), key=lambda x: x[1])[0]
            worst_author_id = min(per_author_accuracy.items(), key=lambda x: x[1])[0]
            
            best_author_name = author_names.get(int(best_author_id)) if author_names else None
            worst_author_name = author_names.get(int(worst_author_id)) if author_names else None
            
            if best_author_name:
                logger.info(f"  - Author with highest accuracy: {best_author_name} (ID: {best_author_id})")
            else:
                logger.info(f"  - Author with highest accuracy: ID {best_author_id}")
                
            if worst_author_name:
                logger.info(f"  - Author with lowest accuracy: {worst_author_name} (ID: {worst_author_id})")
            else:
                logger.info(f"  - Author with lowest accuracy: ID {worst_author_id}")
        
        # Save classification report
        report = classification_report(true_valid, pred_valid, output_dict=True)
        with open(f"{results_dir}/{n_shots}_shot_classification_report.json", "w") as f:
            json.dump(report, f, indent=2)
            
        # Save per-author accuracy
        with open(f"{results_dir}/{n_shots}_shot_per_author_accuracy.json", "w") as f:
            json.dump(per_author_accuracy, f, indent=2)
    else:
        accuracy = 0
        logger.warning(f"No valid {n_shots}-shot predictions")

    # Save detailed results
    few_shot_results = pd.DataFrame({
        'text': texts,
        'true_label': true_labels,
        'prediction': predictions,
        'response': responses,
        'time': times
    })
    few_shot_results.to_csv(f"{results_dir}/{n_shots}_shot_results.csv", index=False)

    return predictions, accuracy, few_shot_results

# -------------------- Results Visualization --------------------
def visualize_results(results, author_names=None):
    """Create enhanced visualizations for method comparison"""
    logger.info("Creating comprehensive visualizations for method comparison")
    
    # Create directory for visualizations
    vis_dir = f"{results_dir}/visualizations"
    os.makedirs(vis_dir, exist_ok=True)
    
    # 1. Method comparison bar chart with error bars
    plt.figure(figsize=(12, 8))
    methods = list(results.keys())
    accuracies = [results[m]['accuracy'] for m in methods]
    
    # Calculate confidence intervals (95%)
    confidence_intervals = []
    for method in methods:
        if 'df' in results[method]:
            df = results[method]['df']
            valid_indices = df['prediction'] != "Unknown"
            if sum(valid_indices) > 0:
                true_valid = df['true_label'][valid_indices].tolist()
                pred_valid = df['prediction'][valid_indices].tolist()
                correct = [1 if t == p else 0 for t, p in zip(true_valid, pred_valid)]
                n = len(correct)
                std_err = np.std(correct, ddof=1) / np.sqrt(n) if n > 1 else 0
                # 95% confidence interval
                ci = 1.96 * std_err
                confidence_intervals.append(ci)
            else:
                confidence_intervals.append(0)
        else:
            confidence_intervals.append(0)
    
    # Create a better bar chart with error bars
    bars = plt.bar(methods, accuracies, yerr=confidence_intervals, 
                   capsize=10, color=plt.cm.viridis(np.linspace(0, 0.8, len(methods))),
                   edgecolor='black', linewidth=1.5, alpha=0.8)
    
    plt.ylim(0, min(1.0, max(accuracies) + 0.15))
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.title('Victorian Author Attribution Methods Comparison (Llama3)', fontsize=16, fontweight='bold')
    plt.ylabel('Accuracy', fontsize=14)
    plt.xlabel('Method', fontsize=14)
    plt.xticks(fontsize=12, rotation=15)
    
    # Add value labels with percentages
    for i, bar in enumerate(bars):
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 0.02,
                f"{height:.1%}", ha='center', va='bottom', fontsize=12, fontweight='bold')
    
    # Add note about confidence intervals
    plt.figtext(0.5, 0.01, "Error bars represent 95% confidence intervals", 
                ha="center", fontsize=10, style='italic')
    
    plt.tight_layout(rect=[0, 0.03, 1, 0.98])
    plt.savefig(f'{vis_dir}/method_comparison.png', dpi=300)
    plt.close()
    
    # 2. Per-author accuracy comparison
    plt.figure(figsize=(14, 9))
    
    author_methods = {}
    # Collect per-author accuracies for each method
    for method in results.keys():
        method_path = None
        if method.endswith("-shot"):
            n_shots = method.split("-")[0]
            method_path = f"{results_dir}/{n_shots}_shot_per_author_accuracy.json"
        
        if method_path and os.path.exists(method_path):
            with open(method_path, 'r') as f:
                author_acc = json.load(f)
                for author, acc in author_acc.items():
                    if author not in author_methods:
                        author_methods[author] = {}
                    author_methods[author][method] = acc
    
    # Plot per-author comparison
    if author_methods:
        # Sort authors by accuracy in the best method
        author_avg_acc = {author: np.mean(list(methods.values())) 
                         for author, methods in author_methods.items()}
        sorted_authors = sorted(author_avg_acc.items(), key=lambda x: x[1], reverse=True)
        top_authors = [a[0] for a in sorted_authors[:min(10, len(sorted_authors))]]
        
        # Create data for plotting
        plot_data = []
        for author in top_authors:
            # Add author name if available
            author_display = author
            if author_names:
                try:
                    author_id = int(author)
                    if author_id in author_names:
                        author_display = f"{author_names[author_id]}"
                except:
                    pass
                
            for method, acc in author_methods[author].items():
                plot_data.append({
                    'Author': author_display,
                    'Author ID': author,
                    'Method': method,
                    'Accuracy': acc
                })
        
        plot_df = pd.DataFrame(plot_data)
        
        # Create grouped bar chart
        plt.figure(figsize=(14, 10))
        ax = sns.barplot(x='Author', y='Accuracy', hue='Method', data=plot_df, palette='viridis')
        
        plt.title('Per-Author Accuracy Comparison Across Methods', fontsize=16, fontweight='bold')
        plt.ylabel('Accuracy', fontsize=14)
        plt.xlabel('Author', fontsize=14)
        plt.xticks(fontsize=12)
        plt.yticks(fontsize=12)
        plt.ylim(0, 1.05)
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.legend(title='Method', fontsize=12, title_fontsize=13)
        
        # Rotate x-axis labels if there are many
        if len(top_authors) > 5:
            plt.xticks(rotation=45)
        
        plt.tight_layout()
        plt.savefig(f'{vis_dir}/per_author_comparison.png', dpi=300)
        plt.close()
    
    # 3. Enhanced confusion matrices for the best method
    best_method = max(results.items(), key=lambda x: x[1]['accuracy'])[0]
    if 'df' in results[best_method]:
        df = results[best_method]['df']
        valid_indices = df['prediction'] != "Unknown"
        if sum(valid_indices) > 0:
            true_valid = df['true_label'][valid_indices].tolist()
            pred_valid = df['prediction'][valid_indices].tolist()
            
            # Create confusion matrix
            unique_labels = sorted(list(set(true_valid + pred_valid)))
            cm = confusion_matrix(true_valid, pred_valid, labels=unique_labels)
            
            # Normalize confusion matrix
            cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
            cm_norm = np.nan_to_num(cm_norm)  # Replace NaN with 0
            
            # Plot enhanced confusion matrix
            plt.figure(figsize=(14, 12))
            
            # Create label mapping with author names if available
            if author_names:
                unique_label_names = []
                for label in unique_labels:
                    try:
                        author_id = int(label)
                        if author_id in author_names:
                            unique_label_names.append(f"{author_names[author_id]}")
                        else:
                            unique_label_names.append(f"ID {label}")
                    except:
                        unique_label_names.append(f"ID {label}")
            else:
                unique_label_names = [f"ID {label}" for label in unique_labels]
            
            # Plot with improved aesthetics
            sns.heatmap(cm_norm, annot=cm, fmt='d', cmap='Blues', 
                       xticklabels=unique_label_names, yticklabels=unique_label_names,
                       linewidths=0.5, cbar_kws={"shrink": 0.8})
            
            plt.title(f'Confusion Matrix - {best_method} Method (Best Performance)', fontsize=16, fontweight='bold')
            plt.ylabel('True Author', fontsize=14)
            plt.xlabel('Predicted Author', fontsize=14)
            plt.xticks(fontsize=10, rotation=45, ha='right')
            plt.yticks(fontsize=10)
            
            # Add annotations for diagonal accuracy
            for i, label in enumerate(unique_labels):
                if i < len(cm) and i < len(cm_norm):
                    accuracy = cm_norm[i, i]
                    plt.text(i+0.5, i+0.2, f"{accuracy:.1%}", ha='center', fontsize=9, color='darkred', fontweight='bold')
            
            plt.tight_layout()
            plt.savefig(f'{vis_dir}/best_method_confusion_matrix.png', dpi=300)
            plt.close()
    
    # 4. Timing comparison with better visualization
    plt.figure(figsize=(12, 7))
    method_times = {m: results[m]['df']['time'].mean() for m in results if 'df' in results[m]}
    
    methods = list(method_times.keys())
    times = list(method_times.values())
    
    # Create bar chart with error bars for standard deviation
    method_stds = {m: results[m]['df']['time'].std() for m in results if 'df' in results[m]}
    std_times = [method_stds[m] for m in methods]
    
    # Use a more visually distinct color palette
    bars = plt.bar(methods, times, yerr=std_times, capsize=10, 
                  color=plt.cm.rocket(np.linspace(0.2, 0.8, len(methods))),
                  edgecolor='black', linewidth=1.5, alpha=0.8)
    
    plt.title('Average Processing Time per Sample', fontsize=16, fontweight='bold')
    plt.ylabel('Time (seconds)', fontsize=14)
    plt.xlabel('Method', fontsize=14)
    plt.xticks(fontsize=12, rotation=15)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Add text labels
    for i, bar in enumerate(bars):
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 0.1,
                f"{height:.2f}s", ha='center', fontsize=12, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig(f'{vis_dir}/timing_comparison.png', dpi=300)
    plt.close()
    
    # 5. Create a performance vs. efficiency scatter plot
    if len(methods) > 1:
        plt.figure(figsize=(12, 8))
        
        accuracies = [results[m]['accuracy'] for m in methods]
        times = [method_times[m] for m in methods]
        
        # Create scatter plot with sized points based on number of examples
        sizes = []
        for m in methods:
            if m == "Zero-shot":
                sizes.append(100)
            elif "shot" in m:
                n_shots = int(m.split("-")[0])
                sizes.append(100 + n_shots * 60)
        
        # Create scatter plot with custom styling
        plt.scatter(times, accuracies, s=sizes, alpha=0.7, 
                   c=range(len(methods)), cmap='viridis', edgecolor='black', linewidth=1)
        
        # Add method labels to points
        for i, method in enumerate(methods):
            plt.annotate(method, (times[i], accuracies[i]), 
                        xytext=(5, 5), textcoords='offset points',
                        fontsize=11, fontweight='bold')
        
        plt.title('Performance vs. Efficiency Trade-off', fontsize=16, fontweight='bold')
        plt.xlabel('Average Processing Time (seconds)', fontsize=14)
        plt.ylabel('Accuracy', fontsize=14)
        plt.grid(True, linestyle='--', alpha=0.7)
        
        # Add ideal region indicator (top-left corner is best - high accuracy, low time)
        plt.axhline(y=max(accuracies), color='g', linestyle='--', alpha=0.3)
        plt.axvline(x=min(times), color='g', linestyle='--', alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(f'{vis_dir}/performance_vs_efficiency.png', dpi=300)
        plt.close()
    
    # 6. NEW: Create a chord diagram to show confusion between authors
    if 'df' in results[best_method]:
        try:
            # This is optional since it requires the circlify package
            import circlify
            
            df = results[best_method]['df']
            valid_indices = df['prediction'] != "Unknown"
            if sum(valid_indices) > 0:
                true_valid = df['true_label'][valid_indices].tolist()
                pred_valid = df['prediction'][valid_indices].tolist()
                
                # Create confusion matrix for top authors
                unique_labels = sorted(list(set(true_valid + pred_valid)))
                if len(unique_labels) > 15:
                    # Limit to most common authors for readability
                    label_counts = {}
                    for label in unique_labels:
                        label_counts[label] = true_valid.count(label) + pred_valid.count(label)
                    top_labels = sorted(label_counts.items(), key=lambda x: x[1], reverse=True)[:15]
                    unique_labels = [l[0] for l in top_labels]
                
                # Create confusion data
                confusion_data = {}
                for true, pred in zip(true_valid, pred_valid):
                    if true in unique_labels and pred in unique_labels:
                        key = (true, pred)
                        if key not in confusion_data:
                            confusion_data[key] = 0
                        confusion_data[key] += 1
                
                # Create plot data
                if confusion_data:
                    plt.figure(figsize=(14, 14))
                    
                    # Create a circular layout
                    circles = circlify.circlify(
                        [{'id': label, 'datum': true_valid.count(label)} for label in unique_labels],
                        show_enclosure=False
                    )
                    
                    # Plot circles
                    ax = plt.subplot(111, aspect='equal')
                    ax.axis('off')
                    
                    # Create center points dictionary
                    centers = {}
                    for circle in circles:
                        label = circle.ex['id']
                        x, y = circle.circle.center
                        centers[label] = (x, y)
                        r = circle.circle.radius
                        
                        # Get author name if available
                        author_display = label
                        if author_names:
                            try:
                                author_id = int(label)
                                if author_id in author_names:
                                    author_display = f"{author_names[author_id]}"
                            except:
                                pass
                        
                        # Draw circle
                        ax.add_patch(plt.Circle((x, y), r, alpha=0.5, linewidth=2, 
                                              fill=True, edgecolor='black', 
                                              facecolor=plt.cm.tab20(unique_labels.index(label) % 20)))
                        
                        # Add text label
                        plt.annotate(author_display, (x, y), ha='center', va='center', fontsize=10, fontweight='bold')
                    
                    # Draw connections between authors
                    for (true, pred), count in confusion_data.items():
                        if true != pred:  # Skip self-connections
                            # Get centers
                            if true in centers and pred in centers:
                                start = centers[true]
                                end = centers[pred]
                                
                                # Calculate thickness based on count (normalized)
                                max_count = max(confusion_data.values())
                                thickness = 0.5 + 2.0 * (count / max_count)
                                alpha = 0.3 + 0.5 * (count / max_count)
                                
                                # Draw arrow
                                ax.annotate("", xy=end, xytext=start,
                                         arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2", 
                                                        linewidth=thickness, alpha=alpha, color='red'))
                    
                    # Set axis limits
                    ax.set_xlim(-1.2, 1.2)
                    ax.set_ylim(-1.2, 1.2)
                    
                    plt.title(f'Author Confusion Network - {best_method} Method', fontsize=16, fontweight='bold')
                    plt.tight_layout()
                    plt.savefig(f'{vis_dir}/author_confusion_network.png', dpi=300)
                    plt.close()
        except ImportError:
            logger.warning("Circlify package not available for chord diagram. Skipping.")
    
    logger.info(f"Created enhanced visualizations in {vis_dir}")

# -------------------- Main Execution --------------------
def main():
    logger.info("======== VICTORIAN AUTHOR ATTRIBUTION WITH LLAMA3 PROMPTING ========")
    logger.info("Following professor's recommendation to try a prompting baseline")
    logger.info("This implementation ONLY tests zero-shot and few-shot prompting (NO training)")
    
    # Check if Ollama is running
    if not check_ollama_running():
        logger.warning("Ollama is not running. Zero-shot and Few-shot evaluations will fail.")
        logger.info("Please start Ollama with 'ollama serve' in a separate terminal.")
        
        start_ollama = input("Do you want to try starting Ollama now? (y/n): ")
        if start_ollama.lower() == 'y':
            try:
                subprocess.Popen(["ollama", "serve"])
                logger.info("Ollama started. Waiting 5 seconds for it to initialize...")
                time.sleep(5)
            except Exception as e:
                logger.error(f"Failed to start Ollama: {e}")
                logger.info("Please start Ollama manually and then continue.")
    
    # Check if we need to pull the base Llama3 model for Ollama
    try:
        response = requests.get(f"{OLLAMA_API}/tags")
        if response.status_code == 200:
            available_models = [model["name"] for model in response.json()["models"]]
            if OLLAMA_MODEL not in available_models:
                logger.warning(f"Model {OLLAMA_MODEL} not found in Ollama. Pulling it now...")
                try:
                    subprocess.run(["ollama", "pull", OLLAMA_MODEL], check=True)
                    logger.info(f"Successfully pulled {OLLAMA_MODEL}")
                except subprocess.CalledProcessError as e:
                    logger.error(f"Failed to pull {OLLAMA_MODEL}: {e}")
    except Exception as e:
        logger.error(f"Error checking Ollama models: {e}")
    
    # Load author names if available
    author_names = load_author_names(AUTHOR_LIST)
    
    # Load data with author names
    logger.info("Starting Victorian Author Attribution with Llama3 Prompting Baseline")
    logger.info("Following professor's recommendation to establish a prompting baseline with Llama3")
    df_all, orig_ids = load_and_map(TRAIN_CSV, author_names)
    num_authors = len(orig_ids)

    # Split into training and test sets
    train_df, test_df = train_test_split(
        df_all, test_size=TEST_SIZE,
        stratify=df_all['author_idx'], random_state=SEED
    )

    logger.info(f"Split sizes: Training: {len(train_df)}, Test: {len(test_df)}")

    # Sample evaluation set from the test data
    eval_df = test_df.sample(n=min(EVAL_SAMPLES, len(test_df)), random_state=SEED).reset_index(drop=True)
    eval_texts = eval_df['text'].tolist()
    eval_true_idxs = eval_df['author_idx'].tolist()
    eval_true_labels = [str(orig_ids[idx]) for idx in eval_true_idxs]

    logger.info(f"Created evaluation set with {len(eval_texts)} samples")

    # Prepare author name information for evaluation
    if author_names:
        eval_author_names = [author_names.get(int(label), f"Unknown Author {label}") for label in eval_true_labels]
        logger.info(f"Using author names from {AUTHOR_LIST} for enhanced context")
        
        # Print sample of authors in evaluation set
        sample_authors = set(eval_author_names[:min(10, len(eval_author_names))])
        logger.info(f"Sample authors in evaluation set: {', '.join(sample_authors)}")

    # Save the data splits for reproducibility
    logger.info("Saving data splits for future reproducibility")
    train_df.to_csv(f"{results_dir}/train.csv", index=False)
    test_df.to_csv(f"{results_dir}/test.csv", index=False)
    eval_df.to_csv(f"{results_dir}/eval.csv", index=False)
    
    # Prepare examples for few-shot with author names
    examples = prepare_high_quality_examples(train_df, num_authors, orig_ids, author_names)
    
    # Conduct zero-shot and few-shot evaluations with base Llama3
    results = {}
    
    try:
        # 1. Zero-shot evaluation
        logger.info("\n=== ZERO-SHOT EVALUATION ===")
        logger.info("Testing how well Llama3 can identify authors without any examples")
        zs_preds, zs_acc, zs_df = evaluate_zero_shot(eval_texts, eval_true_labels, orig_ids, author_names)
        results['Zero-shot'] = {'predictions': zs_preds, 'accuracy': zs_acc, 'df': zs_df}
        
        # 2. Few-shot evaluations with different numbers of examples
        shot_counts = [1, 3, 5]
        for n in shot_counts:
            logger.info(f"\n=== {n}-SHOT EVALUATION ===")
            logger.info(f"Testing if providing {n} examples improves performance")
            fs_preds, fs_acc, fs_df = evaluate_few_shot(
                eval_texts, eval_true_labels, n, train_df, 
                examples, num_authors, orig_ids, author_names
            )
            results[f"{n}-shot"] = {'predictions': fs_preds, 'accuracy': fs_acc, 'df': fs_df}
    except Exception as e:
        logger.error(f"Error during prompt-based evaluations: {e}")
    
    # Create comprehensive visualizations
    visualize_results(results, author_names)
    
    # Save combined results
    all_results = pd.DataFrame({
        'text': eval_texts,
        'true': eval_true_labels,
    })
    
    if author_names:
        all_results['true_author_name'] = [
            author_names.get(int(label), f"Unknown Author {label}") 
            if label.isdigit() else "Unknown" 
            for label in eval_true_labels
        ]
    
    for method, result in results.items():
        all_results[method.replace('-', '_').lower()] = result['predictions']
        
        # Add author names for predictions where possible
        if author_names:
            method_col = method.replace('-', '_').lower()
            all_results[f"{method_col}_author_name"] = [
                author_names.get(int(pred), f"Unknown Author {pred}") 
                if pred.isdigit() else "Unknown" 
                for pred in result['predictions']
            ]
    
    all_results.to_csv(f'{results_dir}/all_results.csv', index=False)
    logger.info(f"Saved detailed results to '{results_dir}/all_results.csv'")
    
    # Final summary and comparison
    logger.info("\n=== VICTORIAN AUTHOR ATTRIBUTION SUMMARY ===")
    logger.info("Method Comparison:")
    
    # Print results in descending order of accuracy
    sorted_results = sorted(results.items(), key=lambda x: x[1]['accuracy'], reverse=True)
    for i, (method, result) in enumerate(sorted_results):
        if i == 0:
            logger.info(f"  BEST METHOD: {method} - Accuracy: {result['accuracy']:.2%}")
        else:
            logger.info(f"  {method} - Accuracy: {result['accuracy']:.2%}")
    
    # Compute improvement over zero-shot (if available)
    if 'Zero-shot' in results and len(sorted_results) > 1:
        zero_shot_acc = results['Zero-shot']['accuracy']
        best_method, best_result = sorted_results[0]
        
        if best_method != 'Zero-shot':
            improvement = (best_result['accuracy'] - zero_shot_acc) / zero_shot_acc if zero_shot_acc > 0 else float('inf')
            logger.info(f"\nImprovement of best method ({best_method}) over zero-shot baseline: {improvement:.2%}")
    
    # Performance analysis
    logger.info("\nPERFORMANCE ANALYSIS:")
    logger.info("- The prompting baseline approach demonstrates how well Llama3 can identify Victorian authors without any training")
    logger.info("- Using author names provides more context and potentially improves model performance")
    logger.info("- This aligns with the professor's recommendation to try a prompting baseline with a smaller LLM")
    
    # Compare with traditional ML approaches (if available)
    logger.info("\nCOMPARISON WITH TRADITIONAL ML APPROACHES:")
    logger.info("- Traditional ML approaches (like in the LSTM attribution system) typically require extensive feature engineering")
    logger.info("- This prompting baseline provides a simpler alternative that requires less technical setup")
    logger.info("- The effectiveness of the prompting approach depends on the quality of examples and prompt engineering")
    logger.info("- Compare these results with your LSTM-based system to see the trade-offs between approaches")
    
    logger.info("\nRECOMMENDATIONS:")
    logger.info("- For best results with minimal effort, the few-shot prompting approach offers a good balance")
    logger.info("- Try experimenting with different example selection strategies to improve few-shot performance")
    logger.info("- Consider exploring hybrid approaches that combine LLM prompting with traditional ML features")
    
    logger.info("\n=== VICTORIAN AUTHOR ATTRIBUTION COMPLETE ===")

if __name__ == '__main__':
    main()

