<a href="https://colab.research.google.com/github/hsandaver/hsandaver/blob/main/PromptGenerator1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import json
import random
import re
from collections import defaultdict
import nltk
from nltk.corpus import wordnet, stopwords
from nltk import pos_tag, word_tokenize
import os
import sys
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)

# Attempt to import Google Colab's files module; handle if not in Colab
try:
    from google.colab import files
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

# Download necessary NLTK data if not already available
nltk_packages = ['wordnet', 'omw-1.4', 'punkt', 'averaged_perceptron_tagger', 'stopwords']
for package in nltk_packages:
    nltk.download(package, quiet=True)

# Mapping NLTK POS tags to WordNet POS tags
def get_wordnet_pos(treebank_tag):
    """
    Map POS tag to the format accepted by wordnet.synsets()
    """
    if treebank_tag.startswith('J'):
        return wordnet.ADJ
    elif treebank_tag.startswith('V'):
        return wordnet.VERB
    elif treebank_tag.startswith('N'):
        return wordnet.NOUN
    elif treebank_tag.startswith('R'):
        return wordnet.ADV
    else:
        return None  # Return None if POS tag is not recognized

# Function to load the dataset
def load_dataset(default_file='your_dataset.jsonl'):
    """
    Load dataset from a JSON Lines file.
    If the file is not found locally, prompt the user to upload it (Colab only).
    """
    if os.path.exists(default_file):
        file_path = default_file
    else:
        if IN_COLAB:
            print(f"File '{default_file}' not found locally. Please upload the dataset.")
            uploaded = files.upload()
            if not uploaded:
                raise FileNotFoundError("No file uploaded.")
            file_path = list(uploaded.keys())[0]
        else:
            raise FileNotFoundError(f"File '{default_file}' not found and not running in Colab.")

    dataset = []
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            for line in file:
                line = line.strip()
                if line:
                    try:
                        dataset.append(json.loads(line))
                    except json.JSONDecodeError as e:
                        logging.warning(f"Error decoding JSON: {e} in line: {line}")
        if not dataset:
            raise ValueError("Dataset is empty. Please check the file contents.")
        return dataset
    except Exception as e:
        logging.error(f"An error occurred while loading the dataset: {e}")
        raise

# Function to clean and tokenize text
def tokenize(text):
    """
    Remove punctuation and tokenize the text into words.
    """
    tokens = word_tokenize(text.lower())
    tokens = [token for token in tokens if re.match(r'\w+', token)]  # Keep words only
    return tokens

# Function to extract unique elements from the dataset
def extract_elements_from_dataset(dataset):
    """
    Extract unique subjects, settings, moods, lighting, and perspectives from the dataset.
    """
    features = defaultdict(set)

    # Define patterns for extraction (can be customized based on dataset)
    patterns = {
        'subjects': re.compile(r'\b\w+\s+eyes\b', re.IGNORECASE),
        'settings': re.compile(r'\bforest\b|\bocean\b|\bsea\b|\bcity\b|\bmountain\b|\bdesert\b|\bbeach\b', re.IGNORECASE),
        'moods': re.compile(r'\bintrospective\b|\bserene\b|\bwistful\b|\breflective\b|\badventurous\b|\bjoyful\b|\bsomber\b', re.IGNORECASE),
        'lighting': re.compile(r'\bsoft\b|\bgolden[-\s]hour\b|\bsunset\b|\bmoonlight\b|\bneon\b|\bharsh\b|\bdim\b', re.IGNORECASE),
        'perspectives': re.compile(r'\blow-angle\b|\bbug’s-eye\b|\bthree-quarter view\b|\bclose-up\b|\bwide-angle\b|\bbird’s-eye view\b', re.IGNORECASE)
    }

    for entry in dataset:
        prompt = entry.get('prompt', '')
        if not prompt:
            continue
        for key, pattern in patterns.items():
            matches = pattern.findall(prompt)
            for match in matches:
                features[key].add(match.lower())

    # Convert sets to lists
    for key in features:
        features[key] = list(features[key])

    return features

# Function to combine base prompt with additional elements
def combine_elements(base_prompt, additional_elements):
    """
    Incorporate additional elements into the base prompt to create a more detailed prompt.
    """
    try:
        parts = [base_prompt]
        additional_subject = random.choice(additional_elements['subjects']) if additional_elements['subjects'] else ''
        additional_setting = random.choice(additional_elements['settings']) if additional_elements['settings'] else ''
        additional_mood = random.choice(additional_elements['moods']) if additional_elements['moods'] else ''
        additional_lighting = random.choice(additional_elements['lighting']) if additional_elements['lighting'] else ''
        additional_perspective = random.choice(additional_elements['perspectives']) if additional_elements['perspectives'] else ''

        if additional_subject:
            parts.append(f"The subject has {additional_subject}")
        if additional_setting:
            parts.append(f"set against a {additional_setting} backdrop")
        if additional_mood:
            parts.append(f"The mood is {additional_mood}")
        if additional_lighting:
            parts.append(f"illuminated by {additional_lighting} lighting")
        if additional_perspective:
            parts.append(f"and the portrait is captured from a {additional_perspective} perspective.")

        new_prompt = '. '.join(filter(None, parts))
        return new_prompt
    except KeyError as e:
        logging.warning(f"Missing feature category: {e}")
        return base_prompt

