# Check GPU (optional but recommended)

In [None]:
import torch

print("CUDA available:", torch.cuda.is_available())
print("Device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")

# Install dependencies

In [None]:
!pip install -q transformers accelerate safetensors pillow gradio

# Imports

In [None]:
import torch
from transformers import AutoImageProcessor, SiglipForImageClassification
from PIL import Image

# Load model & processor

In [None]:
model_name = "prithivMLmods/Geometric-Shapes-Classification"

processor = AutoImageProcessor.from_pretrained(model_name)
model = SiglipForImageClassification.from_pretrained(model_name)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()

# Label mapping

In [None]:
labels = {
    0: "Circle ◯",
    1: "Kite ⬰",
    2: "Parallelogram ▰",
    3: "Rectangle ▭",
    4: "Rhombus ◆",
    5: "Square ◼",
    6: "Trapezoid ⏢",
    7: "Triangle ▲"
}

# Inference function

In [None]:
def classify_shape(image_path):
    image = Image.open(image_path).convert("RGB")

    inputs = processor(images=image, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.softmax(outputs.logits, dim=1)[0]

    results = {
        labels[i]: round(probs[i].item(), 4)
        for i in range(len(labels))
    }

    return results


# Upload & test image

In [None]:
from google.colab import files

uploaded = files.upload()
image_path = list(uploaded.keys())[0]

classify_shape(image_path)


# Gradio Web App

In [None]:
import gradio as gr
import numpy as np

def classify_gradio(image):
    image = Image.fromarray(image).convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.softmax(outputs.logits, dim=1)[0]

    return {labels[i]: float(probs[i]) for i in range(len(labels))}

gr.Interface(
    fn=classify_gradio,
    inputs=gr.Image(type="numpy"),
    outputs=gr.Label(num_top_classes=8),
    title="Geometric Shapes Classification",
    description="Upload an image containing a geometric shape"
).launch(share=True)

