<a href="https://colab.research.google.com/github/ikramMc/PFE/blob/main/llm_deployement.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    !pip install --no-deps unsloth vllm==0.8.5.post1

In [None]:
!pip install flask-cors


In [None]:
%%capture
!pip install pyngrok


In [None]:
%%capture
!pip install gradio requests fastapi uvicorn

In [None]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    !pip install --no-deps unsloth vllm==0.8.5.post1
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" huggingface_hub hf_transfer

    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy
    f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
    with open("vllm_requirements.txt", "wb") as file:
        file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
    !pip install -r vllm_requirements.txt

In [None]:
from pyngrok import ngrok
from google.colab import userdata
ngrok_key=userdata.get('ngrok')
ngrok.set_auth_token(ngrok_key)


In [None]:
from fastapi import FastAPI, Body, HTTPException
from pydantic import BaseModel
import torch
import gc
from unsloth import FastLanguageModel
from typing import List, Dict, Optional
import time

# Load the LLM model with Unsloth
model_path = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
token=userdata.get(HF_token)

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_path,
    max_seq_length=32000,
    dtype=torch.float16,
    load_in_4bit=True,
    token=token,
)

FastLanguageModel.for_inference(model)

# === Create FastAPI app ===
app = FastAPI()
SESSIONS = {}

# OpenAI compatible schemas
class ChatCompletionMessage(BaseModel):
    role: str  # "user", "assistant", or "system"
    content: str
    name: Optional[str] = None

class ChatCompletionRequest(BaseModel):
    model: str = "mistral"
    messages: List[ChatCompletionMessage]
    temperature: Optional[float] = 0.7
    top_p: Optional[float] = 0.9
    max_tokens: Optional[int] = 400

class ChatCompletionChoice(BaseModel):
    index: int
    message: ChatCompletionMessage
    finish_reason: str

class ChatCompletionResponse(BaseModel):
    id: str
    object: str = "chat.completion"
    created: int
    model: str
    choices: List[ChatCompletionChoice]
    usage: Dict[str, int]

# === Health Check Endpoint ===
@app.get("/health")
async def health_check():
    return {
        "status": "healthy",
        "model_loaded": True,
        "timestamp": int(time.time())
    }

# === OpenAI compatible endpoints ===
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
    try:
        # Convert messages to chat template format
        chat_history = [{"role": msg.role, "content": msg.content} for msg in request.messages]

        formatted_prompt = tokenizer.apply_chat_template(chat_history, tokenize=False) if hasattr(tokenizer, "apply_chat_template") else request.messages[-1].content

        print("\n📝 PROMPT SENT TO MODEL:\n")
        print(formatted_prompt)

        with torch.no_grad():
            inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
            outputs = model.generate(
                **inputs,
                max_new_tokens=request.max_tokens,
                do_sample=True,
                temperature=request.temperature,
                top_p=request.top_p
            )
            input_length = inputs['input_ids'].shape[1]
            generated_tokens = outputs[0][input_length:]
            generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)

        print("\n🤖 MODEL GENERATED RESPONSE:\n")
        print(generated_text)

        # Clean up
        del inputs, outputs
        torch.cuda.empty_cache()
        gc.collect()

        # Prepare OpenAI compatible response
        response_message = ChatCompletionMessage(
            role="assistant",
            content=generated_text
        )

        response = ChatCompletionResponse(
            id=f"chatcmpl-{int(time.time())}",
            created=int(time.time()),
            model=request.model,
            choices=[ChatCompletionChoice(
                index=0,
                message=response_message,
                finish_reason="stop"
            )],
            usage={
                "prompt_tokens": input_length,
                "completion_tokens": len(generated_tokens),
                "total_tokens": input_length + len(generated_tokens)
            }
        )

        return response

    except Exception as e:
        torch.cuda.empty_cache()
        gc.collect()
        print("❌ Error:", str(e))
        raise HTTPException(status_code=500, detail=str(e))



In [None]:
# Add this at the top of your imports
import atexit
from threading import Event

# Global variables for cleanup
server_thread = None
stop_event = Event()
ngrok_tunnel = None

# === ngrok setup ===
if __name__ == "__main__":
    import nest_asyncio
    from pyngrok import ngrok
    from threading import Thread
    import asyncio
    from uvicorn import Config, Server

    # Cleanup any previous instances
    stop_event.set()
    if server_thread and server_thread.is_alive():
        server_thread.join()
    if ngrok_tunnel:
        ngrok.disconnect(ngrok_tunnel.public_url)

    # Reset for new run
    stop_event.clear()

    # 1. Public tunnel with ngrok
    ngrok_tunnel = ngrok.connect(8001)
    print("✅ Public URL:", ngrok_tunnel.public_url)

    # 2. Patch asyncio for Jupyter environments
    nest_asyncio.apply()

    # 3. Run server in a thread with its own event loop
    def run_api():
        config = Config(app=app, host="0.0.0.0", port=8001, log_level="info")
        server = Server(config=config)
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)

        # Run until stop event is set
        while not stop_event.is_set():
            loop.run_until_complete(server.serve())
        loop.close()

    server_thread = Thread(target=run_api)
    server_thread.start()

    # Register cleanup at exit
    def cleanup():
        stop_event.set()
        if ngrok_tunnel:
            ngrok.disconnect(ngrok_tunnel.public_url)
        if server_thread and server_thread.is_alive():
            server_thread.join()

    atexit.register(cleanup)