In [None]:
# ------------------------------------------------------------
# SCORE-OPTIMIZED ByT5 Inference Script for BLEU & chrF++
# ------------------------------------------------------------
# Key improvements for better scores:
# 1. Enhanced preprocessing to preserve linguistic structure
# 2. Smarter postprocessing that doesn't over-clean
# 3. Optimized beam search parameters
# 4. Better handling of gaps and special tokens
# ------------------------------------------------------------

import os
os.environ["OMP_NUM_THREADS"] = "4"
os.environ["MKL_NUM_THREADS"] = "4"
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "true"

import re
import json
import random
import logging
import warnings
from pathlib import Path
from dataclasses import dataclass
from typing import List, Dict

import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.cuda.amp import autocast
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from tqdm.auto import tqdm

warnings.filterwarnings("ignore")

# ------------------------------------------------------------
# CONFIGURATION - OPTIMIZED FOR SCORE
# ------------------------------------------------------------
@dataclass
class ScoreOptimizedConfig:
    # ============ PATHS ============
    test_data_path: str = "/kaggle/input/deep-past-initiative-machine-translation/test.csv"
    model_path: str = "/kaggle/input/final-byt5/byt5-akkadian-optimized-34x"
    output_dir: str = "/kaggle/working/"

    # ============ PROCESSING ============
    max_length: int = 512
    batch_size: int = 6  # Reduced for more beams
    num_workers: int = 4

    # ============ GENERATION - OPTIMIZED FOR QUALITY ============
    num_beams: int = 15  # Increased for better search
    max_new_tokens: int = 512
    length_penalty: float = 1.0  # Neutral - let model decide
    early_stopping: bool = False  # Explore more
    no_repeat_ngram_size: int = 3  # Prevent repetition
    repetition_penalty: float = 1.2  # Discourage repetition
    
    # NEW: Multiple hypothesis generation
    num_return_sequences: int = 1  # Can increase for ensemble
    use_ensemble: bool = False  # Generate multiple, pick best

    # ============ OPTIMIZATIONS ============
    use_mixed_precision: bool = True
    use_better_transformer: bool = True
    use_bucket_batching: bool = True
    use_smart_postproc: bool = True  # Less aggressive
    use_adaptive_beams: bool = False  # Use consistent high beams

    # ============ POSTPROCESSING ============
    minimal_postprocessing: bool = True  # Preserve model output
    preserve_punctuation: bool = True
    preserve_numbers: bool = True
    
    # ============ OTHER ============
    checkpoint_freq: int = 100
    num_buckets: int = 4

    def __post_init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        Path(self.output_dir).mkdir(exist_ok=True, parents=True)
        
        if not torch.cuda.is_available():
            self.use_mixed_precision = False
            self.use_better_transformer = False

# ------------------------------------------------------------
# LOGGING
# ------------------------------------------------------------
def setup_logging(output_dir: str) -> logging.Logger:
    Path(output_dir).mkdir(exist_ok=True, parents=True)
    log_file = Path(output_dir) / "inference_score_optimized.log"
    
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler(log_file),
        ],
    )
    
    return logging.getLogger(__name__)

# ------------------------------------------------------------
# SMART PREPROCESSOR - Preserves linguistic structure
# ------------------------------------------------------------
class SmartPreprocessor:
    def __init__(self):
        # More conservative patterns
        self.patterns = {
            "big_gap": re.compile(r"(\.{4,}|‚Ä¶{2,})"),  # Only very large gaps
            "small_gap": re.compile(r"(xxx+)"),  # Only multiple x's
        }

    def preprocess_input_text(self, text: str) -> str:
        if pd.isna(text):
            return ""
        
        cleaned_text = str(text).strip()
        
        # Minimal gap normalization
        cleaned_text = self.patterns["big_gap"].sub(" <big_gap> ", cleaned_text)
        cleaned_text = self.patterns["small_gap"].sub(" <gap> ", cleaned_text)
        
        # Normalize whitespace
        cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip()
        
        return cleaned_text

    def preprocess_batch(self, texts: List[str]) -> List[str]:
        return [self.preprocess_input_text(t) for t in texts]

