In [1]:
import json, os, pandas as pd, numpy as np, csv
import requests
import io
import tarfile
import zipfile
from datasets import load_dataset

In [2]:
import os

# Print the current working directory
print("Current Working Directory:", os.getcwd())

# Change the current working directory
new_directory = "/teamspace/studios/this_studio/simpletext-2025-controlcreativity/notebooks"
os.chdir(new_directory)

# Print the new working directory to confirm the change
print("New Working Directory:", os.getcwd())

Current Working Directory: /teamspace/studios/this_studio
New Working Directory: /teamspace/studios/this_studio/simpletext-2025-controlcreativity/notebooks


### Load data for Hallucination Detection training

In [3]:
# Directory where the CSV files are stored
data_dir = os.path.join(os.path.dirname(os.getcwd()), 'data')

# List all CSV files in the directory
csv_files = [f for f in os.listdir(data_dir) if f.endswith('.csv')]

# Load all CSV files into a single DataFrame
df_list = []
for file in csv_files:
    file_path = os.path.join(data_dir, file)
    df = pd.read_csv(file_path)
    df_list.append(df)

# Concatenate all DataFrames
training_df = pd.concat(df_list, ignore_index=True)

# Display the first few rows of the combined DataFrame
training_df.head()

Unnamed: 0,id,grounding,generated_text,label,cut,dataset_origin
0,34687720,France's Dubuisson carded a 67 to tie with ove...,rory mcilroy will take a one-shot lead into th...,0,val,XSumFaith
1,29347895,He died at his home in Cambridge following an ...,veteran classical music conductor christopher ...,0,val,XSumFaith
2,37895159,The Cherries went down 2-1 at Sunderland on Sa...,bournemouth manager eddie howe says his side a...,0,val,XSumFaith
3,37546354,Washington blamed Russia and the Syrian govern...,the us says it has suspended talks with russia...,0,val,XSumFaith
4,22299596,"Gareth Colfer-Williams, 25, died last week at ...",a post-mortem examination has concluded that a...,0,val,XSumFaith


In [None]:
# Get the counts of val and test data
val_test_spit = training_df['cut'].value_counts()

# Display the counts
print(val_test_spit)

print("val can be used for training the model and test can be used for evaluation the performance")

cut
val     84044
test    38332
Name: count, dtype: int64
val can be used for training the model and test can be used for evaluation the performance


In [5]:
# Get the counts by dataset origin
training_df['dataset_origin'].value_counts()

dataset_origin
Vitamin C     63054
HaluEval      20000
Fever         19998
PAWS           8000
XSumFaith      2353
SummEval       1698
FactCC         1434
FRANK          1393
Polytope       1268
Cao22           696
CLIFF           600
TofuEval        534
Wang20          474
samsum          250
qags_xsum       239
qags_cnndm      235
Goyal21         150
Name: count, dtype: int64

In [6]:
# prepare train and test - remove vitamin c & Fever as it is skewing the dataset towards Fact verification
# train lists 
train_data = training_df[(training_df.cut == 'val') & (~training_df['dataset_origin'].isin(['Vitamin C', 'Fever']))]
train_grounding_list = list(train_data['grounding'])
train_generated_list = list(train_data['generated_text'])

# test lists
test_data = training_df[(training_df.cut == 'test') & (~training_df['dataset_origin'].isin(['Vitamin C', 'Fever']))]
test_grounding_list = list(test_data['grounding'])
test_generated_list = list(test_data['generated_text'])

### LLM as judge baseline Approach

##### Entailment Prompting

In [7]:
### Entailment prompting
from typing import List, Tuple, Dict

kg_construction_prompt="""You are an expert at determining if a summary is consistent with a source article. Given an article and a summary, determine if all the information in the summary is supported by the article. Answer "yes" if the summary is consistent, and "no" if it is inconsistent."""
ei_format_prompt="""Article: {article}
    Summary: {summary}
    Answer (yes or no):"""

# Function to create entailment prompt
def construct_prompt_batch(articles: List[str], summaries: List[str]) -> List[str]: #ei for entailment inference #Changed name and parameters
        # Builds prompts for entailment inference
        prompts = [kg_construction_prompt + "\n" + ei_format_prompt.format(article=article, summary=summary) for article, summary in zip(articles, summaries)]  #Changed to ei
        return prompts #Returning list of prompts

In [8]:
# generate prompts for train and test
train_prompts = construct_prompt_batch(train_grounding_list, train_generated_list)
test_prompts = construct_prompt_batch(test_grounding_list, test_generated_list)

In [9]:
import torch
import json
import os
from groq import Groq
import requests # pip install requests first!


# Using Asynchromous Groq Api call functions

# Create Json file with batch requests
def create_batch_requests(messages, model, filename):
    with open(filename, "w") as file:
        for i, message in enumerate(messages, start=1):
            # Create a unique custom_id for each message
            custom_id = f"request-{i}"
            
            # Create the batch request dictionary
            batch_request = {
                "custom_id": custom_id,
                "method": "POST",
                "url": "/v1/chat/completions",
                "body": {
                    "model": model,
                    "messages": [
                        {"role": "user", "content": message}
                    ]
                }
            }
            
            # Write the batch request to the file as a single JSON object per line
            json.dump(batch_request, file)
            file.write("\n")  # Ensure each JSON object is on a new line


# Upload the batch requests to Groq
def upload_file_to_groq(api_key, file_path):
    url = "https://api.groq.com/openai/v1/files"
    
    headers = {
        "Authorization": f"Bearer {api_key}"
    }
    
    # Prepare the file and form data
    files = {
        "file": ("batch_file.jsonl", open(file_path, "rb"))
    }
    
    data = {
        "purpose": "batch"
    }
    
    # Make the POST request
    response = requests.post(url, headers=headers, files=files, data=data)
    
    return response.json()


