In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!git clone https://github.com/microsoft/LLaVA-Med.git

In [None]:
%cd LLaVA-Med

In [None]:


# Install pip in editable mode with PEP 660 compatibility
!pip install --upgrade pip setuptools wheel

# Install the repo and its dependencies
!pip install -e .

!pip install numpy==1.26.4
!pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118
!pip install transformers==4.38.1
!pip uninstall -y bitsandbytes  # remove it completely


In [None]:
!mkdir -p checkpoints
!huggingface-cli download microsoft/llava-med-v1.5-mistral-7b --local-dir checkpoints/llava-med-v1.5-mistral-7b --local-dir-use-symlinks False


In [None]:
# Import necessary modules
from llava.model.builder import load_pretrained_model
from llava.mm_utils import process_images, tokenizer_image_token
from llava.conversation import conv_templates
from llava.constants import IMAGE_TOKEN_INDEX
from PIL import Image
import torch

# Load the model once
print("Loading model...")
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path="/content/checkpoints/llava-med-v1.5-mistral-7b",
    model_base=None,
    model_name="llava-med-v1.5-mistral-7b"
)
print("Model loaded successfully!")


In [None]:
# @title Default title text
# Simple Inference
import os
import glob

# Set the test folder path
# test_folder = "/content/drive/MyDrive/preprocessed_combined_dataset_2/test"
test_folder = "/content/drive/MyDrive/preprocessed_pneumothorax/test"
# test_folder = "/content/lungs"

def generate_response(image_path, prompt):
    # Load and process the image
    image = Image.open(image_path).convert('RGB')
    image_tensor = process_images([image], image_processor, model.config)
    image_tensor = image_tensor.to(model.device, dtype=torch.float16)

    # Create conversation template
    conv = conv_templates["mistral_instruct"].copy()
    conv.append_message(conv.roles[0], f"<image>\n{prompt}")
    conv.append_message(conv.roles[1], None)

    # Get the formatted prompt
    formatted_prompt = conv.get_prompt()

    # Tokenize the input
    input_ids = tokenizer_image_token(
        formatted_prompt,
        tokenizer,
        IMAGE_TOKEN_INDEX,
        return_tensors='pt'
    ).unsqueeze(0).to(model.device)

    # Generate response
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            do_sample=True,
            temperature=1,
            max_new_tokens=512,
            use_cache=True,
            pad_token_id=tokenizer.eos_token_id
        )

    # Decode the response
    response = tokenizer.decode(
        output_ids[0, input_ids.shape[1]:],
        skip_special_tokens=True
    ).strip()

    return response

# Get all image files in the test folder
# image_files = glob.glob(os.path.join(test_folder, "*.jpg"))
image_files = glob.glob(os.path.join(test_folder, "*.png"))
image_files.sort()  # Sort for consistent ordering

# Iterate through each image and classify
# prompt = "Classify the disease in this X-ray scan of the lungs as Pneumothorax or no finding: Please Explain"
prompt = "Classify the disease in this image: "
for image_path in image_files:
    print('a')
    print("a")
    image_filename = os.path.basename(image_path)
    print(f"\n--- Processing: {image_filename} ---")

    try:
        result = generate_response(image_path, prompt)
        print(f"Classification: {result}")
    except Exception as e:
        print(f"Error processing {image_filename}: {str(e)}")

    print("-" * 50)

In [None]:
# @title Default title text
#
import os
import glob
import pandas as pd

# Set the test folder path and CSV path
# test_folder = "/content/drive/MyDrive/preprocessed_combined_dataset_2/test"
# csv_path = "/content/drive/MyDrive/preprocessed_combined_dataset_2/test.csv"
test_folder = "/content/drive/MyDrive/preprocessed_pneumothorax/test"
csv_path = "/content/drive/MyDrive/preprocessed_pneumothorax/test.csv"

