In [1]:
import json, os, pandas as pd, numpy as np, csv
import requests
import io
import tarfile
import zipfile
from datasets import load_dataset

In [2]:
import os

# Print the current working directory
print("Current Working Directory:", os.getcwd())

# Change the current working directory
new_directory = "/teamspace/studios/this_studio/simpletext-2025-controlcreativity/notebooks"
os.chdir(new_directory)

# Print the new working directory to confirm the change
print("New Working Directory:", os.getcwd())

Current Working Directory: /teamspace/studios/this_studio
New Working Directory: /teamspace/studios/this_studio/simpletext-2025-controlcreativity/notebooks


### Load data for Hallucination Detection training

In [3]:
# Directory where the CSV files are stored
data_dir = os.path.join(os.path.dirname(os.getcwd()), 'data')

# List all CSV files in the directory
csv_files = [f for f in os.listdir(data_dir) if f.endswith('.csv')]

# Load all CSV files into a single DataFrame
df_list = []
for file in csv_files:
    file_path = os.path.join(data_dir, file)
    df = pd.read_csv(file_path)
    df_list.append(df)

# Concatenate all DataFrames
training_df = pd.concat(df_list, ignore_index=True)

# Display the first few rows of the combined DataFrame
training_df.head()

Unnamed: 0,id,grounding,generated_text,label,cut,dataset_origin
0,34687720,France's Dubuisson carded a 67 to tie with ove...,rory mcilroy will take a one-shot lead into th...,0,val,XSumFaith
1,29347895,He died at his home in Cambridge following an ...,veteran classical music conductor christopher ...,0,val,XSumFaith
2,37895159,The Cherries went down 2-1 at Sunderland on Sa...,bournemouth manager eddie howe says his side a...,0,val,XSumFaith
3,37546354,Washington blamed Russia and the Syrian govern...,the us says it has suspended talks with russia...,0,val,XSumFaith
4,22299596,"Gareth Colfer-Williams, 25, died last week at ...",a post-mortem examination has concluded that a...,0,val,XSumFaith


In [4]:
# Get the counts of val and test data
val_test_spit = training_df['cut'].value_counts()

# Display the counts
print(val_test_spit)

print("val can be used for training the model and test can be used for evaluation the performance")

cut
val     84044
test    38332
Name: count, dtype: int64
val can be used for training the model and test can be used for evaluation the performance


In [5]:
# Get the counts by dataset origin
training_df['dataset_origin'].value_counts()

dataset_origin
Vitamin C     63054
HaluEval      20000
Fever         19998
PAWS           8000
XSumFaith      2353
SummEval       1698
FactCC         1434
FRANK          1393
Polytope       1268
Cao22           696
CLIFF           600
TofuEval        534
Wang20          474
samsum          250
qags_xsum       239
qags_cnndm      235
Goyal21         150
Name: count, dtype: int64

In [6]:
# prepare train and test - remove vitamin c & Fever as it is skewing the dataset towards Fact verification
# train lists 
train_data = training_df[(training_df.cut == 'val') & (~training_df['dataset_origin'].isin(['Vitamin C', 'Fever']))]
train_grounding_list = list(train_data['grounding'])
train_generated_list = list(train_data['generated_text'])

# test lists
test_data = training_df[(training_df.cut == 'test') & (~training_df['dataset_origin'].isin(['Vitamin C', 'Fever']))]
test_grounding_list = list(test_data['grounding'])
test_generated_list = list(test_data['generated_text'])

### LLM as judge baseline Approach

##### Entailment Prompting

In [7]:
### Entailment prompting
from typing import List, Tuple, Dict

kg_construction_prompt="""You are an expert at determining if a summary is consistent with a source article. Given an article and a summary, determine if all the information in the summary is supported by the article. Answer "yes" if the summary is consistent, and "no" if it is inconsistent."""
ei_format_prompt="""Article: {article}
    Summary: {summary}
    Answer (yes or no):"""

# Function to create entailment prompt
def construct_prompt_batch(articles: List[str], summaries: List[str]) -> List[str]: #ei for entailment inference #Changed name and parameters
        # Builds prompts for entailment inference
        prompts = [kg_construction_prompt + "\n" + ei_format_prompt.format(article=article, summary=summary) for article, summary in zip(articles, summaries)]  #Changed to ei
        return prompts #Returning list of prompts

In [8]:
# generate prompts for train and test
train_prompts = construct_prompt_batch(train_grounding_list, train_generated_list)
test_prompts = construct_prompt_batch(test_grounding_list, test_generated_list)

In [9]:
import torch
import json
import os
from groq import Groq
import requests # pip install requests first!


# Using Asynchromous Groq Api call functions

# Create Json file with batch requests
def create_batch_requests(messages, model, filename):
    with open(filename, "w") as file:
        for i, message in enumerate(messages, start=1):
            # Create a unique custom_id for each message
            custom_id = f"request-{i}"
            
            # Create the batch request dictionary
            batch_request = {
                "custom_id": custom_id,
                "method": "POST",
                "url": "/v1/chat/completions",
                "body": {
                    "model": model,
                    "messages": [
                        {"role": "user", "content": message}
                    ]
                }
            }
            
            # Write the batch request to the file as a single JSON object per line
            json.dump(batch_request, file)
            file.write("\n")  # Ensure each JSON object is on a new line


# Upload the batch requests to Groq
def upload_file_to_groq(api_key, file_path):
    url = "https://api.groq.com/openai/v1/files"
    
    headers = {
        "Authorization": f"Bearer {api_key}"
    }
    
    # Prepare the file and form data
    files = {
        "file": ("batch_file.jsonl", open(file_path, "rb"))
    }
    
    data = {
        "purpose": "batch"
    }
    
    # Make the POST request
    response = requests.post(url, headers=headers, files=files, data=data)
    
    return response.json()


# Create a batch job
def create_batch(api_key, input_file_id):
    url = "https://api.groq.com/openai/v1/batches"
    
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }
    
    data = {
        "input_file_id": input_file_id,
        "endpoint": "/v1/chat/completions",
        "completion_window": "24h"
    }
    
    response = requests.post(url, headers=headers, json=data)
    return response.json()


# check the status of the batch job
def get_batch_status(api_key, batch_id):
    url = f"https://api.groq.com/openai/v1/batches/{batch_id}"
    
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }
    
    response = requests.get(url, headers=headers)
    return response.json()

# Download the results of the batch job
def download_file_content(api_key, output_file_id, output_file):
    url = f"https://api.groq.com/openai/v1/files/{output_file_id}/content"
    
    headers = {
        "Authorization": f"Bearer {api_key}"
    }
    
    response = requests.get(url, headers=headers)
    
    # Write the content to a file
    with open(output_file, 'wb') as f:
        f.write(response.content)
    
    return f"File downloaded successfully to {output_file}"



