In [None]:
import torch
from transformers import AutoTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList
import re
import csv
import os
import json
import pickle
import logging

In [None]:
# Configure logging
logging.basicConfig(
    filename='keyword_extraction.log',
    filemode='a',
    format='%(asctime)s - %(levelname)s - %(message)s',
    level=logging.INFO
)

In [None]:
# Define a custom stopping criteria class
class StopAtClosingBracket(StoppingCriteria):
    def __init__(self, tokenizer):
        super().__init__()
        self.tokenizer = tokenizer
        # Encode the closing bracket to get its token ID
        self.closing_bracket_id = tokenizer.encode('}')[0]

    def __call__(self, input_ids, scores, **kwargs):
        # Check if the last generated token is a closing bracket
        if input_ids[0][-1].item() == self.closing_bracket_id:
            return True
        return False

In [None]:
def extract_keywords(text, candidates, model, tokenizer, device, max_new_tokens=500):
    """
    Extracts relevant terms enclosed in curly brackets '{}' from given text and candidates using the LLAMA Instruct model.

    Args:
        text (str): The text from which to extract keywords.
        candidates (str): The candidate keywords in the format [keyword][keyword]...
        model: The pre-loaded LLAMA Instruct model.
        tokenizer: The tokenizer corresponding to the LLAMA model.
        device: The device to run the model on ('cpu' or 'cuda').
        max_new_tokens (int): Maximum number of tokens to generate.

    Returns:
        list: A list of extracted relevant keywords.
    """
    # Define the few-shot prompt with dynamic text and candidates
    few_shot_prompt = (
        "For given text and extracted candidate keywords, discard the irrelevant ones and return the relevant ones:\n"
        f"Text: {text}\n"
        f"Candidates: {candidates}\n"
        "Relevant keywords: {"
    )

    # Encode the prompt
    inputs = tokenizer(
        few_shot_prompt,
        return_tensors='pt',
        padding=True,        # Now works because pad_token is set
        truncation=True,
        max_length=512,
    )
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    # Initialize the custom stopping criteria
    stopping_criteria = StoppingCriteriaList([StopAtClosingBracket(tokenizer)])

    # Generate the output
    try:
        output_ids = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            do_sample=False,           # Greedy decoding
            temperature=0.0,           # Deterministic output
            top_p=1.0,
            pad_token_id=tokenizer.pad_token_id,  # Use the defined pad token
            stopping_criteria=stopping_criteria
        )
    except Exception as e:
        logging.error(f"Error during generation for text: {text}. Error: {e}")
        return []

    # Decode the generated text
    generated_text = tokenizer.decode(
        output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True
    ).strip()

    # Log the generated text for debugging
    logging.info(f"Generated Text: {generated_text}")

    # Extract terms between curly braces '{}'
    try:
        if "}" in generated_text:
            # Extract everything before the closing bracket
            terms_text = generated_text.split('}', 1)[0]
            # Split the terms by commas, strip whitespace, and remove empty entries
            predicted_terms = [term.strip() for term in terms_text.split(',') if term.strip()]
        else:
            predicted_terms = []
    except IndexError:
        predicted_terms = []

    # Deduplicate terms while preserving order
    predicted_terms = list(dict.fromkeys(predicted_terms))

    return predicted_terms


def parse_candidates(candidates_str):
    """
    Parses the candidates string into a list of keywords.

    Args:
        candidates_str (str): String containing candidates in the format [keyword|TAG][keyword|TAG]...

    Returns:
        list: List of keywords.
    """
    # Use regex to find all [keyword|TAG] patterns
    pattern = r'\[([^\]|]+)\|[^\]]+\]'
    keywords = re.findall(pattern, candidates_str)
    return keywords

def validate_keywords(extracted_keywords, candidates_list):
    """
    Validates that the extracted keywords are part of the candidates list.

    Args:
        extracted_keywords (list): List of keywords extracted by the model.
        candidates_list (list): Original list of candidate keywords.

    Returns:
        list: Filtered list containing only valid keywords.
    """
    # Convert candidates list to lowercase for case-insensitive matching
    candidates_lower = [kw.lower() for kw in candidates_list]
    validated = [kw for kw in extracted_keywords if kw.lower() in candidates_lower]
    return validated

