In [None]:
# %%
# ===================================================================
# CELL 1: Install Dependencies
# ===================================================================

import os

# Check if running in Colab
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Colab-specific installation
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl==0.15.2 triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer
    !pip install --no-deps unsloth

In [None]:
# %%
# Install additional dependencies
!pip install openai fastapi uvicorn nest-asyncio pyngrok pydantic python-multipart
!pip install ffmpeg-python pydub

In [None]:
# %%
# ===================================================================
# CELL 2: Import Libraries
# ===================================================================
import torch
from transformers import AutoTokenizer
from unsloth import FastLanguageModel
import openai
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
from pyngrok import ngrok
import nest_asyncio
import tempfile
import os
from typing import Optional
import json

# Apply nest_asyncio to allow running FastAPI in Colab
nest_asyncio.apply()

print("✓ All libraries imported successfully")


In [None]:
# %%
os.environ["OPENAI_API_KEY"] = "enter you key"


In [None]:
# %%
# ===================================================================
# CELL 4: Load Fine-Tuned Model
# ===================================================================
print("Loading fine-tuned model...")

# Path to your fine-tuned model
model_path = "abdulsamad99/My_updated_Medical_Note_Generation-fine-tuning"

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Load model
model, _ = FastLanguageModel.from_pretrained(
    model_name=model_path,
    load_in_4bit=True,
    dtype=torch.float16,
)

FastLanguageModel.for_inference(model)

print("✓ Model loaded successfully")

In [None]:
# %%
# ===================================================================
# CELL 5: Audio Transcription Function
# ===================================================================
def transcribe_audio(audio_file_path: str) -> str:
    """
    Transcribe audio file using OpenAI Whisper API
    
    Args:
        audio_file_path: Path to audio file
        
    Returns:
        Transcribed text
    """
    try:
        with open(audio_file_path, "rb") as audio_file:
            transcript = openai.audio.transcriptions.create(
                model="whisper-1",
                file=audio_file,
                response_format="text"
            )
        return transcript
    except Exception as e:
        raise Exception(f"Transcription failed: {str(e)}")


In [None]:
# %%
# ===================================================================
# CELL 6: Dialogue Segmentation Function
# ===================================================================
def segment_dialogue(transcribed_text: str) -> str:
    """
    Segment raw transcription into doctor-patient dialogue turns
    Uses GPT-4 to intelligently parse and format the conversation
    
    Args:
        transcribed_text: Raw transcription text
        
    Returns:
        Formatted dialogue with Doctor/Patient labels
    """
    try:
        response = openai.chat.completions.create(
            model="gpt-4",
            messages=[
                {
                    "role": "system",
                    "content": """You are a medical transcription assistant. Your task is to segment a raw doctor-patient conversation transcript into structured dialogue turns.

Format each turn as:
Doctor: [what the doctor said]
Patient: [what the patient said]

Rules:
- Identify speaker changes based on context and conversational flow
- Maintain chronological order
- Do not add, remove, or modify the actual content
- Only format and label speakers
- If speaker is unclear, make best judgment based on medical context"""
                },
                {
                    "role": "user",
                    "content": f"Segment this medical conversation:\n\n{transcribed_text}"
                }
            ],
            temperature=0.3
        )
        return response.choices[0].message.content
    except Exception as e:
        raise Exception(f"Dialogue segmentation failed: {str(e)}")

