In [None]:
import os
import pickle
import openslide as ops
import numpy as np
import logging
import tensorflow as tf
import keras
from huggingface_hub import from_pretrained_keras

In [None]:
WSI_DIR = "../../../Lfstorage/wsis_2/"
MODEL_NAME = "/home/cilem/.cache/huggingface/hub/models--google--path-foundation/snapshots/fd6a835ceaae15be80db6abd8dcfeb86a9287e72"
PATCH_DIR = "./embeddings"
LOG_NAME = "embedding_extractor.log"
PATCH_SIZE = 512
OVERLAP = 0

In [None]:
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 
                    filename="./logs/{}".format(LOG_NAME))

logging.info("Starting patch extraction...")
if not os.path.exists(PATCH_DIR):
    os.makedirs(PATCH_DIR)

logging.info("Extracting patches from WSI...")
logging.info("Patch size: {}".format(PATCH_SIZE))
logging.info("Overlap: {}".format(OVERLAP))
logging.info("WSI directory: {}".format(WSI_DIR))
logging.info("Patch directory: {}".format(PATCH_DIR))

In [None]:
class PatchEmbeddingExtractor:
    def __init__(self, slide_root_path, model_name, patch_size, overlap):
        self.model = keras.layers.TFSMLayer(model_name, call_endpoint='serving_default')
        #self.infer = self.model.signatures["serving_default"]
        self.infer = self.model
        

        self.slides_path = []
        for root, dirs, files in os.walk(slide_root_path):
            for file in files:
                if file.endswith((".svs", ".tiff", ".tif")):
                    self.slide_path = os.path.join(root, file)
                    self.slides_path.append(self.slide_path)
        
        self.patch_size = patch_size
        self.overlap = overlap

    def __len__(self):
        return len(self.embeddings)
    
    def extract_patch_embeddings(self):
        self.embeddings = []
        for slide_path in self.slides_path:
            try:
                slide = ops.OpenSlide(slide_path)
                slide_name = os.path.basename(slide_path)
                slide_width, slide_height = slide.dimensions
                patch_width, patch_height = self.patch_size
                overlap_width, overlap_height = self.overlap

                for y in range(0, slide_height, patch_height-overlap_height):
                    for x in range(0, slide_width, patch_width-overlap_width):
                        patch = slide.read_region(location=(x, y), level=0, size=self.patch_size)
                        if patch.size < self.patch_size:
                            continue
                        else:
                            patch_ = patch.convert("RGB")
                            patch = patch_.resize((224, 224))
                            patch = np.array(patch)
                            img = tf.cast(patch, tf.float32) / 255.0
                            img = tf.expand_dims(img, 0)
                            embedding = self.infer(tf.constant(img))["output_0"].numpy()
                            self.embeddings.append({
                                "slide_name": slide_name,
                                "x": x,
                                "y": y,
                                "level": 0,
                                "patch_size": self.patch_size,
                                "resize": (224, 224),
                                "embedding_vector": embedding
                            })
                            logging.info(f"Extracted patch embedding from {slide_path} at ({x}, {y})")
            except Exception as e:
                logging.error(f"Error extracting patch embeddings from {slide_path}: {e}")

        return self.embeddings

In [None]:
extractor = PatchEmbeddingExtractor(slide_root_path=WSI_DIR, 
                                    model_name= MODEL_NAME,
                                    patch_size=(PATCH_SIZE, PATCH_SIZE), 
                                    overlap=(OVERLAP, OVERLAP))

embeddings = extractor.extract_patch_embeddings()
logging.info("Number of extracted patches: {}".format(len(embeddings)))

In [None]:
print("Number of extracted patches: {}".format(len(embeddings)))

In [None]:
logging.info("Saving embeddings...")
with open(os.path.join(PATCH_DIR, f"embeddings_wsis2_{PATCH_SIZE}.pkl"), "wb") as f:
    pickle.dump(embeddings, f)