# Create a batch job
def create_batch(api_key, input_file_id):
    url = "https://api.groq.com/openai/v1/batches"
    
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }
    
    data = {
        "input_file_id": input_file_id,
        "endpoint": "/v1/chat/completions",
        "completion_window": "24h"
    }
    
    response = requests.post(url, headers=headers, json=data)
    return response.json()


# check the status of the batch job
def get_batch_status(api_key, batch_id):
    url = f"https://api.groq.com/openai/v1/batches/{batch_id}"
    
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }
    
    response = requests.get(url, headers=headers)
    return response.json()

# Download the results of the batch job
def download_file_content(api_key, output_file_id, output_file):
    url = f"https://api.groq.com/openai/v1/files/{output_file_id}/content"
    
    headers = {
        "Authorization": f"Bearer {api_key}"
    }
    
    response = requests.get(url, headers=headers)
    
    # Write the content to a file
    with open(output_file, 'wb') as f:
        f.write(response.content)
    
    return f"File downloaded successfully to {output_file}"



In [10]:
batch_size = 1000
api_key='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'
model_name = "llama-3.3-70b-versatile"
output_dir = "batch_requests"

In [11]:
# Excution of Asynchromous Groq Api call functions

#create batch requests
from torch import mode


def process_batches_and_create_jobs(train_prompts, batch_size, model_name, api_key, output_dir):
    """
    Splits train_prompts into batches, creates JSON files using create_batch_requests,
    uploads them to Groq using upload_file_to_groq, and creates batch jobs using create_batch.

    Args:
        train_prompts (list): List of prompts to process.
        batch_size (int): The size of each batch.
        model_name (str): The model name for the Groq API.
        api_key (Groq): API key.
        output_dir (str): Directory to save the JSON files.

    Returns:
        dict: A dictionary mapping batch file names to their associated batch IDs.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    batch_file_ids = {}
    batch_ids = {}

    # Split train_prompts into batches
    for i in range(0, len(train_prompts), batch_size):
        batch_prompts = train_prompts[i:i + batch_size]
        batch_file_name = f"batch_{i // batch_size + 1}.json"
        batch_file_path = os.path.join(output_dir, batch_file_name)

        # Create JSON file for the batch using create_batch_requests
        create_batch_requests(
            messages=batch_prompts,
            model=model_name,
            filename=batch_file_path
        )

        # Upload JSON file to Groq using upload_file_to_groq
        try:
            file_upload_response = upload_file_to_groq(api_key, batch_file_path)
            file_id = file_upload_response["id"]
            batch_file_ids[batch_file_name] = file_id

            # Create batch job using create_batch
            batch_job_response = create_batch(api_key, file_id)
            batch_id = batch_job_response["id"]
            batch_ids[batch_file_name] = batch_id

        except Exception as e:
            print(f"Error processing batch {batch_file_name}: {e}")

    return batch_ids



batch_ids = process_batches_and_create_jobs(
    train_prompts=train_prompts,
    batch_size=batch_size,
    model_name=model_name,
    api_key=api_key,
    output_dir=output_dir
)

In [12]:
# Check the status of the batch jobs
def check_batch_status(api_key, batch_ids):
    """
    Checks the status of each batch job using the Groq API.

    Args:
        api_key (str): API key for authentication.
        batch_ids (dict): Dictionary mapping batch file names to their associated batch IDs.

    Returns:
        dict: A dictionary mapping batch file names to their status.
    """
    batch_statuses = {}
    for batch_file_name, batch_id in batch_ids.items():
        try:
            status_response = get_batch_status(api_key, batch_id)
            batch_statuses[batch_file_name] = status_response["status"]
        except Exception as e:
            print(f"Error checking status for batch {batch_file_name}: {e}")

    return batch_statuses

batch_statuses = check_batch_status(api_key, batch_ids)
# Print the status of each batch job
for batch_file_name, status in batch_statuses.items():
    print(f"Batch file: {batch_file_name}, Status: {status}")

Batch file: batch_1.json, Status: completed
Batch file: batch_2.json, Status: completed
Batch file: batch_3.json, Status: completed
Batch file: batch_4.json, Status: completed
Batch file: batch_5.json, Status: completed
Batch file: batch_6.json, Status: completed
Batch file: batch_7.json, Status: in_progress
Batch file: batch_8.json, Status: in_progress
Batch file: batch_9.json, Status: in_progress
Batch file: batch_10.json, Status: in_progress
Batch file: batch_11.json, Status: in_progress
Batch file: batch_12.json, Status: in_progress
Batch file: batch_13.json, Status: in_progress
Batch file: batch_14.json, Status: in_progress
Batch file: batch_15.json, Status: in_progress
Batch file: batch_16.json, Status: in_progress
Batch file: batch_17.json, Status: in_progress
Batch file: batch_18.json, Status: in_progress
Batch file: batch_19.json, Status: in_progress
Batch file: batch_20.json, Status: in_progress
Batch file: batch_21.json, Status: in_progress
Batch file: batch_22.json, Status:

In [13]:
# Check the status of the batch jobs and download the results
def check_batch_status_and_download(api_key, batch_ids, output_dir):
    """
    Checks the status of batch jobs and downloads the results if completed.

    Args:
        api_key (str): API key for Groq.
        batch_ids (dict): Dictionary mapping batch file names to their associated batch IDs.
        output_dir (str): Directory to save the downloaded files.

    Returns:
        None
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for batch_file_name, batch_id in batch_ids.items():
        try:
            # Check the status of the batch job
            status_response = get_batch_status(api_key, batch_id)
            status = status_response["status"]

            if status == "completed":
                # Download the results
                output_file_id = status_response["output_file_id"]
                output_file_name = f"{batch_file_name}_output.json"
                output_file_path = os.path.join(output_dir, output_file_name)

                download_message = download_file_content(api_key, output_file_id, output_file_path)
                print(download_message)
            else:
                print(f"Batch {batch_file_name} is still processing. Status: {status}")

        except Exception as e:
            print(f"Error checking status or downloading for {batch_file_name}: {e}")