In [None]:
# %%
# ===================================================================
# CELL 7: Clinical Note Generation Function
# ===================================================================
def generate_clinical_note(dialogue: str, max_new_tokens=350) -> str:
    """
    Generate clinical note from dialogue using fine-tuned model
    
    Args:
        dialogue: Formatted doctor-patient dialogue
        max_new_tokens: Maximum tokens to generate
        
    Returns:
        Generated clinical note
    """
    # Full section list
    header_list = (
        "fam/sochx, genhx, pastmedicalhx, cc, pastsurgical, allergy, ros, medications, "
        "assessment, exam, diagnosis, disposition, plan, edcourse, immunizations, "
        "imaging, gynhx, procedures, other_history, labs"
    )
    
    # SYSTEM
    system_message = (
        "You are a medical scribe AI assistant. Your task is to read a doctor–patient "
        "dialogue and generate a clinical note for the correct documentation section. "
        "Use ONLY the dialogue. Never hallucinate."
    )
    
    # USER
    user_message = f"""Choose exactly ONE section from this list:
{header_list}

Rules:
- Use ONLY information stated in the dialogue.
- Be concise and medically accurate.
- Do not infer or add assumptions.

Format:
<section_header>
<section_text>

Dialogue:
{dialogue}"""
    
    # Build prompt EXACTLY like training format
    prompt = (
        f"<|im_start|>system<|im_sep|>{system_message}<|im_end|>\n"
        f"<|im_start|>user<|im_sep|>{user_message}<|im_end|>\n"
        f"<|im_start|>assistant<|im_sep|>"
    )
    
    # Tokenize
    inputs = tokenizer(prompt, return_tensors="pt")
    
    if torch.cuda.is_available():
        inputs = inputs.to("cuda")
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=0.3,
            top_p=0.9,
            do_sample=False,
            eos_token_id=tokenizer.convert_tokens_to_ids("<|im_end|>"),
            pad_token_id=tokenizer.eos_token_id
        )
    
    # Decode full output
    full_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
    
    # Extract ONLY the assistant part
    if "<|im_start|>assistant<|im_sep|>" in full_text:
        assistant_part = full_text.split("<|im_start|>assistant<|im_sep|>")[1]
        assistant_part = assistant_part.split("<|im_end|>")[0]
    else:
        assistant_part = full_text
    
    return assistant_part.strip()

In [None]:
# %%
# ===================================================================
# ADD THIS FUNCTION AFTER YOUR generate_clinical_note FUNCTION
# ===================================================================

def convert_to_soap_format(clinical_note: str) -> str:
    """
    Convert model output to standard SOAP format
    """
    try:
        response = openai.chat.completions.create(
            model="gpt-4",
            messages=[
                {
                    "role": "system",
                    "content": """You are a medical documentation specialist. Convert the given clinical note into proper SOAP format.

SOAP Format:

SUBJECTIVE:
- Chief Complaint
- History of Present Illness
- Past Medical History
- Medications
- Allergies
- Family History
- Social History

OBJECTIVE:
- Vital Signs (if available)
- Physical Examination
- Review of Systems
- Labs/Imaging (if mentioned)

ASSESSMENT:
- Diagnoses (numbered list)

PLAN:
- Diagnostic workup
- Medications
- Procedures
- Follow-up
- Patient education
- Referrals

Use ONLY the information provided. Write "Not documented" for missing sections."""
                },
                {
                    "role": "user",
                    "content": f"Convert to SOAP format:\n\n{clinical_note}"
                }
            ],
            temperature=0.3
        )
        return response.choices[0].message.content
    except Exception as e:
        raise Exception(f"SOAP conversion failed: {str(e)}")

In [None]:
# %%
# ===================================================================
# CELL 6: Test Model (Optional)
# ===================================================================
# Quick test to verify model works
test_dialogue = """Doctor: Hello, how are you feeling today?
Patient: I've been having headaches for the past week.
Doctor: Can you describe the headaches?
Patient: They're on the right side of my head, throbbing pain.
Doctor: Any nausea or vision changes?
Patient: Yes, some nausea but no vision problems."""

print("\n" + "="*70)
print("TESTING MODEL")
print("="*70)
print("\nTest Dialogue:")
print(test_dialogue)
print("\n" + "-"*70)
print("Generated Clinical Note:")
print("-"*70)

test_note = generate_clinical_note(test_dialogue)
print(test_note)
print("="*70 + "\n")

In [None]:
# %%
# ===================================================================
# UPDATE YOUR process_audio_to_clinical_note FUNCTION
# ===================================================================