In [10]:
batch_size = 1000
api_key='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'
model_name = "llama-3.3-70b-versatile"
output_dir = "batch_requests"

In [48]:
# Excution of Asynchromous Groq Api call functions

#create batch requests
from torch import mode


def process_batches_and_create_jobs(train_prompts, batch_size, model_name, api_key, output_dir):
    """
    Splits train_prompts into batches, creates JSON files using create_batch_requests,
    uploads them to Groq using upload_file_to_groq, and creates batch jobs using create_batch.

    Args:
        train_prompts (list): List of prompts to process.
        batch_size (int): The size of each batch.
        model_name (str): The model name for the Groq API.
        api_key (Groq): API key.
        output_dir (str): Directory to save the JSON files.

    Returns:
        dict: A dictionary mapping batch file names to their associated batch IDs.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    batch_file_ids = {}
    batch_ids = {}

    # Split train_prompts into batches
    for i in range(0, len(train_prompts), batch_size):
        batch_prompts = train_prompts[i:i + batch_size]
        batch_file_name = f"batch_{i // batch_size + 1}.json"
        batch_file_path = os.path.join(output_dir, batch_file_name)

        # Create JSON file for the batch using create_batch_requests
        create_batch_requests(
            messages=batch_prompts,
            model=model_name,
            filename=batch_file_path
        )

        # Upload JSON file to Groq using upload_file_to_groq
        try:
            file_upload_response = upload_file_to_groq(api_key, batch_file_path)
            file_id = file_upload_response["id"]
            batch_file_ids[batch_file_name] = file_id

            # Create batch job using create_batch
            batch_job_response = create_batch(api_key, file_id)
            batch_id = batch_job_response["id"]
            batch_ids[batch_file_name] = batch_id

        except Exception as e:
            print(f"Error processing batch {batch_file_name}: {e}")

    return batch_ids



batch_ids = process_batches_and_create_jobs(
    train_prompts=train_prompts,
    batch_size=batch_size,
    model_name=model_name,
    api_key=api_key,
    output_dir=output_dir
)

KeyboardInterrupt: 

In [11]:
#save the batch_ids to a json file
batch_ids_file = os.path.join(output_dir, "batch_ids.json")
with open(batch_ids_file, "w") as f:
    json.dump(batch_ids, f)

NameError: name 'batch_ids' is not defined

In [12]:
#load the batch_ids from the json file
batch_ids_file = os.path.join(output_dir, "batch_ids.json")
with open(batch_ids_file, "r") as f:
    batch_ids = json.load(f)

In [13]:
# Check the status of the batch jobs
def check_batch_status(api_key, batch_ids):
    """
    Checks the status of each batch job using the Groq API.

    Args:
        api_key (str): API key for authentication.
        batch_ids (dict): Dictionary mapping batch file names to their associated batch IDs.

    Returns:
        dict: A dictionary mapping batch file names to their status.
    """
    batch_statuses = {}
    for batch_file_name, batch_id in batch_ids.items():
        try:
            status_response = get_batch_status(api_key, batch_id)
            batch_statuses[batch_file_name] = status_response["status"]
        except Exception as e:
            print(f"Error checking status for batch {batch_file_name}: {e}")

    return batch_statuses

batch_statuses = check_batch_status(api_key, batch_ids)
# Print the status of each batch job
for batch_file_name, status in batch_statuses.items():
    print(f"Batch file: {batch_file_name}, Status: {status}")

Batch file: batch_1.json, Status: completed
Batch file: batch_2.json, Status: completed
Batch file: batch_3.json, Status: completed
Batch file: batch_4.json, Status: completed
Batch file: batch_5.json, Status: completed
Batch file: batch_6.json, Status: completed
Batch file: batch_7.json, Status: completed
Batch file: batch_8.json, Status: completed
Batch file: batch_9.json, Status: completed
Batch file: batch_10.json, Status: completed
Batch file: batch_11.json, Status: completed
Batch file: batch_12.json, Status: completed
Batch file: batch_13.json, Status: completed
Batch file: batch_14.json, Status: completed
Batch file: batch_15.json, Status: completed
Batch file: batch_16.json, Status: completed
Batch file: batch_17.json, Status: completed
Batch file: batch_18.json, Status: completed
Batch file: batch_19.json, Status: completed
Batch file: batch_20.json, Status: completed
Batch file: batch_21.json, Status: completed
Batch file: batch_22.json, Status: completed
Batch file: batch_2

In [14]:
# Check the status of the batch jobs and download the results
def check_batch_status_and_download(api_key, batch_ids, output_dir):
    """
    Checks the status of batch jobs and downloads the results if completed.

    Args:
        api_key (str): API key for Groq.
        batch_ids (dict): Dictionary mapping batch file names to their associated batch IDs.
        output_dir (str): Directory to save the downloaded files.

    Returns:
        None
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for batch_file_name, batch_id in batch_ids.items():
        try:
            # Check the status of the batch job
            status_response = get_batch_status(api_key, batch_id)
            status = status_response["status"]

            if status == "completed":
                # Download the results
                output_file_id = status_response["output_file_id"]
                output_file_name = f"{batch_file_name}_output.json"
                output_file_path = os.path.join(output_dir, output_file_name)

                download_message = download_file_content(api_key, output_file_id, output_file_path)
                print(download_message)
            else:
                print(f"Batch {batch_file_name} is still processing. Status: {status}")

        except Exception as e:
            print(f"Error checking status or downloading for {batch_file_name}: {e}")

# Check the status of the batch jobs and download the results
check_batch_status_and_download(api_key, batch_ids, output_dir)

File downloaded successfully to batch_requests/batch_1.json_output.json
File downloaded successfully to batch_requests/batch_2.json_output.json
File downloaded successfully to batch_requests/batch_3.json_output.json
File downloaded successfully to batch_requests/batch_4.json_output.json
File downloaded successfully to batch_requests/batch_5.json_output.json
File downloaded successfully to batch_requests/batch_6.json_output.json
File downloaded successfully to batch_requests/batch_7.json_output.json
File downloaded successfully to batch_requests/batch_8.json_output.json
File downloaded successfully to batch_requests/batch_9.json_output.json
File downloaded successfully to batch_requests/batch_10.json_output.json
File downloaded successfully to batch_requests/batch_11.json_output.json
File downloaded successfully to batch_requests/batch_12.json_output.json
File downloaded successfully to batch_requests/batch_13.json_output.json
File downloaded successfully to batch_requests/batch_14.json

