In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

from transformers import ViTModel, ViTImageProcessor
from sklearn.cluster import KMeans
import numpy as np

In [None]:
class HouseImagesDataset(Dataset):
    def __init__(self, root_dir, processor):
        """
        root_dir: Directory with Georgetown property images.
        processor: ViTImageProcessor for preprocessing (resize, normalization, etc.).
        """
        self.root_dir = root_dir
        self.processor = processor
        self.image_paths = [
            os.path.join(root_dir, f) 
            for f in os.listdir(root_dir) 
            if f.lower().endswith(('.jpg', '.jpeg', '.png'))
        ]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        
        # Preprocess the image for ViT
        encoding = self.processor(images=image, return_tensors="pt")
        pixel_values = encoding["pixel_values"].squeeze()  # shape: (3, H, W)
        
        return pixel_values, img_path

In [None]:
def create_dataloader(root_dir, processor, batch_size=8):
    dataset = HouseImagesDataset(root_dir, processor)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    return dataloader

In [None]:
def load_vit_model():
    # Load the base ViT model (without the classification head)
    model = ViTModel.from_pretrained("google/vit-base-patch16-224")
    model.eval()  # Set to eval mode for inference
    return model

In [None]:
def extract_embeddings(model, dataloader, device):
    """ Extract embeddings for all images in `dataloader` using `model`. 
        Returns a list of (embedding, img_path).
    """
    model.to(device)
    all_embeddings = []
    all_paths = []

    with torch.no_grad():
        for pixel_values, img_paths in dataloader:
            pixel_values = pixel_values.to(device)
            outputs = model(pixel_values)
            
            # outputs.last_hidden_state is [batch_size, seq_len, hidden_size]
            # The [CLS] token embedding is typically at index 0
            # Some ViT models also have a `pooler_output`, but not all.
            cls_embeddings = outputs.last_hidden_state[:, 0, :]  # shape: (batch_size, hidden_size)

            cls_embeddings = cls_embeddings.cpu().numpy()
            
            for emb, path in zip(cls_embeddings, img_paths):
                all_embeddings.append(emb)
                all_paths.append(path)

    return np.array(all_embeddings), all_paths

In [None]:
def cluster_embeddings(embeddings, n_clusters=5):
    """
    Use KMeans to cluster embeddings into `n_clusters`.
    Returns the cluster assignments for each embedding.
    """
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    cluster_labels = kmeans.fit_predict(embeddings)
    return cluster_labels

In [None]:
def main():
    # 1. Data directory with unlabeled images
    root_dir = "/path/to/house/images"  # e.g. dataset/

    # 2. Load ViT processor
    processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")

    # 3. Create data loader
    batch_size = 8
    dataloader = create_dataloader(root_dir, processor, batch_size)

    # 4. Load the pretrained ViT model (for embeddings)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    vit_model = load_vit_model()

    # 5. Extract embeddings
    embeddings, paths = extract_embeddings(vit_model, dataloader, device)

    # 6. Cluster embeddings
    n_clusters = 5  # you can guess how many clusters might exist
    cluster_labels = cluster_embeddings(embeddings, n_clusters)

    # 7. Inspect results
    # Here, we just print out each image path with its cluster assignment
    for img_path, label in zip(paths, cluster_labels):
        print(f"{img_path} -> Cluster {label}")

    # Optionally, you could group these by cluster and analyze images together
    # to see if each cluster corresponds to a certain type of property, style, etc.

if __name__ == "__main__":
    main()