In [None]:
import os
from pathlib import Path
from dotenv import load_dotenv
from PIL import Image
import numpy as np
import rootutils
import torch
import torchvision
from sklearn.preprocessing import normalize
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors
import umap
from matplotlib.lines import Line2D

from lightly.data import LightlyDataset

# adding root to python path
rootutils.setup_root(
    os.path.abspath(''), indicator=['.git', 'pyproject.toml'], pythonpath=True
)

from src.models.components.nn_utils import weight_load

load_dotenv()

In [None]:
data_path = Path(os.environ.get('lear_good_data_path'))

test_transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
    ]
)

data_test = LightlyDataset(
    input_dir=data_path,
    transform=test_transform
    )

dataloader_test = DataLoader(
    dataset=data_test,
    batch_size=10,
    num_workers=3,
    persistent_workers=True,
    shuffle=False,
    drop_last=False
    )

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
resnet = torchvision.models.resnet18()
model = torch.nn.Sequential(*list(resnet.children())[:-1]).to(device)

weights = weight_load(
    ckpt_path='../trained_models/contrastive_model.ckpt',
    weights_only=True,
    remove_prefix='backbone.'
)

model.load_state_dict(weights, strict=False)
model.eval()

In [None]:
def generate_embeddings(model, dataloader):
    """Generates representations for all images in the dataloader with
    the given model
    """

    embeddings = []
    filenames = []
    with torch.no_grad():
        for img, _, fnames in dataloader:
            img = img.to(device)
            emb = model(img).flatten(start_dim=1)
            embeddings.append(emb)
            filenames.extend(fnames)

    embeddings = torch.cat(embeddings, 0)
    embeddings = normalize(embeddings.cpu().numpy())
    return embeddings, filenames

In [None]:
embeddings, filenames = generate_embeddings(model, dataloader_test)
np.save('good_embeddings.npy', embeddings)

In [None]:
def get_image_as_np_array(filename: str):
    """Returns an image as an numpy array"""
    img = Image.open(filename)
    return np.asarray(img)

In [None]:
def plot_knn_examples(embeddings, filenames, n_neighbors=5, num_examples=30):
    """Plots multiple rows of random images with their nearest neighbors"""
    # lets look at the nearest neighbors for some samples
    # we use the sklearn library
    nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(embeddings)
    distances, indices = nbrs.kneighbors(embeddings)

    # get 5 random samples
    samples_idx = np.random.choice(len(indices), size=num_examples, replace=False)

    # loop through our randomly picked samples
    for idx in samples_idx:
        fig = plt.figure()
        # loop through their nearest neighbors
        for plot_x_offset, neighbor_idx in enumerate(indices[idx]):
            # add the subplot
            ax = fig.add_subplot(1, len(indices[idx]), plot_x_offset + 1)
            # get the correponding filename for the current index
            fname = os.path.join(data_path, filenames[neighbor_idx])
            # plot the image
            plt.imshow(get_image_as_np_array(fname))
            # set the title to the distance of the neighbor
            ax.set_title(f"d={distances[idx][plot_x_offset]:.3f}")
            # let's disable the axis
            plt.axis("off")

In [None]:
plot_knn_examples(embeddings, filenames)

In [None]:
def plot_umap_features(embeddings, filenames, num_samples=1000):
    """Plots UMAP visualization of embeddings with colors based on classes and legend outside plot"""
    reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, metric='euclidean')
    sampled_indices = np.random.choice(len(embeddings), size=min(num_samples, len(embeddings)), replace=False)
    sampled_embeddings = embeddings[sampled_indices]
    sampled_filenames = [filenames[i] for i in sampled_indices]

    # Extract classes from filenames (assumes format 'class\\name.png')
    classes = [fname.split('\\')[0] for fname in sampled_filenames]
    unique_classes = list(set(classes))
    class_to_color = {cls: i for i, cls in enumerate(unique_classes)}
    colors = [class_to_color[cls] for cls in classes]

    umap_embeddings = reducer.fit_transform(sampled_embeddings)

    fig, ax = plt.subplots(figsize=(10, 8))
    scatter = ax.scatter(umap_embeddings[:, 0], umap_embeddings[:, 1], c=colors, cmap='tab10', s=5)

    # Create custom legend manually
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', markersize=5, markerfacecolor=plt.get_cmap('tab10')(i))
        for i, cls in enumerate(unique_classes)
    ]
    ax.legend(handles=legend_elements, labels=unique_classes, title="Classes",
              bbox_to_anchor=(1.05, 1), loc='upper left')

    plt.title('UMAP visualization of image embeddings by class')
    plt.show()

In [None]:
plot_umap_features(embeddings, filenames, num_samples=len(filenames))