In [15]:
# Process the downloaded files
def process_downloaded_files(output_dir):
    """
    Processes the downloaded files and extracts the responses.

    Args:
        output_dir (str): Directory where the downloaded files are stored.

    Returns:
        dict: A dictionary mapping batch file names to their responses.
    """
    responses = []
    for i in range(0, len(train_prompts), batch_size):
        batch_file_name = f"batch_{i // batch_size + 1}.json_output.json"
        file_path = os.path.join(output_dir, batch_file_name)
        if os.path.exists(file_path):
            with open(file_path, "r") as file:
                batch_output_dict = {}
                for line in file:
                    response = json.loads(line)
                    custom_id = int(response["custom_id"].split('-')[1])
                    content = response["response"]["body"]["choices"][0]["message"]["content"]
                    batch_output_dict[custom_id] = content
                
                # Extract the responses from the batch output
                ordered_values = [batch_output_dict[key] for key in sorted(batch_output_dict.keys())]
                responses.extend(ordered_values)
    

    return responses

# Process the downloaded files
def check_consistency_batch(grok_outputs: List[str]) -> List[int]:  # Now the grok outputs is what comes in

    results = []
    for output in grok_outputs:
        output = output.strip().lower()  # Normalize the output by stripping and converting to lowercase
        if output.startswith("yes"):  # Check if the output starts with "yes"
            results.append(1)  # Consistent
        elif output.startswith("no"):  # Check if the output starts with "no"
            results.append(0)  # Inconsistent
        else:
            results.append(0)  # Undetermined, handle as needed.

    return results



responses = process_downloaded_files(output_dir)
consistency_results = check_consistency_batch(responses)

In [25]:
# append to train data
train_data.loc[:,"predictions"] =  consistency_results


from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def calculate_metrics(train_data):
    """
    Calculates accuracy, precision, recall, and F1-score using the label and predictions columns of a pandas DataFrame.

    Args:
        train_data (pd.DataFrame): A pandas DataFrame with 'label' and 'predictions' columns.

    Returns:
        dict: A dictionary containing accuracy, precision, recall, and F1-score.
    """
    # Extract labels and predictions
    labels = train_data['label']
    predictions = train_data['predictions']

    # Calculate metrics
    prevalence = labels.value_counts(normalize=True)[1]
    accuracy = accuracy_score(labels, predictions)
    precision = precision_score(labels, predictions, average='binary')  # Use 'binary' for binary classification
    recall = recall_score(labels, predictions, average='binary')
    f1 = f1_score(labels, predictions, average='binary')

    # Return metrics as a dictionary
    return {
        "prevalence": prevalence,
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1_score": f1
    }

# Calculate metrics
metrics = calculate_metrics(train_data)
# Print the metrics
print("Metrics:")
for metric, value in metrics.items():
    print(f"{metric}: {value:.4f}")


Metrics:
prevalence: 0.5009
accuracy: 0.7844
precision: 0.7426
recall: 0.8719
f1_score: 0.8021


In [29]:
# Calculate metrics by dataset origin
def calculate_metrics_by_origin(train_data):
    """
    Calculates metrics by dataset origin.

    Args:
        train_data (pd.DataFrame): A pandas DataFrame with 'label', 'predictions', and 'dataset_origin' columns.

    Returns:
        dict: A dictionary containing accuracy, precision, recall, and F1-score by dataset origin.
    """
    metrics_by_origin = {}
    for origin in train_data['dataset_origin'].unique():
        origin_data = train_data[train_data['dataset_origin'] == origin]
        metrics = calculate_metrics(origin_data)
        metrics_by_origin[origin] = metrics

    return metrics_by_origin
# Calculate metrics by dataset origin
metrics_by_origin = calculate_metrics_by_origin(train_data)
# Print the metrics by dataset origin
print("Metrics by Dataset Origin:")



# show metrics by dataset origin in tabular format
def display_metrics_by_origin(metrics_by_origin):
    """
    Displays metrics by dataset origin in a tabular format.

    Args:
        metrics_by_origin (dict): A dictionary containing metrics by dataset origin.

    Returns:
        pd.DataFrame: A DataFrame containing the metrics.
    """
    # Create a DataFrame from the metrics dictionary
    metrics_df = pd.DataFrame.from_dict(metrics_by_origin, orient='index')
    return metrics_df
# Display metrics by dataset origin in tabular format
metrics_df = display_metrics_by_origin(metrics_by_origin)
# Print the DataFrame
print(metrics_df)

Metrics by Dataset Origin:
            prevalence  accuracy  precision    recall  f1_score
XSumFaith     0.084000  0.866400   0.325843  0.552381  0.409894
Polytope      0.872240  0.894322   0.967308  0.909584  0.937558
FactCC        0.858217  0.946294   0.965217  0.972466  0.968828
SummEval      0.878824  0.912941   0.974612  0.925033  0.949176
FRANK         0.405769  0.819231   0.711191  0.933649  0.807377
Wang20        0.478992  0.794118   0.755906  0.842105  0.796680
CLIFF         0.596667  0.860000   0.840796  0.944134  0.889474
Goyal21       0.266667  0.773333   0.571429  0.600000  0.585366
Cao22         0.571116  0.730853   0.852041  0.639847  0.730853
HaluEval      0.500178  0.739979   0.691757  0.866050  0.769153
PAWS          0.444032  0.810734   0.750087  0.860437  0.801481
qags_cnndm    0.488636  0.846591   0.797980  0.918605  0.854054
qags_xsum     0.478788  0.800000   0.755556  0.860759  0.804734
samsum        0.182857  0.914286   0.814815  0.687500  0.745763
TofuEval     

In [30]:
# delete the batch files
import shutil
def delete_batch_files(output_dir):
    """
    Deletes all files in the specified directory.

    Args:
        output_dir (str): Directory where the batch files are stored.

    Returns:
        None
    """
    if os.path.exists(output_dir):
        shutil.rmtree(output_dir)
        print(f"Deleted all files in {output_dir}")
    else:
        print(f"{output_dir} does not exist")
# Delete the batch files
delete_batch_files(output_dir)

Deleted all files in batch_requests


### 1. GraphEval Implementation

GraphEval is a combination approach of using LLMs to create KGs and check consistency using NLI to detect hallucinations.

The implementation includes the main components of GraphEval as described in the paper:
1.KG construction from the LLM output
2.Consistency checking for each triple using an NLI model
3.Overall evaluation based on the consistency of all triples

Note that the KG construction step (construct_kg method) is a placeholder and should be implemented using an actual LLM in practice. The paper doesn't provide specific details on this step, so you would need to design an appropriate prompt and use an LLM API to generate the KG triples.

The check_consistency method uses a pre-trained RoBERTa model fine-tuned on MNLI for natural language inference. It returns the probability of contradiction between the triple and the context.

The evaluate method puts it all together, constructing the KG, checking each triple for consistency, and returning the overall result along with any inconsistent triples found.

In [11]:
### KG construction prompting using batch groq call
from typing import List, Tuple, Dict

