# AITA Travel AI Chat - Gemma 3 API Server

This notebook sets up a FastAPI server with Gemma 3 model to power the AITA travel chat assistant.

## How to Use This Notebook in Google Colab

### Step 1: Opening Google Colab
1. Go to [colab.research.google.com](https://colab.research.google.com)
2. Sign in with your Google account
3. Click "New notebook" or upload this existing notebook

### Step 2: Naming Your Notebook
- Click on the notebook name at the top (usually "Untitled")
- Rename it to something descriptive like: `AITA_Gemma3_API_Server.ipynb`
- Or: `TravelAI_Chat_Demo.ipynb`

### Step 3: Runtime Setup
- Go to Runtime → Change runtime type
- Select **T4 GPU** for better performance (recommended)
- Click Save

In [None]:
# Test cell to confirm Google Colab environment is working
print("✅ Google Colab is working!")
print("🚀 Ready to set up AITA Travel AI Chat API")

# Check Python version
import sys
print(f"Python version: {sys.version}")

# Check if we're in Colab
try:
    import google.colab
    print("📍 Running in Google Colab")
except ImportError:
    print("📍 Not running in Google Colab")

## 📦 Step 1: Install Required Packages

Run this cell first to install all the necessary packages for the API server.

In [None]:
# Install required packages
!pip install fastapi uvicorn transformers torch pyngrok nest-asyncio accelerate bitsandbytes

## 📚 Step 2: Import Libraries and Setup

In [None]:
from fastapi import FastAPI, Request, HTTPException
from transformers import pipeline, AutoTokenizer
import nest_asyncio
from pyngrok import ngrok
import uvicorn
import torch

# Allow nested event loops (required for Colab)
nest_asyncio.apply()

print("✅ All libraries imported successfully!")

## 🚀 Step 3: Initialize FastAPI App with CORS

In [None]:
# Initialize FastAPI app
app = FastAPI()

from fastapi.middleware.cors import CORSMiddleware

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allows all origins
    allow_credentials=True,
    allow_methods=["*"],  # Allows all methods
    allow_headers=["*"],  # Allows all headers
)

print("✅ FastAPI app initialized with CORS middleware")

## 🧠 Step 4: Load Gemma 3 Model

**This will take a few minutes to download and load the model.**

In [None]:
# Initialize model and tokenizer
model_name = "google/gemma-3-1b-it"
print(f"Loading model: {model_name}")

tokenizer = AutoTokenizer.from_pretrained(model_name)
pipe = pipeline(
    "text-generation", 
    model=model_name, 
    tokenizer=tokenizer,
    device=0 if torch.cuda.is_available() else -1,  # Use GPU if available
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
)

print(f"✅ Model loaded successfully!")
print(f"Using device: {'GPU' if torch.cuda.is_available() else 'CPU'}")

## 🔧 Step 5: Helper Functions

In [None]:
def extract_content_text(content):
    """Extract text from content array format"""
    if isinstance(content, list):
        for item in content:
            if isinstance(item, dict) and item.get("type") == "text":
                return item.get("text", "")
    elif isinstance(content, str):
        return content
    return ""

print("✅ Helper functions defined")

## 🌐 Step 6: Define API Endpoints

In [None]:
@app.post("/generate")
async def generate(request: Request):
    try:
        data = await request.json()
        messages = data.get("messages", [])
        
        if not messages:
            raise HTTPException(status_code=400, detail="No messages provided")
        
        # Convert messages to proper format for Gemma 3
        formatted_messages = []
        for msg in messages:
            role = msg.get("role", "")
            content = msg.get("content", [])
            
            if role in ["system", "user", "assistant"]:
                text_content = extract_content_text(content)
                if text_content:
                    formatted_messages.append({
                        "role": role,
                        "content": text_content
                    })
        
        if not formatted_messages:
            raise HTTPException(status_code=400, detail="No valid messages found")
        
        # Use tokenizer's chat template for proper formatting
        try:
            # Apply chat template
            prompt = tokenizer.apply_chat_template(
                formatted_messages,
                tokenize=False,
                add_generation_prompt=True
            )
        except Exception as e:
            # Fallback to manual formatting if chat template fails
            prompt = ""
            for msg in formatted_messages:
                if msg["role"] == "system":
                    prompt += f"System: {msg['content']}\n"
                elif msg["role"] == "user":
                    prompt += f"User: {msg['content']}\n"
                elif msg["role"] == "assistant":
                    prompt += f"Assistant: {msg['content']}\n"
            prompt += "Assistant:"
        
        # Generate response
        result = pipe(
            prompt, 
            max_new_tokens=512,  # Reduced for better performance
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id
        )
        
        # Extract only the new generated text
        generated_text = result[0]["generated_text"]
        
        # Remove the original prompt to get only the new response
        if generated_text.startswith(prompt):
            response_text = generated_text[len(prompt):].strip()
        else:
            response_text = generated_text
        
        return {"result": response_text}
        
    except HTTPException:
        raise
    except Exception as e:
        print(f"Error in generate endpoint: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")

@app.get("/")
async def root():
    return {"message": "Gemma 3 API is running"}

@app.get("/health")
async def health():
    return {"status": "healthy", "model": model_name}

print("✅ API endpoints defined")

## 🌍 Step 7: Start Ngrok Tunnel

