In [1]:
from transformers import NllbTokenizer, AutoModelForSeq2SeqLM
import torch
from torch.utils.data import Dataset, DataLoader
import json
from pathlib import Path
import logging
from typing import List, Dict, Optional, Union, Tuple
from dataclasses import dataclass
import torch.nn.functional as F
import gc
import os
import psutil

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('translation_app.log'),
        logging.StreamHandler()
    ]
)

@dataclass
class TranslationExample:
    """Data class for storing translation pairs"""
    japanese: str
    english: str

class TranslationDataset(Dataset):
    """Dataset class for handling translation data"""
    def __init__(self, japanese_texts: List[str], english_texts: List[str], tokenizer: NllbTokenizer):
        self.japanese_texts = japanese_texts
        self.english_texts = english_texts
        self.tokenizer = tokenizer

    def __len__(self) -> int:
        return len(self.japanese_texts)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        jp_text = self.japanese_texts[idx]
        en_text = self.english_texts[idx]
        
        inputs = self.tokenizer(
            f"jpn_Jpan {jp_text}", 
            return_tensors="pt", 
            padding=True, 
            truncation=True,
            max_length=128
        )
        
        with self.tokenizer.as_target_tokenizer():
            labels = self.tokenizer(
                f"eng_Latn {en_text}", 
                return_tensors="pt", 
                padding=True, 
                truncation=True,
                max_length=128
            )
        
        return {
            "input_ids": inputs["input_ids"].squeeze(),
            "attention_mask": inputs["attention_mask"].squeeze(),
            "labels": labels["input_ids"].squeeze()
        }