def generate_response(image_path, prompt):
    print(f"  Loading image: {image_path}")

    # Load and process the image
    image = Image.open(image_path).convert('RGB')
    print(f"  Image size: {image.size}")

    image_tensor = process_images([image], image_processor, model.config)
    image_tensor = image_tensor.to(model.device, dtype=torch.float16)
    print(f"  Image tensor shape: {image_tensor.shape}")

    # Create conversation template
    conv = conv_templates["mistral_instruct"].copy()
    conv.append_message(conv.roles[0], f"<image>\n{prompt}")
    conv.append_message(conv.roles[1], None)

    # Get the formatted prompt
    formatted_prompt = conv.get_prompt()
    print(f"  Formatted prompt length: {len(formatted_prompt)}")

    # Tokenize the input
    input_ids = tokenizer_image_token(
        formatted_prompt,
        tokenizer,
        IMAGE_TOKEN_INDEX,
        return_tensors='pt'
    ).unsqueeze(0).to(model.device)
    print(f"  Input IDs shape: {input_ids.shape}")

    # Generate response with more conservative settings
    print("  Generating response...")
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            do_sample=False,  # Changed to deterministic
            temperature=.5,  # Higher temperature
            max_new_tokens=512,  # Reduced tokens
            use_cache=True,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            stopping_criteria=None
        )

    print(f"  Output IDs shape: {output_ids.shape}")
    print(f"  Generated tokens: {output_ids.shape[1] - input_ids.shape[1]}")

    # Decode the response
    response = tokenizer.decode(
        output_ids[0, input_ids.shape[1]:],
        skip_special_tokens=True
    ).strip()

    print(f"  Raw response length: {len(response)}")

    return response

def classify_response(response):
    response_lower = response.lower()

    hs_keywords = [
        'hallervorden-spatz', 'hallervorden spatz', 'pkan',
        'pantothenate kinase', 'neurodegeneration with brain iron',
        'eye of the tiger', 'eye-of-the-tiger'
    ]

    hs_indirect_clues = [
        'iron deposits', 'abnormal iron accumulation',
        't2 hypointensity', 'eye of the tiger', 'globus pallidus'
    ]

    # Strong direct match
    for keyword in hs_keywords:
        if keyword in response_lower:
            return 1

    # Only trigger on indirect clues if abnormality is also mentioned
    if any(indirect in response_lower for indirect in hs_indirect_clues):
        if 'abnormal' in response_lower or 'consistent with disease' in response_lower:
            return 1

    # If clearly states normal
    if 'normal' in response_lower or 'no abnormality' in response_lower:
        return 0

    # Default to normal
    return 0


# Load the ground truth labels
df = pd.read_csv(csv_path)
print(f"Loaded {len(df)} test samples from CSV")

# Create dictionaries for quick lookup
filename_to_label = dict(zip(df.iloc[:, 0], df.iloc[:, 1]))

# Lists to store results
predictions = []
ground_truth = []
filenames = []

# Iterate through each image and classify
# prompt = "Classify the disease in this MRI scan of a brain: "
# prompt = "Is this a case of Hallervorden-Spatz disease or a normal brain T2 MRI? Please explain."
# prompt = "Classify this brain MRI as either normal or showing Hallervorden-Spatz disease."
prompt = "This is a T2-weighted MRI of the brain. Please identify if it shows signs of Pantothenate Kinase-Associated Neurodegeneration (PKAN)."

print("\n=== Starting Inference Accuracy Test ===\n")

for idx, (filename, true_label) in enumerate(filename_to_label.items()):
    image_path = os.path.join(test_folder, filename)

    # Check if image file exists
    if not os.path.exists(image_path):
        print(f"Warning: Image {filename} not found, skipping...")
        continue

    print(f"[{idx+1}/{len(filename_to_label)}] Processing: {filename}")
    print(f"Ground Truth: {true_label} ({'Normal' if true_label == 0 else 'Hallervorden-Spatz'})")

    try:
        # Generate response
        response = generate_response(image_path, prompt)
        print(f"Model Response: '{response}'")

        # Check if response is empty
        if not response or response.strip() == "":
            print("WARNING: Empty response from model!")
            predicted_label = 0  # Default to normal for empty responses
        else:
            # Classify the response
            predicted_label = classify_response(response)

        print(f"Predicted Label: {predicted_label} ({'Normal' if predicted_label == 0 else 'Hallervorden-Spatz'})")

        # Store results
        predictions.append(predicted_label)
        ground_truth.append(true_label)
        filenames.append(filename)

        # Show if correct
        correct = "✓" if predicted_label == true_label else "✗"
        print(f"Result: {correct}")

    except Exception as e:
        print(f"Error processing {filename}: {str(e)}")
        import traceback
        traceback.print_exc()

    print("-" * 70)

