# Setup

## Verify we're in the Conda environment

In [None]:
import sys
print(sys.executable)

## Import python packages

In [None]:
import os
import sys
import json
import openai
from PIL import Image
import base64
import io
from dotenv import load_dotenv
import requests
from openai import OpenAI
import pprint
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
import subprocess
import textwrap
from collections import Counter
import pprint
import random
from collections import defaultdict

## openAI API key

In [None]:
# Set up your OpenAI API key
# api_key = os.environ.get("OPENAI_API_KEY")

# Load the .env file
load_dotenv()

client = OpenAI(
    # This is the default and can be omitted
    api_key=os.environ.get("OPENAI_API_KEY"),
)


# Helper functions

## Function to base64 encode an image

In [None]:
def encode_image(image_path):
    """Encode the image to base64 format to send to OpenAI."""
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

## Function to load existing results from JSON

In [None]:
def load_existing_results(filename):
    """Load existing data from JSON file if it exists."""
    if os.path.exists(filename):
        with open(filename, 'r') as f:
            return json.load(f)
    return []

## Function to review the generated stories for each image

In [None]:
# Display for review
def review(input_file):
    inputs = []

    # Read the input file
    if os.path.exists(input_file):
        with open(input_file, 'r') as f:
            try:
                inputs = json.load(f)
            except json.JSONDecodeError:
                print(f"Error: Could not parse existing data in {input_file}.")
                sys.exit(1)
    else:
        print(f"Error: Could not find input_file {input_file}.")
        sys.exit(1)
    

    # Create a set of image paths
    inputs_image_paths = {entry["image"] for entry in inputs}

    # Base directory containing all the folders
    base_dir = Path("/Scandisk/onicai/charles/images")
    image_paths = [path for path in base_dir.glob("**/*.png") if not path.name.startswith(".")]
    thumbnail_size = (200, 200)

    icount = 0
    count_not_enough_sentences = 0
    for image_path in image_paths:
        icount += 1
        existing_entry = None

        # Find the image
        if str(image_path) in inputs_image_paths:
            existing_entry = next(entry for entry in inputs if entry["image"] == str(image_path))
            accepted_opening_sentences_with_stories = existing_entry["response"]["accepted_opening_sentences_with_stories"]
            # rejected_opening_sentences_with_stories = existing_entry["response"]["rejected_opening_sentences_with_stories"]
            if len(accepted_opening_sentences_with_stories) < 10:
                count_not_enough_sentences += 1
                print(f"------------------\n image {icount}: {image_path}")
                # Open and display the image
                image = Image.open(image_path)
                image.thumbnail(thumbnail_size)  # Resize the image to a thumbnail
                plt.figure(figsize=(4, 4))  # Adjust figure size
                plt.imshow(image)
                plt.axis('off')  # Hide axes for better view
                plt.show()
                print(f"# accepted opening sentences = {len(accepted_opening_sentences_with_stories)}")
                # print(json.dumps(existing_entry, indent=4))
            continue
        
        # Not yet processed, ERROR
        print("ERROR: Stories for this image were not yet judged.")
        sys.exit(1)

    print(f"Number of images with not exactly than 10 sentences = {count_not_enough_sentences}")

## Function to call llama2.c

In [None]:
def generate_story(opening_sentence):
    # Define the command as a list
    command = [
        "/Users/arjaan/icppWorld/repos/llama2.c/run",
        "/Users/arjaan/icppWorld/repos/charles/models/out-09/model.bin",
        "-z", "/Users/arjaan/icppWorld/repos/charles/models/out-09/tok4096.bin",
        "-t", "0.1",
        "-p", "0.9",
        "-i", opening_sentence
    ]

    # Run the command and capture the output
    result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

    # Process the output to exclude the "achieved tok/s" line
    output_lines = result.stdout.strip().split('\n')
    story = "\n".join(line for line in output_lines if "achieved tok/s" not in line)

    return story

## Function to reformat for bioniq