kg_construction_prompt="""You are an expert at extracting information in structured formats to build a knowledge graph. 
    Step 1 − Entity detection: Identify all entities in the raw text. Make sure not to miss any out. Entities should be basic and simple, they are akin to Wikipedia nodes. 
    Step 2 − Coreference resolution: Find all expressions in the text that refer to the same entity. Make sure entities are not duplicated. 
    In particular do not include entities that are more specific versions themselves, e.g. "a detailed view of jupiter’s atmosphere" and "jupiter’s atmosphere", only include the most specific version of the entity. 
    Step 3 − Relation extraction: Identify semantic relationships between the entities you have identified.
    Step 4 - Follow these rules carefully:
           - Entity Consistency: Use consistent names for entities throughout the document. For example, if "John Smith" is mentioned as "John", "Mr. Smith", and "John Smith" in different places, use a single consistent form (preferably the most complete one) in all triples.
           - Atomic Terms: Identify distinct key terms (e.g., objects, locations, organizations, acronyms, people, conditions, concepts, feelings). Avoid merging multiple ideas into one term (they should be as "atomistic" as possible).
           - Unified References: Replace any pronouns (e.g., "he," "she," "it," "they," etc.) with the actual referenced entity, if identifiable.
           - Pairwise Relationships: If multiple terms co-occur in the same sentence (or a short paragraph that makes them contextually related), create one triple for each pair that has a meaningful relationship.
           - Standardize terminology: If the same concept appears with slight variations (e.g., "artificial intelligence" and "AI"), use the most common or canonical form consistently.
           - If a person is mentioned by name, create a relation to their location, profession and what they are known for (invented, wrote, started, title, etc.) if known and if it fits the context of the informaiton. 
    Format: Return only the knowledge graph as a list of triples, i.e. [[ "entity1", "relation1−2", "entity2"], ]."""
kg_format_prompt="""Use the given format to extract information from the following input: <input>{input}</input>.
    Skip the preamble and output the result as a list of triples"""
kg_tips_prompt="""Important Tips:
    1. Make sure all information is included in the knowledge graph.
    2. Each triple must only contain three strings! None of the strings should be empty.
    3. Do not split up related information into separate triples because this could change the meaning.
    4. Make sure all brackets and quotation marks are matched.
    5. Before adding a triplet to the knowledge graph, check the concatenated triple makes sense as a sentence. If not, discard it.
    6. Do not include any information that is not in the input text."""
kg_examples_prompt="""Here are some example input and output pairs.
    ## Example 1.
    Input: "The Walt Disney Company, commonly known as Disney, is an American multinational mass media and entertainment conglomerate that is headquartered at the Walt Disney Studios complex in Burbank, California."
    Output: [ [ "The Walt Disney Company", "headquartered at", "Walt Disney Studios complex in Burbank, California" ], [ "The Walt Disney Company", "commonly known as", "Disney" ], [ "The Walt Disney Company", "instance of", "American multinational mass media and entertainment conglomerate" ] ]
    ## Example 2.
    Input: "Amanda Jackson was born in Springfield, Ohio, USA on June 1, 1985. She was a basketball player for the U.S. women’s team."
    Output: [ [ "Amanda Jackson", "born in", "Springfield, Ohio, USA" ], [ "Amanda Jackson", "born on", "June 1, 1985" ], [ "Amanda Jackson", "occupation", "basketball player" ], [ "Amanda Jackson", "played for", "U.S. women’s basketball team" ] ]
    ## Example 3.
    Input: "Music executive Darius Van Arman was born in Pennsylvania. He attended Gonzaga College High School and is a human being."
    Output: [ [ "Darius Van Arman", "occupation", "Music executive" ], [ "Darius Van Arman", "born in", "Pennsylvania" ], [ "Darius Van Arman", "attended", "Gonzaga College High School" ], [ "Darius Van Arman", "instance of", "human being" ] ]
    ## Example 4.
    Input: "Italy had 3.6x times more cases of coronavirus than China."
    Output: [ [ "Italy", "had 3.6x times more cases of coronavirus than", "China" ] ]
    """
kg_return_prompt = """Do not include any text or commentary outside of the knowledge graph.
    The knowledge graph should be a list of triples, i.e. [[ "entity1", "relation1−2", "entity2"], ]."""

# Function to create kg prompt
def construct_prompt_batch(articles: List[str]) -> List[str]: #ei for entailment inference #Changed name and parameters
        # Builds prompts for entailment inference
        prompts = [f"{kg_construction_prompt} {kg_format_prompt.format(input=article)} {kg_tips_prompt} {kg_examples_prompt} {kg_return_prompt}" for index, article in enumerate(articles)]  #Changed to kg
        return prompts #Returning list of prompts

In [12]:
# generate prompts for train and test
train_prompts = construct_prompt_batch(train_generated_list)
test_prompts = construct_prompt_batch(test_generated_list)

In [13]:
batch_size = 1000
api_key='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'
model_name = "llama-3.3-70b-versatile"
output_dir = "batch_requests"

In [14]:
# Create Json file with batch requests
def create_batch_requests(messages, model, filename):
    with open(filename, "w") as file:
        for i, message in enumerate(messages, start=1):
            # Create a unique custom_id for each message
            custom_id = f"request-{i}"
            
            # Create the batch request dictionary
            batch_request = {
                "custom_id": custom_id,
                "method": "POST",
                "url": "/v1/chat/completions",
                "body": {
                    "model": model,
                    "messages": [
                        {"role": "system", "content": "You are an advanced AI system specialized in knowledge extraction and knowledge graph generation.Your expertise includes identifying consistent entity references and meaningful relationships in text."},
                        {"role": "user", "content": message}
                    ]
                }
            }
            
            # Write the batch request to the file as a single JSON object per line
            json.dump(batch_request, file)
            file.write("\n")  # Ensure each JSON object is on a new line


In [40]:
# Excution of Asynchromous Groq Api call functions

#create batch requests
from torch import mode


