In [None]:
from __future__ import annotations

import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
import csv
from dataclasses import dataclass

import numpy as np
import cv2
import imageio
import matplotlib.pyplot as plt

from coin_ai.fg_segmentation import SegmentationDino

In [None]:
@dataclass
class Homography:
    path1: str
    path2: str
    H: np.ndarray

    @property
    def image1(self) -> np.ndarray:
        return imageio.imread(self.path1)
    
    @property
    def image2(self) -> np.ndarray:
        return imageio.imread(self.path2)
    
    def inverse(self) -> Homography:
        return Homography(self.path2, self.path1, np.linalg.inv(self.H))


def parse_homography_csv(source_path: str) -> list[Homography]:
    HEADERS = ['img1', 'h11', 'h12', 'h13', 'h21', 'h22', 'h23', 'h31', 'h32', 'h33']

    if not os.path.exists(source_path):
        raise ValueError(f'File {source_path} does not exist')

    pairs = []
    with open(source_path, 'r') as f:
        reader = csv.reader(f)
        headers = next(reader)
        assert headers[:len(HEADERS)] == HEADERS

        for img1, img2, *floats in reader:
            H_floats = floats[:9]
            source_dir = os.path.dirname(source_path)
            path1 = os.path.join(source_dir, img1)
            path2 = os.path.join(source_dir, img2)
            H = np.array(H_floats, dtype=np.float32).reshape(3, 3)
            H = H[[1, 0, 2]][:, [1, 0, 2]] # undo the xy flip
            H = np.linalg.inv(H)
            pairs.append(Homography(
                path1, path2, H,
            ))
    
    return pairs

In [None]:
import torch
import numpy as np
from torch import nn, Tensor
from einops import rearrange
from torchvision.transforms import v2 as transforms
from kornia.utils import image_to_tensor


class DenseDino(nn.Module):
    def __init__(self):
        super().__init__()
        self.dino = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14_reg")
        self.dino.eval()
    
    @property
    def device(self) -> torch.device:
        return next(self.parameters()).device

    def forward(self, x: Tensor | np.ndarray) -> Tensor:
        if isinstance(x, np.ndarray):
            x = image_to_tensor(x, keepdim=False)

        b, c, h, w = x.shape
        assert (c, h, w) == (3, 518, 518)

        if x.dtype == torch.uint8:
            x = transforms.ToDtype(torch.float32, scale=True)(x)
        x = transforms.Grayscale(num_output_channels=3)(x)
        x = x.to(self.device)
        x = self.dino.forward_features(x)["x_norm_patchtokens"]
        return rearrange(x, "b (h w) c -> b h w c", h=37, w=37)

In [None]:
import torch.nn.functional as F
from kornia.feature import match_mnn, match_smnn

@dataclass
class MatchingPatches:
    left: Tensor # (h, w, c)
    right: Tensor # (h, w, c)
    is_correct_match: Tensor # (h, w, h, w)

    def to(self, device: torch.device) -> MatchingPatches:
        return MatchingPatches(
            left=self.left.to(device),
            right=self.right.to(device),
            is_correct_match=self.is_correct_match.to(device),
        )
    
    def flip(self) -> MatchingPatches:
        return MatchingPatches(
            left=self.right,
            right=self.left,
            is_correct_match=rearrange(self.is_correct_match, 'i j k l -> k l i j'),
        )

    def __repr__(self) -> str:
        shape = lambda t: tuple(t.shape)
        return f'MatchingPatches(left={shape(self.left)}, right={shape(self.right)}, is_correct_match={shape(self.is_correct_match)})'

In [None]:
def match_patches(dense_dino: DenseDino, segmentation: nn.Linear, image1: np.ndarray, image2: np.ndarray, H: np.ndarray, threshold: float = 1.4) -> MatchingPatches:
    H = torch.from_numpy(H).clone()
    H[:2, -1] /= 14

    with torch.no_grad():
        left_feat = dense_dino(image1).cpu().squeeze(0)
        right_feat = dense_dino(image2).cpu().squeeze(0)

        left_fg_mask = segmentation(left_feat).cpu().squeeze(0).squeeze(-1) > 0
        right_fg_mask = segmentation(right_feat).cpu().squeeze(0).squeeze(-1) > 0

    grid = torch.stack(torch.meshgrid(
        torch.arange(37),
        torch.arange(37),
        indexing='xy',
    ), dim=-1).to(torch.float32) + 0.5

    grid_homo = torch.cat([
        grid,
        torch.ones((37, 37, 1), device=grid.device, dtype=grid.dtype),
    ], dim=-1)

    left_mapped = torch.einsum('ij,hwj->hwi', H, grid_homo)
    left_mapped = left_mapped[..., 0:2] / left_mapped[..., 2:3]

    distances = torch.linalg.norm(
        left_mapped[:, :, None, None, :]- grid[None, None, :, :, :],
        dim=-1,
    )

    is_correct_match = (distances < threshold) & left_fg_mask[..., None, None] & right_fg_mask[None, None, ...]

    return MatchingPatches(
        left_feat,
        right_feat,
        is_correct_match,
    )

