In [None]:
!wget https://huggingface.co/ahsannadir/llama-3-8b-instruct-aidoctor/resolve/main/unsloth.Q4_K_M.gguf

In [None]:
!pip install -U xformers --index-url https://download.pytorch.org/whl/cu121
!pip install --no-deps packaging ninja einops flash-attn trl peft accelerate bitsandbytes
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

In [None]:
!pip install llama-cpp-python
!pip install streamlit
!pip install chainlit
!pip install pyngrok

In [35]:
!ngrok authtoken 2sJRycBarS7IUUpygwqRIlrxIKP_2ah5aUU1hsj8oWzF83SuA

Authtoken saved to configuration file: /root/.config/ngrok/ngrok.yml


In [36]:
import torch
import os
import json
import pandas as pd
from datasets import Dataset, DatasetDict
from datasets import load_dataset
from huggingface_hub import notebook_login
from transformers import TrainingArguments
from trl import SFTTrainer
from unsloth import FastLanguageModel

In [None]:
%%writefile ai-doctor.py
import chainlit as cl
import subprocess
import logging
import torch
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
from unsloth import FastLanguageModel

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@cl.cache
def load_diagnosis_model():
    try:
        model_path = "/content/unsloth.Q4_K_M.gguf"
        logger.info(f"Model path: {model_path}")
        
        return Llama(
            model_path=model_path,
            n_ctx=2048,
            n_gpu_layers=15,
            n_threads=8,
            verbose=False
        )
    except Exception as e:
        logger.error(f"Diagnosis model failed: {str(e)}")
        raise

@cl.cache 
def load_prescription_model():
    try:
        logger.info("Loading prescription model...")
        model, tokenizer = FastLanguageModel.from_pretrained(
            model_name="unsloth/llama-3-8b-Instruct-bnb-4bit",
            max_seq_length=2048,
            load_in_4bit=True,
            device_map="auto"
        )
        FastLanguageModel.for_inference(model)
        return model, tokenizer
    except Exception as e:
        logger.error(f"Prescription model failed: {str(e)}")
        raise


@cl.set_starters
async def set_starters():
    return [
        cl.Starter(
            label="Detect Common Cold",
            message="I've been experiencing symptoms like sneezing, runny nose, nasal congestion, sore throat, cough, mild headache, slight fever, fatigue. What could it be?",
            ),

        cl.Starter(
            label="Detect Pneumonia",
            message="I've been experiencing symptoms like cough, fever, chills, shortness of breath, chest pain, fatigue, sweating, loss of appetite. What could it be?",
            ),
        cl.Starter(
            label="Detect Heart Attack",
            message="I've been experiencing symptomns like vomiting, breathlessness, sweating, and chest pain. What could it be?",
            ),
        cl.Starter(
            label="Detect Chicken Pox",
            message="I've been experiencing symptoms like itching, skin rash, fatigue, lethargy, high fever, headache, loss of appetite, mild fever, malaise, and red spots over body. What could this be?",
            )
        ]
    
@cl.on_chat_start
async def start_chat():
    try:
        status = cl.Message(content="🚀 Starting Medical AI...")
        await status.send()
        
        status.content = "📥 Loading Diagnosis Model..."
        await status.update()
        diag_model = load_diagnosis_model()
        
        status.content = "📥 Loading Prescription Model..."
        await status.update()
        rx_model, rx_tokenizer = load_prescription_model()
        
        status.content = "🔍 Verifying GPU Resources..."
        await status.update()
        free, total = torch.cuda.mem_get_info()
        logger.info(f"GPU Memory - Free: {free/1e9:.1f}GB, Total: {total/1e9:.1f}GB")
        
        if free < 2e9:
            raise MemoryError(f"Low GPU memory: {free/1e9:.1f}GB free")
            
        cl.user_session.set("diag_model", diag_model)
        cl.user_session.set("rx_model", rx_model)
        cl.user_session.set("rx_tokenizer", rx_tokenizer)
        
        status.content = "🩺 AI Doctor Ready! Describe your symptoms."
        await status.update()
        
    except Exception as e:
        await cl.Message(content=f"❌ Initialization Failed: {str(e)}").send()
        raise

@cl.on_message
async def main(message: cl.Message):
    try:
        diag_model = cl.user_session.get("diag_model")
        rx_model = cl.user_session.get("rx_model")
        rx_tokenizer = cl.user_session.get("rx_tokenizer")
        
        response = cl.Message(content="")
        await response.send()
        
        # Step 1: Generate Diagnosis
        logger.info("Generating diagnosis...")
        diagnosis = diag_model.create_completion(
            prompt=f"""<|im_start|>system
            Analyze symptoms and respond with:
            - Detected Condition
            - Confidence Level
            - Key Indicators<|im_end|>
            <|im_start|>user
            {message.content}<|im_end|>
            <|im_start|>assistant""",
            max_tokens=256,
            temperature=0.2,
            stop=["<|im_end|>"]
        )['choices'][0]['text']
        
        logger.info("Generating treatment...")
        treatment = rx_tokenizer.decode(
            rx_model.generate(
                **rx_tokenizer(
                    f"""<|im_start|>system
                    Provide treatment plan based on diagnosis:
                    {diagnosis}<|im_end|>
                    <|im_start|>assistant""",
                    return_tensors="pt"
                ).to("cuda"),
                max_new_tokens=512
            )[0],
            skip_special_tokens=True
        )
        
        response.content = f"""
        **Diagnosis**
        {diagnosis.strip()}
        
        {treatment.split('<|im_end|>')[-1].strip()}
        """
        await response.update()
        
    except Exception as e:
        await cl.Message(content=f"⚠️ Error: {str(e)}").send()

if __name__ == "__main__":
    cl.run()

In [None]:
import subprocess
import time
from pyngrok import ngrok

process = subprocess.Popen(["chainlit", "run", "ai-doctor.py", "--port", "8000", "--host", "0.0.0.0"])

time.sleep(10)

ngrok_tunnel = ngrok.connect(8000)
print("Chainlit App URL:", ngrok_tunnel.public_url)