In [None]:
def create_storybookNFTs(input_file, output_file, image_start, image_end):
    inputs = []
    outputs = []

    # Read the input file
    if os.path.exists(input_file):
        with open(input_file, 'r') as f:
            try:
                inputs = json.load(f)
            except json.JSONDecodeError:
                print(f"Error: Could not parse existing data in {input_file}.")
                sys.exit(1)
    else:
        print(f"Error: Could not find input_file {input_file}.")
        sys.exit(1)
    

    base_dir = Path("/Users/arjaan/icppWorld/repos/charles/assets/images_charles_for_bioniq")
    thumbnail_size = (200, 200)

    # Expand the data into doublepages
    doublepages_all = []

    # sanity check in uniqueness of image_index
    image_indexes = []
    for input in inputs:
        reviewed = input["reviewed"]
        image_index = input["image_index"]
        image_category = input["image_category"]
        image_filename = input["image_filename"]
        num_prompts = input["num_prompts"]
        prompts = input["prompts"]

        image_indexes.append(image_index)

        # We take the first 10 prompts
        prompts = prompts[:10]

        for prompt_index, prompt in enumerate(prompts):
            doublepage = {
                "imageId": image_index,
                "imageUrl": f"./{image_filename}",
                "imageCategory": image_category, # to remove later
                "promptIndex": prompt_index,
                "promptId": f"{image_index}-{prompt_index}",
                "prompt": prompt,
                "story": ""
            }

            doublepages_all.append(doublepage)

        # Do also a quick sanity check...
        if len(prompts) != 10:
            image_path = base_dir / image_filename
            # Open and display the image
            print(f"--------------------------------------------------")
            image = Image.open(image_path)
            image.thumbnail(thumbnail_size)  # Resize the image to a thumbnail
            plt.figure(figsize=(2, 2))  # Adjust figure size
            plt.imshow(image)
            plt.axis('off')  # Hide axes for better view
            plt.show()

            if not reviewed:
                print(f"To be reviewed")
            else:
                print(f"len(prompts) is not 10 but {len(prompts)}")

            print(f"reviewed        : {reviewed}")
            print(f"image_index     : {image_index}")
            print(f"image_category  : {image_category}")
            print(f"image_filename  : {image_filename}")
            print(f"image_path      : {str(image_path)}")
            print(f"num_prompts     : {num_prompts}")
            print(f"len(prompts)    : {len(prompts)}")
            print("prompts:")
            for prompt_index, prompt in enumerate(prompts):
                print("=====================")
                plt.figure(figsize=(2, 2))  # Adjust figure size
                plt.imshow(image)
                plt.axis('off')  # Hide axes for better view
                plt.show()
                print(f"image_filename  : {image_filename}")
                print("")
                print(f"{prompt_index}: {prompt}")
                
                
                story = generate_story(prompt)
                
                # Wrap the text to 80 characters per line
                print("")
                wrapped_text = textwrap.fill(story, width=80)
                print(wrapped_text)
                print("")
            
            print("STILL MORE REVIEW WORK TO DO...")
            sys.exit(1)

    # Check if all image_index are unique
    if len(image_indexes) == len(set(image_indexes)):
        print("OK, all the image_index values are unique:")
    else:
        print("Houston, we got a problem. Not all image_index values are unique.")
        sys.exit(1)

    print("OK, the data is complete for creating the storybookNFTs")
    print(f"Number of available doublepages = {len(doublepages_all)}")

    # reformat into storybookNFTs format using doublepages structure.
    # See: https://github.com/onicai/Charles/blob/master/src/charlesStorybook_frontend/index.html

    
    # Count the occurrences of each category
    image_categories = [doublepage["imageCategory"] for doublepage in doublepages_all]
    category_counts = Counter(image_categories)

    image_categories_counts = dict(category_counts)
    total_categories = len(category_counts)
    total_doublepages = sum(category_counts.values())

    # Display the results
    print("Number of items for each image category:")
    pprint.pprint(dict(image_categories_counts))
    print("Total number of categories:", total_categories)
    print("Total number of doublepages:", total_doublepages)

    # Shuffle all combinations to randomize their order
    random.shuffle(doublepages_all)

    # Organize combinations by category for easier selection
    category_dict = defaultdict(list)
    for doublepage in doublepages_all:
        image_category = doublepage["imageCategory"]
        category_dict[image_category].append(doublepage)

    # Generate storybooks
    storybooks = []
    used_prompt_ids = set()

    for storybook_index in range(990):
        storybook = []
        categories_in_storybook = set()

        while len(storybook) < 5:
            # First try to select a random category that hasn't been used yet in this storybook
            # - only allow if there are still choices left for that category
            available_categories = [cat for cat in category_dict if cat not in categories_in_storybook and category_dict[cat]]
            if not available_categories:
                print(f"Allowing a duplicate category for storybook_index = {storybook_index}")

                # Select a random category while allowing duplicates per storybook
                available_categories = [cat for cat in category_dict if category_dict[cat]]
                if not available_categories:
                    print(f"Houston, we got a problem for storybook_index = {storybook_index}")
                    sys.exit(1)
                    break

            selected_category = random.choice(available_categories)
            
            # Find an image-prompt combination from this category
            for doublepage in category_dict[selected_category]:
                prompt_id = doublepage["promptId"]  # unique identifier, eg. "102-8"
                image_category = doublepage["imageCategory"]
                if prompt_id not in used_prompt_ids:
                    used_prompt_ids.add(prompt_id)
                    categories_in_storybook.add(image_category)
                    category_dict[selected_category].remove(doublepage)

                    # before adding to storybook, remove the "imageCategory"
                    doublepage_for_book = doublepage.copy()
                    del doublepage_for_book["imageCategory"]
                    storybook.append(doublepage_for_book)
                    break
                else:
                    print(f"Houston, we got another problem for storybook_index = {storybook_index}")
                    sys.exit(1)

        # print(f"Succesfully created storybook {storybook_index}")
        storybooks.append(storybook)

    # Verify that all storybooks are complete
    complete_storybooks = [sb for sb in storybooks if len(sb) == 5]

    if len(complete_storybooks) == 990:
        print("Successfully created 990 storybooks.")
    else:
        print(f"Only {len(complete_storybooks)} storybooks were created.")


    # Example: Print the first storybook
    for idx, storybook in enumerate(complete_storybooks[:1], 1):
        print("------------------------------")
        print(f"Storybook {idx}:")
        for input in storybook:
            pprint.pprint(input)
        print()
        print("------------------------------")

    # Create the storybookNFTs 
    storybookNFTs = []
    for idx, storybook in enumerate(complete_storybooks, 1):
        storybookNFT = {
            "doublepages": storybook
        }
        storybookNFTs.append(storybookNFT)

    with open(output_file, 'w') as f:
        json.dump(storybookNFTs, f, indent=4)

    print(f"Saved storybookNFTs to : {output_file}")

# Run it

In [None]:
input_file = './5-reformat.json'

image_start = 0
image_end = 500

output_file = './storybookNFTs.json'
create_storybookNFTs(input_file, output_file, image_start, image_end)