In [4]:
import os
import glob
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.auto import tqdm

import torch
from torch.utils.data import Dataset, DataLoader

from transformers import ViTModel, AutoImageProcessor

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import accuracy_score, classification_report

# 1. CONFIG
DATA_DIR = "/kaggle/input/fruit-detection-yolo/fruit-detection-dataset/fruit-detection-dataset"  # ✏️ adjust if needed
TRAIN_IMG_DIR = os.path.join(DATA_DIR, "images/train")
TRAIN_LABEL_DIR = os.path.join(DATA_DIR, "labels/train")
VAL_IMG_DIR   = os.path.join(DATA_DIR, "images/val")
VAL_LABEL_DIR = os.path.join(DATA_DIR, "labels/val")
BATCH_SIZE = 16
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CLASS_NAMES = ["apple","avocado","banana","guava","kiwi","mango","orange","peach","pineapple"]


In [5]:
# 2. BUILD CLASSIFICATION DATAFRAME
def build_df(img_dir, label_dir):
    rows = []
    for img_path in tqdm(glob.glob(os.path.join(img_dir, "*"))):
        fname = os.path.basename(img_path)
        lbl_path = os.path.join(label_dir, os.path.splitext(fname)[0] + ".txt")
        if not os.path.exists(lbl_path):
            continue
        # read first line
        with open(lbl_path, "r") as f:
            line = f.readline().strip().split()
        if len(line)==0: continue
        class_id = int(line[0])
        rows.append({"image": img_path, "label": class_id})
    return pd.DataFrame(rows)

train_df = build_df(TRAIN_IMG_DIR, TRAIN_LABEL_DIR)
val_df   = build_df(VAL_IMG_DIR,   VAL_LABEL_DIR)

print(f"Train samples: {len(train_df)}, Val samples: {len(val_df)}")

# 3. DATASET + DATALOADER
processor = AutoImageProcessor.from_pretrained("google/vit-large-patch16-224-in21k")

  0%|          | 0/534 [00:00<?, ?it/s]

  0%|          | 0/134 [00:00<?, ?it/s]

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


Train samples: 534, Val samples: 134


In [6]:
class FruitClsDataset(Dataset):
    def __init__(self, df, processor):
        self.df = df.reset_index(drop=True)
        self.processor = processor

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.loc[idx]
        img = Image.open(row.image).convert("RGB")
        # HF processor returns pixel_values tensor
        pv = self.processor(images=img, return_tensors="pt")["pixel_values"].squeeze(0)
        label = row.label
        return pv, label

def collate_fn(batch):
    imgs = torch.stack([b[0] for b in batch], dim=0)
    labels = torch.tensor([b[1] for b in batch], dtype=torch.long)
    return imgs, labels

# extract embeddings function
def extract_embeddings(df):
    ds = FruitClsDataset(df, processor)
    dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
    model = ViTModel.from_pretrained("google/vit-large-patch16-224-in21k").eval().to(DEVICE)
    all_embs = []
    all_labels = []
    with torch.no_grad():
        for imgs, labels in tqdm(dl):
            imgs = imgs.to(DEVICE)
            out = model(imgs).last_hidden_state[:,0]   # take CLS token
            all_embs.append(out.cpu().numpy())
            all_labels.append(labels.numpy())
    embs = np.concatenate(all_embs, axis=0)
    labs =  np.concatenate(all_labels,axis=0)
    return embs, labs

In [7]:
# 4. EXTRACT FEATURES
print("Extracting train embeddings…")
X_train, y_train = extract_embeddings(train_df)
print("Extracting val embeddings…")
X_val,   y_val   = extract_embeddings(val_df)

# 5. TRAIN & EVALUATE CLASSIFIERS
classifiers = {
    "LogisticRegression": LogisticRegression(max_iter=200, n_jobs=-1),
    "SVM (RBF)":          SVC(gamma="scale"),
    "RandomForest":       RandomForestClassifier(n_estimators=100, n_jobs=-1),
    "GradientBoosting":   GradientBoostingClassifier(n_estimators=100),
}

Extracting train embeddings…


  0%|          | 0/34 [00:00<?, ?it/s]



Extracting val embeddings…


  0%|          | 0/9 [00:00<?, ?it/s]

In [3]:
results = []
for name, clf in classifiers.items():
    print(f"\nTraining {name}…")
    clf.fit(X_train, y_train)
    preds = clf.predict(X_val)
    acc   = accuracy_score(y_val, preds)
    print(f"→ {name} validation accuracy: {acc:.4f}")
    print(classification_report(y_val, preds, target_names=CLASS_NAMES, zero_division=0))
    results.append({"model": name, "val_acc": acc})

# 6. SUMMARY
res_df = pd.DataFrame(results).sort_values("val_acc", ascending=False)
print("\n=== Summary ===")
print(res_df)


Extracting train embeddings…


model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

  0%|          | 0/34 [00:00<?, ?it/s]



Extracting val embeddings…


  0%|          | 0/9 [00:00<?, ?it/s]


Training LogisticRegression…
→ LogisticRegression validation accuracy: 0.9627
              precision    recall  f1-score   support

       apple       1.00      0.91      0.95        22
     avocado       1.00      1.00      1.00        11
      banana       0.95      1.00      0.97        18
       guava       1.00      1.00      1.00         9
        kiwi       1.00      1.00      1.00        13
       mango       1.00      0.91      0.95        11
      orange       0.90      0.95      0.92        19
       peach       0.89      1.00      0.94        17
   pineapple       1.00      0.93      0.96        14

    accuracy                           0.96       134
   macro avg       0.97      0.97      0.97       134
weighted avg       0.97      0.96      0.96       134


Training SVM (RBF)…
→ SVM (RBF) validation accuracy: 0.9403
              precision    recall  f1-score   support

       apple       1.00      0.86      0.93        22
     avocado       1.00      0.82      0.90   