In [None]:
def process_file(input_file_path, output_file_path, model, tokenizer, device, max_new_tokens=50):
    """
    Processes the input file to extract relevant keywords for each entry.

    Args:
        input_file_path (str): Path to the input file.
        output_file_path (str): Path to the output file (can be JSON, pickle, or CSV).
        model: The pre-loaded LLAMA Instruct model.
        tokenizer: The tokenizer corresponding to the LLAMA model.
        device: The device to run the model on ('cpu' or 'cuda').
        max_new_tokens (int): Maximum number of tokens to generate.

    Returns:
        None
    """
    results_dict = {}

    with open(input_file_path, 'r', encoding='utf-8') as infile:
        for line_num, line in enumerate(infile, start=1):
            line = line.strip()
            if not line:
                continue  # Skip empty lines

            # Split the line by tab
            parts = line.split('\t')
            if len(parts) != 4:
                logging.warning(f"Line {line_num}: Unexpected format. Skipping.")
                continue

            # Extract text and candidates
            text_field = parts[0]
            candidates_field = parts[3]

            # Remove the 'text:' prefix
            if text_field.startswith('text:'):
                title = text_field[len('text:'):].strip()
            else:
                title = text_field.strip()

            # Parse candidates to list
            candidates_list = parse_candidates(candidates_field)
            # Reconstruct candidates in the required format [keyword][keyword]...
            # Omitting TAGs as they are irrelevant for the prompt
            candidates_formatted = ''.join([f'[{keyword}]' for keyword in candidates_list])

            # Extract relevant keywords using the model
            relevant_keywords = extract_keywords(title, candidates_formatted, model, tokenizer, device, max_new_tokens)

            # Validate extracted keywords against candidates
            validated_keywords = validate_keywords(relevant_keywords, candidates_list)

            # Add to the results dictionary
            results_dict[title] = validated_keywords

            logging.info(f"Processed Line {line_num}: Extracted Keywords: {validated_keywords}")
            print(f"Processed Line {line_num}: Extracted Keywords: {validated_keywords}")

    # Decide the output format based on the output_file_path extension
    _, ext = os.path.splitext(output_file_path)

    if ext.lower() == '.json':
        with open(output_file_path, 'w', encoding='utf-8') as outfile:
            json.dump(results_dict, outfile, ensure_ascii=False, indent=4)
        print(f"Processing complete. Results saved to {output_file_path} as JSON.")
    elif ext.lower() in ['.pickle', '.pkl']:
        with open(output_file_path, 'wb') as outfile:
            pickle.dump(results_dict, outfile)
        print(f"Processing complete. Results saved to {output_file_path} as Pickle.")
    elif ext.lower() == '.csv':
        with open(output_file_path, 'w', newline='', encoding='utf-8') as outfile:
            fieldnames = ['Title', 'Relevant Keywords']
            writer = csv.DictWriter(outfile, fieldnames=fieldnames)

            writer.writeheader()
            for title, keywords in results_dict.items():
                writer.writerow({
                    'Title': title,
                    'Relevant Keywords': '; '.join(keywords)
                })
        print(f"Processing complete. Results saved to {output_file_path} as CSV.")
    else:
        print(f"Unsupported file extension '{ext}'. Please use .json, .pickle, or .csv.")

In [None]:
# Define the model name and device
model_name = 'meta-llama/Llama-3.1-8B-Instruct'  # Replace with your model
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Initialize the tokenizer and model
try:
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Check if pad_token is already set
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        print("Pad token not found. Set pad_token to eos_token.")
    else:
        print("Pad token already set.")

    model = LlamaForCausalLM.from_pretrained(model_name)
    model.to(device)
    model.eval()  # Set the model to evaluation mode
    print(f"Model '{model_name}' loaded successfully on {device}.")
except Exception as e:
    logging.error(f"Error loading model '{model_name}': {e}")
    print(f"Error loading model '{model_name}'. Check logs for details.")

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

Pad token not found. Set pad_token to eos_token.


config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

Model 'meta-llama/Llama-3.1-8B-Instruct' loaded successfully on cuda.


In [None]:
# Define input and output file paths
input_file_path = '/content/outputfile_all.tsv'  # Ensure this file is uploaded to the notebook environment
output_file_path = '/content/drive/MyDrive/NeSym/extracted_keywords_incidents_train.json'  # Change to 'extracted_keywords.csv' or 'extracted_keywords.pickle' as needed

# Call the process_file function
process_file(
    input_file_path=input_file_path,
    output_file_path=output_file_path,
    model=model,
    tokenizer=tokenizer,
    device=device,
    max_new_tokens=100
)



Processed Line 1: Extracted Keywords: ['pie', 'mint', 'rosemary', 'lamb']
Processed Line 2: Extracted Keywords: ['jackpot', 'mix', 'pretzels', 'biscuits', 'recall']
Processed Line 3: Extracted Keywords: ['recall', 'update', 'milk', 'emphasise', 'buxton']
Processed Line 4: Extracted Keywords: ['consume', 'patulin', 'apple', 'bottled', 'contaminated', 'drink', 'juice']
Processed Line 5: Extracted Keywords: ['listeria monocytogenes', 'dill', 'contaminated', 'consume', 'product']
Processed Line 6: Extracted Keywords: ['listeria', 'pork', 'recall', 'contamination']
Processed Line 7: Extracted Keywords: ['hazard', 'jelly', 'recall', 'choking', 'products']
Processed Line 8: Extracted Keywords: ['listeria', 'raw milk', 'milk', 'contamination', 'alert']
Processed Line 9: Extracted Keywords: ['allergy', 'alert', 'chocolate', 'chip', 'muffin', 'mix']
Processed Line 10: Extracted Keywords: ['allergy', 'egg', 'frozen', 'undeclared', 'alert', 'issues', 'foods']
Processed Line 11: Extracted Keywords:

KeyboardInterrupt: 