# Calculate accuracy and additional metrics
if len(predictions) > 0:
    correct_predictions = sum(1 for pred, true in zip(predictions, ground_truth) if pred == true)
    accuracy = correct_predictions / len(predictions)

    # Calculate confusion matrix components
    true_positives = sum(1 for pred, true in zip(predictions, ground_truth) if pred == 1 and true == 1)
    true_negatives = sum(1 for pred, true in zip(predictions, ground_truth) if pred == 0 and true == 0)
    false_positives = sum(1 for pred, true in zip(predictions, ground_truth) if pred == 1 and true == 0)
    false_negatives = sum(1 for pred, true in zip(predictions, ground_truth) if pred == 0 and true == 1)

    # Calculate metrics
    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0.0
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0.0
    specificity = true_negatives / (true_negatives + false_positives) if (true_negatives + false_positives) > 0 else 0.0

    print(f"\n=== ACCURACY RESULTS ===")
    print(f"Total samples processed: {len(predictions)}")
    print(f"Correct predictions: {correct_predictions}")
    print(f"Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")

    print(f"\n=== CONFUSION MATRIX ===")
    print(f"True Positives (HS predicted as HS): {true_positives}")
    print(f"True Negatives (Normal predicted as Normal): {true_negatives}")
    print(f"False Positives (Normal predicted as HS): {false_positives}")
    print(f"False Negatives (HS predicted as Normal): {false_negatives}")

    print(f"\n=== PERFORMANCE METRICS ===")
    print(f"Precision: {precision:.4f} ({precision*100:.2f}%)")
    print(f"  - Proportion of true positive predictions among all positive predictions")
    print(f"  - Important for minimizing false alarms")
    print(f"Recall (Sensitivity): {recall:.4f} ({recall*100:.2f}%)")
    print(f"  - Proportion of true positive predictions among all actual positive instances")
    print(f"  - Crucial for identifying actual cases of Hallervorden-Spatz disease")
    print(f"Specificity: {specificity:.4f} ({specificity*100:.2f}%)")
    print(f"  - Proportion of true negative predictions among all actual negative instances")
    print(f"  - Important for correctly identifying normal/healthy cases")

    # Detailed breakdown
    normal_correct = true_negatives
    hs_correct = true_positives
    normal_total = true_negatives + false_positives
    hs_total = true_positives + false_negatives

    print(f"\n=== DETAILED RESULTS BY CLASS ===")
    normal_pct = f"{normal_correct/normal_total*100:.1f}%" if normal_total > 0 else "N/A"
    hs_pct = f"{hs_correct/hs_total*100:.1f}%" if hs_total > 0 else "N/A"
    print(f"Normal cases: {normal_correct}/{normal_total} correct ({normal_pct})")
    print(f"Hallervorden-Spatz cases: {hs_correct}/{hs_total} correct ({hs_pct})")

    # Additional derived metrics
    if precision > 0 and recall > 0:
        f1_score = 2 * (precision * recall) / (precision + recall)
        print(f"\nF1-Score: {f1_score:.4f} ({f1_score*100:.2f}%)")
        print(f"  - Harmonic mean of precision and recall")

    # Show misclassified cases
    print(f"\n=== MISCLASSIFIED CASES ===")
    misclassified_count = 0
    for filename, pred, true in zip(filenames, predictions, ground_truth):
        if pred != true:
            misclassified_count += 1
            pred_label = 'Normal' if pred == 0 else 'Hallervorden-Spatz'
            true_label = 'Normal' if true == 0 else 'Hallervorden-Spatz'
            error_type = ""
            if pred == 1 and true == 0:
                error_type = " (False Positive)"
            elif pred == 0 and true == 1:
                error_type = " (False Negative)"
            print(f"  {filename}: Predicted {pred_label}, Actual {true_label}{error_type}")

    if misclassified_count == 0:
        print("  No misclassified cases - Perfect classification!")

    # Summary for quick reference
    print(f"\n=== QUICK SUMMARY ===")
    print(f"Accuracy: {accuracy*100:.1f}% | Precision: {precision*100:.1f}% | Recall: {recall*100:.1f}% | Specificity: {specificity*100:.1f}%")

else:
    print("No predictions were made successfully.")

In [None]:
# @title Default title text
import os
import pandas as pd
from PIL import Image
import torch

# Set the test folder path and CSV paths
test_folder = "/content/drive/MyDrive/preprocessed_pneumothorax/test"
csv_path = "/content/drive/MyDrive/preprocessed_pneumothorax/test.csv"
metadata_csv_path = "/content/drive/MyDrive/preprocessed_pneumothorax/pneumothorax_combined.csv"