# Check the status of the batch jobs and download the results
check_batch_status_and_download(api_key, batch_ids, output_dir)

File downloaded successfully to batch_requests/batch_1.json_output.json
File downloaded successfully to batch_requests/batch_2.json_output.json
File downloaded successfully to batch_requests/batch_3.json_output.json
File downloaded successfully to batch_requests/batch_4.json_output.json
File downloaded successfully to batch_requests/batch_5.json_output.json
File downloaded successfully to batch_requests/batch_6.json_output.json
File downloaded successfully to batch_requests/batch_7.json_output.json
File downloaded successfully to batch_requests/batch_8.json_output.json
File downloaded successfully to batch_requests/batch_9.json_output.json
File downloaded successfully to batch_requests/batch_10.json_output.json
File downloaded successfully to batch_requests/batch_11.json_output.json
File downloaded successfully to batch_requests/batch_12.json_output.json
File downloaded successfully to batch_requests/batch_13.json_output.json
File downloaded successfully to batch_requests/batch_14.json

In [28]:
# Process the downloaded files
def process_downloaded_files(output_dir):
    """
    Processes the downloaded files and extracts the responses.

    Args:
        output_dir (str): Directory where the downloaded files are stored.

    Returns:
        dict: A dictionary mapping batch file names to their responses.
    """
    responses = []
    for i in range(0, len(train_prompts), batch_size):
        batch_file_name = f"batch_{i // batch_size + 1}.json_output.json"
        file_path = os.path.join(output_dir, batch_file_name)
        if os.path.exists(file_path):
            with open(file_path, "r") as file:
                for line in file:
                    response = json.loads(line)
                    custom_id = response["custom_id"]
                    content = response["response"]["body"]["choices"][0]["message"]["content"]
                    responses.append(content)
    

    return responses

# Process the downloaded files
def check_consistency_batch(grok_outputs: List[str]) -> List[int]:  # Now the grok outputs is what comes in

    results = []
    for output in grok_outputs:
        output = output.strip().lower()  # Normalize the output by stripping and converting to lowercase
        if output.startswith("yes"):  # Check if the output starts with "yes"
            results.append(1)  # Consistent
        elif output.startswith("no"):  # Check if the output starts with "no"
            results.append(0)  # Inconsistent
        else:
            results.append(0)  # Undetermined, handle as needed.

    return results



responses = process_downloaded_files(output_dir)
consistency_results = check_consistency_batch(responses)

In [39]:
# append to train data
train_data.loc[:,'predictions'] =  consistency_results


from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def calculate_metrics(train_data):
    """
    Calculates accuracy, precision, recall, and F1-score using the label and predictions columns of a pandas DataFrame.

    Args:
        train_data (pd.DataFrame): A pandas DataFrame with 'label' and 'predictions' columns.

    Returns:
        dict: A dictionary containing accuracy, precision, recall, and F1-score.
    """
    # Extract labels and predictions
    labels = train_data['label']
    predictions = train_data['predictions']

    # Calculate metrics
    prevalence = labels.value_counts(normalize=True)[1]
    accuracy = accuracy_score(labels, predictions)
    precision = precision_score(labels, predictions, average='binary')  # Use 'binary' for binary classification
    recall = recall_score(labels, predictions, average='binary')
    f1 = f1_score(labels, predictions, average='binary')

    # Return metrics as a dictionary
    return {
        "prevalence": prevalence,
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1_score": f1
    }

# Calculate metrics
metrics = calculate_metrics(train_data)
# Print the metrics
print("Metrics:")
for metric, value in metrics.items():
    print(f"{metric}: {value:.4f}")


Metrics:
prevalence: 0.5009
accuracy: 0.6782
precision: 0.6523
recall: 0.7660
f1_score: 0.7046


### 1. GraphEval Implementation as Baseline Approach

GraphEval is a combination approach of using LLMs to create KGs and check consistency using NLI to detect hallucinations.

The implementation includes the main components of GraphEval as described in the paper:
1.KG construction from the LLM output
2.Consistency checking for each triple using an NLI model
3.Overall evaluation based on the consistency of all triples

Note that the KG construction step (construct_kg method) is a placeholder and should be implemented using an actual LLM in practice. The paper doesn't provide specific details on this step, so you would need to design an appropriate prompt and use an LLM API to generate the KG triples.

The check_consistency method uses a pre-trained RoBERTa model fine-tuned on MNLI for natural language inference. It returns the probability of contradiction between the triple and the context.

The evaluate method puts it all together, constructing the KG, checking each triple for consistency, and returning the overall result along with any inconsistent triples found.

In [13]:
### KG construction prompting using batch groq call
from typing import List, Tuple, Dict

kg_construction_prompt="""You are an expert at extracting information in structured formats to build a knowledge graph. 
    Step 1 − Entity detection: Identify all entities in the raw text. Make sure not to miss any out. Entities should be basic and simple, they are akin to Wikipedia nodes. 
    Step 2 − Coreference resolution: Find all expressions in the text that refer to the same entity. Make sure entities are not duplicated. 
    In particular do not include entities that are more specific versions themselves, e.g. "a detailed view of jupiter’s atmosphere" and "jupiter’s atmosphere", only include the most specific version of the entity. 
    Step 3 − Relation extraction: Identify semantic relationships between the entities you have identified.
    Format: Return the knowledge graph as a list of triples, i.e. [ "entity1", "relation1−2", "entity2"], in Python code."""
kg_format_prompt="""Use the given format to extract information from the following input: <input>{input}</input>.
    Skip the preamble and output the result as a list within <python> tags."""