In [None]:
class KeypointExtractor(nn.Module):
    def __init__(self, dino: DenseDino, segmenter: nn.Linear, descriptor: nn.Linear):
        super().__init__()
        self.dino = dino
        self.segmenter = segmenter
        self.descriptor = descriptor

    @torch.no_grad
    def forward(self, image: np.ndarray) -> tuple[Tensor, Tensor]:
        raw_features = self.dino(image).squeeze(0)
        is_foreground = self.segmenter(raw_features).squeeze(-1) > 0
        matching_features = self.descriptor(raw_features)
        matching_features = F.normalize(matching_features, dim=-1)

        coord_grid = torch.stack(
            torch.meshgrid(
                torch.arange(37, device=raw_features.device),
                torch.arange(37, device=raw_features.device),
                indexing='xy',
            ),
            dim=-1,
        ).to(torch.float32) * 14 + 7

        return coord_grid[is_foreground], matching_features[is_foreground]

In [None]:
from matplotlib.figure import Figure, Axes

def get_matching_keypoints(kp1: Tensor, kp2: Tensor, idxs: Tensor) -> Tensor:
    mkpts1 = kp1[idxs[:, 0]]
    mkpts2 = kp2[idxs[:, 1]]
    return mkpts1, mkpts2

class ImageAligner:
    def __init__(self, extractor: KeypointExtractor, n_steps: int = 3, with_visualization: bool = True):
        self.extractor = extractor
        self.n_steps = n_steps
        self.with_visualization = with_visualization
    
    def find_matches(self, image1: np.ndarray, image2: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
        kps1, descs1 = self.extractor(image1)
        kps2, descs2 = self.extractor(image2)

        _dists, idxs = match_smnn(descs1, descs2, 0.90)
        mkpts1, mkpts2 = get_matching_keypoints(kps1, kps2, idxs)

        return mkpts1.cpu().numpy(), mkpts2.cpu().numpy()
    
    def find_homography(self, kpts1: np.ndarray, kpts2: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
        Hm, inliers = cv2.findHomography(
            srcPoints=kpts1,
            dstPoints=kpts2,
            method=cv2.RANSAC,
            ransacReprojThreshold=14.,
            confidence=0.999,
            maxIters=100000,
        )

        return Hm, inliers
    
    def warp_perpective(self, image: np.ndarray, H: np.ndarray) -> np.ndarray:
        image_f32 = image.astype(np.float32) / 255
        warp_f32 = cv2.warpPerspective(
            image_f32,
            np.linalg.inv(H),
            (image.shape[1], image.shape[0]),
            borderMode=cv2.BORDER_REPLICATE,
        )

        return (warp_f32 * 255).astype(np.uint8)
    
    def visualize_matches(
        self,
        image1: np.ndarray,
        image2: np.ndarray,
        mkpts1: np.ndarray,
        mkpts2: np.ndarray,
        inliers: np.ndarray,
    ) -> tuple[Figure, Axes]:
        kpts_to_cv2 = lambda kpts: [cv2.KeyPoint(x, y, 14) for x, y in kpts]

        outlier_matches = [cv2.DMatch(i, i, 0) for i in range(len(mkpts1)) if not inliers[i]]
        inlier_matches = [cv2.DMatch(i, i, 0) for i in range(len(mkpts1)) if inliers[i]]

        image = cv2.drawMatches(
            image1,
            kpts_to_cv2(mkpts1),
            image2,
            kpts_to_cv2(mkpts2),
            outlier_matches,
            outImg=None,
            matchColor=(0, 0, 255),
            flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS,
        )

        image = cv2.drawMatches(
            image1,
            kpts_to_cv2(mkpts1),
            image2,
            kpts_to_cv2(mkpts2),
            inlier_matches,
            outImg=image,
            matchColor=(0, 255, 0),
            flags=cv2.DrawMatchesFlags_DEFAULT,
        )

        fig, ax = plt.subplots()
        ax.imshow(image)
        ax.axis('off')
        return fig, ax

    
    def step(self, image1: np.ndarray, image2: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
        mkpts1, mkpts2 = self.find_matches(image1, image2)

        Hm, inliers = self.find_homography(mkpts1, mkpts2)

        if self.with_visualization:
            self.visualize_matches(image1, image2, mkpts1, mkpts2, inliers.squeeze(-1))
            plt.show()

        image1_warped = self.warp_perpective(image1, Hm)

        return image1_warped, Hm

    def __call__(self, image1: np.ndarray, image2: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
        for _ in range(self.n_steps):
            image1, _ = self.step(image1, image2)
            plt.figure()
            plt.imshow(image1)
            plt.show()
        return image1, image2

In [None]:
pairs = []
for dir, _, files in os.walk('/Users/jatentaki/Data/archeo/coins/cropped-with-background/Krzywousty/'):
    for file in files:
        if not file.endswith('.csv'):
            continue
        path = os.path.join(dir, file)
        pairs.extend(parse_homography_csv(path))

print(len(pairs))
#pairs = parse_homography_csv('/Users/jatentaki/Data/archeo/coins/cropped-with-background/Krzywousty/Typ 3/Awers - stempel a.07/homographies.csv')
main_pair = pairs[0]
image1 = main_pair.image1
image2 = main_pair.image2

In [None]:
sd = SegmentationDino()
dense_dino = DenseDino()

In [None]:
mps = [match_patches(dense_dino, sd.head, image1, image2, pair.H, threshold=1.0) for pair in pairs]
main_mp = mps[0]
plt.imshow(main_mp.is_correct_match.cpu().numpy()[25, 15])
plt.scatter([15], [25], c='r')

In [None]:
def _loss_fn(mp: MatchingPatches, embed_fn: nn.Module, margin: float = 0.5) -> tuple[Tensor, dict[str, Tensor]]:
    left = embed_fn(mp.left)
    right = embed_fn(mp.right)

    left = rearrange(left, "h w c -> (h w) c")
    right = rearrange(right, "h w c -> (h w) c")
    is_correct_match = rearrange(mp.is_correct_match, "i j k l -> (i j) (k l)")

    with torch.no_grad():
        #_dists, idxs = match_mnn(left, right)
        _dists, idxs = match_smnn(left, right, 0.8)
    is_correct_choice = is_correct_match[idxs[..., 0], idxs[..., 1]]

    is_any_correct = is_correct_match.any(dim=1)
    left = left[is_any_correct]
    is_correct_match = is_correct_match[is_any_correct]

    similarity = F.cosine_similarity(left[:, None, :], right[None, :], dim=-1)

    negative_similarity = torch.where(
        ~is_correct_match,
        similarity,
        torch.tensor(0.0, device=similarity.device, dtype=similarity.dtype),
    )

    positive_similarity = torch.where(
        is_correct_match,
        similarity,
        torch.tensor(-1.0, device=similarity.device, dtype=similarity.dtype),
    ).max(dim=1, keepdim=True).values

    loss = F.relu(negative_similarity - positive_similarity + margin)
    n_nonzero = (loss > 0).sum()
    loss = loss.sum() / n_nonzero
    
    accuracy = is_correct_choice.float().mean().item() if is_correct_choice.numel() else 0.0
    metrics = {
        'loss': loss.item(),
        'accuracy': accuracy,
        'n_matches': is_correct_choice.numel(),
    }

    return loss, metrics

def loss_fn(mp: MatchingPatches, embed_fn: nn.Module, margin: float = 0.5) -> tuple[Tensor, dict[str, Tensor]]:
    left = embed_fn(mp.left)
    right = embed_fn(mp.right)

    left = rearrange(left, "h w c -> (h w) c")
    right = rearrange(right, "h w c -> (h w) c")
    similarity = F.cosine_similarity(left[:, None, :], right[None, :], dim=-1)
    is_correct_match = rearrange(mp.is_correct_match, "i j k l -> (i j) (k l)")

    with torch.no_grad():
        _dists, idxs = match_smnn(left, right, dm=1 - similarity, th=0.8)
    is_correct_choice = is_correct_match[idxs[..., 0], idxs[..., 1]]

    loss_lr, n_lr = _one_way(similarity, is_correct_match, margin=margin)
    loss_rl, n_rl = _one_way(similarity.T, is_correct_match.T, margin=margin)

    loss = (loss_lr + loss_rl) / (n_lr + n_rl)

    accuracy = is_correct_choice.float().mean().item() if is_correct_choice.numel() else 0.0
    metrics = {
        'loss': loss.item(),
        'accuracy': accuracy,
        'n_matches': is_correct_choice.numel(),
    }

    return loss, metrics

def _one_way(similarity: Tensor, is_correct_match: Tensor, margin: float = 0.5) -> tuple[Tensor, Tensor]:
    is_any_correct = is_correct_match.any(dim=1)
    similarity = similarity[is_any_correct]
    is_correct_match = is_correct_match[is_any_correct]

    negative_similarity = torch.where(
        ~is_correct_match,
        similarity,
        torch.tensor(0.0, device=similarity.device, dtype=similarity.dtype),
    )

    positive_similarity = torch.where(
        is_correct_match,
        similarity,
        torch.tensor(-1.0, device=similarity.device, dtype=similarity.dtype),
    ).max(dim=1, keepdim=True).values

    loss = F.relu(negative_similarity - positive_similarity + margin)
    n_nonzero = (loss > 0).sum()

    return loss.sum(), n_nonzero

def total_loss(mps: list[MatchingPatches], embedder: nn.Module, **kwargs) -> tuple[Tensor, dict[str, float]]:
    total_loss = 0
    total_metrics = {}
    for mp in mps:
        loss, metrics = loss_fn(mp, embedder, **kwargs)
        total_loss = total_loss + loss
        for k, v in metrics.items():
            total_metrics[k] = total_metrics.get(k, 0) + v
    
    for k in total_metrics:
        total_metrics[k] /= len(mps)

    return total_loss, total_metrics

In [None]:
import random
from tqdm.auto import tqdm

device = torch.device('mps')
embedder = nn.Sequential(
    nn.LayerNorm(384, elementwise_affine=False, bias=False),
    nn.Linear(384, 32),
).to(device)
mps_train = []
for mp in mps[:2]:
    mps_train.append(mp.to(device))
optim = torch.optim.Adam(embedder.parameters(), lr=1e-3)
#optim = torch.optim.SGD(embedder.parameters(), lr=1e-4)

metrics_log = {}
with tqdm(range(500)) as pbar:
    for i in pbar:
        random.shuffle(mps_train)
        loss, metrics = total_loss(mps_train[:4], embedder, margin=0.25)
        #loss, metrics = mp.loss(embedder, margin=0.25)
        for k, v in metrics.items():
            metrics_log.setdefault(k, []).append(v)
        optim.zero_grad()
        loss.backward()
        optim.step()
        pbar.set_postfix(**metrics)

fig, loss_ax = plt.subplots()
loss_ax.plot(metrics_log['loss'])
loss_ax.set_ylabel('Loss')
acc_ax = loss_ax.twinx()
acc_ax.plot(metrics_log['accuracy'], color='orange')
acc_ax.set_ylabel('Accuracy')

In [None]:
main_mp = main_mp.to(device)
def show_similarity(point: tuple[int, int]) -> None:
    fig, (a1, a2, a3, a4) = plt.subplots(1, 4, figsize=(20, 5), constrained_layout=True)
    a1.imshow(image1)
    a1.scatter([point[0] * 14], [point[1] * 14], c='r')
    a2.imshow(image2)
    similarity_base = F.cosine_similarity(main_mp.left.squeeze(0)[point[1], point[0]], main_mp.right.squeeze(0).flatten(0, -2), dim=-1)
    argmax_x, argmax_y = np.unravel_index(similarity_base.argmax().cpu().numpy(), (37, 37))
    a3.imshow(similarity_base.cpu().numpy().reshape(37, 37), vmin=-1, vmax=1)
    a3.scatter([argmax_y], [argmax_x], c='r')
    with torch.no_grad():
        similarity_embedder = F.cosine_similarity(embedder(main_mp.left.squeeze(0))[point[1], point[0]], embedder(main_mp.right.squeeze(0).flatten(0, -2)), dim=-1)
    similarity_embedder = similarity_embedder.cpu().numpy().reshape(37, 37)
    argmax_x, argmax_y = np.unravel_index(similarity_embedder.argmax(), similarity_embedder.shape)
    a4.imshow(similarity_embedder, vmin=-1, vmax=1)
    a4.scatter([argmax_y], [argmax_x], c='r')
    a2.scatter([argmax_y * 14], [argmax_x * 14], c='r')

    for ax in (a1, a2, a3, a4):
        ax.axis('off')

show_similarity((25, 15))
show_similarity((10, 13))
show_similarity((10, 25))

In [None]:
kpe = KeypointExtractor(dense_dino, sd.head, embedder).to(device)
aligner = ImageAligner(kpe, n_steps=3, with_visualization=True)

In [None]:
aligner = ImageAligner(kpe, 3)

warp_1, H = aligner.step(image1, image2)
plt.imsave('0_src.png', image1)
plt.imsave('1_warp.png', warp_1)
plt.imsave('2_target.png', image2)


warp_1, warp_2 = aligner(image1, image2)

plt.imshow(warp_1)