<a href="https://colab.research.google.com/github/hsandaver/essays/blob/main/PromptGenerator3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import json
import random
import re
from collections import defaultdict
import nltk
from nltk.corpus import wordnet, stopwords
from nltk import pos_tag, word_tokenize
import os
import sys
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')

# Attempt to import Google Colab's files module; handle if not in Colab
try:
    from google.colab import files
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

# Download necessary NLTK data if not already available
nltk_packages = ['wordnet', 'omw-1.4', 'punkt', 'averaged_perceptron_tagger', 'stopwords']
for package in nltk_packages:
    nltk.download(package, quiet=True)

# Mapping NLTK POS tags to WordNet POS tags
def get_wordnet_pos(treebank_tag):
    """
    Map POS tag to the format accepted by wordnet.synsets()
    """
    if treebank_tag.startswith('J'):
        return wordnet.ADJ
    elif treebank_tag.startswith('V'):
        return wordnet.VERB
    elif treebank_tag.startswith('N'):
        return wordnet.NOUN
    elif treebank_tag.startswith('R'):
        return wordnet.ADV
    else:
        return None  # Return None if POS tag is not recognized

# Function to load the dataset
def load_dataset(default_file='your_dataset.jsonl'):
    """
    Load dataset from a JSON Lines file.
    If the file is not found locally, prompt the user to upload it (Colab only).
    """
    if os.path.exists(default_file):
        file_path = default_file
    else:
        if IN_COLAB:
            logging.info(f"File '{default_file}' not found locally. Please upload the dataset.")
            uploaded = files.upload()
            if not uploaded:
                raise FileNotFoundError("No file uploaded.")
            file_path = list(uploaded.keys())[0]
        else:
            raise FileNotFoundError(f"File '{default_file}' not found and not running in Colab.")

    dataset = []
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            for line in file:
                line = line.strip()
                if line:
                    try:
                        dataset.append(json.loads(line))
                    except json.JSONDecodeError as e:
                        logging.warning(f"Error decoding JSON: {e} in line: {line}")
        if not dataset:
            raise ValueError("Dataset is empty. Please check the file contents.")
        logging.info(f"Loaded {len(dataset)} entries from the dataset.")
        return dataset
    except Exception as e:
        logging.error(f"An error occurred while loading the dataset: {e}")
        raise

# Function to clean and tokenize text
def tokenize(text):
    """
    Remove punctuation and tokenize the text into words.
    """
    tokens = word_tokenize(text.lower())
    tokens = [token for token in tokens if re.match(r'\w+', token)]  # Keep words only
    return tokens

# Function to extract unique elements from the dataset
def extract_elements_from_dataset(dataset):
    """
    Extract unique elements from the dataset and define camera, film, and new creative elements.
    """
    features = defaultdict(set)

    # Define patterns for extraction (can be customized based on dataset)
    patterns = {
        'subjects': re.compile(r'\b\w+\s+eyes\b', re.IGNORECASE),
        'settings': re.compile(r'\bforest\b|\bocean\b|\bsea\b|\bcity\b|\bmountain\b|\bdesert\b|\bbeach\b', re.IGNORECASE),
        'moods': re.compile(r'\bintrospective\b|\bserene\b|\bwistful\b|\breflective\b|\badventurous\b|\bjoyful\b|\bsomber\b', re.IGNORECASE),
        'lighting': re.compile(r'\bsoft\b|\bgolden[-\s]hour\b|\bsunset\b|\bmoonlight\b|\bneon\b|\bharsh\b|\bdim\b', re.IGNORECASE),
        'perspectives': re.compile(r'\blow-angle\b|\bbug’s-eye\b|\bthree-quarter view\b|\bclose-up\b|\bwide-angle\b|\bbird’s-eye view\b', re.IGNORECASE)
    }

    for entry in dataset:
        prompt = entry.get('prompt', '')
        if not prompt:
            continue
        for key, pattern in patterns.items():
            matches = pattern.findall(prompt)
            for match in matches:
                features[key].add(match.lower())

    # Define camera and film elements directly
    features['camera'] = set([
        'DSLR',
        'mirrorless',
        'film camera',
        'medium format',
        '35mm',
        'Hasselblad'  # Added Hasselblad without "camera"
    ])
    features['film'] = set([
        'Kodak Portra 400',
        'Fujifilm Pro 400H',
        'Ilford HP5 Plus',
        'Kodak Tri-X 400',
        'Cinestill 800T'
    ])

    # Additional creative elements
    features['themes'] = set([
        'cyberpunk',
        'steampunk',
        'noir',
        'fantasy',
        'sci-fi',
        'surrealism',
        'minimalism'
    ])
    features['styles'] = set([
        'vintage',
        'modern',
        'abstract',
        'realistic',
        'impressionistic',
        'expressionistic',
        'geometric'
    ])
    features['colors'] = set([
        'monochromatic',
        'vibrant',
        'pastel',
        'neon',
        'earth tones',
        'complementary colors',
        'analogous colors'
    ])
    features['abstract_concepts'] = set([
        'time dilation',
        'metamorphosis',
        'juxtaposition',
        'paradox',
        'transcendence',
        'chaos',
        'harmony'
    ])

    # Convert all sets to lists
    for key in features:
        features[key] = list(features[key])

    logging.info("Extracted additional elements from the dataset.")
    return features