def process_batches_and_create_jobs(train_prompts, batch_size, model_name, api_key, output_dir):
    """
    Splits train_prompts into batches, creates JSON files using create_batch_requests,
    uploads them to Groq using upload_file_to_groq, and creates batch jobs using create_batch.

    Args:
        train_prompts (list): List of prompts to process.
        batch_size (int): The size of each batch.
        model_name (str): The model name for the Groq API.
        api_key (Groq): API key.
        output_dir (str): Directory to save the JSON files.

    Returns:
        dict: A dictionary mapping batch file names to their associated batch IDs.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    batch_file_ids = {}
    batch_ids = {}

    # Split train_prompts into batches
    for i in range(0, len(train_prompts), batch_size):
        batch_prompts = train_prompts[i:i + batch_size]
        batch_file_name = f"batch_{i // batch_size + 1}.json"
        batch_file_path = os.path.join(output_dir, batch_file_name)

        # Create JSON file for the batch using create_batch_requests
        create_batch_requests(
            messages=batch_prompts,
            model=model_name,
            filename=batch_file_path
        )

        # Upload JSON file to Groq using upload_file_to_groq
        try:
            file_upload_response = upload_file_to_groq(api_key, batch_file_path)
            file_id = file_upload_response["id"]
            batch_file_ids[batch_file_name] = file_id

            # Create batch job using create_batch
            batch_job_response = create_batch(api_key, file_id)
            batch_id = batch_job_response["id"]
            batch_ids[batch_file_name] = batch_id

        except Exception as e:
            print(f"Error processing batch {batch_file_name}: {e}")

    return batch_ids



batch_ids = process_batches_and_create_jobs(
    train_prompts=train_prompts,
    batch_size=batch_size,
    model_name=model_name,
    api_key=api_key,
    output_dir=output_dir
)

In [41]:
#save the batch_ids to a json file
batch_ids_file = os.path.join(output_dir, "batch_ids.json")
with open(batch_ids_file, "w") as f:
    json.dump(batch_ids, f)

In [15]:
#load the batch_ids from the json file
batch_ids_file = os.path.join(output_dir, "batch_ids.json")
with open(batch_ids_file, "r") as f:
    batch_ids = json.load(f)

JSONDecodeError: Expecting value: line 1 column 1 (char 0)

In [46]:
# Check the status of the batch jobs
def check_batch_status(api_key, batch_ids):
    """
    Checks the status of each batch job using the Groq API.

    Args:
        api_key (str): API key for authentication.
        batch_ids (dict): Dictionary mapping batch file names to their associated batch IDs.

    Returns:
        dict: A dictionary mapping batch file names to their status.
    """
    batch_statuses = {}
    for batch_file_name, batch_id in batch_ids.items():
        try:
            status_response = get_batch_status(api_key, batch_id)
            batch_statuses[batch_file_name] = status_response["status"]
        except Exception as e:
            print(f"Error checking status for batch {batch_file_name}: {e}")

    return batch_statuses

batch_statuses = check_batch_status(api_key, batch_ids)
# Print the status of each batch job
for batch_file_name, status in batch_statuses.items():
    print(f"Batch file: {batch_file_name}, Status: {status}")

Batch file: batch_1.json, Status: completed
Batch file: batch_2.json, Status: completed
Batch file: batch_3.json, Status: completed
Batch file: batch_4.json, Status: completed
Batch file: batch_5.json, Status: completed
Batch file: batch_6.json, Status: completed
Batch file: batch_7.json, Status: completed
Batch file: batch_8.json, Status: completed
Batch file: batch_9.json, Status: completed
Batch file: batch_10.json, Status: completed
Batch file: batch_11.json, Status: completed
Batch file: batch_12.json, Status: completed
Batch file: batch_13.json, Status: completed
Batch file: batch_14.json, Status: completed
Batch file: batch_15.json, Status: completed
Batch file: batch_16.json, Status: completed
Batch file: batch_17.json, Status: completed
Batch file: batch_18.json, Status: completed
Batch file: batch_19.json, Status: completed
Batch file: batch_20.json, Status: completed
Batch file: batch_21.json, Status: completed
Batch file: batch_22.json, Status: completed
Batch file: batch_2

In [47]:
# Check the status of the batch jobs and download the results
def check_batch_status_and_download(api_key, batch_ids, output_dir):
    """
    Checks the status of batch jobs and downloads the results if completed.

    Args:
        api_key (str): API key for Groq.
        batch_ids (dict): Dictionary mapping batch file names to their associated batch IDs.
        output_dir (str): Directory to save the downloaded files.

    Returns:
        None
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for batch_file_name, batch_id in batch_ids.items():
        try:
            # Check the status of the batch job
            status_response = get_batch_status(api_key, batch_id)
            status = status_response["status"]

            if status == "completed":
                # Download the results
                output_file_id = status_response["output_file_id"]
                output_file_name = f"{batch_file_name}_output.json"
                output_file_path = os.path.join(output_dir, output_file_name)

                download_message = download_file_content(api_key, output_file_id, output_file_path)
                print(download_message)
            else:
                print(f"Batch {batch_file_name} is still processing. Status: {status}")

        except Exception as e:
            print(f"Error checking status or downloading for {batch_file_name}: {e}")

# Check the status of the batch jobs and download the results
check_batch_status_and_download(api_key, batch_ids, output_dir)

File downloaded successfully to batch_requests/batch_1.json_output.json
File downloaded successfully to batch_requests/batch_2.json_output.json
File downloaded successfully to batch_requests/batch_3.json_output.json
File downloaded successfully to batch_requests/batch_4.json_output.json
File downloaded successfully to batch_requests/batch_5.json_output.json
File downloaded successfully to batch_requests/batch_6.json_output.json
File downloaded successfully to batch_requests/batch_7.json_output.json
File downloaded successfully to batch_requests/batch_8.json_output.json
File downloaded successfully to batch_requests/batch_9.json_output.json
File downloaded successfully to batch_requests/batch_10.json_output.json
File downloaded successfully to batch_requests/batch_11.json_output.json
File downloaded successfully to batch_requests/batch_12.json_output.json
File downloaded successfully to batch_requests/batch_13.json_output.json
File downloaded successfully to batch_requests/batch_14.json

In [15]:
# download the results into list
responses_llm = []
for i in range(0, len(train_prompts), batch_size):
    batch_file_name = f"batch_{i // batch_size + 1}.json_output.json"
    file_path = os.path.join(output_dir, batch_file_name)
    if os.path.exists(file_path):
        with open(file_path, "r") as file:
            batch_output_dict = {}
            for line in file:
                response = json.loads(line)
                custom_id = int(response["custom_id"].split('-')[1])
                content = response["response"]["body"]["choices"][0]["message"]["content"]
                batch_output_dict[custom_id] = content
                
            # Extract the responses from the batch output
            ordered_values = [batch_output_dict[key] for key in sorted(batch_output_dict.keys())]
            responses_llm.extend(ordered_values)

In [35]:
# delete the batch files
import shutil
def delete_batch_files(output_dir):
    """
    Deletes all files in the specified directory.

    Args:
        output_dir (str): Directory where the batch files are stored.

    Returns:
        None
    """
    if os.path.exists(output_dir):
        shutil.rmtree(output_dir)
        print(f"Deleted all files in {output_dir}")
    else:
        print(f"{output_dir} does not exist")
# Delete the batch files
delete_batch_files(output_dir)

Deleted all files in batch_requests