**Copy the Public URL from the output below and use it in your React Native app!**

In [None]:
# Start ngrok tunnel
try:
    # Kill any existing tunnels
    ngrok.kill()
    
    public_url = ngrok.connect(8000)
    print("🎉 Ngrok tunnel started!")
    print(f"🔗 Public URL: {public_url}")
    print("📱 Use this URL in your React Native app!")
    print("📋 Copy this URL and replace it in your chatAI.tsx file")
except Exception as e:
    print(f"❌ Error starting ngrok: {e}")

## 🚀 Step 8: Start FastAPI Server

**Run this cell to start the server. It will keep running until you stop it.**

**Note:** This cell will block (show a running indicator). The server is working when you see "Uvicorn running" message.

In [None]:
# Start FastAPI server
print("🚀 Starting FastAPI server...")
print("🔄 This cell will run continuously...")
print("⏹️ Use Runtime > Interrupt execution to stop the server")

uvicorn.run(app, host="0.0.0.0", port=8000)

## 🧪 Step 9: Test Your API (Optional)

Run this cell in a **new tab/window** while the server is running to test if everything works.

In [None]:
import requests
import json

# Test the API with a simple request
def test_api():
    try:
        # Get the current ngrok URL
        tunnels = ngrok.get_tunnels()
        if tunnels:
            url = tunnels[0].public_url
            
            test_data = {
                "messages": [
                    {
                        "role": "system",
                        "content": [{"type": "text", "text": "You are AITA, a helpful AI travel assistant."}]
                    },
                    {
                        "role": "user", 
                        "content": [{"type": "text", "text": "Hello! Can you help me plan a trip to Paris?"}]
                    }
                ]
            }
            
            print("🧪 Testing API...")
            response = requests.post(f"{url}/generate", json=test_data, timeout=30)
            print(f"✅ Status Code: {response.status_code}")
            
            if response.status_code == 200:
                result = response.json()
                print("🎉 API Response:")
                print(f"📝 {result.get('result', 'No result')}")
            else:
                print(f"❌ Error: {response.text}")
        else:
            print("❌ No active ngrok tunnels found")
    except Exception as e:
        print(f"❌ Test failed: {e}")

# Run the test
test_api()

## 📮 How to Test with Postman

Follow these steps to test your API using Postman:

### 1. Setup Postman Request
- **Method:** `POST`
- **URL:** Use the ngrok URL from Step 7 + `/generate`
  - Example: `https://abc123-34-134-35-25.ngrok-free.app/generate`

### 2. Set Headers
```
Content-Type: application/json
ngrok-skip-browser-warning: true
```

### 3. Request Body (Raw JSON)
Use this exact format for testing:

In [None]:
# Copy this JSON for Postman testing:

basic_test = {
    "messages": [
        {
            "role": "system",
            "content": [{"type": "text", "text": "You are AITA, a helpful AI travel assistant."}]
        },
        {
            "role": "user",
            "content": [{"type": "text", "text": "What are the best places to visit in Paris?"}]
        }
    ]
}

# Print formatted JSON for easy copying
import json
print("📋 Copy this JSON for Postman:")
print("=" * 50)
print(json.dumps(basic_test, indent=2))
print("=" * 50)

### 4. More Test Examples

**Travel Planning:**

In [None]:
# Travel Planning Test
travel_test = {
    "messages": [
        {
            "role": "system",
            "content": [{"type": "text", "text": "You are AITA, a helpful AI travel assistant."}]
        },
        {
            "role": "user",
            "content": [{"type": "text", "text": "I'm planning a 5-day trip to Tokyo. What should I include in my itinerary?"}]
        }
    ]
}

# Multi-turn Conversation Test
conversation_test = {
    "messages": [
        {
            "role": "system",
            "content": [{"type": "text", "text": "You are AITA, a helpful AI travel assistant."}]
        },
        {
            "role": "user",
            "content": [{"type": "text", "text": "I want to visit Europe this summer"}]
        },
        {
            "role": "assistant",
            "content": [{"type": "text", "text": "That sounds exciting! Europe has so much to offer. What type of experience are you looking for - cultural sites, food, nightlife, or outdoor activities?"}]
        },
        {
            "role": "user",
            "content": [{"type": "text", "text": "I love historical sites and museums"}]
        }
    ]
}

print("🌍 Travel Planning Test:")
print(json.dumps(travel_test, indent=2))
print("\n" + "="*50 + "\n")
print("💬 Multi-turn Conversation Test:")
print(json.dumps(conversation_test, indent=2))

### 5. Expected Response Format

You should get a response like this:
```json
{
  "result": "For a 5-day Tokyo trip, I'd recommend visiting..."
}
```

### 6. Troubleshooting Common Issues

**❌ SSL/TLS Error:** Add `ngrok-skip-browser-warning: true` header

**❌ 404 Not Found:** Make sure your ngrok URL is correct and server is running

**❌ 500 Internal Server Error:** Check the Colab console for error messages

**❌ Timeout:** The model might be processing - wait up to 30 seconds

**❌ CORS Error:** CORS is already configured, this shouldn't happen

### 7. Quick Health Check

Test this endpoint first to make sure server is running:
- **Method:** `GET`  
- **URL:** `{your-ngrok-url}/health`