In [None]:
# ✅ STEP 0: Install necessary packages
!pip install transformers accelerate fastapi uvicorn nest-asyncio pyngrok --quiet

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import nest_asyncio
from pyngrok import ngrok, conf

# 🔐 STEP 1: ENTER YOUR NGROK AUTH TOKEN
ngrok_token = input("Paste your ngrok auth token: ")
conf.get_default().auth_token = ngrok_token

# 🔄 STEP 2: LOAD THE PHI-2 MODEL
model_name = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)

# ⚙️ STEP 3: SET UP FASTAPI
app = FastAPI()

# Allow frontend to access the API
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# 📦 Define input schema
class GoalInput(BaseModel):
    goal: str

# 🔮 STEP 4: DIET PLAN GENERATION ENDPOINT
@app.post("/generate")
async def generate_diet_plan(data: GoalInput):
    goal = data.goal

    # Modified prompt to force output format
    prompt = f"""You are a certified dietitian. Generate a one-day meal plan ONLY based on the goal: {goal}.
Do NOT include any explanation or extra text.

Respond ONLY in this exact format:
Breakfast:
Lunch:
Dinner:
Snack:"""

    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    outputs = model.generate(
        **inputs,
        max_new_tokens=300,
        do_sample=True,       # Enable sampling for diverse outputs
        top_k=50,
        top_p=0.95,
        temperature=0.8
    )
    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    cleaned = result.replace(prompt.strip(), "").strip()
    return {"meal_plan": cleaned}

# 🚀 STEP 5: Start the FastAPI app with ngrok
nest_asyncio.apply()
public_url = ngrok.connect(8000)
print(f"🔗 Public API URL: {public_url}/generate")

# Run the app
uvicorn.run(app, port=8000)