# Function to combine base prompt with additional elements using dynamic structures
def combine_elements(base_prompt, additional_elements):
    """
    Incorporate additional elements into the base prompt to create a more detailed and creative prompt.
    Uses dynamic sentence structures for variety.
    """
    try:
        parts = [base_prompt]
        element_keys = ['subjects', 'settings', 'themes', 'styles', 'colors', 'moods', 'lighting', 'perspectives', 'abstract_concepts']
        selected_elements = {}

        for key in element_keys:
            if additional_elements.get(key):
                selected_elements[key] = random.choice(additional_elements[key])

        # Dynamic sentence structures
        structures = [
            "{base} Featuring {subjects}, set in a {settings} with a {moods} mood.",
            "{base}. A {styles} interpretation with {colors} hues and {lighting} lighting.",
            "{base} captured from a {perspectives} perspective, embodying {abstract_concepts}.",
            "{base} in a {themes} style, highlighting {subjects} against a {settings} backdrop."
        ]

        structure = random.choice(structures)
        filled_structure = structure.format(
            base=base_prompt,
            subjects=selected_elements.get('subjects', ''),
            settings=selected_elements.get('settings', ''),
            moods=selected_elements.get('moods', ''),
            styles=selected_elements.get('styles', ''),
            colors=selected_elements.get('colors', ''),
            lighting=selected_elements.get('lighting', ''),
            perspectives=selected_elements.get('perspectives', ''),
            abstract_concepts=selected_elements.get('abstract_concepts', ''),
            themes=selected_elements.get('themes', '')
        )

        # Incorporate camera and film elements
        camera_film_sentence = ""
        if additional_elements.get('camera') or additional_elements.get('film'):
            camera = random.choice(additional_elements['camera']) if additional_elements.get('camera') else ''
            film = random.choice(additional_elements['film']) if additional_elements.get('film') else ''
            camera_film_sentence = " Captured using a {}{}.".format(
                camera,
                f" on {film} film" if film else ""
            )
        filled_structure += camera_film_sentence

        return filled_structure
    except KeyError as e:
        logging.warning(f"Missing feature category: {e}")
        return base_prompt

# Cache for synonyms to improve performance
synonym_cache = {}

# Function to find and replace words with synonyms for added creativity
def replace_with_synonyms(prompt, protected_words):
    """
    Replace adjectives, adverbs, and nouns in the prompt with their synonyms to add variety, excluding protected words.
    """
    tokens = word_tokenize(prompt)
    tagged_tokens = pos_tag(tokens)
    new_tokens = []

    stop_words = set(stopwords.words('english'))

    # Split protected words into tokens
    protected_tokens = set()
    for word in protected_words:
        protected_tokens.update(word_tokenize(word.lower()))

    for word, tag in tagged_tokens:
        wn_pos = get_wordnet_pos(tag)
        # Replace adjectives, adverbs, and nouns for more creativity
        if wn_pos not in (wordnet.ADJ, wordnet.ADV, wordnet.NOUN):
            new_tokens.append(word)
            continue

        # Exclude stopwords, function words, and protected words
        if word.lower() in stop_words or word.lower() in protected_tokens:
            new_tokens.append(word)
            continue

        word_lower = word.lower()
        # Check cache first
        if word_lower in synonym_cache:
            synonyms = synonym_cache[word_lower]
        else:
            synsets = wordnet.synsets(word_lower, pos=wn_pos)
            synonyms = set()
            for syn in synsets:
                for lemma in syn.lemmas():
                    syn_name = lemma.name().replace('_', ' ')
                    # Include multi-word synonyms
                    if syn_name.isalpha() or ' ' in syn_name:
                        synonyms.add(syn_name)
            # Filter to alphabetic synonyms
            synonyms = {syn for syn in synonyms if syn.replace(' ', '').isalpha()}
            synonym_cache[word_lower] = list(synonyms)

        # Exclude the original word and select a random synonym
        synonyms = [syn for syn in synonym_cache[word_lower] if syn.lower() != word_lower]
        if synonyms:
            synonym = random.choice(synonyms)
            # Preserve the original casing
            if word.isupper():
                synonym = synonym.upper()
            elif word[0].isupper():
                synonym = synonym.capitalize()
            new_tokens.append(synonym)
        else:
            new_tokens.append(word)

    return ' '.join(new_tokens)

