# Demo #9: Fine-Tuning Embedding Models for Domain-Specific RAG

## Objective
Demonstrate how fine-tuning an embedding model on domain-specific query-passage pairs significantly improves retrieval accuracy compared to generic embeddings.

## Core Concepts Demonstrated
1. **Embedding Model Fine-Tuning**: Adapting pre-trained embeddings to specialized domains
2. **Contrastive Learning with Triplet Loss**: Training with (query, positive, negative) triplets
3. **Domain Adaptation**: Improving performance on specialized terminology (Medical domain)
4. **Comparative Evaluation**: Measuring retrieval accuracy, MRR, and answer quality

## Workshop Context
This is Demo #9 in the Advanced RAG Workshop series, focusing on optimization and production readiness. Fine-tuning embeddings is often more practical than fine-tuning large generator LLMs.

---

## Data Flow: Fine-Tuned vs Generic Embeddings

### Generic Embeddings (Baseline):
```
Query → Generic Embedding → Vector Search → Retrieved Chunks (may miss domain-specific nuances)
```

### Fine-Tuned Embeddings (Enhanced):
```
Domain-Specific Triplets → Fine-Tune Embedding Model → 
Query → Fine-Tuned Embedding → Vector Search → Retrieved Chunks (domain-optimized)
```

**Key Insight**: Fine-tuning pulls domain-relevant query-document pairs closer in embedding space while pushing irrelevant pairs apart through contrastive learning.

## Section 1: Setup and Environment Configuration

In [None]:
# Import required libraries
import os
import json
import warnings
from pathlib import Path
from typing import List, Tuple, Dict
from dotenv import load_dotenv
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# LlamaIndex imports
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Document
from llama_index.core.node_parser import SentenceSplitter
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.query_engine import RetrieverQueryEngine

# Sentence Transformers for fine-tuning
from sentence_transformers import SentenceTransformer, InputExample, losses, evaluation
from sentence_transformers.util import cos_sim
from torch.utils.data import DataLoader

warnings.filterwarnings('ignore')
sns.set_style("whitegrid")

print("✅ All libraries imported successfully!")
print(f"📁 Current working directory: {Path.cwd()}")

In [None]:
# Load environment variables for Azure OpenAI
load_dotenv()

# Azure OpenAI Configuration
AZURE_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
AZURE_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
AZURE_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION", "2024-02-15-preview")
AZURE_LLM_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME")
AZURE_EMBED_DEPLOYMENT = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT")

# Validate configuration
if not all([AZURE_API_KEY, AZURE_ENDPOINT, AZURE_LLM_DEPLOYMENT, AZURE_EMBED_DEPLOYMENT]):
    raise ValueError("❌ Missing Azure OpenAI credentials. Please check your .env file.")

# Initialize Azure OpenAI LLM (for answer generation)
azure_llm = AzureOpenAI(
    model="gpt-4",
    deployment_name=AZURE_LLM_DEPLOYMENT,
    api_key=AZURE_API_KEY,
    azure_endpoint=AZURE_ENDPOINT,
    api_version=AZURE_API_VERSION,
    temperature=0.1
)

# Initialize Azure OpenAI Embeddings (for baseline comparison)
azure_embed = AzureOpenAIEmbedding(
    model="text-embedding-ada-002",
    deployment_name=AZURE_EMBED_DEPLOYMENT,
    api_key=AZURE_API_KEY,
    azure_endpoint=AZURE_ENDPOINT,
    api_version=AZURE_API_VERSION
)

print("✅ Azure OpenAI LLM and Embedding models configured successfully!")
print(f"🤖 LLM Deployment: {AZURE_LLM_DEPLOYMENT}")
print(f"🔢 Embedding Deployment: {AZURE_EMBED_DEPLOYMENT}")

## Section 2: Create Domain-Specific Training Dataset

We'll create a medical terminology dataset with 60 query-positive-negative triplets. Generic embeddings struggle with specialized medical terms like "myocardial infarction" vs "heart attack".

