In [None]:
import os
import torch
from transformers import CLIPModel, CLIPProcessor
import faiss
from PIL import Image, ImageFilter
import numpy as np
import pickle

device     = "cuda" if torch.cuda.is_available() else "cpu"
model      = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor  = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
tiles_dir  = "../../data/tiles/"
tile_files = [os.path.join(tiles_dir, fname)
              for fname in os.listdir(tiles_dir) if fname.lower().endswith(".png")]

embeddings_list = []
tile_mapping    = []


In [None]:
for tile_path in tile_files: 
    image        = Image.open(tile_path).convert("RGB")
    inputs       = processor(images=image, return_tensors="pt")
    pixel_values = inputs["pixel_values"].to(device)

    with torch.no_grad():

        image_features = model.get_image_features(pixel_values)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

    embeddings_list.append(image_features.cpu().numpy())
    tile_mapping.append(tile_path)

tile_embeddings = np.concatenate(embeddings_list, axis=0)
d = tile_embeddings.shape[1]

In [None]:
d

In [None]:
import lpips
import torchvision.transforms as transforms
import os
from PIL import Image
import torch 
import numpy as np

loss_fn      = lpips.LPIPS(net='alex')
transform    = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

image_dir    = "../../data/tiles/"
image_names  = []
feature_list = []

for filename in os.listdir(image_dir):

    if filename.endswith((".jpg", ".png")):

        img        = Image.open(os.path.join(image_dir, filename)).convert("RGB")
        img_tensor = transform(img).unsqueeze(0)

        with torch.no_grad():
            features = loss_fn.net(img_tensor)
            feature_vector = torch.cat([f.mean(dim=[2, 3]).flatten() for f in features], dim=0)
            feature_vector = feature_vector.cpu().numpy().flatten()
        image_names.append(filename)
        feature_list.append(feature_vector)

feature_matrix = np.array(feature_list, dtype=np.float32)
index = faiss.IndexFlatL2(feature_matrix.shape[1]) 
index.add(feature_matrix)

faiss.write_index(index, "../../data/output/lpips_index.faiss")
np.save("../../data/output/image_names.npy", image_names)
print("LPIPS features saved and FAISS index built!")

In [None]:

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),          
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

best_lpips_score    = 0
ref                 = Image.open('../../data/examples/maps/house_map.jpg').convert('RGB')
img1_tensor         = transform(ref).unsqueeze(0)

with torch.no_grad():

    features       = loss_fn.net(img1_tensor) 
    feature_vector = torch.cat([f.mean(dim=[2, 3]).flatten() for f in features], dim=0)
    feature_vector = feature_vector.cpu().numpy().flatten()

D, I             = index.search(np.array([feature_vector], dtype=np.float32), 10)
best_lpips_score = 0

for i in range(10):
    tile_path   = image_names[I[0][i]]
    tile        = Image.open(os.path.join(image_dir, tile_path)).convert("RGB")
    tile_tensor = transform(tile).unsqueeze(0)
    lpips_score = loss_fn(tile_tensor, img1_tensor).item()
    
    print(f"Tile: {tile_path}, LPIPS: {lpips_score}")
    if lpips_score > best_lpips_score:
        best_lpips_score = lpips_score
        best_tile_path = tile_path

tile      = Image.open(os.path.join(image_dir, best_tile_path)).convert("RGB")
tile.show()


In [None]:
index = faiss.IndexFlatL2(d)
index.add(tile_embeddings)
faiss.write_index(index, '../../data/output/my_tiles_file.index')
with open('../../data/output/tile_mapping.pkl', 'wb') as f:
    pickle.dump(tile_mapping, f)

print(f"Indexed {index.ntotal} tiles.")

In [None]:
def retrieve_closest_tile(query_image_path, k=1):
    """
    Given a query image, compute its CLIP embedding and retrieve the closest tile from the FAISS index.
    """
    query_image  = Image.open(query_image_path).convert("RGB")
    inputs       = processor(images=query_image, return_tensors="pt")
    pixel_values = inputs["pixel_values"].to(device)

    with torch.no_grad():
        query_features = model.get_image_features(pixel_values)
        query_features = query_features / query_features.norm(dim=-1, keepdim=True)

    query_features           = query_features.cpu().numpy()
    distances, indices       = index.search(query_features, k)

    closest_tile_path = tile_mapping[indices[0][0]]
    return closest_tile_path, distances[0][0]

In [None]:
def retrieve_closest_label(query_image_path, text_labels, k=1):
    
    query_image  = Image.open(query_image_path).convert("RGB")
    inputs       = processor(images=query_image, return_tensors="pt")
    pixel_values = inputs["pixel_values"].to(device)

    with torch.no_grad():
        query_features = model.get_image_features(pixel_values)
        query_features = query_features / query_features.norm(dim=-1, keepdim=True)

    query_features = query_features.cpu().numpy()
    inputs_text    = processor(text=text_labels, return_tensors="pt", padding=True).to(device)
    with torch.no_grad(): 
        text_features   = model.get_text_features(**inputs_text)
        text_features   = text_features / text_features.norm(dim=-1, keepdim=True)
    text_embeddings = text_features.cpu().numpy()
    similarities    = (query_features @ text_embeddings.T)[0]
    
    top_k_indices = similarities.argsort()[-k:][::-1]
    top_k_labels  = [text_labels[i] for i in top_k_indices]
    top_k_scores  = [similarities[i] for i in top_k_indices]
    
    if k == 1:
        return top_k_labels[0], top_k_scores[0]
    else:
        return top_k_labels, top_k_scores

In [None]:
query_path             = "../data/examples/structures/bridge.jpg"
text_labels            = ["house", "rock", "forest", "dirt", "grass", "water"]
closest_tile, distance = retrieve_closest_tile(query_path)
closest_label          = retrieve_closest_label(query_path, text_labels)
print(f"Closest label: {closest_label} (distance: {distance:.4f})")

tile_image = Image.open(closest_tile)
tile_image.show()