In [16]:
# Process the downloaded files
def process_downloaded_files(output_dir):
    """
    Processes the downloaded files and extracts the responses.

    Args:
        output_dir (str): Directory where the downloaded files are stored.

    Returns:
        dict: A dictionary mapping batch file names to their responses.
    """
    responses = []
    for i in range(0, len(train_prompts), batch_size):
        batch_file_name = f"batch_{i // batch_size + 1}.json_output.json"
        file_path = os.path.join(output_dir, batch_file_name)
        if os.path.exists(file_path):
            with open(file_path, "r") as file:
                batch_output_dict = {}
                for line in file:
                    response = json.loads(line)
                    custom_id = int(response["custom_id"].split('-')[1])
                    content = response["response"]["body"]["choices"][0]["message"]["content"]
                    content = content.replace("{", "[").replace("}", "]")
                    content = content.replace("(", "[").replace(")", "]")
                    
                    # Extract the knowledge graph from the output
                    # Assumes the LLM returns the KG in a list within <python> tags
                    start_tag = content.find('[')
                    end_tag = content.rfind(']')
                    if start_tag != -1 and end_tag != -1:
                        kg_string = content[start_tag:end_tag+1]
                        try:
                            #print(kg_string)
                            kg = eval(kg_string) #use literal_eval for security
                            if isinstance(kg, list):
                                #responses.append(kg)
                                batch_output_dict[custom_id] = kg
                            else:
                                print("LLM did not return a list.")
                                #responses.append([])
                                batch_output_dict[custom_id] = []
                        except (SyntaxError, NameError) as e:
                            content = content.replace(" ", "")
                            start_tag = content.rfind('[\n[')
                            end_tag = content.rfind(']\n]')
                            if start_tag != -1 and end_tag != -1:
                                kg_string = content[start_tag:end_tag+3]
                                try:
                                    kg = eval(kg_string) #use literal_eval for security
                                    if isinstance(kg, list):
                                        #responses.append(kg)
                                        batch_output_dict[custom_id] = kg
                                    else:
                                        print("LLM did not return a list.")
                                        #responses.append([])
                                        batch_output_dict[custom_id] = []
                                except (SyntaxError, NameError) as e:
                                    print(f"Error parsing LLM output: {e}")
                                    #responses.append([])
                                    batch_output_dict[custom_id] = []
                            else:
                                end_tag = content.rfind(']')
                                if start_tag != -1 and end_tag != -1:
                                    kg_string = content[start_tag:end_tag+1]
                                    try:
                                        kg = eval(kg_string) #use literal_eval for security
                                        if isinstance(kg, list):
                                            #responses.append(kg)
                                            batch_output_dict[custom_id] = kg
                                        else:
                                            print("LLM did not return a list.")
                                            #responses.append([])
                                            batch_output_dict[custom_id] = []
                                    except (SyntaxError, NameError) as e:
                                        print(f"Error parsing LLM output: {e}")
                                        #responses.append([])
                                        batch_output_dict[custom_id] = []
                                else:
                                    
                                    content = content.replace("\n", "")
                                    start_tag = content.rfind('[[')
                                    end_tag = content.rfind(']]')
                                    if start_tag != -1 and end_tag != -1:
                                        kg_string = content[start_tag:end_tag+2]
                                        try:
                                            kg = eval(kg_string) #use literal_eval for security
                                            if isinstance(kg, list):
                                                #responses.append(kg)
                                                batch_output_dict[custom_id] = kg
                                            else:
                                                print("LLM did not return a list.")
                                                #responses.append([])
                                                batch_output_dict[custom_id] = []
                                        except (SyntaxError, NameError) as e:
                                            print(f"Error parsing LLM output: {e}")
                                            #responses.append([])
                                            batch_output_dict[custom_id] = []
                                    else:
                                        print("Could not find KG in LLM output.")
                                        #responses.append([])
                                        batch_output_dict[custom_id] = []
                            
                    else:
                        print("Could not find KG in LLM output.")
                        #responses.append([])
                        batch_output_dict[custom_id] = []
        
                    
                # Extract the responses from the batch output
                ordered_values = [batch_output_dict[key] for key in sorted(batch_output_dict.keys())]
                responses.extend(ordered_values)

    return responses



# Process the downloaded files
responses = process_downloaded_files(output_dir)

Error parsing LLM output: unterminated string literal (detected at line 1) (<string>, line 1)
Error parsing LLM output: invalid syntax. Perhaps you forgot a comma? (<string>, line 1)
Error parsing LLM output: unterminated string literal (detected at line 1) (<string>, line 1)


In [17]:

# find items with empty list in responses
empty_responses = [i for i, response in enumerate(responses) if not response]
# Print the indices of empty responses
print("Indices of empty responses:", empty_responses)
print("Number of empty responses:", len(empty_responses))


Indices of empty responses: [722, 2045, 2364, 2628, 5928, 12973, 19296, 19927, 20406, 20500, 21491, 22058]
Number of empty responses: 12


In [18]:
# sampling 2000 of train data

train_data = train_data.reset_index(drop=True)
train_data_sample = train_data.sample(n=2000, random_state=42)
selected_indices = train_data_sample.index.tolist()

# sample based selected indices
sampled_responses = [responses[i] for i in selected_indices]
sampled_train_grounding_list = [train_grounding_list[i] for i in selected_indices]
sampled_train_generated_list = [train_generated_list[i] for i in selected_indices]


In [18]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification,AutoModelForSeq2SeqLM
import torch
from torch.utils.data import DataLoader, Dataset

class NLIDataset(Dataset):
    def __init__(self, premises, hypotheses):
        self.premises = premises
        self.hypotheses = hypotheses

    def __len__(self):
        return len(self.premises)

    def __getitem__(self, idx):
        return self.premises[idx], self.hypotheses[idx]

