<a href="https://colab.research.google.com/github/hydmic009/CS3001F-A1/blob/main/Copy_of_modAl_AL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# 1. Install dependencies
!pip install -q --upgrade git+https://github.com/modAL-python/modAL.git google-generativeai Pillow

import os
import sys
import numpy as np
from PIL import Image
import google.generativeai as genai
import time

from modAL.models import ActiveLearner
from modAL.uncertainty import uncertainty_sampling

from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier

from matplotlib import pyplot as plt
# A robust backend for non-interactive environments
import matplotlib
matplotlib.use('Agg')
%matplotlib inline

# 2. Configure Gemini
# --- FIX: Replace "YOUR_GEMINI_API_KEY" with your actual API key ---
# You can get a key from https://aistudio.google.com/app/apikey
MY_API_KEY = "AIzaSyDjMJ_1fymbbG2B51gZ4zDyVqQbmzZ1S1E"
# ----------------------------------------------------------------

MY_GEMINI_MODEL = "gemini-1.5-flash"

# Add a check for the API key to provide a clearer error message
if "YOUR_GEMINI_API_KEY" in MY_API_KEY:
    raise ValueError("Please replace 'YOUR_GEMINI_API_KEY' with your actual Google Gemini API key.")

genai.configure(api_key=MY_API_KEY)

# 3. Helper to query Gemini via GenerativeModel
def query_gemini_label(img_array):
    """Queries Gemini for a label for the given digit image."""
    arr = (img_array.reshape(8,8) / 16 * 255).astype(np.uint8)
    pil_img = Image.fromarray(arr)
    prompt = "You are an expert at reading low-resolution handwritten digits.\nBelow is an 8x8 grayscale image of a digit. Reply with exactly one digit (0-9)."
    model = genai.GenerativeModel(MY_GEMINI_MODEL)
    try:
        response = model.generate_content(
            [prompt, pil_img],
            generation_config={"temperature": 0.0}
        )
        return int(response.text.strip().splitlines()[0])
    except (ValueError, IndexError, RuntimeError) as e:
        print(f"Error parsing Gemini response: {e}. Defaulting to None.")
        return None

# 4. Load data and warm-start ActiveLearner
X, y = load_digits(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
n_initial = 100
initial_idx = np.random.choice(len(X_train), n_initial, replace=False)
X_initial, y_initial = X_train[initial_idx], y_train[initial_idx]
X_pool = np.delete(X_train, initial_idx, axis=0)
y_pool = np.delete(y_train, initial_idx, axis=0)

learner = ActiveLearner(
    estimator=RandomForestClassifier(),
    query_strategy=uncertainty_sampling,
    X_training=X_initial,
    y_training=y_initial
)

# 5. Active learning + Gemini loop
n_queries = 20
learner_acc = [learner.score(X_test, y_test)]
gemini_acc = []
gemini_correct = 0

plt.style.use('default')

for i in range(n_queries):
    query_idx, query_inst = learner.query(X_pool)
    true_label_arr = y_pool[query_idx]
    true_label_scalar = true_label_arr[0]

    # Query Gemini and update its accuracy
    gem_label = query_gemini_label(query_inst)
    if gem_label is not None and gem_label == true_label_scalar:
        gemini_correct += 1
    gemini_acc.append(gemini_correct / (i + 1))

    print(f"Query {i+1}/{n_queries} -> True label: {true_label_scalar} â€” Gemini guessed: {gem_label}")

    # --- FINAL FIX: Explicitly reshape both arrays to the correct dimensions ---
    # Reshape the instance to (1, 64) and the label to (1,) to ensure
    # they can be correctly stacked with the existing training data.
    X_to_teach = query_inst.reshape(1, -1)
    y_to_teach = true_label_arr.reshape(1,)
    learner.teach(X=X_to_teach, y=y_to_teach)
    # -------------------------------------------------------------------------


    time.sleep(1)
    # Remove the queried instance from the pool
    X_pool = np.delete(X_pool, query_idx, axis=0)
    y_pool = np.delete(y_pool, query_idx, axis=0)
    learner_acc.append(learner.score(X_test, y_test))

# 6. Final comparison plot
plt.figure(figsize=(10, 6))
plt.plot(range(len(learner_acc)), learner_acc, marker='o', label='Classifier (Active Learner)')
plt.plot(range(1, n_queries + 1), gemini_acc, marker='x', label='Gemini (Oracle)')
plt.title('Classifier vs. Gemini Accuracy')
plt.xlabel('Number of Queries')
plt.ylabel('Accuracy on Test Set')
plt.xticks(range(0, n_queries + 1, 2))
plt.grid(True, linestyle='--', alpha=0.7)
plt.legend()
plt.show()

  Preparing metadata (setup.py) ... [?25l[?25hdone
