<a href="https://colab.research.google.com/github/ikramMc/PFE/blob/main/llm_deployement.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
!pip install pip3-autoremove
!pip install torch torchvision torchaudio xformers --index-url https://download.pytorch.org/whl/cu124
!pip install unsloth
# !pip install --upgrade transformers==4.52.3

In [None]:
%%capture
!pip installrequests fastapi uvicorn
!pip install pyngrok


In [None]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    !pip install --no-deps unsloth vllm==0.8.5.post1
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" huggingface_hub hf_transfer

    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy
    f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
    with open("vllm_requirements.txt", "wb") as file:
        file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
    !pip install -r vllm_requirements.txt

In [None]:
from pyngrok import ngrok
from google.colab import userdata
ngrok_key=userdata.get('ngrok')
ngrok.set_auth_token(ngrok_key)


TimeoutException: Requesting secret ngrok timed out. Secrets can only be fetched when running from the Colab UI.

In [None]:
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
import gc
from unsloth import FastLanguageModel
from typing import List, Dict, Optional
import time

# === Load the LLM model with Unsloth ===
model_path = "kimxxxx/mistral_r32_a64_b8_gas4_lr5e-5_4500tk_2epoch"
token=userdata.get(HF_token)

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_path,
    max_seq_length=32000,
    dtype=torch.float16,
    load_in_4bit=True,
    token=token,
)

# === Create FastAPI app ===
app = FastAPI()

# === Schemas ===
class ChatCompletionMessage(BaseModel):
    role: str
    content: str
    name: Optional[str] = None

class ChatCompletionRequest(BaseModel):
    model: str = "mistral"
    messages: List[ChatCompletionMessage]
    temperature: Optional[float] = 0.4
    top_p: Optional[float] = 0.9
    max_tokens: Optional[int] = 100

class ChatCompletionChoice(BaseModel):
    index: int
    message: ChatCompletionMessage
    finish_reason: str

class ChatCompletionResponse(BaseModel):
    id: str
    object: str = "chat.completion"
    created: int
    model: str
    choices: List[ChatCompletionChoice]
    usage: Dict[str, int]

# === Health Check ===
@app.get("/health")
async def health_check():
    return {"status": "healthy", "model_loaded": True, "timestamp": int(time.time())}

# === Chat Completions ===
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
    try:
        # --- Apply chat template ---
        formatted_prompt = tokenizer.apply_chat_template(
            [msg.dict() for msg in request.messages],
            tokenize=False,
            add_generation_prompt=True
        )
        print("📝 Prompt:", formatted_prompt)

        # --- Tokenize ---
        inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
        input_length = inputs["input_ids"].shape[1]

        # --- Generate text ---
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=request.max_tokens,
                temperature=request.temperature,
                top_p=request.top_p,
                do_sample=True,
            )

        # --- Keep only the generated tokens ---
        generated_tokens = outputs[0][input_length:]

        # --- Decode result ---
        generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
        print("🤖 Response:", generated_text)

        # Cleanup
        del inputs, outputs
        torch.cuda.empty_cache()
        gc.collect()

        # --- Build API response ---
        response_message = ChatCompletionMessage(role="assistant", content=generated_text)

        response = ChatCompletionResponse(
            id=f"chatcmpl-{int(time.time())}",
            created=int(time.time()),
            model=request.model,
            choices=[ChatCompletionChoice(index=0, message=response_message, finish_reason="stop")],
            usage={
                "prompt_tokens": input_length,
                "completion_tokens": len(generated_tokens),
                "total_tokens": input_length + len(generated_tokens),
            },
        )

        return response

    except Exception as e:
        torch.cuda.empty_cache()
        gc.collect()
        print("❌ Error:", str(e))
        raise HTTPException(status_code=500, detail=str(e))


In [None]:
# Add this at the top of your imports
import atexit
from threading import Event

# Global variables for cleanup
server_thread = None
stop_event = Event()
ngrok_tunnel = None

# === ngrok setup ===
if __name__ == "__main__":
    import nest_asyncio
    from pyngrok import ngrok
    from threading import Thread
    import asyncio
    from uvicorn import Config, Server

    # Cleanup any previous instances
    stop_event.set()
    if server_thread and server_thread.is_alive():
        server_thread.join()
    if ngrok_tunnel:
        ngrok.disconnect(ngrok_tunnel.public_url)

    # Reset for new run
    stop_event.clear()

    # 1. Public tunnel with ngrok
    ngrok_tunnel = ngrok.connect(8001)
    print("✅ Public URL:", ngrok_tunnel.public_url)

    # 2. Patch asyncio for Jupyter environments
    nest_asyncio.apply()

    # 3. Run server in a thread with its own event loop
    def run_api():
        config = Config(app=app, host="0.0.0.0", port=8001, log_level="info")
        server = Server(config=config)
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)

        # Run until stop event is set
        while not stop_event.is_set():
            loop.run_until_complete(server.serve())
        loop.close()

    server_thread = Thread(target=run_api)
    server_thread.start()

    # Register cleanup at exit
    def cleanup():
        stop_event.set()
        if ngrok_tunnel:
            ngrok.disconnect(ngrok_tunnel.public_url)
        if server_thread and server_thread.is_alive():
            server_thread.join()

    atexit.register(cleanup)