def apply_nli_to_knowledge_graphs_optimized(grounding_list, knowledge_graphs,train_generated_list ,model_name="tasksource/ModernBERT-base-nli", batch_size=16):
    """
    Optimized version of applying an NLI model to knowledge graphs using batch processing.

    Args:
        grounding_list (list): List of grounding statements (premises).
        knowledge_graphs (list): List of knowledge graphs, where each graph is a list of triplets (hypotheses).
        model_name (str): Name of the pre-trained NLI model to use.
        batch_size (int): Batch size for processing.

    Returns:
        list: A list of lists, where each inner list contains the NLI labels for the triplets in the corresponding knowledge graph.
    """
    # Load the pre-trained NLI model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name)
    #model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    #model.eval()

    # Move model to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #model = torch.nn.DataParallel(model)  # Wrap the model for multi-GPU
    model.to(device)

    # Prepare premises and hypotheses
    premises = []
    hypotheses = []
    for grounding, kg,generated in zip(grounding_list, knowledge_graphs,train_generated_list):
        # Use generated as hypothesis if kg is empty
        if kg == []:
            hypothesis = generated
            premises.append(grounding)
            hypotheses.append(hypothesis)
        else:
            hypothesis = f""
            for triplet in kg:
                
                #hypothesis = f""
                for items in triplet:
                    hypothesis += items
                    hypothesis += " "
                hypothesis = hypothesis.strip()
                hypothesis += ". "
                #hypothesis = f"{triplet[0]} {triplet[1]} {triplet[2]}"
            premises.append(grounding)
            hypotheses.append(hypothesis)

    # Create a dataset and dataloader
    dataset = NLIDataset(premises, hypotheses)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    # List to store the results
    all_nli_labels = []

    # Process in batches
    batch_no = 0
    for batch in dataloader:
        batch_premises, batch_hypotheses = batch

        # Tokenize the batch
        inputs = tokenizer(
            list(batch_premises),
            list(batch_hypotheses),
            return_tensors="pt",
            truncation=True,
            max_length=8000,
            padding=True
        ).to(device)

        # Perform inference
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits

        # Get probabilities and predicted labels
        probabilities = torch.softmax(logits, dim=-1)
        labels = ["entailment", "neutral", "contradiction"]
        #batch_labels = [labels[torch.argmax(prob).item()] for prob in probabilities]
        predicted_probs = probabilities.tolist()


        # Append the batch results
        all_nli_labels.extend(predicted_probs)
        if batch_no % 10 == 0:
            print(f"Processed batch {batch_no + 1}/{len(dataloader)}")
        batch_no += 1

    # Reshape the results to match the structure of knowledge_graphs
    # reshaped_results = []
    # idx = 0
    # for kg in knowledge_graphs:
    #     if kg == []:
    #         reshaped_results.append([all_nli_labels[idx]])
    #         idx += 1
    #     else:
    #         reshaped_results.append(all_nli_labels[idx:idx + len(kg)])
    #         idx += len(kg)

    return all_nli_labels





In [19]:
# Apply the optimized NLI model
#nli_results = apply_nli_to_knowledge_graphs_optimized(sampled_train_grounding_list, sampled_responses, sampled_train_generated_list, batch_size=1)
nli_results = apply_nli_to_knowledge_graphs_optimized(train_grounding_list, responses, train_generated_list, batch_size=1)
#save the nli_labels to a json file
nli_labels_file = os.path.join(output_dir, "nli_labels_4.json")
with open(nli_labels_file, "w") as f:
    json.dump(nli_results, f)

# Print the results
# for i, (grounding, kg, labels) in enumerate(zip(grounding_list, knowledge_graphs, nli_results)):
#     print(f"Grounding: {grounding}")
#     for triplet, label in zip(kg, labels):
#         print(f"  Triplet: {triplet} -> NLI Label: {label}")

  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,


Processed batch 1/25854
Processed batch 11/25854
Processed batch 21/25854
Processed batch 31/25854
Processed batch 41/25854
Processed batch 51/25854
Processed batch 61/25854
Processed batch 71/25854
Processed batch 81/25854
Processed batch 91/25854
Processed batch 101/25854
Processed batch 111/25854
Processed batch 121/25854
Processed batch 131/25854
Processed batch 141/25854
Processed batch 151/25854
Processed batch 161/25854
Processed batch 171/25854
Processed batch 181/25854
Processed batch 191/25854
Processed batch 201/25854
Processed batch 211/25854
Processed batch 221/25854
Processed batch 231/25854
Processed batch 241/25854
Processed batch 251/25854
Processed batch 261/25854
Processed batch 271/25854
Processed batch 281/25854
Processed batch 291/25854
Processed batch 301/25854
Processed batch 311/25854
Processed batch 321/25854
Processed batch 331/25854
Processed batch 341/25854
Processed batch 351/25854
Processed batch 361/25854
Processed batch 371/25854
Processed batch 381/258

In [20]:
# create a list of empty lists
empty_list = []
# for i in range(len(sampled_responses)):
#     empty_list.append([])
for i in range(len(responses)):
    empty_list.append([])

In [21]:
# Apply the optimized NLI model
nli_results = apply_nli_to_knowledge_graphs_optimized(train_grounding_list, empty_list, train_generated_list, batch_size=1)

#save the nli_labels to a json file
nli_labels_file = os.path.join(output_dir, "nli_labels_5.json")
with open(nli_labels_file, "w") as f:
    json.dump(nli_results, f)

# Print the results
# for i, (grounding, kg, labels) in enumerate(zip(grounding_list, knowledge_graphs, nli_results)):
#     print(f"Grounding: {grounding}")
#     for triplet, label in zip(kg, labels):
#         print(f"  Triplet: {triplet} -> NLI Label: {label}")

Processed batch 1/25854
Processed batch 11/25854
Processed batch 21/25854
Processed batch 31/25854
Processed batch 41/25854
Processed batch 51/25854
Processed batch 61/25854
Processed batch 71/25854
Processed batch 81/25854
Processed batch 91/25854
Processed batch 101/25854
Processed batch 111/25854
Processed batch 121/25854
Processed batch 131/25854
Processed batch 141/25854
Processed batch 151/25854
Processed batch 161/25854
Processed batch 171/25854
Processed batch 181/25854
Processed batch 191/25854
Processed batch 201/25854
Processed batch 211/25854
Processed batch 221/25854
Processed batch 231/25854
Processed batch 241/25854
Processed batch 251/25854
Processed batch 261/25854
Processed batch 271/25854
Processed batch 281/25854
Processed batch 291/25854
Processed batch 301/25854
Processed batch 311/25854
Processed batch 321/25854
Processed batch 331/25854
Processed batch 341/25854
Processed batch 351/25854
Processed batch 361/25854
Processed batch 371/25854
Processed batch 381/258

In [20]:
#load the nli_labels from the json file
nli_labels_file = os.path.join(output_dir, "nli_labels_1.json")
with open(nli_labels_file, "r") as f:
    nli_results_1 = json.load(f)

In [21]:
#load the nli_labels from the json file
nli_labels_file = os.path.join(output_dir, "nli_labels_2.json")
with open(nli_labels_file, "r") as f:
    nli_results_2 = json.load(f)

In [22]:
#load the nli_labels from the json file
nli_labels_file = os.path.join(output_dir, "nli_labels_3.json")
with open(nli_labels_file, "r") as f:
    nli_results_3 = json.load(f)

In [23]:
# unnest the nli_results_3 and insert into a simple list
# if the list is empty, ignore it
# Flatten the list of lists 
nli_results_3_flat = []
for sublist in nli_results_3:
    #check if the sublist is empty
    if sublist:  # Only append non-empty sublists
        # Flatten the sublist
        for item in sublist:
            nli_results_3_flat.append([item])
# Print the first 10 items
print(nli_results_3_flat[:10])