In [None]:
# Create domain-specific training dataset (Medical domain)
# Format: (query, positive_passage, negative_passage)
medical_triplets = [
    # Cardiology
    ("What causes myocardial infarction?", 
     "Myocardial infarction, commonly known as a heart attack, occurs when blood flow to the heart muscle is blocked, usually by a blood clot in a coronary artery. This blockage deprives the heart tissue of oxygen, causing cell death.",
     "Pneumonia is an infection that inflames the air sacs in one or both lungs, which may fill with fluid or pus, causing cough with phlegm, fever, chills, and difficulty breathing."),
    
    ("Symptoms of acute coronary syndrome",
     "Acute coronary syndrome symptoms include chest pain or discomfort, pain radiating to the arm, jaw, or back, shortness of breath, nausea, sweating, and lightheadedness. Immediate medical attention is critical.",
     "Asthma symptoms include wheezing, shortness of breath, chest tightness, and coughing, especially at night or early morning. These symptoms are caused by inflammation and narrowing of the airways."),
    
    ("Treatment for atrial fibrillation",
     "Atrial fibrillation treatment includes rate control medications like beta-blockers, rhythm control drugs such as amiodarone, anticoagulation therapy to prevent stroke, and procedures like cardioversion or catheter ablation.",
     "Diabetes treatment involves lifestyle modifications, blood glucose monitoring, oral hypoglycemic agents like metformin, and insulin therapy for type 1 or advanced type 2 diabetes."),
    
    # Neurology
    ("What is a cerebrovascular accident?",
     "A cerebrovascular accident (CVA), commonly called a stroke, occurs when blood supply to part of the brain is interrupted or reduced, preventing brain tissue from getting oxygen and nutrients. Brain cells begin to die within minutes.",
     "Multiple sclerosis is an autoimmune disease where the immune system attacks the protective myelin sheath covering nerve fibers, causing communication problems between the brain and the rest of the body."),
    
    ("Signs of increased intracranial pressure",
     "Increased intracranial pressure signs include severe headache, vomiting, altered mental status, papilledema on fundoscopic exam, and in severe cases, Cushing's triad: hypertension, bradycardia, and irregular respirations.",
     "Migraine headaches are characterized by throbbing pain on one side of the head, often accompanied by nausea, vomiting, and sensitivity to light and sound, lasting 4-72 hours if untreated."),
    
    ("Diagnosis of Alzheimer's disease",
     "Alzheimer's disease diagnosis involves comprehensive medical history, cognitive testing with tools like MMSE or MoCA, neuroimaging (MRI/PET scans) showing brain atrophy, and ruling out other causes of dementia.",
     "Parkinson's disease diagnosis is primarily clinical, based on presence of cardinal symptoms: resting tremor, bradykinesia, rigidity, and postural instability, along with response to dopaminergic therapy."),
    
    # Pulmonology
    ("Pathophysiology of chronic obstructive pulmonary disease",
     "COPD pathophysiology involves chronic inflammation of airways and destruction of alveolar walls (emphysema) and airway narrowing (chronic bronchitis), leading to progressive airflow limitation and gas exchange abnormalities.",
     "Congestive heart failure pathophysiology involves the heart's inability to pump sufficient blood to meet the body's needs, leading to fluid accumulation in lungs and peripheral tissues."),
    
    ("Treatment for acute respiratory distress syndrome",
     "ARDS treatment includes low tidal volume mechanical ventilation (6 ml/kg ideal body weight), positive end-expiratory pressure (PEEP), prone positioning, and treatment of underlying cause like sepsis or pneumonia.",
     "Treatment for anaphylaxis includes immediate intramuscular epinephrine, removal of trigger, supplemental oxygen, IV fluids, and antihistamines. Patients should be monitored for biphasic reactions."),
    
    # Gastroenterology
    ("What causes peptic ulcer disease?",
     "Peptic ulcer disease is primarily caused by Helicobacter pylori infection or long-term use of NSAIDs. These factors damage the protective mucosal lining of stomach or duodenum, allowing gastric acid to cause ulceration.",
     "Gastroesophageal reflux disease (GERD) occurs when stomach acid frequently flows back into the esophagus, irritating its lining and causing heartburn and potential esophageal damage."),
    
    ("Complications of cirrhosis",
     "Cirrhosis complications include portal hypertension leading to esophageal varices, ascites, hepatic encephalopathy, hepatorenal syndrome, spontaneous bacterial peritonitis, and increased risk of hepatocellular carcinoma.",
     "Complications of diabetes mellitus include diabetic retinopathy, nephropathy, neuropathy, cardiovascular disease, peripheral arterial disease, and increased susceptibility to infections."),
    
    # Nephrology
    ("Stages of chronic kidney disease",
     "Chronic kidney disease is classified into 5 stages based on GFR: Stage 1 (GFR ≥90) with kidney damage, Stage 2 (60-89), Stage 3 (30-59), Stage 4 (15-29), and Stage 5 (GFR <15) requiring dialysis or transplant.",
     "Heart failure is classified using NYHA functional classification: Class I (no limitation), Class II (slight limitation), Class III (marked limitation), and Class IV (symptoms at rest)."),
    
    ("Indications for hemodialysis",
     "Hemodialysis indications include severe hyperkalemia, metabolic acidosis, uremic pericarditis, uremic encephalopathy, severe fluid overload, and GFR <10-15 ml/min with uremic symptoms.",
     "Indications for liver transplantation include decompensated cirrhosis with MELD score >15, hepatocellular carcinoma within Milan criteria, acute liver failure, and certain metabolic disorders."),
    
    # Continue with more triplets to reach 60...
]

print(f"✅ Created {len(medical_triplets)} medical training triplets")
print("\n📊 Sample Triplet:")
print(f"Query: {medical_triplets[0][0]}")
print(f"Positive: {medical_triplets[0][1][:100]}...")
print(f"Negative: {medical_triplets[0][2][:100]}...")