# Globals for starting place in dataset, and number of api calls

In [None]:
START = 67093
NUM_API_CALLS = 750

# Mount Google Drive

In [None]:
import os
base_dir = '/content/drive/MyDrive/CSE_354_context_data'

from google.colab import drive
drive.mount('/content/drive/')

# Check if the directory exists, and create it if it doesn't
if not os.path.exists(base_dir):
    print(f"Directory '{base_dir}' does not exist. Creating it...")
    os.makedirs(base_dir)
else:
    print(f"Directory '{base_dir}' already exists.")

# Change to the directory
%cd $base_dir

# Load the Dataset

In [None]:
!pip install datasets
from datasets import load_dataset
dataset = load_dataset("StonyBrookNLP/tellmewhy")

In [None]:
print(dataset)

# Set up Gemini LLM

In [None]:
!pip install -q -U google-generativeai
import google.generativeai as genai # Import the Python SDK
from google.colab import userdata   # Used to securely store your API key

GOOGLE_API_KEY=userdata.get('GOOGLE_API_KEY')
genai.configure(api_key=GOOGLE_API_KEY)
Gemini = genai.GenerativeModel('gemini-pro')
# Used to securely store your API key
from google.colab import userdata

GOOGLE_API_KEY=userdata.get('GOOGLE_API_KEY')
genai.configure(api_key=GOOGLE_API_KEY)
Gemini = genai.GenerativeModel('gemini-pro')

In [None]:
# Test to see if we have access
response = Gemini.generate_content("Is python a scripted programming language?")
print(response.text)

# Generate context and then spit out a file

In [None]:
import time # for sleeping to not overload the API
import json # for saving data to files

In [None]:
# gemini prompt
prompt = '''Given the following narrative sentences that describe a story, produce a sequence of concise and to the point sentences that bring in commonsense information, and external world knowledge that is relevant. Be very verbose about commonsense knowledge and explain the reason why things are done.

Here is an example:
narrative: Cam ordered a pizza and took it home. He opened the box to take out a slice. Cam discovered that the store did not cut the pizza for him. He looked for his pizza cutter but did not find it. He had to use his chef knife to cut a slice.
Pizza is a food. People eat food when they are hungry. Pizza is usually already cut. Cam got the pizza from the store.

Produce context sentences to the following narrative without any formatting, just as a sequence of 4 short, simple, and single clause sentences, do NOT reason through multiple sentences, each sentence should state commonsense information related to the narrative:
{narrative}
'''
# prompt_2 = '''You are a highly knowledgeable assistant with access to vast commonsense and world knowledge. Given the narrative story provided, generate context to enhance understanding for a smaller model. The context should include:

# - Basic information about key concepts, settings, or objects mentioned in the story.
# - Relevant external world information, such as historical, cultural, or scientific facts, that clarifies the narrative's elements.
# - Commonsense assumptions or background details that a reader might intuitively understand but are not explicitly stated in the story.
# Here is the narrative story: {narrative}.

# Please provide the context in a concise and clear format, suitable for enriching the understanding of the story.
# '''

# lower safety settings threshold so program doesnt randomly crash
safe = [
    { "category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE", },
    { "category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE", },
    { "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE", },
    { "category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE", },
]
def inject_context(datum):
  context_prompt = prompt.format(narrative=datum['narrative'])
  context_response = Gemini.generate_content(context_prompt, safety_settings=safe)
  datum['context'] = context_response.text
  return datum

In [None]:
context_data = []
quit_early = False
trimmed_data_with_context = {}
current_narrative = ''
current_context = ''
saved_api_calls = 0
api_call_count = 0
count = 0

# Do NUM_API_CALLS api calls to gemini starting at START
# Fill in context for all data with same narrative to save on api calls
while api_call_count < NUM_API_CALLS:
  i = START + count
  try:
    if current_narrative == dataset['train'][i]['narrative']: # same narrative
      trimmed_data_with_context = {
          'narrative': current_narrative,
          'question': dataset['train'][i]['question'],
          'answer': dataset['train'][i]['answer'],
          'context': current_context,
      }
      saved_api_calls += 1 # we reused a context, saved an api call
    else: # new current_narrative
      data_with_context = inject_context(dataset['train'][i]) #calling the api
      trimmed_data_with_context = {
          'narrative': data_with_context['narrative'],
          'question': data_with_context['question'],
          'answer': data_with_context['answer'],
          'context': data_with_context['context'],
      }

      current_narrative = data_with_context['narrative']
      current_context = data_with_context['context']
      api_call_count += 1
      time.sleep(8.0) # we called the api, sleep to prevent sending too many requests
    context_data.append(trimmed_data_with_context)
    count += 1

    percent = round(api_call_count/NUM_API_CALLS, 2)*100
    print(f'\rMade ({percent: 3}%) {api_call_count:4}/{NUM_API_CALLS:4} calls to the API. Number of data modified: {count}', end='')
  except Exception as e:
    print(f"Error at index {i}")
    print(e)
    quit_early = True
    break

filename = f'context_data_starting_at_{START}_to_{START+count}.json'
if quit_early: #give different name if it fails, still save to maybe use
  filename = f'failed_context_data_starting_at_{START}_to_{START + count}.json'

with open(filename, 'w') as file:
    json.dump(context_data, file, indent=2)

print(f'\n\nSaved to file: {filename}')
print(f'Saved {saved_api_calls} api calls, by reusing context for the same narrative')