kg_tips_prompt="""Important Tips:
    1. Make sure all information is included in the knowledge graph.
    2. Each triple must only contain three strings! None of the strings should be empty.
    3. Do not split up related information into separate triples because this could change the meaning.
    4. Make sure all brackets and quotation marks are matched.
    5. Before adding a triplet to the knowledge graph, check the concatenated triple makes sense as a sentence. If not, discard it."""
kg_examples_prompt="""Here are some example input and output pairs.
    ## Example 1.
    Input: "The Walt Disney Company, commonly known as Disney, is an American multinational mass media and entertainment conglomerate that is headquartered at the Walt Disney Studios complex in Burbank, California."
    Output: [ [ "The Walt Disney Company", "headquartered at", "Walt Disney Studios complex in Burbank, California" ], [ "The Walt Disney Company", "commonly known as", "Disney" ], [ "The Walt Disney Company", "instance of", "American multinational mass media and entertainment conglomerate" ] ]
    ## Example 2.
    Input: "Amanda Jackson was born in Springfield, Ohio, USA on June 1, 1985. She was a basketball player for the U.S. women’s team."
    Output: [ [ "Amanda Jackson", "born in", "Springfield, Ohio, USA" ], [ "Amanda Jackson", "born on", "June 1, 1985" ], [ "Amanda Jackson", "occupation", "basketball player" ], [ "Amanda Jackson", "played for", "U.S. women’s basketball team" ] ]
    ## Example 3.
    Input: "Music executive Darius Van Arman was born in Pennsylvania. He attended Gonzaga College High School and is a human being."
    Output: [ [ "Darius Van Arman", "occupation", "Music executive" ], [ "Darius Van Arman", "born in", "Pennsylvania" ], [ "Darius Van Arman", "attended", "Gonzaga College High School" ], [ "Darius Van Arman", "instance of", "human being" ] ]
    ## Example 4.
    Input: "Italy had 3.6x times more cases of coronavirus than China."
    Output: [ [ "Italy", "had 3.6x times more cases of coronavirus than", "China" ] ]
    """

# Function to create kg prompt
def construct_prompt_batch(articles: List[str]) -> List[str]: #ei for entailment inference #Changed name and parameters
        # Builds prompts for entailment inference
        prompts = [f"{kg_construction_prompt} {kg_format_prompt.format(input=article)} {kg_tips_prompt} {kg_examples_prompt}" for index, article in enumerate(articles)]  #Changed to kg
        return prompts #Returning list of prompts

In [14]:
# generate prompts for train and test
train_prompts = construct_prompt_batch(train_generated_list)
test_prompts = construct_prompt_batch(test_generated_list)

In [15]:
batch_size = 1000
api_key='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'
model_name = "llama-3.3-70b-versatile"
output_dir = "batch_requests"

In [15]:
# Excution of Asynchromous Groq Api call functions

#create batch requests
from torch import mode


def process_batches_and_create_jobs(train_prompts, batch_size, model_name, api_key, output_dir):
    """
    Splits train_prompts into batches, creates JSON files using create_batch_requests,
    uploads them to Groq using upload_file_to_groq, and creates batch jobs using create_batch.

    Args:
        train_prompts (list): List of prompts to process.
        batch_size (int): The size of each batch.
        model_name (str): The model name for the Groq API.
        api_key (Groq): API key.
        output_dir (str): Directory to save the JSON files.

    Returns:
        dict: A dictionary mapping batch file names to their associated batch IDs.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    batch_file_ids = {}
    batch_ids = {}

    # Split train_prompts into batches
    for i in range(0, len(train_prompts), batch_size):
        batch_prompts = train_prompts[i:i + batch_size]
        batch_file_name = f"batch_{i // batch_size + 1}.json"
        batch_file_path = os.path.join(output_dir, batch_file_name)

        # Create JSON file for the batch using create_batch_requests
        create_batch_requests(
            messages=batch_prompts,
            model=model_name,
            filename=batch_file_path
        )

        # Upload JSON file to Groq using upload_file_to_groq
        try:
            file_upload_response = upload_file_to_groq(api_key, batch_file_path)
            file_id = file_upload_response["id"]
            batch_file_ids[batch_file_name] = file_id

            # Create batch job using create_batch
            batch_job_response = create_batch(api_key, file_id)
            batch_id = batch_job_response["id"]
            batch_ids[batch_file_name] = batch_id

        except Exception as e:
            print(f"Error processing batch {batch_file_name}: {e}")

    return batch_ids



batch_ids = process_batches_and_create_jobs(
    train_prompts=train_prompts,
    batch_size=batch_size,
    model_name=model_name,
    api_key=api_key,
    output_dir=output_dir
)

In [18]:
# Check the status of the batch jobs
def check_batch_status(api_key, batch_ids):
    """
    Checks the status of each batch job using the Groq API.

    Args:
        api_key (str): API key for authentication.
        batch_ids (dict): Dictionary mapping batch file names to their associated batch IDs.

    Returns:
        dict: A dictionary mapping batch file names to their status.
    """
    batch_statuses = {}
    for batch_file_name, batch_id in batch_ids.items():
        try:
            status_response = get_batch_status(api_key, batch_id)
            batch_statuses[batch_file_name] = status_response["status"]
        except Exception as e:
            print(f"Error checking status for batch {batch_file_name}: {e}")

    return batch_statuses

batch_statuses = check_batch_status(api_key, batch_ids)
# Print the status of each batch job
for batch_file_name, status in batch_statuses.items():
    print(f"Batch file: {batch_file_name}, Status: {status}")

Batch file: batch_1.json, Status: completed
Batch file: batch_2.json, Status: completed
Batch file: batch_3.json, Status: completed
Batch file: batch_4.json, Status: completed
Batch file: batch_5.json, Status: completed
Batch file: batch_6.json, Status: completed
Batch file: batch_7.json, Status: completed
Batch file: batch_8.json, Status: completed
Batch file: batch_9.json, Status: completed
Batch file: batch_10.json, Status: completed
Batch file: batch_11.json, Status: completed
Batch file: batch_12.json, Status: completed
Batch file: batch_13.json, Status: completed
Batch file: batch_14.json, Status: completed
Batch file: batch_15.json, Status: completed
Batch file: batch_16.json, Status: completed
Batch file: batch_17.json, Status: completed
Batch file: batch_18.json, Status: completed
Batch file: batch_19.json, Status: completed
Batch file: batch_20.json, Status: completed
Batch file: batch_21.json, Status: completed
Batch file: batch_22.json, Status: completed
Batch file: batch_2

In [19]:
# Check the status of the batch jobs and download the results
def check_batch_status_and_download(api_key, batch_ids, output_dir):
    """
    Checks the status of batch jobs and downloads the results if completed.

    Args:
        api_key (str): API key for Groq.
        batch_ids (dict): Dictionary mapping batch file names to their associated batch IDs.
        output_dir (str): Directory to save the downloaded files.

    Returns:
        None
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for batch_file_name, batch_id in batch_ids.items():
        try:
            # Check the status of the batch job
            status_response = get_batch_status(api_key, batch_id)
            status = status_response["status"]

            if status == "completed":
                # Download the results
                output_file_id = status_response["output_file_id"]
                output_file_name = f"{batch_file_name}_output.json"
                output_file_path = os.path.join(output_dir, output_file_name)

                download_message = download_file_content(api_key, output_file_id, output_file_path)
                print(download_message)
            else:
                print(f"Batch {batch_file_name} is still processing. Status: {status}")

        except Exception as e:
            print(f"Error checking status or downloading for {batch_file_name}: {e}")