# Cache for synonyms to improve performance
synonym_cache = {}

# Function to find and replace words with synonyms for added creativity
def replace_with_synonyms(prompt):
    """
    Replace adjectives and adverbs in the prompt with their synonyms to add variety.
    """
    tokens = word_tokenize(prompt)
    tagged_tokens = pos_tag(tokens)
    new_tokens = []

    stop_words = set(stopwords.words('english'))

    for word, tag in tagged_tokens:
        wn_pos = get_wordnet_pos(tag)
        # Only replace adjectives and adverbs
        if wn_pos not in (wordnet.ADJ, wordnet.ADV):
            new_tokens.append(word)
            continue

        # Exclude stopwords and function words
        if word.lower() in stop_words:
            new_tokens.append(word)
            continue

        word_lower = word.lower()
        # Check cache first
        if word_lower in synonym_cache:
            synonyms = synonym_cache[word_lower]
        else:
            synonyms = wordnet.synsets(word_lower, pos=wn_pos)
            synonyms = [lemma.name().replace('_', ' ') for syn in synonyms for lemma in syn.lemmas()]
            # Filter to single-word synonyms that are alphabetic
            synonyms = [syn for syn in synonyms if ' ' not in syn and syn.isalpha()]
            synonym_cache[word_lower] = synonyms

        # Exclude the original word and select a random synonym
        synonyms = [syn for syn in synonyms if syn.lower() != word_lower]
        if synonyms:
            synonym = random.choice(synonyms)
            # Preserve the original casing
            if word.isupper():
                synonym = synonym.upper()
            elif word[0].isupper():
                synonym = synonym.capitalize()
            new_tokens.append(synonym)
        else:
            new_tokens.append(word)

    return ' '.join(new_tokens)

# Function to truncate the prompt to a given token limit
def truncate_prompt(prompt, token_limit):
    """
    Truncate the prompt to the specified number of tokens without breaking words.
    """
    tokens = word_tokenize(prompt)
    if len(tokens) <= token_limit:
        return prompt
    truncated_tokens = tokens[:token_limit]
    # Ensure that the prompt ends gracefully
    truncated_prompt = ' '.join(truncated_tokens)
    if not truncated_prompt.endswith('.'):
        truncated_prompt += '...'
    return truncated_prompt

# Function to generate a creative random prompt based on dataset elements
def generate_prompt(dataset, additional_elements, token_limit=75):
    """
    Generate a creative prompt by combining elements from the dataset and adding synonyms.
    """
    base_entry = random.choice(dataset)
    base_prompt = base_entry.get('prompt', 'A creative scene.')

    combined_prompt = combine_elements(base_prompt, additional_elements)
    combined_prompt = replace_with_synonyms(combined_prompt)

    if len(tokenize(combined_prompt)) > token_limit:
        combined_prompt = truncate_prompt(combined_prompt, token_limit)

    return combined_prompt

# Function to save the generated prompts to a file
def save_generated_prompts(prompts, output_file):
    """
    Save the list of generated prompts to a text file, each on a new line.
    """
    try:
        with open(output_file, 'w', encoding='utf-8') as file:
            for prompt in prompts:
                file.write(f"{prompt}\n")
        logging.info(f"{len(prompts)} prompts generated and saved to '{output_file}'.")

        if IN_COLAB:
            files.download(output_file)
    except Exception as e:
        logging.error(f"An error occurred while saving prompts: {e}")
        raise

# Main function to run the generator
def main():
    """
    Main function to execute the prompt generation process.
    """
    import argparse

    parser = argparse.ArgumentParser(description="Generate creative prompts based on a dataset.")
    parser.add_argument('--dataset', type=str, default='your_dataset.jsonl', help='Path to the dataset JSONL file.')
    parser.add_argument('--output', type=str, default='generated_prompts.txt', help='Output file for generated prompts.')
    parser.add_argument('--number', type=int, default=10, help='Number of prompts to generate.')
    parser.add_argument('--tokens', type=int, default=75, help='Maximum number of tokens per prompt.')

    # Use parse_known_args() instead of parse_args()
    args, unknown = parser.parse_known_args()

    try:
        dataset = load_dataset(args.dataset)
    except Exception as e:
        logging.error(e)
        return

    additional_elements = extract_elements_from_dataset(dataset)

    if not any(additional_elements.values()):
        logging.error("No additional elements extracted. Please check the dataset's 'prompt' fields.")
        return

    generated_prompts = []
    for _ in range(args.number):
        generated_prompt = generate_prompt(dataset, additional_elements, token_limit=args.tokens)
        generated_prompts.append(generated_prompt)

    try:
        save_generated_prompts(generated_prompts, args.output)
    except Exception as e:
        logging.error(e)
        return

# Entry point
if __name__ == '__main__':
    main()