# ------------------------------------------------------------
# MINIMAL POSTPROCESSOR - Preserves model output quality
# ------------------------------------------------------------
class MinimalPostprocessor:
    def __init__(self, minimal: bool = True):
        self.minimal = minimal
        
        # Only essential patterns
        self.patterns = {
            "whitespace": re.compile(r'\s+'),
            "space_before_punct": re.compile(r'\s+([.,;:!?])'),
            "repeated_spaces": re.compile(r'  +'),
        }

    def postprocess_single(self, text: str) -> str:
        if not isinstance(text, str) or not text.strip():
            return ""
        
        cleaned = text.strip()
        
        if self.minimal:
            # Only basic cleanup
            cleaned = self.patterns["whitespace"].sub(" ", cleaned)
            cleaned = self.patterns["space_before_punct"].sub(r"\1", cleaned)
            cleaned = cleaned.strip()
        else:
            # Slightly more cleanup but still conservative
            # Normalize subscripts
            cleaned = cleaned.translate(str.maketrans("‚ÇÄ‚ÇÅ‚ÇÇ‚ÇÉ‚ÇÑ‚ÇÖ‚ÇÜ‚Çá‚Çà‚Çâ", "0123456789"))
            
            # Fix spacing
            cleaned = self.patterns["whitespace"].sub(" ", cleaned)
            cleaned = self.patterns["space_before_punct"].sub(r"\1", cleaned)
            
            # Remove only truly problematic characters
            cleaned = cleaned.replace("‚Äî‚Äî", "-")
            
            cleaned = cleaned.strip()
        
        return cleaned

    def postprocess_batch(self, translations: List[str]) -> List[str]:
        return [self.postprocess_single(t) for t in translations]