# Check the status of the batch jobs and download the results
check_batch_status_and_download(api_key, batch_ids, output_dir)

File downloaded successfully to batch_requests/batch_1.json_output.json
File downloaded successfully to batch_requests/batch_2.json_output.json
File downloaded successfully to batch_requests/batch_3.json_output.json
File downloaded successfully to batch_requests/batch_4.json_output.json
File downloaded successfully to batch_requests/batch_5.json_output.json
File downloaded successfully to batch_requests/batch_6.json_output.json
File downloaded successfully to batch_requests/batch_7.json_output.json
File downloaded successfully to batch_requests/batch_8.json_output.json
File downloaded successfully to batch_requests/batch_9.json_output.json
File downloaded successfully to batch_requests/batch_10.json_output.json
File downloaded successfully to batch_requests/batch_11.json_output.json
File downloaded successfully to batch_requests/batch_12.json_output.json
File downloaded successfully to batch_requests/batch_13.json_output.json
File downloaded successfully to batch_requests/batch_14.json

In [60]:
# Process the downloaded files
def process_downloaded_files(output_dir):
    """
    Processes the downloaded files and extracts the responses.

    Args:
        output_dir (str): Directory where the downloaded files are stored.

    Returns:
        dict: A dictionary mapping batch file names to their responses.
    """
    responses = []
    for i in range(0, len(train_prompts), batch_size):
        batch_file_name = f"batch_{i // batch_size + 1}.json_output.json"
        file_path = os.path.join(output_dir, batch_file_name)
        if os.path.exists(file_path):
            with open(file_path, "r") as file:
                for line in file:
                    response = json.loads(line)
                    custom_id = response["custom_id"]
                    content = response["response"]["body"]["choices"][0]["message"]["content"]
                    
                    # Extract the knowledge graph from the output
                    # Assumes the LLM returns the KG in a list within <python> tags
                    start_tag = content.find('[')
                    end_tag = content.rfind(']')
                    if start_tag != -1 and end_tag != -1:
                        kg_string = content[start_tag:end_tag+1]
                        try:
                            kg = eval(kg_string) #use literal_eval for security
                            if isinstance(kg, list):
                                responses.append(kg)
                            else:
                                print("LLM did not return a list.")
                                responses.append([])
                        except (SyntaxError, NameError) as e:
                            content = content.replace(" ", "")
                            start_tag = content.rfind('[\n[')
                            end_tag = content.rfind(']\n]')
                            if start_tag != -1 and end_tag != -1:
                                kg_string = content[start_tag:end_tag+3]
                                try:
                                    kg = eval(kg_string) #use literal_eval for security
                                    if isinstance(kg, list):
                                        responses.append(kg)
                                    else:
                                        print("LLM did not return a list.")
                                        responses.append([])
                                except (SyntaxError, NameError) as e:
                                    print(f"Error parsing LLM output: {e}")
                                    responses.append([])
                            else:
                                end_tag = content.rfind(']')
                                if start_tag != -1 and end_tag != -1:
                                    kg_string = content[start_tag:end_tag+1]
                                    try:
                                        kg = eval(kg_string) #use literal_eval for security
                                        if isinstance(kg, list):
                                            responses.append(kg)
                                        else:
                                            print("LLM did not return a list.")
                                            responses.append([])
                                    except (SyntaxError, NameError) as e:
                                        print(f"Error parsing LLM output: {e}")
                                        responses.append([])
                                else:
                                    
                                    content = content.replace("\n", "")
                                    start_tag = content.rfind('[[')
                                    end_tag = content.rfind(']]')
                                    if start_tag != -1 and end_tag != -1:
                                        kg_string = content[start_tag:end_tag+2]
                                        try:
                                            kg = eval(kg_string) #use literal_eval for security
                                            if isinstance(kg, list):
                                                responses.append(kg)
                                            else:
                                                print("LLM did not return a list.")
                                                responses.append([])
                                        except (SyntaxError, NameError) as e:
                                            print(f"Error parsing LLM output: {e}")
                                            responses.append([])
                                    else:
                                        print("Could not find KG in LLM output.")
                                        responses.append([])
                            
                    else:
                        print("Could not find KG in LLM output.")
                        responses.append([])
        
                    
    

    return responses



# Process the downloaded files
responses = process_downloaded_files(output_dir)