class TranslationApp:
    """Translation application optimized for RTX 4070 Laptop GPU"""
    def __init__(self, feedback_file: str = "translation_feedback.json"):
        self.model_name = "facebook/nllb-200-1.3B"
        self.feedback_file = Path(feedback_file)
        
        # Initialize device and optimize GPU settings
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if self.device.type == "cuda":
            # Optimize for 8GB VRAM
            torch.cuda.set_per_process_memory_fraction(0.85)
            torch.backends.cudnn.benchmark = True
            logging.info(f"CUDA Device: {torch.cuda.get_device_name()}")
            logging.info(f"Available VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f}GB")
        
        self.initialize_model()
        self.feedback_data = self.load_feedback()
        logging.info(f"Initialized TranslationApp using device: {self.device}")

    def initialize_model(self) -> None:
        """Initialize model with RTX 4070 optimizations"""
        try:
            # Load tokenizer
            self.tokenizer = NllbTokenizer.from_pretrained(self.model_name)
            
            # Optimized loading for 8GB VRAM
            logging.info("Loading model with optimized settings for RTX 4070...")
            
            # Create offload folder
            os.makedirs("offload_folder", exist_ok=True)
            
            model_kwargs = {
                "device_map": "auto",
                "torch_dtype": torch.float16,
                "low_cpu_mem_usage": True,
                "max_memory": {0: "7GiB", "cpu": "12GiB"},
                "offload_folder": "offload_folder"
            }
            
            try:
                self.model = AutoModelForSeq2SeqLM.from_pretrained(
                    self.model_name,
                    **model_kwargs
                )
                logging.info("Model loaded successfully with optimal settings")
                
            except ImportError:
                logging.warning("Accelerate library not found. Installing...")
                os.system("pip install 'accelerate>=0.26.0'")
                self.model = AutoModelForSeq2SeqLM.from_pretrained(
                    self.model_name,
                    **model_kwargs
                )
            
            # Set language codes
            self.src_lang = "jpn_Jpan"
            self.tgt_lang = "eng_Latn"
            self.tgt_lang_id = self.tokenizer.convert_tokens_to_ids(self.tgt_lang)
            
        except Exception as e:
            logging.warning(f"Optimal loading failed: {e}")
            self._handle_model_loading_fallback(e)

    def _handle_model_loading_fallback(self, error: Exception) -> None:
        """Fallback loading with memory optimizations"""
        try:
            logging.info("Attempting to load with basic float16...")
            self.model = (
                AutoModelForSeq2SeqLM.from_pretrained(
                    self.model_name,
                    torch_dtype=torch.float16,
                    low_cpu_mem_usage=True
                )
                .to(self.device)
            )
        except Exception as e:
            logging.warning(f"Float16 loading failed: {e}")
            logging.info("Falling back to basic model loading...")
            self.model = (
                AutoModelForSeq2SeqLM.from_pretrained(
                    self.model_name,
                    low_cpu_mem_usage=True
                )
                .to(self.device)
            )

    def _monitor_memory(self) -> None:
        """Monitor memory usage"""
        if torch.cuda.is_available():
            vram_used = torch.cuda.memory_allocated() / 1024**3
            vram_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
            logging.info(f"VRAM Usage: {vram_used:.2f}GB / {vram_total:.2f}GB")
        
        ram_used = psutil.Process(os.getpid()).memory_info().rss / 1024**3
        logging.info(f"RAM Usage: {ram_used:.2f}GB")

    def load_feedback(self) -> Dict[str, List[str]]:
        """Load feedback data"""
        if self.feedback_file.exists():
            try:
                with open(self.feedback_file, 'r', encoding='utf-8') as f:
                    return json.load(f)
            except json.JSONDecodeError:
                logging.error(f"Error reading feedback file: {self.feedback_file}")
        return {"japanese": [], "english": []}

    def save_feedback(self) -> None:
        """Save feedback data"""
        try:
            with open(self.feedback_file, 'w', encoding='utf-8') as f:
                json.dump(self.feedback_data, f, ensure_ascii=False, indent=2)
            logging.info("Feedback saved successfully")
        except Exception as e:
            logging.error(f"Error saving feedback: {e}")

    @torch.inference_mode()
    def translate(self, japanese_text: str, beam_size: int = 5) -> str:
        """Optimized translation for RTX 4070"""
        try:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            self._monitor_memory()
            
            inputs = self.tokenizer(
                f"{self.src_lang} {japanese_text}",
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=128
            ).to(self.device)
            
            generated_tokens = self.model.generate(
                **inputs,
                forced_bos_token_id=self.tgt_lang_id,
                max_length=128,
                num_beams=beam_size,
                length_penalty=0.8,
                early_stopping=True,
                no_repeat_ngram_size=3,
                num_beam_groups=beam_size,
                diversity_penalty=0.1,
                do_sample=False
            )
            
            translation = self.tokenizer.batch_decode(
                generated_tokens,
                skip_special_tokens=True
            )[0].replace(f"{self.tgt_lang} ", "")
            
            return translation
            
        except Exception as e:
            logging.error(f"Translation error: {e}")
            return f"Error during translation: {str(e)}"
        finally:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            gc.collect()

    def add_feedback(self, japanese_text: str, correct_english: str) -> None:
        """Add translation feedback"""
        if not japanese_text.strip() or not correct_english.strip():
            logging.warning("Empty translation pair received")
            return
        
        self.feedback_data["japanese"].append(japanese_text)
        self.feedback_data["english"].append(correct_english)
        self.save_feedback()

    def fine_tune(self, epochs: int = 2, batch_size: int = 1, 
                 learning_rate: float = 2e-6) -> None:
        """Fine-tune with memory optimizations for RTX 4070"""
        if len(self.feedback_data["japanese"]) == 0:
            logging.warning("No feedback data available")
            return
            
        try:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            gc.collect()
            
            self._monitor_memory()
            
            dataset = TranslationDataset(
                self.feedback_data["japanese"],
                self.feedback_data["english"],
                self.tokenizer
            )
            dataloader = DataLoader(
                dataset,
                batch_size=batch_size,
                shuffle=True,
                pin_memory=True
            )
            
            # Gradient accumulation steps
            gradient_accumulation_steps = max(1, 4 // batch_size)
            
            optimizer = torch.optim.AdamW(
                self.model.parameters(),
                lr=learning_rate,
                weight_decay=0.01
            )
            
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode='min', factor=0.5, patience=1, verbose=True
            )
            
            self.model.train()
            best_loss = float('inf')
            
            for epoch in range(epochs):
                total_loss = 0
                optimizer.zero_grad()
                
                for batch_idx, batch in enumerate(dataloader):
                    try:
                        batch = {k: v.to(self.device) for k, v in batch.items()}
                        
                        outputs = self.model(**batch)
                        loss = outputs.loss / gradient_accumulation_steps
                        loss.backward()
                        
                        if (batch_idx + 1) % gradient_accumulation_steps == 0:
                            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
                            optimizer.step()
                            optimizer.zero_grad()
                        
                        total_loss += loss.item() * gradient_accumulation_steps
                        
                        if batch_idx % 2 == 0:
                            self._monitor_memory()
                            logging.info(
                                f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, "
                                f"Loss: {loss.item() * gradient_accumulation_steps:.4f}"
                            )
                            
                    except RuntimeError as e:
                        if "out of memory" in str(e):
                            if torch.cuda.is_available():
                                torch.cuda.empty_cache()
                            logging.error(f"OOM error in batch {batch_idx}. Skipping...")
                            optimizer.zero_grad()
                            continue
                        raise e
                
                avg_loss = total_loss / len(dataloader)
                scheduler.step(avg_loss)
                
                if avg_loss < best_loss:
                    best_loss = avg_loss
                    logging.info(f"New best loss: {best_loss:.4f}")
                
                logging.info(f"Epoch {epoch+1}/{epochs}, Average Loss: {avg_loss:.4f}")
            
            self.model.eval()
            logging.info("Fine-tuning complete!")
            
        except Exception as e:
            logging.error(f"Fine-tuning error: {e}")
        finally:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            gc.collect()

