In [None]:
import sys
sys.path.append("/workdir/unsupervised_pretrain/")

In [None]:
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F

from tqdm.notebook import tqdm
import numpy as np

from models import SeriesResNet18
from datasets import SeriesEmbedDataset

# Get ready to do some business #

In [None]:
# ds = SeriesEmbedDataset(["/datasets/datasets/unsupervised-sentinel2/testset-16SEF/"], size=512, series_length=20, bands=[2,3,4,8,9])
ds = SeriesEmbedDataset(["/datasets/datasets/unsupervised-sentinel2/testset-16SEF/"], size=512, series_length=20)

In [None]:
# ds = SeriesEmbedDataset(["/datasets/datasets/berlin/32UQD/"], size=512, series_length=8)
# print(len(ds))

In [None]:
dataloader = DataLoader(
    ds,
    batch_size=8,
    shuffle=False,
    num_workers=2,
)


# Load the test set, compute embeddings, save embeddings #

This only needs to be done once.

In [None]:
device = torch.device("cuda")

In [None]:
model = torch.load("/workdir/unsupervised_pretrain/12band-resnet34.pth", map_location=device).to(device)
model = model.eval()
autoencoder = torch.load("/workdir/unsupervised_pretrain/12band-resnet34-autoencoder.pth", map_location=device).to(device)
autoencoder = autoencoder.eval()

In [None]:
visual_embeddings = []
text_embeddings = []

with torch.inference_mode():
    for imagery, _, text_embedding in tqdm(dataloader):
        visual_embedding = model(imagery.to(device))
        visual_embedding = F.normalize(visual_embedding, dim=1)
        text_embedding = F.normalize(text_embedding.to(device), dim=1)

        visual_embedding = visual_embedding.detach().cpu()
        text_embedding = text_embedding.detach().cpu()

        visual_embeddings.append(visual_embedding)
        text_embeddings.append(text_embedding)

In [None]:
text_embeddings = torch.cat(text_embeddings, dim=0)
text_embeddings.shape

In [None]:
torch.save(text_embeddings, "/workdir/unsupervised_pretrain/jupyter/text-embeddings.t")

In [None]:
visual_embeddings = torch.cat(visual_embeddings, dim=0)
visual_embeddings.shape

In [None]:
torch.save(visual_embeddings, "/workdir/unsupervised_pretrain/jupyter/visual-embeddings.t")

In [None]:
with torch.inference_mode():
    stuff = autoencoder(F.normalize(visual_embeddings.to(device), dim=1), text_embeddings.to(device))

In [None]:
torch.save(stuff, "/workdir/unsupervised_pretrain/jupyter/autoencoder-output.t")

# Load embeddings #

In [None]:
device = torch.device("cpu")

In [None]:
model = torch.load("/workdir/unsupervised_pretrain/12band-resnet34.pth", map_location=device).to(device)
model = model.eval()
autoencoder = torch.load("/workdir/unsupervised_pretrain/12band-resnet34-autoencoder.pth", map_location=device).to(device)
autoencoder = autoencoder.eval()

In [None]:
_text_embeddings = torch.load("/workdir/unsupervised_pretrain/jupyter/text-embeddings.t")
text_embeddings = _text_embeddings.detach().cpu().numpy()

In [None]:
_visual_embeddings = torch.load("/workdir/unsupervised_pretrain/jupyter/visual-embeddings.t")
visual_embeddings = _visual_embeddings.detach().cpu().numpy()

In [None]:
_stuff = torch.load("/workdir/unsupervised_pretrain/jupyter/autoencoder-output.t")
stuff = [thing.detach().cpu().numpy() for thing in _stuff]

## 2D ##

In [None]:
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

### Visual embeddings ###

Blue dots are (projections) of original embeddings, orange dots are reconstructed by/from the autoencoder.

In [None]:
tsne = TSNE(n_components=2, random_state=0)

data0 = tsne.fit_transform(visual_embeddings)
data1 = tsne.fit_transform(stuff[0])

# plot the result
plt.figure(figsize=(6, 5))
plt.scatter(data0[:, 0], data0[:, 1])
plt.scatter(data1[:, 0], data1[:, 1])
# plt.scatter(data_2d[[333], 0], data_2d[[333], 1])  # Wood
# plt.scatter(data_2d[[82], 0], data_2d[[82], 1])  # City
# plt.scatter(data_2d[[440], 0], data_2d[[440], 1])  # Water
plt.xlabel("t-SNE feature 0")
plt.ylabel("t-SNE feature 1")
plt.show()

