In [1]:
from openai import OpenAI
from typing import List, Dict
import itertools
import time
from tqdm import tqdm
import evaluate
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Initialize BLEU and BERTScore metrics
bleu_metric = evaluate.load("bleu")
bertscore_metric = evaluate.load("bertscore")

class LegalFalcon3_7B_Answerer:
    def __init__(self, api_key, model="tiiuae/falcon3-7b-instruct"):
        self.client = OpenAI(
            api_key=api_key,
            base_url="https://integrate.api.nvidia.com/v1"
        )
        self.model = model

    def generate(self, prompt, temperature=0.3, top_p=1.0, best_of_n=1, debug=False, retries=3, wait_time=10):
        messages = [{"role": "user", "content": prompt}]
        candidates = []

        for _ in range(best_of_n):
            for attempt in range(retries):
                try:
                    completion = self.client.chat.completions.create(
                        model=self.model,
                        messages=messages,
                        temperature=temperature,
                        top_p=top_p,
                        max_tokens=1024,
                        stream=False
                    )
                    response_text = completion.choices[0].message.content.strip()

                    if debug:
                        print("🧠 Raw Response:\n", response_text)

                    candidates.append(response_text)
                    break
                except Exception as e:
                    print(f"⚠️ Attempt {attempt+1} failed: {e}")
                    if attempt < retries - 1:
                        print(f"⏳ Retrying in {wait_time}s...")
                        time.sleep(wait_time)
                    else:
                        raise RuntimeError("❌ All retries failed for Falcon3-7B")

        return candidates

def simple_reranker(candidates: List[str]) -> str:
    return max(candidates, key=lambda x: len(x))

def evaluate_outputs(predictions: List[str], references: List[str]) -> Dict:
    bleu = bleu_metric.compute(predictions=predictions,
                               references=[[ref] for ref in references])['bleu']
    bert = bertscore_metric.compute(predictions=predictions, references=references, lang='en')['f1']
    bert_avg = sum(bert) / len(bert)
    return {"BLEU": bleu, "BERTScore": bert_avg}

def hyperparameter_grid_search(answerer: LegalFalcon3_7B_Answerer,
                                prompts: List[str],
                                references: List[str],
                                temperatures: List[float],
                                top_ps: List[float],
                                best_of_n: int = 3):
    all_results = []
    grid = list(itertools.product(temperatures, top_ps))

    for temp, top_p in tqdm(grid, desc="Grid Search"):
        batch_predictions = []
        for prompt in prompts:
            candidates = answerer.generate(
                prompt,
                temperature=temp,
                top_p=top_p,
                best_of_n=best_of_n
            )
            best_answer = simple_reranker(candidates)
            batch_predictions.append(best_answer)
        scores = evaluate_outputs(batch_predictions, references)
        result = {
            "temperature": temp,
            "top_p": top_p,
            "BLEU": scores["BLEU"],
            "BERTScore": scores["BERTScore"]
        }
        print(f"Config Tested: {result}")
        all_results.append(result)

    return all_results

def log_results_and_plot(results: List[Dict], csv_path: str, heatmap_path: str):
    df = pd.DataFrame(results)
    df.to_csv(csv_path, index=False)
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    pivot_bleu = df.pivot(index="temperature", columns="top_p", values="BLEU")
    sns.heatmap(pivot_bleu, annot=True, fmt=".5f", cmap="YlGnBu")
    plt.title("BLEU Score Heatmap")

    plt.subplot(1, 2, 2)
    pivot_bert = df.pivot(index="temperature", columns="top_p", values="BERTScore")
    sns.heatmap(pivot_bert, annot=True, fmt=".5f", cmap="YlOrRd")
    plt.title("BERTScore Heatmap")

    plt.tight_layout()
    plt.savefig(heatmap_path, dpi=300)
    plt.close()
    print(f"\nResults saved to {csv_path} and heatmap to {heatmap_path}")

