# TranslateGemma 4B - Google Colab 實驗

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jimmyliao/trans-gemma/blob/main/translategemma-colab.ipynb)

本 notebook 展示如何在 Google Colab 免費 GPU 上運行 TranslateGemma 4B 翻譯模型。

## 環境需求
- Google Colab T4 GPU（免費）
- Python 3.10+
- Transformers, PyTorch

## 使用方式
1. 點擊上方「Open In Colab」按鈕
2. Runtime > Change runtime type > T4 GPU
3. 依序執行所有 cells

## 1. 環境檢查與設置

In [None]:
# 檢查 GPU 是否可用
!nvidia-smi

In [None]:
# 安裝必要套件
!pip install -q transformers torch accelerate huggingface_hub sentencepiece protobuf

In [None]:
# 檢查 Python 和套件版本
import sys
import torch
import transformers

print(f"Python: {sys.version}")
print(f"PyTorch: {torch.__version__}")
print(f"Transformers: {transformers.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

## 2. 載入 TranslateGemma 4B 模型

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# 模型名稱
MODEL_ID = "google/translategemma-4b-it"

print(f"Loading model: {MODEL_ID}")
print("This may take a few minutes...")

# 載入 tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# 載入模型（使用 bfloat16 降低記憶體使用）
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto"  # 自動分配到 GPU
)

print("Model loaded successfully!")
print(f"Model device: {model.device}")
print(f"Model dtype: {model.dtype}")

## 3. 基本翻譯功能測試

In [None]:
def translate(text, target_lang="Traditional Chinese", max_new_tokens=256):
    """
    翻譯文本到指定語言
    
    Args:
        text: 要翻譯的文本
        target_lang: 目標語言（例如：Traditional Chinese, Japanese, French）
        max_new_tokens: 最大生成 token 數
    
    Returns:
        翻譯結果
    """
    # 構建 prompt
    messages = [
        {
            "role": "user",
            "content": f"Translate this to {target_lang}: {text}"
        }
    ]
    
    # 使用 chat template
    inputs = tokenizer.apply_chat_template(
        messages,
        return_tensors="pt",
        add_generation_prompt=True
    ).to(model.device)
    
    # 生成翻譯
    with torch.no_grad():
        outputs = model.generate(
            inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,  # 使用 greedy decoding 確保一致性
            pad_token_id=tokenizer.eos_token_id
        )
    
    # 解碼結果
    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # 提取翻譯部分（移除 prompt）
    # TranslateGemma 的輸出格式通常包含完整對話，我們只需要翻譯結果
    if "Translate this to" in result:
        result = result.split("\n")[-1].strip()
    
    return result

In [None]:
# 測試英文到繁體中文
test_text = "Hello, how are you today? I hope you're having a great day!"
print(f"Original (EN): {test_text}")
print(f"\nTranslated (ZH-TW): {translate(test_text, 'Traditional Chinese')}")

In [None]:
# 測試英文到日文
test_text = "Machine learning is transforming the world."
print(f"Original (EN): {test_text}")
print(f"\nTranslated (JA): {translate(test_text, 'Japanese')}")

In [None]:
# 測試中文到英文
test_text = "我喜歡在週末閱讀和寫程式。"
print(f"Original (ZH): {test_text}")
print(f"\nTranslated (EN): {translate(test_text, 'English')}")

## 4. 效能評估

In [None]:
import time

def benchmark_translation(text, target_lang="Traditional Chinese", num_runs=5):
    """
    評估翻譯效能
    """
    times = []
    
    for i in range(num_runs):
        start = time.time()
        result = translate(text, target_lang)
        elapsed = time.time() - start
        times.append(elapsed)
        print(f"Run {i+1}: {elapsed:.2f}s")
    
    avg_time = sum(times) / len(times)
    print(f"\nAverage time: {avg_time:.2f}s")
    print(f"Min time: {min(times):.2f}s")
    print(f"Max time: {max(times):.2f}s")
    
    return result

# 執行基準測試
test_text = "Artificial intelligence is changing how we live and work."
print(f"Testing with: {test_text}\n")
result = benchmark_translation(test_text)

## 5. 批量翻譯測試

In [None]:
# 準備多個測試句子
test_sentences = [
    "Good morning!",
    "Thank you for your help.",
    "Where is the nearest restaurant?",
    "I would like to book a hotel room.",
    "The weather is beautiful today."
]

print("Batch Translation Test (EN → ZH-TW)\n")
print("=" * 80)

for i, sentence in enumerate(test_sentences, 1):
    translation = translate(sentence, "Traditional Chinese")
    print(f"{i}. EN: {sentence}")
    print(f"   ZH: {translation}")
    print()

## 6. API 設計原型

In [None]:
from typing import Dict, Optional
from dataclasses import dataclass

@dataclass
class TranslationRequest:
    """翻譯請求"""
    text: str
    target_lang: str = "Traditional Chinese"
    max_tokens: int = 256

@dataclass
class TranslationResponse:
    """翻譯回應"""
    original: str
    translated: str
    target_lang: str
    status: str = "success"
    error: Optional[str] = None

class TranslationService:
    """翻譯服務類別（用於 FastAPI）"""
    
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
    
    def translate(self, request: TranslationRequest) -> TranslationResponse:
        """執行翻譯"""
        try:
            # 使用之前定義的 translate 函數
            translated = translate(
                request.text,
                request.target_lang,
                request.max_tokens
            )
            
            return TranslationResponse(
                original=request.text,
                translated=translated,
                target_lang=request.target_lang,
                status="success"
            )
        except Exception as e:
            return TranslationResponse(
                original=request.text,
                translated="",
                target_lang=request.target_lang,
                status="error",
                error=str(e)
            )