[[[0.8872289061546326, 0.10726634413003922, 0.005504806060343981]], [[0.5776011943817139, 0.27792131900787354, 0.1444774717092514]], [[0.8225379586219788, 0.1677551418542862, 0.009706896729767323]], [[0.045631855726242065, 0.885104238986969, 0.06926392018795013]], [[0.769058346748352, 0.21048440039157867, 0.02045726776123047]], [[0.019338561221957207, 0.9722611308097839, 0.008400311693549156]], [[0.730110228061676, 0.22835691273212433, 0.041532889008522034]], [[0.7056415677070618, 0.27160972356796265, 0.02274864725768566]], [[0.6167622804641724, 0.37195467948913574, 0.01128300093114376]], [[0.630352258682251, 0.3574708104133606, 0.012177002616226673]]]


In [19]:
#load the nli_labels from the json file
nli_labels_file = os.path.join(output_dir, "nli_labels_4.json")
with open(nli_labels_file, "r") as f:
    nli_results_4 = json.load(f)

In [20]:
#load the nli_labels from the json file
nli_labels_file = os.path.join(output_dir, "nli_labels_5.json")
with open(nli_labels_file, "r") as f:
    nli_results_5 = json.load(f)

In [36]:
# Check the consistency of NLI results
#for each sublist in nli_results, if the third element is greater than 0.5, then 0 else if the average of first element is greater than 0.5 then 1 else 0
def check_consistency_nli(nli_results):
    """
    Checks the consistency of NLI results. If all items in the list are "entailment", return 1, else return 0.
    Args:
        nli_results (list): List of NLI results, where each element is a list of labels for the corresponding knowledge graph.
    Returns:
        int: 1 if all items are "entailment", else 0.
    """
    results = []
    for labels in nli_results:
        if labels[2] >= 0.5 :
            results.append(0)
        elif labels[0] >= 0.5:
            results.append(1)
        else:
            results.append(0)
        
    return results
# Example usage
#nli_results = [[[0.3, 0.4, 0.3], [0.2, 0.5, 0.3], [0.1, 0.8, 0.1]], 
 #                [[0.6, 0.2, 0.2], [0.7, 0.1, 0.2], [0.5, 0.4, 0.1]], 
 #                [[0.4, 0.4, 0.2], [0.3, 0.5, 0.2], [0.2, 0.7, 0.1]]]
consistency_results_1 = check_consistency_nli(nli_results_4)
consistency_results_2 = check_consistency_nli(nli_results_5)
#consistency_results_3 = check_consistency_nli(nli_results_3_flat)

In [37]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def calculate_metrics(train_data):
    """
    Calculates accuracy, precision, recall, and F1-score using the label and predictions columns of a pandas DataFrame.

    Args:
        train_data (pd.DataFrame): A pandas DataFrame with 'label' and 'predictions' columns.

    Returns:
        dict: A dictionary containing accuracy, precision, recall, and F1-score.
    """
    # Extract labels and predictions
    labels = train_data['label']
    predictions = train_data['predictions']

    # Calculate metrics
    prevalence = labels.value_counts(normalize=True)[1]
    accuracy = accuracy_score(labels, predictions)
    precision = precision_score(labels, predictions, average='binary')  # Use 'binary' for binary classification
    recall = recall_score(labels, predictions, average='binary')
    f1 = f1_score(labels, predictions, average='binary')

    # Return metrics as a dictionary
    return {
        "prevalence": prevalence,
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1_score": f1
    }


train_data.loc[:,"predictions"] =  consistency_results_1

# Calculate metrics
metrics = calculate_metrics(train_data)
# Print the metrics
print("Metrics:")
for metric, value in metrics.items():
    print(f"{metric}: {value:.4f}")


Metrics:
prevalence: 0.5009
accuracy: 0.6089
precision: 0.5813
recall: 0.7833
f1_score: 0.6674


In [38]:
# append to train data
train_data.loc[:,"predictions"] =  consistency_results_2

# Calculate metrics
metrics = calculate_metrics(train_data)
# Print the metrics
print("Metrics:")
for metric, value in metrics.items():
    print(f"{metric}: {value:.4f}")

Metrics:
prevalence: 0.5009
accuracy: 0.6944
precision: 0.6485
recall: 0.8511
f1_score: 0.7361


In [39]:
# Calculate metrics by dataset origin
def calculate_metrics_by_origin(train_data):
    """
    Calculates metrics by dataset origin.

    Args:
        train_data (pd.DataFrame): A pandas DataFrame with 'label', 'predictions', and 'dataset_origin' columns.

    Returns:
        dict: A dictionary containing accuracy, precision, recall, and F1-score by dataset origin.
    """
    metrics_by_origin = {}
    for origin in train_data['dataset_origin'].unique():
        origin_data = train_data[train_data['dataset_origin'] == origin]
        metrics = calculate_metrics(origin_data)
        metrics_by_origin[origin] = metrics

    return metrics_by_origin
# Calculate metrics by dataset origin
metrics_by_origin = calculate_metrics_by_origin(train_data)
# Print the metrics by dataset origin
print("Metrics by Dataset Origin:")



# show metrics by dataset origin in tabular format
def display_metrics_by_origin(metrics_by_origin):
    """
    Displays metrics by dataset origin in a tabular format.

    Args:
        metrics_by_origin (dict): A dictionary containing metrics by dataset origin.

    Returns:
        pd.DataFrame: A DataFrame containing the metrics.
    """
    # Create a DataFrame from the metrics dictionary
    metrics_df = pd.DataFrame.from_dict(metrics_by_origin, orient='index')
    return metrics_df
# Display metrics by dataset origin in tabular format
metrics_df = display_metrics_by_origin(metrics_by_origin)
# Print the DataFrame
print(metrics_df)

Metrics by Dataset Origin:
            prevalence  accuracy  precision    recall  f1_score
XSumFaith     0.084000  0.810400   0.255556  0.657143  0.368000
Polytope      0.872240  0.867508   0.915044  0.934901  0.924866
FactCC        0.858217  0.892589   0.909730  0.971214  0.939467
SummEval      0.878824  0.878824   0.910714  0.955823  0.932724
FRANK         0.405769  0.692308   0.576577  0.909953  0.705882
Wang20        0.478992  0.567227   0.530055  0.850877  0.653199
CLIFF         0.596667  0.780000   0.762791  0.916201  0.832487
Goyal21       0.266667  0.453333   0.266667  0.600000  0.369231
Cao22         0.571116  0.663020   0.729614  0.651341  0.688259
HaluEval      0.500178  0.578782   0.555556  0.789324  0.652123
PAWS          0.444032  0.891243   0.833509  0.943539  0.885117
qags_cnndm    0.488636  0.602273   0.552632  0.976744  0.705882
qags_xsum     0.478788  0.618182   0.564516  0.886076  0.689655
samsum        0.182857  0.480000   0.256198  0.968750  0.405229
TofuEval     