In [8]:
import cv2
import numpy as np
import os
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable
from sklearn.preprocessing import StandardScaler
from PIL import Image
from sklearn.cluster import KMeans

In [9]:
image_dir = "figure_data/"

In [10]:
# Preprocessing images
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [19]:
def load_images(image_folder):
    file_names = []
    images = []
    for file in os.listdir(image_folder):
        if file.endswith(".png"):
            with Image.open(image_dir + file).convert("RGB") as img:
                img = img.resize((224,224))
                images.append(img)
                file_names.append(file)
    return np.array(images), file_names

def extract_features(images, model):  
    features = []
    for img in images:
        img_tensor = preprocess(img).unsqueeze(0)
        
        with torch.no_grad():
            feature = model(img_tensor)
            features.append(feature.numpy().flatten())

    return np.array(features)

In [20]:
# Extracting the features
model = models.resnet50(pretrained=True)
model.eval()

#images, file_names = load_images(image_dir)
features = extract_features(images, model)

scaler = StandardScaler()
features_scaled = scaler.fit_transform(features)

In [23]:
# Applying K-means clustering
num_clusters = 3
kmeans = KMeans(n_clusters=num_clusters, random_state=42)
kmeans.fit(features_scaled)
labels = kmeans.labels_

for cluster_id in range(num_clusters):
    print(f"\nCluster {cluster_id}:")
    cluster_indices = np.where(labels == cluster_id)[0]

    counter = 0
    for idx in cluster_indices:
        print(f"  {file_names[idx]}")
        counter += 1
        if counter == 5:
            break


Cluster 0:
  2409.09882_FIG_2.png
  2409.08245_FIG_2.png
  2409.07913_FIG_2.png
  2409.09727_FIG_4.png
  2409.08161_FIG_3.png

Cluster 1:
  2409.11104_FIG_3.png
  2409.08027_FIG_4.png
  2409.09678_FIG_1.png
  2409.07723_FIG_2.png
  2409.07775_FIG_1.png

Cluster 2:
  2409.09882_FIG_5.png
  2409.08351_FIG_3.png
  2409.09285_FIG_2.png
  2409.09285_FIG_1.png
  2409.08946_FIG_6.png