# Load the metadata CSV and prepare a lookup dictionary
background_df = pd.read_csv(metadata_csv_path)
background_lookup = background_df.set_index("Image Index")[["Follow-up #", "Patient Age", "Patient Sex"]].to_dict(orient="index")

# Load ground truth labels
df = pd.read_csv(csv_path)
filename_to_label = dict(zip(df.iloc[:, 0], df.iloc[:, 1]))
print(f"Loaded {len(df)} test samples from CSV")

# Make sure you have the right template
from llava.conversation import conv_templates, SeparatorStyle

def generate_response(image_path, metadata=None):
    print(f"  Loading image: {image_path}")

    image = Image.open(image_path).convert('RGB')
    print(f"  Image size: {image.size}")

    # Process image correctly
    image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values']
    image_tensor = image_tensor.to(model.device, dtype=torch.float16)
    print(f"  Image tensor shape: {image_tensor.shape}")

    # Use your original working prompt structure
    if metadata:
        prompt = f"This chest X-ray is from a {metadata['Patient Age']}-year-old {metadata['Patient Sex']} patient. Does this X-ray show pneumothorax? Answer Yes or No and explain briefly."
    else:
        prompt = "Does this chest X-ray show pneumothorax? Answer Yes or No and explain briefly."

    # Use the correct conversation template (NO SYSTEM PROMPT)
    conv = conv_templates["mistral_instruct"].copy()
    conv.append_message(conv.roles[0], f"<image>\n{prompt}")
    conv.append_message(conv.roles[1], None)

    prompt_formatted = conv.get_prompt()
    print(f"  Formatted prompt length: {len(prompt_formatted)}")

    # Tokenize
    input_ids = tokenizer_image_token(
        prompt_formatted,
        tokenizer,
        IMAGE_TOKEN_INDEX,
        return_tensors='pt'
    ).unsqueeze(0).to(model.device)

    print(f"  Input IDs shape: {input_ids.shape}")
    print("  Generating response...")

    # Generate with your original working parameters
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            do_sample=True,
            temperature=0.7,  # Back to your original value
            top_p=0.9,
            max_new_tokens=128,  # Back to your original value
            use_cache=True,
            repetition_penalty=1.1,
            pad_token_id=tokenizer.eos_token_id,
        )

    print(f"  Output IDs shape: {output_ids.shape}")
    print(f"  Input length: {input_ids.shape[1]}, Output length: {output_ids.shape[1]}")

    # Simple approach: get the new tokens
    response = tokenizer.decode(
        output_ids[0, input_ids.shape[1]:],
        skip_special_tokens=True
    ).strip()

    # Debug: also show full response
    full_response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    print(f"  Full decoded response:\n{full_response}")
    print(f"  Extracted response: '{response}'")

    # If response is empty, try to extract from full response
    if not response and full_response:
        print("  Response empty, trying to extract from full response...")
        # Look for Yes/No patterns in full response
        full_clean = full_response.strip()
        if "Yes," in full_clean or "No," in full_clean:
            # Find the sentence that starts with Yes/No
            sentences = full_clean.split('.')
            for sentence in sentences:
                if sentence.strip().startswith(('Yes', 'No')):
                    response = sentence.strip() + "."
                    print(f"  Found Yes/No sentence: '{response}'")
                    break

    return response


def classify_response(response):
    """
    Classify response as pneumothorax (1) or no finding (0)
    """
    if not response or not response.strip():
        print("    Empty response - defaulting to No Finding")
        return 0

    response_lower = response.lower().strip()
    print(f"    Classifying: '{response[:100]}...'")

    # Check for explicit Yes/No at the start
    if response_lower.startswith("yes"):
        print(f"    Found 'Yes' at start -> Pneumothorax")
        return 1
    elif response_lower.startswith("no"):
        print(f"    Found 'No' at start -> No Finding")
        return 0

    # Check for pneumothorax keywords if no clear Yes/No
    pneumothorax_indicators = [
        "pneumothorax is present",
        "shows pneumothorax",
        "pneumothorax visible",
        "consistent with pneumothorax",
        "indicative of pneumothorax",
        "evidence of pneumothorax",
        "pneumothorax can be seen"
    ]

    negative_indicators = [
        "no pneumothorax",
        "does not show pneumothorax",
        "not indicate pneumothorax",
        "no evidence of pneumothorax",
        "no signs of pneumothorax",
        "normal chest x-ray",
        "appears normal"
    ]

    # Check for negative indicators first (they're often more explicit)
    for indicator in negative_indicators:
        if indicator in response_lower:
            print(f"    Found negative indicator: '{indicator}' -> No Finding")
            return 0

    # Check for positive indicators
    for indicator in pneumothorax_indicators:
        if indicator in response_lower:
            print(f"    Found positive indicator: '{indicator}' -> Pneumothorax")
            return 1

    # If contains "pneumothorax" but no clear positive/negative context
    if "pneumothorax" in response_lower:
        print(f"    Contains 'pneumothorax' but unclear context -> defaulting to No Finding")
        return 0  # Changed from 1 to 0 for safety

    print(f"    No clear indicators found -> defaulting to No Finding")
    return 0



