In [None]:
import os
from pathlib import Path
from dotenv import load_dotenv
import rootutils
import numpy as np
import cv2
import matplotlib.pyplot as plt

import torch
import torchvision
from sklearn.preprocessing import normalize

import albumentations as A
from albumentations.pytorch import ToTensorV2

# adding root to python path
rootutils.setup_root(
    os.path.abspath(''), indicator=['.git', 'pyproject.toml'], pythonpath=True
)

from src.models.components.nn_utils import weight_load
from src.data.components.utils import list_files
from src.models.components.base_model import BaseModel
from src.data.components.preprocessing.preproc_strategy_tile import sliding_window_with_coordinates

load_dotenv()

In [None]:
def segment_image(image: np.array, model: torch.nn.Module, device: torch.device, transform) -> np.array:
    transformed = transform(image=image)
    image_tensor = transformed['image'].unsqueeze(0).to(device)
    with torch.no_grad():
        out = model(image_tensor)
    out = torch.nn.functional.interpolate(out, size=image.shape[:2], mode="bilinear", align_corners=False)
    mask = torch.sigmoid(out[0])
    mask = (mask > 0.5).float()

    mask = mask.detach().cpu().numpy()
    mask = (mask[0] * 255).astype('uint8')
    return mask

In [None]:
def crop_and_align_image(image: np.array, mask: np.array) -> np.array:
    # Ensure the mask is binary
    _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
    # Compute the bounding box of the mask
    x, y, w, h = cv2.boundingRect(mask)
    # Compute the center of the bounding box
    bbox_cx, bbox_cy = x + w // 2, y + h // 2  # (x, y) coordinates
    # Create an empty black image of the same size
    H, W = image.shape[:2]
    output = np.zeros_like(image)
    # Extract the masked region from the original image
    masked_region = cv2.bitwise_and(image, image, mask=mask)
    # Compute the new center position (center of output image)
    new_cx, new_cy = W // 2, H // 2  # Image center
    # Compute translation offsets
    dx, dy = int(new_cx - bbox_cx), int(new_cy - bbox_cy)
    # Create a translation matrix
    M = np.float32([[1, 0, dx], [0, 1, dy]])
    # Move the masked region to the new position
    moved_masked_region = cv2.warpAffine(masked_region, M, (W, H))
    # Move the mask itself to match the new position
    moved_mask = cv2.warpAffine(mask, M, (W, H))
    # Combine only the valid (non-zero) parts into the output image
    output[moved_mask > 0] = moved_masked_region[moved_mask > 0]

    return output

In [None]:
def is_background(image: np.ndarray, background_perc: float) -> bool:
    """
    Determines if an image has a background percentage of black pixels
    greater than the specified threshold.

    Args:
        background_perc (float): The threshold percentage for black pixels.

    Returns:
        bool: True if the percentage of black pixels is greater than the threshold, False otherwise.
    """
    black_pixels = np.all(image == 0, axis=-1)
    black_pixel_count = np.sum(black_pixels)
    total_pixels = image.shape[0] * image.shape[1]
    black_pixel_percentage = black_pixel_count / total_pixels
    return black_pixel_percentage > background_perc

In [None]:
def extract_tiles(image: np.array, tile_size: tuple[int, int] = (224, 224), ovelap: int = 0, background_th: float = 0.8) -> list[np.array]:
    tiles = []
    for tile, coordinates in sliding_window_with_coordinates(image, tile_size=tile_size, overlap=ovelap):
        if is_background(tile, background_th):
            continue
        tiles.append(tile)

    return tiles

In [None]:
def generate_embeddings(model: torch.nn.Module, image_tiles: list[np.array], device: torch.device, transform) -> np.array:
    embeddings = []
    with torch.no_grad():
        for img in image_tiles:
            img = transform(img)
            img = img.unsqueeze(0).to(device)
            emb = model(img).flatten(start_dim=1)
            embeddings.append(emb)

    embeddings = torch.cat(embeddings, 0)
    embeddings = normalize(embeddings.cpu().numpy())
    return embeddings

In [None]:
def find_nearest_neighbors(query_emb: np.ndarray, reference_emb: np.ndarray) -> np.ndarray:
    # Calculate pairwise distances using broadcasting
    distances = np.linalg.norm(query_emb[:, np.newaxis] - reference_emb, axis=2)
    # Find indices of minimum distances along axis 1
    nearest_indices = np.argmin(distances, axis=1)
    # Extract minimum distances
    min_distances = np.min(distances, axis=1)

    for i, (idx, dist) in enumerate(zip(nearest_indices, min_distances)):
        print(f"Vector {i} nearest neighbor is at index {idx} with distance: {dist:.4f}")

    # Plot distances
    plt.figure(figsize=(8, 6))
    plt.plot(range(len(min_distances)), min_distances, marker='o', linestyle='-', color='b')
    plt.xlabel('Query Vector Index')
    plt.ylabel('Distance to Nearest Neighbor')
    plt.title('Distance of Each Query Vector to its Nearest Neighbor')
    plt.grid(True)
    plt.show()

    return nearest_indices

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
segmentation_model = BaseModel(
    model_name = 'segmentation_models_pytorch/Segformer',
    encoder_name = 'resnet50',
    encoder_weigths = 'imagenet',
    num_classes = 1
    ).to(device)
segmentation_weights = weight_load(
    ckpt_path='../trained_models/segformer.ckpt',
    weights_only=True,
)
segmentation_model.load_state_dict(segmentation_weights)
segmentation_model.eval()

segmentation_transform = A.Compose([
    A.Resize(
        height = 768,
        width = 640
        ),
    A.ToFloat(max_value=255),
    ToTensorV2(),
])

In [None]:
resnet = torchvision.models.resnet18()
embedding_model = torch.nn.Sequential(*list(resnet.children())[:-1]).to(device)

embedding_weights = weight_load(
    ckpt_path='../trained_models/contrastive_model.ckpt',
    weights_only=True,
    remove_prefix='backbone.'
)
embedding_model.load_state_dict(embedding_weights, strict=False)
embedding_model.eval()

embedding_transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
    ]
)


In [None]:
good_embeddings = np.load("good_embeddings.npy")

In [None]:
data_path = Path(os.environ.get('lear_bad_data_path'))
image_paths = list_files(data_path, file_extensions=['.bmp', '.jpg', '.png'])
print(f'Found {len(image_paths)} images')

output_path = Path('tiles')
output_path.mkdir(exist_ok=True)
tile_size = (512, 512)
overlap = 20

for image_path in image_paths:
    image = cv2.imread(str(image_path), cv2.IMREAD_COLOR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    mask = segment_image(image, segmentation_model, device, segmentation_transform)
    aligned_image = crop_and_align_image(image, mask)
    tiles = extract_tiles(aligned_image, tile_size=tile_size, ovelap=overlap)
    #save tiles
    for i, tile in enumerate(tiles):
        cv2.imwrite(f"tiles/tile_{i}.png", cv2.cvtColor(tile, cv2.COLOR_RGB2BGR))
    embeddings = generate_embeddings(embedding_model, tiles, device, embedding_transform)
    print(image_path.name)
    nearest_indices = find_nearest_neighbors(embeddings, good_embeddings)
    #break