# 測試服務
service = TranslationService(model, tokenizer)
request = TranslationRequest(
    text="Hello, world!",
    target_lang="Traditional Chinese"
)
response = service.translate(request)

print(f"Status: {response.status}")
print(f"Original: {response.original}")
print(f"Translated: {response.translated}")

## 7. 記憶體使用分析

In [None]:
if torch.cuda.is_available():
    print("GPU Memory Usage:")
    print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
    print(f"Max Allocated: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")

## 8. 互動式翻譯工具

In [None]:
# 支援的語言列表
SUPPORTED_LANGUAGES = [
    "Traditional Chinese",
    "Simplified Chinese",
    "Japanese",
    "Korean",
    "French",
    "German",
    "Spanish",
    "Italian",
    "Portuguese",
    "Russian",
    "Arabic",
    "Hindi",
    "Vietnamese",
    "Thai",
    "Indonesian"
]

def interactive_translate():
    """互動式翻譯工具"""
    print("=" * 80)
    print("TranslateGemma Interactive Tool")
    print("=" * 80)
    print("\nSupported Languages:")
    for i, lang in enumerate(SUPPORTED_LANGUAGES, 1):
        print(f"{i:2d}. {lang}")
    print("\nType 'quit' to exit\n")
    
    while True:
        text = input("\nEnter text to translate (or 'quit'): ").strip()
        if text.lower() == 'quit':
            break
        
        if not text:
            print("Please enter some text.")
            continue
        
        lang_input = input("Target language (name or number): ").strip()
        
        # 處理數字輸入
        if lang_input.isdigit():
            lang_idx = int(lang_input) - 1
            if 0 <= lang_idx < len(SUPPORTED_LANGUAGES):
                target_lang = SUPPORTED_LANGUAGES[lang_idx]
            else:
                print("Invalid language number.")
                continue
        else:
            target_lang = lang_input
        
        print(f"\nTranslating to {target_lang}...")
        result = translate(text, target_lang)
        print(f"Result: {result}")

# 執行互動式工具（在 Colab 中可用）
# interactive_translate()

## 9. 部署準備

以下代碼準備部署到 Cloud Run 所需的檔案。

In [None]:
# 生成 requirements.txt
requirements = """transformers>=4.36.0
torch>=2.1.0
accelerate>=0.25.0
huggingface_hub>=0.20.0
sentencepiece>=0.1.99
protobuf>=4.25.0
fastapi>=0.109.0
uvicorn[standard]>=0.27.0
pydantic>=2.5.0
"""

print("requirements.txt content:")
print(requirements)

In [None]:
# FastAPI 應用範例（main.py）
fastapi_code = '''
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from typing import Optional

app = FastAPI(title="TranslateGemma API")

# 全域變數
model = None
tokenizer = None

class TranslationRequest(BaseModel):
    text: str
    target_lang: str = "Traditional Chinese"
    max_tokens: int = 256

class TranslationResponse(BaseModel):
    original: str
    translated: str
    target_lang: str

@app.on_event("startup")
async def load_model():
    global model, tokenizer
    MODEL_ID = "google/translategemma-4b-it"
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16,
        device_map="auto"
    )

@app.get("/")
async def root():
    return {"message": "TranslateGemma API", "version": "1.0.0"}

@app.post("/translate", response_model=TranslationResponse)
async def translate(request: TranslationRequest):
    try:
        messages = [{
            "role": "user",
            "content": f"Translate this to {request.target_lang}: {request.text}"
        }]
        
        inputs = tokenizer.apply_chat_template(
            messages,
            return_tensors="pt",
            add_generation_prompt=True
        ).to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(
                inputs,
                max_new_tokens=request.max_tokens,
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id
            )
        
        result = tokenizer.decode(outputs[0], skip_special_tokens=True)
        if "Translate this to" in result:
            result = result.split("\\n")[-1].strip()
        
        return TranslationResponse(
            original=request.text,
            translated=result,
            target_lang=request.target_lang
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))
'''

print("FastAPI main.py content:")
print(fastapi_code)

## 10. 總結與下一步

本 notebook 完成了以下內容：

✅ 在 Google Colab T4 GPU 上成功載入 TranslateGemma 4B
✅ 實作基本翻譯功能
✅ 效能評估與基準測試
✅ 批量翻譯測試
✅ API 設計原型
✅ 部署準備（requirements.txt, main.py）

### 下一步

1. **部署到 Cloud Run**:
   - 使用 `cloudrun/deploy.sh` 腳本
   - 或透過 GitHub Actions 自動部署

2. **效能優化**:
   - 實作批次處理
   - 使用 vLLM 或 TGI 加速推理
   - 模型量化（INT8/INT4）

3. **功能擴展**:
   - 支援更多語言對
   - 實作翻譯品質評估
   - 加入快取機制

### 相關資源

- [TranslateGemma 官方文檔](https://blog.google/innovation-and-ai/technology/developers-tools/translategemma/)
- [Cloud Run GPU 部署指南](https://cloud.google.com/run/docs/configuring/services/gpu)
- [專案 GitHub Repository](https://github.com/jimmyliao/trans-gemma)