In [20]:
import os
import torch
import uvicorn
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
from fastapi.middleware.cors import CORSMiddleware
from functools import lru_cache

#### Optional transliteration fallback (no indictrans2)

In [21]:
try:
    from indic_transliteration import sanscript
    from indic_transliteration.sanscript import transliterate as itransliterate
except Exception:
    sanscript = None
    itransliterate = None

In [22]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_ID = "ai4bharat/indictrans2-en-indic-1B"
INDIC_EN_MODEL_ID = "ai4bharat/indictrans2-indic-en-1B"

In [23]:
LANGUAGE_TAGS = {
    "en": "eng_Latn",
    "hi": "hin_Deva",
    "ta": "tam_Taml",
    "te": "tel_Telu",
    "kn": "kan_Knda",
    "ml": "mal_Mlym",
    "bn": "ben_Beng",
    "mr": "mar_Deva",
    "gu": "guj_Gujr",
    "or": "ory_Orya",  # if this fails for your model snapshot, try "ori_Orya"
    "pa": "pan_Guru"
}

LANGUAGE_ISO3 = {k: v.split("_")[0] for k, v in LANGUAGE_TAGS.items()}
LANGUAGE_SCRIPT = {k: v.split("_")[1] for k, v in LANGUAGE_TAGS.items()}

# Unicode script ranges for sanity check
SCRIPT_RANGES = {
    "Latn": (0x0041, 0x007A),  # coarse Latin range (A-z)
    "Deva": (0x0900, 0x097F),
    "Beng": (0x0980, 0x09FF),
    "Guru": (0x0A00, 0x0A7F),
    "Gujr": (0x0A80, 0x0AFF),
    "Orya": (0x0B00, 0x0B7F),
    "Taml": (0x0B80, 0x0BFF),
    "Telu": (0x0C00, 0x0C7F),
    "Knda": (0x0C80, 0x0CFF),
    "Mlym": (0x0D00, 0x0D7F),
}

In [24]:
def looks_like_script(s: str, script: str) -> bool:
    lo, hi = SCRIPT_RANGES.get(script, (None, None))
    if lo is None:  # unknown script key → don't block
        return True
    return any(lo <= ord(ch) <= hi for ch in s)

# Map our script keys -> indic-transliteration constants
SANSCRIPT_MAP = None
if sanscript is not None:
    SANSCRIPT_MAP = {
        "Deva": getattr(sanscript, "DEVANAGARI", None),
        "Beng": getattr(sanscript, "BENGALI", None),
        "Guru": getattr(sanscript, "GURMUKHI", None),
        "Gujr": getattr(sanscript, "GUJARATI", None),
        "Orya": getattr(sanscript, "ORIYA", None),   # a.k.a. Odia
        "Taml": getattr(sanscript, "TAMIL", None),
        "Telu": getattr(sanscript, "TELUGU", None),
        "Knda": getattr(sanscript, "KANNADA", None),
        "Mlym": getattr(sanscript, "MALAYALAM", None),
    }

In [25]:
def transliterate_if_needed(text: str, target_script: str) -> str:
    """
    If output isn't in target script but is Devanagari, try converting
    Devanagari -> target_script using indic-transliteration (if available).
    """
    if looks_like_script(text, target_script):
        return text
    if looks_like_script(text, "Deva") and itransliterate and SANSCRIPT_MAP:
        src = SANSCRIPT_MAP.get("Deva")
        dst = SANSCRIPT_MAP.get(target_script)
        if src and dst:
            try:
                return itransliterate(text, src, dst)
            except Exception:
                pass
    return text


In [26]:
# -----------------------------
# App lifecycle
# -----------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
    try:
        print("Starting up... Loading model/tokenizer/pipelines")
        get_tokenizer()
        get_model()
        get_translation_pipe_en_to_indic()
        get_translation_pipe_indic_to_en()
        print("Resources loaded successfully")
    except Exception as e:
        print(f"Warning: Could not pre-load resources: {e}")
    yield
    print("Shutting down...")

app = FastAPI(lifespan=lifespan)

In [27]:
# Static & templates
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")

In [28]:
# CORS (relax in dev)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

In [29]:
# -----------------------------
# Cached loaders
# -----------------------------
@lru_cache(maxsize=1)
def load_tokenizer():
    print("Loading tokenizer...")
    tok = AutoTokenizer.from_pretrained(
        MODEL_ID,
        trust_remote_code=True,
        cache_dir="/app/.cache" if os.path.exists("/app") else None
    )
    print("Tokenizer loaded")
    return tok