# =========================
# Example Usage
# =========================
if __name__ == "__main__":
    API_KEY = "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"

    answerer = LegalFalcon3_7B_Answerer(api_key=API_KEY)

    prompts = [
        "Question: What is the principle of stare decisis in US law? Answer and give reasoning.",
        "Question: Can a minor legally enter into a binding contract? Explain with reasoning."
    ]
    references = [
        "Answer: The principle of stare decisis means that courts follow precedents established by higher courts. Reasoning: This ensures legal consistency and predictability.",
        "Answer: Generally, a minor cannot enter into a binding contract, except for necessities. Reasoning: Contracts with minors are usually voidable to protect them from exploitation."
    ]

    temperatures = [0.1, 0.3, 0.5, 0.7]
    top_ps = [0.5, 0.7, 0.9]

    results = hyperparameter_grid_search(
        answerer=answerer,
        prompts=prompts,
        references=references,
        temperatures=temperatures,
        top_ps=top_ps,
        best_of_n=3
    )

    log_results_and_plot(
        results,
        csv_path="falcon3_hyperparameter_tuning.csv",
        heatmap_path="falcon3_hyperparameter_heatmap.png"
    )


Grid Search:   0%|                                       | 0/12 [00:00<?, ?it/s]Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Grid Search:   8%|██▌                            | 1/12 [00:23<04:21, 23.78s/it]

Config Tested: {'temperature': 0.1, 'top_p': 0.5, 'BLEU': 0.025501797325257813, 'BERTScore': 0.8732405602931976}


Grid Search:  17%|█████▏                         | 2/12 [00:55<04:42, 28.22s/it]

Config Tested: {'temperature': 0.1, 'top_p': 0.7, 'BLEU': 0.017921636220706704, 'BERTScore': 0.8466645181179047}


Grid Search:  25%|███████▊                       | 3/12 [01:22<04:10, 27.79s/it]

Config Tested: {'temperature': 0.1, 'top_p': 0.9, 'BLEU': 0.023681827868464067, 'BERTScore': 0.854119211435318}


Grid Search:  33%|██████████▎                    | 4/12 [01:45<03:26, 25.87s/it]

Config Tested: {'temperature': 0.3, 'top_p': 0.5, 'BLEU': 0.025368694537454998, 'BERTScore': 0.874313235282898}


Grid Search:  42%|████████████▉                  | 5/12 [02:12<03:05, 26.44s/it]

Config Tested: {'temperature': 0.3, 'top_p': 0.7, 'BLEU': 0.02361927684259837, 'BERTScore': 0.8538897931575775}


Grid Search:  50%|███████████████▌               | 6/12 [02:45<02:52, 28.75s/it]

Config Tested: {'temperature': 0.3, 'top_p': 0.9, 'BLEU': 0.020285170428755998, 'BERTScore': 0.8534522652626038}


Grid Search:  58%|██████████████████             | 7/12 [03:08<02:13, 26.60s/it]

Config Tested: {'temperature': 0.5, 'top_p': 0.5, 'BLEU': 0.030262267934684162, 'BERTScore': 0.8722962141036987}


Grid Search:  67%|████████████████████▋          | 8/12 [03:38<01:51, 27.84s/it]

Config Tested: {'temperature': 0.5, 'top_p': 0.7, 'BLEU': 0.020398910581278414, 'BERTScore': 0.8497428894042969}


Grid Search:  75%|███████████████████████▎       | 9/12 [04:04<01:21, 27.20s/it]

Config Tested: {'temperature': 0.5, 'top_p': 0.9, 'BLEU': 0.02099527116726129, 'BERTScore': 0.858477771282196}


Grid Search:  83%|█████████████████████████     | 10/12 [04:31<00:54, 27.24s/it]

Config Tested: {'temperature': 0.7, 'top_p': 0.5, 'BLEU': 0.024882265019984362, 'BERTScore': 0.8568577766418457}


Grid Search:  92%|███████████████████████████▌  | 11/12 [05:06<00:29, 29.66s/it]

Config Tested: {'temperature': 0.7, 'top_p': 0.7, 'BLEU': 0.017332164533316475, 'BERTScore': 0.8499995470046997}


Grid Search: 100%|██████████████████████████████| 12/12 [05:38<00:00, 28.25s/it]

Config Tested: {'temperature': 0.7, 'top_p': 0.9, 'BLEU': 0.018610577759493915, 'BERTScore': 0.8491026759147644}






Results saved to falcon3_hyperparameter_tuning.csv and heatmap to falcon3_hyperparameter_heatmap.png
