In [None]:
from huggingface_hub import login

login('')

In [None]:
!git clone https://github.com/docty/transformer-training.git

In [None]:
%cd transformer-training/

In [None]:
#!git clone https://huggingface.co/Docty/{MODEL_ID}

In [None]:
!pip install -q evaluate datasets torchvision transformers hf_xet

In [None]:
import os
os.environ["WANDB_DISABLED"] = "true"

In [None]:
PRETRAINED_MODEL = "google/vit-base-patch16-224-in21k"
DATASET_NAME = "Docty/solaices"
OUTPUT_DIR="./solacies/"

In [None]:
from datasets import load_dataset
import os

def download_samples(DATASET_NAME, DATASET_SAMPLE = "images_samples"):
  dataset = load_dataset(DATASET_NAME)
  label_names = dataset["train"].features["label"].names

  os.makedirs(DATASET_SAMPLE, exist_ok=True)

  examples = {}

  for sample in dataset["train"]:
      label = sample["label"]
      image = sample["image"]
      if label not in examples:
          examples[label] = sample
          image.save(os.path.join(DATASET_SAMPLE, f"{label}.jpg"))
      if len(examples) == len(label_names):
          break

In [None]:
download_samples(DATASET_NAME)

# Training

In [None]:
!python {os.getcwd()}/image_classification.py \
    --model_name_or_path "$PRETRAINED_MODEL" \
    --dataset_name "$DATASET_NAME" \
    --output_dir "$OUTPUT_DIR"   \
    --remove_unused_columns False \
    --label_column_name label \
    --do_train \
    --do_eval \
    --learning_rate 2e-5 \
    --num_train_epochs 2 \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 8 \
    --logging_strategy steps \
    --logging_steps 10 \
    --eval_strategy epoch \
    --save_strategy epoch \
    --load_best_model_at_end True \
    --save_total_limit 3 \
    --seed 1337 \
    --push_to_hub \
    --report_to none

# Inference

In [None]:
from PIL import Image
from transformers import pipeline
import random

item = random.choice(os.listdir('./images_samples'))
print(item)
img = Image.open(f"./images_samples/{item}")
classifier = pipeline("image-classification", model=OUTPUT_DIR)
classifier(img)


In [None]:
import gradio as gr
from PIL import Image
import requests
from io import BytesIO
import os
from transformers import pipeline


classifier = pipeline("image-classification", model=OUTPUT_DIR)

def classify_image(input_img=None, img_url=None):
    """
    Accepts either an uploaded image (input_img) or an image URL (img_url).
    """
    if img_url:  # If a URL is provided, fetch image from the internet
        try:
            response = requests.get(img_url)
            response.raise_for_status()
            img = Image.open(BytesIO(response.content)).convert("RGB")
        except Exception as e:
            return {"Error": f"Failed to load image from URL: {e}"}
    elif input_img:  # If uploaded image is provided
        if not isinstance(input_img, Image.Image):
            img = Image.fromarray(input_img)
        else:
            img = input_img
    else:
        return {"Error": "No image provided."}

    results = classifier(img)
    return {res["label"]: float(res["score"]) for res in results}


theme = gr.themes.Soft(
    primary_hue="blue",
    secondary_hue="lime",
    neutral_hue="slate"
)

with gr.Blocks(theme=theme) as demo:
    gr.Markdown("## Image Classifier")
    gr.Markdown("Upload an image **or** enter an image URL to classify it using the model.")

    with gr.Row():
        image_input = gr.Image(type="pil", label="Upload Image")
        url_input = gr.Textbox(label="Image URL (optional)", placeholder="Paste image URL here...")
        label_output = gr.Label(num_top_classes=3, label="Predictions")

    classify_btn = gr.Button("Classify Image", variant="primary")

    gr.Examples(
        examples=[f'./images_samples/{i}' for i in os.listdir('./images_samples')],
        inputs=image_input,
        outputs=label_output,
        fn=classify_image,
        cache_examples=False
    )

    classify_btn.click(
        fn=classify_image,
        inputs=[image_input, url_input],
        outputs=label_output
    )

demo.launch(share=True)


In [None]:
!zip -r images_samples.zip /content/transformer-training/images_samples

In [None]:
from google.colab import files
files.download("images_samples.zip")