predictions = []
ground_truth = []
filenames = []

print("\n=== Starting Inference Accuracy Test ===\n")

import random

# Convert dict items to list and shuffle the order
random.seed(42)  # or any number
shuffled_items = list(filename_to_label.items())
random.shuffle(shuffled_items)

for idx, (filename, true_label) in enumerate(shuffled_items):
    image_path = os.path.join(test_folder, filename)

    if not os.path.exists(image_path):
        print(f"Warning: Image {filename} not found, skipping...")
        continue

    metadata = background_lookup.get(filename)
    if not metadata:
        print(f"Warning: No metadata for {filename}, skipping...")
        continue

    print(f"[{idx+1}/{len(filename_to_label)}] Processing: {filename}")
    print(f"Ground Truth: {true_label} ({'No Finding' if true_label == 0 else 'Pneumothorax'})")

    try:
        response = generate_response(image_path, metadata)
        print(f"Model Response: '{response}'")

        if not response.strip():
            print("WARNING: Empty response from model!")
            predicted_label = 0
        else:
            predicted_label = classify_response(response)

        print(f"Predicted Label: {predicted_label} ({'No Finding' if predicted_label == 0 else 'Pneumothorax'})")
        predictions.append(predicted_label)
        ground_truth.append(true_label)
        filenames.append(filename)

        correct = "✓" if predicted_label == true_label else "✗"
        print(f"Result: {correct}")

    except Exception as e:
        print(f"Error processing {filename}: {str(e)}")
        import traceback
        traceback.print_exc()

    print("-" * 70)

# Metrics calculation
if predictions:
    correct_predictions = sum(p == t for p, t in zip(predictions, ground_truth))
    accuracy = correct_predictions / len(predictions)

    tp = sum(p == t == 1 for p, t in zip(predictions, ground_truth))
    tn = sum(p == t == 0 for p, t in zip(predictions, ground_truth))
    fp = sum(p == 1 and t == 0 for p, t in zip(predictions, ground_truth))
    fn = sum(p == 0 and t == 1 for p, t in zip(predictions, ground_truth))

    precision = tp / (tp + fp) if (tp + fp) else 0
    recall = tp / (tp + fn) if (tp + fn) else 0
    specificity = tn / (tn + fp) if (tn + fp) else 0

    print(f"\n=== ACCURACY RESULTS ===")
    print(f"Total samples processed: {len(predictions)}")
    print(f"Correct predictions: {correct_predictions}")
    print(f"Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")

    print(f"\n=== CONFUSION MATRIX ===")
    print(f"TP (Pneumothorax → Pneumothorax): {tp}")
    print(f"TN (No Finding → No Finding): {tn}")
    print(f"FP (No Finding → Pneumothorax): {fp}")
    print(f"FN (Pneumothorax → No Finding): {fn}")

    print(f"\n=== METRICS ===")
    print(f"Precision: {precision:.4f} ({precision*100:.2f}%)")
    print(f"Recall: {recall:.4f} ({recall*100:.2f}%)")
    print(f"Specificity: {specificity:.4f} ({specificity*100:.2f}%)")

    if precision > 0 and recall > 0:
        f1 = 2 * (precision * recall) / (precision + recall)
        print(f"F1-Score: {f1:.4f} ({f1*100:.2f}%)")

    print(f"\n=== MISCLASSIFIED CASES ===")
    misclassified = 0
    for fn, pred, true in zip(filenames, predictions, ground_truth):
        if pred != true:
            misclassified += 1
            pred_label = 'Normal' if pred == 0 else 'Pneumothorax'
            true_label = 'Normal' if true == 0 else 'Pneumothorax'
            print(f"  {fn}: Predicted {pred_label}, Actual {true_label}")
    if misclassified == 0:
        print("  No misclassifications – perfect result!")

    print(f"\n=== SUMMARY ===")
    print(f"Accuracy: {accuracy*100:.1f}%, Precision: {precision*100:.1f}%, Recall: {recall*100:.1f}%, Specificity: {specificity*100:.1f}%")