In [30]:
@lru_cache(maxsize=1)
def load_model():
    print("Loading model...")
    configs = [
        dict(trust_remote_code=True,
             torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
             cache_dir="/app/.cache" if os.path.exists("/app") else None,
             low_cpu_mem_usage=True),
        dict(trust_remote_code=True,
             torch_dtype=torch.float32,
             cache_dir="/app/.cache" if os.path.exists("/app") else None),
        dict(trust_remote_code=True),
    ]
    last_err = None
    for i, cfg in enumerate(configs, 1):
        try:
            print(f"Trying model strategy {i}...")
            mdl = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID, **cfg)

            # Disable caching to avoid "Cache only has 0 layers"
            for attr in ("config", "generation_config"):
                obj = getattr(mdl, attr, None)
                if obj is not None:
                    try: setattr(obj, "use_cache", False)
                    except Exception: pass
            try: mdl.config.cache_implementation = None
            except Exception: pass
            try: mdl.cache_implementation = None
            except Exception: pass

            mdl.eval().to(DEVICE)
            print(f"Model loaded on {DEVICE}")
            return mdl
        except Exception as e:
            print(f"Strategy {i} failed: {e}")
            last_err = e
    raise RuntimeError(f"Failed to load model: {last_err}")


In [31]:
# Globals
tokenizer = None
model = None

def get_tokenizer():
    global tokenizer
    if tokenizer is None:
        tokenizer = load_tokenizer()
    return tokenizer

def get_model():
    global model
    if model is None:
        model = load_model()
    return model

In [32]:
# Additional loaders for Indic→English
@lru_cache(maxsize=1)
def load_tokenizer_indic_en():
    print("Loading INDIC→EN tokenizer...")
    tok = AutoTokenizer.from_pretrained(
        INDIC_EN_MODEL_ID,
        trust_remote_code=True,
        cache_dir="/app/.cache" if os.path.exists("/app") else None
    )
    print("INDIC→EN tokenizer loaded")
    return tok

@lru_cache(maxsize=1)
def load_model_indic_en():
    print("Loading INDIC→EN model...")
    configs = [
        dict(trust_remote_code=True,
             torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
             cache_dir="/app/.cache" if os.path.exists("/app") else None,
             low_cpu_mem_usage=True),
        dict(trust_remote_code=True,
             torch_dtype=torch.float32,
             cache_dir="/app/.cache" if os.path.exists("/app") else None),
        dict(trust_remote_code=True),
    ]
    last_err = None
    for i, cfg in enumerate(configs, 1):
        try:
            print(f"Trying INDIC→EN model strategy {i}...")
            mdl = AutoModelForSeq2SeqLM.from_pretrained(INDIC_EN_MODEL_ID, **cfg)
            for attr in ("config", "generation_config"):
                obj = getattr(mdl, attr, None)
                if obj is not None:
                    try: setattr(obj, "use_cache", False)
                    except Exception: pass
            try: mdl.config.cache_implementation = None
            except Exception: pass
            try: mdl.cache_implementation = None
            except Exception: pass
            mdl.eval().to(DEVICE)
            print(f"INDIC→EN model loaded on {DEVICE}")
            return mdl
        except Exception as e:
            print(f"INDIC→EN strategy {i} failed: {e}")
            last_err = e
    raise RuntimeError(f"Failed to load INDIC→EN model: {last_err}")

# Getters
@lru_cache(maxsize=1)
def get_tokenizer_indic_en():
    return load_tokenizer_indic_en()

@lru_cache(maxsize=1)
def get_model_indic_en():
    return load_model_indic_en()

In [33]:
@lru_cache(maxsize=1)
def get_translation_pipe_en_to_indic():
    tok = get_tokenizer()
    mdl = get_model()
    device_idx = 0 if torch.cuda.is_available() else -1
    return pipeline(
        "translation",
        model=mdl,
        tokenizer=tok,
        trust_remote_code=True,
        device=device_idx
    )

@lru_cache(maxsize=1)
def get_translation_pipe_indic_to_en():
    tok = get_tokenizer_indic_en()
    mdl = get_model_indic_en()
    device_idx = 0 if torch.cuda.is_available() else -1
    return pipeline(
        "translation",
        model=mdl,
        tokenizer=tok,
        trust_remote_code=True,
        device=device_idx
    )

In [34]:

# -----------------------------
# Schemas
# -----------------------------
class TranslationRequest(BaseModel):
    source_text: str
    target_lang: str  # 'hi', 'ta', etc.
    source_lang: str = "en"  # default English, but allow any supported