Could not find KG in LLM output.
Could not find KG in LLM output.
Could not find KG in LLM output.
Error parsing LLM output: '[' was never closed (<string>, line 1)
Could not find KG in LLM output.
Could not find KG in LLM output.
Could not find KG in LLM output.
Error parsing LLM output: invalid syntax (<string>, line 15)
Could not find KG in LLM output.
Error parsing LLM output: '[' was never closed (<string>, line 1)
Could not find KG in LLM output.
Error parsing LLM output: invalid syntax (<string>, line 0)
Error parsing LLM output: invalid syntax (<string>, line 0)
Could not find KG in LLM output.
Could not find KG in LLM output.
Error parsing LLM output: name 'entities' is not defined
Could not find KG in LLM output.
Could not find KG in LLM output.
Could not find KG in LLM output.
Error parsing LLM output: unterminated string literal (detected at line 7) (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 12)
Could not find KG in LLM output.
Could not fin

In [62]:

# find items with empty list in responses
empty_responses = [i for i, response in enumerate(responses) if not response]
# Print the indices of empty responses
print("Indices of empty responses:", empty_responses)
print("Number of empty responses:", len(empty_responses))


Indices of empty responses: [146, 428, 686, 724, 1094, 1202, 1211, 1241, 1293, 1486, 1494, 1669, 1749, 1751, 1758, 2274, 2356, 2377, 2524, 2814, 3034, 3107, 3940, 3961, 4028, 4442, 4909, 4921, 4935, 4961, 5947, 6093, 7285, 7726, 8604, 10069, 10335, 10802, 11011, 11979, 12191, 12549, 12573, 12718, 12835, 12838, 13285, 13288, 13541, 14686, 15500, 16653, 17001, 17435, 17896, 17947, 17980, 18076, 18142, 18212, 18263, 18504, 19316, 19321, 19584, 19709, 19775, 19841, 19904, 19907, 19938, 20245, 20533, 20556, 20611, 20612, 20731, 20740, 20761, 20767, 20853, 20890, 20938, 20942, 20948, 20975, 21025, 21070, 21138, 21214, 21252, 21261, 21277, 21407, 21414, 21713, 21725, 21768, 21911, 21947, 22129, 22147, 22159, 22190, 22205, 22210, 22212, 22256, 22367, 22445, 22538, 22560, 22675, 22700, 22708, 22885, 22898, 23128, 23275, 23400, 23498, 23551, 23591, 23739, 23751, 23863, 23997, 24067, 24182, 24346, 24371, 24402, 24407, 24447, 24477, 24600, 24716, 24793, 24815, 24854, 24969, 24973, 25178, 25745, 25

In [None]:
import torch
from typing import List, Tuple, Dict
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import json
import os
from groq import Groq

class BatchGraphEval:
    def __init__(self,
                 nli_model_name: str = "roberta-large-mnli",
                 batch_size: int = 32,
                 kg_construction_prompt="""You are an expert at extracting information in structured formats to build a knowledge graph. 
    Step 1 − Entity detection: Identify all entities in the raw text. Make sure not to miss any out. Entities should be basic and simple, they are akin to Wikipedia nodes. 
    Step 2 − Coreference resolution: Find all expressions in the text that refer to the same entity. Make sure entities are not duplicated. 
    In particular do not include entities that are more specific versions themselves, e.g. "a detailed view of jupiter’s atmosphere" and "jupiter’s atmosphere", only include the most specific version of the entity. 
    Step 3 − Relation extraction: Identify semantic relationships between the entities you have identified.
    Format: Return the knowledge graph as a list of triples, i.e. [ "entity1", "relation1−2", "entity2"], in Python code.""",
                 kg_format_prompt="""Use the given format to extract information from the following input: <input>{input}</input>.
    Skip the preamble and output the result as a list within <python> tags.""",
                 kg_tips_prompt="""Important Tips:
    1. Make sure all information is included in the knowledge graph.
    2. Each triple must only contain three strings! None of the strings should be empty.
    3. Do not split up related information into separate triples because this could change the meaning.
    4. Make sure all brackets and quotation marks are matched.
    5. Before adding a triplet to the knowledge graph, check the concatenated triple makes sense as a sentence. If not, discard it.""",
                 kg_examples_prompt="""Here are some example input and output pairs.
    ## Example 1.
    Input: "The Walt Disney Company, commonly known as Disney, is an American multinational mass media and entertainment conglomerate that is headquartered at the Walt Disney Studios complex in Burbank, California."
    Output: [ [ "The Walt Disney Company", "headquartered at", "Walt Disney Studios complex in Burbank, California" ], [ "The Walt Disney Company", "commonly known as", "Disney" ], [ "The Walt Disney Company", "instance of", "American multinational mass media and entertainment conglomerate" ] ]
    ## Example 2.
    Input: "Amanda Jackson was born in Springfield, Ohio, USA on June 1, 1985. She was a basketball player for the U.S. women’s team."
    Output: [ [ "Amanda Jackson", "born in", "Springfield, Ohio, USA" ], [ "Amanda Jackson", "born on", "June 1, 1985" ], [ "Amanda Jackson", "occupation", "basketball player" ], [ "Amanda Jackson", "played for", "U.S. women’s basketball team" ] ]
    ## Example 3.
    Input: "Music executive Darius Van Arman was born in Pennsylvania. He attended Gonzaga College High School and is a human being."
    Output: [ [ "Darius Van Arman", "occupation", "Music executive" ], [ "Darius Van Arman", "born in", "Pennsylvania" ], [ "Darius Van Arman", "attended", "Gonzaga College High School" ], [ "Darius Van Arman", "instance of", "human being" ] ]
    ## Example 4.
    Input: "Italy had 3.6x times more cases of coronavirus than China."
    Output: [ [ "Italy", "had 3.6x times more cases of coronavirus than", "China" ] ]
    """,
                 llm_model: str = "llama-3.3-70b-versatile"):  #Using groq model
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = AutoTokenizer.from_pretrained(nli_model_name)
        self.nli_model = AutoModelForSequenceClassification.from_pretrained(nli_model_name).to(self.device)
        self.batch_size = batch_size
        self.kg_construction_prompt = kg_construction_prompt
        self.kg_format_prompt = kg_format_prompt
        self.kg_tips_prompt = kg_tips_prompt
        self.kg_examples_prompt = kg_examples_prompt
        self.llm_model_name = llm_model  # Store LLM model name
        self.groq_client = Groq(api_key='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',)  #Groq client
        # No pipeline needed for groq api


    def construct_kg_batch(self, llm_outputs: List[str]) -> List[List[Tuple[str, str, str]]]:
        # Use the prompt with the LLM to construct KGs for multiple outputs
        batch_kgs = []
        for output in llm_outputs:
          input_text = f"{self.kg_construction_prompt} {self.kg_format_prompt.format(input=output)} {self.kg_tips_prompt} {self.kg_examples_prompt}"
          #print(input_text)
            #In practice, you would call an LLM API here with the combined prompt
            #and process the output to extract the KG triples.
            #Replace this with the actual LLM call
          triples = self.call_llm_to_extract_kg(input_text)
          batch_kgs.append(triples)

        return batch_kgs

    def call_llm_to_extract_kg(self,prompt: str) -> List[Tuple[str, str, str]]:
      # Wrap the LLM call in a try-except block
        try:
            #Call Groq API
            chat_completion = self.groq_client.chat.completions.create(
                messages=[{"role": "user", "content": prompt}],
                model=self.llm_model_name,
            )
            output = chat_completion.choices[0].message.content
            

            # Extract the knowledge graph from the output
            # Assumes the LLM returns the KG in a list within <python> tags
            start_tag = output.find('[')
            end_tag = output.rfind(']')
            if start_tag != -1 and end_tag != -1:
                kg_string = output[start_tag:end_tag+1]
                try:
                    kg = eval(kg_string) #use literal_eval for security
                    if isinstance(kg, list):
                        return kg
                    else:
                        print("LLM did not return a list.")
                        return []
                except (SyntaxError, NameError) as e:
                    print(f"Error parsing LLM output: {e}")
                    return []
            else:
                print("Could not find KG in LLM output.")
                return []
        except Exception as e:
            print(f"Error calling LLM: {e}")
            return []

    def check_consistency_batch(self, triples: List[Tuple[str, str, str]], contexts: List[str]) -> List[float]:
        # Combine the triples into sentences
        triple_texts = [f"{t[0]} {t[1]} {t[2]}" for t in triples]

        # Tokenize the inputs
        inputs = self.tokenizer(triple_texts, contexts, return_tensors="pt", truncation=True, max_length=512, padding=True)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        # Get the model predictions
        with torch.no_grad():
            outputs = self.nli_model(**inputs)
        probs = outputs.logits.softmax(dim=-1)

        # Return the probabilities of contradiction (index 2 in RoBERTa MNLI model)
        return probs[:, 2].tolist()

    def evaluate_batch(self, batch_kgs: List[List[Tuple[str, str, str]]], contexts: List[str]) -> List[int]:
        #batch_kgs = self.construct_kg_batch(llm_outputs)
        results = []

        for i in range(0, len(llm_outputs), self.batch_size):
            batch_llm_outputs = llm_outputs[i:i+self.batch_size]
            batch_contexts = contexts[i:i+self.batch_size]
            batch_kgs_subset = batch_kgs[i:i+self.batch_size]

            batch_triples = [triple for idx,kg in enumerate(batch_kgs_subset) for triple in kg]

            batch_contexts_expanded = []

            #Iterate over batch of kgs
            for batch_idx, kg in enumerate(batch_kgs_subset):

              #Extend context for each set of triples within a kg
              batch_contexts_expanded.extend([batch_contexts[batch_idx]] * len(kg))

            inconsistency_probs = self.check_consistency_batch(batch_triples, batch_contexts_expanded)

            triple_index = 0

            for batch_idx, kg in enumerate(batch_kgs_subset):

                inconsistent_triples = []

                for triple in kg:
                    inconsistency_prob = inconsistency_probs[triple_index]
                    if inconsistency_prob > 0.5:
                        inconsistent_triples.append((triple, inconsistency_prob))
                    triple_index += 1

                if len(inconsistent_triples) > 0:
                    result.append(0)
                else:
                    result.append(1)
                    1 else 0 end

        return results


In [104]:
batch_graph_eval = BatchGraphEval(llm_model="llama-3.3-70b-versatile")
train_llm_kgs = batch_graph_eval.construct_kg_batch(train_generated_list[0:1000])

Some weights of the model checkpoint at roberta-large-mnli were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Error parsing LLM output: invalid syntax (<string>, line 4)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 4)
Error parsing LLM output: invalid syntax (<string>, line 5)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 8)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: unexpected indent (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 4)
Error parsing LLM output: invalid syntax (<string>, line 4)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 4)
Error parsing LLM output: unexpected indent (<string>, line 7)
Error parsing LLM output: invalid 

Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: unexpected indent (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 4)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 4)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: unexpected indent (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid 

In [110]:
# In recursive way, complete the KGs
def process_arrays(arrays, index=0):
    if index >= len(arrays):
        return arrays

    if not arrays[index]:
        arrays[index] = batch_graph_eval.construct_kg_batch(train_generated_list[index:index+1])[0]

    return process_arrays(arrays, index + 1)


# Process the list of arrays
processed_train_llm_kgs = process_arrays(train_llm_kgs)


Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 5)
Error parsing LLM output: invalid syntax (<string>, line 8)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: unmatched ')' (<string>, line 3)
Error parsing LLM output: invalid syntax (<string>, line 9)
Error parsing LLM output: invalid syntax (<string>, line 6)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: unterminated string literal (detected at line 6) (<string>, line 6)
Error parsing LLM output: invalid syntax (<string>, line 4)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error p