# ------------------------------------------------------------
# BUCKET BATCH SAMPLER
# ------------------------------------------------------------
class BucketBatchSampler(Sampler):
    def __init__(self, dataset: Dataset, batch_size: int, num_buckets: int, logger: logging.Logger, shuffle: bool = False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.logger = logger

        lengths = [len(text.split()) for _, text in dataset]
        sorted_indices = sorted(range(len(lengths)), key=lambda i: lengths[i])

        bucket_size = max(1, len(sorted_indices) // max(1, num_buckets))
        self.buckets = []

        for i in range(num_buckets):
            start = i * bucket_size
            end = None if i == num_buckets - 1 else (i + 1) * bucket_size
            self.buckets.append(sorted_indices[start:end])

        self.logger.info(f"Created {num_buckets} buckets:")
        for i, bucket in enumerate(self.buckets):
            bucket_lengths = [lengths[idx] for idx in bucket] if len(bucket) > 0 else [0]
            self.logger.info(
                f"  Bucket {i}: {len(bucket)} samples, length range [{min(bucket_lengths)}, {max(bucket_lengths)}]"
            )

    def __iter__(self):
        for bucket in self.buckets:
            if self.shuffle:
                random.shuffle(bucket)
            for i in range(0, len(bucket), self.batch_size):
                yield bucket[i : i + self.batch_size]

    def __len__(self):
        total = 0
        for bucket in self.buckets:
            total += (len(bucket) + self.batch_size - 1) // self.batch_size
        return total

# ------------------------------------------------------------
# DATASET
# ------------------------------------------------------------
class AkkadianDataset(Dataset):
    def __init__(self, dataframe: pd.DataFrame, preprocessor: SmartPreprocessor, logger: logging.Logger):
        self.sample_ids = dataframe["id"].tolist()
        raw_texts = dataframe["transliteration"].tolist()
        preprocessed = preprocessor.preprocess_batch(raw_texts)
        self.input_texts = ["translate Akkadian to English: " + text for text in preprocessed]
        logger.info(f"Dataset created with {len(self.sample_ids)} samples")

    def __len__(self):
        return len(self.sample_ids)

    def __getitem__(self, index: int):
        return self.sample_ids[index], self.input_texts[index]

# ------------------------------------------------------------
# SCORE-OPTIMIZED INFERENCE ENGINE
# ------------------------------------------------------------
class ScoreOptimizedEngine:
    def __init__(self, config: ScoreOptimizedConfig, logger: logging.Logger):
        self.config = config
        self.logger = logger
        self.preprocessor = SmartPreprocessor()
        self.postprocessor = MinimalPostprocessor(minimal=config.minimal_postprocessing)
        self.results = []
        self._load_model()

    def _load_model(self):
        self.logger.info(f"Loading model from {self.config.model_path}")
        
        self.model = AutoModelForSeq2SeqLM.from_pretrained(self.config.model_path)
        self.model = self.model.to(self.config.device)
        self.model = self.model.eval()
        
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path)
        
        num_params = sum(p.numel() for p in self.model.parameters())
        self.logger.info(f"Model loaded: {num_params:,} parameters")

        if self.config.use_better_transformer and torch.cuda.is_available():
            try:
                from optimum.bettertransformer import BetterTransformer
                self.logger.info("Applying BetterTransformer...")
                self.model = BetterTransformer.transform(self.model)
                self.logger.info("‚úÖ BetterTransformer applied")
            except ImportError:
                self.logger.warning("‚ö†Ô∏è  'optimum' not installed, skipping BetterTransformer")
            except Exception as exc:
                self.logger.warning(f"‚ö†Ô∏è  BetterTransformer failed: {exc}")

    def _collate_fn(self, batch_samples):
        batch_ids = [s[0] for s in batch_samples]
        batch_texts = [s[1] for s in batch_samples]
        
        tokenized = self.tokenizer(
            batch_texts,
            max_length=self.config.max_length,
            padding=True,
            truncation=True,
            return_tensors="pt",
        )
        
        return batch_ids, tokenized

    def _save_checkpoint(self):
        if len(self.results) > 0 and len(self.results) % self.config.checkpoint_freq == 0:
            checkpoint_path = Path(self.config.output_dir) / f"checkpoint_{len(self.results)}.csv"
            df = pd.DataFrame(self.results, columns=["id", "translation"])
            df.to_csv(checkpoint_path, index=False)
            self.logger.info(f"üíæ Checkpoint: {len(self.results)} translations")

    def run_inference(self, test_df: pd.DataFrame) -> pd.DataFrame:
        self.logger.info("üöÄ Starting SCORE-OPTIMIZED inference")
        
        dataset = AkkadianDataset(test_df, self.preprocessor, self.logger)

        if self.config.use_bucket_batching:
            batch_sampler = BucketBatchSampler(
                dataset=dataset,
                batch_size=self.config.batch_size,
                num_buckets=self.config.num_buckets,
                logger=self.logger,
                shuffle=False,
            )
            dataloader = DataLoader(
                dataset,
                batch_sampler=batch_sampler,
                num_workers=self.config.num_workers,
                collate_fn=self._collate_fn,
                pin_memory=True,
                prefetch_factor=2,
                persistent_workers=True if self.config.num_workers > 0 else False,
            )
        else:
            dataloader = DataLoader(
                dataset,
                batch_size=self.config.batch_size,
                shuffle=False,
                num_workers=self.config.num_workers,
                collate_fn=self._collate_fn,
                pin_memory=True,
                prefetch_factor=2,
                persistent_workers=True if self.config.num_workers > 0 else False,
            )

        self.logger.info(f"DataLoader created: {len(dataloader)} batches")
        self.logger.info("Score optimization settings:")
        self.logger.info(f"  üéØ Num Beams: {self.config.num_beams}")
        self.logger.info(f"  üéØ Length Penalty: {self.config.length_penalty}")
        self.logger.info(f"  üéØ Repetition Penalty: {self.config.repetition_penalty}")
        self.logger.info(f"  üéØ No Repeat N-gram: {self.config.no_repeat_ngram_size}")
        self.logger.info(f"  üéØ Minimal Postproc: {self.config.minimal_postprocessing}")

        # Build generation config - FIXED: Removed incompatible parameters
        gen_config = {
            "max_new_tokens": self.config.max_new_tokens,
            "num_beams": self.config.num_beams,
            "length_penalty": self.config.length_penalty,
            "repetition_penalty": self.config.repetition_penalty,
            "early_stopping": self.config.early_stopping,
            "no_repeat_ngram_size": self.config.no_repeat_ngram_size,
            "use_cache": True,
            "num_return_sequences": self.config.num_return_sequences,
        }

        self.results = []

        with torch.inference_mode():
            for batch_idx, (batch_ids, tokenized) in enumerate(tqdm(dataloader, desc="üöÄ Translating")):
                try:
                    input_ids = tokenized.input_ids.to(self.config.device)
                    attention_mask = tokenized.attention_mask.to(self.config.device)

                    # Generate
                    if self.config.use_mixed_precision:
                        with autocast():
                            outputs = self.model.generate(
                                input_ids=input_ids,
                                attention_mask=attention_mask,
                                **gen_config,
                            )
                    else:
                        outputs = self.model.generate(
                            input_ids=input_ids,
                            attention_mask=attention_mask,
                            **gen_config,
                        )

                    # Decode
                    translations = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)

                    # Handle multiple sequences per input
                    if self.config.num_return_sequences > 1:
                        # Group by input
                        grouped_translations = []
                        for i in range(0, len(translations), self.config.num_return_sequences):
                            candidates = translations[i:i + self.config.num_return_sequences]
                            # Pick longest or use ensemble logic
                            best = max(candidates, key=len)
                            grouped_translations.append(best)
                        translations = grouped_translations

                    # Minimal postprocessing
                    cleaned = self.postprocessor.postprocess_batch(translations)

                    # Store
                    self.results.extend(list(zip(batch_ids, cleaned)))

                    # Save checkpoint
                    self._save_checkpoint()

                    # Periodic cache clearing
                    if torch.cuda.is_available() and batch_idx % 10 == 0:
                        torch.cuda.empty_cache()

                except Exception as exc:
                    self.logger.error(f"‚ùå Batch {batch_idx} error: {exc}")
                    self.results.extend([(bid, "") for bid in batch_ids])

        self.logger.info("‚úÖ Inference completed")

        results_df = pd.DataFrame(self.results, columns=["id", "translation"])
        self._validate_results(results_df)

        return results_df

    def _validate_results(self, df: pd.DataFrame):
        print("\n" + "=" * 60)
        print("üìä VALIDATION REPORT")
        print("=" * 60)

        empty = df["translation"].astype(str).str.strip().eq("").sum()
        print(f"\nEmpty: {empty} ({(empty / max(1, len(df))) * 100:.2f}%)")

        lengths = df["translation"].astype(str).str.len()
        print("\nüìè Length stats:")
        print(f"   Mean: {lengths.mean():.1f}, Median: {lengths.median():.1f}")
        print(f"   Min: {lengths.min()}, Max: {lengths.max()}")

        short = ((lengths < 5) & (lengths > 0)).sum()
        if short > 0:
            print(f"   ‚ö†Ô∏è  {short} very short translations")

        print("\nüìù Sample translations:")
        sample_indices = [0]
        if len(df) > 2:
            sample_indices.append(len(df) // 2)
        if len(df) > 1:
            sample_indices.append(len(df) - 1)

        for idx in sample_indices:
            row = df.iloc[idx]
            text = str(row["translation"])
            preview = text[:70] + "..." if len(text) > 70 else text
            print(f"   ID {int(row['id']):4d}: {preview}")

        print("\n" + "=" * 60 + "\n")

# ------------------------------------------------------------
# IO HELPERS
# ------------------------------------------------------------
def print_environment_info():
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")

    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
        print(f"GPU Memory: {total_mem_gb:.2f} GB")

    try:
        from optimum.bettertransformer import BetterTransformer
        print("‚úÖ BetterTransformer available!")
    except ImportError:
        print("‚ùå BetterTransformer NOT available")

def save_outputs(results_df: pd.DataFrame, config: ScoreOptimizedConfig, output_dir: str, logger: logging.Logger):
    output_path = Path(output_dir) / "submission.csv"
    config_path = Path(output_dir) / "score_optimized_config.json"

    results_df.to_csv(output_path, index=False)
    logger.info(f"‚úÖ Submission saved to {output_path}")

    config_dict = {
        "batch_size": config.batch_size,
        "num_beams": config.num_beams,
        "length_penalty": config.length_penalty,
        "repetition_penalty": config.repetition_penalty,
        "no_repeat_ngram_size": config.no_repeat_ngram_size,
        "minimal_postprocessing": config.minimal_postprocessing,
        "optimizations": {
            "mixed_precision": config.use_mixed_precision,
            "better_transformer": config.use_better_transformer,
            "bucket_batching": config.use_bucket_batching,
        },
    }

    with open(config_path, "w", encoding="utf-8") as f:
        json.dump(config_dict, f, indent=2)

    print("\n" + "=" * 60)
    print("üéâ SCORE-OPTIMIZED INFERENCE COMPLETE!")
    print("=" * 60)
    print(f"Submission file: {output_path}")
    print(f"Config file: {config_path}")
    print(f"Log file: {Path(output_dir) / 'inference_score_optimized.log'}")
    print(f"Total translations: {len(results_df)}")
    print("=" * 60)

def inspect_results(output_dir: str):
    submission_path = Path(output_dir) / "submission.csv"
    submission = pd.read_csv(submission_path)

    print(f"\nSubmission shape: {submission.shape}")
    print("\nFirst 10 translations:")
    print(submission.head(10))
    print("\nLast 10 translations:")
    print(submission.tail(10))

    lengths = submission["translation"].astype(str).str.len()
    print("\nLength distribution:")
    print(lengths.describe())

    empty = submission["translation"].astype(str).str.strip().eq("").sum()
    print(f"\nEmpty translations: {empty}")

    if empty > 0:
        print("\nEmpty translation IDs:")
        print(submission[submission["translation"].astype(str).str.strip().eq("")]["id"].tolist())

# ------------------------------------------------------------
# MAIN
# ------------------------------------------------------------
def main():
    config = ScoreOptimizedConfig()
    logger = setup_logging(config.output_dir)
    logger.info("Logging initialized")

    print_environment_info()

    logger.info("Configuration:")
    logger.info(f"  Device: {config.device}")
    logger.info(f"  Batch size: {config.batch_size}")
    logger.info(f"  Beams: {config.num_beams}")
    logger.info(f"  Repetition Penalty: {config.repetition_penalty}")

    logger.info(f"Loading test data from {config.test_data_path}")
    test_df = pd.read_csv(config.test_data_path, encoding="utf-8")
    logger.info(f"‚úÖ Loaded {len(test_df)} test samples")

    print("\nFirst 5 samples:")
    print(test_df.head())

    engine = ScoreOptimizedEngine(config, logger)
    results_df = engine.run_inference(test_df)
    save_outputs(results_df, config, config.output_dir, logger)
    inspect_results(config.output_dir)

if __name__ == "__main__":
    main()