In [2]:
import pandas as pd
from langchain_ollama import ChatOllama
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split

# Initialize Llama3 or Mistral with optimized settings
llm = ChatOllama(
    model="mistral:latest",
    temperature=0.1,  # Lower temperature for more consistent medical reasoning
    top_p=0.95,
    repeat_penalty=1.2
)

def load_data(csv_path):
    """Load and prepare data from CSV"""
    df = pd.read_csv(csv_path)
    
    # Handle missing values
    df['bmi'] = pd.to_numeric(df['bmi'], errors='coerce')
    df['bmi'].fillna(df['bmi'].median(), inplace=True)

    return df

def create_few_shot_examples(df):
    """Select and format balanced examples"""
    # Find a positive case (stroke=1)
    positive_example = df[df['stroke'] == 1].iloc[0]
    # Find a negative case (stroke=0)
    negative_example = df[df['stroke'] == 0].iloc[0]
    
    examples_str = (
        f"Example Patient 1 (Stroke):\n"
        f"- Gender: {positive_example['gender']}\n"
        f"- Age: {int(positive_example['age'])}\n"
        f"- Hypertension: {positive_example['hypertension']}\n"
        f"- Heart Disease: {positive_example['heart_disease']}\n"
        f"- Ever Married: {positive_example['ever_married']}\n"
        f"- Work Type: {positive_example['work_type']}\n"
        f"- Residence Type: {positive_example['Residence_type']}\n"
        f"- Average Glucose Level: {positive_example['avg_glucose_level']:.2f}\n"
        f"- BMI: {positive_example['bmi']:.1f}\n"
        f"- Smoking Status: {positive_example['smoking_status']}\n"
        f"Stroke Status: Stroke\n\n"
        
        f"Example Patient 2 (No Stroke):\n"
        f"- Gender: {negative_example['gender']}\n"
        f"- Age: {int(negative_example['age'])}\n"
        f"- Hypertension: {negative_example['hypertension']}\n"
        f"- Heart Disease: {negative_example['heart_disease']}\n"
        f"- Ever Married: {negative_example['ever_married']}\n"
        f"- Work Type: {negative_example['work_type']}\n"
        f"- Residence Type: {negative_example['Residence_type']}\n"
        f"- Average Glucose Level: {negative_example['avg_glucose_level']:.2f}\n"
        f"- BMI: {negative_example['bmi']:.1f}\n"
        f"- Smoking Status: {negative_example['smoking_status']}\n"
        f"Stroke Status: No Stroke"
    )
    return examples_str, [positive_example.name, negative_example.name]

def predict_with_few_shot(row, examples):
    """Make prediction using few-shot learning with better prompting"""
    prompt = f"""As a medical AI specialized in stroke prediction, analyze the following patient data:

{examples}

Now evaluate this new patient:
- Gender: {row['gender']}
- Age: {int(row['age'])}
- Hypertension: {row['hypertension']}
- Heart Disease: {row['heart_disease']}
- Ever Married: {row['ever_married']}
- Work Type: {row['work_type']}
- Residence Type: {row['Residence_type']}
- Average Glucose Level: {row['avg_glucose_level']:.2f}
- BMI: {row['bmi']:.1f}
- Smoking Status: {row['smoking_status']}

Instructions:
1. Analyze all risk factors for stroke.
2. Key risk factors: age, hypertension, heart disease, glucose level.
3. Compare to the examples provided.
4. Answer with ONLY "Stroke" or "No Stroke".

Your prediction:"""
    
    response = llm.invoke([
        {"role": "system", "content": "You are a medical AI assistant specialized in stroke prediction. Answer with ONLY 'Stroke' or 'No Stroke', nothing else."},
        {"role": "user", "content": prompt}
    ])
    
    # More robust output parsing
    output = response.content.strip().lower()
    if "stroke" in output and not any(neg in output for neg in ["no stroke", "no-stroke", "not stroke"]):
        return 1
    else:
        return 0

# Main execution
if __name__ == "__main__":
    # Load data
    df = load_data('healthcare-dataset-stroke-data.csv')
    
    # Create few-shot examples
    examples, example_indices = create_few_shot_examples(df)
    
    # Create a more balanced test set with both positive and negative cases
    # Remove example cases from the dataset
    remaining_df = df.drop(example_indices)
    
    # Split into stroke and non-stroke cases
    stroke_cases = remaining_df[remaining_df['stroke'] == 1]
    non_stroke_cases = remaining_df[remaining_df['stroke'] == 0]
    
    # Sample equal numbers of each for testing (10 of each)
    test_stroke = stroke_cases.sample(min(10, len(stroke_cases)))
    test_non_stroke = non_stroke_cases.sample(min(10, len(non_stroke_cases)))
    
    # Combine into balanced test set
    test_df = pd.concat([test_stroke, test_non_stroke])
    true_labels = test_df['stroke'].astype(int)
    
    # Get predictions
    predictions = []
    for idx, row in test_df.iterrows():
        pred = predict_with_few_shot(row, examples)
        predictions.append(pred)
        print(f"Processing case {len(predictions)}/20: Actual={row['stroke']}, Predicted={pred}")
    
    predictions = pd.Series(predictions, index=test_df.index)
    
    # Evaluate
    print(f"\nAccuracy: {accuracy_score(true_labels, predictions):.2%}")
    print("\nClassification Report:")
    print(classification_report(true_labels, predictions))
    
    # Compare predictions
    print("\nDetailed Comparison:")
    results = pd.DataFrame({
        'Actual': true_labels,
        'Predicted': predictions,
        'Match': true_labels == predictions
    })
    print(results)

Processing case 1/20: Actual=1, Predicted=1
Processing case 2/20: Actual=1, Predicted=0
Processing case 3/20: Actual=1, Predicted=0
Processing case 4/20: Actual=1, Predicted=1
Processing case 5/20: Actual=1, Predicted=0
Processing case 6/20: Actual=1, Predicted=1
Processing case 7/20: Actual=1, Predicted=0
Processing case 8/20: Actual=1, Predicted=0
Processing case 9/20: Actual=1, Predicted=0
Processing case 10/20: Actual=1, Predicted=0
Processing case 11/20: Actual=0, Predicted=0
Processing case 12/20: Actual=0, Predicted=1
Processing case 13/20: Actual=0, Predicted=0
Processing case 14/20: Actual=0, Predicted=0
Processing case 15/20: Actual=0, Predicted=0
Processing case 16/20: Actual=0, Predicted=0
Processing case 17/20: Actual=0, Predicted=0
Processing case 18/20: Actual=0, Predicted=0
Processing case 19/20: Actual=0, Predicted=0
Processing case 20/20: Actual=0, Predicted=0

Accuracy: 60.00%

Classification Report:
              precision    recall  f1-score   support

           0