# Setup

## Verify we're in the Conda environment

In [None]:
import sys
print(sys.executable)

## Import python packages

In [None]:
import os
import sys
import json
import openai
from PIL import Image
import base64
import io
from dotenv import load_dotenv
import requests
from openai import OpenAI
import pprint
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
import subprocess


## openAI API key

In [None]:
# Set up your OpenAI API key
# api_key = os.environ.get("OPENAI_API_KEY")

# Load the .env file
load_dotenv()

client = OpenAI(
    # This is the default and can be omitted
    api_key=os.environ.get("OPENAI_API_KEY"),
)


# Helper functions

## Function to base64 encode an image

In [None]:
def encode_image(image_path):
    """Encode the image to base64 format to send to OpenAI."""
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

## Function to load existing results from JSON

In [None]:
def load_existing_results(filename):
    """Load existing data from JSON file if it exists."""
    if os.path.exists(filename):
        with open(filename, 'r') as f:
            return json.load(f)
    return []

## Function to send the request to OpenAI API

In [None]:
# Function to send the request to OpenAI API
def get_image_description_and_prompts(prompt, base64_image):
    chat_completion = client.chat.completions.create(
        messages=[
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": prompt
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{base64_image}"
                        }
                    }
                ]
            }
        ],
        model="gpt-4o-mini",
        response_format={"type": "json_object"},
        max_tokens=10000
    )

    if (chat_completion.choices[0].finish_reason != "stop"):
        print("Something went wrong during openAI call - finish_reason is not 'stop' ")
        print(chat_completion)
        sys.exit(1)

    content = chat_completion.choices[0].message.content.strip()

    # Convert the content to a dictionary
    try:
        data = json.loads(content)
    except json.JSONDecodeError:
        print("Error: The content is not valid JSON.")
        sys.exit()

    pprint.pprint(data)
    return data

## Function to call llama2.c

In [None]:
def generate_story(opening_sentence):
    # Define the command as a list
    command = [
        "../..//llama2.c/run",
        "../models/out-09/model.bin",
        "-z", "../models/out-09/tok4096.bin",
        "-t", "0.1",
        "-p", "0.9",
        "-i", opening_sentence
    ]

    # Run the command and capture the output
    result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

    # Process the output to exclude the "achieved tok/s" line
    output_lines = result.stdout.strip().split('\n')
    story = "\n".join(line for line in output_lines if "achieved tok/s" not in line)

    return story

## Function to process the images

In [None]:
# Function to process N images
def process_images_with_llama2(num_stories_with_llama2, output_file):
    results = []

    # If output file exists, load existing data
    if os.path.exists(output_file):
        with open(output_file, 'r') as f:
            try:
                results = json.load(f)
            except json.JSONDecodeError:
                print(f"Warning: Could not parse existing data in {output_file}. Starting fresh.")
    
    # Create a set of image paths already processed
    existing_image_paths = {entry["image"] for entry in results}

    # Base directory containing all the folders
    base_dir = Path("/Scandisk/onicai/charles/images")
    image_paths = [path for path in base_dir.glob("**/*.png") if not path.name.startswith(".")]
    thumbnail_size = (100, 100)

    icount = 0
    for image_path in image_paths:
        icount += 1
        print(f"------------------\n Processing image {icount}: {image_path}")
        # # Open and display the image
        image = Image.open(image_path)
        image.thumbnail(thumbnail_size)  # Resize the image to a thumbnail
        plt.figure(figsize=(2, 2))  # Adjust figure size
        plt.imshow(image)
        plt.axis('off')  # Hide axes for better view
        plt.show()

        # Make sure this image has already been processed by openAI
        if not (str(image_path) in existing_image_paths):
            print("This image was not yet processed. No opening_sentences found. Skipping...")
            continue

        existing_entry = next(entry for entry in results if entry["image"] == str(image_path))
        
        # check if we already generated llama2.c stories for this entry
        if "opening_sentences_with_stories" in existing_entry:
            print(f"Already done, found opening_sentences_with_stories for this image")
            continue
        
        # Extract data (format is a little strange, but that's what it is...)
        description = existing_entry["response"]["response"]["description"]
        opening_sentences = existing_entry["response"]["response"]["opening_sentences"]

        print(f"description = {description}")

        # Loop over all opening_sentences, and for each opening_sentence, create 3 stories
        # Then add these storeis back into the results dictionary
        response_entry = {
            "image": str(image_path),
            "description": description,
            "opening_sentences_with_stories": []
        }
        for i,sentence in enumerate(opening_sentences):
            story_set = []
            for ii in range(num_stories_with_llama2):
                print(f"For opening_sentence {i}, calling llama2.c to generate story variant {ii}")
                story = generate_story(sentence)
                story_set.append(story)
            
            # Append the generated stories under each opening_sentence
            response_entry["opening_sentences_with_stories"].append({
                "opening_sentence": sentence,
                "story_set": story_set
            })

            # pprint.pprint(response_entry)

        # Find the index of the existing entry and replace it
        entry_index = next((i for i, entry in enumerate(results) if entry["image"] == str(image_path)), None)
        if entry_index is not None:
            print(f"Replacing entry {entry_index} in results")
            results[entry_index] = response_entry  # Replace the existing entry
        else:
            # This should never happen, but keep it as a fallback to append
            print("Strange... How can we not find the index ??? Please investigate...")
            results.append(response_entry)  
    
        # Save the updated results back to the output file
        print(f"Overwriting {output_file} with updated results")
        with open(output_file, 'w') as f:
            json.dump(results, f, indent=4)

# Run it

In [None]:
output_file = './2-stories-by-llama2.json'
num_stories_with_llama2 = 2  # For each example_sentence, try num stories with llama2.c

process_images_with_llama2(num_stories_with_llama2, output_file)