def process_audio_to_clinical_note(audio_file_path: str) -> dict:
    """
    Complete pipeline with SOAP formatting
    """
    print("Step 1: Transcribing audio...")
    transcription = transcribe_audio(audio_file_path)
    print(f"✓ Transcription complete ({len(transcription)} characters)")
    
    print("\nStep 2: Segmenting dialogue...")
    segmented_dialogue = segment_dialogue(transcription)
    print("✓ Dialogue segmentation complete")
    
    print("\nStep 3: Generating clinical note...")
    clinical_note = generate_clinical_note(segmented_dialogue)
    print("✓ Clinical note generated")
    
    print("\nStep 4: Converting to SOAP format...")
    soap_note = convert_to_soap_format(clinical_note)
    print("✓ SOAP format applied")
    
    return {
        "transcription": transcription,
        "segmented_dialogue": segmented_dialogue,
        "clinical_note": clinical_note,  # Original format
        "soap_note": soap_note            # NEW: SOAP formatted
    }

In [None]:
# ===================================================================
# CELL 11: API Endpoints
# ===================================================================

@app.get("/")
async def root():
    """Health check endpoint"""
    return {
        "status": "running",
        "message": "Clinical Note Generation API is operational",
        "endpoints": {
            "POST /upload-audio": "Upload audio file for complete processing",
            "POST /transcribe": "Transcribe audio file",
            "POST /segment": "Segment transcribed text into dialogue",
            "POST /generate-note": "Generate clinical note from dialogue"
        }
    }

@app.post("/upload-audio", response_model=PipelineResponse)
async def upload_audio(file: UploadFile = File(...)):
    try:
        with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp_file:
            content = await file.read()
            tmp_file.write(content)
            tmp_file_path = tmp_file.name
        
        result = process_audio_to_clinical_note(tmp_file_path)
        os.unlink(tmp_file_path)
        
        return PipelineResponse(
            transcription=result["transcription"],
            segmented_dialogue=result["segmented_dialogue"],
            clinical_note=result["clinical_note"],
            soap_note=result["soap_note"],  # NEW
            success=True,
            message="Audio processed successfully"
        )
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/transcribe")
async def transcribe_endpoint(file: UploadFile = File(...)):
    """
    Transcribe audio file only
    """
    try:
        with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp_file:
            content = await file.read()
            tmp_file.write(content)
            tmp_file_path = tmp_file.name
        
        transcription = transcribe_audio(tmp_file_path)
        os.unlink(tmp_file_path)
        
        return {
            "transcription": transcription,
            "success": True
        }
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/segment")
async def segment_endpoint(request: TranscriptionRequest):
    """
    Segment transcribed text into dialogue
    """
    try:
        segmented = segment_dialogue(request.text)
        return {
            "segmented_dialogue": segmented,
            "success": True
        }
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/generate-note")
async def generate_note_endpoint(request: DialogueRequest):
    """
    Generate clinical note from dialogue
    """
    try:
        note = generate_clinical_note(request.dialogue)
        return {
            "clinical_note": note,
            "success": True
        }
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

In [None]:
# %%
# ===================================================================
# CELL 12: Start FastAPI Server with ngrok (BULLETPROOF COLAB VERSION)
# ===================================================================
import threading
import uvicorn

print("""
╔══════════════════════════════════════════════════════════════╗
║          FASTAPI SERVER + NGROK – PUBLIC API READY           ║
╚══════════════════════════════════════════════════════════════╝
""")

# Your ngrok token
ngrok.set_auth_token("enter ngrok key")

# Start ngrok tunnel first
tunnel = ngrok.connect(8000)
public_url = tunnel.public_url

print(f"✓ Public URL → {public_url}\n")
print("="*70)
print("API ENDPOINTS:")
print("="*70)
print(f"Health Check    : {public_url}/")
print(f"Full Pipeline   : {public_url}/upload-audio")
print(f"Swagger UI/Docs : {public_url}/docs")   # ← Open this and test instantly
print(f"Redoc           : {public_url}/redoc")
print("="*70)
print("Server is now running permanently in background thread!")
print("You can close this cell — the API stays alive until runtime disconnects.")
print("="*70)

# Run uvicorn in a background thread → completely bypasses the asyncio error
def run_uvicorn():
    uvicorn.run(app, host="0.0.0.0", port=8000)

thread = threading.Thread(target=run_uvicorn, daemon=True)
thread.start()

# Optional: keep the cell "running" visually so Colab doesn't show "finished"
import time
while True:
    time.sleep(60)
    print(f"Server still alive @ {public_url}/docs")