# 1. Set-up

## 1.0 Installations

In [None]:
!pip install easyfsl

## 1.1 Base Imports

In [None]:
import requests
import base64
import os
from google.colab import userdata

## 1.2 Helper Functions & Necessary Data

In [None]:
openai_api_key = userdata.get("openai_api_key")

def get_headers(api_key):
  request_headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {api_key}"
    }

  return request_headers

def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

def get_classification_payload(image_path):
    payload = {
        "model": "gpt-4-vision-preview",
        "messages": [
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": "Classify this flag's country (or 'edge' if it is not a country) and creative complexity (on a scale 0-10, with blank flags being a 0, the Indonesian flag being a 2, and the American flag being a 10. Please format the response in the form of '{Country} {Creative Complexity}.'"
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{encode_image(image_path)}"
                        }
                    }
                ]
            }
        ],
        "max_tokens": 300
    }

    return payload

def make_classification_request(image_path, api_key):
  response = requests.post("https://api.openai.com/v1/chat/completions",
                            headers=get_headers(api_key),
                            json=get_classification_payload(image_path))

  return response.json()

## 1.3 Data Loading

Note: We don't need to train GPT for this task. We also as of 12/15/2023 cannot fine-tune the model.

In [None]:
# Set the base path to specify where we are working
PROJECT_BASE_PATH = "/content/drive/MyDrive/CS 229 Project"  # Change to match your mounted drive layout

# Set paths for our test data with proper relation to our mounted drive
TEST_PATH = os.path.join(PROJECT_BASE_PATH, "all_complexities_easyset_test.json")

### a. (Run Once) Split Data
You should not (and likely will get an error either way if you try to) repeat this step.

Make sure you have all your files in hand so you do not have to reset your data and repeat this step.

In [None]:
# You'll need to copy over your PROJECT_BASE_PATH to here as well
safeguard = True

if not safeguard:
  !python /content/drive/MyDrive/"CS 229 Project"/create_easyset_data.py
  !python /content/drive/MyDrive/"CS 229 Project"/create_train_test_easyset.py
else:
  print("Safeguard is active. Skipping data split.")

### b. (Optional) Verify Data Paths

In [None]:
print(f"{TEST_PATH = }")

## 1.4 Generate EasySet Data

In [None]:
from easyfsl.datasets import EasySet
from torchvision import transforms

image_size = 80

test_transform = transforms.Compose(
    [
        transforms.RandomResizedCrop(image_size),
        transforms.ToTensor()
        # ,transforms.Normalize(**{"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]})
    ]
)

test_set = EasySet(TEST_PATH, image_size=image_size, transform=test_transform)

### (Optional) Validate Data

In [None]:
# Test set checks
print(f"{test_set.number_of_classes() = }")
print(f"{len(test_set) = }")

## 1.5 Get Test Set

In [None]:
from tqdm import tqdm

# Get our query images and labels
data_roots = [os.path.join(PROJECT_BASE_PATH, "data", class_name + "_test") for class_name in test_set.class_names]
image_paths, labels = test_set.list_data_instances(data_roots)

### (Optional) Check size

In [None]:
assert(len(image_paths) == len(test_set))

# 2.0 Run Comparisons

In [None]:
from time import sleep

predictions = {}

RUN_SAFEGUARD = True

if not RUN_SAFEGUARD:
  for i in tqdm(range(len(image_paths))):
    path = image_paths[i]

    try:
      response = make_classification_request(path, openai_api_key)
      prediction = response["choices"][0]["message"]["content"]
      country, complexity = prediction.split()

      predictions[path] = {"country": country, "complexity": int(float(complexity))}

    except:
      print(path, "had an error. Waiting and trying again")
      sleep(60)

      try:
        response = make_classification_request(path, openai_api_key)
        prediction = response["choices"][0]["message"]["content"]
        country, complexity = prediction.split()

        predictions[path] = {"country": country, "complexity": int(float(complexity))}

      except:
        print(path, "failed to resolve its error.")
else:
  print("Run time safeguard for API calls is active. Skipping cell.")

In [None]:
import json

PREDICTIONS_SAVE_PATH = "chatgpt_predictions.json"
gpt_save_path = os.path.join(PROJECT_BASE_PATH, PREDICTIONS_SAVE_PATH)

if not RUN_SAFEGUARD:
  with open(gpt_save_path, 'w') as out_fp:
    json.dump(predictions, out_fp, indent=4)

else:
  with open(gpt_save_path, 'r') as in_fp:
    predictions = json.load(in_fp)

# 3.0 Evaluate


## 3.1 Complexity

In [None]:
num_correct = 0
num_total = len(predictions)
num_counted = len(predictions)

for i in range(num_total):
  try:
    image_path = image_paths[i]

    prediction = predictions[image_path]["complexity"]
    y_complexity = labels[i]

    if (y_complexity >= 2 and prediction >= 2) or (y_complexity < 2 and prediction < 2):
      num_correct += 1
  except:
    num_counted -= 1

print(f"Num Correct: {num_correct}/{num_counted}")
print(f"Accuracy: {num_correct / num_counted}")

## 3.2 Country

In [None]:
num_correct = 0
num_total = len(predictions)
num_counted = len(predictions)

with open(os.path.join(PROJECT_BASE_PATH, "final_data.json")) as in_fp:
  all_data = json.load(in_fp)

for i in range(num_total):
  try:
    image_path = image_paths[i]

    prediction = predictions[image_path]["country"]
    cleaned_path = image_path[len(image_path) - image_path[::-1].find('/'):]
    y_country = all_data[cleaned_path]["country"]

    if y_country.lower() == prediction.lower():
      num_correct += 1
  except:
    num_counted -= 1

print(f"Num Correct: {num_correct}/{num_counted}")
print(f"Accuracy: {num_correct / num_counted}")