else:
    print("No successful predictions were made.")


In [None]:
import os
import pandas as pd
import torch
from PIL import Image
from llava.conversation import conv_templates
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import random

# Paths
test_folder = "/content/drive/MyDrive/preprocessed_pneumothorax/test"
csv_path = "/content/drive/MyDrive/preprocessed_pneumothorax/test.csv"
metadata_csv_path = "/content/drive/MyDrive/preprocessed_pneumothorax/pneumothorax_combined.csv"

# Load labels and metadata
df = pd.read_csv(csv_path)
filename_to_label = dict(zip(df.iloc[:, 0], df.iloc[:, 1]))
background_df = pd.read_csv(metadata_csv_path)
background_lookup = background_df.set_index("Image Index")[["Follow-up #", "Patient Age", "Patient Sex"]].to_dict(orient="index")

# Inference helper
def generate_response(image_path, metadata=None):
    image = Image.open(image_path).convert('RGB')
    image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].to(model.device, dtype=torch.float16)

    prompt = (
        f"You are an expert radiologist. This chest X-ray is from a {metadata['Patient Age']}-year-old (Sex: {metadata['Patient Sex']}) patient. "
        f"Does this chest X-ray show a pneumothorax? Answer 'Yes' or 'No' and provide an explanation"
        if metadata else
        "Does this chest X-ray show a pneumothorax? Answer 'Yes' or 'No' and provide an explanation"
    )

    conv = conv_templates["mistral_instruct"].copy()
    conv.append_message(conv.roles[0], f"<image>\n{prompt}")
    conv.append_message(conv.roles[1], None)
    prompt_formatted = conv.get_prompt()

    input_ids = tokenizer_image_token(prompt_formatted, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            do_sample=True,
            temperature=0.5,
            top_p=0.9,
            max_new_tokens=256,
            use_cache=True,
            repetition_penalty=1.1,
            pad_token_id=tokenizer.eos_token_id,
        )

    full_decoded = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
    response = full_decoded.split("### Assistant:")[-1].strip() if "### Assistant:" in full_decoded else full_decoded
    return response

# Evaluation tracking
true_labels = []
predicted_labels = []

# Shuffle dataset
shuffled_items = list(filename_to_label.items())
random.seed(42)
random.shuffle(shuffled_items)

# Process each image
for i, (filename, label) in enumerate(shuffled_items):
    image_path = os.path.join(test_folder, filename)
    if not os.path.exists(image_path):
        continue

    metadata = background_lookup.get(filename)
    response = generate_response(image_path, metadata)

    # Determine predicted label from response
    predicted_label = int(response.strip().lower().startswith("yes"))

    # Record prediction and ground truth
    true_labels.append(int(label))
    predicted_labels.append(predicted_label)

    print(f"Filename       : {filename}")
    print(f"Ground Truth   : {'Pneumothorax' if label == 1 else 'No Finding'}")
    print(f"Model Response : {response}")
    print("-" * 60)

    # Print metrics every 10 samples
    if (i + 1) % 10 == 0:
        tn, fp, fn, tp = confusion_matrix(true_labels, predicted_labels, labels=[0, 1]).ravel()
        acc = accuracy_score(true_labels, predicted_labels)
        prec = precision_score(true_labels, predicted_labels, zero_division=0)
        rec = recall_score(true_labels, predicted_labels, zero_division=0)
        f1 = f1_score(true_labels, predicted_labels, zero_division=0)
        spec = tn / (tn + fp) if (tn + fp) > 0 else 0.0

        print(f"After {i+1} samples:")
        print(f"  Accuracy   : {acc:.4f}")
        print(f"  Precision  : {prec:.4f}")
        print(f"  Recall     : {rec:.4f}")
        print(f"  Specificity: {spec:.4f}")
        print(f"  F1 Score   : {f1:.4f}")
        print("=" * 60)