In [112]:
# In recursive way, complete the KGs
# Process the list of arrays
re_processed_train_llm_kgs = process_arrays(processed_train_llm_kgs)


Error parsing LLM output: invalid syntax (<string>, line 4)
Error parsing LLM output: unterminated string literal (detected at line 13) (<string>, line 13)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: unexpected indent (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 4)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: unterminated string literal (detected at line 6) (<string>, line 6)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 4)
Error parsing LLM output: unexpected indent (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 8)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output

In [114]:
# In recursive way, complete the KGs
# Process the list of arrays
final_processed_train_llm_kgs = process_arrays(re_processed_train_llm_kgs)

Error parsing LLM output: invalid syntax (<string>, line 4)
Error parsing LLM output: invalid syntax (<string>, line 8)
Error parsing LLM output: invalid syntax (<string>, line 4)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 4)
Error parsing LLM output: unterminated string literal (detected at line 6) (<string>, line 6)
Error parsing LLM output: invalid syntax (<string>, line 7)
Error parsing LLM output: unterminated string literal (detected at line 7) (<string>, line 7)
Error parsing LLM output: invalid syntax (<string>, line 9)
Error parsing LLM output: invalid syntax (<string>, line 4)
Error parsing LLM output: invalid syntax (<string>, line 7)


