In [None]:
# 1. SETUP AND INSTALLATION
# -------------------------
print("Installing required packages (transformers from source)...")
!pip install -q pandas --upgrade
!pip install -q accelerate
!pip install -q git+https://github.com/huggingface/transformers.git
!pip install -q sentencepiece

import os
import time
import torch
import kagglehub
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List

print("✅ Packages installed successfully!")


# 2. MODEL CONFIGURATION
# ----------------------
print("Downloading generative model files with kagglehub...")
MODEL_PATH = kagglehub.model_download("google/gemma-3n/transformers/gemma-3n-e2b-it")
print(f"✅ Model path set to: {MODEL_PATH}")


# 3. GEMMA GENERATIVE TRANSLATOR CLASS
# ------------------------------------
class GemmaTranslator:
    def __init__(self, model_path: str):
        """Initializes the translator by loading the generative model with Transformers."""
        self.model_path = model_path
        self.tokenizer = None
        self.model = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self._initialize_model()

    def _initialize_model(self):
        """Loads the Gemma model using the AutoModelForCausalLM interface."""
        if not os.path.exists(self.model_path):
            print(f"❌ Cannot initialize model: Path '{self.model_path}' does not exist.")
            return

        try:
            print("🔄 Loading tokenizer...")
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)

            print("🔄 Loading model... (This may take a moment)")
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_path,
                torch_dtype=torch.float16,
            ).to(self.device)

            print("✅ Translator initialized successfully!")
            if self.model:
                print(f"📊 Model loaded on: {self.model.device}")

        except Exception as e:
            print(f"❌ Error initializing model: {e}")
            self.model = None
            self.tokenizer = None

    def translate(self, text: str, target_language: str = "Italian", source_language: str = "English") -> str:
        """
        Translates text by generating a translation from scratch.
        No phrasebook is used.
        """
        if not self.model or not self.tokenizer:
            return "❌ Model not initialized. Please check for errors above."

        # Create a dynamic prompt for the instruction-tuned model.
        prompt = f"""Translate the following {source_language} text to {target_language}. Provide only the {target_language} translation.

{source_language}: "{text}"
{target_language}:"""

        try:
            start_time = time.time()
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)

            # Use disable_compile=True for compatibility with older Kaggle GPUs (like P100)
            outputs = self.model.generate(**inputs, max_new_tokens=150, disable_compile=True)

            # Decode and clean the output
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            end_time = time.time()

            # Extract only the clean translation from the full response.
            if f"{target_language}:" in response:
                translation = response.split(f"{target_language}:")[-1].strip()
            else:
                # Fallback if the model doesn't follow the format perfectly.
                translation = response[len(prompt) - len(f'\n{target_language}:'):].strip()

            # Final cleanup to remove any stray quotation marks.
            translation = translation.replace('"', '').strip()

            print(f"⏱️  '{text[:30]}...' -> {target_language} in {end_time - start_time:.2f}s")
            return translation

        except Exception as e:
            return f"❌ Translation error: {e}"

    def batch_translate(self, texts: List[str], target_language: str = "Italian") -> List[str]:
        """Translates a list of texts."""
        return [self.translate(text, target_language) for text in texts]


# 4. INITIALIZE AND TEST THE TRANSLATOR
# -------------------------------------
translator = GemmaTranslator(MODEL_PATH)

if translator.model:
    # Supported languages for testing
    SUPPORTED_LANGUAGES = {
        "Italian": "Italiano",
        "Arabic": "العربية",
        "Chinese": "中文",
        "Spanish": "Español",
        "French": "Français",
        "German": "Deutsch",
        "Japanese": "日本語",
    }
    print("\n🌍 Supported Test Languages:", ", ".join(SUPPORTED_LANGUAGES.keys()))

    print("\n🔄 Starting Translation Examples...")
    print("=" * 60)

    # Translate to various languages
    test_sentence = "Where is the nearest restaurant?"
    for lang_code, lang_name in SUPPORTED_LANGUAGES.items():
        print(f"\n🇺🇸 English to {lang_name} ({lang_code})")
        print("-" * 30)
        translation = translator.translate(test_sentence, lang_code)
        print(f"EN: {test_sentence}")
        print(f"{lang_code[:2].upper()}: {translation}")


# 5. PERFORMANCE ANALYSIS
# -----------------------
def analyze_performance():
    if not translator.model:
        print("❌ Cannot analyze performance, model not loaded.")
        return

    test_text = "Can you recommend a good local hotel?"
    results = []

    print("\n\n📊 Performance Analysis")
    print("=" * 40)
    print(f"Testing with sentence: \"{test_text}\"")

    for language in ["French", "Spanish", "German", "Japanese"]:
        translation = translator.translate(test_text, language)
        # We can't get the time from the translate function directly anymore, so we re-time here for the analysis
        start_time = time.time()
        _ = translator.translate(test_text, language) # Run again just for timing
        end_time = time.time()

        results.append({
            'Language': language,
            'Translation': translation,
            'Time (seconds)': round(end_time - start_time, 2)
        })

    df = pd.DataFrame(results)
    print("\n--- Performance Results ---")
    print(df.to_string(index=False))
    print("\n--- Summary ---")
    print(f"📊 Average translation time: {df['Time (seconds)'].mean():.2f} seconds")
    print(f"📊 Fastest translation: {df.loc[df['Time (seconds)'].idxmin(), 'Language']} ({df['Time (seconds)'].min():.2f}s)")
    print(f"📊 Slowest translation: {df.loc[df['Time (seconds)'].idxmax(), 'Language']} ({df['Time (seconds)'].max():.2f}s)")
    return df

performance_df = analyze_performance()


# 6. INTERACTIVE TRANSLATOR
# -------------------------
def interactive_translator():
    if not translator.model:
        print("❌ Cannot start interactive mode, model not loaded.")
        return

    print("\n\n🌍 Gemma Generative Interactive Translator")
    print("=" * 40)
    print("Available languages:", ", ".join(SUPPORTED_LANGUAGES.keys()))
    print("Type 'quit' to exit\n")

    while True:
        text = input("📝 Enter text to translate (English): ")
        if text.lower() == 'quit':
            break

        target_lang = input(f"🎯 Target language ({'/'.join(SUPPORTED_LANGUAGES.keys())}): ")
        if target_lang not in SUPPORTED_LANGUAGES:
            print(f"❌ Language '{target_lang}' not in test list. Defaulting to Italian.")
            target_lang = "Italian"

        translation = translator.translate(text, target_lang)
        print(f"✅ Translation: {translation}")
        print("-" * 50)

# Uncomment the line below to run the interactive translator
# interactive_translator()


# 7. MODEL INFORMATION
# --------------------
def display_model_info():
    print("\n\n🤖 GEMMA 3N MODEL INFORMATION")
    print("=" * 50)
    print("📊 Model Used: google/gemma-3n/transformers/gemma-3n-e2b-it")
    print("📊 Architecture: Gemma 3n-E2B-IT (Instruction Tuned Generative Model)")
    print("📊 Framework: PyTorch / Transformers")
    print(f"📊 Inference Device: {translator.device.upper()}")
    print("📊 Key Feature: Generates translations for arbitrary text, no phrasebook required.")
    print("\n🚀 FIX APPLIED:")
    print("• `disable_compile=True` used in `model.generate()` to ensure compatibility on all GPUs.")

display_model_info()

# Verify environment and GPU usage
print("\n🔍 Verifying GPU Usage")
!nvidia-smi