In [None]:
np.max(np.abs(np.mean(visual_embeddings, axis=1))), np.max(np.abs(np.mean(stuff[0], axis=1)))

In [None]:
np.max(np.abs(np.mean(visual_embeddings - stuff[0], axis=1))), np.mean(np.abs(np.mean(visual_embeddings - stuff[0], axis=1)))

### Text embeddings ###

In [None]:
tsne = TSNE(n_components=2, random_state=0)

mask = ~np.isnan(text_embeddings[:, 0])
data0 = tsne.fit_transform(text_embeddings[mask])
data1 = tsne.fit_transform(stuff[1][mask])

# plot the result
plt.figure(figsize=(6, 5))
plt.scatter(data0[:, 0], data0[:, 1])
plt.scatter(data1[:, 0], data1[:, 1])
# plt.scatter(data_2d[[333], 0], data_2d[[333], 1])  # Wood
# plt.scatter(data_2d[[82], 0], data_2d[[82], 1])  # City
# plt.scatter(data_2d[[440], 0], data_2d[[440], 1])  # Water
plt.xlabel("t-SNE feature 0")
plt.ylabel("t-SNE feature 1")
plt.show()

In [None]:
mask = ~np.isnan(text_embeddings[:, 0])

In [None]:
np.max(np.abs(np.mean(text_embeddings[mask], axis=1))), np.max(np.abs(np.mean(stuff[1][mask], axis=1)))

In [None]:
np.max(np.abs(np.mean(text_embeddings[mask] - stuff[1][mask], axis=1))), np.mean(np.abs(np.mean(text_embeddings[mask] - stuff[1][mask], axis=1)))

### Shared latent space ###

In [None]:
tsne = TSNE(n_components=2, random_state=0)

mask = ~np.isnan(text_embeddings[:, 0])
data0 = tsne.fit_transform(stuff[2][mask])
data1 = tsne.fit_transform(stuff[3][mask])

# plot the result
plt.figure(figsize=(6, 5))
plt.scatter(data0[:, 0], data0[:, 1])
plt.scatter(data1[:, 0], data1[:, 1])
# plt.scatter(data_2d[[333], 0], data_2d[[333], 1])  # Wood
# plt.scatter(data_2d[[82], 0], data_2d[[82], 1])  # City
# plt.scatter(data_2d[[440], 0], data_2d[[440], 1])  # Water
plt.xlabel("t-SNE feature 0")
plt.ylabel("t-SNE feature 1")
plt.show()

In [None]:
mask = ~np.isnan(text_embeddings[:, 0])

In [None]:
np.max(np.abs(np.mean(stuff[2][mask], axis=1))), np.max(np.abs(np.mean(stuff[3][mask], axis=1)))

In [None]:
np.max(np.abs(np.mean(stuff[2][mask] - stuff[3][mask], axis=1))), np.mean(np.abs(np.mean(stuff[2][mask] - stuff[3][mask], axis=1)))

# Look for similarity #

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
from scipy.spatial.distance import cdist
from scipy import spatial
import numpy as np
import matplotlib.pyplot as plt

## Utility functions ##

In [None]:
def top_k_query_cosine(query_vector, data, k, skip=2):
    # calculate cosine similarities
    cosine_similarities = cosine_similarity(data, query_vector.reshape(1, -1)).flatten()

    # get top-k indices
    k_skip = k * skip
    top_k_indices = np.argpartition(-cosine_similarities, k_skip)[:k_skip]
    top_k_indices = list(filter(lambda n: n % skip == 0, top_k_indices))

    # return indices of the top-k closest vectors
    return top_k_indices

In [None]:
def top_k_query_l1(query_vector, data, k):
    # calculate L1 distances
    l1_distances = cdist(data, query_vector.reshape(1, -1), 'cityblock').flatten()

    # get top-k indices
    top_k_indices = np.argpartition(l1_distances, k)[:k]
    
    # return indices of the top-k closest vectors
    return top_k_indices

In [None]:
def top_k_query_l2(query_vector, data, k):
    # calculate L2 distances
    l1_distances = cdist(data, query_vector.reshape(1, -1), 'euclidean').flatten()

    # get top-k indices
    top_k_indices = np.argpartition(l1_distances, k)[:k]
    
    # return indices of the top-k closest vectors
    return top_k_indices

