In [1]:
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from huggingface_hub import from_pretrained_keras
import tensorflow as tf
import numpy as np
import os
import logging
import pickle
import matplotlib.pyplot as plt

In [2]:
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 
                    filename='logs/embedding_and_clustering.log')

PATCH_DIR = "./datasets/patches"
MODEL_NAME = "google/path-foundation"

logging.info(f"Patch directory: {PATCH_DIR}")
logging.info(f"Model name: {MODEL_NAME}")

In [3]:
class GenerateEmbeddings:
    def __init__(self, model_name, patch_dir):
        logging.info(f"Loading model: {model_name}")
        self.model = from_pretrained_keras(model_name)
        self.infer = self.model.signatures["serving_default"]
        self.patch_dir = patch_dir
        self.patches_dir = []

        for root, _, files in os.walk(patch_dir):
            for file in files:
                if file.endswith(".png"):
                    self.patches_dir.append(os.path.join(root, file))
    
    def __len__(self):
        return len(self.patches_dir)

    def generate_embeddings(self):
        embeddings = []
        for patch in self.patches_dir:
            try:
                logging.info(f"Generating embeddings for {patch}")
                img = tf.io.read_file(patch)
                img = tf.image.decode_png(img, channels=3)
                img = tf.image.resize(img, (224, 224))
                img = tf.cast(img, tf.float32) / 255.0
                img = tf.expand_dims(img, 0)
                embedding = self.infer(img)["output_0"].numpy()
                embeddings.append({
                    "patch": os.path.basename(patch),
                    "embeds": embedding
                })
            except Exception as e:
                logging.error(f"Error generating embeddings for {patch}: {e}")
        return embeddings

In [None]:
embed_gen = GenerateEmbeddings(MODEL_NAME, PATCH_DIR)
embeddings = embed_gen.generate_embeddings()

with open("embeddings.pkl", "wb") as f:
    pickle.dump(embeddings, f)

In [6]:
logging.info(f"Length of embeddings: {len(embeddings)}")