In [1]:
# Cell 1: Set up the environment and install dependencies
!pip install -U transformers>=4.38.1
!pip install accelerate>=0.21.0
!pip install bitsandbytes>=0.40.0
!pip install huggingface_hub>=0.17.0
!pip install fastapi uvicorn python-dotenv pyngrok
# Cell 2: Import necessary libraries
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from huggingface_hub import login
import os
from fastapi import FastAPI
from pydantic import BaseModel
from dotenv import load_dotenv
import os
from pyngrok import ngrok
import uvicorn
import nest_asyncio
import logging



Collecting fastapi
  Downloading fastapi-0.115.12-py3-none-any.whl.metadata (27 kB)
Collecting uvicorn
  Downloading uvicorn-0.34.2-py3-none-any.whl.metadata (6.5 kB)
Collecting python-dotenv
  Downloading python_dotenv-1.1.0-py3-none-any.whl.metadata (24 kB)
Collecting pyngrok
  Downloading pyngrok-7.2.5-py3-none-any.whl.metadata (8.9 kB)
Collecting starlette<0.47.0,>=0.40.0 (from fastapi)
  Downloading starlette-0.46.2-py3-none-any.whl.metadata (6.2 kB)
Downloading fastapi-0.115.12-py3-none-any.whl (95 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m95.2/95.2 kB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading uvicorn-0.34.2-py3-none-any.whl (62 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.5/62.5 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading python_dotenv-1.1.0-py3-none-any.whl (20 kB)
Downloading pyngrok-7.2.5-py3-none-any.whl (23 kB)
Downloading starlette-0.46.2-py3-none-any.whl (72 kB)
[2K   [90m━━━━━━━━━━━━

In [2]:
# Cell 3: Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


from google.colab import userdata

HF_TOKEN = userdata.get('HF_TOKEN')
NGROK_AUTH_TOKEN = userdata.get('NGROK_AUTH_TOKEN')


if not HF_TOKEN or not NGROK_AUTH_TOKEN:
    raise ValueError("HF_TOKEN or NGROK_AUTH_TOKEN not found in .env file")

In [3]:
# Cell 5: Log in to Hugging Face
login(token=HF_TOKEN, add_to_git_credential=False)
logger.info("Logged in to Hugging Face")

# Cell 6: Configure the model for 4-bit quantization
model_id = "google/gemma-2b-it"
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

# Cell 7: Load the tokenizer and model
try:
    tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=quantization_config,
        device_map="auto",
        token=HF_TOKEN,
        torch_dtype=torch.float16,
        trust_remote_code=True
    )
    logger.info("Model and tokenizer loaded successfully!")
except Exception as e:
    logger.error(f"Error loading model or tokenizer: {str(e)}")
    raise e

tokenizer_config.json:   0%|          | 0.00/34.2k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

In [4]:
!pip freeze > requirements.txt

In [None]:



# Cell 8: Define Pydantic models for input and output
class PredictInput(BaseModel):
    prompt: str
    max_length: int = 200
    temperature: float = 0.7
    top_k: int = 50
    top_p: float = 0.95

class PredictOutput(BaseModel):
    response: str
    prompt: str
    error: str | None = None

# Cell 9: Define inference function
def generate_response(prompt: str, max_length: int, temperature: float, top_k: int, top_p: float) -> str:
    try:
        inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            num_return_sequences=1,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            do_sample=True
        )
        return tokenizer.decode(outputs[0], skip_special_tokens=True)
    except Exception as e:
        logger.error(f"Error during generation: {str(e)}")
        return f"Error: {str(e)}"

# Cell 10: Set up FastAPI app
app = FastAPI(title="Gemma AI API")

@app.post("/predict", response_model=PredictOutput)
async def predict(input_data: PredictInput):
    logger.info(f"Received request with prompt: {input_data.prompt}")
    response = generate_response(
        prompt=input_data.prompt,
        max_length=input_data.max_length,
        temperature=input_data.temperature,
        top_k=input_data.top_k,
        top_p=input_data.top_p
    )
    return PredictOutput(
        response=response,
        prompt=input_data.prompt,
        error=None if not response.startswith("Error:") else response
    )

# Cell 11: Start Ngrok and FastAPI server
def start_server():
    try:
        # Apply nest_asyncio to allow Uvicorn to run in Colab
        nest_asyncio.apply()

        # Set Ngrok auth token
        ngrok.set_auth_token(NGROK_AUTH_TOKEN)

        # Start Ngrok tunnel
        public_url = ngrok.connect(8000, bind_tls=True)
        logger.info(f"Ngrok tunnel started at: {public_url}")

        # Start FastAPI server
        uvicorn.run(app, host="0.0.0.0", port=8000)
    except Exception as e:
        logger.error(f"Error starting server: {str(e)}")
        raise e

# Cell 12: Run the server
if __name__ == "__main__":
    start_server()



INFO:     Started server process [1181]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)


INFO:     2a0d:6fc2:63f0:c200:9b49:3930:c16d:ac1a:0 - "POST /predict HTTP/1.1" 200 OK
INFO:     2a0d:6fc2:63f0:c200:f9ee:add9:fa8d:517d:0 - "GET /predict HTTP/1.1" 405 Method Not Allowed
INFO:     2601:647:5b00:9600:8555:6600:9099:8966:0 - "GET /predict HTTP/1.1" 405 Method Not Allowed
INFO:     2601:647:5b00:9600:8555:6600:9099:8966:0 - "GET /favicon.ico HTTP/1.1" 404 Not Found
INFO:     2601:647:5b00:9600:8555:6600:9099:8966:0 - "GET /predict HTTP/1.1" 405 Method Not Allowed
INFO:     2601:647:5b00:9600:8555:6600:9099:8966:0 - "GET /predict HTTP/1.1" 405 Method Not Allowed
INFO:     2601:647:5b00:9600:8555:6600:9099:8966:0 - "GET /favicon.ico HTTP/1.1" 404 Not Found
INFO:     2601:647:5b00:9600:8555:6600:9099:8966:0 - "GET /predict HTTP/1.1" 405 Method Not Allowed
INFO:     2601:647:5b00:9600:8555:6600:9099:8966:0 - "GET /predict HTTP/1.1" 405 Method Not Allowed
INFO:     2601:647:5b00:9600:8555:6600:9099:8966:0 - "POST /predict HTTP/1.1" 200 OK
INFO:     2601:647:5b00:9600:8555:6600