In [None]:
import os
from pathlib import Path
from dotenv import load_dotenv
import rootutils
import numpy as np
import cv2
import matplotlib.pyplot as plt

import torch

import albumentations as A
from albumentations.pytorch import ToTensorV2

# adding root to python path
rootutils.setup_root(
    os.path.abspath(''), indicator=['.git', 'pyproject.toml'], pythonpath=True
)

from src.models.components.nn_utils import weight_load
from src.data.components.utils import list_files
from src.models.components.base_model import BaseModel
from src.data.components.preprocessing.preproc_strategy_tile import sliding_window_with_coordinates
from src.models.components.mae_model import mae_vit_base_patch16_dec512d8b

load_dotenv()

In [None]:
def segment_image(image: np.ndarray, model: torch.nn.Module, device: torch.device, transform) -> np.ndarray:
    transformed = transform(image=image)
    image_tensor = transformed['image'].unsqueeze(0).to(device)
    with torch.no_grad():
        out = model(image_tensor)
    out = torch.nn.functional.interpolate(out, size=image.shape[:2], mode="bilinear", align_corners=False)
    mask = torch.sigmoid(out[0])
    mask = (mask > 0.5).float()

    mask = mask.detach().cpu().numpy()
    mask = (mask[0] * 255).astype('uint8')
    return mask

In [None]:
def crop_and_align_image(image: np.array, mask: np.array) -> np.array:
    # Ensure the mask is binary
    _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
    # Compute the bounding box of the mask
    x, y, w, h = cv2.boundingRect(mask)
    # Compute the center of the bounding box
    bbox_cx, bbox_cy = x + w // 2, y + h // 2  # (x, y) coordinates
    # Create an empty black image of the same size
    H, W = image.shape[:2]
    output = np.zeros_like(image)
    # Extract the masked region from the original image
    masked_region = cv2.bitwise_and(image, image, mask=mask)
    # Compute the new center position (center of output image)
    new_cx, new_cy = W // 2, H // 2  # Image center
    # Compute translation offsets
    dx, dy = int(new_cx - bbox_cx), int(new_cy - bbox_cy)
    # Create a translation matrix
    M = np.float32([[1, 0, dx], [0, 1, dy]])
    # Move the masked region to the new position
    moved_masked_region = cv2.warpAffine(masked_region, M, (W, H))
    # Move the mask itself to match the new position
    moved_mask = cv2.warpAffine(mask, M, (W, H))
    # Combine only the valid (non-zero) parts into the output image
    output[moved_mask > 0] = moved_masked_region[moved_mask > 0]

    return output

In [None]:
def is_background(image: np.ndarray, background_perc: float) -> bool:
    """
    Determines if an image has a background percentage of black pixels
    greater than the specified threshold.

    Args:
        background_perc (float): The threshold percentage for black pixels.

    Returns:
        bool: True if the percentage of black pixels is greater than the threshold, False otherwise.
    """
    black_pixels = np.all(image == 0, axis=-1)
    black_pixel_count = np.sum(black_pixels)
    total_pixels = image.shape[0] * image.shape[1]
    black_pixel_percentage = black_pixel_count / total_pixels
    return black_pixel_percentage > background_perc

def extract_tiles(image: np.array, tile_size: tuple[int, int] = (224, 224), ovelap: int = 0, background_th: float = 0.8) -> list[np.array]:
    tiles = []
    for tile, coordinates in sliding_window_with_coordinates(image, tile_size=tile_size, overlap=ovelap):
        # calcaulte center of the tile
        if is_background(tile, background_th):
            continue
        tiles.append(tile)

    return tiles

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
segmentation_model = BaseModel(
    model_name = 'segmentation_models_pytorch/Segformer',
    encoder_name = 'resnet50',
    encoder_weigths = 'imagenet',
    num_classes = 1
).to(device)
segmentation_weights = weight_load(
    ckpt_path='../trained_models/segformer.ckpt',
    weights_only=True,
)
segmentation_model.load_state_dict(segmentation_weights)
segmentation_model.eval()

segmentation_transform = A.Compose([
    A.Resize(
        height = 768,
        width = 640
        ),
    A.ToFloat(max_value=255),
    ToTensorV2(),
])

In [None]:
model = mae_vit_base_patch16_dec512d8b()

mae_weights = weight_load(
    ckpt_path='../trained_models/mae.ckpt',
    weights_only=True,
)
model.load_state_dict(mae_weights, strict=False)
model.eval()
model.to(device)

In [None]:
def show_image(image, title=''):
    # image is [H, W, 3]
    assert image.shape[2] == 3
    plt.imshow(torch.clip(image * 255, 0, 255).int())
    plt.title(title, fontsize=16)
    plt.axis('off')
    return

def run_one_image(img, model):
    x = torch.tensor(img).to(device)

    # make it a batch-like
    x = x.unsqueeze(dim=0)
    x = torch.einsum('nhwc->nchw', x)

    # run MAE
    loss, y, mask = model(x.float(), mask_ratio=0.5)
    y = model.unpatchify(y)
    y = torch.einsum('nchw->nhwc', y).detach().cpu()

    # visualize the mask
    mask = mask.detach()
    mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3)  # (N, H*W, p*p*3)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
    
    x = torch.einsum('nchw->nhwc', x)
    x = x.detach().cpu()

    # masked image
    im_masked = x * (1 - mask)

    # MAE reconstruction pasted with visible patches
    im_paste = x * (1 - mask) + y * mask

    # Compute the absolute difference
    diff = torch.abs(x - im_paste).numpy()[0]
    diff_gray = np.mean(diff, axis=-1)  # Convert to grayscale difference
    # Normalize the difference for visualization
    diff_gray = (diff_gray - np.min(diff_gray)) / (np.max(diff_gray) - np.min(diff_gray) + 1e-6)
    # Convert to a heatmap using OpenCV
    diff_colormap = cv2.applyColorMap((diff_gray * 255).astype(np.uint8), cv2.COLORMAP_JET)
    # Convert OpenCV BGR format to RGB
    diff_colormap = cv2.cvtColor(diff_colormap, cv2.COLOR_BGR2RGB)

    # make the plt figure larger
    plt.rcParams['figure.figsize'] = [24, 24]

    plt.subplot(1, 5, 1)
    show_image(x[0], "original")

    plt.subplot(1, 5, 2)
    show_image(im_masked[0], "masked")

    plt.subplot(1, 5, 3)
    show_image(y[0], "reconstruction")

    plt.subplot(1, 5, 4)
    show_image(im_paste[0], "reconstruction + visible")

    plt.subplot(1, 5, 5)
    plt.imshow(diff_colormap)
    plt.axis("off")
    plt.title("Difference Heatmap")

    plt.show()

In [None]:
data_path = Path(os.environ.get('lear_bad_data_path'))
image_paths = list_files(data_path, file_extensions=['.bmp', '.jpg', '.png'])
print(f'Found {len(image_paths)} images')

tile_size = (224, 224)
overlap = 20

for i, image_path in enumerate(image_paths):
    image = cv2.imread(str(image_path), cv2.IMREAD_COLOR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    mask = segment_image(image, segmentation_model, device, segmentation_transform)
    aligned_image = crop_and_align_image(image, mask)
    tiles = extract_tiles(aligned_image, tile_size=tile_size, ovelap=overlap)
    for tile in tiles:
        tile = np.array(tile) / 255.
        assert tile.shape == (224, 224, 3)
        run_one_image(tile, model)
    break