In [None]:
import torch
import clip
from PIL import Image
import torchvision.transforms as transforms
from pathlib import Path
from sparse_autoencoder import SparseAutoencoder


def process_image_pipeline(image_path, sae_model_path, output_path):
    """
    Przetwarza obraz przez model CLIP i SAE, a następnie zapisuje wynik.
    :param image_path: Ścieżka do obrazu wejściowego.
    :param sae_model_path: Ścieżka do wytrenowanego modelu SAE.
    :param output_path: Ścieżka do zapisu przetworzonych cech.
    """

    # Wybór urządzenia
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # CLIP
    model, preprocess = clip.load("ViT-L/14", device=device)
    # Załaduj i przetwórz obraz
    image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
    # Przetwarzanie obrazu przez CLIP
    with torch.no_grad():
        image_features = model.encode_image(image)

    # SAE
    def load_sae_model(sae_checkpoint_path):
        state_dict = torch.load(sae_checkpoint_path, map_location=device)
        autoencoder_input_dim = 768  # CLIP ViT-L/14
        expansion_factor = 8
        n_learned_features = int(autoencoder_input_dim * expansion_factor)
        len_hook_points = 1  

        sae = SparseAutoencoder(
            n_input_features=autoencoder_input_dim,
            n_learned_features=n_learned_features,
            n_components=len_hook_points
        ).to(device)

        sae.load_state_dict(state_dict)
        sae.eval()
        return sae  

    # Przepuszczanie CLIP features przez SAE
    @torch.no_grad()
    def get_sae_representation(clip_features, sae_model):
        concepts, _ = sae_model(clip_features)
        return concepts


    sae = load_sae_model(sae_model_path)
    sae_repr = get_sae_representation(image_features, sae)

    # Zapisz wynik
    torch.save(sae_repr.cpu(), output_path)
    print(f"Zapisano: {output_path}")


    


In [25]:
process_image_pipeline("dog.jpeg", "clip_ViT-L_14sparse_autoencoder_final.pt", "dog_sae_concepts.pth")

Zapisano: dog_sae_concepts.pth