# Function to truncate the prompt to a given token limit
def truncate_prompt(prompt, token_limit):
    """
    Truncate the prompt to the specified number of tokens without breaking words.
    """
    tokens = word_tokenize(prompt)
    if len(tokens) <= token_limit:
        return prompt
    truncated_tokens = tokens[:token_limit]
    # Ensure that the prompt ends gracefully
    truncated_prompt = ' '.join(truncated_tokens)
    if not truncated_prompt.endswith(('.', '!', '?')):
        truncated_prompt += '...'
    return truncated_prompt

# Function to generate a creative random prompt based on dataset elements
def generate_prompt(dataset, additional_elements, token_limit=75):
    """
    Generate a creative prompt by combining elements from the dataset and adding synonyms.
    """
    base_entry = random.choice(dataset)
    base_prompt = base_entry.get('prompt', 'A creative scene.')

    combined_prompt = combine_elements(base_prompt, additional_elements)

    # Combine camera and film elements to protect them from synonym replacement
    protected_words = set(
        word.lower() for word in additional_elements.get('camera', []) + additional_elements.get('film', [])
    )

    combined_prompt = replace_with_synonyms(combined_prompt, protected_words)

    if len(tokenize(combined_prompt)) > token_limit:
        combined_prompt = truncate_prompt(combined_prompt, token_limit)

    return combined_prompt

# Function to save the generated prompts to a file
def save_generated_prompts(prompts, output_file):
    """
    Save the list of generated prompts to a text file, each on a new line.
    """
    try:
        with open(output_file, 'w', encoding='utf-8') as file:
            for prompt in prompts:
                file.write(f"{prompt}\n")
        logging.info(f"{len(prompts)} prompts generated and saved to '{output_file}'.")

        if IN_COLAB:
            files.download(output_file)
    except Exception as e:
        logging.error(f"An error occurred while saving prompts: {e}")
        raise

# Main function to run the generator
def main():
    """
    Main function to execute the prompt generation process.
    """
    import argparse

    parser = argparse.ArgumentParser(description="Generate creative prompts based on a dataset.")
    parser.add_argument('--dataset', type=str, default='your_dataset.jsonl', help='Path to the dataset JSONL file.')
    parser.add_argument('--output', type=str, default='generated_prompts.txt', help='Output file for generated prompts.')
    parser.add_argument('--number', type=int, default=10, help='Number of prompts to generate.')
    parser.add_argument('--tokens', type=int, default=75, help='Maximum number of tokens per prompt.')
    parser.add_argument('--categories', type=str, nargs='*', default=[], help='Categories to include (e.g., themes styles). If empty, all are included.')

    args, unknown = parser.parse_known_args()

    try:
        dataset = load_dataset(args.dataset)
    except Exception as e:
        logging.error(e)
        return

    additional_elements = extract_elements_from_dataset(dataset)

    # If user specifies categories, filter them
    if args.categories:
        allowed_categories = set(args.categories)
        additional_elements = {k: v for k, v in additional_elements.items() if k in allowed_categories}
        logging.info(f"Selected categories for prompt generation: {', '.join(allowed_categories)}")
    else:
        logging.info("Using all available categories for prompt generation.")

    if not any(additional_elements.values()):
        logging.error("No additional elements extracted. Please check the dataset's 'prompt' fields.")
        return

    generated_prompts = []
    seen_prompts = set()

    for _ in range(args.number):
        attempt = 0
        max_attempts = 5
        while attempt < max_attempts:
            generated_prompt = generate_prompt(dataset, additional_elements, token_limit=args.tokens)
            if generated_prompt not in seen_prompts:
                generated_prompts.append(generated_prompt)
                seen_prompts.add(generated_prompt)
                break
            attempt += 1
        else:
            logging.warning("Max attempts reached. Some prompts may be duplicates.")

    try:
        save_generated_prompts(generated_prompts, args.output)
    except Exception as e:
        logging.error(e)
        return

# Entry point
if __name__ == '__main__':
    main()