In [None]:
# training results ==> needs to be further fine-tuned
batch_graph_eval = BatchGraphEval(llm_model="llama-3.3-70b-versatile") #Specify Groq Model
train_results = batch_graph_eval.evaluate_batch(final_processed_train_llm_kgs, train_grounding_list[0:1000])

In [122]:
# Example Articles and Summaries (replace with your actual data)
articles = [
    "The Walt Disney Company, commonly known as Disney, is an American multinational mass media and entertainment conglomerate.",
    "Amanda Jackson was born in Springfield, Ohio, USA on June 1, 1985. She was a basketball player for the U.S. women’s team.",
    "Music executive Darius Van Arman was born in Pennsylvania. He attended Gonzaga College High School and is a human being.",
    "Italy had 3.6x times more cases of coronavirus than China."
]
summaries = [
    "Disney is a media conglomerate.",
    "Amanda Jackson was born in Ohio and played basketball.",
    "Darius Van Arman is a music executive born in Pennsylvania",
    "China had less coronavirus than Italy"
]

# Example Labels (1 for consistent, 0 for inconsistent)
labels = [1, 1, 1, 1]


results = batch_graph_eval.evaluate_batch(articles, summaries)

print(results)

[0, 0, 1, -1]


In [137]:
# Store the results as a new column in the original DataFrame
new_train_data = train_data.iloc[0:6112]
new_train_data['prediction'] = predictions


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  new_train_data['prediction'] = predictions


In [138]:
new_train_data.loc[:, 'prediction'] = new_train_data['prediction'].replace(-1, 0)
labels = list(new_train_data['label'])
predicted_labels = list(new_train_data['prediction'])

predicted_labels = list(new_train_data['prediction'])
correct_predictions = sum([1 for i in range(len(labels)) if labels[i] == predicted_labels[i]])

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  new_train_data['prediction'] = new_train_data['prediction'].replace(-1, 0)


In [139]:
new_train_data.head(10)

Unnamed: 0,id,grounding,generated_text,label,cut,dataset_origin,prediction
0,91198,Colin Kaepernick . Kaepernick began his profes...,Colin Kaepernick became a starting quarterback...,0,val,Fever,0
1,194462,Katherine Matilda `` Tilda '' Swinton ( born 5...,Tilda Swinton is a vegan.,0,val,Fever,0
2,137334,Soul Food is a 1997 American comedy-drama film...,Fox 2000 Pictures released the film Soul Food.,1,val,Fever,1
4,111897,Telemundo ( [ teleˈmundo ] ) is an American Sp...,Telemundo is a English-language television net...,0,val,Fever,0
6,181634,Mogadishu ( [ ˌmɔːɡəˈdiːʃuː ] Muqdisho [ mʉqdɪ...,There is a capital called Mogadishu.,1,val,Fever,0
7,219028,Savages (2012 film) . Savages is a 2012 Americ...,Savages was exclusively a German film.,0,val,Fever,1
9,108281,"Andrew Kevin Walker ( born August 14 , 1964 ) ...",Andrew Kevin Walker is only Chinese.,0,val,Fever,0
10,140846,Shooter (2007 film) . The film follows Force R...,Shooter is about an expert marksman who tries ...,0,val,Fever,1
13,54168,,Murda Beatz's real name is Marshall Mathers.,0,val,Fever,0
14,105095,"Carrie Anne Mathison , played by actress Clair...",Nicholas Brody is a character on Homeland.,1,val,Fever,0


In [143]:
# accuracy
correct_predictions/len(labels)

0.6277814136125655

In [147]:
tp = sum([1 for i in range(len(labels)) if labels[i] == predicted_labels[i] and predicted_labels[i] == 1]) 
fp = sum([1 for i in range(len(labels)) if labels[i] != predicted_labels[i] and predicted_labels[i] == 1]) 
tn = sum([1 for i in range(len(labels)) if labels[i] == predicted_labels[i] and predicted_labels[i] == 0]) 
fn = sum([1 for i in range(len(labels)) if labels[i] != predicted_labels[i] and predicted_labels[i] == 0]) 

In [148]:
precision = tp/(tp+fp)
print(precision)

0.40016433853738703


In [149]:
recall = tp/(tp+fn)
print(recall)

0.23966535433070865


In [7]:
new_train_data = train_data.iloc[0:6112]
labels = list(new_train_data['label'])

In [8]:
positive_label = sum([1 for i in range(len(labels)) if labels[i] == 1])

In [11]:
positive_label/len(labels)

0.3324607329842932