In [1]:
import torch
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
import os
import pandas as pd
from tqdm import tqdm

In [2]:
# =========================
# CONFIG
# =========================
DATASET_DIR = "../cnn_dataset"
OUTPUT_DIR = "../xgb_dataset"

IMG_SIZE = 224
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)

os.makedirs(OUTPUT_DIR, exist_ok=True)

cuda


In [3]:
# =========================
# TRANSFORM (same as CNNs)
# =========================
tfm = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],
                         [0.229,0.224,0.225])
])

In [4]:
# =========================
# LOAD MODELS
# =========================
def load_mobilenet(path, num_classes):
    m = models.mobilenet_v2(weights=None)
    m.classifier[1] = torch.nn.Linear(m.last_channel, num_classes)
    m.load_state_dict(torch.load(path, map_location=DEVICE))
    return m.eval().to(DEVICE)

def load_effnet(path, num_classes):
    m = models.efficientnet_v2_s(weights=None)
    m.classifier[1] = torch.nn.Linear(
        m.classifier[1].in_features, num_classes
    )
    m.load_state_dict(torch.load(path, map_location=DEVICE))
    return m.eval().to(DEVICE)

def load_resnet(path, num_classes):
    m = models.resnet18(weights=None)
    m.fc = torch.nn.Linear(m.fc.in_features, num_classes)
    m.load_state_dict(torch.load(path, map_location=DEVICE))
    return m.eval().to(DEVICE)


# PATHS TO THE SAVED MODELS
mobilenet = load_mobilenet("../mobilenet_v2_finetuned_best.pth", 3)
effnet    = load_effnet("../efficientnetv2s_best_finetuned_best.pth", 3)
resnet    = load_resnet("../resnet18_best_finetuned.pth", 3)

  m.load_state_dict(torch.load(path, map_location=DEVICE))
  m.load_state_dict(torch.load(path, map_location=DEVICE))
  m.load_state_dict(torch.load(path, map_location=DEVICE))


In [5]:
# =========================
# CLASS MAPPING
# =========================
class_names = sorted(os.listdir(f"{DATASET_DIR}/train"))
class_to_idx = {c:i for i,c in enumerate(class_names)}

print("Classes:", class_names)

# =========================
# FEATURE EXTRACTION
# =========================
@torch.no_grad()
def extract_probs(model, img):
    out = model(img)
    return F.softmax(out, dim=1).cpu().numpy()[0]

def process_split(split):

    rows = []

    base_dir = os.path.join(DATASET_DIR, split)

    for cls in class_names:
        cls_dir = os.path.join(base_dir, cls)

        for img_name in tqdm(os.listdir(cls_dir), desc=f"{split}/{cls}"):

            img_path = os.path.join(cls_dir, img_name)
            img = Image.open(img_path).convert("RGB")
            img = tfm(img).unsqueeze(0).to(DEVICE)

            p_mob = extract_probs(mobilenet, img)
            p_eff = extract_probs(effnet, img)
            p_res = extract_probs(resnet, img)

            features = list(p_mob) + list(p_eff) + list(p_res)
            features.append(class_to_idx[cls])

            rows.append(features)

    cols = (
        [f"mob_{i}" for i in range(3)] +
        [f"eff_{i}" for i in range(3)] +
        [f"res_{i}" for i in range(3)] +
        ["label"]
    )

    df = pd.DataFrame(rows, columns=cols)
    save_path = f"{OUTPUT_DIR}/{split}_features.csv"
    df.to_csv(save_path, index=False)

    print(f"✅ Saved {save_path} | shape = {df.shape}")

# =========================
# RUN
# =========================
process_split("train")
process_split("val")
process_split("test")


Classes: ['1509', 'IRRI-6', 'Super White']


train/1509: 100%|██████████| 3000/3000 [02:43<00:00, 18.31it/s]
train/IRRI-6: 100%|██████████| 3000/3000 [02:37<00:00, 19.06it/s]
train/Super White: 100%|██████████| 3000/3000 [02:51<00:00, 17.47it/s]


✅ Saved ../xgb_dataset/train_features.csv | shape = (9000, 10)


val/1509: 100%|██████████| 300/300 [00:16<00:00, 17.99it/s]
val/IRRI-6: 100%|██████████| 300/300 [00:16<00:00, 17.71it/s]
val/Super White: 100%|██████████| 300/300 [00:16<00:00, 17.87it/s]


✅ Saved ../xgb_dataset/val_features.csv | shape = (900, 10)


test/1509: 100%|██████████| 383/383 [00:24<00:00, 15.56it/s]
test/IRRI-6: 100%|██████████| 614/614 [00:37<00:00, 16.22it/s]
test/Super White: 100%|██████████| 2102/2102 [01:57<00:00, 17.87it/s]

✅ Saved ../xgb_dataset/test_features.csv | shape = (3099, 10)