# -----------------------------
# Core translation
# -----------------------------
@lru_cache(maxsize=512)
def cached_translation(source_text: str, target_lang: str, source_lang: str = "en") -> str:
    if target_lang not in LANGUAGE_TAGS:
        raise ValueError(f"Unsupported target language: {target_lang}")
    if source_lang not in LANGUAGE_TAGS:
        raise ValueError(f"Unsupported source language: {source_lang}")
    if source_lang == target_lang:
        return (source_text or "").strip()

    tgt_tag = LANGUAGE_TAGS[target_lang]
    src_tag = LANGUAGE_TAGS[source_lang]
    tgt_iso, tgt_script = tgt_tag.split("_")

    text = (source_text or "").strip()
    if not text:
        return ""

    def generate_with_tags(text: str, src: str, tgt: str, use_indic_en: bool = False) -> str:
        if use_indic_en:
            tok, mdl = get_tokenizer_indic_en(), get_model_indic_en()
        else:
            tok, mdl = get_tokenizer(), get_model()
        # Format: "{src_tag} {tgt_tag} {text}"
        tagged = f"{src} {tgt} {text}"
        inputs = tok(tagged, return_tensors="pt", padding=True, truncation=True, max_length=512)
        inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = mdl.generate(
                **inputs,
                max_length=512,
                num_beams=1,
                do_sample=False,
                use_cache=False,
                pad_token_id=tok.pad_token_id,
                eos_token_id=tok.eos_token_id,
            )
        cand = tok.batch_decode(outputs, skip_special_tokens=True)[0].strip()
        return cand

    try:
        if source_lang == "en":
            cand = generate_with_tags(text, src_tag, tgt_tag, use_indic_en=False)
        elif target_lang == "en":
            cand = generate_with_tags(text, src_tag, "eng_Latn", use_indic_en=True)
        else:
            # Indic -> Indic via English pivot
            mid = generate_with_tags(text, src_tag, "eng_Latn", use_indic_en=True)
            cand = generate_with_tags(mid, "eng_Latn", tgt_tag, use_indic_en=False)

        # Script correction for non-English targets
        if target_lang != "en" and not looks_like_script(cand, tgt_script) and looks_like_script(cand, "Deva"):
            cand = transliterate_if_needed(cand, tgt_script)
        return cand
    except Exception as e:
        raise RuntimeError(f"Model translation failed: {e}") from e


In [35]:
# -----------------------------
# Routes
# -----------------------------
app = app  # keep reference name stable

@app.post("/api/v1/translate")
def translate(request: TranslationRequest):
    try:
        translated_text = cached_translation(
            request.source_text,
            request.target_lang,
            request.source_lang,
        )
        return {"translated_text": translated_text}
    except ValueError as ve:
        raise HTTPException(status_code=400, detail=str(ve))
    except RuntimeError as re:
        raise HTTPException(status_code=500, detail=str(re))


In [36]:

@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})


In [37]:
@app.get("/health")
def health_check():
    try:
        _tok = get_tokenizer()
        _mdl = get_model()
        _pipe_en_indic = get_translation_pipe_en_to_indic()
        _pipe_indic_en = get_translation_pipe_indic_to_en()
        return {
            "status": "healthy",
            "device": str(DEVICE),
            "model_loaded": _mdl is not None,
            "tokenizer_loaded": _tok is not None,
            "pipeline_en_indic": _pipe_en_indic is not None,
            "pipeline_indic_en": _pipe_indic_en is not None,
            "translit_enabled": bool(itransliterate and SANSCRIPT_MAP),
        }
    except Exception as e:
        return {
            "status": "unhealthy",
            "error": str(e),
            "device": str(DEVICE),
        }

In [38]:
# -----------------------------
# Entrypoint (Azure PORT-ready)
# -----------------------------
try:
    get_ipython  # defined only in IPython/Jupyter
    IN_IPYTHON = True
except NameError:
    IN_IPYTHON = False

if IN_IPYTHON:
    import nest_asyncio
    nest_asyncio.apply()

    port = int(os.environ.get("PORT", 8000))
    config = uvicorn.Config(app, host="0.0.0.0", port=port, log_level="info")
    server = uvicorn.Server(config)
    await server.serve()
else:
    if __name__ == "__main__":
        port = int(os.environ.get("PORT", 8000))
        uvicorn.run(app, host="0.0.0.0", port=port)


INFO:     Started server process [88268]
INFO:     Waiting for application startup.
INFO:     Waiting for application startup.


Starting up... Loading model/tokenizer/pipelines
Loading tokenizer...
Tokenizer loaded
Loading model...
Trying model strategy 1...
Tokenizer loaded
Loading model...
Trying model strategy 1...


Device set to use cpu


Model loaded on cpu
Loading INDIC→EN tokenizer...
INDIC→EN tokenizer loaded
Loading INDIC→EN model...
Trying INDIC→EN model strategy 1...
INDIC→EN tokenizer loaded
Loading INDIC→EN model...
Trying INDIC→EN model strategy 1...


Device set to use cpu
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)


INDIC→EN model loaded on cpu
Resources loaded successfully
INFO:     127.0.0.1:55994 - "POST /api/v1/translate HTTP/1.1" 200 OK
INFO:     127.0.0.1:55994 - "POST /api/v1/translate HTTP/1.1" 200 OK
INFO:     127.0.0.1:56145 - "POST /api/v1/translate HTTP/1.1" 200 OK
INFO:     127.0.0.1:56145 - "POST /api/v1/translate HTTP/1.1" 200 OK
INFO:     127.0.0.1:57183 - "POST /api/v1/translate HTTP/1.1" 200 OK
INFO:     127.0.0.1:57183 - "POST /api/v1/translate HTTP/1.1" 200 OK
INFO:     127.0.0.1:57212 - "POST /api/v1/translate HTTP/1.1" 200 OK
INFO:     127.0.0.1:57212 - "POST /api/v1/translate HTTP/1.1" 200 OK


INFO:     Shutting down
INFO:     Waiting for application shutdown.
INFO:     Application shutdown complete.
INFO:     Finished server process [88268]
INFO:     Waiting for application shutdown.
INFO:     Application shutdown complete.
INFO:     Finished server process [88268]


Shutting down...