def main():
    """Main function with improved error handling"""
    try:
        app = TranslationApp()
        logging.info("Translation app started")
        
        while True:
            try:
                print("\n=== Japanese-English Translation App (NLLB 1.3B) ===")
                print("1. Translate text")
                print("2. Fine-tune model")
                print("3. Exit")
                
                choice = input("\nSelect an option (1-3): ").strip()
                
                if choice == '1':
                    japanese_text = input("\nEnter Japanese text: ").strip()
                    if not japanese_text:
                        print("Please enter some text to translate.")
                        continue
                    
                    print("\nTranslating...")
                    translation = app.translate(japanese_text)
                    print(f"\nTranslation: {translation}")
                    
                    feedback = input("\nIs this translation correct? (y/n): ").strip().lower()
                    if feedback == 'n':
                        correct_translation = input("Please provide the correct translation: ").strip()
                        if correct_translation:
                            app.add_feedback(japanese_text, correct_translation)
                    
                elif choice == '2':
                    if len(app.feedback_data["japanese"]) == 0:
                        print("\nNo feedback data available for fine-tuning.")
                        continue
                    
                    print(f"\nAvailable feedback pairs: {len(app.feedback_data['japanese'])}")
                    confirm = input("Start fine-tuning? (y/n): ").strip().lower()
                    if confirm == 'y':
                        print("\nStarting fine-tuning process...")
                        app.fine_tune()
                    
                elif choice == '3':
                    print("\nThank you for using the translation app!")
                    break
                    
                else:
                    print("\nInvalid option. Please try again.")
                
            except KeyboardInterrupt:
                print("\nOperation cancelled by user.")
                continue
                
            except Exception as e:
                logging.error(f"Error in main loop: {e}")
                print("An error occurred. Please try again.")
                
    except KeyboardInterrupt:
        print("\nExiting...")
    except Exception as e:
        logging.error(f"Critical error: {e}")
        print("A critical error occurred. Please check the logs.")
    finally:
        # Clean up resources
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()

if __name__ == "__main__":
    main()

2024-11-27 15:14:35,978 - INFO - CUDA Device: NVIDIA GeForce RTX 4070 Laptop GPU
2024-11-27 15:14:35,978 - INFO - Available VRAM: 8.00GB
2024-11-27 15:14:37,843 - INFO - Loading model with optimized settings for RTX 4070...
2024-11-27 15:14:45,634 - INFO - Model loaded successfully with optimal settings
2024-11-27 15:14:45,642 - INFO - Initialized TranslationApp using device: cuda
2024-11-27 15:14:45,643 - INFO - Translation app started



=== Japanese-English Translation App (NLLB 1.3B) ===
1. Translate text
2. Fine-tune model
3. Exit



Select an option (1-3):  3



Thank you for using the translation app!
