In [None]:
# Install dependencies
!pip install transformers torch flask pyngrok sentencepiece -q

In [None]:
# Load the model
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

MODEL_NAME = "tzaware/codet5p-spider-finetuned"

print(f"Loading tokenizer and model from {MODEL_NAME}...")
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
except:
    print("Using fallback tokenizer")
    tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base")

device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
model.eval()

print(f"âœ“ Model loaded on {device}")

In [None]:
# Create Flask API
from flask import Flask, request, jsonify

app = Flask(__name__)

@app.route('/health', methods=['GET'])
def health():
    return jsonify({"status": "healthy", "device": device})

@app.route('/generate', methods=['POST'])
def generate():
    try:
        data = request.json
        prompt = data.get('prompt', '')
        max_new_tokens = data.get('max_new_tokens', 128)

        # Tokenize
        inputs = tokenizer(
            prompt,
            return_tensors="pt",
            max_length=512,
            truncation=True,
            padding=True
        ).to(device)

        # Generate
        with torch.no_grad():
            outputs = model.generate(
                input_ids=inputs['input_ids'],
                attention_mask=inputs.get('attention_mask', None),
                max_new_tokens=max_new_tokens,
                num_beams=4,
                early_stopping=True,
                pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id
            )

        # Decode
        sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)

        return jsonify({
            "generated_sql": sql_query.strip(),
            "status": "success"
        })

    except Exception as e:
        return jsonify({
            "error": str(e),
            "status": "error"
        }), 500

print("âœ“ Flask API created")

In [None]:
# Setup ngrok
from pyngrok import ngrok
import getpass

# Get ngrok auth token
print("Get your ngrok auth token from: https://dashboard.ngrok.com/get-started/your-authtoken")
ngrok_token = getpass.getpass("Enter your ngrok auth token: ")
ngrok.set_auth_token(ngrok_token)

# Start ngrok tunnel
public_url = ngrok.connect(5000)
print(f"\n" + "="*60)
print(f"ðŸš€ Model Server Running!")
print(f"Public URL: {public_url}")
print(f"="*60)
print(f"\nUse this URL in your backend config:")
print(f"CUSTOM_MODEL_API_URL = '{public_url}'")
print(f"\nTest it: curl {public_url}/health")

In [None]:
# Run the Flask server
from flask import Flask
import threading

def run_app():
    app.run(port=5000, debug=False, use_reloader=False)

# Start server in background thread
server_thread = threading.Thread(target=run_app, daemon=True)
server_thread.start()

print("\nâœ“ Server is running! Keep this cell running...")
print("Press Ctrl+C to stop the server")

# Keep the cell running
import time
try:
    while True:
        time.sleep(1)
except KeyboardInterrupt:
    print("\nServer stopped")