In [None]:
import os
from tqdm import tqdm
import torch
import logging
import pickle

from utils.dataset_loader import CustomDatasetFromSlide
from torch.utils.data import DataLoader
from utils.latent_extractor import LatentExtractor, TextEmbeddingExtractor, ImageEmbeddingExtractor
from torchvision import transforms
from datasets import load_dataset

In [None]:
vae_base_name = "stable-diffusion-v1-5/stable-diffusion-v1-5"
text_encoder_base_name = "openai/clip-vit-large-patch14"
image_encoder_base_name = "openai/clip-vit-large-patch14"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
if not os.path.exists("latent_files"):
    os.makedirs("latent_files")

In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

logging.info("Transform: {}".format(transform))

In [None]:
MODE = "I2I"

latent_extractor = LatentExtractor(vae_name=vae_base_name, device=device, transform=None)

if MODE == "T2I":
    save_path = "./latent_files/dataset_with_latents_t2i.pkl"
    text_embedding_extractor = TextEmbeddingExtractor(text_encoder_name=text_encoder_base_name, device=device)

elif MODE == "E2I":
    save_path = "./latent_files/dataset_with_latents_e2i.pkl"

else:
    save_path = "./latent_files/dataset_with_latents_i2i.pkl"
    image_embedding_extractor = ImageEmbeddingExtractor(img_encoder_name=image_encoder_base_name, device=device)
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])

In [None]:
if not os.path.exists('logs'):
    os.makedirs('logs')

In [None]:
log_name ="latent_extractor"

logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 
                    filename=f'logs/{log_name}.log')

logging.info(f"Device: {device}")
logging.info(f"Log file: {log_name}")
logging.info(f"Save path: {save_path}")

In [None]:
dataset = load_dataset("Cilem/mixed-histopathology-512")
train_dset = dataset["train"]
train_dataset = CustomDatasetFromSlide(dataset=train_dset, 
                                       slide_dir="/home/cilem/Lfstorage/wsis", 
                                       transform=transform)

dataloader = DataLoader(train_dataset, batch_size=64, shuffle=False)

dataset_with_latents = []

for i, data in enumerate(tqdm(dataloader, desc="Processing dataset")):

    image = data["image"].to(device)
    google_embedding_vector = data["embedding"]
    organ = data["organ"]

    latent = latent_extractor.extract_latent(image=image)
    
    if MODE == "T2I":
        text = [f"histopathology image of {organ[j]}" for j in range(len(organ))]
        embedding = text_embedding_extractor.extract_text_embedding(text=text)
    
    elif MODE == "E2I":
        embedding = google_embedding_vector
        
    else:
        embedding = image_embedding_extractor.extract_image_embedding(image=image)

    for j in range(len(latent)):
        dataset_with_latents.append({"latent": latent[j], 
                                     "embedding_vector": embedding[j]})

    if i % 100 == 0:
        with open(save_path, "wb") as f:
            pickle.dump(dataset_with_latents, f)

with open(save_path, "wb") as f:
    pickle.dump(dataset_with_latents, f)

print("Dataset with latents saved to {}".format(save_path))
print("Dataset size: {}".format(len(dataset_with_latents)))
      
logging.info("Dataset with latents saved to {}".format(save_path))
logging.info("Dataset size: {}".format(len(dataset_with_latents)))