In [1]:
!pip install nltk rouge-score matplotlib seaborn
!pip install sacrebleu bert-score torchmetrics nltk rouge-score datasets transformers groq pandas tqdm matplotlib seaborn
import nltk
nltk.download('wordnet')
nltk.download('omw-1.4')



Collecting rouge-score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: rouge-score
  Building wheel for rouge-score (setup.py) ... [?25ldone
[?25h  Created wheel for rouge-score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=9576e4045b548eebc8192b34be0649891ca1b7534f3a947174ae3206243e8877
  Stored in directory: /root/.cache/pip/wheels/5f/dd/89/461065a73be61a532ff8599a28e9beef17985c9e9c31e541b4
Successfully built rouge-score
Installing collected packages: rouge-score
Successfully installed rouge-score-0.1.2
Collecting sacrebleu
  Downloading sacrebleu-2.4.3-py3-none-any.whl.metadata (51 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting bert-score
  Downloading bert_score-0.3.13-py3-none-any.whl.metadata (15 kB)
Collecting groq
  Downloading groq-0.12.0-py3-none-any.whl.metadata (13 kB)
Collectin

True

In [2]:
import numpy as np
import pandas as pd
from datasets import load_dataset
from sacrebleu.metrics import BLEU, CHRF, TER
from bert_score import BERTScorer
from torchmetrics.text import TranslationEditRate, WordErrorRate, CharErrorRate
from rouge_score import rouge_scorer
from groq import Groq
import torch
from tqdm import tqdm
import warnings
import matplotlib.pyplot as plt
import seaborn as sns

warnings.filterwarnings('ignore')

In [None]:
import logging
import warnings
import numpy as np
import pandas as pd
import torch
from datasets import load_dataset
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from tqdm import tqdm
import time
from ratelimit import limits, sleep_and_retry
from groq import Groq

# Metrics imports
from sacrebleu.metrics import BLEU, CHRF, TER
from bert_score import BERTScorer
from torchmetrics.text import WordErrorRate, CharErrorRate
from rouge_score import rouge_scorer

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Suppress warnings
warnings.filterwarnings('ignore')
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Initialize Groq clients with multiple API keys
API_KEYS = [
    "your_api_key_1"
]

# API configuration
API_CONFIG = {
    "calls_per_minute": 50,  # Adjust based on API limits
    "timeout": 30,
    "max_retries": 3,
    "backoff_factor": 2
}

class APIManager:
    def __init__(self, api_keys):
        self.api_keys = api_keys
        self.current_key_index = 0
        self.request_counts = {key: 0 for key in api_keys}
        self.last_request_time = time.time()
        self.error_counts = {key: 0 for key in api_keys}
        
    def get_current_key(self):
        return self.api_keys[self.current_key_index]
        
    def rotate_key(self):
        self.current_key_index = (self.current_key_index + 1) % len(self.api_keys)
        logging.info(f"Rotating to API key index: {self.current_key_index}")
        return self.get_current_key()
        
    def handle_error(self, error):
        current_key = self.get_current_key()
        self.error_counts[current_key] += 1
        
        if "rate limit" in str(error).lower() or self.error_counts[current_key] >= 3:
            self.error_counts[current_key] = 0
            return self.rotate_key()
        return current_key

# Initialize API manager
api_manager = APIManager(API_KEYS)

# Define models with their context lengths
MODELS = {
    "gemma2-9b-it": {"provider": "Google", "context_length": 8192},
    "gemma-7b-it": {"provider": "Google", "context_length": 8192},
    "llama3-groq-70b-8192-tool-use-preview": {"provider": "Groq", "context_length": 8192},
    "llama3-groq-8b-8192-tool-use-preview": {"provider": "Groq", "context_length": 8192},
    "llama-3.1-70b-versatile": {"provider": "Meta", "context_length": 8192},
    "llama-3.1-8b-instant": {"provider": "Meta", "context_length": 8192},
    "mixtral-8x7b-32768": {"provider": "Mistral", "context_length": 32768},
    "llama-3.2-90b-vision-preview": {"provider": "Meta", "context_length": 128000}
}

class APIKeyLimitError(Exception):
    pass

class TranslationError(Exception):
    """Custom exception for translation errors"""
    pass

def switch_client():
    global client_index, client
    client_index = (client_index + 1) % len(API_KEYS)
    client = Groq(api_key=API_KEYS[client_index])
    logging.info(f"Switched to API key index: {client_index}")

def load_translation_data(language_pair, num_samples=1000):
    """Load dataset for specified language pair."""
    try:
        # Try loading from validation set first
        dataset = load_dataset("wmt19", language_pair, split="validation", streaming=True)
    except ValueError:
        # If validation not available, try train set
        dataset = load_dataset("wmt19", language_pair, split="train", streaming=True)
    
    # Select the specified number of samples
    dataset = dataset.take(num_samples)
    return list(dataset)

@sleep_and_retry
@limits(calls=API_CONFIG["calls_per_minute"], period=60)
@retry(stop=stop_after_attempt(3), 
       wait=wait_exponential(multiplier=API_CONFIG["backoff_factor"], min=4, max=10),
       retry=retry_if_exception_type((APIKeyLimitError, TranslationError)))
def translate_text(text, model_name, source_lang, target_lang, prompt):
    """Enhanced translation function with improved API handling"""
    try:
        current_key = api_manager.get_current_key()
        client = Groq(api_key=current_key)
        
        language_names = {
            'cs': 'Czech',
            'en': 'English',
            'de': 'German',
            'fi': 'Finnish',
            'fr': 'French',
            'gu': 'Gujarati',
            'kk': 'Kazakh',
            'lt': 'Lithuanian',
            'ru': 'Russian',
            'zh': 'Chinese'
        }
        
        source_lang_name = language_names[source_lang]
        target_lang_name = language_names[target_lang]
            
        # Get model's max context length
        max_length = MODELS[model_name]["context_length"]
        
        # Truncate input if necessary to fit context length (leaving room for prompt and response)
        safe_length = max_length - 500  # Reserve tokens for prompt and response
        if len(text) > safe_length:
            text = text[:safe_length] + "..."
        
        chat_completion = client.chat.completions.create(
            model=model_name,
            messages=[
                {"role": "system", "content": "You are a professional translator focused on accuracy and fluency."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.1,
            timeout=API_CONFIG["timeout"]
        )
        
        translation = chat_completion.choices[0].message.content.strip()
        if not translation or len(translation) < 2:
            raise TranslationError("Empty or invalid translation received")
            
        # Track successful request
        api_manager.request_counts[current_key] += 1
        return translation
        
    except Exception as e:
        logging.error(f"Translation error with key {api_manager.current_key_index}: {str(e)}")
        new_key = api_manager.handle_error(e)
        if new_key != current_key:
            raise APIKeyLimitError("Rotating API key due to errors")
        raise TranslationError(str(e))

def calculate_metrics(references, hypotheses):
    """Calculate various MT evaluation metrics."""
    try:
        # Initialize metrics
        bleu = BLEU()
        chrf = CHRF()
        ter_metric = TER()
        bert_scorer = BERTScorer(lang="en", rescale_with_baseline=True)
        wer = WordErrorRate()
        cer = CharErrorRate()
        rouge_metrics = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
        
        # Ensure inputs are valid
        if not references or not hypotheses:
            raise ValueError("Empty references or hypotheses")
            
        # Calculate metrics with error handling
        try:
            bleu_score = bleu.corpus_score(hypotheses, [references]).score
            chrf_score = chrf.corpus_score(hypotheses, [references]).score
            ter_score = ter_metric.corpus_score(hypotheses, [references]).score
            
            # BERTScore calculation
            P, R, F1 = bert_scorer.score(hypotheses, references)
            bert_score = F1.mean().item()
            
            # WER and CER calculation
            wer_score = wer(hypotheses, references)
            cer_score = cer(hypotheses, references)
            
            # ROUGE scores calculation
            rouge_scores = {'rouge1': 0, 'rouge2': 0, 'rougeL': 0}
            for hyp, ref in zip(hypotheses, references):
                scores = rouge_metrics.score(ref, hyp)
                rouge_scores['rouge1'] += scores['rouge1'].fmeasure
                rouge_scores['rouge2'] += scores['rouge2'].fmeasure
                rouge_scores['rougeL'] += scores['rougeL'].fmeasure
            
            for key in rouge_scores:
                rouge_scores[key] /= len(hypotheses)
                
            return {
                "BLEU": bleu_score,
                "chrF": chrf_score,
                "TER": ter_score,
                "BERTScore": bert_score,
                "WER": wer_score,
                "CER": cer_score,
                "ROUGE-1": rouge_scores['rouge1'],
                "ROUGE-2": rouge_scores['rouge2'],
                "ROUGE-L": rouge_scores['rougeL']
            }
            
        except Exception as e:
            logging.error(f"Error calculating metrics: {str(e)}")
            return None
            
    except Exception as e:
        logging.error(f"Error initializing metrics: {str(e)}")
        return None

def save_translations(model_name, source_lang, target_lang, prompt_name, source_texts, references, translations):
    pd.DataFrame({
        'Source': source_texts,
        'Reference': references,
        'Translation': translations
    }).to_csv(f'translations_{model_name}_{source_lang}-{target_lang}_{prompt_name}.csv', index=False)

def evaluate_models(dataset, source_lang, target_lang, prompts):
    """Enhanced evaluation function with better API handling"""
    results = {}
    
    for model_name, model_info in MODELS.items():
        logging.info(f"Processing model: {model_name}")
        
        for prompt_id, prompt_func in prompts:
            logging.info(f"Using prompt: {prompt_id}")
            translations = []
            references = []
            source_texts = []
            
            for example in tqdm(dataset):
                try:
                    # Add delay between requests if needed
                    if len(translations) > 0 and len(translations) % 10 == 0:
                        time.sleep(1)  # Prevent hitting rate limits
                    
                    source_text = example['translation'][source_lang]
                    reference = example['translation'][target_lang]
                    prompt = prompt_func(source_text, source_lang, target_lang)
                    translation = translate_text(source_text, model_name, source_lang, target_lang, prompt)
                    
                    source_texts.append(source_text)
                    translations.append(translation)
                    references.append(reference)
                    
                except Exception as e:
                    logging.error(f"Error processing example: {str(e)}")
                    continue
            
            if translations:
                metrics = calculate_metrics(references, translations)
                model_prompt_key = f"{model_name} ({model_info['provider']}) - Prompt {prompt_id}"
                results[model_prompt_key] = metrics
                
                save_translations(
                    model_name=model_name,
                    source_lang=source_lang,
                    target_lang=target_lang,
                    prompt_name=f"Prompt_{prompt_id}",
                    source_texts=source_texts,
                    references=references,
                    translations=translations
                )
    
    return pd.DataFrame(results).T

def visualize_results(results, pair_name):
    """Create visualizations for the evaluation results."""
    plt.figure(figsize=(20, 10))
    sns.heatmap(results, annot=True, cmap='YlOrRd', fmt='.3f')
    plt.title(f'Translation Metrics Comparison for {pair_name}')
    plt.ylabel('Models')
    plt.xlabel('Metrics')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(f'mt_evaluation_heatmap_{pair_name}.png')
    plt.close()

def main():
    language_pairs = {
    # ("cs-en"): [("cs-en", "Czech-English"), ("en-cs", "English-Czech")],
    # ("de-en"): [("de-en", "German-English"), ("en-de", "English-German")],
    # ("fi-en"): [("fi-en", "Finnish-English"), ("en-fi", "English-Finnish")],
    # ("fr-de"): [("fr-de", "French-German"), ("de-fr", "German-French")],
    # ("gu-en"): [("gu-en", "Gujarati-English"), ("en-gu", "English-Gujarati")],
    # ("kk-en"): [("kk-en", "Kazakh-English"), ("en-kk", "English-Kazakh")],
    # ("lt-en"): [("lt-en", "Lithuanian-English"), ("en-lt", "English-Lithuanian")],
    # ("ru-en"): [("ru-en", "Russian-English"), ("en-ru", "English-Russian")],
    ("zh-en"): [("zh-en", "Chinese-English"), ("en-zh", "English-Chinese")]
    }

    prompts = [
        (1, lambda text, sl, tl: f"""System: Professional {sl}-{tl} translator
        Objective: Precise translation maintaining source meaning and target fluency
        Guidelines:
        1. Preserve exact meaning
        2. Maintain formatting
        3. Keep technical terms
        4. Ensure natural flow
        
        IMPORTANT: Provide ONLY the translation, without explanations or additional text.
        
        Source ({sl}): {text}
        Translation ({tl}): """),

        (2, lambda text, sl, tl: f"""System: Expert {sl}-{tl} translator with deep cultural understanding.
        Context: Professional translation requiring cultural and contextual accuracy.
        Requirements:
        - Preserve idiomatic expressions
        - Adapt cultural references appropriately
        - Maintain tone and register
        - Ensure natural {tl} language patterns
        
        IMPORTANT: Return ONLY the translated text, nothing else.
        
        Original ({sl}): {text}
        Translation ({tl}): """),

        (3, lambda text, sl, tl: f"""System: Specialized translation engine for {sl} to {tl}.
        Focus areas:
        - Technical accuracy
        - Domain-specific terminology
        - Structural equivalence
        - Target language conventions
        
        IMPORTANT: Output ONLY the translation, no comments or explanations.
        
        Input ({sl}): {text}
        Professional translation ({tl}): """),

        (4, lambda text, sl, tl: f"""System: Neural machine translation model optimized for {sl}-{tl} pair.
        Translation parameters:
        - Maximum semantic fidelity
        - Context preservation
        - Appropriate register
        - Natural language generation
        
        IMPORTANT: Respond with ONLY the translation text.
        
        Source content ({sl}): {text}
        Target content ({tl}): """),

        (5, lambda text, sl, tl: f"""System: Professional translation service
        Task: Convert from {sl} to {tl}
        Requirements:
        - Highest accuracy level
        - Natural expression
        - Contextual awareness
        - Style matching
        
        IMPORTANT: Give ONLY the translation, without any additional text.
        
        Source text ({sl}):
        {text}
        
        High-quality translation ({tl}):""")
    ]

    all_results = {}
    
    for pair_name, lang_pairs in language_pairs.items():
        dataset = load_translation_data(pair_name, num_samples=100)
        for pair_code, name in lang_pairs:
            source_lang, target_lang = pair_code.split("-")
            logging.info(f"Evaluating {name} translations...")
            results = evaluate_models(dataset, source_lang, target_lang, prompts)
            results.to_csv(f"mt_evaluation_results_{pair_code}.csv")
            all_results[pair_code] = results
            logging.info(f"Results for {name}:")
            logging.info(results)

if __name__ == "__main__":
    main()



2024-11-22 00:10:42,192 - INFO - Evaluating Chinese-English translations...
2024-11-22 00:10:42,192 - INFO - Processing model: gemma2-9b-it
2024-11-22 00:10:42,192 - INFO - Using prompt: 1
  0%|          | 0/100 [00:00<?, ?it/s]2024-11-22 00:10:42,764 - INFO - HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"
  1%|          | 1/100 [00:00<00:56,  1.74it/s]2024-11-22 00:10:43,359 - INFO - HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"
  2%|▏         | 2/100 [00:01<00:57,  1.71it/s]2024-11-22 00:10:44,019 - INFO - HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"
  3%|▎         | 3/100 [00:01<01:00,  1.61it/s]2024-11-22 00:10:44,743 - INFO - HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"
  4%|▍         | 4/100 [00:02<01:04,  1.50it/s]2024-11-22 00:10:45,393 - INFO - HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 20

KeyboardInterrupt: 