In [None]:
def display_image(images, image_number, ri = 3, gi = 2, bi = 1):
    # Check that image_number is valid
    if image_number < 0 or image_number >= images.shape[0]:
        raise ValueError('image_number must be between 0 and the number of images')

    # Get the RGB bands (adjusting for 1-based indexing)
    r = images[image_number, ri, :, :] # Red band
    g = images[image_number, gi, :, :] # Green band
    b = images[image_number, bi, :, :] # Blue band

    # Stack them along the last dimension to create an RGB image
    rgb = np.stack([r, g, b], axis=-1)

    # Clamp and scale to [0, 255] range for display
    rgb = np.clip(rgb, 0, 2500)  # Ensure values are within 0-2500
    rgb = (rgb / 2500) * 255  # Scale values to 0-255

    # Convert to 8-bit unsigned integer type
    rgb = rgb.astype(np.uint8)

    # Show the image
    plt.figure(figsize=(6, 6))
    plt.imshow(rgb)
    plt.axis('off')  # Hide the axes
    plt.show()

In [None]:
def display_all_images(images, ri = 3, gi = 2, bi = 1):
    # Determine the grid size to accommodate all images
    grid_size = int(np.ceil(np.sqrt(images.shape[0])))

    fig, ax = plt.subplots(grid_size, grid_size, figsize=(12, 12))

    for i in range(grid_size * grid_size):
        if i < images.shape[0]:
            # Get the RGB bands (adjusting for 1-based indexing)
            r = images[i, ri, :, :]  # Red band
            g = images[i, gi, :, :]  # Green band
            b = images[i, bi, :, :]  # Blue band

            # Stack them along the last dimension to create an RGB image
            rgb = np.stack([r, g, b], axis=-1)

            # Clamp and scale to [0, 255] range for display
            rgb = np.clip(rgb, 0, 2500)  # Ensure values are within 0-2500
            rgb = (rgb / 2500) * 255  # Scale values to 0-255

            # Convert to 8-bit unsigned integer type
            rgb = rgb.astype(np.uint8)

            # Display the image
            ax[i // grid_size, i % grid_size].imshow(rgb)
            ax[i // grid_size, i % grid_size].axis('off')  # Hide the axes
        else:
            # Hide empty subplots
            ax[i // grid_size, i % grid_size].axis('off')

    plt.show()

## Visual-visual queries ##

In [None]:
ri = 2
gi = 1
bi = 0

In [None]:
ri = 3
gi = 2
bi = 1

In [None]:
center = torch.mean(_visual_embeddings, dim=0, keepdims=True)
centered_visual_embeddings = F.normalize(_visual_embeddings - center, dim=1).detach().cpu()

In [None]:
center = center.detach().cpu().numpy()
centered_visual_embeddings = centered_visual_embeddings.numpy()

### Water ###

In [None]:
query_vector = centered_visual_embeddings[440*2]
print(top_k_query_cosine(query_vector, centered_visual_embeddings, 5))

In [None]:
images_this = ds[440*2][0]
images_this = images_this.detach().cpu().numpy()
display_all_images(images_this, ri, gi, bi)

In [None]:
images_neighbor = ds[872][0]
images_neighbor = images_neighbor.detach().cpu().numpy()
display_all_images(images_neighbor, ri, gi, bi)

In [None]:
images_neighbor = ds[874][0]
images_neighbor = images_neighbor.detach().cpu().numpy()
display_all_images(images_neighbor, ri, gi, bi)

In [None]:
images_neighbor = ds[878][0]
images_neighbor = images_neighbor.detach().cpu().numpy()
display_all_images(images_neighbor, ri, gi, bi)

In [None]:
images_neighbor = ds[828][0]
images_neighbor = images_neighbor.detach().cpu().numpy()
display_all_images(images_neighbor, ri, gi, bi)

### Farmland(?) ###

In [None]:
query_vector = centered_visual_embeddings[330*2]
print(top_k_query_cosine(query_vector, centered_visual_embeddings, 5))

In [None]:
images_this = ds[330*2][0]
images_this = images_this.detach().cpu().numpy()
display_all_images(images_this, ri, gi, bi)

In [None]:
images_neighbor = ds[868][0]
images_neighbor = images_neighbor.detach().cpu().numpy()
display_all_images(images_neighbor, ri, gi, bi)

In [None]:
images_neighbor = ds[214][0]
images_neighbor = images_neighbor.detach().cpu().numpy()
display_all_images(images_neighbor, ri, gi, bi)

### Buildings ###

In [None]:
query_vector = centered_visual_embeddings[(19 + 3*21)*2]
print(top_k_query_cosine(query_vector, centered_visual_embeddings, 5))

In [None]:
images_this = ds[(19 + 3*21)*2][0]
images_this = images_this.detach().cpu().numpy()
display_all_images(images_this, ri, gi, bi)

In [None]:
images_neighbor = ds[248][0]
images_neighbor = images_neighbor.detach().cpu().numpy()
display_all_images(images_neighbor, ri, gi, bi)

## Text queries ##

In [None]:
center = torch.mean(_stuff[0], dim=0, keepdims=True)
centered_stuff = F.normalize(_stuff[0] - center, dim=1)
center = center.detach().cpu()
centered_stuff = centered_stuff.detach().cpu()

### Text similarity ###

In [None]:
from InstructorEmbedding import INSTRUCTOR
embed_model = INSTRUCTOR("hkunlp/instructor-xl").to(device)
embed_model.max_seq_length = 4096

In [None]:
def text_visual_query(query_text, instruction, center, embeddings, k: int = 5):
    query = embed_model.encode([[instruction, query_text]])
    query = torch.from_numpy(query).to(device)
    with torch.inference_mode():
        _, z = autoencoder.autoencoder_2(query)
        z = z / z.norm(dim=1, keepdim=True)
        query = autoencoder.autoencoder_1.decoder(z)
        # query = query / query.norm(dim=1, keepdim=True)  # XXX
    query = F.normalize(query - center.to(query.device), dim=1)
    query.detach().cpu().numpy()
    return top_k_query_cosine(query, embeddings, k)

In [None]:
instruction = "Represent the geospatial data for retrieval; Input: "

In [None]:
text_visual_query("Land use land cover: farmland.", instruction, center, centered_stuff, 5)

In [None]:
text_visual_query("Buildings: ten.", instruction, center, centered_stuff, 5)

In [None]:
images_neighbor = ds[242][0]
images_neighbor = images_neighbor.detach().cpu().numpy()
display_all_images(images_neighbor, ri, gi, bi)

In [None]:
display_image(images_neighbor, 2)

### "Zero shot" classification ###

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def plot_percentage_bars(values, labels):
    # Ensure values add up to 1
    assert np.isclose(np.sum(values), 1), "Values do not add up to 1"

    # Convert values to percentages
    percentages = values * 100

    # Create the bar plot
    plt.figure(figsize=(10, 6))
    bars = plt.bar(labels, percentages, color='skyblue')

    # Add value labels on top of the bars
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, height, f'{height:.2f}%', ha='center', va='bottom')

    # Set title and ylabel
    plt.title('Bar Graph with Percentage Values')
    plt.ylabel('% Percentage')
    plt.xticks(rotation=45, ha="right") # Rotate x labels for better visibility

    plt.tight_layout()
    plt.show()

In [None]:
lulc_strings = [
    [instruction, "Land use land cover: agricultural."],
    [instruction, "Land use land cover: forest."],
    [instruction, "Land use land cover: water."],
]

In [None]:
with torch.inference_mode():
    lulcs = embed_model.encode(lulc_strings)
    lulcs = torch.from_numpy(lulcs).to(device)
    _, z = autoencoder.autoencoder_2(lulcs)
    z = z / z.norm(dim=1, keepdim=True)
    lulcs = autoencoder.autoencoder_1.decoder(z)
    lulcs = lulcs - center.to(lulcs.device)
    lulcs = F.normalize(lulcs, dim=1)
    lulcs = lulcs.detach().cpu()

In [None]:
indxs = [80, 162, 224, 880]
query = centered_stuff[indxs]
query = query - center.to(query.device)
query = F.normalize(query, dim=1)

In [None]:
results = F.softmax((query @ lulcs.t())/1e-1, dim=1)
results

In [None]:
i = 3

In [None]:
images_neighbor = ds[indxs[i]][0]
images_neighbor = images_neighbor.detach().cpu().numpy()
display_all_images(images_neighbor, ri, gi, bi)

In [None]:
plot_percentage_bars(results[i].numpy(), [s[1] for s in lulc_strings])