In [1]:
!pip install torch>=1.9.0 numpy>=1.20.0 groq>=0.11.0 transformers>=4.30.0 python-dotenv>=1.0.0 tqdm>=4.65.0 pydantic>=2.0.0 -q


In [2]:
import sys
sys.argv = sys.argv[:1]  # Loại bỏ đối số '-f' của Jupyter

In [3]:
import os
import json
import time
import random
from itertools import cycle
from groq import Groq  # Only keeping this import

def query_groq_api(data_test_path, result_path, list_of_api_keys, model="meta-llama/llama-4-maverick-17b-128e-instruct"):
    k=0
    # Load test data
    with open(data_test_path, "r", encoding="utf-8") as file:
        data_test = json.load(file)
    print("Lenght of data test",len(data_test))
    api_keys = cycle(list_of_api_keys)  # Rotate API keys
    API_KEY = next(api_keys)  # Start with the first key
    client = Groq(api_key=API_KEY)

    results = {}

    for i, sample in enumerate(data_test[k:]):
        prompt = f"""
        You are an advanced information extraction model specializing in Named Entity Recognition (NER) and Relation Extraction (RE). 
        Your specific domain is {sample['domain']}.
        Extract named entities and relationships from the given document. 
        Return only the extracted JSON output without any extra text.
        Extract relevant named entities and their relationships based on predefined NER and RE labels.
        Find all entities that you can find.

        ### Input:
        {json.dumps(sample, ensure_ascii=False)}

        ### Output Format:
        {{
            "{sample['id']}": {{
                "title": "{sample['title']}",
                "entities": [
                    {{
                        "mentions": ["<Entity Text>"],
                        "type": "<NER Label>"
                    }}
                ]
            }}
        }}
        """
        success =False
        while not success:
            try:
                chat_completion = client.chat.completions.create(
                    messages=[{"role": "user", "content": prompt}],
                    model=model,
                    response_format={"type": "json_object"},
                )
                print(f"Processing sample {i+k} with API key {API_KEY}")
                extracted_json = json.loads(chat_completion.choices[0].message.content)
    
                # Save results **only if** API call is successful
                results.update(extracted_json)
                success = True
                with open(result_path, "w", encoding="utf-8") as file:
                    json.dump(results, file, indent=4, ensure_ascii=False)
                print(f"Results updated and saved to {result_path} after processing sample {i+k}")
                break  # Move to next sample if success
    
            except Exception as e:
                error_message = str(e)
    
                # Handling invalid API key errors
                if "401" in error_message or "invalid API key" in error_message.lower():
                    print(f"Invalid API key detected: {API_KEY}. Switching to the next API key...")
                    API_KEY = next(api_keys)
                    client = Groq(api_key=API_KEY)
    
                # Handling rate limit errors
                elif "429" in error_message or "rate limit" in error_message.lower():
                    wait_time = 10 + random.uniform(0, 2)  # Slight randomization
                    print(f"Rate limit exceeded. Retrying in {wait_time:.2f} seconds.Switching to the next API key...")
                    time.sleep(wait_time)
                    API_KEY = next(api_keys)
                    client = Groq(api_key=API_KEY)
    
                # Handling unexpected errors
                else:
                    print(f"Unexpected API error: {e}. Skipping sample {i}...")
                    break  # Skip the sample if other errors occur

    print(f"Final results saved to {result_path}")


In [None]:
query_groq_api(
    data_test_path="/kaggle/input/docie2025/test_title.json",
    result_path="/kaggle/working/final1.json",
    list_of_api_keys=['gsk___']
)
