In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torchvision import datasets, transforms
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import cv2

import matplotlib.pyplot as plt

from utils import get_data
from tqdm import tqdm
import time

import random

# seed everything for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7f7ed03b0a30>

In [3]:
import os
from openai import OpenAI
openai_api_key = os.getenv("OPENAI_KEY")

client = OpenAI(api_key = openai_api_key)

In [4]:
transform=transforms.Compose([
    transforms.ToTensor(),
])
train_ds = datasets.MNIST('../data', train=True, download=True,transform=transform)
test_ds = datasets.MNIST('../data', train=False,transform=transform)

In [5]:
indices = [t for t in range(len(train_ds))]
random.shuffle(indices)

x = []
for i in range(10):
    for j in indices:
        val = train_ds[j][1]
        if val == i:
            img, _, coords, coords_str = get_data(train_ds[j])
            x.append((img, coords, val, coords_str))
            break
examples = "Examples:\n" + "\n".join([f"({x[i][-2]}): {x[i][-1]}" for i in range(10)])

# img, coords, val, _ = x[2]

# temp_img = np.zeros((28, 28))
# for x, y in coords:
#     temp_img[x, y] = 1
# plt.imshow(temp_img, cmap='gray')

In [None]:
def generate_message(value, examples):
    instruction = """
Given (x, y) coordinates of non-zero pixels in a 28x28 grayscale image representing a digit, classify the digit between 0 and 9. 
Input format: "(x0,y0);(x1,y1);(x2,y2);..." with coordinates sorted first in the x-axis and then in the y-axis. 
PLEASE RETURN ONLY the digit number in brackets, e.g., if the digit is 3, return "(3)".
"""
 
    # remove newlines and extra spaces
    instruction = "".join(instruction.strip().split("\n"))

    instruction += "\n\n" + examples
    instruction += "\n\nInput:\n"
    instruction += value

    messages = [
        {
            "role": "user",
            "content": instruction
        }
    ]
    return messages

In [None]:
sample_size = 10

sample_indices = [i for i in range(len(test_ds))]
random.shuffle(sample_indices)

inputs = [[] for i in range(10)]
for idx in sample_indices:
    img, val, coords, coords_str = get_data(test_ds[idx])
    if len(inputs[val]) < sample_size:
        inputs[val].append((val, coords_str, coords))
    if all([len(inputs[i]) == sample_size for i in range(10)]):
        break

In [None]:
val = 9
coords = inputs[val][0][2]
print(f"Digit: {inputs[val][0][0]}")

img = np.zeros((28, 28))
for c in coords:
    img[c[0], c[1]] = 1
plt.imshow(img, cmap="gray")

In [None]:
inputs_list = []
for i in range(10):
    for j in range(sample_size):
        inputs_list.append(inputs[i][j][:2])

In [None]:
predictions = []

In [None]:
i = len(predictions)
for val, coords_str in tqdm(inputs_list[i:]):
    temp_message = generate_message(coords_str, examples)
    # gpt-4-0125-preview, gpt-3.5-turbo-0125, gpt-4
    response = client.chat.completions.create(
        model="gpt-3.5-turbo-0125",
        messages=temp_message,
        temperature=1.0,
    )
    response = response.choices[0].message.content
    predictions.append(response)

    time.sleep(0.5)
    

In [None]:
print(len(predictions))

In [None]:
# save the predictions as .npy file
import re
pattern = re.compile(r'\((\d+)\)')

predictions_int = [int(match.group(1)) for s in predictions for match in pattern.finditer(s)]

In [None]:
y_gt = [x[0] for x in inputs_list]
y_pred = predictions_int[:]

print(y_gt[:10], y_pred[:10])

In [None]:
print(f"Accuracy: {np.mean(np.array(y_gt) == np.array(y_pred))}")
# F1 score
from sklearn.metrics import f1_score
print(f"F1 score: {f1_score(y_gt, y_pred, average='weighted')}")

# make a confusion matrix
confusion_matrix = np.zeros((10, 10))
for gt, pred in zip(y_gt, y_pred):
    confusion_matrix[gt, pred] += 1

confusion_matrix = confusion_matrix / np.sum(confusion_matrix, axis=1)
# color the confusion matrix
# plot with values
fig, ax = plt.subplots()
cax = ax.matshow(confusion_matrix, cmap='viridis')
for (i, j), val in np.ndenumerate(confusion_matrix):
    ax.text(j, i, f"{int(100 * val)}", ha='center', va='center', color='white')
plt.xlabel('Predicted')
plt.ylabel('Ground Truth')
plt.title('Confusion Matrix')
plt.show()


In [None]:
img, val, coords, coords_str = get_data(test_ds[sample_indices[120]])

print(f"The number is {val}")
img = np.zeros((28, 28))
for i in range(len(coords)):
    img[coords[i][0], coords[i][1]] = 1
plt.imshow(img, cmap='gray')